-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_with_rnn.py
349 lines (296 loc) · 16.5 KB
/
train_with_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os, sys, io
from copy import deepcopy
import argparse
from contextlib import ExitStack
import importlib
from typing import List, Dict, Union, Any, Optional
import numpy as np
import progressbar
import pprint
import torch
from torch import nn
wd = os.path.dirname(__file__)
wd = "." if wd == "" else wd
os.chdir(wd)
from common.loader.text import TextLoader
from preprocess.tokenizer import CharacterTokenizer
from preprocess.corpora import Dictionary
from preprocess.dataset_feeder import GeneralSentenceFeeder
from preprocess import utils
# encoders
from model.multi_layer import MultiDenseLayer, IdentityLayer
from model.encoder import GMMLSTMEncoder
# decoder
from model.attention import SimpleGlobalAttention, MultiHeadedAttention, PassTuru
from model.decoder import SelfAttentiveLSTMDecoder
from model.decoder import SimplePredictor
# regularizers
## sampler
from model.noise_layer import GMMSampler
## prior distribution
from distribution.mixture import MultiVariateGaussianMixture
from utility import generate_random_orthogonal_vectors, calculate_prior_dist_params, calculate_mean_l2_between_sample
## loss functions
from model.loss import EmpiricalSlicedWassersteinDistance, GMMSinkhornWassersteinDistance, GMMApproxKLDivergence
from model.loss import MaskedKLDivLoss
from model.loss import PaddedNLLLoss
# variational autoencoder
from model.vae import VariationalAutoEncoder
## used for evaluation
from utility import enumerate_optional_metrics, write_log_and_progress
# estimator
from train import Estimator
def _parse_args():
parser = argparse.ArgumentParser(description="Wasserstein AutoEncoder using Gaussian Mixture and RNN Encoder/Decoder: train/validation script")
parser.add_argument("--config_module", "-c", required=True, type=str, help="config module name. example: `config.default`")
parser.add_argument("--save_dir", "-s", required=True, type=str, help="directory for saving trained model")
parser.add_argument("--device", "-d", required=False, type=str, default="cpu", help="computing device. DEFAULT:cpu")
parser.add_argument("--gpus", required=False, type=str, default="0", help="GPU device ids to be used for dataparallel processing. DEFAULT:0")
parser.add_argument("--save_every_epoch", action="store_true", help="save trained model at every epoch. DEFAULT:False")
parser.add_argument("--log_validation_only", action="store_true", help="record validation metrics only. DEFAULT:False")
parser.add_argument("--checkpoint", required=False, type=str, default=None, help="restart checkpoint. specify saved model parameters. DEFAULT:None")
parser.add_argument("--checkpoint_prior", required=False, type=str, default=None, help="prior distribution that corresponds to specified checkpoint.")
parser.add_argument("--checkpoint_epoch", required=False, type=int, default=None, help="restart epoch number.")
parser.add_argument("--verbose", action="store_true", help="output verbosity")
args = parser.parse_args()
if args.device.find("cuda") != -1:
assert torch.cuda.is_available(), "GPU is unavailable but cuda device was specified."
args.gpus = [int(i) for i in args.gpus.split(",")]
else:
args.gpus = []
args.device = torch.device(args.device)
if args.checkpoint is not None:
assert args.checkpoint_epoch is not None, "you must specify `checkpoint_epoch` argument when you restart from the checkpoint."
assert args.checkpoint_prior is not None, "you must specify `checkpoint_prior` argument when you restart from the checkpoint."
assert os.path.exists(args.checkpoint_prior), f"specified file does not exist: {args.checkpoint_prior}"
return args
def main():
args = _parse_args()
# import configuration
config = importlib.import_module(name=args.config_module)
file_name_suffix = args.config_module
cfg_auto_encoder = config.cfg_auto_encoder
cfg_corpus = config.cfg_corpus
cfg_optimizer = config.cfg_optimizer
# show important configurations
print("corpus:")
pprint.pprint(cfg_corpus)
print("optimization:")
for k, v in cfg_optimizer.items():
print(f"\t{k}:{v}")
print("prior distribution:")
pprint.pprint(cfg_auto_encoder["prior"])
# instanciate corpora
tokenizer = CharacterTokenizer()
dictionary = Dictionary.load(file_path=cfg_corpus["dictionary"])
if len(dictionary.special_tokens) == 2:
bos, eos = list(dictionary.special_tokens)
elif len(dictionary.special_tokens) == 3:
bos, eos, oov = list(dictionary.special_tokens)
else:
raise NotImplementedError("unexpected dictionary type.")
dict_data_feeder = {}
for corpus_type in "train,dev,test".split(","):
if not corpus_type in cfg_corpus:
continue
cfg_corpus_t = cfg_corpus[corpus_type]
corpus_t = TextLoader(file_path=cfg_corpus_t["corpus"])
data_feeder_t = GeneralSentenceFeeder(corpus = corpus_t,
tokenizer = tokenizer, dictionary = dictionary,
n_minibatch=cfg_optimizer["n_minibatch"], validation_split=0.,
min_seq_len=cfg_corpus_t["min_seq_len"],
max_seq_len=cfg_corpus_t["max_seq_len"],
# append `<eos>` at the end of each sentence
bos_symbol=None, eos_symbol=eos)
dict_data_feeder[corpus_type] = data_feeder_t
# setup logger
dict_logger = {}
for phase in "train,test".split(","):
path_log_file = cfg_corpus["log_file_path"] + f".{phase}"
if os.path.exists(path_log_file):
os.remove(path_log_file)
logger = io.open(path_log_file, mode="w")
dict_logger[phase] = logger
# instanciate variational autoencoder
## prior distribution
n_dim_gmm = cfg_auto_encoder["prior"]["n_dim"]
n_prior_gmm_component = cfg_auto_encoder["prior"]["n_gmm_component"]
## regularizer type
regularizer_name = next(iter(cfg_auto_encoder["loss"]["reg"]))
print(f"regularizer type: {regularizer_name}")
is_sliced_wasserstein = regularizer_name.find("sliced_wasserstein") != -1
# calculate l2 norm and stdev of the mean and stdev of prior distribution
if args.checkpoint is not None:
path_prior = args.checkpoint_prior
print(f"prior distribution will be restored from file: {path_prior}")
prior_distribution = MultiVariateGaussianMixture.load(file_path=path_prior)
assert prior_distribution.n_dim == n_dim_gmm, f"configuration mismatch detected: {n_dim_gmm}"
assert prior_distribution.n_component == n_prior_gmm_component, f"configuration mismatch detected: {n_prior_gmm_component}"
else:
expected_wd = cfg_auto_encoder["prior"].get("expected_wd", 1.0)
l2_norm, std = calculate_prior_dist_params(expected_wd=expected_wd, n_dim_latent=n_dim_gmm, sliced_wasserstein=is_sliced_wasserstein)
## overwrite auto values with user-defined values
l2_norm = cfg_auto_encoder["prior"].get("l2_norm", l2_norm)
std = cfg_auto_encoder["prior"].get("std", std)
print("prior distribution parameters.")
print(f"\tl2_norm:{l2_norm:2.3f}, stdev:{std:2.3f}")
vec_alpha = np.full(shape=n_prior_gmm_component, fill_value=1./n_prior_gmm_component)
mat_mu = generate_random_orthogonal_vectors(n_dim=n_dim_gmm, n_vector=n_prior_gmm_component, l2_norm=l2_norm)
vec_std = np.ones(shape=n_prior_gmm_component) * std
prior_distribution = MultiVariateGaussianMixture(vec_alpha=vec_alpha, mat_mu=mat_mu, vec_std=vec_std)
path_prior = os.path.join(args.save_dir, f"prior_distribution.gmm.{file_name_suffix}.pickle")
prior_distribution.save(file_path=path_prior)
# instanciate variational autoencoder model
model = Estimator.instanciate_variational_autoencoder(cfg_auto_encoder=cfg_auto_encoder, n_vocab=dictionary.max_id+1,
device=args.device, path_state_dict=args.checkpoint, encoder_only=False)
## wrap with DataParallel class for parallel processing
if len(args.gpus) > 1:
model = nn.DataParallel(model, device_ids = args.gpus)
model.to(device=args.device)
## loss layers
### regularizer between posteriors and prior; wasserstein distance or kullback-leibler divergence: d(p(z|x), q(z))
cfg_regularizer = cfg_auto_encoder["loss"]["reg"]
if regularizer_name == "empirical_sliced_wasserstein":
loss_regularizer = EmpiricalSlicedWassersteinDistance(device=args.device, **cfg_regularizer["empirical_sliced_wasserstein"])
elif regularizer_name == "sinkhorn_wasserstein":
loss_regularizer = GMMSinkhornWassersteinDistance(device=args.device, **cfg_regularizer["sinkhorn_wasserstein"])
elif regularizer_name == "kullback_leibler":
loss_regularizer = GMMApproxKLDivergence(device=args.device, **cfg_regularizer["kullback_leibler"])
else:
raise NotImplementedError("unsupported regularizer type:", regularizer_name)
### kullback-leibler divergence on \alpha: KL(p(\alpha|x), q(\alpha))
if cfg_auto_encoder["loss"]["kldiv"]["enabled"]:
loss_kldiv = MaskedKLDivLoss(scale=cfg_auto_encoder["loss"]["kldiv"]["scale"], reduction="samplewise_mean")
else:
loss_kldiv = None
### negtive log likelihood: -lnp(x|z); z~p(z|x)
loss_reconst = PaddedNLLLoss(reduction="samplewise_mean")
### instanciate estimator ###
estimator = Estimator(model=model, loss_reconst=loss_reconst, loss_layer_reg=loss_regularizer, loss_layer_kldiv=loss_kldiv,
device=args.device, verbose=args.verbose)
# optimizer for variational autoencoder
optimizer = cfg_optimizer["optimizer"](model.parameters(), lr=cfg_optimizer["lr"])
# start training
n_epoch_total = cfg_optimizer["n_epoch"]
# iterate over epoch
n_epoch_init = 0 if args.checkpoint is None else args.checkpoint_epoch
n_iteration = 0 if args.checkpoint is None else int(np.ceil(cfg_corpus["train"]["size"] / cfg_optimizer["n_minibatch"]))
n_processed = 0 if args.checkpoint is None else cfg_corpus["train"]["size"] * n_epoch_init
for n_epoch in range(n_epoch_init, n_epoch_total):
print(f"epoch:{n_epoch}")
#### train phase ####
phase = "train"
print(f"phase:{phase}")
model.train()
logger = dict_logger[phase]
cfg_corpus_t = cfg_corpus[phase]
lst_eval_metrics = enumerate_optional_metrics(cfg_metrics=cfg_corpus[phase].get("evaluation_metrics",[]), n_epoch=n_epoch+1)
q = progressbar.ProgressBar(max_value=cfg_corpus_t["size"])
n_progress = 0
q.update(n_progress)
## iterate over mini-batch
for train, _ in dict_data_feeder[phase]:
n_iteration += 1
# update scale parameter of wasserstein distance layer
estimator.loss_reg.update_scale_parameter(n_processed=n_processed)
# update annealing parameter of sampler layer
estimator.model._sampler.update_anneal_parameter(n_processed=n_processed)
# training
if cfg_optimizer["validation_interval"] is not None:
train_mode = not(n_iteration % cfg_optimizer["validation_interval"] == 0)
else:
train_mode = True
lst_seq_len, lst_seq = utils.len_pad_sort(lst_seq=train)
metrics_batch = estimator.train_single_step(lst_seq=lst_seq, lst_seq_len=lst_seq_len, optimizer=optimizer,
prior_distribution=prior_distribution,
clip_gradient_value=cfg_optimizer["gradient_clip"],
evaluation_metrics=lst_eval_metrics)
n_processed += len(lst_seq_len)
n_progress += len(lst_seq_len)
# logging and reporting
write_log_and_progress(n_epoch=n_epoch,n_processed=n_processed,
mode="train" if train_mode else "val",
dict_metrics=metrics_batch,
logger = logger,
output_log= not(args.log_validation_only) or not(train_mode),
output_std=args.verbose
)
# next iteration
q.update(n_progress)
# save progress
if args.save_every_epoch:
path_trained_model_e = os.path.join(args.save_dir, f"lstm_vae.{file_name_suffix}.model." + str(n_epoch))
print(f"saving...:{path_trained_model_e}")
torch.save(model.state_dict(), path_trained_model_e)
# clean up GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
#### (optional) update prior distribution ###
if "update" in cfg_auto_encoder["prior"]:
cfg_update_prior = cfg_auto_encoder["prior"]["update"]
if n_epoch in cfg_update_prior["target_epoch"]:
print("update prior distribution. wait for a while...")
prior_distribution_new = estimator.train_prior_distribution(
cfg_optimizer=cfg_update_prior["optimizer"],
cfg_regularizer=cfg_update_prior["regularizer"],
prior_distribution=prior_distribution,
data_feeder=dict_data_feeder["train"]
)
# renew prior distribution
prior_distribution = prior_distribution_new
prior_distribution.save(file_path=path_prior)
else:
print("we do not update prior distribution. skip training.")
#### test phase ####
if not "test" in dict_data_feeder:
print("we do not have testset. skip evaluation.")
continue
phase = "test"
print(f"phase:{phase}")
model.eval()
logger = dict_logger[phase]
lst_eval_metrics = enumerate_optional_metrics(cfg_metrics=cfg_corpus[phase].get("evaluation_metrics",[]), n_epoch=n_epoch+1)
lst_metrics_batch = []
## iterate over mini-batch
for batch, _ in dict_data_feeder[phase]:
lst_seq_len, lst_seq = utils.len_pad_sort(lst_seq=batch)
metrics_batch = estimator.test_single_step(lst_seq=lst_seq, lst_seq_len=lst_seq_len,
prior_distribution=prior_distribution,
evaluation_metrics=lst_eval_metrics)
lst_metrics_batch.append(metrics_batch)
# clean up GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
# calculate whole metrics
metrics = {}
vec_n_sentence = np.array([m["n_sentence"] for m in lst_metrics_batch])
vec_n_token = np.array([m["n_token"] for m in lst_metrics_batch])
metrics["n_sentence"] = np.sum(vec_n_sentence)
metrics["n_token"] = np.sum(vec_n_token)
metrics["n_token_per_sentence"] = metrics["n_token"] / metrics["n_sentence"]
for metric_name in lst_metrics_batch[0].keys():
vec_values = np.array([np.nan if m[metric_name] is None else m[metric_name] for m in lst_metrics_batch])
if metric_name in ["n_sentence","n_token"]:
continue
elif metric_name.startswith("mean_"): # sentence-wise mean
metrics[metric_name] = np.sum(vec_n_sentence * vec_values) / np.sum(vec_n_sentence)
elif metric_name == "nll_token": # token-wise mean
metrics[metric_name] = np.sum(vec_n_token * vec_values) / np.sum(vec_n_token)
else: # token-wise mean * average sentence length
metrics[metric_name] = np.sum(vec_n_sentence * vec_values) * metrics["n_token_per_sentence"] / metrics["n_token"]
# logging and reporting
write_log_and_progress(n_epoch=n_epoch, n_processed=n_processed, mode="test", dict_metrics=metrics,
logger=logger, output_log=True, output_std=True)
### proceed to next epoch ###
# end of epoch
for logger in dict_logger.values():
logger.close()
# save trained model
path_trained_model_e = os.path.join(args.save_dir, f"lstm_vae.{file_name_suffix}.model." + str(n_epoch_total))
print(f"saving...:{path_trained_model_e}")
torch.save(model.state_dict(), path_trained_model_e)
if __name__ == "__main__":
main()
print("finished. good-bye.")