[tensorflow] 생성 적대 신경망의 구현 + 코드 해설 (MNIST 이용)

MNIST 데이터를 이용해 생성 적대 신경망을 구현하는 코드이다.

생성망과 분별망을 설계하고 학습하며 생성망에서 진짜 같은 가짜 샘플을 출력해내도록 하는 것이 목표!

 

전체 코드


import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Activation, Dense, Flatten, Reshape, Conv2D, Conv2DTranspose, Dropout, BatchNormalization, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.losses import mse
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype('float32')/255.0)*2.0-1.0 #[-1,1]구간
x_test = (x_test.astype('float32')/255.0)*2.0-1.0 
x_train = np.reshape(x_train, (len(x_train), 28,28,1))
x_test = np.reshape(x_test, (len(x_test), 28,28,1))

batch_siz = 64
epochs = 5000
dropout_rate = 0.4
batch_norm = 0.9
zdim = 100 #잠복공간의 차원

discriminator_input = Input(shape=(28,28,1)) #분별망 D 설계
x= Conv2D(64,(5,5), activation = 'relu', padding = 'same', strides = (2,2))(discriminator_input)
x = Dropout(dropout_rate)(x)
x=Conv2D(64,(5,5), activation='relu', padding = 'same', strides=(2,2))(x)
x = Dropout(dropout_rate)(x)
x=Conv2D(128,(5,5), activation='relu', padding = 'same', strides=(2,2))(x)
x = Dropout(dropout_rate)(x)
x=Conv2D(128,(5,5), activation='relu', padding = 'same', strides=(1,1))(x)
x = Dropout(dropout_rate)(x)
x=Flatten()(x)
discriminator_output = Dense(1, activation = 'sigmoid')(x)
discriminator = Model(discriminator_input, discriminator_output)

