[tensrolfow] 텐서플로 데이터 증대 ImageDataGeneration 사용

전체 코드


증대된 영상 확인

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

#CIFAR 10 부류
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

#데이터를 신경망에 입력할 형태로 변환
(x_train, y_train),(x_test,y_test) = cifar10.load_data()
x_train = x_train.astype('float32')/255.0
x_train=x_train[0:12,]; y_train=y_train[0:12,] #앞 12개만 가져옴

#앞 12개 영상 그리기
plt.figure(figsize=(16,2))
plt.suptitle("First 12 images in the train set")

for i in range(12):
    plt.subplot(1,12,i+1)
    plt.imshow(x_train[i])
    plt.xticks([]); plt.yticks([])
    plt.title(class_names[int(y_train[i])])

#영상 증대기
batch_siz = 6 #한번 생성 양
generator = ImageDataGenerator(rotation_range=30.0, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True)
gen = generator.flow(x_train, y_train, batch_size= batch_siz)

#첫번째 증대하고그리기
img,label = gen.next()
plt.figure(figsize=(16,3))
plt.suptitle('Generator trial 1')
for i in range(batch_siz):
    plt.subplot(1, batch_siz, i+1)
    plt.imshow(img[i])
    plt.xticks([]); plt.yticks([])
    plt.title(class_names[int(label[i])])


#두번째 증대하고그리기
img,label = gen.next()
plt.figure(figsize=(16,3))
plt.suptitle('Generator trial 2')
for i in range(batch_siz):
    plt.subplot(1, batch_siz, i+1)
    plt.imshow(img[i])
    plt.xticks([]); plt.yticks([])
    plt.title(class_names[int(label[i])])

맨 윗줄의 원본 데이터를 가공한 결과를 보여줌

 

 

ImageDataGerator로 생성한 이미지로 학습하기

#신경망 모델 학습(영상 증대기 활용)
cnn.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
batch_size=128
generator= ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
hist = cnn.fit_generator(generator.flow(x_train,y_train, batch_size=batch_siz), epochs=50, validation_data=(x_test, y_test), verbose=1)

 

 

코드 해설


증대된 영상 확인

ImageDataGenerator는 데이터에 있는 샘플을 무작위로 골라 변형을 시도한다.
generator = ImageDataGenerator(rotation_range=30.0, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True)

변형 방식을 설정하는 함수이다. 매개변수 4개를 통해, 30도 이내에서 회전하고, 가로/세로 방향 20%이내에서 이동하고, 좌우 반전을 시도하라고 지시한다. 

gen = generator.flow(x_train, y_train, batch_size= batch_siz)

x_train, y_train에 있는 데이터를 한번에 batch_siz크기만큼 뽑아 변형하라는 지시

img,label = gen.next()

next함수를 호출할 때마다 변형된 샘플이 6개씩 생성된다. 

ImageDataGenerator함수의 API

 

 

ImageDataGerator로 생성한 이미지로 학습하기

generator= ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)

ImageDataGenerator함수로 영상을 변형하는 방식 설정, 수평으로 0.1, 수직으로 0.1 범위 안에서 이동을 허용하고, 좌우 반전을 허용했다.

hist = cnn.fit_generator(generator.flow(x_train,y_train, batch_size=batch_siz), epochs=50, validation_data=(x_test, y_test), verbose=1)

여기를 주목!

데이터 증대를 활용할 때는 fit함수가 아닌 fit_generator를 사용해야 한다. 

이 함수에선 flow함수를 써서 학습 도중에 실시간으로 변형된 샘플이 생성되도록 설정한다.

 

 

 

 


참고교재 파이썬으로 만드는 인공지능 - 한빛미디어