-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_toy2D.py
106 lines (95 loc) · 3.53 KB
/
main_toy2D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 10 11:24:49 2017
@author: zhaoxm
"""
import tensorflow as tf
from skimage import io
from gan_model import *
from gan_solver import *
from utils import *
import seaborn as sb
from sampler import generate_lut, sample_2d
from visualizer import GANDemoVisualizer
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
model_type = 'Vanilla'
N = 20 00
D = 2
hidden_num = 128
z_dim = 2
init_learning_rate = 1e-3
max_iter = 100000
verbose_interval = 100
show_interval = 1000
snapshot = 1000
train_dir = './img/triangle.jpg'
density_img = io.imread(train_dir, True)
lut_2d = generate_lut(density_img)
visualizer = GANDemoVisualizer('GAN 2D Example Visualization')
tf.reset_default_graph()
data = tf.placeholder(tf.float32, [None, D])
if model_type == 'Vanilla':
K = 1
model = VanillaGAN(data, hidden_num=hidden_num, z_dim=z_dim)
train_op = BaseSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver]
elif model_type == 'Wasserstein':
K = 1
use_gp = True
model = WassersteinGAN(data, hidden_num=hidden_num, z_dim=z_dim, use_gp=use_gp)
if use_gp:
train_op = BaseSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver]
else:
train_op = WassersteinSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver, train_op.clip_grad]
elif model_type == 'EMGAN':
K = 1
use_gp = True
model = EMGAN(data, hidden_num=hidden_num, z_dim=z_dim, use_gp=use_gp, embedding_dim=10)
if use_gp:
train_op = BaseSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver]
else:
train_op = WassersteinSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver, train_op.clip_grad]
elif model_type == 'EBGAN':
K = 1
model = EBGAN(data, hidden_num=hidden_num, z_dim=z_dim)
train_op = BaseSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver]
elif model_type == 'BEGAN':
K = 1
model = BEGAN(data, hidden_num=hidden_num, z_dim=z_dim)
train_op = BaseSolver(model, init_learning_rate=init_learning_rate)
d_fetches = [train_op.d_solver, model.balance, model.k_update]
else:
raise NotImplementedError('model_type is wrong.')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
saver = tf.train.Saver(max_to_keep=20)
init = tf.global_variables_initializer()
sess.run(init)
#saver.restore(sess, "./trial/trial-100000")
for iter in range(max_iter):
for k in range(K):
x_batch = sample_2d(lut_2d, N) - 0.5
sess.run(d_fetches, feed_dict={data:x_batch})
# x_batch, _ = train_data.next_batch(N)
sess.run(train_op.g_solver, feed_dict={data:x_batch})
if iter % verbose_interval == 0:
d_loss, g_loss, lr = sess.run([model.d_loss, model.g_loss, train_op.learning_rate], feed_dict={data:x_batch})
print('iter=%d, lr=%f, d_loss=%f, g_loss=%f') % (iter, lr, d_loss, g_loss)
if model_type == 'BEGAN':
messure, kt = sess.run([model.messure, model.kt], feed_dict={data:x_batch})
print('messure=%f, k=%f') % (messure, kt)
if iter % show_interval == 0:
real_samples = sample_2d(lut_2d, 2000)
gen_samples = sess.run(model.g_data, feed_dict={data:real_samples}) + 0.5
visualizer.draw(real_samples, gen_samples)
plt.show()
print ''
if (iter) % snapshot == 0:
saver.save(sess, './trial/trial', global_step=iter)