Skip to content

Commit

Permalink
Fix FancyIterator behavior during evaluation (#51)
Browse files Browse the repository at this point in the history
* Added truncate option to FancyIterator

* Added checks + tests related to truncation in FancyIterator

* Raise ConfigurationError if batch size is too big
* Test FancyIterator handles truncation properly

* Updated configs to stop truncating data during validation

* Handle edge case in splitter + cleanup

* Disable truncation in kglm tests to ensure nothing weird happens on padded batches
  • Loading branch information
rloganiv authored Sep 19, 2019
1 parent 0d2332c commit dc08cf2
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 35 deletions.
29 changes: 22 additions & 7 deletions experiments/kglm-disc.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
"vocabulary": {
"type": "extended",
"extend": false,
"directory_path": "data/enhanced-wikitext-2/vocab"
"directory_path": "data/linked-wikitext-2/vocab"
},
"dataset_reader": {
"type": "enhanced-wikitext-kglm",
"alias_database_path": "data/enhanced-wikitext-2/alias.pkl",
"alias_database_path": "data/linked-wikitext-2/alias.pkl",
"mode": "discriminative"
},
"train_data_path": "data/enhanced-wikitext-2/train.jsonl",
"validation_data_path": "data/enhanced-wikitext-2/valid.jsonl",
"train_data_path": "data/linked-wikitext-2/train.jsonl",
"validation_data_path": "data/linked-wikitext-2/valid.jsonl",
"model": {
"type": "kglm-disc",
"token_embedder": {
Expand All @@ -26,7 +26,7 @@
"token_embedders": {
"entity_ids": {
"type": "embedding",
"pretrained_file": "data/enhanced-wikitext-2/embeddings.entities.txt",
"pretrained_file": "data/linked-wikitext-2/embeddings.entities.txt",
"embedding_dim": 256,
"trainable": false,
"vocab_namespace": "entity_ids"
Expand All @@ -37,14 +37,14 @@
"token_embedders": {
"relations": {
"type": "embedding",
"pretrained_file": "data/enhanced-wikitext-2/embeddings.relations.txt",
"pretrained_file": "data/linked-wikitext-2/embeddings.relations.txt",
"embedding_dim": 256,
"trainable": true,
"vocab_namespace": "relations"
}
}
},
"knowledge_graph_path": "data/enhanced-wikitext-2/knowledge_graph.pkl",
"knowledge_graph_path": "data/linked-wikitext-2/knowledge_graph.pkl",
"use_shortlist": false,
"hidden_size": 1150,
"num_layers": 3,
Expand All @@ -69,6 +69,21 @@
"shortlist_inds"
]
},
"validation_iterator": {
"type": "fancy",
"batch_size": 60,
"split_size": 70,
"splitting_keys": [
"source",
"mention_type",
"raw_entity_ids",
"entity_ids",
"parent_ids",
"relations",
"shortlist_inds"
],
"truncate": false
},
"trainer": {
"type": "lm",
"num_epochs": 500,
Expand Down
31 changes: 24 additions & 7 deletions experiments/kglm.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
"vocabulary": {
"type": "extended",
"extend": false,
"directory_path": "data/enhanced-wikitext-2/vocab"
"directory_path": "data/linked-wikitext-2/vocab"
},
"dataset_reader": {
"type": "enhanced-wikitext-kglm",
"alias_database_path": "data/enhanced-wikitext-2/alias.pkl"
"alias_database_path": "data/linked-wikitext-2/alias.pkl"
},
"train_data_path": "data/enhanced-wikitext-2/train.jsonl",
"validation_data_path": "data/enhanced-wikitext-2/valid.jsonl",
"train_data_path": "data/linked-wikitext-2/train.jsonl",
"validation_data_path": "data/linked-wikitext-2/valid.jsonl",
"model": {
"type": "kglm",
"token_embedder": {
Expand All @@ -25,7 +25,7 @@
"token_embedders": {
"entity_ids": {
"type": "embedding",
"pretrained_file": "data/enhanced-wikitext-2/embeddings.entities.txt",
"pretrained_file": "data/linked-wikitext-2/embeddings.entities.txt",
"embedding_dim": 256,
"trainable": false,
"vocab_namespace": "entity_ids"
Expand All @@ -36,7 +36,7 @@
"token_embedders": {
"relations": {
"type": "embedding",
"pretrained_file": "data/enhanced-wikitext-2/embeddings.relations.txt",
"pretrained_file": "data/linked-wikitext-2/embeddings.relations.txt",
"embedding_dim": 256,
"trainable": true,
"vocab_namespace": "relations"
Expand All @@ -48,7 +48,7 @@
"input_size": 400,
"hidden_size": 400
},
"knowledge_graph_path": "data/enhanced-wikitext-2/knowledge_graph.pkl",
"knowledge_graph_path": "data/linked-wikitext-2/knowledge_graph.pkl",
"use_shortlist": false,
"hidden_size": 1150,
"num_layers": 3,
Expand All @@ -75,6 +75,23 @@
"alias_copy_inds"
]
},
"validation_iterator": {
"type": "fancy",
"batch_size": 60,
"split_size": 70,
"splitting_keys": [
"source",
"target",
"mention_type",
"raw_entity_ids",
"entity_ids",
"parent_ids",
"relations",
"shortlist_inds",
"alias_copy_inds"
],
"truncate": false
},
"trainer": {
"type": "lm",
"num_epochs": 500,
Expand Down
45 changes: 28 additions & 17 deletions kglm/data/iterators/fancy_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
from typing import Deque, Dict, Iterable, Iterator, List, Tuple, Union

from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset import Batch
from allennlp.data.fields import Field, ListField, MetadataField, TextField
from allennlp.data.instance import Instance
Expand All @@ -23,45 +24,51 @@
class FancyIterator(DataIterator):
"""Fancy cause it's really expensive."""
def __init__(self,
splitting_keys: List[str],
batch_size: int,
split_size: int,
batch_size: int = 32,
splitting_keys: List[str],
truncate: bool = True,
instances_per_epoch: int = None,
max_instances_in_memory: int = None,
cache_instances: bool = False,
track_epoch: bool = False,
maximum_samples_per_batch: Tuple[str, int] = None) -> None:
super(FancyIterator, self).__init__(
batch_size=batch_size,
instances_per_epoch=instances_per_epoch,
max_instances_in_memory=max_instances_in_memory,
cache_instances=cache_instances,
track_epoch=track_epoch,
maximum_samples_per_batch=maximum_samples_per_batch)
batch_size=batch_size,
instances_per_epoch=instances_per_epoch,
max_instances_in_memory=max_instances_in_memory,
cache_instances=cache_instances,
track_epoch=track_epoch,
maximum_samples_per_batch=maximum_samples_per_batch)
self._splitting_keys = splitting_keys
self._split_size = split_size
self._eval = False

def eval(self):
self._eval = True
self._truncate = truncate

def __call__(self,
instances: Iterable[Instance],
num_epochs: int = None,
shuffle: bool = False) -> Iterator[TensorDict]:

key = id(instances)
starting_epoch = self._epochs[key]

# In order to ensure that we are (almost) constantly streaming data to the model we
# need to have all of the instances in memory ($$$)
instance_list = list(instances)

if (self._batch_size > len(instance_list)) and self._truncate:
raise ConfigurationError('FancyIterator will not return any data when the batch size '
'is larger than number of instances and truncation is enabled. '
'To fix this either use a smaller batch size (better for '
'training) or disable truncation (better for validation).')

if num_epochs is None:
epochs: Iterable[int] = itertools.count(starting_epoch)
else:
epochs = range(starting_epoch, starting_epoch + num_epochs)

for epoch in epochs:

# In order to ensure that we are (almost) constantly streaming data to the model we
# need to have all of the instances in memory ($$$)
instance_list = list(instances)
if shuffle:
random.shuffle(instance_list)

Expand Down Expand Up @@ -104,7 +111,11 @@ def __call__(self,
def _split(self, instance: Instance) -> Tuple[List[Instance], int]:
# Determine the size of the sequence inside the instance.
true_length = len(instance['source'])
padded_length = self._split_size * (true_length // self._split_size)
if (true_length % self._split_size) != 0:
offset = 1
else:
offset = 0
padded_length = self._split_size * (true_length // self._split_size + offset)

# Determine the split indices.
split_indices = list(range(0, true_length, self._split_size))
Expand Down Expand Up @@ -162,7 +173,7 @@ def _generate_batches(self,
except IndexError: # A queue is depleted
# If we're training, we break to avoid densely padded inputs (since this biases
# the model to overfit the longer sequences).
if not self._eval:
if self._truncate:
return
# But if we're evaluating we do want the padding, so that we don't skip anything.
else:
Expand Down
82 changes: 82 additions & 0 deletions kglm/tests/data/fancy_iterator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# pylint: disable=no-self-use
from typing import List

from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data import Instance, Token, Vocabulary
from allennlp.data.fields import TextField
from allennlp.data.token_indexers import SingleIdTokenIndexer

from kglm.data.iterators import FancyIterator


class FancyIteratorTest(AllenNlpTestCase):
def setUp(self):
super().setUp()
self.token_indexers = {"tokens": SingleIdTokenIndexer()}
self.vocab = Vocabulary()
self.this_index = self.vocab.add_token_to_namespace('this')
self.is_index = self.vocab.add_token_to_namespace('is')
self.a_index = self.vocab.add_token_to_namespace('a')
self.sentence_index = self.vocab.add_token_to_namespace('sentence')
self.another_index = self.vocab.add_token_to_namespace('another')
self.yet_index = self.vocab.add_token_to_namespace('yet')
self.very_index = self.vocab.add_token_to_namespace('very')
self.long_index = self.vocab.add_token_to_namespace('long')
instances = [
self.create_instance(["this", "is", "a", "sentence"]),
self.create_instance(["this", "is", "another", "sentence"]),
self.create_instance(["yet", "another", "sentence"]),
self.create_instance(["this", "is", "a", "very", "very", "very", "very", "long", "sentence"]),
self.create_instance(["sentence"]),
]

self.instances = instances

def create_instance(self, str_tokens: List[str]):
tokens = [Token(t) for t in str_tokens]
instance = Instance({'source': TextField(tokens, self.token_indexers)})
return instance

def test_truncate(self):
# Checks that the truncate parameter works as intended.

# Since split size is less than the length of the "very ... very long" sentence, the
# iterator should return one batch when the truncation is enabled.
split_size = 4
truncated_iterator = FancyIterator(batch_size=5,
split_size=split_size,
splitting_keys=['source'],
truncate=True)
truncated_iterator.index_with(self.vocab)
batches = list(truncated_iterator(self.instances, num_epochs=1))
assert len(batches) == 1

# When truncation is disabled the iterator should return 3 batches instead.
non_truncated_iterator = FancyIterator(batch_size=5,
split_size=split_size,
splitting_keys=['source'],
truncate=False)
non_truncated_iterator.index_with(self.vocab)
batches = list(non_truncated_iterator(self.instances, num_epochs=1))
assert len(batches) == 3

# When the batch size is larger than the number of instances, truncation will the iterator
# to return zero batches of data (since some of the instances in the batch would consist
# entirely of padding). Check that the iterator raises an error in this case.
invalid_iterator = FancyIterator(batch_size=6,
split_size=split_size,
splitting_keys=['source'],
truncate=True)
invalid_iterator.index_with(self.vocab)
with self.assertRaises(ConfigurationError):
batches = list(invalid_iterator(self.instances, num_epochs=1))

# If truncation is disabled then this should not cause an issue
valid_iterator = FancyIterator(batch_size=6,
split_size=split_size,
splitting_keys=['source'],
truncate=False)
valid_iterator.index_with(self.vocab)
batches = list(valid_iterator(self.instances, num_epochs=1))
assert len(batches) == 3
3 changes: 2 additions & 1 deletion kglm/tests/fixtures/training_config/kglm-disc.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
"parent_ids",
"relations",
"shortlist_inds"
]
],
"truncate": false
},
"trainer": {
"type": "lm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
"parent_ids",
"relations",
"shortlist_inds"
]
],
"truncate": false
},
"trainer": {
"type": "lm",
Expand Down
3 changes: 2 additions & 1 deletion kglm/tests/fixtures/training_config/kglm.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
"relations",
"shortlist_inds",
"alias_copy_inds"
]
],
"truncate": false
},
"trainer": {
"type": "lm",
Expand Down
3 changes: 2 additions & 1 deletion kglm/tests/fixtures/training_config/kglm.no-shortlist.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
"relations",
"shortlist_inds",
"alias_copy_inds"
]
],
"truncate": false
},
"trainer": {
"type": "lm",
Expand Down

0 comments on commit dc08cf2

Please sign in to comment.