diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index d2bf603..7457614 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -12,7 +12,7 @@ import numpy as np from transformers import OPTForCausalLM from chemlactica.mol_opt.utils import OptimEntry, MoleculeEntry, Pool, generate_random_number, tanimoto_dist_func -from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler, CustomEarlyStopCallback +from chemlactica.mol_opt.tunning import get_training_arguments, get_optimizer_and_lr_scheduler def create_similar_mol_entries(pool, mol_entry, num_similars): @@ -209,10 +209,6 @@ def optimize( validation_dataset.shuffle(seed=42) model.train() - early_stopping_callback = CustomEarlyStopCallback( - early_stopping_patience=1, - early_stopping_threshold=0.0001 - ) trainer = SFTTrainer( model=model, train_dataset=train_dataset, @@ -223,7 +219,6 @@ def optimize( tokenizer=tokenizer, max_seq_length=config["rej_sample_config"]["max_seq_length"], # data_collator=collator, - callbacks=[early_stopping_callback], optimizers=[optimizer, lr_scheduler], ) trainer.train() diff --git a/chemlactica/mol_opt/tunning.py b/chemlactica/mol_opt/tunning.py index fc7d0ed..eeac662 100644 --- a/chemlactica/mol_opt/tunning.py +++ b/chemlactica/mol_opt/tunning.py @@ -8,30 +8,30 @@ from chemlactica.mol_opt.utils import generate_random_number -class CustomEarlyStopCallback(TrainerCallback): +# class CustomEarlyStopCallback(TrainerCallback): - def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: - super().__init__() - self.best_valid_loss = math.inf - self.early_stopping_patience = early_stopping_patience - self.current_patiance = 0 - self.early_stopping_threshold = early_stopping_threshold +# def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None: +# super().__init__() +# self.best_valid_loss = math.inf +# self.early_stopping_patience = early_stopping_patience +# self.current_patiance = 0 +# self.early_stopping_threshold = early_stopping_threshold - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - self.best_valid_loss = math.inf - self.current_patiance = 0 - return super().on_train_begin(args, state, control, **kwargs) +# def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): +# self.best_valid_loss = math.inf +# self.current_patiance = 0 +# return super().on_train_begin(args, state, control, **kwargs) - def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): - if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: - self.current_patiance += 1 - else: - self.current_patiance = 0 - self.best_valid_loss = metrics["eval_loss"] - print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") - if self.current_patiance >= self.early_stopping_patience: - control.should_training_stop = True - return super().on_evaluate(args, state, control, **kwargs) +# def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): +# if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold: +# self.current_patiance += 1 +# else: +# self.current_patiance = 0 +# self.best_valid_loss = metrics["eval_loss"] +# print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}") +# if self.current_patiance >= self.early_stopping_patience: +# control.should_training_stop = True +# return super().on_evaluate(args, state, control, **kwargs) # class CustomSFTTrainer(SFTTrainer): @@ -60,7 +60,7 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra def get_training_arguments(config): - checkpoints_dir = config["checkpoints_dir"] + "_" + str(time.time()) + checkpoints_dir = f"{config['checkpoints_dir']}/checkpoint-{time.time():.4f}" return TrainingArguments( output_dir=checkpoints_dir, per_device_train_batch_size=config["train_batch_size"], @@ -68,15 +68,16 @@ def get_training_arguments(config): max_grad_norm=config["global_gradient_norm"], num_train_epochs=config["num_train_epochs"], evaluation_strategy="epoch", - # save_strategy="epoch", + save_strategy="epoch", dataloader_drop_last=False, dataloader_pin_memory=True, dataloader_num_workers=config["dataloader_num_workers"], gradient_accumulation_steps=config["gradient_accumulation_steps"], logging_steps=1, + save_safetensors=False, metric_for_best_model="loss", - # load_best_model_at_end=True, - # save_total_limit=1 + load_best_model_at_end=True, + save_total_limit=1 ) diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 0c9d9cc..16b36e3 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -223,7 +223,9 @@ def get_train_valid_entries(self): return train_entries, valid_entries def random_subset(self, subset_size): - rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2)) + # rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size * 2)) + rand_inds = np.random.permutation(len(self.optim_entries)) + rand_inds = rand_inds[:subset_size] return [self.optim_entries[i] for i in rand_inds] def __len__(self):