In [0]:
#from __future__ import absolute_import, division, print_function, unicode_literals
#!pip install -q tensorflow-gpu==2.0.0-rc1

Variational Autoencoder

대체 텍스트

케라스 및 기타 라이브러리 가져오기

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os

from keras.datasets import mnist
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras import backend as K
from keras.utils import plot_model
from keras.losses import mse, binary_crossentropy
Using TensorFlow backend.

The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
We recommend you upgrade now or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic: more info.

In [0]:
# network parameters
reconstruct_dim=784
input_shape = (reconstruct_dim,)
intermediate_dim = 512
batch_size = 64
latent_dim = 2
epochs = 50

MNIST 데이터 셋 로드

In [4]:
# Load the MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
 # 이미지 픽셀의 모든 값을 0~1 사이로 정규화
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
 # 28x28 이미지를 크기 784의 벡터로 만듭니다
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

print(x_train.shape,x_test.shape)
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
11493376/11490434 [==============================] - 2s 0us/step
(60000, 784) (10000, 784)

VAE 모델생성

대체 텍스트

(1) Encoder 생성

In [5]:
# Z 샘플링 함수 생성
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    # Reparameterization Trick사용을 위해 Gussian(=normal)분포에서 랜덤변수(sample) ε추출 
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
# 인코더 모델을 생성합니다.
# 입력단에는 Input shape를 지정해줍니다.
inputs = Input(shape=input_shape, name='input_shape')
encoder_hidden = Dense(intermediate_dim, activation='relu', name='encoder_hidden1')(inputs)
# 잠재변수Z에 사용할 평균(mean)과 표준편차(log_sigma) 정의합니다.  
z_mean = Dense(latent_dim, name='z_mean')(encoder_hidden)
z_log_sigma = Dense(latent_dim, name='z_log_sigma')(encoder_hidden)
# Lambda단에는 래핑에 사용할 함수와 출력 shape를 지정해줍니다.
z_sampling = Lambda(sampling, (latent_dim,), name='z')([z_mean, z_log_sigma])
# 인코더 모델은 다중 출력 모델이기 때문에 리스트를 사용하여 모델을 정의합니다.
encoder = Model(inputs,[z_mean,z_log_sigma,z_sampling], name='encoder')
# 인코더 네트워크를 확인합니다.
encoder.summary()
plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:541: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4432: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4409: The name tf.random_normal is deprecated. Please use tf.random.normal instead.

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_shape (InputLayer)        (None, 784)          0                                            
__________________________________________________________________________________________________
encoder_hidden1 (Dense)         (None, 512)          401920      input_shape[0][0]                
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            1026        encoder_hidden1[0][0]            
__________________________________________________________________________________________________
z_log_sigma (Dense)             (None, 2)            1026        encoder_hidden1[0][0]            
__________________________________________________________________________________________________
z (Lambda)                      (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_sigma[0][0]                
==================================================================================================
Total params: 403,972
Trainable params: 403,972
Non-trainable params: 0
__________________________________________________________________________________________________
Out[5]:

(2) Decoder 생성

In [6]:
# 디코더 모델을 생성합니다.
# Z가 입력으로 들어오기 때문에 Z의 shape를 Input shape로 지정해줍니다.
input_z = Input(shape=(latent_dim,), name='input_z')
decoder_hidden = Dense(intermediate_dim, activation='relu', name='decoder_hidden')(input_z)
outputs = Dense(reconstruct_dim, activation='sigmoid',name='output')(decoder_hidden)

# 디코더 모델은 잠재변수 z가 입력으로 들어가고 복원값을 출력합니다. 
decoder = Model(input_z, outputs, name='decoder')

# 디코더의 네트워크를 확인합니다.
decoder.summary()
plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True)
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_z (InputLayer)         (None, 2)                 0         
_________________________________________________________________
decoder_hidden (Dense)       (None, 512)               1536      
_________________________________________________________________
output (Dense)               (None, 784)               402192    
=================================================================
Total params: 403,728
Trainable params: 403,728
Non-trainable params: 0
_________________________________________________________________
Out[6]:

(3) VAE모델 생성

In [7]:
def vae_loss(x,recon_x):
    # (1)Reconstruct loss (Marginal_likelihood) : Cross-entropy 
    inputs, outputs = x, recon_x
    reconstruction_loss = binary_crossentropy(inputs,outputs)
    #reconstruction_loss = mse(inputs, outputs)
    reconstruction_loss *= reconstruct_dim
    # (2) KL divergence(Latent_loss)
    kl_loss = 0.5 * K.sum(K.square(z_mean)+ K.exp(z_log_sigma)- 2*z_log_sigma -1, 1)    
    return K.mean(reconstruction_loss + kl_loss) #ELBO(=VAE_loss)

# build VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
vae.compile(optimizer='adam', loss=vae_loss)
vae.summary()
plot_model(vae,to_file='vae_mlp.png',show_shapes=True)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/optimizers.py:793: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3657: The name tf.log is deprecated. Please use tf.math.log instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "vae_mlp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_shape (InputLayer)     (None, 784)               0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 403972    
_________________________________________________________________
decoder (Model)              (None, 784)               403728    
=================================================================
Total params: 807,700
Trainable params: 807,700
Non-trainable params: 0
_________________________________________________________________
Out[7]:

