Skip to content

Commit

Permalink
Finetune: load tasks with AttentionalDecoder + fix missing sent RNN e…
Browse files Browse the repository at this point in the history
…ncoder loading
  • Loading branch information
OrianeN committed Apr 18, 2024
1 parent f2bee55 commit 2e571b0
Showing 1 changed file with 76 additions and 14 deletions.
90 changes: 76 additions & 14 deletions pie/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import re
import tarfile
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -234,14 +235,32 @@ def load_state_dict_from_pretrained(self, pretrained, exclude=[]):
exclude : list (optional), List of model parts to exclude from loading. Defaults to [].
Available values are: "wemb", "cemb", "cemb_rnn", "sent_rnn", "lm", as well as any decoding task
"""
def get_block_state_dict(block_name, full_state_dict):
block_state_dict = OrderedDict({
k.replace(block_name, ""): v for k, v in full_state_dict.items()
if k.startswith(block_name)
})
return block_state_dict
def get_module_state_dict(module_str, full_state_dict):
"""Creates the state_dict of a specific module given the full state_dict.
Used to get a state_dict that can be loaded into a module
(the full state_dict can only be loaded at the root of a model)
Args:
module_str (str): Name of the module (e.g. "cemb.rnn")
full_state_dict (OrderedDict): state_dict of a model (meant for the pretrained model)
def load_state_dict_label_by_label(module, module_str, labels_table_key, label_encoder_pretrained, state_dict_pretrained, is_task=False):
Returns:
OrderedDict: state_dict corresponding the the module
"""
module_state_dict = OrderedDict({
k.replace(module_str, ""): v for k, v in full_state_dict.items()
if k.startswith(module_str)
})
return module_state_dict

def load_state_dict_label_by_label(
module,
module_str,
labels_table_key,
label_encoder_pretrained,
state_dict_pretrained,
is_task=False,
exclude_params_regex=None):
"""Load pretrained state dict parameters iteratively for each label (vocab/classification labels...)
This is done label by label as the label tables may differ between the pretrained and current model.
Expand All @@ -260,8 +279,15 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e
is_task (bool): Set to True if the labels are stored in the dicts MultiLabelEncoder.tasks
(instead of direct attributes of MultiLabelEncoder)
"""
params_to_update = [p[0] for p in module.named_buffers()] + [p[0] for p in module.named_parameters()]

params_to_update = []
for param_tuples in [module.named_buffers(), module.named_parameters()]:
names = [
p[0] for p in param_tuples
if not (exclude_params_regex and re.search(exclude_params_regex, p[0]))
]
if names:
params_to_update.extend(names)

if is_task:
labels_dict_pretrained = label_encoder_pretrained.tasks[labels_table_key].table
labels_dict_current = self.label_encoder.tasks[labels_table_key].table
Expand Down Expand Up @@ -315,13 +341,18 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e
total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label(
self.wemb, "wemb", "word", label_encoder_pretrained, state_dict_pretrained)
print(f"Initialized {total_updated}/{nb_labels_new_model} word embs ({nb_pretrained_labels} in pretrained model)")
# Character embeddings (cemb.emb)
total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label(
self.cemb.emb, "cemb.emb", "char", label_encoder_pretrained, state_dict_pretrained)
print(f"Initialized {total_updated}/{nb_labels_new_model} char embs ({nb_pretrained_labels} in pretrained model)")

# Load state_dict of the character-level RNN
# Load state_dict of the character-level RNN encoder
if "cemb_rnn" in model_parts_to_load:
self.cemb.rnn.load_state_dict(get_block_state_dict("cemb.rnn.", state_dict_pretrained))
self.cemb.rnn.load_state_dict(get_module_state_dict("cemb.rnn.", state_dict_pretrained))

# Load state_dict of the sentence-level RNN encoder
if "sent_rnn" in model_parts_to_load:
self.encoder.rnn.load_state_dict(get_module_state_dict("encoder.rnn.", state_dict_pretrained))

# Load state_dict of the language model (lm)
# This is done words by word as the vocabularies may differ between both models
Expand All @@ -330,11 +361,21 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e
if "lm" in model_parts_to_load:
# fwd
total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label(
self.lm_fwd_decoder, "lm_fwd_decoder", "word", label_encoder_pretrained, state_dict_pretrained)
self.lm_fwd_decoder,
"lm_fwd_decoder",
"word",
label_encoder_pretrained,
state_dict_pretrained
)
print(f"Initialized {total_updated}/{nb_labels_new_model} fwd LM word parameters ({nb_pretrained_labels} in pretrained model)")
# bwd
total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label(
self.lm_bwd_decoder, "lm_bwd_decoder", "word", label_encoder_pretrained, state_dict_pretrained)
self.lm_bwd_decoder,
"lm_bwd_decoder",
"word",
label_encoder_pretrained,
state_dict_pretrained
)
print(f"Initialized {total_updated}/{nb_labels_new_model} bwd LM word parameters ({nb_pretrained_labels} in pretrained model)")

# Fill tasks-specific params - WARNING DEPENDS ON THE TYPE OF TASK=DECODER
Expand All @@ -356,8 +397,29 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e

if isinstance(self.decoders[tname], LinearDecoder):
total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label(
self.decoders[tname], tname+"_decoder", tname, label_encoder_pretrained, state_dict_pretrained, is_task=True)
self.decoders[tname],
tname+"_decoder",
tname,
label_encoder_pretrained,
state_dict_pretrained,
is_task=True
)
print(f"Initialized {total_updated}/{nb_labels_new_model} {tname} labels parameters ({nb_pretrained_labels} in pretrained model)")
elif isinstance(self.decoders[tname], AttentionalDecoder):
# Load label-related params iterativaly
total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label(
self.decoders[tname],
tname+"_decoder",
tname,
label_encoder_pretrained,
state_dict_pretrained,
is_task=True,
exclude_params_regex="^rnn"
)
print(f"Initialized {total_updated}/{nb_labels_new_model} {tname} labels parameters ({nb_pretrained_labels} in pretrained model)")
# Load decoder RNN and attn params in block
self.decoders[tname].rnn.load_state_dict(get_module_state_dict(f"{tname}_decoder.rnn.", state_dict_pretrained))
self.decoders[tname].attn.load_state_dict(get_module_state_dict(f"{tname}_decoder.attn.", state_dict_pretrained))
else:
raise NotImplementedError(f"Can only load decoder parameters for tasks with Linear Decoders (found {type(self.decoders[tname])})")

Expand Down

0 comments on commit 2e571b0

Please sign in to comment.