-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathjoint_flvm.py
277 lines (258 loc) · 12.9 KB
/
joint_flvm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import normalizing_flows.utils as utils
from normalizing_flows.flows import Transform
from tqdm import tqdm
from .trackable_module import TrackableModule
from .losses import wasserstein_loss
class JointFlowLVM(TrackableModule):
"""
Flow-based latent variable model for joint distribution inference.
Given random variables X, Y; JointFlowLVM attempts to learn a latent variable
model P(X,Y,Z) enabling conditional inference P(X|Y) and P(Y|X) via implicit
integration over the shared latent variables Z. This is done by learning two
bijective mappings X<->Z and Y<->Z via maximum likelihood in conjunction with
adversarial losses on P(X) and P(Y). See AlignFlow (Grover et al. 2019) for details.
"""
def __init__(self,
G_zx: Transform,
G_zy: Transform,
D_x: tf.keras.Model,
D_y: tf.keras.Model,
prior: tfp.distributions.Distribution=tfp.distributions.Normal(loc=0.0, scale=1.0),
input_shape=None,
num_bins=None,
Gx_aux_loss=lambda x,y: tf.constant(0.),
Gy_aux_loss=lambda x,y: tf.constant(0.),
adversarial_loss_ctor=wasserstein_loss,
optimizer_g=tf.keras.optimizers.Adam(lr=1.0E-4, beta_1=0.5, beta_2=0.9),
optimizer_dx=tf.keras.optimizers.Adam(lr=1.0E-4, beta_1=0.5, beta_2=0.9),
optimizer_dy=tf.keras.optimizers.Adam(lr=1.0E-4, beta_1=0.5, beta_2=0.9),
clip_grads=10.0,
name='joint_flvm'):
assert G_zx.name != G_zy.name, 'generators must have unique names'
super().__init__({'optimizer_g': optimizer_g, 'optimizer_dx': optimizer_dx, 'optimizer_dy': optimizer_dy}, name=name)
self.G_zx = G_zx
self.G_zy = G_zy
self.D_x = D_x
self.D_y = D_y
self.prior = prior
self.input_shape = input_shape
self.num_bins = num_bins
self.Gx_aux_loss = Gx_aux_loss
self.Gy_aux_loss = Gy_aux_loss
self.adv_loss_ctor = adversarial_loss_ctor
self.optimizer_g = optimizer_g
self.optimizer_dx = optimizer_dx
self.optimizer_dy = optimizer_dy
self.clip_grads = clip_grads
self.scale_factor = np.log2(num_bins) if num_bins is not None else 0.0
if self._is_initialized():
self.initialize(self.input_shape)
def _is_initialized(self):
return self.input_shape is not None
def initialize(self, input_shape):
self.input_shape = input_shape
self.Dx_loss, self.Gx_loss = self.adv_loss_ctor(self.D_x)
self.Dy_loss, self.Gy_loss = self.adv_loss_ctor(self.D_y)
with tf.init_scope():
self.G_zx.initialize(input_shape)
self.G_zy.initialize(input_shape)
self._init_checkpoint()
def _preprocess(self, x):
if self.num_bins is not None:
x += tf.random.uniform(x.shape, 0, 1./self.num_bins)
return x
def predict_y(self, x, return_log_prob=False):
if return_log_prob:
z, p_x = self.encode_x(x, return_log_prob=return_log_prob)
y, p_y = self.decode_y(z, return_log_prob=return_log_prob)
return y, p_y, p_x
else:
return self.decode_y(self.encode_x(x))
def predict_x(self, y, return_log_prob=False):
if return_log_prob:
z, p_y = self.encode_y(y, return_log_prob=return_log_prob)
x, p_x = self.decode_x(z, return_log_prob=return_log_prob)
return x, p_x, p_y
else:
return self.decode_x(self.encode_y(y))
@tf.function
def eval_generators_on_batch(self, x, y):
assert self.input_shape is not None, 'model not initialized'
num_elements = tf.cast(x.shape[1]*x.shape[2]*x.shape[3], tf.float32)
x = self._preprocess(x)
y = self._preprocess(y)
# compute generator outputs
z_x, ildj_x = self.G_zx.inverse(x)
y_x, _ = self.G_zy.forward(z_x)
z_y, ildj_y = self.G_zy.inverse(y)
x_y, _ = self.G_zx.forward(z_y)
# compute adversarial losses
gx_loss = self.Gx_loss(x, x_y)
gy_loss = self.Gy_loss(y, y_x)
# compute auxiliary loss
gx_aux = self.Gx_aux_loss(y, x_y)
gy_aux = self.Gy_aux_loss(x, y_x)
# compute likelihood losses
prior_logp_x = self.prior.log_prob(z_x)
prior_logp_y = self.prior.log_prob(z_y)
if prior_logp_x.shape.rank > 1:
# reduce log probs along non-batch dimensions
prior_logp_x = tf.math.reduce_sum(prior_logp_x, axis=[i for i in range(1,prior_logp_x.shape.rank)])
prior_logp_y = tf.math.reduce_sum(prior_logp_y, axis=[i for i in range(1,prior_logp_y.shape.rank)])
nll_x = -tf.math.reduce_mean((prior_logp_x + ildj_x - self.scale_factor*num_elements) / num_elements)
nll_y = -tf.math.reduce_mean((prior_logp_y + ildj_y - self.scale_factor*num_elements) / num_elements)
return nll_x, nll_y, gx_loss, gy_loss, gx_aux, gy_aux
@tf.function
def eval_discriminators_on_batch(self, x, y):
x_pred = self.predict_x(y)
y_pred = self.predict_y(x)
# evaluate discriminators
dx_loss = self.Dx_loss(x, x_pred)
dy_loss = self.Dy_loss(y, y_pred)
return dx_loss, dy_loss
@tf.function
def train_generators_on_batch(self, x, y, lam=1.0, alpha=0.0):
assert self.input_shape is not None, 'model not initialized'
nll_x, nll_y, gx_loss, gy_loss, gx_aux, gy_aux = self.eval_generators_on_batch(x, y)
# compute losses
reg_losses = [self.G_zx._regularization_loss()]
reg_losses += [self.G_zy._regularization_loss()]
g_obj = gx_loss + gy_loss + lam*(nll_x + nll_y) + alpha*(gx_aux + gy_aux) + tf.math.add_n(reg_losses)
# generator gradient update
generator_variables = list(self.G_zx.trainable_variables) + list(self.G_zy.trainable_variables)
g_grads = tf.gradients(g_obj, generator_variables)
if self.clip_grads:
g_grads, grad_norm = tf.clip_by_global_norm(g_grads, self.clip_grads)
self.optimizer_g.apply_gradients(zip(g_grads, generator_variables))
return g_obj, nll_x, nll_y, gx_loss, gy_loss, gx_aux, gy_aux
@tf.function
def train_discriminators_on_batch(self, x, y):
dx_loss, dy_loss = self.eval_discriminators_on_batch(x, y)
dx_grads = tf.gradients(dx_loss, self.D_x.trainable_variables)
dy_grads = tf.gradients(dy_loss, self.D_y.trainable_variables)
self.optimizer_dx.apply_gradients(zip(dx_grads, self.D_x.trainable_variables))
self.optimizer_dy.apply_gradients(zip(dy_grads, self.D_y.trainable_variables))
return dx_loss, dy_loss
def train(self, train_data: tf.data.Dataset, steps_per_epoch, num_epochs=1,
lam=1.0, lam_decay=0.0, alpha=0.0, **flow_kwargs):
train_gen_data = train_data.take(steps_per_epoch).repeat(num_epochs)
with tqdm(total=steps_per_epoch*num_epochs, desc='train') as prog:
hist = dict()
lam = tf.Variable(lam, dtype=tf.float32)
for epoch in range(num_epochs):
for x,y in train_gen_data.take(steps_per_epoch):
# train discriminators
dx_loss, dy_loss = self.train_discriminators_on_batch(x, y)
# train generators
g_obj, nll_x, nll_y, gx_loss, gy_loss, gx_aux, gy_aux = self.train_generators_on_batch(x, y, alpha=alpha, lam=utils.var(lam))
utils.update_metrics(hist, g_obj=g_obj.numpy(), gx_loss=gx_loss.numpy(), gy_loss=dy_loss.numpy(),
nll_x=nll_x.numpy(), nll_y=nll_y.numpy())
prog.update(1)
prog.set_postfix(utils.get_metrics(hist))
lam.assign_sub(lam_decay)
return hist
def evaluate(self, validation_data: tf.data.Dataset, validation_steps, **flow_kwargs):
validation_data = validation_data.take(validation_steps)
with tqdm(total=validation_steps, desc='eval') as prog:
hist = dict()
for x,y in validation_data:
# train discriminators
dx_loss, dy_loss = self.eval_discriminators_on_batch(x, y)
# train generators
nll_x, nll_y, gx_loss, gy_loss, gx_aux, gy_aux = self.eval_generators_on_batch(x, y)
utils.update_metrics(hist,
nll_x=nll_x.numpy(),
nll_y=nll_y.numpy(),
gx_loss=gx_loss.numpy(),
gy_loss=gy_loss.numpy(),
dx_loss=dx_loss.numpy(),
dy_loss=dy_loss.numpy(),
gx_aux=gx_aux.numpy(),
gy_aux=gy_aux.numpy())
prog.update(1)
prog.set_postfix(utils.get_metrics(hist))
return hist
def encode_x(self, x, return_log_prob=False):
z, ildj = self.G_zx.inverse(x)
if return_log_prob:
num_elements = tf.cast(x.shape[1]*x.shape[2]*x.shape[3], tf.float32)
log_prob = tf.math.reduce_sum(self.prior.log_prob(z), axis=[i for i in range(1,z.shape.rank)])
log_prob += ildj
log_prob /= num_elements
return z, log_prob
else:
return z
def decode_x(self, z, return_log_prob=False):
x, fldj = self.G_zx.forward(z)
if return_log_prob:
num_elements = tf.cast(x.shape[1]*x.shape[2]*x.shape[3], tf.float32)
log_prob = tf.math.reduce_sum(self.prior.log_prob(z), axis=[i for i in range(1,z.shape.rank)])
log_prob -= fldj
log_prob /= num_elements
return x, log_prob
else:
return x
def encode_y(self, y, return_log_prob=False):
z, ildj = self.G_zy.inverse(y)
if return_log_prob:
num_elements = tf.cast(y.shape[1]*y.shape[2]*y.shape[3], tf.float32)
log_prob = tf.math.reduce_sum(self.prior.log_prob(z), axis=[i for i in range(1,z.shape.rank)])
log_prob += ildj
log_prob /= num_elements
return z, log_prob
else:
return z
def decode_y(self, z, return_log_prob=False):
y, fldj = self.G_zy.forward(z)
if return_log_prob:
num_elements = tf.cast(y.shape[1]*y.shape[2]*y.shape[3], tf.float32)
log_prob = tf.math.reduce_sum(self.prior.log_prob(z), axis=[i for i in range(1,z.shape.rank)])
log_prob -= fldj
log_prob /= num_elements
return y, log_prob
else:
return y
def sample(self, n=1, return_log_prob=False):
assert self.input_shape is not None, 'model not initialized'
event_ndims = self.prior.event_shape.rank
z_shape = self.input_shape[1:]
if self.prior.is_scalar_batch():
z = self.prior.sample((n,*z_shape[:len(z_shape)-event_ndims]))
else:
z = self.prior.sample((n,))
return self.decode_x(z, return_log_prob=return_log_prob), self.decode_y(z, return_log_prob=return_log_prob)
def sample_predict_y(self, x, n=1, temperature=0.5, return_log_prob=False):
assert self.input_shape is not None, 'model not initialized'
z = self.encode_x(x)
eps = tf.random.normal((z.shape[0], n,*z.shape[1:]), stddev=temperature)
zs = eps + tf.expand_dims(z, axis=1)
zs = tf.reshape(zs, (n*z.shape[0], *z.shape[1:]))
x_, p_zx = self.decode_x(zs, return_log_prob=True)
y_, p_zy = self.decode_y(zs, return_log_prob=True)
x_ = tf.reshape(x_, (x.shape[0], n, *x.shape[1:]))
y_ = tf.reshape(y_, (x.shape[0], n, *x.shape[1:]))
p_zx = tf.reshape(p_zx, (z.shape[0], n))
p_zy = tf.reshape(p_zy, (z.shape[0], n))
if return_log_prob:
return y_, x_, p_zx, p_zy
else:
return y_, x_
def sample_predict_x(self, y, n=1, temperature=0.5, return_log_prob=False):
assert self.input_shape is not None, 'model not initialized'
z = self.encode_y(x)
eps = tf.random.normal((z.shape[0], n,*z.shape[1:]), stddev=temperature)
zs = eps + tf.expand_dims(z, axis=1)
zs = tf.reshape(zs, (n*z.shape[0], *z.shape[1:]))
x_, p_zx = self.decode_x(zs, return_log_prob=True)
y_, p_zy = self.decode_y(zs, return_log_prob=True)
x_ = tf.reshape(x_, (y.shape[0], n, *y.shape[1:]))
y_ = tf.reshape(y_, (y.shape[0], n, *y.shape[1:]))
p_zx = tf.reshape(p_zx, (z.shape[0], n))
p_zy = tf.reshape(p_zy, (z.shape[0], n))
if return_log_prob:
return x_, y_, p_zx, p_zy
else:
return x_, y_