diff --git a/scripts/datasets/__main__.py b/scripts/datasets/__main__.py index 3558a1424c..301c7036e9 100644 --- a/scripts/datasets/__main__.py +++ b/scripts/datasets/__main__.py @@ -7,7 +7,8 @@ from .general_nlp_benchmark import prepare_glue from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY - +# TODO(zheyuye), lazy_import theses data parser functions and data main function +# and their dependencies by a dictionary mapping the datasets names to the functions. def list_all_subcommands(): out = [] for key in DATA_PARSER_REGISTRY.list_keys(): diff --git a/scripts/datasets/pretrain_corpus/README.md b/scripts/datasets/pretrain_corpus/README.md index 48d3a55810..49ace8d8eb 100644 --- a/scripts/datasets/pretrain_corpus/README.md +++ b/scripts/datasets/pretrain_corpus/README.md @@ -2,36 +2,41 @@ We provide a series of shared scripts for downloading/preparing the text corpus for pretraining NLP models. This helps create a unified text corpus for studying the performance of different pretraining algorithms. -When releasing the datasets, we follow the [FAIR principle](https://www.go-fair.org/fair-principles/), -i.e., the dataset needs to be findable, accessible, interoperable, and reusable. +When releasing the datasets, we follow the [FAIR principle](https://www.go-fair.org/fair-principles/), +i.e., the dataset needs to be findable, accessible, interoperable, and reusable. ## BookCorpus Unfortunately, we are unable to provide the original [Toronto BookCorpus dataset](https://yknzhu.wixsite.com/mbweb) due to licensing issues. There are some open source efforts for reproducing the dataset, e.g., - using [soskek/bookcorpus](https://github.com/soskek/bookcorpus) or directly downloading the [preprocessed version](https://drive.google.com/file/d/16KCjV9z_FHm8LgZw05RSuk4EsAWPOP_z/view). - + using [soskek/bookcorpus](https://github.com/soskek/bookcorpus) or directly downloading the [preprocessed version](https://drive.google.com/file/d/16KCjV9z_FHm8LgZw05RSuk4EsAWPOP_z/view). + Nevertheless, we utilize the [Project Gutenberg](https://www.gutenberg.org/) as an alternative to Toronto BookCorpus. -You can use the following command to download and prepare the Gutenberg dataset. +You can use the following command to download and prepare the Gutenberg dataset. ```bash python prepare_bookcorpus.py --dataset gutenberg ``` -Also, you should follow the [license](https://www.gutenberg.org/wiki/Gutenberg:The_Project_Gutenberg_License) for using the data. +Also, you should follow the [license](https://www.gutenberg.org/wiki/Gutenberg:The_Project_Gutenberg_License) for using the data. ## Wikipedia Please install [attardi/wikiextractor](https://github.com/attardi/wikiextractor) for preparing the data. -``` +```bash # Download python prepare_wikipedia.py --mode download --lang en --date latest -o ./ # Properly format the text files python prepare_wikipedia.py --mode format -i [path-to-wiki.xml.bz2] -o ./ +``` +The process of downloading and formatting is time consuming, and we offer an alternative solution to download the prepared raw text file from S3 bucket. This raw text file is in English and was dumped at 2020-06-20 being formated by the above very process (` --lang en --date 20200620`). + +```bash +python prepare_wikipedia.py --mode download_prepared -o ./ ``` ### References - [NVIDIA/DeepLearningExamples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT) @@ -43,7 +48,7 @@ You can download the OpenWebText from [link](https://skylion007.github.io/OpenWe After downloading and extracting the OpenWebText (i.e., `tar xf openwebtext.tar.xz`), you can use the following command to preprocess the dataset. ```bash -python prepare_openwebtext.py --input openwebtext/ --output prepared_owt +python prepare_openwebtext.py --input openwebtext/ --output prepared_owt --shuffle ``` In this step, the archived txt are directly read without decompressing. diff --git a/scripts/datasets/pretrain_corpus/prepare_bookcorpus.py b/scripts/datasets/pretrain_corpus/prepare_bookcorpus.py index 07eb2ce603..7e00f73a98 100644 --- a/scripts/datasets/pretrain_corpus/prepare_bookcorpus.py +++ b/scripts/datasets/pretrain_corpus/prepare_bookcorpus.py @@ -75,6 +75,7 @@ def main(args): filename = os.path.basename(name) f.extract(name, os.path.join(save_dir, filename)) else: + # TODO(zheyuye), format for pretraining raise NotImplementedError else: raise NotImplementedError diff --git a/scripts/datasets/pretrain_corpus/prepare_openwebtext.py b/scripts/datasets/pretrain_corpus/prepare_openwebtext.py index 8c2c2b79b4..3a60fe52de 100644 --- a/scripts/datasets/pretrain_corpus/prepare_openwebtext.py +++ b/scripts/datasets/pretrain_corpus/prepare_openwebtext.py @@ -51,8 +51,9 @@ def extract_files(full_name, output_dir, shuffle=False): """ if not full_name.endswith(".xz"): return - file_prefix = re.split('\.|/',full_name)[1] - with open("{}.txt".format(os.path.join(output_dir, file_prefix)),"w") as fp: + file_prefix = re.split(r'\.|/', full_name)[-2] + file_prefix = file_prefix.replace('urlsf_subset', 'openwebtext-prepared-') + with open("{}.txt".format(os.path.join(output_dir, file_prefix)), "w") as fp: with tarfile.open(full_name) as t: txt_names = t.getnames() if shuffle: @@ -63,9 +64,9 @@ def extract_files(full_name, output_dir, shuffle=False): # skip empty line line = line.strip() if line: - fp.write(line.decode()+'\n') + fp.write(line.decode() + '\n') # Two extra line break to mark the document separation - fp.write('\n\n') + fp.write('\n') @DATA_MAIN_REGISTRY.register('prepare_openwebtext') @@ -76,11 +77,16 @@ def main(args): fnames = sorted(os.listdir(args.input)) fnames = [os.path.join(args.input, fname) for fname in fnames] if args.shuffle: - fnames = random.shuffle(fnames) + random.shuffle(fnames) print('Start extracting {} files with {} cores'.format(len(fnames), num_process)) start_time = time.time() with multiprocessing.Pool(num_process) as pool: - iter = pool.imap(functools.partial(extract_files, output_dir=args.output, shuffle=args.shuffle), fnames) + iter = pool.imap( + functools.partial( + extract_files, + output_dir=args.output, + shuffle=args.shuffle), + fnames) for f_index, _ in enumerate(iter): if f_index > 0 and f_index % 250 == 0: elapsed = time.time() - start_time diff --git a/scripts/datasets/pretrain_corpus/prepare_wikipedia.py b/scripts/datasets/pretrain_corpus/prepare_wikipedia.py index b312533108..481598c22e 100644 --- a/scripts/datasets/pretrain_corpus/prepare_wikipedia.py +++ b/scripts/datasets/pretrain_corpus/prepare_wikipedia.py @@ -1,10 +1,15 @@ """Prepare the Wikipedia dataset that contain cleaned articles of all languages.""" import os import sys -import argparse import glob -from gluonnlp.utils.misc import download -from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY +import math +import time +import tarfile +import argparse +import multiprocessing + +from gluonnlp.registry import DATA_MAIN_REGISTRY, DATA_PARSER_REGISTRY +from gluonnlp.utils.misc import download, load_checksum_stats _CITATION = """\ @ONLINE {wikidump, @@ -47,6 +52,13 @@ _BASE_URL_TMPL\ = "https://dumps.wikimedia.org/{lang}wiki/{date}/{lang}wiki-{date}-pages-articles.xml.bz2" _CURR_DIR = os.path.realpath(os.path.dirname(os.path.realpath(__file__))) +_URL_FILE_STATS_PATH = os.path.join(_CURR_DIR, '..', 'url_checksums', 'wikipedia.txt') +_URL_FILE_STATS = load_checksum_stats(_URL_FILE_STATS_PATH) + +_URLS = { + 'wikipedia-en-20200620': + 'https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz', +} def get_url(lang, date): @@ -55,51 +67,58 @@ def get_url(lang, date): def try_import_wikiextractor(): try: + sys.path.append(_CURR_DIR) import WikiExtractor except ImportError: try: download( - 'https://raw.githubusercontent.com/attardi/wikiextractor/' - '16186e290d9eb0eb3a3784c6c0635a9ed7e855c3/WikiExtractor.py', + 'https://raw.githubusercontent.com/attardi/wikiextractor/master/WikiExtractor.py', path=os.path.join(_CURR_DIR, 'WikiExtractor.py'), sha1_hash='3c4896a837b75c476d23c037e8d6c7fdfd9a29eb') + sys.path.append(_CURR_DIR) import WikiExtractor - except: + except BaseException: raise ImportError('Cannot import WikiExtractor! You can download the "WikiExtractor.py"' ' in https://github.com/attardi/wikiextractor to {}' .format(_CURR_DIR)) return WikiExtractor -class WikicorpusTextFormatting: - def __init__(self, wiki_path, output_filename, recursive=False): - self.wiki_path = wiki_path - self.recursive = recursive - self.output_filename = output_filename - - # This puts one article per line - def merge(self): - with open(self.output_filename, mode='w', newline='\n') as ofile: - for dirname in glob.glob(os.path.join(self.wiki_path, '*'), recursive=False): - for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=self.recursive): - print(filename) - article_lines = [] - article_open = False - - with open(filename, mode='r', newline='\n') as file: - for line in file: - if '' in line: - article_open = False - for oline in article_lines[1:]: - if oline != '\n': - ofile.write(oline.rstrip() + " ") - ofile.write("\n\n") - article_lines = [] - else: - if article_open: - article_lines.append(line) +def get_formatting_list(wiki_path, recursive=False): + """ + get formatting list of file names from extracted content + """ + filenames = [] + for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False): + for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=recursive): + filenames.append(filename) + return filenames + + +def merge(x): + """ + Puts one article per line + """ + file_list, output_filename = x + article_lines = [] + article_open = False + + with open(output_filename, mode='w', newline='\n') as ofile: + for filename in file_list: + with open(filename, mode='r', newline='\n') as file: + for line in file: + if '' in line: + article_open = False + for oline in article_lines[1:]: + if oline != '\n': + ofile.write(oline.rstrip() + " ") + ofile.write("\n\n") + article_lines = [] + else: + if article_open: + article_lines.append(line) @DATA_PARSER_REGISTRY.register('prepare_wikipedia') @@ -107,12 +126,13 @@ def get_parser(): parser = argparse.ArgumentParser(description='Download and Prepare the Wikipedia') parser.add_argument('--mode', type=str, default='download+format', - choices=['download', 'format', 'download+format'], + choices=['download', 'format', 'download+format', 'download_prepared'], help='Specify the action you want the app to take. ' '"download" means to download the Wikipedia dump. ' '"format" means to extract the content and ' 'format it for pretraining. "download+format" means to combine ' - 'these two options') + 'these two options' + '"download_prepared" downloads the prepared txt from S3 directly') parser.add_argument('--lang', type=str, default='en', help='Language of the wikipedia dump file.' 'We only support English and Chinese for current version') @@ -124,8 +144,13 @@ def get_parser(): parser.add_argument("-o", "--output", default="wikicorpus", help="directory for downloaded or formatted files") parser.add_argument("-b", "--bytes", default="100M", - help="maximum bytes per output file (default %(default)s)", + help="maximum bytes per extracted file (default %(default)s)", metavar="n[KMG]") + parser.add_argument("--num_process", type=int, default=8, + help="number of processes for multiprocessing") + parser.add_argument("--num_out_files", type=int, default=1000, + help="Number of desired output files, where each is processed" + " independently by a worker.") return parser @@ -145,32 +170,75 @@ def download_wikicorpus(lang, date, output): return output_file -def format_wikicorpus(input, output, bytes): +def format_wikicorpus(input, output, bytes, num_process, num_out_files): if input is None: raise ValueError('input file is empty.') if not input.endswith('xml.bz2'): raise ValueError('input file not *.xml.bz2.') if not os.path.exists(output): os.makedirs(output) + # Use WikiExtractor to extract the content WikiExtractor = try_import_wikiextractor() wiki_path = os.path.join(output, 'extracted') sys.argv = ['prog', '-b', bytes, '-o', wiki_path, input] WikiExtractor.main() - output_filename = os.path.join(output, 'wikicorpus_one_article_per_line.txt') - wiki_formatter = WikicorpusTextFormatting(wiki_path, output_filename, recursive=True) - wiki_formatter.merge() + + # Merge extracted content into txt files + prepared_path = os.path.join(output, 'prepared_wikipedia') + if not os.path.exists(prepared_path): + os.makedirs(prepared_path) + filenames = get_formatting_list(wiki_path, recursive=True) + num_files = len(filenames) + num_out_files = min(num_out_files, num_files) + file_volume = math.ceil(num_files / num_out_files) + splited_files = [filenames[i: i + file_volume] for i in range(0, num_files, file_volume)] + num_out_files = len(splited_files) + output_files = [ + os.path.join( + prepared_path, + "wikipedia-prepared-{}.txt".format( + str(i).zfill(4))) for i in range(num_out_files)] + print("All prepared raw text will be saved in {} txt files".format(num_out_files)) + num_process = min(num_process, num_out_files) + print('Start preprocessing {} text files with {} cores'.format(num_files, num_process)) + process_args = [(splited_files[i], output_files[i]) for i in range(num_out_files)] + + start_time = time.time() + with multiprocessing.Pool(num_process) as pool: + f_read = 0 + for i, _ in enumerate(pool.imap(merge, process_args)): + elapsed = time.time() - start_time + f_read += len(splited_files[i]) + print("prepared {:} files, Elapsed: {:.2f}s, ETA: {:.2f}s, ".format( + f_read, elapsed, (num_files - f_read) / (num_files / elapsed))) + print("Done preparation within {:.2f} seconds".format(elapsed)) @DATA_MAIN_REGISTRY.register('prepare_wikipedia') def main(args): + num_process = min(multiprocessing.cpu_count(), args.num_process) if args.mode == 'download': download_wikicorpus(args.lang, args.date, args.output) elif args.mode == 'format': - format_wikicorpus(args.input, args.output, args.bytes) + format_wikicorpus(args.input, args.output, args.bytes, num_process, args.num_out_files) elif args.mode == 'download+format': downloaded_file = download_wikicorpus(args.lang, args.date, args.output) - format_wikicorpus(downloaded_file, args.output, args.bytes) + format_wikicorpus(downloaded_file, args.output, args.bytes, num_process, args.num_out_files) + elif args.mode == 'download_prepared': + url = _URLS['wikipedia-en-20200620'] + file_hash = _URL_FILE_STATS[url] + target_download_location = os.path.join(args.output, + os.path.basename(url)) + download(url, target_download_location, sha1_hash=file_hash) + tar = tarfile.open(target_download_location) + names = tar.getnames() + print('Start unarchiving raw text files') + start_time = time.time() + for name in names: + tar.extract(name, path=args.output) + tar.close() + print("Done unarchiving within {:.2f} seconds".format(time.time() - start_time)) else: raise NotImplementedError diff --git a/scripts/datasets/url_checksums/wikipedia.txt b/scripts/datasets/url_checksums/wikipedia.txt new file mode 100644 index 0000000000..2f4c117a9e --- /dev/null +++ b/scripts/datasets/url_checksums/wikipedia.txt @@ -0,0 +1 @@ +https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz 1e1d77c31622744aaa45ff5bfbfca397154d9186 5068070627 diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 089d0267c0..3b0a6565e5 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -367,7 +367,7 @@ def train(args): seed=args.seed) else: raise NotImplementedError - + batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader(data_train, batch_sampler=train_batch_sampler, @@ -387,7 +387,6 @@ def train(args): log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class - accum_count = 0 loss_denom = 0 n_train_iters = 0 @@ -471,12 +470,10 @@ def train(args): deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break - if args.epochs > 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) - avg_valid_loss = validation(model, val_data_loader, ctx_l) logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}' .format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) diff --git a/scripts/pretraining/README.md b/scripts/pretraining/README.md index 700ca7962a..3354d792c4 100644 --- a/scripts/pretraining/README.md +++ b/scripts/pretraining/README.md @@ -3,25 +3,27 @@ Following the instruction of [Prepare OpenWebTextCorpus](../datasets/pretrain_corpus#openwebtext), download and prepare the dataset, obtaining a total of 20610 text files in the folder `prepared_owt`. ```bash -python preprocesse_owt.py --input prepared_owt --output preprocessed_owt --shuffle +python data_preprocessing.py --input prepared_owt --output preprocessed_owt --max_seq_length 128 --shuffle ``` The above command allows us to generate the preprocessed Numpy features saved in `.npz`. # Pretrain Model ## ELECTRA +Following [Official Quickstart](https://github.com/google-research/electra#quickstart-pre-train-a-small-electra-model), pretrain a small model using OpenWebText as pretraining corpus. Note that [horovod](https://github.com/horovod/horovod) needs to be installed in advance, if `comm_backend` is set to `horovod`. ```bash -horovodrun -np 8 -H localhost:8 python -m run_electra \ +horovodrun -np 2 -H localhost:2 python -m run_electra \ --model_name google_electra_small \ - --data `preprocessed_owt/*.npz` \ - --gpus 0,1,2,3,4,5,6,7 \ + --data 'preprocessed_owt/*.npz' \ + --generator_units_scale 0.25 \ + --gpus 0,1 \ --do_train \ --do_eval \ --output_dir ${OUTPUT} \ - --num_accumulated ${ACCMULATE} \ - --batch_size ${BS} \ - --lr ${LR} \ - --wd ${WD} \ - --max_seq_len ${MSL} \ + --num_accumulated 1 \ + --batch_size 64 \ + --lr 5e-4 \ + --wd 0.01 \ + --max_seq_len 128 \ --max_grad_norm 1 \ --warmup_steps 10000 \ --num_train_steps 1000000 \ @@ -31,21 +33,22 @@ horovodrun -np 8 -H localhost:8 python -m run_electra \ --comm_backend horovod \ ``` -Or we could preprocessing the features on the fly based on the `.txt` files like +Alternatively, we could preprocessing the features on the fly and train this model with raw text directly like ```bash -horovodrun -np 8 -H localhost:8 python -m run_electra \ +horovodrun -np 2 -H localhost:2 python -m run_electra \ --model_name google_electra_small \ - --data `prepared_owt/*.txt` \ + --generator_units_scale 0.25 \ + --data 'prepared_owt/*.txt' \ --from_raw \ - --gpus 0,1,2,3,4,5,6,7 \ + --gpus 0,1 \ --do_train \ --do_eval \ --output_dir ${OUTPUT} \ - --num_accumulated ${ACCMULATE} \ - --batch_size ${BS} \ - --lr ${LR} \ - --wd ${WD} \ - --max_seq_len ${MSL} \ + --num_accumulated 1 \ + --batch_size 64 \ + --lr 5e-4 \ + --wd 0.01 \ + --max_seq_len 128 \ --max_grad_norm 1 \ --warmup_steps 10000 \ --num_train_steps 1000000 \ @@ -54,3 +57,45 @@ horovodrun -np 8 -H localhost:8 python -m run_electra \ --mask_prob 0.15 \ --comm_backend horovod \ ``` + +For the convenience of verification, the pretrained small model trained on OpenWebText named `gluon_electra_small_owt` is released and uploaded to S3 with directory structure as + +``` +gluon_electra_small_owt +├── vocab-{short_hash}.json +├── model-{short_hash}.params +├── model-{short_hash}.yml +├── gen_model-{short_hash}.params +├── disc_model-{short_hash}.params +``` + +After pretraining, several downstream NLP tasks such as Question Answering are available to fine-tune. Here is an example of fine-tuning a local pretrained model on [SQuAD 1.1/2.0](../question_answering#squad). + +```bash +python run_squad.py \ + --model_name google_electra_small \ + --data_dir squad \ + --backbone_path ${OUTPUT}/model-{short_hash}.params \ + --output_dir ${FINE-TUNE_OUTPUT} \ + --version ${VERSION} \ + --do_eval \ + --do_train \ + --batch_size 32 \ + --num_accumulated 1 \ + --gpus 0 \ + --epochs 2 \ + --lr 3e-4 \ + --layerwise_decay 0.8 \ + --warmup_ratio 0.1 \ + --max_saved_ckpt 6 \ + --all_evaluate \ + --wd 0 \ + --max_seq_length 128 \ + --max_grad_norm 0.1 \ +``` + +Resulting in the following output + +| Model Name | SQuAD1.1 dev | SQuAD2.0 dev | +|--------------------------|---------------|--------------| +|gluon_electra_small_owt | 69.40/76.98 | 67.63/69.89 | diff --git a/scripts/pretraining/preprocesse_owt.py b/scripts/pretraining/data_preprocessing.py similarity index 78% rename from scripts/pretraining/preprocesse_owt.py rename to scripts/pretraining/data_preprocessing.py index e75d97ea41..067fbf1634 100644 --- a/scripts/pretraining/preprocesse_owt.py +++ b/scripts/pretraining/data_preprocessing.py @@ -4,11 +4,14 @@ import os import time import math +import random import argparse import multiprocessing +import numpy as np + from pretraining_utils import get_all_features -from gluonnlp.data.tokenizers import HuggingFaceWordPieceTokenizer +from gluonnlp.models import get_backbone def get_parser(): @@ -19,14 +22,15 @@ def get_parser(): help="directory for preprocessed features") parser.add_argument("--num_process", type=int, default=8, help="number of processes for multiprocessing") - parser.add_argument("--vocab_file", default="vocab-c3b41053.json", - help="vocabulary file of HuggingFaceWordPieceTokenizer" - " for electra small model") parser.add_argument("--max_seq_length", type=int, default=128, help="the maximum length of the pretraining sequence") parser.add_argument("--num_out_files", type=int, default=1000, help="Number of desired output files, where each is processed" " independently by a worker.") + parser.add_argument('--model_name', type=str, default='google_electra_small', + help='Name of the pretrained model.') + parser.add_argument("--shuffle", action="store_true", + help="Wether to shuffle the data order") parser.add_argument("--do_lower_case", dest='do_lower_case', action="store_true", help="Lower case input text.") parser.add_argument("--no_lower_case", dest='do_lower_case', @@ -40,23 +44,17 @@ def get_parser(): def main(args): num_process = min(multiprocessing.cpu_count(), args.num_process) - assert os.path.isfile(args.vocab_file), 'Cannot find vocab file' - # TODO(zheyuye), download the vocab_file from zoos and check it with sha1 hash. - tokenizer = HuggingFaceWordPieceTokenizer( - vocab_file=args.vocab_file, - unk_token='[UNK]', - pad_token='[PAD]', - cls_token='[CLS]', - sep_token='[SEP]', - mask_token='[MASK]', - lowercase=args.do_lower_case) + _, cfg, tokenizer, _, _ = \ + get_backbone(args.model_name, load_backbone=False) fnames = sorted(os.listdir(args.input)) fnames = [os.path.join(args.input, fname) for fname in fnames] + if args.shuffle: + random.shuffle(fnames) num_files = len(fnames) num_out_files = min(args.num_out_files, num_files) file_volume = math.ceil(num_files / num_out_files) - splited_files = [fnames[i: i + file_volume] for i in range(0, num_files, file_volume)] + splited_files = np.array_split(fnames, file_volume) num_out_files = len(splited_files) output_files = [os.path.join( args.output, "owt-pretrain-record-{}.npz".format(str(i).zfill(4))) for i in range(num_out_files)] @@ -83,7 +81,7 @@ def main(args): fea_written += len(np_features[0]) f_read += len(splited_files[i]) print("Processed {:} files, Elapsed: {:.2f}s, ETA: {:.2f}s, ".format( - fea_written, elapsed, (num_files - f_read) / (num_files / elapsed))) + fea_written, elapsed, (num_files - f_read) / (f_read / elapsed))) print("Done processing within {:.2f} seconds".format(elapsed)) diff --git a/scripts/pretraining/pretraining_utils.py b/scripts/pretraining/pretraining_utils.py index c5ea5289ea..cdb2d6d380 100644 --- a/scripts/pretraining/pretraining_utils.py +++ b/scripts/pretraining/pretraining_utils.py @@ -41,7 +41,6 @@ def tokenize_lines_to_ids(lines, tokenizer): """ results = [] # tag line delimiters or doc delimiters - line_delimiters = False for line in lines: if not line: break @@ -49,9 +48,7 @@ def tokenize_lines_to_ids(lines, tokenizer): # Single empty lines are used as line delimiters # Double empty lines are used as document delimiters if not line: - if not line_delimiters: - results.append([]) - line_delimiters = not line_delimiters + results.append([]) else: token_ids = tokenizer.encode(line, int) if token_ids: @@ -125,8 +122,9 @@ def process_a_text(text_file, tokenizer, max_seq_length, short_seq_prob=0.05): for tokenized_line in tokenized_lines: current_sentences.append(tokenized_line) current_length += len(tokenized_line) - # Create feature when meets the empty line or reaches the target length - if (not tokenized_line and current_length != 0) or (current_length >= target_seq_length): + # Create feature when meets the empty line or reaches the target length + if (not tokenized_line and current_length != 0) or ( + current_length >= target_seq_length): first_segment, second_segment = \ sentenceize(current_sentences, max_seq_length, target_seq_length) @@ -265,11 +263,11 @@ def prepare_pretrain_text_dataset( """Create dataset based on the raw text files""" if not isinstance(filenames, (list, tuple)): filenames = [filenames] - # generate a filename based on the input filename ensuring no crash. - # filename example: urlsf_subset00-130_data.txt - suffix = re.findall(r'\d+-\d+', filenames[0])[0] if cached_file_path: - output_file = os.path.join(cached_file_path, "owt-pretrain-record-{}.npz".format(suffix)) + # generate a filename based on the input filename ensuring no crash. + # filename example: urlsf_subset00-130_data.txt + suffix = re.split(r'\.|/', filenames[0])[-2] + output_file = os.path.join(cached_file_path, "{}-pretrain-record.npz".format(suffix)) else: output_file = None np_features = get_all_features( @@ -496,13 +494,17 @@ def dynamic_masking(self, F, input_ids, valid_lengths): valid_candidates = valid_candidates.astype(np.float32) num_masked_position = F.np.maximum( 1, F.np.minimum(N, round(valid_lengths * self._mask_prob))) - # The categorical distribution takes normalized probabilities as input - # softmax is used here instead of log_softmax + + # Get the masking probability of each position sample_probs = F.npx.softmax( self._proposal_distribution * valid_candidates, axis=-1) # (B, L) - # Top-k Sampling is an alternative solution to avoid duplicates positions - masked_positions = F.npx.random.categorical( - sample_probs, shape=N, dtype=np.int32) + sample_probs = F.npx.stop_gradient(sample_probs) + gumbels = F.np.random.gumbel(F.np.zeros_like(sample_probs)) + # Following the instruction of official repo to avoid deduplicate postions + # with Top_k Sampling as https://github.com/google-research/electra/issues/41 + masked_positions = F.npx.topk( + sample_probs + gumbels, k=N, axis=-1, ret_typ='indices', dtype=np.int32) + masked_weights = F.npx.sequence_mask( F.np.ones_like(masked_positions), sequence_length=num_masked_position, @@ -511,21 +513,27 @@ def dynamic_masking(self, F, input_ids, valid_lengths): length_masks = F.npx.sequence_mask( F.np.ones_like(input_ids, dtype=np.float32), sequence_length=valid_lengths, - use_sequence_length=True, axis=1, value=0).astype(np.float32) + use_sequence_length=True, axis=1, value=0) unmasked_tokens = select_vectors_by_position( F, input_ids, masked_positions) * masked_weights masked_weights = masked_weights.astype(np.float32) - replaced_positions = ( F.np.random.uniform( F.np.zeros_like(masked_positions), F.np.ones_like(masked_positions)) > self._mask_prob) * masked_positions - # deal with multiple zeros + # dealling with multiple zero values in replaced_positions which causes + # the [CLS] being replaced filled = F.np.where( replaced_positions, self.vocab.mask_id, - masked_positions).astype(np.int32) - masked_input_ids, _ = updated_vectors_by_position(F, input_ids, filled, replaced_positions) + self.vocab.cls_id).astype( + np.int32) + # Masking token by replacing with [MASK] + masked_input_ids = updated_vectors_by_position(F, input_ids, filled, replaced_positions) + + # Note: It is likely have multiple zero values in masked_positions if number of masked of + # positions not reached the maximum. However, this example hardly exists since valid_length + # is almost always equal to max_seq_length masked_input = self.MaskedInput(input_ids=masked_input_ids, masks=length_masks, unmasked_tokens=unmasked_tokens, diff --git a/scripts/pretraining/run_electra.py b/scripts/pretraining/run_electra.py index 60f8277268..2a67017106 100644 --- a/scripts/pretraining/run_electra.py +++ b/scripts/pretraining/run_electra.py @@ -1,7 +1,6 @@ """Pretraining Example for Electra Model on the OpenWebText dataset""" import os -import sys import time import shutil import logging @@ -15,7 +14,7 @@ from sklearn import metrics from pretraining_utils import ElectraMasker, get_pretrain_data_npz, get_pretrain_data_text -from gluonnlp.utils.misc import grouper, set_seed, naming_convention, logging_config +from gluonnlp.utils.misc import repeat, grouper, set_seed, init_comm, logging_config, naming_convention from gluonnlp.initializer import TruncNorm from gluonnlp.models.electra import ElectraModel, ElectraForPretrain, get_pretrained_electra from gluonnlp.utils.parameter import clip_grad_global_norm @@ -72,9 +71,10 @@ def parse_args(): help='If set, both training and dev samples are generated on-the-fly ' 'from raw texts instead of pre-processed npz files. ') parser.add_argument("--short_seq_prob", type=float, default=0.05, - help="The probability of sampling sequences shorter than the max_seq_length.") + help='The probability of sampling sequences ' + 'shorter than the max_seq_length.') parser.add_argument("--cached_file_path", default=None, - help="Directory for saving preprocessed features") + help='Directory for saving preprocessed features') parser.add_argument('--circle_length', type=int, default=2, help='Number of files to be read for a single GPU at the same time.') parser.add_argument('--repeat', type=int, default=8, @@ -168,38 +168,6 @@ def get_pretraining_model(model_name, ctx_l, 'corrupted_tokens']) -def init_comm(backend, gpus): - """Init communication backend""" - # backend specific implementation - if backend == 'horovod': - try: - import horovod.mxnet as hvd # pylint: disable=import-outside-toplevel - except ImportError: - logging.info('horovod must be installed.') - sys.exit(1) - hvd.init() - store = None - num_workers = hvd.size() - rank = hvd.rank() - local_rank = hvd.local_rank() - is_master_node = rank == local_rank - ctx_l = [mx.gpu(local_rank)] - logging.info('GPU communication supported by horovod') - else: - store = mx.kv.create(backend) - num_workers = store.num_workers - rank = store.rank - local_rank = 0 - is_master_node = rank == local_rank - if gpus == '-1' or gpus == '': - ctx_l = [mx.cpu()] - logging.info('Runing on CPU') - else: - ctx_l = [mx.gpu(int(x)) for x in gpus.split(',')] - logging.info('GPU communication supported by KVStore') - - return store, num_workers, rank, local_rank, is_master_node, ctx_l - def final_save(model, save_dir, tokenizer): if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -258,6 +226,9 @@ def states_option(step_num, trainer, ckpt_dir, local_rank=0, option='Saving'): def train(args): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) + logging.info('Training info: num_buckets: {}, ' + 'num_workers: {}, rank: {}'.format( + args.num_buckets, num_workers, rank)) cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l, args.max_seq_length, args.hidden_dropout_prob, @@ -266,11 +237,8 @@ def train(args): args.generator_layers_scale) data_masker = ElectraMasker( tokenizer, args.max_seq_length, args.mask_prob) - logging.info('Training info: num_buckets: {}, ' - 'num_workers: {}, rank: {}'.format( - args.num_buckets, num_workers, rank)) if args.from_raw_text: - if not os.path.exists(args.cached_file_path): + if args.cached_file_path and not os.path.exists(args.cached_file_path): os.mkdir(args.cached_file_path) get_dataset_fn = functools.partial(get_pretrain_data_text, max_seq_length=args.max_seq_length, @@ -339,8 +307,6 @@ def train(args): 'epsilon': 1e-6, 'correct_bias': False, }) - # TODO(zheyuye), absentance of layer-wise decay, although the decay power - # is 1.0 in electra model if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: @@ -362,7 +328,7 @@ def train(args): # prepare the records writer writer = None - if args.do_eval: + if args.do_eval and local_rank == 0: from tensorboardX import SummaryWriter record_path = os.path.join(args.output_dir, 'records') logging.info('Evaluation records saved in {}'.format(record_path)) @@ -381,13 +347,13 @@ def train(args): if args.num_accumulated != 1: # set grad to zero for gradient accumulation model.collect_params().zero_grad() - while not finish_flag: + + # start training + train_loop_dataloader = grouper(repeat(data_train), len(ctx_l)) + while step_num < num_train_steps: tic = time.time() - batch_id = 0 - is_last_batch = False - train_dataloader = grouper(data_train, len(ctx_l)) - sample_l = next(train_dataloader) - while not is_last_batch: + for accum_idx in range(args.num_accumulated): + sample_l = next(train_loop_dataloader) loss_l = [] mlm_loss_l = [] rtd_loss_l = [] @@ -437,77 +403,64 @@ def train(args): for ele in rtd_loss_l]).asnumpy() log_total_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in loss_l]).asnumpy() * loss_denom - # pre fetch next batch - try: - sample_l = next(train_dataloader) - except StopIteration: - is_last_batch = True - - # update - if (batch_id + 1) % args.num_accumulated == 0 or is_last_batch: - trainer.allreduce_grads() - # Here, the accumulated gradients are - # \sum_{n=1}^N g_n / loss_denom - # Thus, in order to clip the average gradient - # \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm - # We need to change the ratio to be - # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom - total_norm, ratio, is_finite = clip_grad_global_norm( - params, args.max_grad_norm * num_samples_per_update / loss_denom) - total_norm = total_norm / (num_samples_per_update / loss_denom) - trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True) - step_num += 1 - if args.num_accumulated != 1: - # set grad to zero for gradient accumulation - model.collect_params().zero_grad() - - # saving - if step_num % save_interval == 0 or step_num >= num_train_steps: - if is_master_node: - states_option( - step_num, trainer, args.output_dir, local_rank, 'Saving') - if local_rank == 0: - param_path = parameters_option( - step_num, model, args.output_dir, 'Saving') - - # logging - if step_num % log_interval == 0 and local_rank == 0: - # Output the loss of per step - log_mlm_loss /= log_interval - log_rtd_loss /= log_interval - log_total_loss /= log_interval - toc = time.time() - logging.info( - '[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},' - ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},' - ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format( - step_num, log_mlm_loss, log_rtd_loss, log_total_loss, - trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), - (num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) - tic = time.time() - - if args.do_eval: - evaluation(writer, step_num, masked_input, output) - writer.add_scalars('loss', - {'total_loss': log_total_loss, - 'mlm_loss': log_mlm_loss, - 'rtd_loss': log_rtd_loss}, - step_num) - log_mlm_loss = 0 - log_rtd_loss = 0 - log_total_loss = 0 - log_sample_num = 0 - - num_samples_per_update = 0 - - if step_num >= num_train_steps: - logging.info('Finish training step: %d', step_num) - finish_flag = True - break - - batch_id += 1 - + # update + trainer.allreduce_grads() + # Here, the accumulated gradients are + # \sum_{n=1}^N g_n / loss_denom + # Thus, in order to clip the average gradient + # \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm + # We need to change the ratio to be + # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom + total_norm, ratio, is_finite = clip_grad_global_norm( + params, args.max_grad_norm * num_samples_per_update / loss_denom) + total_norm = total_norm / (num_samples_per_update / loss_denom) + trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True) + step_num += 1 + if args.num_accumulated != 1: + # set grad to zero for gradient accumulation + model.collect_params().zero_grad() + + # saving + if step_num % save_interval == 0 or step_num >= num_train_steps: + if is_master_node: + states_option( + step_num, trainer, args.output_dir, local_rank, 'Saving') + if local_rank == 0: + param_path = parameters_option( + step_num, model, args.output_dir, 'Saving') + + # logging + if step_num % log_interval == 0 and local_rank == 0: + # Output the loss of per step + log_mlm_loss /= log_interval + log_rtd_loss /= log_interval + log_total_loss /= log_interval + toc = time.time() + logging.info( + '[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},' + ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},' + ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format( + step_num, log_mlm_loss, log_rtd_loss, log_total_loss, + trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic), + (num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600)) + tic = time.time() + + if args.do_eval: + evaluation(writer, step_num, masked_input, output) + writer.add_scalars('loss', + {'total_loss': log_total_loss, + 'mlm_loss': log_mlm_loss, + 'rtd_loss': log_rtd_loss}, + step_num) + log_mlm_loss = 0 + log_rtd_loss = 0 + log_total_loss = 0 + log_sample_num = 0 + + num_samples_per_update = 0 + + logging.info('Finish training step: %d', step_num) if is_master_node: state_path = states_option(step_num, trainer, args.output_dir, local_rank, 'Saving') if local_rank == 0: @@ -522,7 +475,8 @@ def train(args): model_name = args.model_name.replace('google', 'gluon') save_dir = os.path.join(args.output_dir, model_name) final_save(model, save_dir, tokenizer) - return param_path, state_path + +# TODO(zheyuye), Directly implement a metric for weighted accuracy def accuracy(labels, predictions, weights=None): @@ -532,19 +486,21 @@ def accuracy(labels, predictions, weights=None): acc = (is_correct * weights).sum() / (weights.sum() + 1e-6) return acc +# TODO(zheyuye), Directly implement a metric for weighted AUC + -def auc(labels, probs, wights=None): +def auc(labels, probs, weights=None): if isinstance(labels, mx.np.ndarray): labels = labels.asnumpy() if isinstance(probs, mx.np.ndarray): probs = probs.asnumpy() - if isinstance(wights, mx.np.ndarray): - wights = wights.asnumpy() + if isinstance(weights, mx.np.ndarray): + weights = weights.asnumpy() labels = labels.reshape(-1) probs = probs.reshape(-1) - wights = wights.reshape(-1) + weights = weights.reshape(-1) - fpr, tpr, thresholds = metrics.roc_curve(labels, probs, sample_weight=wights) + fpr, tpr, thresholds = metrics.roc_curve(labels, probs, sample_weight=weights) return metrics.auc(fpr, tpr) @@ -569,13 +525,14 @@ def evaluation(writer, step_num, masked_input, eval_input): rtd_recall = accuracy(rtd_labels, rtd_preds, rtd_labels * rtd_preds) rtd_auc = auc(rtd_labels, rtd_probs, length_masks) writer.add_scalars('results', - {'mlm_accuracy': mlm_accuracy.asnumpy().item(), - 'corrupted_mlm_accuracy': corrupted_mlm_accuracy.asnumpy().item(), - 'rtd_accuracy': rtd_accuracy.asnumpy().item(), - 'rtd_precision': rtd_precision.asnumpy().item(), - 'rtd_recall': rtd_recall.asnumpy().item(), - 'rtd_auc':rtd_auc}, - step_num) + {'mlm_accuracy': mlm_accuracy.asnumpy().item(), + 'corrupted_mlm_accuracy': corrupted_mlm_accuracy.asnumpy().item(), + 'rtd_accuracy': rtd_accuracy.asnumpy().item(), + 'rtd_precision': rtd_precision.asnumpy().item(), + 'rtd_recall': rtd_recall.asnumpy().item(), + 'rtd_auc': rtd_auc}, + step_num) + if __name__ == '__main__': os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' diff --git a/scripts/question_answering/README.md b/scripts/question_answering/README.md index 3a24cda2d6..e7a8d1432b 100644 --- a/scripts/question_answering/README.md +++ b/scripts/question_answering/README.md @@ -70,6 +70,13 @@ python run_squad.py \ --overwrite_cache \ ``` +We could speed up multi-GPU training via horovod. Compared to KVStore, training RoBERTa Large model on SQuAD 2.0 with 3 epochs will save roughly 1/4 training resources (8.48 vs 11.32 hours). Results may vary depending on the training instances. + +```bash +mpirun -np 4 -H localhost:4 python run_squad.py \ + --comm_backend horovod \ + ... +``` As for ELECTRA model, we fine-tune it with layer-wise learning rate decay as ```bash @@ -165,6 +172,3 @@ For reference, we have also included the results of original version from Google |Google ELECTRA base | - /86.8 | - /83.7 | |Google ELECTRA large | - /89.7 | - /88.1 | |Fairseq RoBERTa large | 94.6/88.9 | 89.4/86.5 | - - -All experiments done on AWS P3.8xlarge (4 x NVIDIA Tesla V100 16 GB) diff --git a/scripts/question_answering/eval_utils.py b/scripts/question_answering/eval_utils.py index 4f9db4916e..e28aecb7af 100644 --- a/scripts/question_answering/eval_utils.py +++ b/scripts/question_answering/eval_utils.py @@ -179,12 +179,30 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h main_eval['best_f1_thresh'] = f1_thresh -def get_revised_results(preds, na_probs, thresh): - results = copy.deepcopy(preds) +def revise_unanswerable(preds, na_probs, na_prob_thresh): + """ + Revise the predictions results and return a null string for unanswerable question + whose unanswerable probability above the threshold. + + Parameters + ---------- + preds: dict + A dictionary of full prediction of spans + na_probs: dict + A dictionary of unanswerable probabilities + na_prob_thresh: float + threshold of the unanswerable probability + + Returns + ------- + revised: dict + A dictionary of revised prediction + """ + revised = copy.deepcopy(preds) for q_id in na_probs.keys(): - if na_probs[q_id] > thresh: - results[q_id] = "" - return results + if na_probs[q_id] > na_prob_thresh: + revised[q_id] = "" + return revised def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False): @@ -197,7 +215,7 @@ def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False): preds predictions dictionary na_probs - probabilities dict of unanswerable + probabilities dictionary of unanswerable na_prob_thresh threshold of unanswerable revise @@ -205,10 +223,10 @@ def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False): with null string '' Returns ------- - out_eval - A dictionary of output results - (preds_out) - A dictionary of final predictions + out_eval + A dictionary of output results + (preds_out) + A dictionary of final predictions """ if isinstance(data_file, str): with open(data_file) as f: @@ -243,7 +261,7 @@ def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False): if revise: thresh = (out_eval['best_exact_thresh'] + out_eval['best_f1_thresh']) * 0.5 - preds_out = get_revised_results(preds, na_probs, thresh) + preds_out = revise_unanswerable(preds, na_probs, thresh) return out_eval, preds_out else: return out_eval, preds diff --git a/scripts/question_answering/models.py b/scripts/question_answering/models.py index cb85cb7abb..58b156cbf3 100644 --- a/scripts/question_answering/models.py +++ b/scripts/question_answering/models.py @@ -74,6 +74,60 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask): end_logits = masked_logsoftmax(F, end_scores, mask=p_mask, axis=-1) return start_logits, end_logits + def inference(self, tokens, token_types, valid_length, p_mask, + start_top_n: int = 5, end_top_n: int = 5): + """Get the inference result with beam search + + Parameters + ---------- + tokens + The input tokens. Shape (batch_size, sequence_length) + token_types + The input token types. Shape (batch_size, sequence_length) + valid_length + The valid length of the tokens. Shape (batch_size,) + p_mask + The mask which indicates that some tokens won't be used in the calculation. + Shape (batch_size, sequence_length) + start_top_n + The number of candidates to select for the start position. + end_top_n + The number of candidates to select for the end position. + + Returns + ------- + start_top_logits + The top start logits + Shape (batch_size, start_top_n) + start_top_index + Index of the top start logits + Shape (batch_size, start_top_n) + end_top_logits + The top end logits. + Shape (batch_size, end_top_n) + end_top_index + Index of the top end logits + Shape (batch_size, end_top_n) + """ + # Shape (batch_size, sequence_length, C) + if self.use_segmentation: + contextual_embeddings = self.backbone(tokens, token_types, valid_length) + else: + contextual_embeddings = self.backbone(tokens, valid_length) + scores = self.qa_outputs(contextual_embeddings) + start_scores = scores[:, :, 0] + end_scores = scores[:, :, 1] + start_logits = masked_logsoftmax(mx.nd, start_scores, mask=p_mask, axis=-1) + end_logits = masked_logsoftmax(mx.nd, end_scores, mask=p_mask, axis=-1) + # The shape of start_top_index will be (..., start_top_n) + start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1, + ret_typ='both') + # Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top) + # So that for each start position, there are end_n_top end positions on the third dim. + end_top_logits, end_top_index = mx.npx.topk(end_logits, k=end_top_n, axis=-1, + ret_typ='both') + return start_top_logits, start_top_index, end_top_logits, end_top_index + @use_np class ModelForQAConditionalV1(HybridBlock): diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 1484aeccd2..1ced1b444c 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -21,9 +21,16 @@ from eval_utils import squad_eval from squad_utils import SquadFeature, get_squad_examples, convert_squad_example_to_feature from gluonnlp.models import get_backbone -from gluonnlp.utils.misc import grouper, set_seed, parse_ctx, logging_config, count_parameters +from gluonnlp.utils.misc import repeat, grouper, set_seed, init_comm, \ + logging_config, count_parameters, parse_ctx from gluonnlp.initializer import TruncNorm -from gluonnlp.utils.parameter import clip_grad_global_norm +from gluonnlp.data.sampler import SplitSampler +from gluonnlp.utils.parameter import grad_global_norm, clip_grad_global_norm + +try: + import horovod.mxnet as hvd +except ImportError: + pass mx.npx.set_np() @@ -33,8 +40,9 @@ def parse_args(): - parser = argparse.ArgumentParser(description='Question Answering example.' - ' We fine-tune the pretrained model on SQuAD dataset.') + parser = argparse.ArgumentParser( + description='Question Answering example. ' + 'We fine-tune the pretrained model on SQuAD dataset.') parser.add_argument('--model_name', type=str, default='google_albert_base_v2', help='Name of the pretrained model.') parser.add_argument('--do_train', action='store_true', @@ -47,11 +55,18 @@ def parse_args(): parser.add_argument('--output_dir', type=str, default='squad_out', help='The output directory where the model params will be written.' ' default is squad_out') + # Communication + parser.add_argument('--comm_backend', type=str, default='device', + choices=['horovod', 'dist_sync_device', 'device'], + help='Communication backend.') parser.add_argument('--gpus', type=str, default='0', help='list of gpus to run, e.g. 0 or 0,2,5. -1 means using cpu.') # Training hyperparameters parser.add_argument('--seed', type=int, default=100, help='Random seed') - parser.add_argument('--log_interval', type=int, default=100, help='The logging interval.') + parser.add_argument('--log_interval', type=int, default=50, + help='The logging interval for training') + parser.add_argument('--eval_log_interval', type=int, default=10, + help='The logging interval for evaluation') parser.add_argument('--save_interval', type=int, default=None, help='the number of steps to save model parameters.' 'default is every epoch') @@ -69,6 +84,10 @@ def parse_args(): help='Max gradient norm.') parser.add_argument('--optimizer', type=str, default='adamw', help='optimization algorithm. default is adamw') + parser.add_argument('--adam_epsilon', type=float, default=1e-6, + help='epsilon of AdamW optimizer') + parser.add_argument('--adam_betas', default='(0.9, 0.999)', metavar='B', + help='betas for Adam optimizer') parser.add_argument('--num_accumulated', type=int, default=1, help='The number of batches for gradients accumulation to ' 'simulate large batch size.') @@ -80,7 +99,8 @@ def parse_args(): help='warmup steps. Note that either warmup_steps or warmup_ratio is set.') parser.add_argument('--wd', type=float, default=0.01, help='weight decay') parser.add_argument('--layerwise_decay', type=float, default=-1, help='Layer-wise lr decay') - parser.add_argument('--untunable_depth', type=float, default=-1, help='Depth of untunable parameters') + parser.add_argument('--untunable_depth', type=float, default=-1, + help='Depth of untunable parameters') parser.add_argument('--classifier_dropout', type=float, default=0.1, help='dropout of classifier') # Data pre/post processing @@ -108,16 +128,17 @@ def parse_args(): 'to a start position') parser.add_argument('--n_best_size', type=int, default=20, help='Top N results written to file') parser.add_argument('--max_answer_length', type=int, default=30, - help='The maximum length of an answer that can be generated. This is needed ' - 'because the start and end predictions are not conditioned on one another.' - ' default is 30') + help='The maximum length of an answer that can be generated. This is ' + 'needed because the start and end predictions are not conditioned ' + 'on one another. default is 30') parser.add_argument('--param_checkpoint', type=str, default=None, help='The parameter checkpoint for evaluating the model') parser.add_argument('--backbone_path', type=str, default=None, help='The parameter checkpoint of backbone model') parser.add_argument('--all_evaluate', action='store_true', - help='Whether to evaluate all intermediate checkpoints instead of only last one') - parser.add_argument('--max_saved_ckpt', type=int, default=10, + help='Whether to evaluate all intermediate checkpoints ' + 'instead of only last one') + parser.add_argument('--max_saved_ckpt', type=int, default=5, help='The maximum number of saved checkpoints') parser.add_argument('--eval_dtype', type=str, default='float32', help='Data type used for evaluation. Either float32 or float16') @@ -126,31 +147,6 @@ def parse_args(): class SquadDatasetProcessor: - # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality. - ChunkFeature = collections.namedtuple('ChunkFeature', - ['qas_id', - 'data', - 'valid_length', - 'segment_ids', - 'masks', - 'is_impossible', - 'gt_start', - 'gt_end', - 'context_offset', - 'chunk_start', - 'chunk_length']) - BatchifyFunction = bf.NamedTuple(ChunkFeature, - {'qas_id': bf.List(), - 'data': bf.Pad(), - 'valid_length': bf.Stack(), - 'segment_ids': bf.Pad(), - 'masks': bf.Pad(val=1), - 'is_impossible': bf.Stack(), - 'gt_start': bf.Stack(), - 'gt_end': bf.Stack(), - 'context_offset': bf.Stack(), - 'chunk_start': bf.Stack(), - 'chunk_length': bf.Stack()}) def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length): """ @@ -177,6 +173,32 @@ def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length): self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id + # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality. + ChunkFeature = collections.namedtuple('ChunkFeature', + ['qas_id', + 'data', + 'valid_length', + 'segment_ids', + 'masks', + 'is_impossible', + 'gt_start', + 'gt_end', + 'context_offset', + 'chunk_start', + 'chunk_length']) + BatchifyFunction = bf.NamedTuple(ChunkFeature, + {'qas_id': bf.List(), + 'data': bf.Pad(val=self.pad_id), + 'valid_length': bf.Stack(), + 'segment_ids': bf.Pad(), + 'masks': bf.Pad(val=1), + 'is_impossible': bf.Stack(), + 'gt_start': bf.Stack(), + 'gt_end': bf.Stack(), + 'context_offset': bf.Stack(), + 'chunk_start': bf.Stack(), + 'chunk_length': bf.Stack()}) + def process_sample(self, feature: SquadFeature): """Process the data to the following format. @@ -242,17 +264,17 @@ def process_sample(self, feature: SquadFeature): # Here, we increase the start and end because we put query before context start_pos = chunk.gt_start_pos + context_offset end_pos = chunk.gt_end_pos + context_offset - chunk_feature = self.ChunkFeature(qas_id=feature.qas_id, - data=data, - valid_length=valid_length, - segment_ids=segment_ids, - masks=masks, - is_impossible=chunk.is_impossible, - gt_start=start_pos, - gt_end=end_pos, - context_offset=context_offset, - chunk_start=chunk.start, - chunk_length=chunk.length) + chunk_feature = ChunkFeature(qas_id=feature.qas_id, + data=data, + valid_length=valid_length, + segment_ids=segment_ids, + masks=masks, + is_impossible=chunk.is_impossible, + gt_start=start_pos, + gt_end=end_pos, + context_offset=context_offset, + chunk_start=chunk.start, + chunk_length=chunk.length) ret.append(chunk_feature) return ret @@ -288,6 +310,50 @@ def get_train(self, features, skip_unreliable=True): return train_dataset, num_token_answer_mismatch, num_unreliable +def get_squad_features(args, tokenizer, segment): + """ + Get processed data features of SQuADExampls + + Parameters + ---------- + args : argparse.Namespace + tokenizer: + Tokenizer instance + segment: str + train or dev + + Returns + ------- + data_features + The list of processed data features + """ + data_cache_path = os.path.join(CACHE_PATH, + '{}_{}_squad_{}.ndjson'.format( + segment, args.model_name, args.version)) + is_training = (segment == 'train') + if os.path.exists(data_cache_path) and not args.overwrite_cache: + data_features = [] + with open(data_cache_path, 'r') as f: + for line in f: + data_features.append(SquadFeature.from_json(line)) + logging.info('Found cached data features, load from {}'.format(data_cache_path)) + else: + data_examples = get_squad_examples(args.data_dir, segment=segment, version=args.version) + start = time.time() + num_process = min(cpu_count(), 8) + logging.info('Tokenize Data:') + with Pool(num_process) as pool: + data_features = pool.map(functools.partial(convert_squad_example_to_feature, + tokenizer=tokenizer, + is_training=is_training), data_examples) + logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start)) + with open(data_cache_path, 'w') as f: + for feature in data_features: + f.write(feature.to_json() + '\n') + + return data_features + + def get_network(model_name, ctx_l, dropout=0.1, @@ -348,107 +414,17 @@ def get_network(model_name, return cfg, tokenizer, qa_net, use_segmentation -def untune_params(model, untunable_depth, not_included=[]): - """Froze part of parameters according to layer depth. - - That is, make all layer that shallower than `untunable_depth` untunable - to stop the gradient backward computation and accelerate the training. - - Parameters: - ---------- - model - qa_net - untunable_depth: int - the depth of the neural network starting from 1 to number of layers - not_included: list of str - A list or parameter names that not included in the untunable parameters - """ - all_layers = model.backbone.encoder.all_encoder_layers - for _, v in model.collect_params('.*embed*').items(): - model.grad_req = 'null' - - for layer in all_layers[:untunable_depth]: - for key, value in layer.collect_params().items(): - for pn in not_included: - if pn in key: - continue - value.grad_req = 'null' - - -def apply_layerwise_decay(model, layerwise_decay, not_included=[]): - """Apply the layer-wise gradient decay - - .. math:: - lr = lr * layerwise_decay^(max_depth - layer_depth) - - Parameters: - ---------- - model - qa_net - layerwise_decay: int - layer-wise decay power - not_included: list of str - A list or parameter names that not included in the layer-wise decay - """ - # consider the task specific finetuning layer as the last layer, following with pooler - # In addition, the embedding parameters have the smaller learning rate based on this setting. - all_layers = model.backbone.encoder.all_encoder_layers - max_depth = len(all_layers) - if 'pool' in model.collect_params().keys(): - max_depth += 1 - for key, value in model.collect_params().items(): - if 'scores' in key: - value.lr_mult = layerwise_decay**(0) - if 'pool' in key: - value.lr_mult = layerwise_decay**(1) - if 'embed' in key: - value.lr_mult = layerwise_decay**(max_depth + 1) - - for (layer_depth, layer) in enumerate(all_layers): - layer_params = layer.collect_params() - for key, value in layer_params.items(): - for pn in not_included: - if pn in key: - continue - value.lr_mult = layerwise_decay**(max_depth - layer_depth) - - def train(args): - ctx_l = parse_ctx(args.gpus) - cfg, tokenizer, qa_net, use_segmentation \ - = get_network(args.model_name, ctx_l, - args.classifier_dropout, - args.param_checkpoint, - args.backbone_path) - # Load the data - train_examples = get_squad_examples(args.data_dir, segment='train', version=args.version) - logging.info('Load data from {}, Version={}'.format(args.data_dir, args.version)) - num_process = min(cpu_count(), 8) - train_cache_path = os.path.join( - CACHE_PATH, 'train_{}_squad_{}.ndjson'.format( - args.model_name, args.version)) - if os.path.exists(train_cache_path) and not args.overwrite_cache: - train_features = [] - with open(train_cache_path, 'r') as f: - for line in f: - train_features.append(SquadFeature.from_json(line)) - logging.info('Found cached training features, load from {}'.format(train_cache_path)) - - else: - start = time.time() - logging.info('Tokenize Training Data:') - with Pool(num_process) as pool: - train_features = pool.map( - functools.partial( - convert_squad_example_to_feature, - tokenizer=tokenizer, - is_training=True), - train_examples) - logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start)) - with open(train_cache_path, 'w') as f: - for feature in train_features: - f.write(feature.to_json() + '\n') - + store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( + args.comm_backend, args.gpus) + cfg, tokenizer, qa_net, use_segmentation = \ + get_network(args.model_name, ctx_l, + args.classifier_dropout, + args.param_checkpoint, + args.backbone_path) + + logging.info('Prepare training data') + train_features = get_squad_features(args, tokenizer, segment='train') dataset_processor = SquadDatasetProcessor(tokenizer=tokenizer, doc_stride=args.doc_stride, max_seq_length=args.max_seq_length, @@ -468,31 +444,39 @@ def train(args): sum([ele.is_impossible for ele in train_features]))) logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}' .format(len(train_dataset), num_impossible)) + sampler = SplitSampler(len(train_dataset), num_parts=num_workers, + part_index=rank, even_size=True) train_dataloader = mx.gluon.data.DataLoader( train_dataset, batchify_fn=dataset_processor.BatchifyFunction, batch_size=args.batch_size, num_workers=0, - shuffle=True) - # Froze parameters + sampler=sampler) if 'electra' in args.model_name: - # does not work for albert model since parameters in all layers are shared + # Froze parameters, does not work for albert model since parameters in all layers are shared if args.untunable_depth > 0: - untune_params(qa_net, args.untunable_depth) + qa_net.backbone.frozen_params(args.untunable_depth) if args.layerwise_decay > 0: - apply_layerwise_decay(qa_net, args.layerwise_decay) + qa_net.backbone.apply_layerwise_decay(args.layerwise_decay) + logging.info('Creating distributed trainer...') + # Collect differentiable parameters + param_dict = qa_net.collect_params() # Do not apply weight decay to all the LayerNorm and bias for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 - # Collect differentiable parameters - params = [p for p in qa_net.collect_params().values() if p.grad_req != 'null'] + params = [p for p in param_dict.values() if p.grad_req != 'null'] # Set grad_req if gradient accumulation is required if args.num_accumulated > 1: logging.info('Using gradient accumulation. Effective global batch size = {}' - .format(args.num_accumulated * args.batch_size * len(ctx_l))) + .format(args.num_accumulated * args.batch_size * len(ctx_l) * num_workers)) for p in params: p.grad_req = 'add' + # backend specific implementation + if args.comm_backend == 'horovod': + # Horovod: fetch and broadcast parameters + hvd.broadcast_parameters(param_dict, root_rank=0) + epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l) if args.num_train_steps is not None: num_train_steps = args.num_train_steps @@ -509,6 +493,7 @@ def train(args): logging.info('#Total Training Steps={}, Warmup={}, Save Interval={}' .format(num_train_steps, warmup_steps, save_interval)) + # set up optimization lr_scheduler = PolyScheduler(max_update=num_train_steps, base_lr=args.lr, warmup_begin_lr=0, @@ -520,18 +505,24 @@ def train(args): 'wd': args.wd, 'lr_scheduler': lr_scheduler, } + adam_betas = eval(args.adam_betas) if args.optimizer == 'adamw': - optimizer_params.update({'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-6, + optimizer_params.update({'beta1': adam_betas[0], + 'beta2': adam_betas[1], + 'epsilon': args.adam_epsilon, 'correct_bias': False, }) - trainer = mx.gluon.Trainer(qa_net.collect_params(), - args.optimizer, optimizer_params, - update_on_kvstore=False) - step_num = 0 - finish_flag = False - epoch_id = 0 + elif args.optimizer == 'adam': + optimizer_params.update({'beta1': adam_betas[0], + 'beta2': adam_betas[1], + 'epsilon': args.adam_epsilon, + }) + if args.comm_backend == 'horovod': + trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) + else: + trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params, + update_on_kvstore=False) + num_samples_per_update = 0 loss_denom = float(len(ctx_l) * args.num_accumulated) @@ -543,22 +534,19 @@ def train(args): # set grad to zero for gradient accumulation qa_net.zero_grad() global_tic = time.time() - while not finish_flag: - epoch_tic = time.time() - tic = time.time() - epoch_sample_num = 0 - for batch_id, sample_l in enumerate(grouper(train_dataloader, len(ctx_l))): + tic = time.time() + for step_num, batch_data in enumerate( + grouper(repeat(train_dataloader), len(ctx_l) * args.num_accumulated)): + for sample_l in grouper(batch_data, len(ctx_l)): loss_l = [] span_loss_l = [] answerable_loss_l = [] - is_last_batch = (batch_id == epoch_size - 1) for sample, ctx in zip(sample_l, ctx_l): if sample is None: continue # Copy the data to device tokens = sample.data.as_in_ctx(ctx) log_sample_num += len(tokens) - epoch_sample_num += len(tokens) num_samples_per_update += len(tokens) segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None valid_length = sample.valid_length.as_in_ctx(ctx) @@ -589,70 +577,72 @@ def train(args): for ele in loss_l]).asnumpy() * loss_denom log_answerable_loss += sum([ele.as_in_ctx(ctx_l[0]) for ele in answerable_loss_l]).asnumpy() - # update - if (batch_id + 1) % args.num_accumulated == 0 or is_last_batch: - trainer.allreduce_grads() - # Here, the accumulated gradients are - # \sum_{n=1}^N g_n / loss_denom - # Thus, in order to clip the average gradient - # \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm - # We need to change the ratio to be - # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom - total_norm, ratio, is_finite = clip_grad_global_norm( - params, args.max_grad_norm * num_samples_per_update / loss_denom) - total_norm = total_norm / (num_samples_per_update / loss_denom) - - trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True) - step_num += 1 - if args.num_accumulated != 1: - # set grad to zero for gradient accumulation - qa_net.zero_grad() - - # saving - if step_num % save_interval == 0 or step_num >= num_train_steps: - version_prefix = 'squad' + args.version - ckpt_name = '{}_{}_{}.params'.format(args.model_name, - version_prefix, - step_num) - params_saved = os.path.join(args.output_dir, ckpt_name) - qa_net.save_parameters(params_saved) - ckpt_candidates = [ - f for f in os.listdir( - args.output_dir) if f.endswith('.params')] - # keep last 10 checkpoints - if len(ckpt_candidates) > args.max_saved_ckpt: - ckpt_candidates.sort(key=lambda ele: (len(ele), ele)) - os.remove(os.path.join(args.output_dir, ckpt_candidates[0])) - logging.info('Params saved in: {}'.format(params_saved)) - - # logging - if step_num % log_interval == 0: - log_span_loss /= log_sample_num - log_answerable_loss /= log_sample_num - log_total_loss /= log_sample_num - toc = time.time() - logging.info( - 'Epoch: {}, Batch: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},' - ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s' - ' ETA={:.2f}h'.format(epoch_id + 1, batch_id + 1, epoch_size, log_span_loss, - log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm, - toc - tic, log_sample_num / (toc - tic), - (num_train_steps - step_num) / (step_num / (toc - global_tic)) / 3600)) - tic = time.time() - log_span_loss = 0 - log_answerable_loss = 0 - log_total_loss = 0 - log_sample_num = 0 - num_samples_per_update = 0 - - if step_num >= num_train_steps: - logging.info('Finish training step: %d', step_num) - finish_flag = True - break - logging.info('Epoch: {}, #Samples: {}, Throughput={:.2f} samples/s' - .format(epoch_id + 1, epoch_sample_num, - epoch_sample_num / (time.time() - epoch_tic))) - epoch_id += 1 + # update + trainer.allreduce_grads() + + if args.max_grad_norm > 0: + # Here, the accumulated gradients are + # \sum_{n=1}^N g_n / loss_denom + # Thus, in order to clip the average gradient + # \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm + # We need to change the ratio to be + # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom + total_norm, ratio, is_finite = clip_grad_global_norm( + params, args.max_grad_norm * num_samples_per_update / loss_denom) + else: + total_norm = grad_global_norm(params) + + total_norm = total_norm / (num_samples_per_update / loss_denom) + trainer.update(num_samples_per_update / loss_denom) + if args.num_accumulated != 1: + # set grad to zero for gradient accumulation + qa_net.zero_grad() + + # saving + if local_rank == 0 and (step_num + 1) % save_interval == 0 or ( + step_num + 1) >= num_train_steps: + version_prefix = 'squad' + args.version + ckpt_name = '{}_{}_{}.params'.format(args.model_name, + version_prefix, + (step_num + 1)) + params_saved = os.path.join(args.output_dir, ckpt_name) + qa_net.save_parameters(params_saved) + ckpt_candidates = [ + f for f in os.listdir( + args.output_dir) if f.endswith('.params')] + # keep last `max_saved_ckpt` checkpoints + if len(ckpt_candidates) > args.max_saved_ckpt: + ckpt_candidates.sort(key=lambda ele: (len(ele), ele)) + os.remove(os.path.join(args.output_dir, ckpt_candidates[0])) + logging.info('Params saved in: {}'.format(params_saved)) + + # logging + if local_rank == 0 and (step_num + 1) % log_interval == 0: + log_span_loss /= log_sample_num + log_answerable_loss /= log_sample_num + log_total_loss /= log_sample_num + toc = time.time() + logging.info( + 'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},' + ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s' + ' ETA={:.2f}h'.format((step_num + 1), num_train_steps, log_span_loss, + log_answerable_loss, log_total_loss, trainer.learning_rate, + total_norm, toc - tic, log_sample_num / (toc - tic), + (num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600)) + tic = time.time() + log_span_loss = 0 + log_answerable_loss = 0 + log_total_loss = 0 + log_sample_num = 0 + num_samples_per_update = 0 + + if (step_num + 1) >= num_train_steps: + toc = time.time() + logging.info( + 'Finish training step: {} within {} hours'.format( + step_num + 1, (toc - global_tic) / 3600)) + break + return params_saved @@ -721,7 +711,7 @@ def predict_extended(original_feature, # TODO investigate the impact token_max_context_score[i, j] = min(j - chunk_start, chunk_start + chunk_length - 1 - j) \ - + 0.01 * chunk_length + + 0.01 * chunk_length token_max_chunk_id = token_max_context_score.argmax(axis=0) for chunk_id, (result, chunk_feature) in enumerate(zip(results, chunked_features)): @@ -787,34 +777,22 @@ def predict_extended(original_feature, def evaluate(args, last=True): + store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( + args.comm_backend, args.gpus) + # only evaluate once + if rank != 0: + logging.info('Skipping node {}'.format(rank)) + return ctx_l = parse_ctx(args.gpus) + logging.info( + 'Srarting inference without horovod on the first node on device {}'.format( + str(ctx_l))) + cfg, tokenizer, qa_net, use_segmentation = get_network( - args.model_name, ctx_l, args.classifier_dropout, dtype=args.eval_dtype) - if args.eval_dtype == 'float16': - qa_net.cast('float16') - # Prepare dev set - dev_cache_path = os.path.join(CACHE_PATH, - 'dev_{}_squad_{}.ndjson'.format(args.model_name, - args.version)) - if os.path.exists(dev_cache_path) and not args.overwrite_cache: - dev_features = [] - with open(dev_cache_path, 'r') as f: - for line in f: - dev_features.append(SquadFeature.from_json(line)) - logging.info('Found cached dev features, load from {}'.format(dev_cache_path)) - else: - dev_examples = get_squad_examples(args.data_dir, segment='dev', version=args.version) - start = time.time() - num_process = min(cpu_count(), 8) - logging.info('Tokenize Dev Data:') - with Pool(num_process) as pool: - dev_features = pool.map(functools.partial(convert_squad_example_to_feature, - tokenizer=tokenizer, - is_training=False), dev_examples) - logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start)) - with open(dev_cache_path, 'w') as f: - for feature in dev_features: - f.write(feature.to_json() + '\n') + args.model_name, ctx_l, args.classifier_dropout) + + logging.info('Prepare dev data') + dev_features = get_squad_features(args, tokenizer, segment='dev') dev_data_path = os.path.join(args.data_dir, 'dev-v{}.json'.format(args.version)) dataset_processor = SquadDatasetProcessor(tokenizer=tokenizer, doc_stride=args.doc_stride, @@ -831,8 +809,6 @@ def eval_validation(ckpt_name, best_eval): """ Model inference during validation or final evaluation. """ - ctx_l = parse_ctx(args.gpus) - # We process all the chunk features and also dev_dataloader = mx.gluon.data.DataLoader( dev_all_chunk_features, batchify_fn=dataset_processor.BatchifyFunction, @@ -840,7 +816,7 @@ def eval_validation(ckpt_name, best_eval): num_workers=0, shuffle=False) - log_interval = args.log_interval + log_interval = args.eval_log_interval all_results = [] epoch_tic = time.time() tic = time.time() @@ -861,8 +837,8 @@ def eval_validation(ckpt_name, best_eval): p_mask = sample.masks.as_in_ctx(ctx) p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \ - = qa_net.inference(tokens, segment_ids, valid_length, p_mask, - args.start_top_n, args.end_top_n) + = qa_net.inference(tokens, segment_ids, valid_length, p_mask, + args.start_top_n, args.end_top_n) for i, qas_id in enumerate(sample.qas_id): result = RawResultExtended(qas_id=qas_id, start_top_logits=start_top_logits[i].asnumpy(), diff --git a/src/gluonnlp/data/tokenizers.py b/src/gluonnlp/data/tokenizers.py index d9579b2d55..5a086df98b 100644 --- a/src/gluonnlp/data/tokenizers.py +++ b/src/gluonnlp/data/tokenizers.py @@ -21,16 +21,18 @@ 'HuggingFaceWordPieceTokenizer', 'create', 'create_with_json', 'list_all'] -from typing import List, Tuple, Union, Optional import os -import json -from collections import OrderedDict import abc +import json import warnings import itertools -from typing import NewType -import sacremoses from uuid import uuid4 +from typing import List, Tuple, Union, NewType, Optional +from collections import OrderedDict + +import jieba +import sacremoses + from .vocab import Vocab from ..registry import TOKENIZER_REGISTRY from ..utils.lazy_imports import try_import_subword_nmt,\ diff --git a/src/gluonnlp/models/albert.py b/src/gluonnlp/models/albert.py index 1b4efa16e2..3f77b734ef 100644 --- a/src/gluonnlp/models/albert.py +++ b/src/gluonnlp/models/albert.py @@ -731,7 +731,8 @@ def list_pretrained_albert(): def get_pretrained_albert(model_name: str = 'google_albert_base_v2', root: str = get_model_zoo_home_dir(), - load_backbone=True, load_mlm=False)\ + load_backbone: str = True, + load_mlm: str = False)\ -> Tuple[CN, SentencepieceTokenizer, str, str]: """Get the pretrained Albert weights diff --git a/src/gluonnlp/models/bert.py b/src/gluonnlp/models/bert.py index 84a1d5ee2e..68e002fea3 100644 --- a/src/gluonnlp/models/bert.py +++ b/src/gluonnlp/models/bert.py @@ -756,7 +756,8 @@ def list_pretrained_bert(): def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base', root: str = get_model_zoo_home_dir(), - load_backbone=True, load_mlm=False)\ + load_backbone: str = True, + load_mlm: str = False)\ -> Tuple[CN, HuggingFaceWordPieceTokenizer, str, str]: """Get the pretrained bert weights diff --git a/src/gluonnlp/models/electra.py b/src/gluonnlp/models/electra.py index b8d4e44029..0be1cfd99a 100644 --- a/src/gluonnlp/models/electra.py +++ b/src/gluonnlp/models/electra.py @@ -36,7 +36,7 @@ from mxnet import use_np from mxnet.gluon import HybridBlock, nn from ..registry import BACKBONE_REGISTRY -from ..op import gumbel_softmax, select_vectors_by_position, updated_vectors_by_position +from ..op import gumbel_softmax, select_vectors_by_position, add_vectors_by_position, updated_vectors_by_position from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir from ..layers import PositionalEmbedding, get_activation from .transformer import TransformerEncoderLayer @@ -157,6 +157,14 @@ def google_electra_large(): 'disc_model': 'google_electra_large/disc_model-5b820c02.params', 'gen_model': 'google_electra_large/gen_model-82c1b17b.params', 'lowercase': True, + }, + 'gluon_electra_small_owt':{ + 'cfg': 'gluon_electra_small_owt/model-6e276d98.yml', + 'vocab': 'gluon_electra_small_owt/vocab-e6d2b21d.json', + 'params': 'gluon_electra_small_owt/model-e9636891.params', + 'disc_model': 'gluon_electra_small_owt/disc_model-87836017.params', + 'gen_model': 'gluon_electra_small_owt/gen_model-45a6fb67.params', + 'lowercase': True, } } @@ -334,6 +342,8 @@ def __init__(self, self.pos_embed_type = pos_embed_type self.num_token_types = num_token_types self.vocab_size = vocab_size + self.num_layers = num_layers + self.num_heads = num_heads self.embed_size = embed_size self.units = units self.max_length = max_length @@ -499,6 +509,58 @@ def get_initial_embedding(self, F, inputs, token_types=None): embedding = self.embed_dropout(embedding) return embedding + def apply_layerwise_decay(self, layerwise_decay, not_included=None): + """Apply the layer-wise gradient decay + + .. math:: + lr = lr * layerwise_decay^(max_depth - layer_depth) + + Parameters: + ---------- + layerwise_decay: int + layer-wise decay power + not_included: list of str + A list or parameter names that not included in the layer-wise decay + """ + + # consider the task specific finetuning layer as the last layer, following with pooler + # In addition, the embedding parameters have the smaller learning rate based on this setting. + max_depth = self.num_layers + 2 + for _, value in self.collect_params('.*embed*').items(): + value.lr_mult = layerwise_decay**(max_depth) + + for (layer_depth, layer) in enumerate(self.encoder.all_encoder_layers): + layer_params = layer.collect_params() + for key, value in layer_params.items(): + for pn in not_included: + if pn in key: + continue + value.lr_mult = layerwise_decay**(max_depth - (layer_depth + 1)) + + def frozen_params(self, untunable_depth, not_included=None): + """Froze part of parameters according to layer depth. + + That is, make all layer that shallower than `untunable_depth` untunable + to stop the gradient backward computation and accelerate the training. + + Parameters: + ---------- + untunable_depth: int + the depth of the neural network starting from 1 to number of layers + not_included: list of str + A list or parameter names that not included in the untunable parameters + """ + all_layers = self.encoder.all_encoder_layers + for _, value in self.collect_params('.*embed*').items(): + value.grad_req = 'null' + + for layer in all_layers[:untunable_depth]: + for key, value in layer.collect_params().items(): + for pn in not_included: + if pn in key: + continue + value.grad_req = 'null' + @staticmethod def get_cfg(key=None): if key is not None: @@ -913,7 +975,7 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log Returns ------- corrupted_tokens - The corrupted tokens + Shape (batch_size, ) fake_data - layout = 'NT' Shape (batch_size, seq_length) @@ -941,13 +1003,15 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log if self.disc_backbone.layout == 'TN': inputs = inputs.T - # Following the Official electra to deal with duplicate positions as - # https://github.com/google-research/electra/issues/41 - original_data, updates_mask = updated_vectors_by_position(F, + original_data = updated_vectors_by_position(F, inputs, unmasked_tokens, masked_positions) - fake_data, _ = updated_vectors_by_position(F, + fake_data = updated_vectors_by_position(F, inputs, corrupted_tokens, masked_positions) - + updates_mask = add_vectors_by_position(F, F.np.zeros_like(inputs), + F.np.ones_like(masked_positions), masked_positions) + # Dealing with multiple zeros in masked_positions which + # results in a non-zero value in the first index [CLS] + updates_mask = F.np.minimum(updates_mask, 1) labels = updates_mask * F.np.not_equal(fake_data, original_data) if self.disc_backbone.layout == 'TN': return corrupted_tokens, fake_data.T, labels.T diff --git a/src/gluonnlp/models/mobilebert.py b/src/gluonnlp/models/mobilebert.py index 5a81de7c64..96ada137f3 100644 --- a/src/gluonnlp/models/mobilebert.py +++ b/src/gluonnlp/models/mobilebert.py @@ -108,8 +108,8 @@ def google_uncased_mobilebert(): @use_np class MobileBertEncoderLayer(HybridBlock): """The Transformer Encoder Layer in Mobile Bert""" - # TODO(zheyuye), use stacked groups for single ffn layer in transformer.TransformerEncoderLayer - # and revise the other models and scripts, making sure they are compatible. + # TODO(zheyuye), use stacked groups for single ffn layer in TransformerEncoderLayer + # and revise the other models and scripts, masking sure their are compatible. def __init__(self, use_bottleneck: bool = True, @@ -267,13 +267,11 @@ def __init__(self, is_last_ffn = (ffn_idx == (num_stacked_ffn - 1)) # only apply dropout on last ffn layer if use bottleneck dropout = float(hidden_dropout_prob * (not use_bottleneck) * is_last_ffn) - activation_dropout = float(activation_dropout_prob * (not use_bottleneck) - * is_last_ffn) self.stacked_ffn.add( PositionwiseFFN(units=real_units, hidden_size=hidden_size, dropout=dropout, - activation_dropout=activation_dropout, + activation_dropout=activation_dropout_prob, weight_initializer=weight_initializer, bias_initializer=bias_initializer, activation=activation, @@ -1021,7 +1019,8 @@ def list_pretrained_mobilebert(): def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert', root: str = get_model_zoo_home_dir(), - load_backbone=True, load_mlm=True)\ + load_backbone: str = True, + load_mlm: str = False)\ -> Tuple[CN, HuggingFaceWordPieceTokenizer, str, str]: """Get the pretrained mobile bert weights @@ -1077,6 +1076,7 @@ def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert', sha1_hash=FILE_STATS[mlm_params_path]) else: local_mlm_params_path = None + do_lower = True if 'lowercase' in PRETRAINED_URL[model_name]\ and PRETRAINED_URL[model_name]['lowercase'] else False tokenizer = HuggingFaceWordPieceTokenizer( diff --git a/src/gluonnlp/models/model_zoo_checksums/electra.txt b/src/gluonnlp/models/model_zoo_checksums/electra.txt index b2f7548c73..2d66960466 100644 --- a/src/gluonnlp/models/model_zoo_checksums/electra.txt +++ b/src/gluonnlp/models/model_zoo_checksums/electra.txt @@ -13,3 +13,8 @@ google_electra_large/model-9baf9ff5.params 9baf9ff55cee0195b7754aee7fc google_electra_large/gen_model-82c1b17b.params 82c1b17b4b5ac19700c272858b0b211437f72855 205211944 google_electra_large/model-31b7dfdd.yml 31b7dfdd343bd2b2e43e200a735c83b0af1963f1 476 google_electra_large/disc_model-5b820c02.params 5b820c026aa2ad779c1e9a41ff4ff1408fefacbf 1340602227 +gluon_electra_small_owt/vocab-e6d2b21d.json e6d2b21d910ccb356aa18f27a1c7d70660edc058 323235 +gluon_electra_small_owt/model-e9636891.params e9636891daae9f2940b2b3210cca3c34c3d8f21e 53748654 +gluon_electra_small_owt/model-6e276d98.yml 6e276d98360fbb7c379d28bac34a3ca2918a90ab 473 +gluon_electra_small_owt/gen_model-45a6fb67.params 45a6fb67e1e6cb65d22b80498f2152ce9780d579 33926624 +gluon_electra_small_owt/disc_model-87836017.params 878360174ac71c3fdc7071be7835bea532c09b8d 54015367 diff --git a/src/gluonnlp/op.py b/src/gluonnlp/op.py index a4762b4ad3..6be0e19262 100644 --- a/src/gluonnlp/op.py +++ b/src/gluonnlp/op.py @@ -100,7 +100,7 @@ def updated_vectors_by_position(F, base, data, positions): """ Update each batch with the given positions. Considered as a reversed process of "select_vectors_by_position", this is an advanced operator of add_vectors_by_position - that updates the results instead of add and avoids duplicate positions. + that updates the results instead of adding. Once advanced indexing can be hybridized, we can revise the implementation. updates[i, positions[i, j], :] = data[i, j, :] @@ -127,22 +127,16 @@ def updated_vectors_by_position(F, base, data, positions): out The updated result. Shape (batch_size, seq_length) - updates_mask - The state of the updated for the whole sequence - 1 -> updated, 0 -> not updated. - Shape (batch_size, seq_length) """ - # TODO(zheyuye), update when npx.index_update implemented - updates = add_vectors_by_position(F, F.np.zeros_like(base), data, positions) - updates_mask = add_vectors_by_position(F, F.np.zeros_like(base), - F.np.ones_like(positions), positions) - updates = (updates / F.np.maximum(1, updates_mask)).astype(np.int32) - - out = F.np.where(updates, updates, base) - updates_mask = F.np.minimum(updates_mask, 1) - - return out, updates_mask + positions = positions.astype(np.int32) + # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] + batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0), + axis=1).astype(np.int32) + batch_idx = batch_idx + F.np.zeros_like(positions) + indices = F.np.stack([batch_idx.reshape(-1), positions.reshape(-1)]) + out = F.npx.index_update(base, indices, data.reshape(-1)) + return out @use_np def gumbel_softmax(F, logits, temperature: float = 1.0, eps: float = 1E-10, diff --git a/src/gluonnlp/utils/misc.py b/src/gluonnlp/utils/misc.py index 7a1a7880a9..51ee1999d1 100644 --- a/src/gluonnlp/utils/misc.py +++ b/src/gluonnlp/utils/misc.py @@ -262,6 +262,15 @@ def grouper(iterable, n, fillvalue=None): args = [iter(iterable)] * n return itertools.zip_longest(*args, fillvalue=fillvalue) +def repeat(iterable, count=None): + if count is None: + while True: + for sample in iterable: + yield sample + else: + for i in range(count): + for sample in iterable: + yield sample def parse_ctx(data_str): import mxnet as mx @@ -536,3 +545,36 @@ def check_version(min_version: str, warnings.warn(msg) else: raise AssertionError(msg) + +def init_comm(backend, gpus): + """Init communication backend""" + # backend specific implementation + import mxnet as mx + if backend == 'horovod': + try: + import horovod.mxnet as hvd # pylint: disable=import-outside-toplevel + except ImportError: + logging.info('horovod must be installed.') + sys.exit(1) + hvd.init() + store = None + num_workers = hvd.size() + rank = hvd.rank() + local_rank = hvd.local_rank() + is_master_node = rank == local_rank + ctx_l = [mx.gpu(local_rank)] + logging.info('GPU communication supported by horovod') + else: + store = mx.kv.create(backend) + num_workers = store.num_workers + rank = store.rank + local_rank = 0 + is_master_node = rank == local_rank + if gpus == '-1' or gpus == '': + ctx_l = [mx.cpu()] + logging.info('Runing on CPU') + else: + ctx_l = [mx.gpu(int(x)) for x in gpus.split(',')] + logging.info('GPU communication supported by KVStore') + + return store, num_workers, rank, local_rank, is_master_node, ctx_l diff --git a/src/gluonnlp/utils/parameter.py b/src/gluonnlp/utils/parameter.py index b0175e63f4..2898933c93 100644 --- a/src/gluonnlp/utils/parameter.py +++ b/src/gluonnlp/utils/parameter.py @@ -127,7 +127,6 @@ def clip_grad_global_norm(parameters: Iterable[Parameter], If the gradient norm is larger than max_norm, it will be clipped to have max_norm check_isfinite If True, check whether the total_norm is finite (not nan or inf). - Returns ------- total_norm diff --git a/tests/test_models_albert.py b/tests/test_models_albert.py index f428a85569..71170500c1 100644 --- a/tests/test_models_albert.py +++ b/tests/test_models_albert.py @@ -169,7 +169,8 @@ def test_albert_get_pretrained(model_name): albert_model = AlbertModel.from_cfg(cfg) albert_model.load_parameters(backbone_params_path) albert_mlm_model = AlbertForMLM(cfg) - albert_mlm_model.load_parameters(mlm_params_path) + if mlm_params_path is not None: + albert_mlm_model.load_parameters(mlm_params_path) # Just load the backbone albert_mlm_model = AlbertForMLM(cfg) albert_mlm_model.backbone_model.load_parameters(backbone_params_path)