모델학습시작

In [8]:
history = vae.fit(x_train, x_train,
        shuffle=True,
        epochs=20,
        batch_size=batch_size,
        validation_data=(x_test, x_test))

#학습된 VAE 모델 저장 
vae.save_weights('vae_mlp_mnist.h5')
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:1033: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:1020: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3005: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

Train on 60000 samples, validate on 10000 samples
Epoch 1/20
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:190: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:197: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:207: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:216: The name tf.is_variable_initialized is deprecated. Please use tf.compat.v1.is_variable_initialized instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:223: The name tf.variables_initializer is deprecated. Please use tf.compat.v1.variables_initializer instead.

60000/60000 [==============================] - 13s 221us/step - loss: 187.5353 - val_loss: 172.3581
Epoch 2/20
60000/60000 [==============================] - 8s 137us/step - loss: 170.4814 - val_loss: 168.9592
Epoch 3/20
60000/60000 [==============================] - 8s 137us/step - loss: 167.3968 - val_loss: 166.3576
Epoch 4/20
60000/60000 [==============================] - 8s 137us/step - loss: 165.1836 - val_loss: 164.3051
Epoch 5/20
60000/60000 [==============================] - 8s 138us/step - loss: 163.4123 - val_loss: 162.6355
Epoch 6/20
60000/60000 [==============================] - 8s 138us/step - loss: 161.9806 - val_loss: 161.8190
Epoch 7/20
60000/60000 [==============================] - 8s 136us/step - loss: 160.9835 - val_loss: 160.8684
Epoch 8/20
60000/60000 [==============================] - 8s 140us/step - loss: 160.2054 - val_loss: 160.0840
Epoch 9/20
60000/60000 [==============================] - 8s 139us/step - loss: 159.5080 - val_loss: 159.7317
Epoch 10/20
60000/60000 [==============================] - 8s 136us/step - loss: 158.9402 - val_loss: 159.4449
Epoch 11/20
60000/60000 [==============================] - 8s 137us/step - loss: 158.4187 - val_loss: 158.8626
Epoch 12/20
60000/60000 [==============================] - 8s 137us/step - loss: 158.0519 - val_loss: 158.3506
Epoch 13/20
60000/60000 [==============================] - 8s 137us/step - loss: 157.5735 - val_loss: 158.3949
Epoch 14/20
60000/60000 [==============================] - 8s 137us/step - loss: 157.2698 - val_loss: 157.7272
Epoch 15/20
60000/60000 [==============================] - 8s 137us/step - loss: 156.9556 - val_loss: 157.8739
Epoch 16/20
60000/60000 [==============================] - 8s 136us/step - loss: 156.6503 - val_loss: 157.4800
Epoch 17/20
60000/60000 [==============================] - 8s 137us/step - loss: 156.3112 - val_loss: 157.1618
Epoch 18/20
60000/60000 [==============================] - 8s 138us/step - loss: 156.0790 - val_loss: 157.3379
Epoch 19/20
60000/60000 [==============================] - 8s 137us/step - loss: 155.8171 - val_loss: 157.0190
Epoch 20/20
60000/60000 [==============================] - 8s 138us/step - loss: 155.5948 - val_loss: 156.7502

학습 그래프 확인

In [9]:
def plt_hist(hist):
    # summarize history for loss
    plt.plot(hist.history['loss'])
    plt.plot(hist.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    
plt_hist(history)

원본이미지와 복원이미지 비교

In [10]:
recon_x_test = vae.predict(x_test)

n = 10  # how many digits we will display
plt.figure(figsize=(15, 4))
for i in range(10):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
    plt.title("Input"+str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    #plt.colorbar()

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(recon_x_test[i].reshape(28, 28),vmin=0, vmax=1, cmap="gray")
    plt.title("Recon"+str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    #plt.colorbar()
plt.show()

2D공간에 잠재된 데이터 출력

In [0]:
"""Plots labels and MNIST digits as a function of the 2D latent vector
# Arguments
    models (tuple): encoder and decoder models
    data (tuple): test data and label
    batch_size (int): prediction batch size
    model_name (string): which model is using this function
"""

# 학습모델이 생성한 Manifold를 plot하는 함수 정의
def plot_results(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist"):
    encoder, decoder = models
    x_test, y_test = data
    filename = "digits_over_latent.png"
    # display a 30x30 2D manifold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-2, 2, n)
    grid_y = np.linspace(-2, 2, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = (n - 1) * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()
In [12]:
plot_results(models = (encoder, decoder),
                 data = (x_test, y_test),
                 batch_size=batch_size,
                 model_name="vae_mlp")
In [17]:
# display a 30x30 2D manifold of digits
n = 10
digit_size =14
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
grid_x = np.linspace(-2, 2, n)
grid_y = np.linspace(-2, 2, n)[::-1]

canvas = np.empty((28, 28*n))
for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        z_sample = np.array([[xi, yi]])
        x_hat = decoder.predict(z_sample)
        canvas[:, i*28: (i+1)*28] = x_hat[0].reshape(28, 28)
fig, ax = plt.subplots(figsize=(10,10))             
plt.title("interpolation")
ax.imshow(canvas, cmap="gray")
Out[17]:
<matplotlib.image.AxesImage at 0x7f83a4724b00>
In [0]: