Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Horovod support for pretraining and fune-tuning squad (#1276)
Browse files Browse the repository at this point in the history
* fix roberta

* fix xlmr

* fix token_ids

* fix

* use_segmentation

* fix roberta

* update

* fix

* fix mobilebert

* repeat

* repeat for pretraining

* revise

* revise train_transformer

* upload gluon_electra_small_owt

* fix openwebtext

* fix wiki

* fix bookcorpus

* multiprocessing for wiki

* update

* rename

* index_update

* topk

* revise

* layer-wise decay

* fix mobilebert

* try

* update hyper-parameters of adamw

* fix roberta

* clip_grad_global_norm with zeros max_grad_norm

* fix ModelForQABasic

* multiply_grads

* remove multiply_grads

* fix

* horovod for squad

* update

* inference without horovod

* fix

* update

* re-upload roberta

* fix get_pretrained

* re-upload xlmr

* update testings

* tiny update on run_squad

* test

* lowercase

* CharTokenizer

* Squashed commit of the following:

commit 35a586676036f627bffd0d3c753c6cd0a70d63cf
Author: ZheyuYe <[email protected]>
Date:   Fri Jul 17 10:10:14 2020 +0800

    Squashed commit of the following:

    commit 673344d
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 15 22:43:07 2020 +0800

        CharTokenizer

    commit 8dabfd6
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 15 15:47:24 2020 +0800

        lowercase

    commit f5c94a6
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jul 14 17:45:28 2020 +0800

        test

    commit dc55fc9
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jul 14 05:45:01 2020 +0800

        tiny update on run_squad

    commit 4defc7a
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jul 13 23:18:08 2020 +0800

        update testings

    commit 2719e81
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jul 13 23:08:32 2020 +0800

        re-upload xlmr

    commit cd0509d
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jul 13 22:30:47 2020 +0800

        fix get_pretrained

    commit 8ed8a72
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jul 13 22:28:13 2020 +0800

        re-upload roberta

    commit 5811d40
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jul 13 18:27:23 2020 +0800

        update

    commit 44a09a3
    Author: ZheyuYe <[email protected]>
    Date:   Sat Jul 11 15:06:33 2020 +0800

        fix

    commit 4074a26
    Author: ZheyuYe <[email protected]>
    Date:   Fri Jul 10 16:08:49 2020 +0800

        inference without horovod

    commit 31cb953
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 9 18:41:55 2020 +0800

        update

    commit 838be2a
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 9 15:14:39 2020 +0800

        horovod for squad

    commit 1d374a2
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 9 12:09:19 2020 +0800

        fix

    commit e4fba39
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 9 10:35:08 2020 +0800

        remove multiply_grads

    commit 007f07e
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jul 7 11:26:38 2020 +0800

        multiply_grads

    commit b8c85bb
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jul 6 12:28:56 2020 +0800

        fix ModelForQABasic

    commit 0e13a58
    Author: ZheyuYe <[email protected]>
    Date:   Sat Jul 4 18:42:12 2020 +0800

        clip_grad_global_norm with zeros max_grad_norm

    commit bd270f2
    Author: ZheyuYe <[email protected]>
    Date:   Fri Jul 3 20:21:31 2020 +0800

        fix roberta

    commit 4fc564c
    Author: ZheyuYe <[email protected]>
    Date:   Fri Jul 3 19:36:08 2020 +0800

        update hyper-parameters of adamw

    commit 59cffbf
    Author: ZheyuYe <[email protected]>
    Date:   Fri Jul 3 16:25:46 2020 +0800

        try

    commit a84f782
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 2 20:39:03 2020 +0800

        fix mobilebert

    commit 4bc3a96
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 2 11:14:39 2020 +0800

        layer-wise decay

    commit 07186d5
    Author: ZheyuYe <[email protected]>
    Date:   Thu Jul 2 02:14:43 2020 +0800

        revise

    commit a5a6475
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 1 19:50:20 2020 +0800

        topk

    commit 34ee884
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 1 19:25:09 2020 +0800

        index_update

    commit 74178e2
    Author: ZheyuYe <[email protected]>
    Date:   Wed Jul 1 00:48:32 2020 +0800

        rename

    commit fa011aa
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jun 30 23:40:28 2020 +0800

        update

    commit 402d625
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jun 30 21:40:30 2020 +0800

        multiprocessing for wiki

    commit ddbde75
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jun 30 20:41:35 2020 +0800

        fix bookcorpus

    commit 6cc5ccd
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jun 30 16:39:12 2020 +0800

        fix wiki

    commit 9773efd
    Author: ZheyuYe <[email protected]>
    Date:   Tue Jun 30 15:52:13 2020 +0800

        fix openwebtext

    commit 1fb8eb8
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jun 29 19:51:25 2020 +0800

        upload gluon_electra_small_owt

    commit ca83fac
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jun 29 18:09:48 2020 +0800

        revise train_transformer

    commit 1450f5c
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jun 29 18:07:04 2020 +0800

        revise

    commit b460bbe
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jun 29 17:24:00 2020 +0800

        repeat for pretraining

    commit 8ee381b
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jun 29 17:06:43 2020 +0800

        repeat

    commit aea936f
    Author: ZheyuYe <[email protected]>
    Date:   Mon Jun 29 16:39:22 2020 +0800

        fix mobilebert

    commit eead164
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 18:44:28 2020 +0800

        fix

    commit 8645115
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 17:27:43 2020 +0800

        update

    commit 2b7f7a3
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 17:18:00 2020 +0800

        fix roberta

    commit 86702fe
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 16:27:43 2020 +0800

        use_segmentation

    commit 6d03d7a
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 15:52:40 2020 +0800

        fix

    commit 5c0ca43
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 15:49:48 2020 +0800

        fix token_ids

    commit ff7aae8
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 13:56:07 2020 +0800

        fix xlmr

    commit 2070b86
    Author: ZheyuYe <[email protected]>
    Date:   Sun Jun 28 13:54:26 2020 +0800

        fix roberta

commit 70a1887
Author: Leonard Lausen <[email protected]>
Date:   Fri Jul 17 00:07:08 2020 +0000

    Update for Block API (#1261)

    - Remove params and prefix arguments for MXNet 2 and update
      parameter sharing implementation
    - Remove Block.name_scope() for MXNet 2
    - Remove self.params.get() and self.params.get_constant()

commit ea9152b
Author: Xingjian Shi <[email protected]>
Date:   Thu Jul 16 15:42:04 2020 -0700

    Fixes to make the CI more stable (#1265)

    * Some fixes to make the CI more stable

    * add retries

    * Update tokenizers.py

commit a646c34
Author: ht <[email protected]>
Date:   Sun Jul 12 02:49:53 2020 +0800

    [FEATURE] update backtranslation and add multinomial sampler (#1259)

    * back translation bash

    * split "lang-pair" para in clean_tok_para_corpus

    * added clean_tok_mono_corpus

    * fix

    * add num_process para

    * fix

    * fix

    * add yml

    * rm yml

    * update cfg name

    * update evaluate

    * added max_update / save_interval_update params

    * fix

    * fix

    * multi gpu inference

    * fix

    * update

    * update multi gpu inference

    * fix

    * fix

    * split evaluate and parallel infer

    * fix

    * test

    * fix

    * update

    * add comments

    * fix

    * remove todo comment

    * revert remove todo comment

    * raw lines remove duplicated '\n'

    * update multinomaial sampler

    * fix

    * fix

    * fix

    * fix

    * sampling

    * update script

    * fix

    * add test_case with k > 1 in topk sampling

    * fix multinomial sampler

    * update docs

    * comments situation eos_id = None

    * fix

    Co-authored-by: Hu <[email protected]>

commit 83e1f13
Author: Leonard Lausen <[email protected]>
Date:   Thu Jul 9 20:57:55 2020 -0700

    Use Amazon S3 Transfer Acceleration (#1260)

commit cd48efd
Author: Leonard Lausen <[email protected]>
Date:   Tue Jul 7 17:39:42 2020 -0700

    Update codecov action to handle different OS and Python versions (#1254)

    codecov/codecov-action#80 (comment)

commit 689eba9
Author: Sheng Zha <[email protected]>
Date:   Tue Jul 7 09:55:34 2020 -0700

    [CI] AWS batch job tool for GluonNLP (Part I) (#1251)

    * AWS batch job tool for GluonNLP

    * limit range

    Co-authored-by: Xingjian Shi <[email protected]>

commit e06ff01
Author: Leonard Lausen <[email protected]>
Date:   Tue Jul 7 08:36:24 2020 -0700

    Pin mxnet version range on CI (#1257)

* frozen_params

* remove conversion to a sperate pr

* fix

* fix

* update

* test

* revise

* update performance numbers

* update apply_layerwisw_decay

* use shuffle

* fix mobilebert

* fix vocab_file
  • Loading branch information
zheyuye authored Aug 1, 2020
1 parent 2294421 commit 1f9ad44
Show file tree
Hide file tree
Showing 25 changed files with 815 additions and 567 deletions.
3 changes: 2 additions & 1 deletion scripts/datasets/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .general_nlp_benchmark import prepare_glue
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY


# TODO(zheyuye), lazy_import theses data parser functions and data main function
# and their dependencies by a dictionary mapping the datasets names to the functions.
def list_all_subcommands():
out = []
for key in DATA_PARSER_REGISTRY.list_keys():
Expand Down
21 changes: 13 additions & 8 deletions scripts/datasets/pretrain_corpus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,41 @@

We provide a series of shared scripts for downloading/preparing the text corpus for pretraining NLP models.
This helps create a unified text corpus for studying the performance of different pretraining algorithms.
When releasing the datasets, we follow the [FAIR principle](https://www.go-fair.org/fair-principles/),
i.e., the dataset needs to be findable, accessible, interoperable, and reusable.
When releasing the datasets, we follow the [FAIR principle](https://www.go-fair.org/fair-principles/),
i.e., the dataset needs to be findable, accessible, interoperable, and reusable.

## BookCorpus
Unfortunately, we are unable to provide the original [Toronto BookCorpus dataset](https://yknzhu.wixsite.com/mbweb) due to licensing issues.

There are some open source efforts for reproducing the dataset, e.g.,
using [soskek/bookcorpus](https://github.com/soskek/bookcorpus) or directly downloading the [preprocessed version](https://drive.google.com/file/d/16KCjV9z_FHm8LgZw05RSuk4EsAWPOP_z/view).
using [soskek/bookcorpus](https://github.com/soskek/bookcorpus) or directly downloading the [preprocessed version](https://drive.google.com/file/d/16KCjV9z_FHm8LgZw05RSuk4EsAWPOP_z/view).

Nevertheless, we utilize the [Project Gutenberg](https://www.gutenberg.org/) as an alternative to Toronto BookCorpus.

You can use the following command to download and prepare the Gutenberg dataset.
You can use the following command to download and prepare the Gutenberg dataset.

```bash
python prepare_bookcorpus.py --dataset gutenberg
```

Also, you should follow the [license](https://www.gutenberg.org/wiki/Gutenberg:The_Project_Gutenberg_License) for using the data.
Also, you should follow the [license](https://www.gutenberg.org/wiki/Gutenberg:The_Project_Gutenberg_License) for using the data.

## Wikipedia

Please install [attardi/wikiextractor](https://github.com/attardi/wikiextractor) for preparing the data.

```
```bash
# Download
python prepare_wikipedia.py --mode download --lang en --date latest -o ./

# Properly format the text files
python prepare_wikipedia.py --mode format -i [path-to-wiki.xml.bz2] -o ./

```
The process of downloading and formatting is time consuming, and we offer an alternative solution to download the prepared raw text file from S3 bucket. This raw text file is in English and was dumped at 2020-06-20 being formated by the above very process (` --lang en --date 20200620`).

```bash
python prepare_wikipedia.py --mode download_prepared -o ./
```
### References
- [NVIDIA/DeepLearningExamples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)
Expand All @@ -43,7 +48,7 @@ You can download the OpenWebText from [link](https://skylion007.github.io/OpenWe
After downloading and extracting the OpenWebText (i.e., `tar xf openwebtext.tar.xz`), you can use the following command to preprocess the dataset.

```bash
python prepare_openwebtext.py --input openwebtext/ --output prepared_owt
python prepare_openwebtext.py --input openwebtext/ --output prepared_owt --shuffle
```

In this step, the archived txt are directly read without decompressing.
Expand Down
1 change: 1 addition & 0 deletions scripts/datasets/pretrain_corpus/prepare_bookcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def main(args):
filename = os.path.basename(name)
f.extract(name, os.path.join(save_dir, filename))
else:
# TODO(zheyuye), format for pretraining
raise NotImplementedError
else:
raise NotImplementedError
Expand Down
18 changes: 12 additions & 6 deletions scripts/datasets/pretrain_corpus/prepare_openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def extract_files(full_name, output_dir, shuffle=False):
"""
if not full_name.endswith(".xz"):
return
file_prefix = re.split('\.|/',full_name)[1]
with open("{}.txt".format(os.path.join(output_dir, file_prefix)),"w") as fp:
file_prefix = re.split(r'\.|/', full_name)[-2]
file_prefix = file_prefix.replace('urlsf_subset', 'openwebtext-prepared-')
with open("{}.txt".format(os.path.join(output_dir, file_prefix)), "w") as fp:
with tarfile.open(full_name) as t:
txt_names = t.getnames()
if shuffle:
Expand All @@ -63,9 +64,9 @@ def extract_files(full_name, output_dir, shuffle=False):
# skip empty line
line = line.strip()
if line:
fp.write(line.decode()+'\n')
fp.write(line.decode() + '\n')
# Two extra line break to mark the document separation
fp.write('\n\n')
fp.write('\n')


@DATA_MAIN_REGISTRY.register('prepare_openwebtext')
Expand All @@ -76,11 +77,16 @@ def main(args):
fnames = sorted(os.listdir(args.input))
fnames = [os.path.join(args.input, fname) for fname in fnames]
if args.shuffle:
fnames = random.shuffle(fnames)
random.shuffle(fnames)
print('Start extracting {} files with {} cores'.format(len(fnames), num_process))
start_time = time.time()
with multiprocessing.Pool(num_process) as pool:
iter = pool.imap(functools.partial(extract_files, output_dir=args.output, shuffle=args.shuffle), fnames)
iter = pool.imap(
functools.partial(
extract_files,
output_dir=args.output,
shuffle=args.shuffle),
fnames)
for f_index, _ in enumerate(iter):
if f_index > 0 and f_index % 250 == 0:
elapsed = time.time() - start_time
Expand Down
156 changes: 112 additions & 44 deletions scripts/datasets/pretrain_corpus/prepare_wikipedia.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Prepare the Wikipedia dataset that contain cleaned articles of all languages."""
import os
import sys
import argparse
import glob
from gluonnlp.utils.misc import download
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY
import math
import time
import tarfile
import argparse
import multiprocessing

from gluonnlp.registry import DATA_MAIN_REGISTRY, DATA_PARSER_REGISTRY
from gluonnlp.utils.misc import download, load_checksum_stats

_CITATION = """\
@ONLINE {wikidump,
Expand Down Expand Up @@ -47,6 +52,13 @@
_BASE_URL_TMPL\
= "https://dumps.wikimedia.org/{lang}wiki/{date}/{lang}wiki-{date}-pages-articles.xml.bz2"
_CURR_DIR = os.path.realpath(os.path.dirname(os.path.realpath(__file__)))
_URL_FILE_STATS_PATH = os.path.join(_CURR_DIR, '..', 'url_checksums', 'wikipedia.txt')
_URL_FILE_STATS = load_checksum_stats(_URL_FILE_STATS_PATH)

_URLS = {
'wikipedia-en-20200620':
'https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz',
}


def get_url(lang, date):
Expand All @@ -55,64 +67,72 @@ def get_url(lang, date):

def try_import_wikiextractor():
try:
sys.path.append(_CURR_DIR)
import WikiExtractor
except ImportError:
try:
download(
'https://raw.githubusercontent.com/attardi/wikiextractor/'
'16186e290d9eb0eb3a3784c6c0635a9ed7e855c3/WikiExtractor.py',
'https://raw.githubusercontent.com/attardi/wikiextractor/master/WikiExtractor.py',
path=os.path.join(_CURR_DIR, 'WikiExtractor.py'),
sha1_hash='3c4896a837b75c476d23c037e8d6c7fdfd9a29eb')
sys.path.append(_CURR_DIR)
import WikiExtractor
except:
except BaseException:
raise ImportError('Cannot import WikiExtractor! You can download the "WikiExtractor.py"'
' in https://github.com/attardi/wikiextractor to {}'
.format(_CURR_DIR))
return WikiExtractor


class WikicorpusTextFormatting:
def __init__(self, wiki_path, output_filename, recursive=False):
self.wiki_path = wiki_path
self.recursive = recursive
self.output_filename = output_filename

# This puts one article per line
def merge(self):
with open(self.output_filename, mode='w', newline='\n') as ofile:
for dirname in glob.glob(os.path.join(self.wiki_path, '*'), recursive=False):
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=self.recursive):
print(filename)
article_lines = []
article_open = False

with open(filename, mode='r', newline='\n') as file:
for line in file:
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
article_open = False
for oline in article_lines[1:]:
if oline != '\n':
ofile.write(oline.rstrip() + " ")
ofile.write("\n\n")
article_lines = []
else:
if article_open:
article_lines.append(line)
def get_formatting_list(wiki_path, recursive=False):
"""
get formatting list of file names from extracted content
"""
filenames = []
for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False):
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=recursive):
filenames.append(filename)
return filenames


def merge(x):
"""
Puts one article per line
"""
file_list, output_filename = x
article_lines = []
article_open = False

with open(output_filename, mode='w', newline='\n') as ofile:
for filename in file_list:
with open(filename, mode='r', newline='\n') as file:
for line in file:
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
article_open = False
for oline in article_lines[1:]:
if oline != '\n':
ofile.write(oline.rstrip() + " ")
ofile.write("\n\n")
article_lines = []
else:
if article_open:
article_lines.append(line)


@DATA_PARSER_REGISTRY.register('prepare_wikipedia')
def get_parser():
parser = argparse.ArgumentParser(description='Download and Prepare the Wikipedia')
parser.add_argument('--mode', type=str,
default='download+format',
choices=['download', 'format', 'download+format'],
choices=['download', 'format', 'download+format', 'download_prepared'],
help='Specify the action you want the app to take. '
'"download" means to download the Wikipedia dump. '
'"format" means to extract the content and '
'format it for pretraining. "download+format" means to combine '
'these two options')
'these two options'
'"download_prepared" downloads the prepared txt from S3 directly')
parser.add_argument('--lang', type=str, default='en',
help='Language of the wikipedia dump file.'
'We only support English and Chinese for current version')
Expand All @@ -124,8 +144,13 @@ def get_parser():
parser.add_argument("-o", "--output", default="wikicorpus",
help="directory for downloaded or formatted files")
parser.add_argument("-b", "--bytes", default="100M",
help="maximum bytes per output file (default %(default)s)",
help="maximum bytes per extracted file (default %(default)s)",
metavar="n[KMG]")
parser.add_argument("--num_process", type=int, default=8,
help="number of processes for multiprocessing")
parser.add_argument("--num_out_files", type=int, default=1000,
help="Number of desired output files, where each is processed"
" independently by a worker.")
return parser


Expand All @@ -145,32 +170,75 @@ def download_wikicorpus(lang, date, output):
return output_file


def format_wikicorpus(input, output, bytes):
def format_wikicorpus(input, output, bytes, num_process, num_out_files):
if input is None:
raise ValueError('input file is empty.')
if not input.endswith('xml.bz2'):
raise ValueError('input file not *.xml.bz2.')
if not os.path.exists(output):
os.makedirs(output)

# Use WikiExtractor to extract the content
WikiExtractor = try_import_wikiextractor()
wiki_path = os.path.join(output, 'extracted')
sys.argv = ['prog', '-b', bytes, '-o', wiki_path, input]
WikiExtractor.main()
output_filename = os.path.join(output, 'wikicorpus_one_article_per_line.txt')
wiki_formatter = WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
wiki_formatter.merge()

# Merge extracted content into txt files
prepared_path = os.path.join(output, 'prepared_wikipedia')
if not os.path.exists(prepared_path):
os.makedirs(prepared_path)
filenames = get_formatting_list(wiki_path, recursive=True)
num_files = len(filenames)
num_out_files = min(num_out_files, num_files)
file_volume = math.ceil(num_files / num_out_files)
splited_files = [filenames[i: i + file_volume] for i in range(0, num_files, file_volume)]
num_out_files = len(splited_files)
output_files = [
os.path.join(
prepared_path,
"wikipedia-prepared-{}.txt".format(
str(i).zfill(4))) for i in range(num_out_files)]
print("All prepared raw text will be saved in {} txt files".format(num_out_files))
num_process = min(num_process, num_out_files)
print('Start preprocessing {} text files with {} cores'.format(num_files, num_process))
process_args = [(splited_files[i], output_files[i]) for i in range(num_out_files)]

start_time = time.time()
with multiprocessing.Pool(num_process) as pool:
f_read = 0
for i, _ in enumerate(pool.imap(merge, process_args)):
elapsed = time.time() - start_time
f_read += len(splited_files[i])
print("prepared {:} files, Elapsed: {:.2f}s, ETA: {:.2f}s, ".format(
f_read, elapsed, (num_files - f_read) / (num_files / elapsed)))
print("Done preparation within {:.2f} seconds".format(elapsed))


@DATA_MAIN_REGISTRY.register('prepare_wikipedia')
def main(args):
num_process = min(multiprocessing.cpu_count(), args.num_process)
if args.mode == 'download':
download_wikicorpus(args.lang, args.date, args.output)
elif args.mode == 'format':
format_wikicorpus(args.input, args.output, args.bytes)
format_wikicorpus(args.input, args.output, args.bytes, num_process, args.num_out_files)
elif args.mode == 'download+format':
downloaded_file = download_wikicorpus(args.lang, args.date, args.output)
format_wikicorpus(downloaded_file, args.output, args.bytes)
format_wikicorpus(downloaded_file, args.output, args.bytes, num_process, args.num_out_files)
elif args.mode == 'download_prepared':
url = _URLS['wikipedia-en-20200620']
file_hash = _URL_FILE_STATS[url]
target_download_location = os.path.join(args.output,
os.path.basename(url))
download(url, target_download_location, sha1_hash=file_hash)
tar = tarfile.open(target_download_location)
names = tar.getnames()
print('Start unarchiving raw text files')
start_time = time.time()
for name in names:
tar.extract(name, path=args.output)
tar.close()
print("Done unarchiving within {:.2f} seconds".format(time.time() - start_time))
else:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions scripts/datasets/url_checksums/wikipedia.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz 1e1d77c31622744aaa45ff5bfbfca397154d9186 5068070627
5 changes: 1 addition & 4 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def train(args):
seed=args.seed)
else:
raise NotImplementedError

batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack())
train_data_loader = gluon.data.DataLoader(data_train,
batch_sampler=train_batch_sampler,
Expand All @@ -387,7 +387,6 @@ def train(args):
log_start_time = time.time()
num_params, num_fixed_params = None, None
# TODO(sxjscience) Add a log metric class

accum_count = 0
loss_denom = 0
n_train_iters = 0
Expand Down Expand Up @@ -471,12 +470,10 @@ def train(args):
deduplicate=True)
if args.max_update > 0 and n_train_iters >= args.max_update:
break

if args.epochs > 0:
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)

avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))
Expand Down
Loading

0 comments on commit 1f9ad44

Please sign in to comment.