From 6cde433dfb76556527be1d3fb99a29285460d97e Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 10 Mar 2022 11:00:50 +0000 Subject: [PATCH] Bug fixes, remove lightning dependency as its not used yet --- requirements.txt | 2 +- setup.py | 2 +- torchseq/agents/model_agent.py | 16 ++--- torchseq/main.py | 2 +- torchseq/metric_hooks/hrq_agg.py | 78 +++++++++++++++++++++++ torchseq/models/bottleneck.py | 2 +- torchseq/models/bottleneck_autoencoder.py | 3 + torchseq/models/encoder.py | 7 +- torchseq/models/hrq_vae.py | 12 +++- torchseq/models/modular_bottleneck.py | 22 +++++-- 10 files changed, 123 insertions(+), 23 deletions(-) diff --git a/requirements.txt b/requirements.txt index ae33fcf..6492350 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ tensorboard==2.7.0 torch==1.10.2 tqdm>=4.62 scipy>=1.5 -pytorch-lightning==1.5.9 +# pytorch-lightning==1.5.9 nltk>=3.6.7 transformers==4.16.2 diff --git a/setup.py b/setup.py index bc956dd..a5ced37 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,6 @@ 'sacrebleu>=2.0', 'py-rouge', 'wandb==0.12.10', - 'pytorch-lightning==1.5.9' + # 'pytorch-lightning==1.5.9' ], ) \ No newline at end of file diff --git a/torchseq/agents/model_agent.py b/torchseq/agents/model_agent.py index 701c30f..d311795 100644 --- a/torchseq/agents/model_agent.py +++ b/torchseq/agents/model_agent.py @@ -717,7 +717,7 @@ def validate( self.all_metrics_at_best = {"nll": test_loss.item(), **all_metrics} - wandb_log({split_slug + "/" + k: v for k, v in self.all_metrics_at_best.items()}, step=self.global_step) + wandb_log({split_slug + "/" + k: v for k, v in self.all_metrics_at_best.items()}) if self.run_id is not None: with open(os.path.join(self.run_output_path, f"output.{split_slug}.txt"), "w") as f: @@ -758,13 +758,13 @@ def update_dashboard(self): """ Update McKenzie with the latest metric values """ - wandb_log( - { - "bleu": self.all_metrics_at_best.get("bleu", None), - "nll": self.all_metrics_at_best.get("nll", None), - }, - step=self.global_step, - ) + # wandb_log( + # { + # "bleu": self.all_metrics_at_best.get("bleu", None), + # "nll": self.all_metrics_at_best.get("nll", None), + # }, + # step=self.global_step, + # ) wandb_log({"progress": self.current_epoch / self.config.training.num_epochs * 100}, step=self.global_step) if "bleu" in self.all_metrics_at_best: diff --git a/torchseq/main.py b/torchseq/main.py index a70e6b0..5cb7766 100644 --- a/torchseq/main.py +++ b/torchseq/main.py @@ -18,7 +18,7 @@ from torchseq.datasets.builder import dataloader_from_config -from pytorch_lightning.lite import LightningLite +# from pytorch_lightning.lite import LightningLite import transformers diff --git a/torchseq/metric_hooks/hrq_agg.py b/torchseq/metric_hooks/hrq_agg.py index d9fc135..a0a6f2a 100644 --- a/torchseq/metric_hooks/hrq_agg.py +++ b/torchseq/metric_hooks/hrq_agg.py @@ -10,6 +10,7 @@ from torchseq.metric_hooks.base import MetricHook from torchseq.datasets.json_loader import JsonDataLoader from torchseq.utils.config import Config +import sacrebleu import logging @@ -51,6 +52,15 @@ def on_end_epoch(self, agent, use_test=False): ) logger.info("...done") + if self.config.eval.metrics.hrq_agg.get("run_masked_generation", False): + logger.info("Running generation with masking") + self.scores["hrq_agg"], _ = HRQAggregationMetricHook.eval_masked_generation( + self.config, + agent, + test=use_test, + ) + logger.info("...done") + return self.scores @abstractmethod @@ -143,3 +153,71 @@ def eval_generate_summaries(config, agent, test=False): agent.config.eval.data["sample_outputs"] = sample_outputs return {}, output + + @abstractmethod + def eval_masked_generation(config, agent, test=False, dev_samples=None, test_samples=None, skip_scores=False): + config_gen_masked = copy.deepcopy(config.data) + config_gen_masked["dataset"] = "json" + config_gen_masked["json_dataset"] = { + "path": config.eval.metrics.hrq_agg.dataset_all, + "filename": "space_reviews.{split}", + "field_map": [ + {"type": "copy", "from": "sentence", "to": "s2"}, + {"type": "copy", "from": "sentence", "to": "s1"}, + {"type": "copy", "from": "head_mask", "to": "head_mask"}, + {"type": "copy", "from": "residual_mask", "to": "residual_mask"}, + ], + } + + data_loader = JsonDataLoader( + data_path=agent.data_path, + config=Config(config_gen_masked), + dev_samples=dev_samples, + test_samples=test_samples, + ) + + bneck_types = [x.type for x in agent.config.bottleneck.modules] + if "hrqvae" not in bneck_types: + logger.warning("Tried to run oracle masked eval on a model without a quantizer!") + return {} + quantizer_index = bneck_types.index("hrqvae") + num_heads = agent.config.bottleneck.modules[quantizer_index].quantizer.num_heads + + scores = {} + outputs = {} + + split = "test" if test else "dev" + + if not skip_scores: + with jsonlines.open( + os.path.join( + agent.data_path, + config.eval.metrics.hrq_agg.dataset_all, + f"space_reviews.{split}.jsonl", + ) + ) as f: + rows = [row for row in f][: config_gen_masked["eval"].get("truncate_dataset", None)] + refs = [[q["sentence"]] for q in rows] + + for mask_length in range(0, num_heads + 1): + mask = [1] * (num_heads - mask_length) + [0] * mask_length + samples = (data_loader._test if test else data_loader._valid).samples + samples = [{**x, "head_mask": mask, "residual_mask": [0]} for x in samples] + masked_loader = JsonDataLoader( + data_path=agent.data_path, config=Config(config_gen_masked), dev_samples=samples + ) + + _, _, (output, _, _), _ = agent.inference(masked_loader.valid_loader) + + if not skip_scores: + # refs = [x["paras"] for x in qs_by_para_split] + max_num_refs = max([len(x) for x in refs]) + refs_padded = [x + [x[0]] * (max_num_refs - len(x)) for x in refs] + + tgt_bleu = sacrebleu.corpus_bleu(output, list(zip(*refs_padded)), lowercase=True).score + + scores[mask_length] = tgt_bleu + + outputs[mask_length] = output + + return scores, outputs diff --git a/torchseq/models/bottleneck.py b/torchseq/models/bottleneck.py index 3f063bb..f71f967 100644 --- a/torchseq/models/bottleneck.py +++ b/torchseq/models/bottleneck.py @@ -88,7 +88,7 @@ def __init__(self, config, embeddings=None): **quantizer_kwargs, ) - def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None): + def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None, residual_mask=None): # Pool encoding_pooled = ( diff --git a/torchseq/models/bottleneck_autoencoder.py b/torchseq/models/bottleneck_autoencoder.py index 0ef0664..0d6142a 100644 --- a/torchseq/models/bottleneck_autoencoder.py +++ b/torchseq/models/bottleneck_autoencoder.py @@ -114,6 +114,7 @@ def forward(self, batch, output, memory=None, tgt_field=None): batch["_global_step"], head_mask=batch.get("head_mask", None), forced_codes=batch.get("forced_codes", None), + residual_mask=batch.get("residual_mask", None), ) prebn_encoding_pooled = memory["encoding_pooled"] @@ -147,6 +148,7 @@ def forward(self, batch, output, memory=None, tgt_field=None): batch["_global_step"], forced_codes=batch.get("forced_codes", None), head_mask=batch.get("head_mask", None), + residual_mask=batch.get("residual_mask", None), ) # TODO: Instead of 2x full size encoders + down projection, change to 2x half size encoders @@ -177,6 +179,7 @@ def forward(self, batch, output, memory=None, tgt_field=None): batch["_global_step"], forced_codes=batch.get("forced_codes", None), head_mask=batch.get("head_mask", None), + residual_mask=batch.get("residual_mask", None), ) if "loss" in memory: diff --git a/torchseq/models/encoder.py b/torchseq/models/encoder.py index dac4929..c42ab10 100644 --- a/torchseq/models/encoder.py +++ b/torchseq/models/encoder.py @@ -180,10 +180,9 @@ def __init__(self, config, embeddings=None): self.embeddings.weight.data = Tokenizer().get_embeddings(config.prepro.tokenizer) self.embeddings.weight.requires_grad = not config.freeze_embeddings - if self.config.raw_embedding_dim != self.config.encoder.embedding_dim: - self.embedding_projection = nn.utils.weight_norm( - nn.Linear(config.raw_embedding_dim, config.encoder.embedding_dim, bias=False) - ) + self.embedding_projection = nn.utils.weight_norm( + nn.Linear(config.raw_embedding_dim, config.encoder.embedding_dim, bias=False) + ) self.bert_embedding_projection = nn.utils.weight_norm( nn.Linear( diff --git a/torchseq/models/hrq_vae.py b/torchseq/models/hrq_vae.py index bf7c90f..a8404a2 100644 --- a/torchseq/models/hrq_vae.py +++ b/torchseq/models/hrq_vae.py @@ -27,6 +27,8 @@ def __init__( head_dropout=0.3, head_dropout_keep_first=False, learnable_priors=False, + include_residual=False, + residual_penalty=0.0, ): super(HierarchicalRefinementQuantizer, self).__init__() @@ -48,6 +50,9 @@ def __init__( self._norm_loss_weight = norm_loss_weight + self._include_residual = include_residual + self._residual_penalty = residual_penalty + if head_dropout is not None and head_dropout > 0: self._head_dropout = torch.distributions.Bernoulli(1 - head_dropout) else: @@ -86,7 +91,7 @@ def __init__( def encoding_to_logits(self, input, head_ix, prev_codes): pass - def forward(self, inputs, global_step=None, forced_codes=None, head_mask=None): + def forward(self, inputs, global_step=None, forced_codes=None, head_mask=None, residual_mask=None): input_shape = inputs.shape quantized_list = [] @@ -236,6 +241,11 @@ def forward(self, inputs, global_step=None, forced_codes=None, head_mask=None): quantized = quantized * mask quantized = torch.sum(quantized, dim=1) + if self._include_residual: + quantized += resid_error * (residual_mask if residual_mask is not None else 1.0) + if self._residual_penalty > 0: + loss += torch.linalg.norm(resid_error, dim=-1) * self._residual_penalty + quantized = quantized.view(input_shape) return loss, quantized, vq_codes diff --git a/torchseq/models/modular_bottleneck.py b/torchseq/models/modular_bottleneck.py index 505f421..be58707 100644 --- a/torchseq/models/modular_bottleneck.py +++ b/torchseq/models/modular_bottleneck.py @@ -31,7 +31,7 @@ def __init__(self, config, embeddings=None): self.module_list = nn.ModuleList(modules) - def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None): + def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None, residual_mask=None): # if head_mask is not None: # print('hm in bottleneck') @@ -50,7 +50,12 @@ def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=No any_pooled = any_pooled | module.config.get("pooling", False) sub_encoding_post, sub_encoding_pooled, memory = module( - sub_encoding_pre, memory, global_step, forced_codes=forced_codes, head_mask=head_mask + sub_encoding_pre, + memory, + global_step, + forced_codes=forced_codes, + head_mask=head_mask, + residual_mask=residual_mask, ) encodings_post.append(sub_encoding_post) encodings_pooled.append(sub_encoding_pooled) @@ -136,7 +141,7 @@ def __init__(self, config, global_config, embedding_dim, num_heads): **quantizer_kwargs, ) - def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None): + def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None, residual_mask=None): # if head_mask is not None: # print('hm in bottleneck part') @@ -152,9 +157,14 @@ def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=No # Quantize if self.config.get("type", None) in ["vqvae", "hrqvae"]: - vq_loss, encoding_post, quantizer_indices = self.quantizer( - encoding_post, global_step, forced_codes, head_mask - ) + if self.config.get("type", None) == "hrqvae": + vq_loss, encoding_post, quantizer_indices = self.quantizer( + encoding_post, global_step, forced_codes, head_mask, residual_mask + ) + else: + vq_loss, encoding_post, quantizer_indices = self.quantizer( + encoding_post, global_step, forced_codes, head_mask + ) if "loss" not in memory: memory["loss"] = 0