Skip to content

Commit

Permalink
Fix chemberta for Multitask Classification (deepchem#4194)
Browse files Browse the repository at this point in the history
* fix chemberta code

* fix finetuning test

* add test for multitask regression

* add docstrings
  • Loading branch information
riya-singh28 authored Dec 17, 2024
1 parent a855d16 commit 348ca63
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
50 changes: 49 additions & 1 deletion deepchem/models/torch_models/chemberta.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Dict, Any
from typing import Dict, Any, Tuple
from deepchem.models.torch_models.hf_models import HuggingFaceModel
from transformers.models.roberta.modeling_roberta import (
RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification)
from transformers.models.roberta.tokenization_roberta_fast import \
RobertaTokenizerFast
from transformers.modeling_utils import PreTrainedModel
try:
import torch
has_torch = True
except:
has_torch = False


class Chemberta(HuggingFaceModel):
Expand Down Expand Up @@ -107,6 +112,7 @@ def __init__(self,
chemberta_config.problem_type = 'single_label_classification'
else:
chemberta_config.problem_type = 'multi_label_classification'
chemberta_config.num_labels = n_tasks
model = RobertaForSequenceClassification(chemberta_config)
else:
raise ValueError('invalid task specification')
Expand All @@ -115,3 +121,45 @@ def __init__(self,
task=task,
tokenizer=tokenizer,
**kwargs)

def _prepare_batch(self, batch: Tuple[Any, Any, Any]):
"""
Prepares a batch of data for the model based on the specified task. It overrides the _prepare_batch
of parent class for the following condition:-
- When n_task == 1 and task == 'classification', CrossEntropyLoss is used which takes input in
long int format.
- When n_task > 1 and task == 'classification', BCEWithLogitsLoss is used which takes input in
float format.
"""

smiles_batch, y, w = batch
tokens = self.tokenizer(smiles_batch[0].tolist(),
padding=True,
return_tensors="pt")

if self.task == 'mlm':
inputs, labels = self.data_collator.torch_mask_tokens(
tokens['input_ids'])
inputs = {
'input_ids': inputs.to(self.device),
'labels': labels.to(self.device),
'attention_mask': tokens['attention_mask'].to(self.device),
}
return inputs, None, w
elif self.task in ['regression', 'classification', 'mtr']:
if y is not None:
# y is None during predict
y = torch.from_numpy(y[0])
if self.task == 'regression' or self.task == 'mtr':
y = y.float().to(self.device)
elif self.task == 'classification':
if self.n_tasks == 1:
y = y.long().to(self.device)
else:
y = y.float().to(self.device)
for key, value in tokens.items():
tokens[key] = value.to(self.device)

inputs = {**tokens, 'labels': y}
return inputs, y, w
43 changes: 43 additions & 0 deletions deepchem/models/torch_models/tests/test_chemberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,46 @@ def test_chemberta_load_weights_from_hf_hub():
# new model's model attribute is an entirely new model initiated by AutoModel.load_from_pretrained
# and hence it should have a different identifier
assert old_model_id != new_model_id


@pytest.mark.hf
def test_chemberta_finetuning_multitask_classification():
# test multitask classification
loader = dc.molnet.load_clintox(featurizer=dc.feat.DummyFeaturizer())
tasks, dataset, transformers = loader
train, val, test = dataset

train_sample = train.select(range(10))
test_sample = test.select(range(10))
model = Chemberta(task='classification', n_tasks=len(tasks))
loss = model.fit(train_sample, nb_epoch=1)
eval_score = model.evaluate(test_sample,
metrics=dc.metrics.Metric(
dc.metrics.roc_auc_score))
assert eval_score, loss
prediction = model.predict(test_sample)
# logit scores
assert prediction.shape == (test_sample.y.shape[0], len(tasks))


@pytest.mark.hf
def test_chemberta_finetuning_multitask_regression():
# test multitask regression

cwd = os.path.dirname(os.path.abspath(__file__))
input_file = os.path.join(cwd,
'../../tests/assets/multitask_regression.csv')

loader = dc.data.CSVLoader(tasks=['task0', 'task1'],
feature_field='smiles',
featurizer=dc.feat.DummyFeaturizer())
dataset = loader.create_dataset(input_file)
model = Chemberta(task='regression', n_tasks=2)
loss = model.fit(dataset, nb_epoch=1)
eval_score = model.evaluate(dataset,
metrics=dc.metrics.Metric(
dc.metrics.mean_absolute_error))

assert loss, eval_score
prediction = model.predict(dataset)
assert prediction.shape == dataset.y.shape

0 comments on commit 348ca63

Please sign in to comment.