diff --git a/pie/models/model.py b/pie/models/model.py index 155dedc..6553207 100644 --- a/pie/models/model.py +++ b/pie/models/model.py @@ -3,6 +3,7 @@ import json import logging import os +import re import tarfile import warnings from collections import OrderedDict @@ -234,14 +235,32 @@ def load_state_dict_from_pretrained(self, pretrained, exclude=[]): exclude : list (optional), List of model parts to exclude from loading. Defaults to []. Available values are: "wemb", "cemb", "cemb_rnn", "sent_rnn", "lm", as well as any decoding task """ - def get_block_state_dict(block_name, full_state_dict): - block_state_dict = OrderedDict({ - k.replace(block_name, ""): v for k, v in full_state_dict.items() - if k.startswith(block_name) - }) - return block_state_dict + def get_module_state_dict(module_str, full_state_dict): + """Creates the state_dict of a specific module given the full state_dict. + Used to get a state_dict that can be loaded into a module + (the full state_dict can only be loaded at the root of a model) + + Args: + module_str (str): Name of the module (e.g. "cemb.rnn") + full_state_dict (OrderedDict): state_dict of a model (meant for the pretrained model) - def load_state_dict_label_by_label(module, module_str, labels_table_key, label_encoder_pretrained, state_dict_pretrained, is_task=False): + Returns: + OrderedDict: state_dict corresponding the the module + """ + module_state_dict = OrderedDict({ + k.replace(module_str, ""): v for k, v in full_state_dict.items() + if k.startswith(module_str) + }) + return module_state_dict + + def load_state_dict_label_by_label( + module, + module_str, + labels_table_key, + label_encoder_pretrained, + state_dict_pretrained, + is_task=False, + exclude_params_regex=None): """Load pretrained state dict parameters iteratively for each label (vocab/classification labels...) This is done label by label as the label tables may differ between the pretrained and current model. @@ -260,8 +279,15 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e is_task (bool): Set to True if the labels are stored in the dicts MultiLabelEncoder.tasks (instead of direct attributes of MultiLabelEncoder) """ - params_to_update = [p[0] for p in module.named_buffers()] + [p[0] for p in module.named_parameters()] - + params_to_update = [] + for param_tuples in [module.named_buffers(), module.named_parameters()]: + names = [ + p[0] for p in param_tuples + if not (exclude_params_regex and re.search(exclude_params_regex, p[0])) + ] + if names: + params_to_update.extend(names) + if is_task: labels_dict_pretrained = label_encoder_pretrained.tasks[labels_table_key].table labels_dict_current = self.label_encoder.tasks[labels_table_key].table @@ -315,13 +341,18 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label( self.wemb, "wemb", "word", label_encoder_pretrained, state_dict_pretrained) print(f"Initialized {total_updated}/{nb_labels_new_model} word embs ({nb_pretrained_labels} in pretrained model)") + # Character embeddings (cemb.emb) total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label( self.cemb.emb, "cemb.emb", "char", label_encoder_pretrained, state_dict_pretrained) print(f"Initialized {total_updated}/{nb_labels_new_model} char embs ({nb_pretrained_labels} in pretrained model)") - # Load state_dict of the character-level RNN + # Load state_dict of the character-level RNN encoder if "cemb_rnn" in model_parts_to_load: - self.cemb.rnn.load_state_dict(get_block_state_dict("cemb.rnn.", state_dict_pretrained)) + self.cemb.rnn.load_state_dict(get_module_state_dict("cemb.rnn.", state_dict_pretrained)) + + # Load state_dict of the sentence-level RNN encoder + if "sent_rnn" in model_parts_to_load: + self.encoder.rnn.load_state_dict(get_module_state_dict("encoder.rnn.", state_dict_pretrained)) # Load state_dict of the language model (lm) # This is done words by word as the vocabularies may differ between both models @@ -330,11 +361,21 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e if "lm" in model_parts_to_load: # fwd total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label( - self.lm_fwd_decoder, "lm_fwd_decoder", "word", label_encoder_pretrained, state_dict_pretrained) + self.lm_fwd_decoder, + "lm_fwd_decoder", + "word", + label_encoder_pretrained, + state_dict_pretrained + ) print(f"Initialized {total_updated}/{nb_labels_new_model} fwd LM word parameters ({nb_pretrained_labels} in pretrained model)") # bwd total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label( - self.lm_bwd_decoder, "lm_bwd_decoder", "word", label_encoder_pretrained, state_dict_pretrained) + self.lm_bwd_decoder, + "lm_bwd_decoder", + "word", + label_encoder_pretrained, + state_dict_pretrained + ) print(f"Initialized {total_updated}/{nb_labels_new_model} bwd LM word parameters ({nb_pretrained_labels} in pretrained model)") # Fill tasks-specific params - WARNING DEPENDS ON THE TYPE OF TASK=DECODER @@ -356,8 +397,29 @@ def load_state_dict_label_by_label(module, module_str, labels_table_key, label_e if isinstance(self.decoders[tname], LinearDecoder): total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label( - self.decoders[tname], tname+"_decoder", tname, label_encoder_pretrained, state_dict_pretrained, is_task=True) + self.decoders[tname], + tname+"_decoder", + tname, + label_encoder_pretrained, + state_dict_pretrained, + is_task=True + ) + print(f"Initialized {total_updated}/{nb_labels_new_model} {tname} labels parameters ({nb_pretrained_labels} in pretrained model)") + elif isinstance(self.decoders[tname], AttentionalDecoder): + # Load label-related params iterativaly + total_updated, nb_labels_new_model, nb_pretrained_labels = load_state_dict_label_by_label( + self.decoders[tname], + tname+"_decoder", + tname, + label_encoder_pretrained, + state_dict_pretrained, + is_task=True, + exclude_params_regex="^rnn" + ) print(f"Initialized {total_updated}/{nb_labels_new_model} {tname} labels parameters ({nb_pretrained_labels} in pretrained model)") + # Load decoder RNN and attn params in block + self.decoders[tname].rnn.load_state_dict(get_module_state_dict(f"{tname}_decoder.rnn.", state_dict_pretrained)) + self.decoders[tname].attn.load_state_dict(get_module_state_dict(f"{tname}_decoder.attn.", state_dict_pretrained)) else: raise NotImplementedError(f"Can only load decoder parameters for tasks with Linear Decoders (found {type(self.decoders[tname])})")