In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Embedding, Flatten, multiply
from tensorflow.keras.models import Model
import time

from IPython import display

print(tf.__version__)
2.0.0

이미지 로드, 파라미터 설정

In [2]:
# mnist data set load
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# 이미지를 [-1, 1]로 정규화합니다.
train_images = (train_images - 127.5) / 127.5 

BUFFER_SIZE = 60000
BATCH_SIZE = 200
latent_dim = 100
num_classes = 10
img_shape = (28,28,1)

# 데이터 배치를 만듭니다.
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).batch(BATCH_SIZE)
label_dataset = tf.data.Dataset.from_tensor_slices(train_labels).batch(BATCH_SIZE)

Generator network 구성

In [3]:
def make_generator_model():
    model = tf.keras.Sequential()
        
    model.add(layers.Dense(256, input_dim=latent_dim))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(1024))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(np.prod(img_shape), activation='tanh'))
    model.add(layers.Reshape(img_shape))

    model.summary()
    
    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    
    # label을 noise의 차원과 맞춰줌
    label_embedding = Flatten()(Embedding(num_classes, latent_dim)(label))

    # noise와 label을 합침(곱함) -> input으로 들어감
    model_input = multiply([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)
In [4]:
# 노이즈 만들어서 generator에 넣은 후 나오는 이미지 출력

generator = make_generator_model()

noise = tf.random.normal([1, 100])

generated_image = generator([noise, 1], training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 256)               25856     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256)               0         
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_3 (Dense)              (None, 784)               803600    
_________________________________________________________________
reshape (Reshape)            (None, 28, 28, 1)         0         
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
_________________________________________________________________
Out[4]:
<matplotlib.image.AxesImage at 0x7f71e83e5f60>

Discriminator network 구성

In [5]:
def make_discriminator_model():
    model = tf.keras.Sequential()
        
    model.add(layers.Dense(512, input_dim=np.prod(img_shape)))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.4))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.4))
    model.add(layers.Dense(1, activation='sigmoid'))
    model.summary()
    
    # noise, label 텐서 생성
    img = Input(shape=img_shape)
    label = Input(shape=(1,), dtype='int32')

    # (10, (28x28)) 크기의 밀집벡터 생성 후 label과 합쳐서 Flatten함
    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))
    # 이미지도 펼침(Flatten)
    flat_img = Flatten()(img)

    # 펼친 이미지와 label embedding을 합침
    model_input = multiply([flat_img, label_embedding])
    
    # 모델에 input을 넣으면 판별값이 나옴
    validity = model(model_input)

    return Model([img, label], validity)
In [27]:
# image를 discriminator에 넣었을 때 판별값이 나옴
discriminator = make_discriminator_model()
sampled_labels = np.random.normal(0, 10, 1).reshape(-1, 1)
decision = discriminator([generated_image, sampled_labels])
print (decision[0])
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_8 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 512)               262656    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 1)                 257       
=================================================================
Total params: 796,161
Trainable params: 796,161
Non-trainable params: 0
_________________________________________________________________
tf.Tensor([0.49998978], shape=(1,), dtype=float32)
In [7]:
# 이 메서드는 크로스 엔트로피 손실함수 (cross entropy loss)를 계산하기 위해 헬퍼 (helper) 함수를 반환
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

loss 계산

In [8]:
# real image를 넣었을 때 1이 나오게, fake image를 넣었을 때 0이 나오게 학습
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
In [9]:
# fake image를 넣었을 때 1이 나오도록 학습
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)
In [10]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
In [25]:
# hyper parameter 설정 
EPOCHS = 500
noise_dim = 100
num_examples_to_generate = 10

# 출력 용도로 random noise 생성
seed = tf.random.normal([num_examples_to_generate, noise_dim])

training function

In [12]:
# `tf.function` 이 데코레이터는 함수를 "컴파일" 한다.
@tf.function
def train_step(images, labels):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
        # generator에 noise 넣고 fake image 생성
        generated_images = generator([noise, labels], training=True)
        
        # discriminator에 real image와 fake image 넣고 판별값 리턴
        real_output = discriminator([images, labels], training=True)
        fake_output = discriminator([generated_images, labels], training=True)

        # fake image를 discriminator가 1로 학습 하도록 업데이트
        gen_loss = generator_loss(fake_output)
        # real image loss와 fake image loss 합한 total loss 리턴
        disc_loss = discriminator_loss(real_output, fake_output)

    # tape에 기록하며 자동미분 실행
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
In [24]:
def train(dataset, labeldataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    # 이미지 넣고 학습 실행. 
    for image_batch, label in zip(dataset, labeldataset):
      train_step(image_batch, label)

    # 이미지를 10 epoch마다 출력
#     display.clear_output(wait=True)
    if epoch % 100 == 0:
        generate_and_save_images(generator,epoch + 1,seed)
    
    # print (' 에포크 {} 에서 걸린 시간은 {} 초 입니다'.format(epoch +1, time.time()-start))
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # 마지막 에포크가 끝난 후 생성합니다.
  display.clear_output(wait=True)
  generate_and_save_images(generator,epochs,seed)
In [18]:
# noise 넣고 이미지 확인
def generate_and_save_images(model, epoch, test_input):
    # `training`이 False : (배치정규화를 포함하여) 모든 층들이 추론 모드로 실행됨
#     predictions = model(test_input, training=False)

    r, c = 2, 5
    sampled_labels = np.arange(0, 10).reshape(-1, 1)

    gen_imgs = model([test_input,sampled_labels], training=False)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r,c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0])
            axs[i,j].set_title("Digit: %d" %sampled_labels[cnt])
            axs[i,j].axis('off')
            cnt += 1
In [26]:
%%time
train(train_dataset, label_dataset, EPOCHS)
CPU times: user 34min 43s, sys: 5min 45s, total: 40min 28s
Wall time: 20min 7s