diff --git a/experiments/kglm-disc.jsonnet b/experiments/kglm-disc.jsonnet index 9ad5d6c..c87de9a 100644 --- a/experiments/kglm-disc.jsonnet +++ b/experiments/kglm-disc.jsonnet @@ -2,15 +2,15 @@ "vocabulary": { "type": "extended", "extend": false, - "directory_path": "data/enhanced-wikitext-2/vocab" + "directory_path": "data/linked-wikitext-2/vocab" }, "dataset_reader": { "type": "enhanced-wikitext-kglm", - "alias_database_path": "data/enhanced-wikitext-2/alias.pkl", + "alias_database_path": "data/linked-wikitext-2/alias.pkl", "mode": "discriminative" }, - "train_data_path": "data/enhanced-wikitext-2/train.jsonl", - "validation_data_path": "data/enhanced-wikitext-2/valid.jsonl", + "train_data_path": "data/linked-wikitext-2/train.jsonl", + "validation_data_path": "data/linked-wikitext-2/valid.jsonl", "model": { "type": "kglm-disc", "token_embedder": { @@ -26,7 +26,7 @@ "token_embedders": { "entity_ids": { "type": "embedding", - "pretrained_file": "data/enhanced-wikitext-2/embeddings.entities.txt", + "pretrained_file": "data/linked-wikitext-2/embeddings.entities.txt", "embedding_dim": 256, "trainable": false, "vocab_namespace": "entity_ids" @@ -37,14 +37,14 @@ "token_embedders": { "relations": { "type": "embedding", - "pretrained_file": "data/enhanced-wikitext-2/embeddings.relations.txt", + "pretrained_file": "data/linked-wikitext-2/embeddings.relations.txt", "embedding_dim": 256, "trainable": true, "vocab_namespace": "relations" } } }, - "knowledge_graph_path": "data/enhanced-wikitext-2/knowledge_graph.pkl", + "knowledge_graph_path": "data/linked-wikitext-2/knowledge_graph.pkl", "use_shortlist": false, "hidden_size": 1150, "num_layers": 3, @@ -69,6 +69,21 @@ "shortlist_inds" ] }, + "validation_iterator": { + "type": "fancy", + "batch_size": 60, + "split_size": 70, + "splitting_keys": [ + "source", + "mention_type", + "raw_entity_ids", + "entity_ids", + "parent_ids", + "relations", + "shortlist_inds" + ], + "truncate": false + }, "trainer": { "type": "lm", "num_epochs": 500, diff --git a/experiments/kglm.jsonnet b/experiments/kglm.jsonnet index 2346b54..490234c 100644 --- a/experiments/kglm.jsonnet +++ b/experiments/kglm.jsonnet @@ -2,14 +2,14 @@ "vocabulary": { "type": "extended", "extend": false, - "directory_path": "data/enhanced-wikitext-2/vocab" + "directory_path": "data/linked-wikitext-2/vocab" }, "dataset_reader": { "type": "enhanced-wikitext-kglm", - "alias_database_path": "data/enhanced-wikitext-2/alias.pkl" + "alias_database_path": "data/linked-wikitext-2/alias.pkl" }, - "train_data_path": "data/enhanced-wikitext-2/train.jsonl", - "validation_data_path": "data/enhanced-wikitext-2/valid.jsonl", + "train_data_path": "data/linked-wikitext-2/train.jsonl", + "validation_data_path": "data/linked-wikitext-2/valid.jsonl", "model": { "type": "kglm", "token_embedder": { @@ -25,7 +25,7 @@ "token_embedders": { "entity_ids": { "type": "embedding", - "pretrained_file": "data/enhanced-wikitext-2/embeddings.entities.txt", + "pretrained_file": "data/linked-wikitext-2/embeddings.entities.txt", "embedding_dim": 256, "trainable": false, "vocab_namespace": "entity_ids" @@ -36,7 +36,7 @@ "token_embedders": { "relations": { "type": "embedding", - "pretrained_file": "data/enhanced-wikitext-2/embeddings.relations.txt", + "pretrained_file": "data/linked-wikitext-2/embeddings.relations.txt", "embedding_dim": 256, "trainable": true, "vocab_namespace": "relations" @@ -48,7 +48,7 @@ "input_size": 400, "hidden_size": 400 }, - "knowledge_graph_path": "data/enhanced-wikitext-2/knowledge_graph.pkl", + "knowledge_graph_path": "data/linked-wikitext-2/knowledge_graph.pkl", "use_shortlist": false, "hidden_size": 1150, "num_layers": 3, @@ -75,6 +75,23 @@ "alias_copy_inds" ] }, + "validation_iterator": { + "type": "fancy", + "batch_size": 60, + "split_size": 70, + "splitting_keys": [ + "source", + "target", + "mention_type", + "raw_entity_ids", + "entity_ids", + "parent_ids", + "relations", + "shortlist_inds", + "alias_copy_inds" + ], + "truncate": false + }, "trainer": { "type": "lm", "num_epochs": 500, diff --git a/kglm/data/iterators/fancy_iterator.py b/kglm/data/iterators/fancy_iterator.py index e918612..e54643b 100644 --- a/kglm/data/iterators/fancy_iterator.py +++ b/kglm/data/iterators/fancy_iterator.py @@ -5,6 +5,7 @@ import random from typing import Deque, Dict, Iterable, Iterator, List, Tuple, Union +from allennlp.common.checks import ConfigurationError from allennlp.data.dataset import Batch from allennlp.data.fields import Field, ListField, MetadataField, TextField from allennlp.data.instance import Instance @@ -23,35 +24,44 @@ class FancyIterator(DataIterator): """Fancy cause it's really expensive.""" def __init__(self, - splitting_keys: List[str], + batch_size: int, split_size: int, - batch_size: int = 32, + splitting_keys: List[str], + truncate: bool = True, instances_per_epoch: int = None, max_instances_in_memory: int = None, cache_instances: bool = False, track_epoch: bool = False, maximum_samples_per_batch: Tuple[str, int] = None) -> None: super(FancyIterator, self).__init__( - batch_size=batch_size, - instances_per_epoch=instances_per_epoch, - max_instances_in_memory=max_instances_in_memory, - cache_instances=cache_instances, - track_epoch=track_epoch, - maximum_samples_per_batch=maximum_samples_per_batch) + batch_size=batch_size, + instances_per_epoch=instances_per_epoch, + max_instances_in_memory=max_instances_in_memory, + cache_instances=cache_instances, + track_epoch=track_epoch, + maximum_samples_per_batch=maximum_samples_per_batch) self._splitting_keys = splitting_keys self._split_size = split_size - self._eval = False - - def eval(self): - self._eval = True + self._truncate = truncate def __call__(self, instances: Iterable[Instance], num_epochs: int = None, shuffle: bool = False) -> Iterator[TensorDict]: + key = id(instances) starting_epoch = self._epochs[key] + # In order to ensure that we are (almost) constantly streaming data to the model we + # need to have all of the instances in memory ($$$) + instance_list = list(instances) + + if (self._batch_size > len(instance_list)) and self._truncate: + raise ConfigurationError('FancyIterator will not return any data when the batch size ' + 'is larger than number of instances and truncation is enabled. ' + 'To fix this either use a smaller batch size (better for ' + 'training) or disable truncation (better for validation).') + if num_epochs is None: epochs: Iterable[int] = itertools.count(starting_epoch) else: @@ -59,9 +69,6 @@ def __call__(self, for epoch in epochs: - # In order to ensure that we are (almost) constantly streaming data to the model we - # need to have all of the instances in memory ($$$) - instance_list = list(instances) if shuffle: random.shuffle(instance_list) @@ -104,7 +111,11 @@ def __call__(self, def _split(self, instance: Instance) -> Tuple[List[Instance], int]: # Determine the size of the sequence inside the instance. true_length = len(instance['source']) - padded_length = self._split_size * (true_length // self._split_size) + if (true_length % self._split_size) != 0: + offset = 1 + else: + offset = 0 + padded_length = self._split_size * (true_length // self._split_size + offset) # Determine the split indices. split_indices = list(range(0, true_length, self._split_size)) @@ -162,7 +173,7 @@ def _generate_batches(self, except IndexError: # A queue is depleted # If we're training, we break to avoid densely padded inputs (since this biases # the model to overfit the longer sequences). - if not self._eval: + if self._truncate: return # But if we're evaluating we do want the padding, so that we don't skip anything. else: diff --git a/kglm/tests/data/fancy_iterator_test.py b/kglm/tests/data/fancy_iterator_test.py new file mode 100644 index 0000000..08b889c --- /dev/null +++ b/kglm/tests/data/fancy_iterator_test.py @@ -0,0 +1,82 @@ +# pylint: disable=no-self-use +from typing import List + +from allennlp.common.checks import ConfigurationError +from allennlp.common.testing import AllenNlpTestCase +from allennlp.data import Instance, Token, Vocabulary +from allennlp.data.fields import TextField +from allennlp.data.token_indexers import SingleIdTokenIndexer + +from kglm.data.iterators import FancyIterator + + +class FancyIteratorTest(AllenNlpTestCase): + def setUp(self): + super().setUp() + self.token_indexers = {"tokens": SingleIdTokenIndexer()} + self.vocab = Vocabulary() + self.this_index = self.vocab.add_token_to_namespace('this') + self.is_index = self.vocab.add_token_to_namespace('is') + self.a_index = self.vocab.add_token_to_namespace('a') + self.sentence_index = self.vocab.add_token_to_namespace('sentence') + self.another_index = self.vocab.add_token_to_namespace('another') + self.yet_index = self.vocab.add_token_to_namespace('yet') + self.very_index = self.vocab.add_token_to_namespace('very') + self.long_index = self.vocab.add_token_to_namespace('long') + instances = [ + self.create_instance(["this", "is", "a", "sentence"]), + self.create_instance(["this", "is", "another", "sentence"]), + self.create_instance(["yet", "another", "sentence"]), + self.create_instance(["this", "is", "a", "very", "very", "very", "very", "long", "sentence"]), + self.create_instance(["sentence"]), + ] + + self.instances = instances + + def create_instance(self, str_tokens: List[str]): + tokens = [Token(t) for t in str_tokens] + instance = Instance({'source': TextField(tokens, self.token_indexers)}) + return instance + + def test_truncate(self): + # Checks that the truncate parameter works as intended. + + # Since split size is less than the length of the "very ... very long" sentence, the + # iterator should return one batch when the truncation is enabled. + split_size = 4 + truncated_iterator = FancyIterator(batch_size=5, + split_size=split_size, + splitting_keys=['source'], + truncate=True) + truncated_iterator.index_with(self.vocab) + batches = list(truncated_iterator(self.instances, num_epochs=1)) + assert len(batches) == 1 + + # When truncation is disabled the iterator should return 3 batches instead. + non_truncated_iterator = FancyIterator(batch_size=5, + split_size=split_size, + splitting_keys=['source'], + truncate=False) + non_truncated_iterator.index_with(self.vocab) + batches = list(non_truncated_iterator(self.instances, num_epochs=1)) + assert len(batches) == 3 + + # When the batch size is larger than the number of instances, truncation will the iterator + # to return zero batches of data (since some of the instances in the batch would consist + # entirely of padding). Check that the iterator raises an error in this case. + invalid_iterator = FancyIterator(batch_size=6, + split_size=split_size, + splitting_keys=['source'], + truncate=True) + invalid_iterator.index_with(self.vocab) + with self.assertRaises(ConfigurationError): + batches = list(invalid_iterator(self.instances, num_epochs=1)) + + # If truncation is disabled then this should not cause an issue + valid_iterator = FancyIterator(batch_size=6, + split_size=split_size, + splitting_keys=['source'], + truncate=False) + valid_iterator.index_with(self.vocab) + batches = list(valid_iterator(self.instances, num_epochs=1)) + assert len(batches) == 3 diff --git a/kglm/tests/fixtures/training_config/kglm-disc.json b/kglm/tests/fixtures/training_config/kglm-disc.json index bdbb6d4..821c2dc 100644 --- a/kglm/tests/fixtures/training_config/kglm-disc.json +++ b/kglm/tests/fixtures/training_config/kglm-disc.json @@ -62,7 +62,8 @@ "parent_ids", "relations", "shortlist_inds" - ] + ], + "truncate": false }, "trainer": { "type": "lm", diff --git a/kglm/tests/fixtures/training_config/kglm-disc.no-shortlist.json b/kglm/tests/fixtures/training_config/kglm-disc.no-shortlist.json index 4f4a221..8504011 100644 --- a/kglm/tests/fixtures/training_config/kglm-disc.no-shortlist.json +++ b/kglm/tests/fixtures/training_config/kglm-disc.no-shortlist.json @@ -62,7 +62,8 @@ "parent_ids", "relations", "shortlist_inds" - ] + ], + "truncate": false }, "trainer": { "type": "lm", diff --git a/kglm/tests/fixtures/training_config/kglm.json b/kglm/tests/fixtures/training_config/kglm.json index 5116a82..d5edc2f 100644 --- a/kglm/tests/fixtures/training_config/kglm.json +++ b/kglm/tests/fixtures/training_config/kglm.json @@ -68,7 +68,8 @@ "relations", "shortlist_inds", "alias_copy_inds" - ] + ], + "truncate": false }, "trainer": { "type": "lm", diff --git a/kglm/tests/fixtures/training_config/kglm.no-shortlist.json b/kglm/tests/fixtures/training_config/kglm.no-shortlist.json index c86f9ee..b8cee9d 100644 --- a/kglm/tests/fixtures/training_config/kglm.no-shortlist.json +++ b/kglm/tests/fixtures/training_config/kglm.no-shortlist.json @@ -68,7 +68,8 @@ "relations", "shortlist_inds", "alias_copy_inds" - ] + ], + "truncate": false }, "trainer": { "type": "lm",