Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Bug Fix] trainer.update(1) should be used after loss.mean() is called #1000

Open
wants to merge 49 commits into
base: v0.x
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
de7b23d
clean slate for 1.x
szha Mar 18, 2020
01122db
[Numpy] Numpy version of GluonNLP (#1225)
sxjscience Jun 10, 2020
982a416
Fix bert cfg (#1245)
zheyuye Jun 11, 2020
789e2b9
fix download
sxjscience Jun 11, 2020
b714eac
[Numpy] Try to fix the CI (#1248)
sxjscience Jun 11, 2020
85b6f09
[Numpy] Add "match_tokens_with_char_spans" + Enable downloading from …
sxjscience Jun 16, 2020
ee1f0e3
[Numpy] Update QA Dataset and revise run_squad (#1250)
zheyuye Jun 18, 2020
e06ff01
Pin mxnet version range on CI (#1257)
leezu Jul 7, 2020
689eba9
[CI] AWS batch job tool for GluonNLP (Part I) (#1251)
szha Jul 7, 2020
cd48efd
Update codecov action to handle different OS and Python versions (#1254)
leezu Jul 8, 2020
83e1f13
Use Amazon S3 Transfer Acceleration (#1260)
leezu Jul 10, 2020
a646c34
[FEATURE] update backtranslation and add multinomial sampler (#1259)
hutao965 Jul 11, 2020
ea9152b
Fixes to make the CI more stable (#1265)
sxjscience Jul 16, 2020
70a1887
Update for Block API (#1261)
leezu Jul 17, 2020
9d83fe6
Fix parameter share regex (#1267)
leezu Jul 17, 2020
4743afc
Add fp16 support for Bert QA inference (#1264)
MoisesHer Jul 17, 2020
e78a24e
[CI] update batch to gluonnlp-dev (#1268)
szha Jul 18, 2020
3a0ed9f
[Numpy] Refactor Roberta (#1269)
zheyuye Jul 21, 2020
f407b8e
[CI] Batch cpu version (#1275)
szha Jul 22, 2020
57eb411
[Numpy] Fix conversion toolkits (#1274)
zheyuye Jul 23, 2020
74bd2ce
[Feature] Add FP16 inference support to NMT + Add BoundedBudgetSample…
hutao965 Jul 23, 2020
d76897b
Add embedding related methods in numpy version (#1263)
acphile Jul 28, 2020
4d43f82
add subversion/wget to docker, add readme (#1279)
szha Jul 28, 2020
3c87457
Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, EL…
sxjscience Jul 29, 2020
033214e
[Numpy] Fix SQuAD + Fix GLUE downloading (#1280)
sxjscience Jul 29, 2020
2294421
[Numpy Refactor] BART (#1282)
zheyuye Jul 30, 2020
1f9ad44
Horovod support for pretraining and fune-tuning squad (#1276)
zheyuye Aug 1, 2020
7e1f9d0
[DOC] Add the basic documentation for the embedding API (#1281)
acphile Aug 4, 2020
20af58f
Fix gelu (#1287)
zheyuye Aug 5, 2020
ded0f99
fix prepare_openwebtext (#1289)
ZiyueHuang Aug 6, 2020
c33e62e
[FEATURE]Horovod support for training transformer + add mirror data f…
hutao965 Aug 7, 2020
9e268c0
Fix electra (#1291)
zheyuye Aug 8, 2020
32e87d4
[Numpy] Benchmark the backbone models + Some fixes + Always use pytho…
sxjscience Aug 14, 2020
6ae558e
[FEATURE]Horovod support for training transformer (PART 2) (#1301)
hutao965 Aug 20, 2020
d8b68c6
[Numpy] Fix AWS Batch + Add Docker Support (#1302)
sxjscience Aug 20, 2020
d17ec4c
minor fix for run_electra.py & remove hybridization in the constructi…
ZiyueHuang Aug 22, 2020
99b35d8
Add Intro for batch + upload squad traininng command (#1305)
zheyuye Aug 22, 2020
d93356f
[MODEL] make beam search a hybrid block (#1310)
szha Aug 23, 2020
210dd0c
[Numpy] [Fix] Update README.md (#1306)
sxjscience Aug 23, 2020
b324ee6
[CI] Add GPU pytest + Append AWS Batch job submission to current pipe…
barry-jin Aug 24, 2020
3b14d69
[CI] Update unittests-gpu (#1313)
barry-jin Aug 24, 2020
dca17ee
automatically generate date suffix for dev versions (#1314)
szha Aug 25, 2020
39ec921
fix typo (#1317)
liuzh47 Aug 26, 2020
970318d
fix typo (#1318)
liuzh47 Aug 26, 2020
bba8697
[CI] Update GPU Test Workflow + Update Some Tests and README (#1316)
barry-jin Aug 28, 2020
66e5e05
fix https://github.com/dmlc/gluon-nlp/issues/1315 (#1319)
ZiyueHuang Aug 28, 2020
ff95fb4
[CI] Fix Source Reference Issues (#1332)
barry-jin Sep 1, 2020
1bd85b6
[BUGFIX] fix valid candidates issue (#1323)
liuzh47 Sep 1, 2020
189bbdc
[MODEL] convert gpt2 model (#1328)
hutao965 Sep 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[FEATURE]Horovod support for training transformer (PART 2) (#1301)
* set default shuffle=True for boundedbudgetsampler

* fix

* fix log condition

* use horovod to train transformer

* fix

* add mirror wmt dataset

* fix

* rename wmt.txt to wmt.json and remove part of urls

* fix

* tuning params

* use get_repo_url()

* update average checkpoint cli

* paste result of transformer large

* fix

* fix logging in train_transformer

* fix

* fix

* fix

* add transformer base config

* fix

* change to wmt14/full

* print more sacrebleu info

* fix

* add test for num_parts and update behavior of boundedbudgetsampler with even_size

* fix

* fix

* fix

* fix logging when using horovd

* udpate doc of train transformer

* add test case for fail downloading

* add a ShardedIterator

* fix

* fix

* fix

* change mpirun to horovodrun

* make the horovod command complete

* use print(sampler) to cover the codes of __repr__ func

* empty commit

* add test case test_sharded_iterator_even_size

Co-authored-by: Hu <[email protected]>
hutao965 and Hu authored Aug 20, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 6ae558e1932ceaecf54409fa484c8eeede241d98
5 changes: 3 additions & 2 deletions scripts/datasets/machine_translation/wmt2014_ende.sh
Original file line number Diff line number Diff line change
@@ -12,8 +12,8 @@ nlp_data prepare_wmt \
# We use sacrebleu to fetch the dev set (newstest2013) and test set (newstest2014)
sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/dev.raw.${SRC}
sacrebleu -t wmt13 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/dev.raw.${TGT}
sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC}
sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}
sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo src > ${SAVE_PATH}/test.raw.${SRC}
sacrebleu -t wmt14/full -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}


# Clean and tokenize the training + dev corpus
@@ -34,6 +34,7 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-corpus dev.raw.${TGT} \
--min-num-words 1 \
--max-num-words 100 \
--max-ratio 1.5 \
--src-save-path dev.tok.${SRC} \
--tgt-save-path dev.tok.${TGT}

50 changes: 38 additions & 12 deletions scripts/machine_translation/README.md
Original file line number Diff line number Diff line change
@@ -30,9 +30,36 @@ python3 train_transformer.py \
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_base \
--lr 0.002 \
--batch_size 2700 \
--num_averages 5 \
--warmup_steps 4000 \
--sampler BoundedBudgetSampler \
--max_num_tokens 2700 \
--max_update 15000 \
--save_interval_update 500 \
--warmup_steps 6000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
```

Or training via horovod
```
horovodrun -np 4 -H localhost:4 python3 train_transformer.py \
--comm_backend horovod \
--train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \
--train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \
--dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \
--dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \
--src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_base \
--lr 0.002 \
--sampler BoundedBudgetSampler \
--max_num_tokens 2700 \
--max_update 15000 \
--save_interval_update 500 \
--warmup_steps 6000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
@@ -42,18 +69,16 @@ Use the average_checkpoint cli to average the last 10 checkpoints

```bash
gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \
--begin 21 \
--end 30 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
--begin 30 \
--end 39 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch_avg_30_39.params
```


Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python3 evaluate_transformer.py \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/epoch_avg_30_39.params \
--src_lang en \
--tgt_lang de \
--cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
@@ -110,7 +135,6 @@ gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_A
Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python3 evaluate_transformer.py \
--param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
--src_lang en \
@@ -131,16 +155,18 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU):

- transformer_base

(test bleu / valid bleu)
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | - | - | - | - |
| yttm | | 26.50/26.29 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |

- transformer_wmt_en_de_big

(test bleu / valid bleu)
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | 27.99 | - | - | - |
| yttm | | 27.93/26.82 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |
13 changes: 10 additions & 3 deletions scripts/machine_translation/evaluate_transformer.py
Original file line number Diff line number Diff line change
@@ -247,10 +247,17 @@ def evaluate(args):
of.write('\n'.join(pred_sentences))
of.write('\n')

sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} Avg NLL={}, Perplexity={}'
sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} '
'({:2.1f} {:2.1f} {:2.1f} {:2.1f}) '
'(BP={:.3f}, ratio={:.3f}, syslen={}, reflen={}), '
'Avg NLL={}, Perplexity={}'
.format(end_eval_time - start_eval_time, len(all_tgt_lines),
sacrebleu_out.score, avg_nll_loss, np.exp(avg_nll_loss)))
sacrebleu_out.score,
*sacrebleu_out.precisions,
sacrebleu_out.bp, sacrebleu_out.sys_len / sacrebleu_out.ref_len,
sacrebleu_out.sys_len, sacrebleu_out.ref_len,
avg_nll_loss, np.exp(avg_nll_loss)))
# inference only
else:
with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
45 changes: 27 additions & 18 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
@@ -50,7 +50,8 @@
LinearWidthBucket,
ExpWidthBucket,
FixedBucketSampler,
BoundedBudgetSampler
BoundedBudgetSampler,
ShardedIterator
)
import gluonnlp.data.batchify as bf
from gluonnlp.data import Vocab
@@ -179,6 +180,7 @@ def parse_args():
logging.info(args)
return args


def validation(model, data_loader, ctx_l):
"""Validate the model on the dataset

@@ -231,14 +233,16 @@ def load_dataset_with_cache(src_corpus_path: str,
tgt_corpus_path: str,
src_tokenizer: BaseTokenizerWithVocab,
tgt_tokenizer: BaseTokenizerWithVocab,
overwrite_cache: bool):
overwrite_cache: bool,
local_rank: int):
# TODO online h5py multi processing encode (Tao)
src_md5sum = md5sum(src_corpus_path)
tgt_md5sum = md5sum(tgt_corpus_path)
cache_filepath = os.path.join(CACHE_PATH,
'{}_{}.cache.npz'.format(src_md5sum[:6], tgt_md5sum[:6]))
if os.path.exists(cache_filepath) and not overwrite_cache:
logging.info('Load cache from {}'.format(cache_filepath))
if local_rank == 0:
logging.info('Load cache from {}'.format(cache_filepath))
npz_data = np.load(cache_filepath, allow_pickle=True)
src_data, tgt_data = npz_data['src_data'][:], npz_data['tgt_data'][:]
else:
@@ -288,7 +292,7 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path):


def train(args):
store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm(
_, num_parts, rank, local_rank, _, ctx_l = init_comm(
args.comm_backend, args.gpus)
src_tokenizer = create_tokenizer(args.src_tokenizer,
args.src_subword_model_path,
@@ -302,12 +306,14 @@ def train(args):
args.train_tgt_corpus,
src_tokenizer,
tgt_tokenizer,
args.overwrite_cache)
args.overwrite_cache,
local_rank)
dev_src_data, dev_tgt_data = load_dataset_with_cache(args.dev_src_corpus,
args.dev_tgt_corpus,
src_tokenizer,
tgt_tokenizer,
args.overwrite_cache)
args.overwrite_cache,
local_rank)
data_train = gluon.data.SimpleDataset(
[(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i)
for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data))])
@@ -363,9 +369,9 @@ def train(args):
train_batch_sampler = BoundedBudgetSampler(lengths=[(ele[2], ele[3]) for ele in data_train],
max_num_tokens=args.max_num_tokens,
max_num_sentences=args.max_num_sentences,
seed=args.seed,
num_parts=num_parts,
part_index=rank)
seed=args.seed)
if num_parts > 1:
train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank)
elif args.sampler == 'FixedBucketSampler':
if args.comm_backend == 'horovod':
raise NotImplementedError('FixedBucketSampler does not support horovod at present')
@@ -390,8 +396,7 @@ def train(args):
else:
raise NotImplementedError

if local_rank == 0:
logging.info(train_batch_sampler)
logging.info(train_batch_sampler)

batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack())
train_data_loader = gluon.data.DataLoader(data_train,
@@ -483,27 +488,31 @@ def train(args):
log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
'throughput={:.2f}K wps, wc={:.2f}K, LR={}'
.format(epoch_id, processed_batch_num * num_parts, len(train_data_loader),
log_avg_loss, np.exp(log_avg_loss),
.format(epoch_id, processed_batch_num * num_parts,
len(train_data_loader), log_avg_loss, np.exp(log_avg_loss),
wps / 1000, log_wc / 1000, trainer.learning_rate))
log_start_time = time.time()
log_avg_loss = 0
log_loss_denom = 0
log_wc = 0
if local_rank == 0 and \
(args.max_update > 0 and n_train_iters % args.save_interval_update == 0):
n_update = n_train_iters // args.save_interval_update
model.save_parameters(os.path.join(args.save_dir,
'update{:d}.params'.format(n_train_iters // args.save_interval_update)),
'update{:d}.params'.format(n_update)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Update {}] validation loss/ppl={:.4f}/{:.4f}'
.format(n_update, avg_valid_loss, np.exp(avg_valid_loss)))
if args.max_update > 0 and n_train_iters >= args.max_update:
break
if local_rank == 0 and args.epochs > 0:
if local_rank == 0:
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

if args.max_update > 0 and n_train_iters >= args.max_update:
break
2 changes: 1 addition & 1 deletion scripts/question_answering/README.md
Original file line number Diff line number Diff line change
@@ -74,7 +74,7 @@ We could speed up multi-GPU training via horovod.
Compared to KVStore, training RoBERTa Large model on SQuAD 2.0 with 3 epochs will save roughly 1/4 training resources (8.48 vs 11.32 hours). Results may vary depending on the training instances.

```bash
mpirun -np 4 -H localhost:4 python3 run_squad.py \
horovodrun -np 4 -H localhost:4 python3 run_squad.py \
--comm_backend horovod \
...
```
89 changes: 71 additions & 18 deletions src/gluonnlp/data/sampler.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@
import math
import random
import warnings
import random
import numpy as np
import abc
from typing import Union, Sequence, Optional, List
@@ -285,20 +284,14 @@ class BoundedBudgetSampler(BaseSampler):
Whether to shuffle the batches.
seed
The seed of the sampler
num_parts
Number of partitions which the data is split into (default: 1)
part_index
The index of the part to read from
"""
def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]],
max_num_tokens: int = -1, max_num_sentences: int = -1,
required_batch_size_multiple: int = 1,
shuffle: bool = True, seed: Optional[int] = None,
num_parts: int = 1, part_index: int = 0):
shuffle: bool = True, seed: Optional[int] = None):
assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.'
assert max_num_tokens > 0 or max_num_sentences > 0, \
'One of max_num_tokens and max_num_sentences must be larger than 0'
assert part_index < num_parts, 'part_index should be less than num_parts'
self._lengths = np.array(lengths)
if self._lengths.ndim == 2:
self._lengths = self._lengths.max(axis=1)
@@ -308,8 +301,6 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]],
self._batches = []
self._shuffle = shuffle
self._rng = np.random.RandomState(seed)
self._num_parts = num_parts
self._part_index = part_index
# sort
self._indices = self._indices[np.argsort(self._lengths, kind='mergesort')]
batch = []
@@ -335,16 +326,12 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]],
)
batch.append(index)
if len(batch) > 0:
self._batches.append(np.array(batch))
self._batches.append(np.array(batch))

def __iter__(self):
if self._shuffle:
self._rng.shuffle(self._batches)
part_batches = []
for i in range(len(self._batches)):
if i % self._num_parts == self._part_index:
part_batches.append(self._batches[i])
for batch in part_batches:
for batch in self._batches:
yield batch

def __len__(self):
@@ -353,7 +340,7 @@ def __len__(self):
def __repr__(self):
ret = '{name}(\n' \
' sample_num={sample_num},\n' \
' batch_num={batch_num}\n'\
' batch_num={batch_num},\n' \
')'\
.format(name=self.__class__.__name__,
sample_num=len(self._lengths),
@@ -671,3 +658,69 @@ def __iter__(self):

def __len__(self):
return self._len * self._repeat


class ShardedIterator(BaseSampler):
r"""A sharded wrapper around an iterable (padded to length).
Parameters
----------
sampler
num_parts
Number of partitions which the data is split into (default: 1)
part_index
The index of the part to read from
even_size
If the number of batches is not even across all partitions, sample a few extra batches
for the ones with fewer batches.
"""
def __init__(self, sampler: BaseSampler,
num_parts: int = 1,
part_index: int = 0,
even_size: bool = False):
assert part_index < num_parts, 'part_index should be less than num_parts'
self._sampler = sampler
self._num_parts = num_parts
self._part_index = part_index
self._even_size = even_size

length = len(sampler)
if not even_size:
part_len = length // num_parts
remaining = length % num_parts
self._start = part_len * part_index + min(part_index, remaining)
self._end = self._start + part_len + (part_index < remaining)
self._part_len = self._end - self._start
else:
part_len = int(length + num_parts - 1) // num_parts
self._start = part_len * part_index
self._end = self._start + part_len
self._start = self._start if self._start < length else length
self._end = self._end if self._end < length else length
self._part_len = part_len

def __iter__(self):
batches = list(self._sampler)
part_batches = batches[self._start:self._end]
if self._even_size and len(part_batches) < self._part_len:
candidates = random.sample(batches, k=self._part_len-len(part_batches))
part_batches.extend(candidates)
for batch in part_batches:
yield batch

def __len__(self):
return len(self._sampler)

def __repr__(self):
ret = '{name}(\n' \
' batch_num={batch_num},\n' \
' part_batch_num={part_batch_num},\n' \
' num_parts={num_parts},\n' \
' part_index={part_index},\n' \
')'\
.format(name=self.__class__.__name__,
batch_num=len(self._sampler),
part_batch_num=self._part_len,
num_parts=self._num_parts,
part_index=self._part_index)
return ret
56 changes: 56 additions & 0 deletions tests/test_data_sampler.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_s
ratio=ratio, shuffle=shuffle,
use_average_length=use_average_length,
bucket_scheme=bucket_scheme)
# here we print sampler to cover the __repr__ func of the sampler
print(sampler)
total_sampled_ids = []
for batch_sample_ids in sampler:
@@ -147,3 +148,58 @@ def test_bounded_budget_sampler(seq_lengths, max_num_tokens, max_num_sentences,
total_sampled_ids.extend(batch_sample_ids)
assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
assert sorted(total_sampled_ids) == list(range(len(total_sampled_ids)))


@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)],
[(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]])
@pytest.mark.parametrize('max_num_tokens', [200, 500])
@pytest.mark.parametrize('max_num_sentences', [-1, 5])
@pytest.mark.parametrize('required_batch_size_multiple', [1, 5])
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('num_parts', [1, 4])
@pytest.mark.parametrize('even_size', [False])
def test_sharded_iterator(seq_lengths, max_num_tokens, max_num_sentences,
required_batch_size_multiple, shuffle,
num_parts, even_size):
total_sampled_ids = []
for part_index in range(num_parts):
# we use independent (but same) sampler to simulate multi process situation
sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences,
required_batch_size_multiple, shuffle, seed=100)
sharded_iter = s.ShardedIterator(sampler, num_parts, part_index, even_size)
print(sharded_iter)
for batch_sample_ids in sharded_iter:
total_sampled_ids.extend(batch_sample_ids)
assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
assert sorted(total_sampled_ids) == list(range(len(total_sampled_ids)))


@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)],
[(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]])
@pytest.mark.parametrize('max_num_tokens', [200, 500])
@pytest.mark.parametrize('max_num_sentences', [-1, 5])
@pytest.mark.parametrize('required_batch_size_multiple', [1, 5])
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('num_parts', [1, 4])
@pytest.mark.parametrize('even_size', [True])
def test_sharded_iterator_even_size(seq_lengths, max_num_tokens, max_num_sentences,
required_batch_size_multiple, shuffle,
num_parts, even_size):
total_sampled_ids = []
first_batch_num = None
for part_index in range(num_parts):
batch_num = 0
# we use independent (but same) sampler to simulate multi process situation
sampler = s.BoundedBudgetSampler(seq_lengths, max_num_tokens, max_num_sentences,
required_batch_size_multiple, shuffle, seed=100)
sharded_iter = s.ShardedIterator(sampler, num_parts, part_index, even_size)
print(sharded_iter)
for batch_sample_ids in sharded_iter:
total_sampled_ids.extend(batch_sample_ids)
batch_num += 1
# assert batch num of each parts equals
if first_batch_num is None:
first_batch_num = batch_num
else:
assert first_batch_num == batch_num
assert len(set(total_sampled_ids)) == N
9 changes: 9 additions & 0 deletions tests/test_utils_misc.py
Original file line number Diff line number Diff line change
@@ -109,6 +109,15 @@ def test_download_https(overwrite):
overwrite=overwrite)


@pytest.mark.remote_required
@pytest.mark.parametrize('overwrite', [False, True])
def test_download_non_existing(overwrite):
with pytest.raises(RuntimeError, match='Failed downloading url'):
verify_download(url='https://commoncrawl.s3.amazonaws.com/crawl-data/CC-MAIN-2014-41/non_existing',
sha1_hash='foo',
overwrite=overwrite)


def test_logging_config():
logger = logging.getLogger(__name__)
with tempfile.TemporaryDirectory() as root: