Skip to content

Commit

Permalink
Bug fixes, remove lightning dependency as its not used yet
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhosking committed Mar 10, 2022
1 parent d1a1e7d commit 6cde433
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 23 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ tensorboard==2.7.0
torch==1.10.2
tqdm>=4.62
scipy>=1.5
pytorch-lightning==1.5.9
# pytorch-lightning==1.5.9

nltk>=3.6.7
transformers==4.16.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@
'sacrebleu>=2.0',
'py-rouge',
'wandb==0.12.10',
'pytorch-lightning==1.5.9'
# 'pytorch-lightning==1.5.9'
],
)
16 changes: 8 additions & 8 deletions torchseq/agents/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def validate(

self.all_metrics_at_best = {"nll": test_loss.item(), **all_metrics}

wandb_log({split_slug + "/" + k: v for k, v in self.all_metrics_at_best.items()}, step=self.global_step)
wandb_log({split_slug + "/" + k: v for k, v in self.all_metrics_at_best.items()})

if self.run_id is not None:
with open(os.path.join(self.run_output_path, f"output.{split_slug}.txt"), "w") as f:
Expand Down Expand Up @@ -758,13 +758,13 @@ def update_dashboard(self):
"""
Update McKenzie with the latest metric values
"""
wandb_log(
{
"bleu": self.all_metrics_at_best.get("bleu", None),
"nll": self.all_metrics_at_best.get("nll", None),
},
step=self.global_step,
)
# wandb_log(
# {
# "bleu": self.all_metrics_at_best.get("bleu", None),
# "nll": self.all_metrics_at_best.get("nll", None),
# },
# step=self.global_step,
# )
wandb_log({"progress": self.current_epoch / self.config.training.num_epochs * 100}, step=self.global_step)

if "bleu" in self.all_metrics_at_best:
Expand Down
2 changes: 1 addition & 1 deletion torchseq/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from torchseq.datasets.builder import dataloader_from_config

from pytorch_lightning.lite import LightningLite
# from pytorch_lightning.lite import LightningLite
import transformers


Expand Down
78 changes: 78 additions & 0 deletions torchseq/metric_hooks/hrq_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchseq.metric_hooks.base import MetricHook
from torchseq.datasets.json_loader import JsonDataLoader
from torchseq.utils.config import Config
import sacrebleu


import logging
Expand Down Expand Up @@ -51,6 +52,15 @@ def on_end_epoch(self, agent, use_test=False):
)
logger.info("...done")

if self.config.eval.metrics.hrq_agg.get("run_masked_generation", False):
logger.info("Running generation with masking")
self.scores["hrq_agg"], _ = HRQAggregationMetricHook.eval_masked_generation(
self.config,
agent,
test=use_test,
)
logger.info("...done")

return self.scores

@abstractmethod
Expand Down Expand Up @@ -143,3 +153,71 @@ def eval_generate_summaries(config, agent, test=False):
agent.config.eval.data["sample_outputs"] = sample_outputs

return {}, output

@abstractmethod
def eval_masked_generation(config, agent, test=False, dev_samples=None, test_samples=None, skip_scores=False):
config_gen_masked = copy.deepcopy(config.data)
config_gen_masked["dataset"] = "json"
config_gen_masked["json_dataset"] = {
"path": config.eval.metrics.hrq_agg.dataset_all,
"filename": "space_reviews.{split}",
"field_map": [
{"type": "copy", "from": "sentence", "to": "s2"},
{"type": "copy", "from": "sentence", "to": "s1"},
{"type": "copy", "from": "head_mask", "to": "head_mask"},
{"type": "copy", "from": "residual_mask", "to": "residual_mask"},
],
}

data_loader = JsonDataLoader(
data_path=agent.data_path,
config=Config(config_gen_masked),
dev_samples=dev_samples,
test_samples=test_samples,
)

bneck_types = [x.type for x in agent.config.bottleneck.modules]
if "hrqvae" not in bneck_types:
logger.warning("Tried to run oracle masked eval on a model without a quantizer!")
return {}
quantizer_index = bneck_types.index("hrqvae")
num_heads = agent.config.bottleneck.modules[quantizer_index].quantizer.num_heads

scores = {}
outputs = {}

split = "test" if test else "dev"

if not skip_scores:
with jsonlines.open(
os.path.join(
agent.data_path,
config.eval.metrics.hrq_agg.dataset_all,
f"space_reviews.{split}.jsonl",
)
) as f:
rows = [row for row in f][: config_gen_masked["eval"].get("truncate_dataset", None)]
refs = [[q["sentence"]] for q in rows]

for mask_length in range(0, num_heads + 1):
mask = [1] * (num_heads - mask_length) + [0] * mask_length
samples = (data_loader._test if test else data_loader._valid).samples
samples = [{**x, "head_mask": mask, "residual_mask": [0]} for x in samples]
masked_loader = JsonDataLoader(
data_path=agent.data_path, config=Config(config_gen_masked), dev_samples=samples
)

_, _, (output, _, _), _ = agent.inference(masked_loader.valid_loader)

if not skip_scores:
# refs = [x["paras"] for x in qs_by_para_split]
max_num_refs = max([len(x) for x in refs])
refs_padded = [x + [x[0]] * (max_num_refs - len(x)) for x in refs]

tgt_bleu = sacrebleu.corpus_bleu(output, list(zip(*refs_padded)), lowercase=True).score

scores[mask_length] = tgt_bleu

outputs[mask_length] = output

