Skip to content

Commit

Permalink
add model selection based on validation loss
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed Jun 8, 2024
1 parent 5d7a258 commit be4096d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 32 deletions.
7 changes: 1 addition & 6 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from transformers import OPTForCausalLM
from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func
from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback
from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler


def create_similar_mol_entries(pool, mol_entry, num_similars):
Expand Down Expand Up @@ -209,10 +209,6 @@ def optimize(
validation_dataset.shuffle(seed=42)

model.train()
early_stopping_callback = CustomEarlyStopCallback(
early_stopping_patience=1,
early_stopping_threshold=0.0001
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
Expand All @@ -223,7 +219,6 @@ def optimize(
tokenizer=tokenizer,
max_seq_length=config["rej_sample_config"]["max_seq_length"],
# data_collator=collator,
callbacks=[early_stopping_callback],
optimizers=[optimizer, lr_scheduler],
)
trainer.train()
Expand Down
51 changes: 26 additions & 25 deletions chemlactica/mol_opt/tunning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@
from chemlactica.mol_opt.utils import generate_random_number


class CustomEarlyStopCallback(TrainerCallback):
# class CustomEarlyStopCallback(TrainerCallback):

def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None:
super().__init__()
self.best_valid_loss = math.inf
self.early_stopping_patience = early_stopping_patience
self.current_patiance = 0
self.early_stopping_threshold = early_stopping_threshold
# def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None:
# super().__init__()
# self.best_valid_loss = math.inf
# self.early_stopping_patience = early_stopping_patience
# self.current_patiance = 0
# self.early_stopping_threshold = early_stopping_threshold

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.best_valid_loss = math.inf
self.current_patiance = 0
return super().on_train_begin(args, state, control, **kwargs)
# def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# self.best_valid_loss = math.inf
# self.current_patiance = 0
# return super().on_train_begin(args, state, control, **kwargs)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold:
self.current_patiance += 1
else:
self.current_patiance = 0
self.best_valid_loss = metrics["eval_loss"]
print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}")
if self.current_patiance >= self.early_stopping_patience:
control.should_training_stop = True
return super().on_evaluate(args, state, control, **kwargs)
# def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
# if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold:
# self.current_patiance += 1
# else:
# self.current_patiance = 0
# self.best_valid_loss = metrics["eval_loss"]
# print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}")
# if self.current_patiance >= self.early_stopping_patience:
# control.should_training_stop = True
# return super().on_evaluate(args, state, control, **kwargs)


# class CustomSFTTrainer(SFTTrainer):
Expand Down Expand Up @@ -60,23 +60,24 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra


def get_training_arguments(config):
checkpoints_dir = config["checkpoints_dir"] + "_" + str(time.time())
checkpoints_dir = f"{config['checkpoints_dir']}/checkpoint-{time.time():.4f}"
return TrainingArguments(
output_dir=checkpoints_dir,
per_device_train_batch_size=config["train_batch_size"],
per_device_eval_batch_size=config["train_batch_size"],
max_grad_norm=config["global_gradient_norm"],
num_train_epochs=config["num_train_epochs"],
evaluation_strategy="epoch",
# save_strategy="epoch",
save_strategy="epoch",
dataloader_drop_last=False,
dataloader_pin_memory=True,
dataloader_num_workers=config["dataloader_num_workers"],
gradient_accumulation_steps=config["gradient_accumulation_steps"],
logging_steps=1,
save_safetensors=False,
metric_for_best_model="loss",
# load_best_model_at_end=True,
# save_total_limit=1
load_best_model_at_end=True,
save_total_limit=1
)


Expand Down
4 changes: 3 additions & 1 deletion chemlactica/mol_opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def get_train_valid_entries(self):
return train_entries, valid_entries

def random_subset(self, subset_size):
rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2))
# rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2))
rand_inds = np.random.permutation(len(self.optim_entries))
rand_inds = rand_inds[:subset_size]
return [self.optim_entries[i] for i in rand_inds]

def __len__(self):
Expand Down

0 comments on commit be4096d

Please sign in to comment.