[딥러닝]

이미지 분류 - RNN으로 손글씨 이미지 분류하기

회색세계 2021. 3. 29. 11:43

0. Import

import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense
from keras.models import Sequential

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings 
from IPython.display import Image

warnings.filterwarnings('ignore')
%matplotlib inline

SEED = 34

1.  mnist 데이터 살펴보기

mnist = keras.datasets.mnist
((train_images, train_labels), (test_images, test_labels)) = mnist.load_data()

2. 데이터의 shape을 출력

print(f"train_imgaes: {train_images.shape}")
print(f"train_labels: {train_labels.shape}")

print(f"test_imgaes: {test_images.shape}")
print(f"test_labels: {test_labels.shape}")

3. (28, 28) 형태의 이미지를 plt을 이용하여 출력

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid()
plt.show()
print(train_labels[0])

4. train_images에서 0이 아닌 값들을 출력

print(list(filter(lambda x: x != 0 , train_images[0].reshape(-1)))[:10])

5. train_images의 dtype을 출력

print(train_images.dtype)
print(train_labels.dtype)
print(test_images.dtype)
print(test_labels.dtype)

6. train/test 이미지 데이터의 범위 확인

print(list(filter(lambda x: x != 0, train_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, train_labels.reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_labels.reshape(-1)))[:10])

7. train/test 이미지 데이터의 최소/최대값을 출력

print(max(train_images.reshape(-1)), min(train_images.reshape(-1)))
print(max(test_images.reshape(-1)), min(test_images.reshape(-1)))

8. 정수형을 실수형으로 변경 후 dtype으로 비교

train_images = train_images.astype(np.float64)
test_images = test_images.astype(np.float64)

9. 데이터 0-1 노말라이즈 수행

train_images = train_images / 255
test_images = test_images / 255

10. 0-1 노말라이즈 후 데이터의 값이 변경되었는지 문제 6, 7의 방법을 이용하여 확인

print(list(filter(lambda x: x != 0, train_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, train_labels.reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_images[0].reshape(-1)))[:10])
print(list(filter(lambda x: x != 0, test_labels.reshape(-1)))[:10])
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)
print(train_images.dtype, train_labels.dtype, test_images.dtype, test_labels.dtype)