generator_input = Input(shape=(zdim,))
x=Dense(3136)(generator_input)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Reshape((7,7,64))(x)
x=UpSampling2D()(x)
x=Conv2D(128,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=UpSampling2D()(x)
x=Conv2D(64,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Conv2D(64,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Conv2D(1,(5,5), activation='tanh',padding='same')(x)
generator_output= x
generator = Model(generator_input, generator_output)

discriminator.compile(optimizer='Adam', loss='binary_crossentropy', metrics = ['accuracy'])

discriminator.trainable = False
gan_input = Input(shape=(zdim,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(optimizer='Adam', loss='binary_crossentropy', metrics =['accuracy'])

def train_discriminator(x_train):
    c= np.random.randint(0, x_train.shape[0], batch_siz)
    real = x_train[c]
    discriminator.train_on_batch(real, np.ones((batch_siz,1)))

    p=np.random.normal(0,1, (batch_siz,zdim))
    fake = generator.predict(p)
    discriminator.train_on_batch(fake, np.zeros((batch_siz,1)))

def train_generator():
    p=np.random.normal(0,1, (batch_siz,zdim))
    gan.train_on_batch(p, np.ones((batch_siz,1)))

for i in range(epochs+1): # 학습을 수행
    train_discriminator(x_train)
    train_generator()
    if(i%100==0): # 학습 도중 100세대마다 중간 상황 출력
        plt.figure(figsize=(20, 4))
        plt.suptitle('epoch '+str(i))
        for k in range(20):
            plt.subplot(2,10,k+1)
            img=generator.predict(np.random.normal(0,1,(1,zdim)))
            plt.imshow(img[0].reshape(28,28),cmap='gray')
            plt.xticks([]); plt.yticks([])
        plt.show()
    
imgs = generator.predcit(np.random.normal(0,1,(50,zdim)))
plt.figure(figsize=(20,10))
for i in range(50):
    plt.subplot(5,10,i+1)
    plt.imshow(imgs[i].reshape(28,28), cmap='gray')
    plt.xticks([]);plt.yticks([])

 

실행 결과


학습 중간 결과

100세대마다 출력하게 했지만, 길어져서 1000세대 마다의 결과사진 업로드함

샘플 50개 생성

 

  • 데이터가 긴 획으로 구성된다는 사실 학습
  • 반 정도는 제대로 된 패턴, 나머지는 상당히 왜곡됨
  • 획이 두꺼운 패턴 , 얇은 패턴 등 다양한 패턴을 생성했음

 

코드 해설


📢 프로그램에 필요한 라이브러리

import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Activation, Dense, Flatten, Reshape, Conv2D, Conv2DTranspose, Dropout, BatchNormalization, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.losses import mse
import matplotlib.pyplot as plt

 

데이터 준비

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype('float32')/255.0)*2.0-1.0 #[-1,1]구간
x_test = (x_test.astype('float32')/255.0)*2.0-1.0 
x_train = np.reshape(x_train, (len(x_train), 28,28,1))
x_test = np.reshape(x_test, (len(x_test), 28,28,1))

MNIST 데이터를 신경망에 입력할 수 있는 상태로 변환한다.

화소의 값을 [-1,1]사이로 정규화한다.

(x_train.astype('float32')/255.0) 까지 했으면 [0,1] 구간의 값이 되고
여기에 *2 -1을 진행해 구간을 [-1,1]로 만든다.
[0,1]로 정규화해도 무방하지만 생성망 출력층의 활성함수를 
[-1,1]로 정규화한 경우에는 tanh로, [0,1]의 경우 sigmoid로 설정해 값의 범위를 맞추면 된다.

 

하이퍼 파라미터 설정

batch_siz = 64
epochs = 5000
dropout_rate = 0.4
batch_norm = 0.9
zdim = 100 #잠복공간의 차원

 

분별망 구축

discriminator_input = Input(shape=(28,28,1)) #분별망 D 설계
x= Conv2D(64,(5,5), activation = 'relu', padding = 'same', strides = (2,2))(discriminator_input)
x = Dropout(dropout_rate)(x)
x=Conv2D(64,(5,5), activation='relu', padding = 'same', strides=(2,2))(x)
x = Dropout(dropout_rate)(x)
x=Conv2D(128,(5,5), activation='relu', padding = 'same', strides=(2,2))(x)
x = Dropout(dropout_rate)(x)
x=Conv2D(128,(5,5), activation='relu', padding = 'same', strides=(1,1))(x)
x = Dropout(dropout_rate)(x)
x=Flatten()(x)
discriminator_output = Dense(1, activation = 'sigmoid')(x)
discriminator = Model(discriminator_input, discriminator_output)

분별망을 구축하는 부분이다.

코드를 자세히 살펴보자.

 

x= Conv2D(64,(5,5), activation = 'relu', padding = 'same', strides = (2,2))(discriminator_input)
x = Dropout(dropout_rate)(x)
x=Conv2D(64,(5,5), activation='relu', padding = 'same', strides=(2,2))(x)
x = Dropout(dropout_rate)(x)
x=Conv2D(128,(5,5), activation='relu', padding = 'same', strides=(2,2))(x)
x = Dropout(dropout_rate)(x)
x=Conv2D(128,(5,5), activation='relu', padding = 'same', strides=(1,1))(x)
x = Dropout(dropout_rate)(x)

컨볼루션층마다 드롭아웃층을 덧붙이고, 컨볼루션층을 보폭을 이용해 맵 크기를 줄여가면서 맵의 개수는 늘린다.

discriminator_output = Dense(1, activation = 'sigmoid')(x)

출력층을 샇는데, 노드는 하나이고 활성함수로 [0,1] 사이의 값을 출력하는 sigmoid 사용함

분별망은 가짜는 0, 진짜는 1을 출력해야 하므로 sigmoid 활성 함수를 사용함

 

생성망 구축

generator_input = Input(shape=(zdim,))
x=Dense(3136)(generator_input)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Reshape((7,7,64))(x)
x=UpSampling2D()(x)
x=Conv2D(128,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=UpSampling2D()(x)
x=Conv2D(64,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Conv2D(64,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Conv2D(1,(5,5), activation='tanh',padding='same')(x)
generator_output= x
generator = Model(generator_input, generator_output)

생성망을 설계하는 코드이다.

코드를 자세히 살펴보자.

 

x=Conv2D(128,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=UpSampling2D()(x)
x=Conv2D(64,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)
x=Conv2D(64,(5,5),padding='same')(x)
x=BatchNormalization(momentum=batch_norm)(x)
x=Activation('relu')(x)

컨볼루션 연산 바로 다음에 배치 정규화(batch_normalization)을 적용함

Upsampling
downsampling은 차원이 축소되어 Dense한 데이터를 만드는 것이고,
이것을 다시 해독해서 원본의 sparse한 데이터를 만드는 것이 upsampling
(padding의 반대 개념이라고 생각하자!)

Batch normalization
학습 과정에서 각 배치 단위 별로 데이터가 다양한 분포를 가지더라도 각 배치별로 평균과 분산을 이용해 정규화
x=Conv2D(1,(5,5), activation='tanh',padding='same')(x)

출력층을 쌓는데, 데이터 준비단계에서 화소의 값을 [-1,1]로 정규화했기 때문에 활성함수로 tanh 설정

 

학습에 필요한 사항 준비

discriminator.compile(optimizer='Adam', loss='binary_crossentropy', metrics = ['accuracy'])

분별망을 위한 코드!

옵티마이저로 Adma을 쓰고 손실 함수로 이진 교차 엔트로피 사용

 

discriminator.trainable = False
gan_input = Input(shape=(zdim,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(optimizer='Adam', loss='binary_crossentropy', metrics =['accuracy'])

생성망을 위한 코드!

코드를 자세히 살펴보자

 

discriminator.trainable = False

분별망의 가중치를 동결함

gan_input = Input(shape=(zdim,))

잠복공간의 벡터를 모델의 입력으로 지정

gan_output = discriminator(generator(gan_input))

입력을 생성망에 통과시켜 샘플을 생성하고 (generator(gan_input))

생성된 샘플을 분별망에 통과시켜 출력을 만든다.

gan = Model(gan_input, gan_output)

Model함수를 이용해 모델을 생성한다.

gan.compile(optimizer='Adam', loss='binary_crossentropy', metrics =['accuracy'])

생성망 학습 준비

옵티마이저로 Adam을 스고 손실 함수로 이진 교차 엔트로피 사용

 

분별망 학습

def train_discriminator(x_train):
    c= np.random.randint(0, x_train.shape[0], batch_siz)
    real = x_train[c]
    discriminator.train_on_batch(real, np.ones((batch_siz,1)))

    p=np.random.normal(0,1, (batch_siz,zdim))
    fake = generator.predict(p)
    discriminator.train_on_batch(fake, np.zeros((batch_siz,1)))

코드를 자세히 살펴보자.

 

    c= np.random.randint(0, x_train.shape[0], batch_siz)
    real = x_train[c]
    discriminator.train_on_batch(real, np.ones((batch_siz,1)))

훈련집합 x_train에서 batch_siz만큼 랜덤하게 샘플을 뽑아 진짜 샘플집합 real을 만듦

이때 레이블은 1을 붙여 학습을 진행한다.

    p=np.random.normal(0,1, (batch_siz,zdim))
    fake = generator.predict(p)
    discriminator.train_on_batch(fake, np.zeros((batch_siz,1)))

잠복공간에서 랜덤하게 batch_siz만큼 벡터를 생성해 가짜 샘플 집합 fake를 만듦

이때 레이블은 0을 붙여 학습을 진행한다.

 

생성망 학습

def train_generator():
    p=np.random.normal(0,1, (batch_siz,zdim))
    gan.train_on_batch(p, np.ones((batch_siz,1)))

잠복공간에서 랜덤하게 batch_siz만큼 생성해 p에 저장하고

p에 레이블 1을 붙여 모델 gan을 학습함

 

✅ 메인 프로그램

for i in range(epochs+1): # 학습을 수행
    train_discriminator(x_train)
    train_generator()
    if(i%100==0): # 학습 도중 100세대마다 중간 상황 출력
        plt.figure(figsize=(20, 4))
        plt.suptitle('epoch '+str(i))
        for k in range(20):
            plt.subplot(2,10,k+1)
            img=generator.predict(np.random.normal(0,1,(1,zdim)))
            plt.imshow(img[0].reshape(28,28),cmap='gray')
            plt.xticks([]); plt.yticks([])
        plt.show()

epoch만큼 for문을 반복하면서 분별망 학습과 생성망 학습을 반복함

아래 if문은 학습하는 도중에 얼마나 잘되고 있는지 확인하기 위해 100세마다 샘플 20개를 출력하는 부분이다.

생성적대 신경망의 학습은 오래 걸리기 때문에 중간중간 확인하자!

 

샘플 생성

imgs = generator.predict(np.random.normal(0,1,(50,zdim)))
plt.figure(figsize=(20,10))
for i in range(50):
    plt.subplot(5,10,i+1)
    plt.imshow(imgs[i].reshape(28,28), cmap='gray')
    plt.xticks([]);plt.yticks([])

샘플 50개를 생성하고 출력함.

 

 

 


참고교재 파이썬으로 만드는 인공지능 - 한빛미디어