[딥러닝]
이미지 분류 - RNN으로 손글씨 이미지 분류하기
회색세계
2021. 3. 29. 11:43
0. Import
import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense
from keras.models import Sequential
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from IPython.display import Image
warnings.filterwarnings('ignore')
%matplotlib inline
SEED = 34
1. mnist 데이터 살펴보기
mnist = keras.datasets.mnist
((train_images, train_labels), (test_images, test_labels)) = mnist.load_data()
2. 데이터의 shape을 출력
print(f"train_imgaes: {train_images.shape}")
print(f"train_labels: {train_labels.shape}")
print(f"test_imgaes: {test_images.shape}")
print(f"test_labels: {test_labels.shape}")
3. (28, 28) 형태의 이미지를 plt을 이용하여 출력
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid()
plt.show()
print(train_labels[0])
4. train_images에서 0이 아닌 값들을 출력
print(list(filter(lambda x: x != 0 , train_images[0].reshape(-1)))[:10])
5. train_images의 dtype을 출력
print(train_images.dtype)
print(train_labels.dtype)
print(test_images.dtype)
print(test_labels.dtype)
6. train/test 이미지 데이터의 범위 확인
print(list(filter(lambda x: x != 0, train_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, train_labels.reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_labels.reshape(-1)))[:10])
7. train/test 이미지 데이터의 최소/최대값을 출력
print(max(train_images.reshape(-1)), min(train_images.reshape(-1)))
print(max(test_images.reshape(-1)), min(test_images.reshape(-1)))
8. 정수형을 실수형으로 변경 후 dtype으로 비교
train_images = train_images.astype(np.float64)
test_images = test_images.astype(np.float64)
9. 데이터 0-1 노말라이즈 수행
train_images = train_images / 255
test_images = test_images / 255
10. 0-1 노말라이즈 후 데이터의 값이 변경되었는지 문제 6, 7의 방법을 이용하여 확인
print(list(filter(lambda x: x != 0, train_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, train_labels.reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_labels.reshape(-1)))[:10])
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)
print(train_images.dtype, train_labels.dtype, test_images.dtype, test_labels.dtype)