Skip to content

Commit

Permalink
fix mtr finetuning (deepchem#4143)
Browse files Browse the repository at this point in the history
  • Loading branch information
riya-singh28 authored Oct 17, 2024
1 parent ec22084 commit c614f9e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
25 changes: 25 additions & 0 deletions deepchem/models/torch_models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ def load_from_pretrained( # type: ignore
>>> finetune_model = HuggingFaceModel(model=model, task='classification', tokenizer=tokenizer, model_dir='model-dir')
>>> finetune_model.load_from_pretrained()
Note
----
1. Use `load_from_pretrained` method only to load a pretrained model - a
model trained on a different task like Masked Language Modeling or
Multitask Regression. To `restore` a model, use the `restore` method.
2. A pretrain model has different number of target tasks for pretraining and a finetune
model has different number of target tasks for finetuning. Thus, they both have different
number of projection outputs in the last layer. To avoid a mismatch
in the weights of the output projection layer (last layer) between
the pretrain model and current model, we delete the projection
layers weights.
"""
if model_dir is None:
model_dir = self.model_dir
Expand All @@ -212,6 +225,18 @@ def load_from_pretrained( # type: ignore
else:
checkpoint = checkpoints[0]
data = torch.load(checkpoint, map_location=self.device)
# Delete keys of output projection layer (last layer) as the number of
# tasks (projections) in pretrain model and the current model
# might vary.
keys = data['model_state_dict'].keys()
if 'classifier.out_proj.weight' in keys:
del data['model_state_dict']['classifier.out_proj.weight']
if 'classifier.out_proj.bias' in keys:
del data['model_state_dict']['classifier.out_proj.bias']
if 'classifier.dense.bias' in keys:
del data['model_state_dict']['classifier.dense.bias']
if 'classifier.dense.weight' in keys:
del data['model_state_dict']['classifier.dense.weight']
self.model.load_state_dict(data['model_state_dict'],
strict=False)

Expand Down
13 changes: 13 additions & 0 deletions deepchem/models/torch_models/tests/test_hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,16 @@ def test_fill_mask_fidelity(tmpdir, hf_tokenizer):

# Test 3. Check that the infilling went to the right spot
assert filled['sequence'].startswith(f'<s>{filled["token_str"]}')


@pytest.mark.torch
def test_load_from_pretrained_with_diff_task(tmpdir):
# Tests loading a pretrained model where the weight shape in last layer
# (the final projection layer) of the pretrained model does not match
# with the weight shape in new model.
from deepchem.models.torch_models import Chemberta
model = Chemberta(task='mtr', n_tasks=10, model_dir=tmpdir)
model.save_checkpoint()

model = Chemberta(task='regression', n_tasks=20)
model.load_from_pretrained(model_dir=tmpdir)

0 comments on commit c614f9e

Please sign in to comment.