Variational Autoencoder_model fit

TensorFlow2.0의 Keras 및 기타 라이브러리 가져오기

In [1]:
import tensorflow as tf
print(tf.__version__)
2.3.0
In [9]:
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2 as cv
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import shape,math
from tensorflow.keras import Input,layers,Model
from tensorflow.keras.losses import mse,binary_crossentropy
from tensorflow.keras.utils import plot_model

print(tf.__version__)
2.3.0

데이터 불러오기

In [10]:
import os
def DataLoad(src):
    temp = []
    files = os.listdir(src)
    for i in range(len(os.listdir(src))):
        if files[i][-4:] == '.jpg':
            temp.append(files[i])
    print(len(temp))
    return temp

src = './data_3000/'
filenames = DataLoad(src)
3000
In [11]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import os 

# 압축해제된 데이터 경로를 찾아 복사해서 붙여넣어주세요 (마지막 '/' 꼭 붙여야함)
src = './data_3000/'

# 이미지 읽기 및 출력
def img_plot(img):
    plt.imshow(img)
    plt.show()

# 이미지 읽기
def img_read(src,file):
    img = cv.imread(src+file,cv.COLOR_BGR2GRAY)
    return img

X,Y = [],[]
count = 0

# 경로와 파일명을 입력으로 넣어 확인하고 
# 데이터를 255로 나눠서 0~1사이로 정규화 하여 X 리스트에 넣습니다. 
for name in filenames: 
    X.append(img_read(src,name)/255.)
    Y.append(float(name[:-4]))

# array로 데이터 변환
X = np.asarray(X)
Y = np.asarray(Y)

# X dataset 일부 확인 
for i in range(10):
    img = X[i]
    img_plot(img)
print('X_list shape:',np.shape(X),'Y_list shape:',np.shape(Y))
X_list shape: (3000, 56, 56) Y_list shape: (3000,)
In [12]:
# Train set, Test set으로 나누기 
x_train, x_test, y_train, y_test = train_test_split(X,Y, test_size=0.2, random_state=1,shuffle=True)
x_train = np.array(x_train)
x_test = np.array(x_test)


 # (image,image)이미지를 크기의 (image*image,)크기의 벡터로 만듭니다
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))).astype('float32')
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:]))).astype('float32')

print("Datasets_train_shape:{}  Datasets_test_shape:{}".format(np.shape(x_train),np.shape(x_test)))
Datasets_train_shape:(2400, 3136)  Datasets_test_shape:(600, 3136)

VAE 모델만들기

In [13]:
from tensorflow import shape,math
from tensorflow.keras import Input,layers,Model
from tensorflow.keras.losses import mse,binary_crossentropy

# network parameters
input_shape = np.shape(x_train[0])[0]
original_dim= input_shape
intermediate_dim = 512
latent_dim = 2

Encoder 생성

In [14]:
def encoder():
  # 인코더의 입력층을 생성합니다.
  inputs = Input(shape=(input_shape,), name='input_shape')

  # 인코더의 hidden층을 생성합니다. 500개의 유닛을 사용했습니다.
  encoder_hidden = layers.Dense(intermediate_dim, activation='relu', name='encoder_hidden1')(inputs)

  # 평균(mean)과 표준편차(sigma)층을 정의합니다. 
  # 이때 sigma 대신 log variance를 사용합니다. 이는 신경망의 출력은 음수를 가질 수 있지만 sigma는 항상 양수여야 하기 때문입니다. 
  # 각각 2개의 유닛을 사용했습니다.
  z_mean = layers.Dense(latent_dim, name='z_mean')(encoder_hidden)
  z_log_var = layers.Dense(latent_dim, name='z_log_var')(encoder_hidden)

  # 평균과 표준편차를 래핑하여 Z_sampling층을 만듭니다.

  # Z 샘플링 함수 생성
  def sampling(args):
      z_mean, z_log_var = args
      batch = shape(z_mean)[0]
      dim = shape(z_mean)[1]

      # by default, random_normal has mean = 0 and std = 1.0
      # Reparameterization Trick사용을 위해 Gussian(=normal)분포에서 랜덤변수(sample) ε추출 
      epsilon = tf.compat.v2.random.normal(shape=(batch, dim))
      return z_mean + tf.math.exp(0.5 * z_log_var) * epsilon

  #  layers.Lambda API 래핑에 사용할 함수와, 유닛수(n,)를 지정합니다.
  z_sampling = layers.Lambda(sampling, (latent_dim,), name='z_sample')([z_mean, z_log_var])

  # 하나의 입력과 다중충력을 포함하는 encoder 모델을 만듭니다.   
  return Model(inputs,[z_mean,z_log_var,z_sampling], name='encoder')

