diff --git a/scripts/language_model/index.rst b/scripts/language_model/index.rst index 9a69f347e0..318050c8a3 100644 --- a/scripts/language_model/index.rst +++ b/scripts/language_model/index.rst @@ -18,65 +18,17 @@ The dataset used for training the models is wikitext-2. +---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ | Model | awd_lstm_lm_1150_wikitext-2 | awd_lstm_lm_600_wikitext-2 | standard_lstm_lm_1500_wikitext-2 | standard_lstm_lm_650_wikitext-2 | standard_lstm_lm_200_wikitext-2 | +===============+============================================================================================================================+===========================================================================================================================+=================================================================================================================================+================================================================================================================================+================================================================================================================================+ -| Mode | LSTM | LSTM | LSTM | LSTM | LSTM | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Num_layers | 3 | 3 | 2 | 2 | 2 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Embed size | 400 | 200 | 1500 | 650 | 200 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Hidden size | 1150 | 600 | 1500 | 650 | 200 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Dropout | 0.4 | 0.2 | 0.65 | 0.5 | 0.2 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Dropout_h | 0.2 | 0.1 | 0 | 0 | 0 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Dropout_i | 0.65 | 0.3 | 0 | 0 | 0 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Dropout_e | 0.1 | 0.05 | 0 | 0 | 0 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Weight_drop | 0.5 | 0.2 | 0 | 0 | 0 | -+---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ | Val PPL | 68.71 | 84.89 | 86.51 | 90.96 | 107.59 | +---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ | Test PPL | 65.62 | 80.67 | 82.29 | 86.91 | 101.64 | +---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ -| Command | [1] | [2] | [3] | [4] | [5] | +| Command | `command `__ | `command `__ | `command `__ | `command `__ | `command `__ | +---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ | Training logs | `log `__ | `log `__ | `log `__ | `log `__ | `log `__ | +---------------+----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+ For all the above model settings, we set Tied = True and NTASGD = True . -[1] awd_lstm_lm_1150_wikitext-2 (Val PPL 68.71 Test PPL 65.62 ) - -.. code-block:: console - - $ python word_language_model.py --gpu 0 --tied --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save awd_lstm_lm_1150_wikitext-2 - -[2] awd_lstm_lm_600_wikitext-2 (Val PPL 84.89 Test PPL 80.67) - -.. code-block:: console - - $ python word_language_model.py --gpu 0 --emsize 200 --nhid 600 --epochs 750 --dropout 0.2 --dropout_h 0.1 --dropout_i 0.3 --dropout_e 0.05 --weight_drop 0.2 --tied --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save awd_lstm_lm_600_wikitext-2 - -[3] standard_lstm_lm_1500_wikitext-2 (Val PPL 86.51 Test PPL 82.29) - -.. code-block:: console - - $ python word_language_model.py --gpu 0 --emsize 1500 --nhid 1500 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.65 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_1500_wikitext-2 - -[4] standard_lstm_lm_650_wikitext-2 (Val PPL 90.96 Test PPL 86.91) - -.. code-block:: console - - $ python word_language_model.py --gpu 0 --emsize 650 --nhid 650 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.5 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_650_wikitext-2 - -[5] standard_lstm_lm_200_wikitext-2 (Val PPL 107.59 Test PPL 101.64) - -.. code-block:: console - - $ python word_language_model.py --gpu 0 --emsize 200 --nhid 200 --nlayers 2 --lr 20 --epochs 750 --batch_size 20 --bptt 35 --dropout 0.2 --dropout_h 0 --dropout_i 0 --dropout_e 0 --weight_drop 0 --tied --wd 0 --alpha 0 --beta 0 --ntasgd --lr_update_interval 30 --lr_update_factor 0.1 --save standard_lstm_lm_200_wikitext-2 - Cache Language Model ~~~~~~~~~~~~~~~~~~~~~ @@ -97,43 +49,13 @@ The dataset used for training the models is wikitext-2. +---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | Test PPL | 51.46 | 62.19 | 62.79 | 65.85 | 73.74 | +---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| Command | [1] | [2] | [3] | [4] | [5] | +| Command | `command `__ | `command `__ | `command `__ | `command `__ | `command `__ | +---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | Training logs | `log `__ | `log `__ | `log `__ | `log `__ | `log `__ | +---------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ For all the above model settings, we set lambdas = 0.1279, theta = 0.662, window = 2000 and bptt= 2000 . -[1] cache_awd_lstm_lm_1150_wikitext-2 (Val PPL 53.41 Test PPL 51.46) - -.. code-block:: console - - $ python cache_language_model.py --gpus 0 --model_name awd_lstm_lm_1150 - -[2] cache_awd_lstm_lm_600_wikitext-2 (Val PPL 64.51 Test PPL 62.19) - -.. code-block:: console - - $ python cache_language_model.py --gpus 0 --model_name awd_lstm_lm_600 - -[3] cache_standard_lstm_lm_1500_wikitext-2 (Val PPL 65.54 Test PPL 62.79) - -.. code-block:: console - - $ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_1500 - -[4] cache_standard_lstm_lm_650_wikitext-2 (Val PPL 68.47 Test PPL 65.85) - -.. code-block:: console - - $ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_650 - -[5] cache_standard_lstm_lm_200_wikitext-2 (Val PPL 77.51 Test PPL 73.74) - -.. code-block:: console - - $ python cache_language_model.py --gpus 0 --model_name standard_lstm_lm_200 - Large Scale Word Language Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -148,42 +70,17 @@ The dataset used for training the models is Google's 1 billion words dataset. +-----------------+------------------------------------------------------------------------------------------------------------------------------+ | Model | LSTM-2048-512 | +=================+==============================================================================================================================+ -| Mode | LSTMP | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Num layers | 1 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Embed size | 512 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Hidden size | 2048 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Projection size | 512 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Dropout | 0.1 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Learning rate | 0.2 | +| Test perplexity | 43.80 | +-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Num samples | 8192 | +| Command | `command `__ | +-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Batch size | 128 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Gradient clip | 10.0 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Test perplexity | 43.62 | -+-----------------+------------------------------------------------------------------------------------------------------------------------------+ -| Num epochs | 50 | +| Command | `command `__ | +-----------------+------------------------------------------------------------------------------------------------------------------------------+ | Training logs | `log `__ | +-----------------+------------------------------------------------------------------------------------------------------------------------------+ | Evaluation logs | `log `__ | +-----------------+------------------------------------------------------------------------------------------------------------------------------+ -[1] LSTM-2048-512 (Test PPL 43.62) - -.. code-block:: console - - $ python large_word_language_model.py --gpus 0,1,2,3 --clip=10 - $ python large_word_language_model.py --gpus 4 --eval-only --batch-size=1 - XLNet: Generalized Autoregressive Pretraining for Language Understanding ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/scripts/language_model/large_word_language_model_estimator.py b/scripts/language_model/large_word_language_model_estimator.py new file mode 100644 index 0000000000..070184f9b8 --- /dev/null +++ b/scripts/language_model/large_word_language_model_estimator.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" large word language model train script """ + +import os +import sys +import argparse +import re + +import numpy as np +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.contrib.estimator import CheckpointHandler, LoggingHandler +import gluonnlp as nlp +from gluonnlp.estimator import ParallelLanguageModelBatchProcessor +from gluonnlp.estimator import HiddenStateHandler, MetricResetHandler +from gluonnlp.estimator import LargeRNNGradientUpdateHandler +from gluonnlp.estimator import LanguageModelEstimator +from gluonnlp.estimator import ParallelLoggingHandler +from gluonnlp.metric.length_normalized_loss import LengthNormalizedLoss +from sampler import LogUniformSampler + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '..', '..')) + +nlp.utils.check_version('0.7.0') + +############################################################################### +# Arg parser +############################################################################### +parser = argparse.ArgumentParser(description= + 'Gluon-NLP Big LSTM 2048-512 Language Model on GBW') +parser.add_argument('--save', type=str, default='model.params', + help='path to save the final model.') +parser.add_argument('--emsize', type=int, default=512, + help='size of word embeddings') +parser.add_argument('--nhid', type=int, default=2048, + help='number of hidden units per layer') +parser.add_argument('--nproj', type=int, default=512, + help='number of projection units per layer. Could be different from embsize') +parser.add_argument('--nlayers', type=int, default=1, + help='number of layers') +parser.add_argument('--from-epoch', type=int, default=None, + help='start training or testing from the provided epoch') +parser.add_argument('--epochs', type=int, default=50, + help='number of epoch for training') +parser.add_argument('--batch-size', type=int, default=128, + help='batch size per gpu') +parser.add_argument('--dropout', type=float, default=0.1, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--eps', type=float, default=1, + help='initial history accumulation for adagrad') +parser.add_argument('--bptt', type=int, default=20, + help='sequence length') +parser.add_argument('--k', type=int, default=8192, + help='number of noise samples for estimation') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.') +parser.add_argument('--log-interval', type=int, default=1000, + help='report interval') +parser.add_argument('--seed', type=int, default=0, + help='random seed') +parser.add_argument('--lr', type=float, default=0.2, + help='initial learning rate') +parser.add_argument('--clip', type=float, default=1.0, + help='gradient clipping by global norm.') +parser.add_argument('--test-mode', action='store_true', + help='Whether to run through the script with few examples') +parser.add_argument('--eval-only', action='store_true', + help='Whether to only run evaluation for the trained model') +args = parser.parse_args() + +segments = ['train', 'test'] +max_nbatch_eval = None + +if args.test_mode: + args.emsize = 200 + args.log_interval = 1 + args.nhid = 200 + args.nlayers = 1 + args.epochs = 20 + max_nbatch_eval = 3 + segments = ['test', 'test'] + +print(args) +mx.random.seed(args.seed) +np.random.seed(args.seed) + +context = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ + [mx.gpu(int(x)) for x in args.gpus.split(',')] + +os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' +os.environ['MXNET_CPU_PARALLEL_RAND_COPY'] = str(len(context)) +os.environ['MXNET_CPU_WORKER_NTHREADS'] = str(len(context)) + +############################################################################### +# Data stream +############################################################################### +train_data_stream, test_data_stream = \ + [nlp.data.GBWStream(segment=segment, skip_empty=True, bos=None, eos='') + for segment in segments] +vocab = train_data_stream.vocab +ntokens = len(vocab) + +# Sampler for generating negative classes during training with importance sampling +sampler = LogUniformSampler(ntokens, args.k) + +# Given a list of (array, context) pairs, load array[i] on context[i] +def _load(xs): + ret = [] + for x, ctx in zip(xs, context): + if isinstance(x, tuple): + ret.append([y.as_in_context(ctx) for y in x]) + else: + ret.append(x.as_in_context(ctx)) + return ret + +# Transformation for a data batch for training. +# First, load the data, target and mask to target contexts. +# Second, the LSTM-2048-512 model performs importance sampling for decoding +# during training, we need to sample negative candidate classes by invoking the +# log uniform sampler. +def _split_and_sample(x, y): + m = x != vocab[vocab.padding_token] # mask padding + num_ctx = len(context) + if num_ctx > 1: + xs = gluon.utils.split_data(x, num_ctx, batch_axis=1, even_split=True) + ys = gluon.utils.split_data(y, num_ctx, batch_axis=1, even_split=True) + ms = gluon.utils.split_data(m, num_ctx, batch_axis=1, even_split=True) + else: + xs, ys, ms = [x], [y], [m] + xs = _load(xs) + ys = _load(ys) + ms = _load(ms) + ss = [sampler(y) for y in ys] + ss = _load(ss) + return xs, ys, ms, ss + +train_batch_size = args.batch_size * len(context) +train_batchify = nlp.data.batchify.StreamBPTTBatchify(vocab, args.bptt, train_batch_size) +train_data = train_batchify(train_data_stream) +train_data = train_data.transform(_split_and_sample) + +test_batch_size = args.batch_size +test_batchify = nlp.data.batchify.StreamBPTTBatchify(vocab, args.bptt, test_batch_size) +test_data = test_batchify(test_data_stream) +test_data = nlp.data.PrefetchingStream(test_data) + +############################################################################### +# Build the model +############################################################################### + +model = nlp.model.language_model.train.BigRNN(ntokens, args.emsize, args.nhid, + args.nlayers, args.nproj, args.k, + embed_dropout=args.dropout, + encode_dropout=args.dropout) +eval_model = nlp.model.language_model.BigRNN(ntokens, args.emsize, args.nhid, + args.nlayers, args.nproj, + embed_dropout=args.dropout, + encode_dropout=args.dropout) + +loss = gluon.loss.SoftmaxCrossEntropyLoss() +model.initialize(mx.init.Xavier(factor_type='out'), ctx=context) +trainer_params = {'learning_rate': args.lr, 'wd': 0, 'eps': args.eps} +trainer = gluon.Trainer(model.collect_params(), 'adagrad', trainer_params) +if args.from_epoch: + from_epoch = args.from_epoch + checkpoint_name = '%s.%s'%(args.save, format(from_epoch - 1, '02d')) + model.load_parameters(checkpoint_name) + trainer.load_states('%s.state'%args.save) + print('Loaded parameters from checkpoint %s'%(checkpoint_name)) + + +model.hybridize(static_alloc=True, static_shape=True) + +train_metric = mx.metric.Loss(loss) +val_metric = LengthNormalizedLoss(loss) +batch_processor = ParallelLanguageModelBatchProcessor(loss=loss, + vocab=vocab, + batch_size=args.batch_size, + val_batch_size=args.batch_size) +lm_estimator = LanguageModelEstimator(net=model, loss=loss, + train_metrics=train_metric, + val_metrics=val_metric, + trainer=trainer, + context=context, + val_loss=loss, + val_net=eval_model, + batch_processor=batch_processor, + bptt=args.bptt) + +hidden_state_handler = HiddenStateHandler() +gradient_handler = LargeRNNGradientUpdateHandler(batch_size=args.batch_size, clip=args.clip) +metric_handler = MetricResetHandler(metrics=lm_estimator.train_metrics, + log_interval=args.log_interval) +checkpoint_handler = CheckpointHandler(model_dir=args.save, model_prefix='largeRNN') +logging_handler = ParallelLoggingHandler(log_interval=args.log_interval, + metrics=lm_estimator.train_metrics) +val_logging_handler = LoggingHandler(log_interval=args.log_interval, + metrics=lm_estimator.val_metrics) + +event_handlers = [hidden_state_handler, gradient_handler, + metric_handler, checkpoint_handler, logging_handler] + +if not args.eval_only: + lm_estimator.fit(train_data=train_data, + epochs=args.epochs, + event_handlers=event_handlers, + #batches=5, + batch_axis=0) + +val_metric_handler = MetricResetHandler(metrics=lm_estimator.val_metrics) +lm_estimator.val_net.initialize(mx.init.Xavier(), ctx=context[0]) +lm_estimator.val_net.hybridize(static_alloc=True, static_shape=True) + +for epoch_id in range(args.epochs): + for filename in os.listdir(args.save): + file_pattern = r'largeRNN-epoch%dbatch\d+.params' % (epoch_id) + if re.match(file_pattern + '', filename): + checkpoint_path = args.save + '/' + filename + lm_estimator.val_net.load_parameters(checkpoint_path) + lm_estimator.evaluate(val_data=test_data, + event_handlers=[val_metric_handler, val_logging_handler]) diff --git a/scripts/language_model/word_language_model_estimator.py b/scripts/language_model/word_language_model_estimator.py new file mode 100644 index 0000000000..49b3291c6d --- /dev/null +++ b/scripts/language_model/word_language_model_estimator.py @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" word language model training script """ + +import argparse +import os +import sys + +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.contrib.estimator import LoggingHandler +from mxnet.gluon.data.sampler import BatchSampler +import gluonnlp as nlp +from gluonnlp.loss.joint_loss import JointActivationRegularizationLoss +from gluonnlp.estimator import LanguageModelEstimator +from gluonnlp.estimator import HiddenStateHandler, AvgParamHandler +from gluonnlp.estimator import LearningRateHandler, RNNGradientUpdateHandler +from gluonnlp.estimator import WordLanguageModelCheckpointHandler +from gluonnlp.estimator import LanguageModelBatchProcessor +from gluonnlp.estimator import MetricResetHandler + + +class BatchVariableLenTextSampler(BatchSampler): + """Sample text of variable length + + Generate batch of text of variable length from the training dataset + + Parameters + ---------- + bptt : int + bptt variable + length : int + base sequence length for sampling + use_variable_length : bool + generate sequence of variable length or not + """ + def __init__(self, bptt, length, use_variable_length=True): + super(BatchVariableLenTextSampler, self).__init__() + self.bptt = bptt + self.length = length + self.index = 0 + self.use_variable_length = use_variable_length + + def __iter__(self): + self.index = 0 + while self.index < self.length - 2: + if self.use_variable_length: + bptt = self.bptt if mx.nd.random.uniform().asscalar() < .95 else self.bptt / 2 + seq_len = max(5, int(mx.nd.random.normal(bptt, 5).asscalar())) + else: + seq_len = self.bptt + seq_len = min(seq_len, self.length - self.index - 1) + # batch_size = seq_len + 1 + batch = [] + for i in range(self.index, self.index + seq_len + 1): + batch.append(i) + self.index += seq_len + yield batch + + def __len__(self): + # you may never get real size of the data sampler beforehand. May need some + # postprocessing after fetching the data batch + return int(self.length / 5) + 1 + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '..', '..')) + +nlp.utils.check_version('0.7.0') + +parser = argparse.ArgumentParser(description= + 'MXNet Autograd RNN/LSTM Language Model on Wikitext-2.') +parser.add_argument('--model', type=str, default='lstm', + help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)') +parser.add_argument('--emsize', type=int, default=400, + help='size of word embeddings') +parser.add_argument('--nhid', type=int, default=1150, + help='number of hidden units per layer') +parser.add_argument('--nlayers', type=int, default=3, + help='number of layers') +parser.add_argument('--lr', type=float, default=30, + help='initial learning rate') +parser.add_argument('--clip', type=float, default=0.25, + help='gradient clipping') +parser.add_argument('--epochs', type=int, default=750, + help='upper epoch limit') +parser.add_argument('--batch_size', type=int, default=80, metavar='N', + help='batch size') +parser.add_argument('--bptt', type=int, default=70, + help='sequence length') +parser.add_argument('--dropout', type=float, default=0.4, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--dropout_h', type=float, default=0.2, + help='dropout applied to hidden layer (0 = no dropout)') +parser.add_argument('--dropout_i', type=float, default=0.65, + help='dropout applied to input layer (0 = no dropout)') +parser.add_argument('--dropout_e', type=float, default=0.1, + help='dropout applied to embedding layer (0 = no dropout)') +parser.add_argument('--weight_dropout', type=float, default=0.5, + help='weight dropout applied to h2h weight matrix (0 = no weight dropout)') +parser.add_argument('--tied', action='store_true', + help='tie the word embedding and softmax weights') +parser.add_argument('--log-interval', type=int, default=200, metavar='N', + help='report interval') +parser.add_argument('--save', type=str, default='model.params', + help='path to save the final model') +parser.add_argument('--eval_only', action='store_true', + help='Whether to only evaluate the trained model') +parser.add_argument('--gpu', type=str, help='single gpu id') +parser.add_argument('--optimizer', type=str, default='sgd', + help='optimizer to use (sgd, adam)') +parser.add_argument('--wd', type=float, default=1.2e-6, + help='weight decay applied to all weights') +parser.add_argument('--alpha', type=float, default=2, + help='alpha L2 regularization on RNN activation ' + '(alpha = 0 means no regularization)') +parser.add_argument('--beta', type=float, default=1, + help='beta slowness regularization applied on RNN activation ' + '(beta = 0 means no regularization)') +parser.add_argument('--ntasgd', action='store_true', + help='Whether to apply ntasgd') +parser.add_argument('--test_mode', action='store_true', + help='Whether to run through the script with few examples') +parser.add_argument('--lr_update_interval', type=int, default=30, + help='lr udpate interval') +parser.add_argument('--lr_update_factor', type=float, default=0.1, + help='lr udpate factor') +args = parser.parse_args() + +############################################################################### +# Load data +############################################################################### + +context = [mx.cpu()] if not args.gpu else [mx.gpu(int(args.gpu))] + +assert args.batch_size % len(context) == 0, \ + 'Total batch size must be multiple of the number of devices' + +assert args.weight_dropout > 0 or (args.weight_dropout == 0 and args.alpha == 0), \ + 'The alpha L2 regularization cannot be used with standard RNN, please set alpha to 0' + +train_dataset, val_dataset, test_dataset = \ + [nlp.data.WikiText2(segment=segment, + skip_empty=False, bos=None, eos='') + for segment in ['train', 'val', 'test']] + +vocab = nlp.Vocab(counter=nlp.data.Counter(train_dataset), padding_token=None, bos_token=None) +train_batchify = nlp.data.batchify.CorpusBatchify(vocab, args.batch_size) +train_data = train_batchify(train_dataset) +val_batch_size = 10 +val_batchify = nlp.data.batchify.CorpusBatchify(vocab, val_batch_size) +val_data = val_batchify(val_dataset) +test_batch_size = 1 +test_batchify = nlp.data.batchify.CorpusBatchify(vocab, test_batch_size) +test_data = test_batchify(test_dataset) + +if args.test_mode: + args.emsize = 200 + args.nhid = 200 + args.nlayers = 1 + args.epochs = 3 + train_data = train_data[0:100] + val_data = val_data[0:100] + test_data = test_data[0:100] + +print(args) + +############################################################################### +# Build the model +############################################################################### + +ntokens = len(vocab) + +if args.weight_dropout > 0: + print('Use AWDRNN') + model = nlp.model.train.AWDRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, + args.tied, args.dropout, args.weight_dropout, + args.dropout_h, args.dropout_i, args.dropout_e) + model.initialize(mx.init.Xavier(), ctx=context) + model_eval = nlp.model.AWDRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, + args.tied, args.dropout, args.weight_dropout, + args.dropout_h, args.dropout_i, args.dropout_e, + params=model.collect_params()) +else: + model = nlp.model.train.StandardRNN(args.model, len(vocab), args.emsize, + args.nhid, args.nlayers, args.dropout, args.tied) + model.initialize(mx.init.Xavier(), ctx=context) + model_eval = nlp.model.StandardRNN(args.model, len(vocab), args.emsize, + args.nhid, args.nlayers, args.dropout, args.tied, + params=model.collect_params()) + + +model.hybridize(static_alloc=True) + +print(model) + +if args.optimizer == 'sgd': + trainer_params = {'learning_rate': args.lr, + 'momentum': 0, + 'wd': args.wd} +elif args.optimizer == 'adam': + trainer_params = {'learning_rate': args.lr, + 'wd': args.wd, + 'beta1': 0, + 'beta2': 0.999, + 'epsilon': 1e-9} + +trainer = gluon.Trainer(model.collect_params(), args.optimizer, trainer_params, + update_on_kvstore=False) + +loss = gluon.loss.SoftmaxCrossEntropyLoss() +train_loss = JointActivationRegularizationLoss(loss, args.alpha, args.beta) + +sampler = BatchVariableLenTextSampler(bptt=70, length=len(train_data)) +val_sampler = BatchVariableLenTextSampler(bptt=70, length=len(val_data), use_variable_length=False) +test_sampler = BatchVariableLenTextSampler(bptt=70, length=len(test_data), + use_variable_length=False) +train_data_loader = mx.gluon.data.DataLoader(train_data, + batch_sampler=sampler) +val_data_loader = mx.gluon.data.DataLoader(val_data, + batch_sampler=val_sampler) +test_data_loader = mx.gluon.data.DataLoader(test_data, + batch_sampler=test_sampler) + +train_metric = mx.metric.Loss(train_loss) +val_metric = mx.metric.Loss(loss) +batch_processor = LanguageModelBatchProcessor() +est = LanguageModelEstimator(net=model, loss=train_loss, + train_metrics=train_metric, + val_metrics=val_metric, + trainer=trainer, context=context, + val_loss=loss, + val_net=model_eval, + batch_processor=batch_processor) +event_handlers = [HiddenStateHandler(), + AvgParamHandler(data_length=len(train_data)), + LearningRateHandler(lr_update_interval=args.lr_update_interval, + lr_update_factor=args.lr_update_factor), + RNNGradientUpdateHandler(clip=args.clip), + LoggingHandler(log_interval=args.log_interval, + metrics=est.train_metrics + est.val_metrics), + MetricResetHandler(metrics=est.train_metrics, + log_interval=args.log_interval), + WordLanguageModelCheckpointHandler(args.save)] +est.fit(train_data=train_data_loader, val_data=val_data_loader, + epochs=args.epochs, + event_handlers=event_handlers, + batch_axis=1) + +est.net.load_parameters(args.save) +est.evaluate(val_data=val_data_loader, event_handlers=[HiddenStateHandler()], batch_axis=1) +est.evaluate(val_data=test_data_loader, event_handlers=[HiddenStateHandler()], batch_axis=1) diff --git a/src/gluonnlp/__init__.py b/src/gluonnlp/__init__.py index 7a588e8233..f9772b95fc 100644 --- a/src/gluonnlp/__init__.py +++ b/src/gluonnlp/__init__.py @@ -30,6 +30,7 @@ from . import vocab from . import optimizer from . import initializer +from . import estimator from .vocab import Vocab __version__ = '0.10.0.dev' @@ -43,7 +44,8 @@ 'initializer', 'optimizer', 'utils', - 'metric'] + 'metric', + 'estimator'] warnings.filterwarnings(module='gluonnlp', action='default', category=DeprecationWarning) utils.version.check_version('1.6.0', warning_only=True, library=mxnet) diff --git a/src/gluonnlp/estimator/__init__.py b/src/gluonnlp/estimator/__init__.py new file mode 100644 index 0000000000..e8de9fa8e5 --- /dev/null +++ b/src/gluonnlp/estimator/__init__.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-variable + +""" Gluon NLP Estimator Module """ +from . import language_model_estimator, language_model_event_handler +from . import language_model_batch_processor + +from .language_model_estimator import * +from .language_model_event_handler import * +from .language_model_batch_processor import * + +__all__ = (language_model_estimator.__all__ + language_model_event_handler.__all__ + + language_model_batch_processor.__all__) diff --git a/src/gluonnlp/estimator/language_model_batch_processor.py b/src/gluonnlp/estimator/language_model_batch_processor.py new file mode 100644 index 0000000000..74b26cde23 --- /dev/null +++ b/src/gluonnlp/estimator/language_model_batch_processor.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-variable +""" Gluon Languange Model Estimator """ + +import mxnet as mx +from mxnet.gluon.contrib.estimator import BatchProcessor +from mxnet.gluon.utils import split_and_load +from ..utils import Parallel +from ..model.train.language_model import ParallelBigRNN + +__all__ = ['LanguageModelBatchProcessor', 'ParallelLanguageModelBatchProcessor'] + +class LanguageModelBatchProcessor(BatchProcessor): + """Word language model batch processor + + Batch training and validation for word language model + """ + def __init__(self): + super(LanguageModelBatchProcessor, self).__init__() + + def fit_batch(self, estimator, train_batch, batch_axis=0): + data = train_batch[:-1] + target = train_batch[1:] + batch_size = train_batch.shape[batch_axis] + data = split_and_load(data, estimator.context, batch_axis=batch_axis, even_split=True) + target = split_and_load(target, estimator.context, batch_axis=batch_axis, even_split=True) + if estimator.hiddens is None: + estimator.hiddens = [estimator.net.begin_state(batch_size // len(estimator.context), + func=mx.nd.zeros, + ctx=ctx) for ctx in estimator.context] + else: + estimator.hiddens = estimator.detach(estimator.hiddens) + + Ls = [] + outputs = [] + data_size = 0 + with mx.autograd.record(): + for i, (X, y, h) in enumerate(zip(data, target, estimator.hiddens)): + data_size = X.size + output, h, encoder_hs, dropped_encoder_hs = estimator.net(X, h) + l = estimator.loss(output, y, encoder_hs, dropped_encoder_hs) + Ls.append(l / (len(estimator.context) * X.size)) + estimator.hiddens[i] = h + outputs.append(output) + + for L in Ls: + L.backward() + + Ls = [l * (len(estimator.context) * data_size) for l in Ls] + return data, target, outputs, Ls + + def evaluate_batch(self, estimator, val_batch, batch_axis=0): + data = val_batch[:-1] + target = val_batch[1:] + batch_size = val_batch.shape[batch_axis] + data = split_and_load(data, estimator.context, batch_axis=batch_axis, even_split=True) + target = split_and_load(target, estimator.context, batch_axis=batch_axis, even_split=True) + + Ls = [] + outputs = [] + if estimator.val_hiddens is None: + estimator.val_hiddens = \ + [estimator.val_net.begin_state(batch_size // + len(estimator.context), + func=mx.nd.zeros, ctx=ctx) for ctx + in estimator.context] + else: + estimator.val_hiddens = estimator.detach(estimator.val_hiddens) + for i, (X, y, h) in enumerate(zip(data, target, estimator.val_hiddens)): + output, h = estimator.val_net(X, h) + L = estimator.val_loss(output.reshape(-3, -1), y.reshape(-1,)) + estimator.val_hiddens[i] = h + Ls.append(L) + outputs.append(output) + + return data, target, outputs, Ls + +class ParallelLanguageModelBatchProcessor(BatchProcessor): + """Parallel large RNN batch processor + + Batch training and validation for parallel large RNN model + + Parameters + ---------- + loss : mxnet.gluon.loss.Loss + Training loss function for parallel large rnn model + vocab : gluonnlp.vocab + Vocab of training and validation dataset + batch_size : int + Training batch size. It is used to construct the initial hidden states of + model + val_batch_size : int + Validation batch size. It is used to construct the initial hidden states + of validation model. + """ + def __init__(self, loss, vocab, batch_size, val_batch_size): + super(ParallelLanguageModelBatchProcessor, self).__init__() + self.loss = loss + self.parallel_model = None + self.batch_size = batch_size + self.val_batch_size = val_batch_size + self.vocab = vocab + + def _get_parallel_model(self, estimator): + if self.parallel_model is None: + self.parallel_model = ParallelBigRNN(estimator.net, self.loss, self.batch_size) + self.parallel_model = Parallel(len(estimator.context), self.parallel_model) + + def fit_batch(self, estimator, train_batch, batch_axis=0): + self._get_parallel_model(estimator) + data, target, mask, sample = train_batch + if estimator.hiddens is None: + estimator.hiddens = [estimator.net.begin_state(batch_size=self.batch_size, + func=mx.nd.zeros, + ctx=ctx) for ctx in estimator.context] + else: + estimator.hiddens = estimator.detach(estimator.hiddens) + Ls = [] + for _, batch in enumerate(zip(data, target, mask, sample, estimator.hiddens)): + self.parallel_model.put(batch) + + for _ in range(len(data)): + hidden, ls = self.parallel_model.get() + index = estimator.context.index(hidden[0].context) + estimator.hiddens[index] = hidden + Ls.append(ls) + + Ls = [l / estimator.bptt for l in Ls] + Ls = [mx.nd.sum(l) for l in Ls] + return data, target, None, Ls + + def evaluate_batch(self, estimator, val_batch, batch_axis=0): + data, target = val_batch + ctx = estimator.context[0] + data = data.as_in_context(ctx) + target = target.as_in_context(ctx) + if estimator.val_hiddens is None: + estimator.val_hiddens = \ + estimator.val_net.begin_state(batch_size=self.val_batch_size, + func=mx.nd.zeros, ctx=ctx) + else: + estimator.val_hiddens = estimator.detach(estimator.val_hiddens) + + mask = data != self.vocab[self.vocab.padding_token] + mask = mask.reshape(-1) + output, estimator.val_hiddens = estimator.val_net(data, estimator.val_hiddens) + output = output.reshape((-3, -1)) + L = estimator.val_loss(output, target.reshape(-1, ) * mask.reshape(-1)) + + return data, [target, mask], output, L diff --git a/src/gluonnlp/estimator/language_model_estimator.py b/src/gluonnlp/estimator/language_model_estimator.py new file mode 100644 index 0000000000..4eb120ea28 --- /dev/null +++ b/src/gluonnlp/estimator/language_model_estimator.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-variable +""" Gluon Languange Model Estimator """ + +from mxnet.gluon.contrib.estimator import Estimator +from .language_model_batch_processor import LanguageModelBatchProcessor + +__all__ = ['LanguageModelEstimator'] + +class LanguageModelEstimator(Estimator): + """Language Model Estimator + + Estimator class to facilitate the language model training and validation process + + Parameters + ---------- + net : gluon.Block + The model used for training. + loss : gluon.loss.Loss + Loss (objective) function to calculate during training. + train_metrics : EvalMetric or list of EvalMetric + Training metrics for evaluating models on training dataset. + val_metrics : EvalMetric or list of EvalMetric + Validation metrics for evaluating models on validation dataset. + initializer : Initializer + Initializer to initialize the network. + trainer : Trainer + Trainer to apply optimizer on network parameters. + context : Context or list of Context + Device(s) to run the training on. + val_net : gluon.Block + The model used for validation. The validation model does not necessarily belong to + the same model class as the training model. + val_loss : gluon.loss.loss + Loss (objective) function to calculate during validation. If set val_loss + None, it will use the same loss function as self.loss + batch_processor: BatchProcessor + BatchProcessor provides customized fit_batch() and evaluate_batch() methods + bptt : int + bptt value for the language model training. It decides how many time steps + to backpropate + """ + def __init__(self, net, loss, train_metrics=None, + val_metrics=None, + initializer=None, + trainer=None, + context=None, + val_loss=None, + val_net=None, + batch_processor=LanguageModelBatchProcessor(), + bptt=70): + super().__init__(net=net, loss=loss, + train_metrics=train_metrics, + val_metrics=val_metrics, + initializer=initializer, + trainer=trainer, + context=context, + val_loss=val_loss, + val_net=val_net, + batch_processor=batch_processor) + self.hiddens = None + self.val_hiddens = None + self.avg_param = None + self.bptt = bptt + self.ntasgd = False + + def detach(self, hidden): + if isinstance(hidden, (tuple, list)): + hidden = [self.detach(h) for h in hidden] + else: + hidden = hidden.detach() + return hidden diff --git a/src/gluonnlp/estimator/language_model_event_handler.py b/src/gluonnlp/estimator/language_model_event_handler.py new file mode 100644 index 0000000000..77ad23d1b1 --- /dev/null +++ b/src/gluonnlp/estimator/language_model_event_handler.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, unused-variable +""" Gluon Language Model Event Handler """ + +import time + +import mxnet as mx +from mxnet.gluon.contrib.estimator import EpochBegin, EpochEnd +from mxnet.gluon.contrib.estimator import BatchBegin, BatchEnd +from mxnet.gluon.contrib.estimator import GradientUpdateHandler, LoggingHandler +from mxnet.gluon.contrib.estimator import MetricHandler +from mxnet.gluon.utils import clip_global_norm +from mxnet.metric import Loss as MetricLoss +from ..metric.length_normalized_loss import LengthNormalizedLoss + +__all__ = ['HiddenStateHandler', 'AvgParamHandler', 'LearningRateHandler', + 'RNNGradientUpdateHandler', 'MetricResetHandler', + 'WordLanguageModelCheckpointHandler', 'ParallelLoggingHandler', + 'LargeRNNGradientUpdateHandler'] + +class HiddenStateHandler(EpochBegin): + """Hidden state reset event handler + + Reset hidden states for language model at each epoch + """ + def __init__(self): + pass + + def epoch_begin(self, estimator, *args, **kwargs): + estimator.hiddens = None + estimator.val_hiddens = None + +class AvgParamHandler(BatchEnd, EpochEnd): + """NTASGD average parameter event handler + + Average model parameters used in word language model estimator + + Parameters + ---------- + data_length: int + Length of training data, i.e., len(train_data). It is used to normalize the weight + average coefficient. + """ + def __init__(self, data_length): + self.epoch_id = 0 + self.batch_id = 0 + self.avg_trigger = 0 + self.t = 0 + self.n = 5 + self.valid_losses = [] + self.data_length = data_length + + def batch_end(self, estimator, *args, **kwargs): + parameters = estimator.net.collect_params() + if estimator.ntasgd: + if estimator.avg_param is None: + estimator.avg_param = \ + {k.split(estimator.net._prefix)[1]: + v.data(estimator.context[0]).copy() + for k, v in parameters.items()} + else: + gamma = 1. / max(1, self.epoch_id * (self.data_length // estimator.bptt) + + self.batch_id - self.avg_trigger + 2) + for key, val in estimator.avg_param.items(): + val[:] += gamma * (parameters['{}{}'.format(estimator.net._prefix, key)] + .data(estimator.context[0]) - val) + self.batch_id += 1 + + def epoch_end(self, estimator, *args, **kwargs): + if not isinstance(estimator.val_metrics, list): + val_metrics = [estimator.val_metrics] + else: + val_metrics = estimator.val_metrics + parameters = estimator.net.collect_params() + if self.avg_trigger == 0: + if self.t > self.n and val_metrics[0].get()[1] > min(self.valid_losses[-self.n:]): + if estimator.avg_param is None: + estimator.avg_param = \ + {k.split(estimator.net._prefix)[1]: + v.data(estimator.context[0]).copy() + for k, v in + parameters.items()} + else: + for key, val in parameters.items(): + estimator.avg_param[key.split(estimator.net._prefix)[1]] \ + = val.data(estimator.context[0]).copy() + self.avg_trigger = (self.epoch_id + 1) * (self.data_length // estimator.bptt) + print('Switching to NTASGD and avg_trigger is : %d' % self.avg_trigger) + estimator.ntasgd = True + self.valid_losses.append(val_metrics[0].get()[1]) + self.t += 1 + self.batch_id = 0 + self.epoch_id += 1 + +class LearningRateHandler(BatchBegin, BatchEnd, EpochEnd): + """NTASGD learning rate event handler + + Dynamically adjust the learning rate during word language model training + TODO: Investigate whether the learing rate event handler can be replaced with + learning rate scheduler + + Parameters + ---------- + lr_update_interval : int + Epoch interval of updating the learning rate during training the word + language model + lr_update_factor : float + learning rate decay factor used when updating the learning rate + """ + def __init__(self, lr_update_interval=30, lr_update_factor=0.1): + self.lr_batch_start = 0 + self.best_val = float('Inf') + self.update_lr_epoch = 0 + self.lr_update_interval = lr_update_interval + self.lr_update_factor = lr_update_factor + + def batch_begin(self, estimator, *args, **kwargs): + batch = kwargs['batch'] + self.lr_batch_start = estimator.trainer.learning_rate + seq_len = batch.shape[0] - 1 + estimator.trainer.set_learning_rate(self.lr_batch_start * seq_len / estimator.bptt) + + def batch_end(self, estimator, *args, **kwargs): + estimator.trainer.set_learning_rate(self.lr_batch_start) + + def epoch_end(self, estimator, *args, **kwargs): + if not isinstance(estimator.val_metrics, list): + val_metrics = [estimator.val_metrics] + else: + val_metrics = estimator.val_metrics + + if val_metrics[0].get()[1] < self.best_val: + self.update_lr_epoch = 0 + self.best_val = val_metrics[0].get()[1] + else: + self.update_lr_epoch += 1 + if self.update_lr_epoch % self.lr_update_interval == 0 and self.update_lr_epoch != 0: + lr_scale = estimator.trainer.learning_rate * self.lr_update_factor + estimator.trainer.set_learning_rate(lr_scale) + self.update_lr_epoch = 0 + +class RNNGradientUpdateHandler(GradientUpdateHandler): + """NTASGD gradient clipping update event handler + + clipping gradient during word language model training + Parameters + ---------- + clip : clip + Gradient clipping threshold. Gradient norm exceeds this value should be scaled + down within the valid range. + """ + def __init__(self, clip=None, **kwargs): + super().__init__(**kwargs) + self.clip = clip + + def batch_end(self, estimator, *args, **kwargs): + loss = kwargs['loss'] + loss_size = sum([l.size for l in loss]) + parameters = estimator.net.collect_params() + grads = [p.grad(ctx) for p in parameters.values() for ctx in estimator.context] + if self.clip is not None: + clip_global_norm(grads, self.clip) + + estimator.trainer.step(1) + +class LargeRNNGradientUpdateHandler(GradientUpdateHandler): + """Parallel Large RNN gradient clipping update event handler + + Rescale gradients of embedding parameters and clipping gradients of encoder parameters + during training parallel large RNN + + Parameters + ---------- + batch_size : int + batch size per gpu used during training parallel large RNN + clip : float + gradient clipping threshold. Gradients of encoder parameters exceed this value + should be scaled down within the valid range. + """ + def __init__(self, batch_size, clip=None, **kwargs): + super().__init__(**kwargs) + self.batch_size = batch_size + self.clip = clip + + def batch_end(self, estimator, *args, **kwargs): + encoder_params = estimator.net.encoder.collect_params().values() + embedding_params = list(estimator.net.embedding.collect_params().values()) + + for ctx in estimator.context: + x = embedding_params[0].grad(ctx) + x[:] *= self.batch_size + encoder_grad = [p.grad(ctx) for p in encoder_params] + clip_global_norm(encoder_grad, self.clip) + + estimator.trainer.step(len(estimator.context)) + +class MetricResetHandler(BatchBegin, MetricHandler): + """Event handler for reseting local metrics + + Reset local metrics for each few iterations and add support of LengthNormalizedMetrics + to compute both local and global metrics. + TODO: Move this event handler to be reusable by other estimators, e.g., + MachineTranslationEstimator + + Parameters + ---------- + Metrics : mxnet.metric + Metrics to be reset during training + log_interval : int or None + If log_interval is of int type, it represents the interval of reseting local + metrics. Otherwise, metrics do not need to be reset. + """ + def __init__(self, metrics, log_interval=None): + super().__init__(metrics=metrics) + self.batch_id = 0 + self.log_interval = log_interval + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + for metric in self.metrics: + metric.reset() + + def batch_begin(self, estimator, *args, **kwargs): + if self.log_interval is not None: + if self.batch_id % self.log_interval == 0: + for metric in self.metrics: + metric.reset_local() + self.batch_id += 1 + + def batch_end(self, estimator, *args, **kwargs): + pred = kwargs['pred'] + label = kwargs['label'] + loss = kwargs['loss'] + for metric in self.metrics: + if isinstance(metric, MetricLoss): + metric.update(0, loss) + elif isinstance(metric, LengthNormalizedLoss): + metric.update(label, loss) + else: + metric.update(label, pred) + +class WordLanguageModelCheckpointHandler(EpochEnd): + """Checkpoint Event handler of word language model + + Save the model checkpoint of word language model + + Parameters + ---------- + save : string + The model checkpoint save path prefix + """ + def __init__(self, save): + self.save = save + self.best_val = float('Inf') + + def epoch_end(self, estimator, *args, **kwargs): + if not isinstance(estimator.val_metrics, list): + val_metrics = [estimator.val_metrics] + else: + val_metrics = estimator.val_metrics + + if estimator.ntasgd: + mx.nd.save('{}.val.params'.format(self.save), estimator.avg_param) + else: + estimator.net.save_parameters('{}.val.params'.format(self.save)) + + if val_metrics[0].get()[1] < self.best_val: + self.best_val = val_metrics[0].get()[1] + if estimator.ntasgd: + mx.nd.save(self.save, estimator.avg_param) + else: + estimator.net.save_parameters(self.save) + + +class ParallelLoggingHandler(LoggingHandler): + """Logging handler of Parallel language model training + + Generating logging information of parallel large RNN training. This event handler + is designed specifically to handle the batches taken from multiple gpus. + """ + def __init__(self, *args, **kwargs): + super(ParallelLoggingHandler, self).__init__(*args, **kwargs) + + def batch_end(self, estimator, *args, **kwargs): + if isinstance(self.log_interval, int): + batch_time = time.time() - self.batch_start + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) + cur_batches = kwargs['batch'][0] + for batch in cur_batches: + self.processed_samples += batch.shape[0] + msg += '[Samples %s]' % (self.processed_samples) + self.log_interval_time += batch_time + if self.batch_index % self.log_interval == self.log_interval - 1: + msg += 'time/interval %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.metrics: + name, val = metric.get() + msg += '%s: %.4f, ' % (name, val) + estimator.logger.info(msg.rstrip(', ')) + self.batch_index += 1 diff --git a/src/gluonnlp/loss/joint_loss.py b/src/gluonnlp/loss/joint_loss.py new file mode 100644 index 0000000000..c62010cbf2 --- /dev/null +++ b/src/gluonnlp/loss/joint_loss.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Joint activation regularization loss """ +from mxnet import gluon +from . import ActivationRegularizationLoss, TemporalActivationRegularizationLoss + +__all__ = ['JointActivationRegularizationLoss'] + +class JointActivationRegularizationLoss(gluon.loss.Loss): + r"""Computes Joint Regularization Loss with standard loss. + + The activation regularization refer to + gluonnlp.loss.ActivationRegularizationLoss. + + The temporal activation regularization refer to + gluonnlp.loss.TemporalActivationRegularizationLoss. + + Parameters + ---------- + loss : gluon.loss.Loss + The standard loss + alpha: float + The activation regularization parameter in gluonnlp.loss.ActivationRegularizationLoss + beta: float + The temporal activation regularization parameter in + gluonnlp.loss.TemporalActivationRegularizationLoss + + Inputs: + - **out**: NDArray + output tensor with shape `(sequence_length, batch_size, input_size)` + when `layout` is "TNC". + - **target**: NDArray + target tensor with shape `(sequence_length, batch_size, input_size)` + when `layout` is "TNC". + - **states**: the stack outputs from RNN, + which consists of output from each time step (TNC). + - **dropped_states**: the stack outputs from RNN with dropout, + which consists of output from each time step (TNC). + + Outputs: + - **loss**: loss tensor with shape (batch_size,). Dimensions other than + batch_axis are averaged out. + """ + + def __init__(self, l, alpha, beta, weight=None, batch_axis=None, **kwargs): + super(JointActivationRegularizationLoss, self).__init__(weight, batch_axis, **kwargs) + self._loss = l + self._alpha, self._beta = alpha, beta + if alpha: + self._ar_loss = ActivationRegularizationLoss(alpha) + if beta: + self._tar_loss = TemporalActivationRegularizationLoss(beta) + + def __repr__(self): + s = 'JointActivationTemporalActivationRegularizationLoss' + return s + + def hybrid_forward(self, F, out, target, states, dropped_states): # pylint: disable=arguments-differ + # pylint: disable=unused-argument + l = self._loss(out.reshape(-3, -1), target.reshape(-1,)) + if self._alpha: + l = l + self._ar_loss(*dropped_states) + if self._beta: + l = l + self._tar_loss(*states) + return l