Skip to content

Commit

Permalink
correct properties order
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed May 21, 2024
1 parent 188f384 commit d95e4d4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
3 changes: 1 addition & 2 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def create_molecule_entry(output_text):
def optimize(
model, tokenizer,
oracle, config,
additional_properties=[]
additional_properties={}
):
file = open(config["log_dir"], "w")
print("config", config)
Expand Down Expand Up @@ -94,7 +94,6 @@ def optimize(
optim_entry.to_prompt(is_generation=True, include_oracle_score=prev_train_iter != 0, config=config)
for optim_entry in optim_entries
]

output_texts = []
for i in range(0, len(prompts), config["generation_batch_size"]):
prompt_batch = prompts[i: min(len(prompts), i + config["generation_batch_size"])]
Expand Down
21 changes: 12 additions & 9 deletions chemlactica/mol_opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,18 @@ def to_prompt(
prompt = config["eos_token"]
for mol_entry in self.mol_entries:
# prompt += config["eos_token"]
prompt += create_prompt_with_similars(mol_entry=mol_entry)

for prop_name, prop_spec in mol_entry.add_props.items():
prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}"

if "default" in config["strategy"]:
prompt += create_prompt_with_similars(mol_entry=mol_entry)
pass
elif "rej-sample-v2" in config["strategy"]:
prompt += create_prompt_with_similars(mol_entry=mol_entry)
if include_oracle_score:
prompt += f"[PROPERTY]oracle_score {mol_entry.score:.2f}[/PROPERTY]"
else:
raise Exception(f"Strategy {config['strategy']} not known.")
for prop_name, prop_spec in mol_entry.add_props.items():
prompt += f"{prop_spec['start_tag']}{prop_spec['value']}{prop_spec['end_tag']}"
prompt += f"[START_SMILES]{mol_entry.smiles}[END_SMILES]"

assert self.last_entry
Expand All @@ -218,10 +220,14 @@ def to_prompt(
else:
prompt_with_similars = create_prompt_with_similars(self.last_entry)

prompt += prompt_with_similars

for prop_name, prop_spec in self.last_entry.add_props.items():
prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"]

if "default" in config["strategy"]:
prompt += prompt_with_similars
pass
elif "rej-sample-v2" in config["strategy"]:
prompt += prompt_with_similars
if is_generation:
oracle_scores_of_mols_in_prompt = [e.score for e in self.mol_entries]
q_0_9 = (
Expand All @@ -240,9 +246,6 @@ def to_prompt(
else:
raise Exception(f"Strategy {config['strategy']} not known.")

for prop_name, prop_spec in self.last_entry.add_props.items():
prompt += prop_spec["start_tag"] + prop_spec["infer_value"](self.last_entry) + prop_spec["end_tag"]

if is_generation:
prompt += "[START_SMILES]"
else:
Expand Down

0 comments on commit d95e4d4

Please sign in to comment.