Skip to content

Commit

Permalink
Fix MoLFormer for finetuning (deepchem#4195)
Browse files Browse the repository at this point in the history
* fix finetuning test

* add test for random weight initialization

* add test for multitask regression

* fix molformer tests

* add docstrings
  • Loading branch information
riya-singh28 authored Dec 17, 2024
1 parent 348ca63 commit 216faf6
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 24 deletions.
84 changes: 60 additions & 24 deletions deepchem/models/torch_models/molformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from deepchem.models.torch_models.hf_models import HuggingFaceModel
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, AutoModelForSequenceClassification
try:
import torch
has_torch = True
except:
has_torch = False


class MoLFormer(HuggingFaceModel):
Expand Down Expand Up @@ -74,42 +79,73 @@ def __init__(self,
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
trust_remote_code=True)
molformer_config = AutoConfig.from_pretrained(
"ibm/MoLFormer-XL-both-10pct",
deterministic_eval=True,
trust_remote_code=True)
"ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)
if task == 'mlm':
model = AutoModelForMaskedLM.from_config(molformer_config,
trust_remote_code=True)
elif task == 'mtr':
problem_type = 'regression'
model = AutoModelForSequenceClassification.from_pretrained(
"ibm/MoLFormer-XL-both-10pct",
problem_type=problem_type,
num_labels=n_tasks,
deterministic_eval=True,
trust_remote_code=True)
molformer_config.problem_type = 'regression'
molformer_config.num_labels = n_tasks
model = AutoModelForSequenceClassification.from_config(
config=molformer_config, trust_remote_code=True)
elif task == 'regression':
problem_type = 'regression'
model = AutoModelForSequenceClassification.from_pretrained(
"ibm/MoLFormer-XL-both-10pct",
problem_type=problem_type,
num_labels=n_tasks,
deterministic_eval=True,
trust_remote_code=True)
molformer_config.problem_type = 'regression'
molformer_config.num_labels = n_tasks
model = AutoModelForSequenceClassification.from_config(
config=molformer_config, trust_remote_code=True)
elif task == 'classification':
if n_tasks == 1:
problem_type = 'single_label_classification'
molformer_config.problem_type = 'single_label_classification'
else:
problem_type = 'multi_label_classification'
model = AutoModelForSequenceClassification.from_pretrained(
"ibm/MoLFormer-XL-both-10pct",
problem_type=problem_type,
deterministic_eval=True,
trust_remote_code=True)
molformer_config.num_labels = n_tasks
molformer_config.problem_type = 'multi_label_classification'
model = AutoModelForSequenceClassification.from_config(
molformer_config, trust_remote_code=True)
else:
raise ValueError('invalid task specification')

super(MoLFormer, self).__init__(model=model,
task=task,
tokenizer=tokenizer,
**kwargs)

def _prepare_batch(self, batch):
"""
Prepares a batch of data for the model based on the specified task. It overrides the _prepare_batch
of parent class for the following condition:-
- When n_task == 1 and task == 'classification', CrossEntropyLoss is used which takes input in
long int format.
- When n_task > 1 and task == 'classification', BCEWithLogitsLoss is used which takes input in
float format.
"""
smiles_batch, y, w = batch
tokens = self.tokenizer(smiles_batch[0].tolist(),
padding=True,
return_tensors="pt")

if self.task == 'mlm':
inputs, labels = self.data_collator.torch_mask_tokens(
tokens['input_ids'])
inputs = {
'input_ids': inputs.to(self.device),
'labels': labels.to(self.device),
'attention_mask': tokens['attention_mask'].to(self.device),
}
return inputs, None, w
elif self.task in ['regression', 'classification', 'mtr']:
if y is not None:
# y is None during predict
y = torch.from_numpy(y[0])
if self.task == 'regression' or self.task == 'mtr':
y = y.float().to(self.device)
elif self.task == 'classification':
if self.n_tasks == 1:
y = y.long().to(self.device)
else:
y = y.float().to(self.device)
for key, value in tokens.items():
tokens[key] = value.to(self.device)

inputs = {**tokens, 'labels': y}
return inputs, y, w
153 changes: 153 additions & 0 deletions deepchem/models/torch_models/tests/test_molformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import random
import deepchem as dc
import numpy as np
import pytest
Expand Down Expand Up @@ -113,3 +114,155 @@ def test_molformer_save_reload(tmpdir):

# all keys values should match
assert all(matches)


@pytest.mark.hf
def test_molformer_finetuning_multitask_classification():
# test multitask classification
loader = dc.molnet.load_clintox(featurizer=dc.feat.DummyFeaturizer())
tasks, dataset, transformers = loader
train, val, test = dataset
train_sample = train.select(range(10))
test_sample = test.select(range(10))

model = MoLFormer(task='classification', n_tasks=len(tasks))
loss = model.fit(train_sample, nb_epoch=1)
eval_score = model.evaluate(test_sample,
metrics=dc.metrics.Metric(
dc.metrics.roc_auc_score))
assert eval_score, loss
prediction = model.predict(test_sample)
# logit scores
assert prediction.shape == (test_sample.y.shape[0], len(tasks))


@pytest.mark.hf
def test_molformer_finetuning_multitask_regression():
# test multitask regression

cwd = os.path.dirname(os.path.abspath(__file__))
input_file = os.path.join(cwd,
'../../tests/assets/multitask_regression.csv')

loader = dc.data.CSVLoader(tasks=['task0', 'task1'],
feature_field='smiles',
featurizer=dc.feat.DummyFeaturizer())
dataset = loader.create_dataset(input_file)
model = MoLFormer(task='regression', n_tasks=2)
loss = model.fit(dataset, nb_epoch=1)
eval_score = model.evaluate(dataset,
metrics=dc.metrics.Metric(
dc.metrics.mean_absolute_error))

assert loss, eval_score
prediction = model.predict(dataset)
assert prediction.shape == dataset.y.shape


def set_seed(seed=42):

random.seed(seed) # Python random module
np.random.seed(seed) # NumPy
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch GPU
torch.cuda.manual_seed_all(seed) # All GPUs (if using multi-GPU)

# Ensures reproducibility in convolution operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


@pytest.mark.hf
def test_random_weight_initialization_regression():
# test model initialization for regression

set_seed(10)
model1 = MoLFormer(task='regression', n_tasks=2)

set_seed(25)
model2 = MoLFormer(task='regression', n_tasks=2)

model1_state_dict = model1.model.state_dict()
model2_state_dict = model2.model.state_dict()

model1_keys = [
key for key in model1_state_dict.keys() if 'molformer' in key
]
matches = [
torch.allclose(model1_state_dict[key], model2_state_dict[key])
for key in model1_keys
]
print(matches)
assert not all(matches)


@pytest.mark.hf
def test_random_weight_initialization_mlm():
# test model initialization for mlm

set_seed(10)
model1 = MoLFormer(task='mlm', n_tasks=2)

set_seed(25)
model2 = MoLFormer(task='mlm', n_tasks=2)

model1_state_dict = model1.model.state_dict()
model2_state_dict = model2.model.state_dict()

model1_keys = [
key for key in model1_state_dict.keys() if 'molformer' in key
]
matches = [
torch.allclose(model1_state_dict[key], model2_state_dict[key])
for key in model1_keys
]
print(matches)
assert not all(matches)


@pytest.mark.hf
def test_random_weight_initialization_mtr():
# test model initialization for mtr

set_seed(10)
model1 = MoLFormer(task='mtr', n_tasks=2)

set_seed(25)
model2 = MoLFormer(task='mtr', n_tasks=2)

model1_state_dict = model1.model.state_dict()
model2_state_dict = model2.model.state_dict()

model1_keys = [
key for key in model1_state_dict.keys() if 'molformer' in key
]
matches = [
torch.allclose(model1_state_dict[key], model2_state_dict[key])
for key in model1_keys
]
print(matches)
assert not all(matches)


@pytest.mark.hf
def test_random_weight_initialization_classification():
# test model initialization for classification

set_seed(10)
model1 = MoLFormer(task='classification', n_tasks=2)

set_seed(25)
model2 = MoLFormer(task='classification', n_tasks=2)

model1_state_dict = model1.model.state_dict()
model2_state_dict = model2.model.state_dict()

model1_keys = [
key for key in model1_state_dict.keys() if 'molformer' in key
]
matches = [
torch.allclose(model1_state_dict[key], model2_state_dict[key])
for key in model1_keys
]
print(matches)
assert not all(matches)

0 comments on commit 216faf6

Please sign in to comment.