encoder = encoder()

# 인코더 네트워크를 확인합니다.
encoder.summary()

from tensorflow.keras.utils import plot_model
plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_shape (InputLayer)        [(None, 3136)]       0                                            
__________________________________________________________________________________________________
encoder_hidden1 (Dense)         (None, 512)          1606144     input_shape[0][0]                
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            1026        encoder_hidden1[0][0]            
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            1026        encoder_hidden1[0][0]            
__________________________________________________________________________________________________
z_sample (Lambda)               (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 1,608,196
Trainable params: 1,608,196
Non-trainable params: 0
__________________________________________________________________________________________________
('Failed to import pydot. You must `pip install pydot` and install graphviz (https://graphviz.gitlab.io/download/), ', 'for `pydotprint` to work.')

Decoder 생성

In [15]:
def decoder():
  
  # 디코더의 입력층을 생성합니다. (Decoder의 입력은 latent입니다)
  input_z = Input(shape=(latent_dim,), name='input_z')

  # 디코더의 hidden층을 생성합니다. 인코더와 동일하게 500개의 유닛을 사용했습니다.
  decoder_hidden = layers.Dense(intermediate_dim, activation='relu', name='decoder_hidden')(input_z)

  # 디코더의 출력층은 인코더 입력벡터 수만큼 유닛을 사용합니다.
  outputs = layers.Dense(original_dim, activation='sigmoid',name='output')(decoder_hidden)

  return Model(input_z, outputs, name='decoder')

decoder = decoder()

# 디코더의 네트워크 확인.
decoder.summary()
plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True)
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_z (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
decoder_hidden (Dense)       (None, 512)               1536      
_________________________________________________________________
output (Dense)               (None, 3136)              1608768   
=================================================================
Total params: 1,610,304
Trainable params: 1,610,304
Non-trainable params: 0
_________________________________________________________________
('Failed to import pydot. You must `pip install pydot` and install graphviz (https://graphviz.gitlab.io/download/), ', 'for `pydotprint` to work.')

VAE모델 생성

In [16]:
def vae():
  # vae는 입력으로 이미지로 들어와 encoder를 통해 z_sampling 되어 decoder로 출력됩니다.  
  inputs = Input(shape=(input_shape,), name='input_shape')
  outputs = decoder(encoder(inputs)[2]) #[0]:z_mean, [1]:z_log_var,[2]:z_sampling
  
  return Model(inputs,outputs, name='vae_mlp')

#VAE 모델 정의
model = vae()

#모델 네트워크 확인 
model.summary()
plot_model(model,to_file='vae_mlp.png',show_shapes=True)
Model: "vae_mlp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_shape (InputLayer)     [(None, 3136)]            0         
_________________________________________________________________
encoder (Functional)         [(None, 2), (None, 2), (N 1608196   
_________________________________________________________________
decoder (Functional)         (None, 3136)              1610304   
=================================================================
Total params: 3,218,500
Trainable params: 3,218,500
Non-trainable params: 0
_________________________________________________________________
('Failed to import pydot. You must `pip install pydot` and install graphviz (https://graphviz.gitlab.io/download/), ', 'for `pydotprint` to work.')

model fit으로 모델학습

In [17]:
# Hyperparameters
num_epochs = 100
batch_size = 20
learning_rate = 1e-3
In [18]:
# 모델 학습 loss, optimizer 정의
adam = tf.keras.optimizers.Adam(learning_rate=learning_rate)
def vae_loss(x,recon_x):
    # (1)Reconstruct loss (Marginal_likelihood) : Cross-entropy 
    z_mean,z_log_var,z_sampling = encoder(x)
    recon_x=decoder(z_sampling)
    reconstruction_loss = binary_crossentropy(x,recon_x)
    #reconstruction_loss = mse(inputs, outputs)
    reconstruction_loss *= original_dim
    # (2) KL divergence(Latent_loss)
    kl_loss = 0.5 * tf.reduce_sum(tf.square(z_mean)+ tf.exp(z_log_var)- z_log_var -1, 1)    
    return tf.reduce_mean(reconstruction_loss + kl_loss) #ELBO(=VAE_loss)

model.compile(optimizer=adam,loss=vae_loss)
In [20]:
 #모델 학습
hist = model.fit(x_train, x_train, epochs=num_epochs, batch_size=batch_size )
#학습된 VAE 모델 저장 
model.save_weights('vae_bracket.h5')
Epoch 1/100
120/120 [==============================] - 1s 11ms/step - loss: 1451.2714
Epoch 2/100
120/120 [==============================] - 1s 12ms/step - loss: 1318.1204
Epoch 3/100
120/120 [==============================] - 1s 12ms/step - loss: 1291.7410
Epoch 4/100
120/120 [==============================] - 1s 12ms/step - loss: 1275.3472
Epoch 5/100
120/120 [==============================] - 2s 14ms/step - loss: 1264.1218
Epoch 6/100
120/120 [==============================] - 1s 12ms/step - loss: 1255.7401
Epoch 7/100
120/120 [==============================] - 1s 12ms/step - loss: 1249.8816
Epoch 8/100
120/120 [==============================] - 1s 12ms/step - loss: 1247.6521
Epoch 9/100
120/120 [==============================] - 2s 13ms/step - loss: 1243.4037
Epoch 10/100
120/120 [==============================] - 2s 13ms/step - loss: 1243.8610
Epoch 11/100
120/120 [==============================] - 1s 12ms/step - loss: 1238.7444
Epoch 12/100
120/120 [==============================] - 2s 13ms/step - loss: 1239.0159
Epoch 13/100
120/120 [==============================] - 2s 13ms/step - loss: 1238.7141
Epoch 14/100
120/120 [==============================] - 1s 12ms/step - loss: 1235.8914
Epoch 15/100
120/120 [==============================] - 2s 13ms/step - loss: 1235.8959
Epoch 16/100
120/120 [==============================] - 1s 12ms/step - loss: 1233.6572
Epoch 17/100
120/120 [==============================] - 1s 12ms/step - loss: 1231.0713
Epoch 18/100
120/120 [==============================] - 2s 13ms/step - loss: 1232.0835
Epoch 19/100
120/120 [==============================] - 1s 12ms/step - loss: 1231.8053
Epoch 20/100
120/120 [==============================] - 2s 13ms/step - loss: 1228.7012
Epoch 21/100
120/120 [==============================] - 1s 12ms/step - loss: 1230.1541
Epoch 22/100
120/120 [==============================] - 1s 12ms/step - loss: 1227.3281
Epoch 23/100
120/120 [==============================] - 1s 11ms/step - loss: 1226.1062
Epoch 24/100
120/120 [==============================] - 1s 12ms/step - loss: 1226.9886
Epoch 25/100
120/120 [==============================] - 1s 11ms/step - loss: 1223.7666
Epoch 26/100
120/120 [==============================] - 1s 11ms/step - loss: 1223.3406
Epoch 27/100
120/120 [==============================] - 1s 11ms/step - loss: 1221.9435
Epoch 28/100
120/120 [==============================] - 1s 11ms/step - loss: 1221.2096
Epoch 29/100
120/120 [==============================] - 1s 12ms/step - loss: 1223.0601
Epoch 30/100
120/120 [==============================] - 1s 12ms/step - loss: 1221.0608
Epoch 31/100
120/120 [==============================] - 1s 12ms/step - loss: 1220.4130
Epoch 32/100
120/120 [==============================] - 1s 12ms/step - loss: 1218.6760
Epoch 33/100
120/120 [==============================] - 1s 12ms/step - loss: 1218.1255
Epoch 34/100
120/120 [==============================] - 1s 12ms/step - loss: 1218.2906
Epoch 35/100
120/120 [==============================] - 1s 12ms/step - loss: 1215.5580
Epoch 36/100
120/120 [==============================] - 1s 12ms/step - loss: 1214.3503
Epoch 37/100
120/120 [==============================] - 1s 12ms/step - loss: 1215.1799
Epoch 38/100
120/120 [==============================] - 1s 12ms/step - loss: 1215.0861
Epoch 39/100
120/120 [==============================] - 1s 12ms/step - loss: 1214.6652
Epoch 40/100
120/120 [==============================] - 1s 12ms/step - loss: 1213.1713
Epoch 41/100
120/120 [==============================] - 1s 12ms/step - loss: 1212.8325
Epoch 42/100
120/120 [==============================] - 1s 12ms/step - loss: 1211.3782
Epoch 43/100
120/120 [==============================] - 1s 12ms/step - loss: 1212.6320
Epoch 44/100
120/120 [==============================] - 1s 12ms/step - loss: 1211.1757
Epoch 45/100
120/120 [==============================] - 1s 12ms/step - loss: 1211.4059
Epoch 46/100
120/120 [==============================] - 1s 12ms/step - loss: 1212.9310
Epoch 47/100
120/120 [==============================] - 2s 13ms/step - loss: 1208.0105
Epoch 48/100
120/120 [==============================] - 1s 11ms/step - loss: 1209.7262
Epoch 49/100
120/120 [==============================] - 1s 11ms/step - loss: 1209.1075
Epoch 50/100
120/120 [==============================] - 1s 11ms/step - loss: 1208.3210
Epoch 51/100
120/120 [==============================] - 1s 11ms/step - loss: 1210.5281
Epoch 52/100
120/120 [==============================] - 1s 11ms/step - loss: 1210.4613
Epoch 53/100
120/120 [==============================] - 2s 14ms/step - loss: 1209.9965
Epoch 54/100
120/120 [==============================] - 2s 15ms/step - loss: 1210.5814
Epoch 55/100
120/120 [==============================] - 1s 11ms/step - loss: 1208.8220
Epoch 56/100
120/120 [==============================] - 1s 11ms/step - loss: 1207.5226
Epoch 57/100
120/120 [==============================] - 2s 14ms/step - loss: 1205.6234
Epoch 58/100
120/120 [==============================] - 1s 12ms/step - loss: 1205.9741 
Epoch 59/100
120/120 [==============================] - 1s 12ms/step - loss: 1205.1418
Epoch 60/100
120/120 [==============================] - 1s 11ms/step - loss: 1204.0105
Epoch 61/100
120/120 [==============================] - 1s 11ms/step - loss: 1202.7397
Epoch 62/100
120/120 [==============================] - 2s 13ms/step - loss: 1204.3744 0s - l
Epoch 63/100
120/120 [==============================] - 2s 14ms/step - loss: 1204.0844
Epoch 64/100
120/120 [==============================] - 1s 11ms/step - loss: 1203.3358
Epoch 65/100
120/120 [==============================] - 1s 11ms/step - loss: 1201.4772
Epoch 66/100
120/120 [==============================] - 1s 11ms/step - loss: 1202.6218
Epoch 67/100
120/120 [==============================] - 1s 11ms/step - loss: 1202.1835
Epoch 68/100
120/120 [==============================] - 1s 11ms/step - loss: 1202.0806
Epoch 69/100
120/120 [==============================] - 1s 10ms/step - loss: 1198.7412
Epoch 70/100
120/120 [==============================] - 1s 10ms/step - loss: 1199.5820
Epoch 71/100
120/120 [==============================] - 1s 11ms/step - loss: 1200.1707
Epoch 72/100
120/120 [==============================] - 1s 11ms/step - loss: 1197.8625
Epoch 73/100
120/120 [==============================] - 1s 10ms/step - loss: 1197.8823
Epoch 74/100
120/120 [==============================] - 1s 11ms/step - loss: 1200.5217
Epoch 75/100
120/120 [==============================] - 1s 11ms/step - loss: 1197.0145
Epoch 76/100
120/120 [==============================] - 1s 11ms/step - loss: 1196.4495
Epoch 77/100
120/120 [==============================] - 1s 11ms/step - loss: 1195.2098
Epoch 78/100
120/120 [==============================] - 1s 10ms/step - loss: 1196.5583
Epoch 79/100
120/120 [==============================] - 1s 11ms/step - loss: 1194.6842
Epoch 80/100
120/120 [==============================] - 1s 11ms/step - loss: 1195.2152
Epoch 81/100
120/120 [==============================] - 1s 10ms/step - loss: 1196.3453
Epoch 82/100
120/120 [==============================] - 1s 11ms/step - loss: 1193.9816 0s - loss: 1192.
Epoch 83/100
120/120 [==============================] - 1s 11ms/step - loss: 1193.9949
Epoch 84/100
120/120 [==============================] - 1s 10ms/step - loss: 1192.7084
Epoch 85/100
120/120 [==============================] - 1s 11ms/step - loss: 1193.6130
Epoch 86/100
120/120 [==============================] - 1s 10ms/step - loss: 1194.2993
Epoch 87/100
120/120 [==============================] - 1s 11ms/step - loss: 1194.0914
Epoch 88/100
120/120 [==============================] - 1s 12ms/step - loss: 1191.5453
Epoch 89/100
120/120 [==============================] - 1s 10ms/step - loss: 1191.8341
Epoch 90/100
120/120 [==============================] - 1s 10ms/step - loss: 1190.6289
Epoch 91/100
120/120 [==============================] - 1s 11ms/step - loss: 1189.4753
Epoch 92/100
120/120 [==============================] - 1s 10ms/step - loss: 1191.8456
Epoch 93/100
120/120 [==============================] - 1s 10ms/step - loss: 1191.5165
Epoch 94/100
120/120 [==============================] - 1s 11ms/step - loss: 1190.7710
Epoch 95/100
120/120 [==============================] - 1s 11ms/step - loss: 1188.9766
Epoch 96/100
120/120 [==============================] - 1s 10ms/step - loss: 1187.8401
Epoch 97/100
120/120 [==============================] - 1s 11ms/step - loss: 1188.5903
Epoch 98/100
120/120 [==============================] - 1s 10ms/step - loss: 1190.6775
Epoch 99/100
120/120 [==============================] - 1s 10ms/step - loss: 1188.4381
Epoch 100/100
120/120 [==============================] - 1s 11ms/step - loss: 1187.9968
In [ ]:
원본이미지와 복원이미지 비교
In [21]:
recon_x_test = model.predict(x_test)

n = 10  # how many digits we will display
plt.figure(figsize=(15, 4))
for i in range(10):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(56,56), vmin=0, vmax=1, cmap="gray")
    plt.title("Input"+str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    #plt.colorbar()

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(recon_x_test[i].reshape(56, 56),vmin=0, vmax=1, cmap="gray")
    plt.title("Recon"+str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    #plt.colorbar()
plt.show()

2D공간에 잠재된 데이터 출력

In [24]:
 #학습모델이 생성한 Manifold를 plot하는 함수 정의
def plot_results(models,
                 data,
                 batch_size=batch_size,
                 model_name="vae_mnist"):
    encoder, decoder = models
    x_test, y_test = data
    filename = "digits_over_latent.png"
    # display a 30x30 2D manifold of digitsa
    n = 10
    digit_size = 56
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-2, 2, n)
    grid_y = np.linspace(-2, 2, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = (n - 1) * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()
    
plot_results(models = (encoder, decoder),
                 data = (x_test, y_test),
                 batch_size=batch_size,
                 model_name="vae_mlp")
In [25]:
def tsne_plot_results(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist"):

    encoder, decoder = models
    x_test, y_test = data
    os.makedirs(model_name, exist_ok=True)

    #filename = os.path.join(model_name, "vae_mean.png")
    filename = "vae_mean.png"
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(10, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)
    plt.show()

tsne_plot_results(models = (encoder, decoder),
                 data = (x_test, y_test),
                 batch_size=batch_size,
                 model_name="vae_mlp")