return scores, outputs
2 changes: 1 addition & 1 deletion torchseq/models/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, config, embeddings=None):
**quantizer_kwargs,
)

def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None):
def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None, residual_mask=None):

# Pool
encoding_pooled = (
Expand Down
3 changes: 3 additions & 0 deletions torchseq/models/bottleneck_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def forward(self, batch, output, memory=None, tgt_field=None):
batch["_global_step"],
head_mask=batch.get("head_mask", None),
forced_codes=batch.get("forced_codes", None),
residual_mask=batch.get("residual_mask", None),
)

prebn_encoding_pooled = memory["encoding_pooled"]
Expand Down Expand Up @@ -147,6 +148,7 @@ def forward(self, batch, output, memory=None, tgt_field=None):
batch["_global_step"],
forced_codes=batch.get("forced_codes", None),
head_mask=batch.get("head_mask", None),
residual_mask=batch.get("residual_mask", None),
)

# TODO: Instead of 2x full size encoders + down projection, change to 2x half size encoders
Expand Down Expand Up @@ -177,6 +179,7 @@ def forward(self, batch, output, memory=None, tgt_field=None):
batch["_global_step"],
forced_codes=batch.get("forced_codes", None),
head_mask=batch.get("head_mask", None),
residual_mask=batch.get("residual_mask", None),
)

if "loss" in memory:
Expand Down
7 changes: 3 additions & 4 deletions torchseq/models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,9 @@ def __init__(self, config, embeddings=None):
self.embeddings.weight.data = Tokenizer().get_embeddings(config.prepro.tokenizer)
self.embeddings.weight.requires_grad = not config.freeze_embeddings

if self.config.raw_embedding_dim != self.config.encoder.embedding_dim:
self.embedding_projection = nn.utils.weight_norm(
nn.Linear(config.raw_embedding_dim, config.encoder.embedding_dim, bias=False)
)
self.embedding_projection = nn.utils.weight_norm(
nn.Linear(config.raw_embedding_dim, config.encoder.embedding_dim, bias=False)
)

self.bert_embedding_projection = nn.utils.weight_norm(
nn.Linear(
Expand Down
12 changes: 11 additions & 1 deletion torchseq/models/hrq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(
head_dropout=0.3,
head_dropout_keep_first=False,
learnable_priors=False,
include_residual=False,
residual_penalty=0.0,
):

super(HierarchicalRefinementQuantizer, self).__init__()
Expand All @@ -48,6 +50,9 @@ def __init__(

self._norm_loss_weight = norm_loss_weight

self._include_residual = include_residual
self._residual_penalty = residual_penalty

if head_dropout is not None and head_dropout > 0:
self._head_dropout = torch.distributions.Bernoulli(1 - head_dropout)
else:
Expand Down Expand Up @@ -86,7 +91,7 @@ def __init__(
def encoding_to_logits(self, input, head_ix, prev_codes):
pass

def forward(self, inputs, global_step=None, forced_codes=None, head_mask=None):
def forward(self, inputs, global_step=None, forced_codes=None, head_mask=None, residual_mask=None):
input_shape = inputs.shape

quantized_list = []
Expand Down Expand Up @@ -236,6 +241,11 @@ def forward(self, inputs, global_step=None, forced_codes=None, head_mask=None):
quantized = quantized * mask

quantized = torch.sum(quantized, dim=1)
if self._include_residual:
quantized += resid_error * (residual_mask if residual_mask is not None else 1.0)
if self._residual_penalty > 0:
loss += torch.linalg.norm(resid_error, dim=-1) * self._residual_penalty

quantized = quantized.view(input_shape)

return loss, quantized, vq_codes
22 changes: 16 additions & 6 deletions torchseq/models/modular_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, config, embeddings=None):

self.module_list = nn.ModuleList(modules)

def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None):
def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None, residual_mask=None):

# if head_mask is not None:
# print('hm in bottleneck')
Expand All @@ -50,7 +50,12 @@ def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=No
any_pooled = any_pooled | module.config.get("pooling", False)

sub_encoding_post, sub_encoding_pooled, memory = module(
sub_encoding_pre, memory, global_step, forced_codes=forced_codes, head_mask=head_mask
sub_encoding_pre,
memory,
global_step,
forced_codes=forced_codes,
head_mask=head_mask,
residual_mask=residual_mask,
)
encodings_post.append(sub_encoding_post)
encodings_pooled.append(sub_encoding_pooled)
Expand Down Expand Up @@ -136,7 +141,7 @@ def __init__(self, config, global_config, embedding_dim, num_heads):
**quantizer_kwargs,
)

def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None):
def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=None, residual_mask=None):
# if head_mask is not None:
# print('hm in bottleneck part')

Expand All @@ -152,9 +157,14 @@ def forward(self, encoding, memory, global_step, forced_codes=None, head_mask=No
# Quantize
if self.config.get("type", None) in ["vqvae", "hrqvae"]:

vq_loss, encoding_post, quantizer_indices = self.quantizer(
encoding_post, global_step, forced_codes, head_mask
)
if self.config.get("type", None) == "hrqvae":
vq_loss, encoding_post, quantizer_indices = self.quantizer(
encoding_post, global_step, forced_codes, head_mask, residual_mask
)
else:
vq_loss, encoding_post, quantizer_indices = self.quantizer(
encoding_post, global_step, forced_codes, head_mask
)

if "loss" not in memory:
memory["loss"] = 0
Expand Down

0 comments on commit 6cde433

Please sign in to comment.