From d95e4d4dcce4d28fe7f58ac471a90ceae2a69769 Mon Sep 17 00:00:00 2001 From: tigranfah Date: Tue, 21 May 2024 14:44:09 +0400 Subject: [PATCH] correct properties order --- chemlactica/mol_opt/optimization.py | 3 +-- chemlactica/mol_opt/utils.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/chemlactica/mol_opt/optimization.py b/chemlactica/mol_opt/optimization.py index 5c8ad33..0e5e530 100644 --- a/chemlactica/mol_opt/optimization.py +++ b/chemlactica/mol_opt/optimization.py @@ -62,7 +62,7 @@ def create_molecule_entry(output_text): def optimize( model, tokenizer, oracle, config, - additional_properties=[] + additional_properties={} ): file = open(config["log_dir"], "w") print("config", config) @@ -94,7 +94,6 @@ def optimize( optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config) for optim_entry in optim_entries ] - output_texts = [] for i in range(0, len(prompts), config["generation_batch_size"]): prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])] diff --git a/chemlactica/mol_opt/utils.py b/chemlactica/mol_opt/utils.py index 84d7589..256b4c2 100644 --- a/chemlactica/mol_opt/utils.py +++ b/chemlactica/mol_opt/utils.py @@ -197,16 +197,18 @@ def to_prompt( prompt = config["eos_token"] for mol_entry in self.mol_entries: # prompt += config["eos_token"] + prompt += create_prompt_with_similars(mol_entry=mol_entry) + + for prop_name, prop_spec in mol_entry.add_props.items(): + prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}" + if "default" in config["strategy"]: - prompt += create_prompt_with_similars(mol_entry=mol_entry) + pass elif "rej-sample-v2" in config["strategy"]: - prompt += create_prompt_with_similars(mol_entry=mol_entry) if include_oracle_score: prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]" else: raise Exception(f"Strategy {config['strategy']} not known.") - for prop_name, prop_spec in mol_entry.add_props.items(): - prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}" prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]" assert self.last_entry @@ -218,10 +220,14 @@ def to_prompt( else: prompt_with_similars = create_prompt_with_similars(self.last_entry) + prompt += prompt_with_similars + + for prop_name, prop_spec in self.last_entry.add_props.items(): + prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"] + if "default" in config["strategy"]: - prompt += prompt_with_similars + pass elif "rej-sample-v2" in config["strategy"]: - prompt += prompt_with_similars if is_generation: oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries] q_0_9 = ( @@ -240,9 +246,6 @@ def to_prompt( else: raise Exception(f"Strategy {config['strategy']} not known.") - for prop_name, prop_spec in self.last_entry.add_props.items(): - prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"] - if is_generation: prompt += "[START_SMILES]" else: