Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RXN4Chem: Running retrosynthesis and reaction prediction locally. #52

Merged
merged 10 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,10 @@ dmypy.json
local/
*ipynb
query/

# Models in docker directory
**/docker/molecular-transformer/*txt
*.pt
*onnx
**/docker/*/files/*csv.gz
**/docker/*/files/*hdf5
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ chem_model = ChemCrow(model="gpt-4-0613", temp=0.1, streaming=False)
chem_model.run("What is the molecular weight of tylenol?")
```


## 🛠️ Self-hosting of some tools.

By default, ChemCrow relies on the RXN4Chem API for retrosynthetic planning and reaction product prediction. This can however be slow and depends on you having an API key.

Optionally, you can also self host these tools by running some pre-made docker images.

Run

```
docker run --gpus all -d -p 8051:5000 doncamilom/rxnpred:latest
docker run --gpus all -d -p 8052:5000 doncamilom/retrosynthesis:latest
```


Now ChemCrow can be used like this:

```python
from chemcrow.agents import ChemCrow

chem_model = ChemCrow(model="gpt-4-0613", temp=0.1, streaming=False, local_rxn=True)
chem_model.run("What is the product of the reaction between styrene and dibromine?")
```


## ✅ Citation
Bran, Andres M., et al. "ChemCrow: Augmenting large-language models with chemistry tools." arXiv preprint arXiv:2304.05376 (2023).

Expand Down
3 changes: 2 additions & 1 deletion chemcrow/agents/chemcrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
streaming: bool = True,
openai_api_key: Optional[str] = None,
api_keys: dict = {},
local_rxn: bool = False,
):
"""Initialize ChemCrow agent."""

Expand All @@ -58,7 +59,7 @@ def __init__(
if tools is None:
api_keys["OPENAI_API_KEY"] = openai_api_key
tools_llm = _make_llm(tools_model, temp, openai_api_key, streaming)
tools = make_tools(tools_llm, api_keys=api_keys, verbose=verbose)
tools = make_tools(tools_llm, api_keys=api_keys, local_rxn=local_rxn, verbose=verbose)

# Initialize agent
self.agent_executor = RetryAgentExecutor.from_agent_and_tools(
Expand Down
9 changes: 7 additions & 2 deletions chemcrow/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chemcrow.tools import *


def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, verbose=True):
def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, local_rxn: bool=False, verbose=True):
serp_api_key = api_keys.get("SERP_API_KEY") or os.getenv("SERP_API_KEY")
rxn4chem_api_key = api_keys.get("RXN4CHEM_API_KEY") or os.getenv("RXN4CHEM_API_KEY")
openai_api_key = api_keys.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
Expand Down Expand Up @@ -48,10 +48,15 @@ def make_tools(llm: BaseLanguageModel, api_keys: dict = {}, verbose=True):
all_tools += [GetMoleculePrice(chemspace_api_key)]
if serp_api_key:
all_tools += [WebSearch(serp_api_key)]
if rxn4chem_api_key:
if (not local_rxn) and rxn4chem_api_key:
all_tools += [
RXNPredict(rxn4chem_api_key),
RXNRetrosynthesis(rxn4chem_api_key, openai_api_key),
]
elif local_rxn:
all_tools += [
RXNPredictLocal(),
RXNRetrosynthesisLocal()
]

return all_tools
30 changes: 30 additions & 0 deletions chemcrow/docker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

# Tools of organic chemistry

A docker container was prepared for each tool, which exposes an api for requests.

> docker run -d -p 8052:5000 doncamilom/rxnpred:latest

Where 5000 is fixed, and 8082 is the port to be exposed.

A request in curl can look like this

> curl -X POST -H "Content-Type: application/json" -d '{"smiles": "O=C(OC(C)(C)C)c1ccc(C(=O)Nc2ccc(Cl)cc2)cc1"}' http://localhost:8082/api/v1/run

Or in Python

```python

import json
import requests

def reaction_predict(reactants):
response = requests.post(
"http://localhost:8052/api/v1/run",
headers={"Content-Type": "application/json"},
data=json.dumps({"smiles": reactants})
)
return response.json()['product'][0]

product = reaction_predict('CCOCCCCO.CC(=O)Cl')
```
9 changes: 9 additions & 0 deletions chemcrow/docker/aizynthfinder/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM python:3.9
RUN pip install aizynthfinder[all] flask
COPY files/ .
COPY . .

EXPOSE 5000
ENTRYPOINT ["python"]

CMD ["app.py"]
29 changes: 29 additions & 0 deletions chemcrow/docker/aizynthfinder/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import json
import subprocess

from flask import Flask, jsonify, request

app = Flask(__name__)

@app.route('/api/v1/run', methods=['POST'])
def rxnfp():
data = request.get_json()
target = data.get("target", [])

command = ["aizynthcli", "--config", "config.yml", "--smiles", f"{target}"]

print(command)
result = subprocess.run(
command, check=True, capture_output=True, text=True
)
print(result)

# Read output trees.json
with open("trees.json", "r") as f:
tree = json.load(f)

return tree


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
11 changes: 11 additions & 0 deletions chemcrow/docker/aizynthfinder/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
expansion:
uspto:
- files/uspto_model.onnx
- files/uspto_templates.csv.gz
ringbreaker:
- files/uspto_ringbreaker_model.onnx
- files/uspto_ringbreaker_templates.csv.gz
filter:
uspto: files/uspto_filter_model.onnx
stock:
zinc: files/zinc_stock.hdf5
11 changes: 11 additions & 0 deletions chemcrow/docker/aizynthfinder/files/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Download al lthe important files for runnig aizynthfinder using

```
download_public_data .
```

Which comes by installing aizynthfinder

```
pip install aizynthfinder[all]
```
11 changes: 11 additions & 0 deletions chemcrow/docker/molecular-transformer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.10
WORKDIR /app

RUN pip install rdkit-pypi==2022.3.1
RUN pip install OpenNMT-py==2.2.0 "numpy<2.0.0"

COPY . .
COPY input.txt .
COPY models/ .
CMD ["python", "app.py"]

75 changes: 75 additions & 0 deletions chemcrow/docker/molecular-transformer/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import re
import subprocess
from flask import Flask, request, jsonify
from rdkit import Chem

app = Flask(__name__)


SMI_REGEX_PATTERN = r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"

def canonicalize_smiles(smiles, verbose=False): # will raise an Exception if invalid SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
return Chem.MolToSmiles(mol)
else:
if verbose:
print(f'{smiles} is invalid.')
return ''

def smiles_tokenizer(smiles):
"""Canonicalize and tokenize input smiles"""

smiles = canonicalize_smiles(smiles)
smiles_regex = re.compile(SMI_REGEX_PATTERN)
tokens = [token for token in smiles_regex.findall(smiles)]
return ' '.join(tokens)


@app.route('/api/v1/run', methods=['POST'])
def f():
request_data = request.get_json()
input = request_data['smiles']

# Write the input to 'inp.txt'
with open('input.txt', 'w') as f:
# Tokenize smiles
smi = smiles_tokenizer(input)
f.write(smi)

model_path = 'models/USPTO480k_model_step_400000.pt'

src_path = 'input.txt'
output_path = 'output.txt'
n_best = 5
beam_size = 10
max_length = 300
batch_size = 1

try:
# Construct the command to execute
cmd = f"onmt_translate -model {model_path} " \
f"--src {src_path} " \
f"--output {output_path} --n_best {n_best} " \
f"--beam_size {beam_size} --max_length {max_length} " \
f"--batch_size {batch_size}"

# Execute the command using subprocess.check_call()
subprocess.check_call(cmd, shell=True)

# Read produced output
with open('output.txt', 'r') as f:
prods = f.read()
prods = re.sub(' ', '', prods).split('\n')


# Return a success message
return jsonify({'status': 'SUCCESS', 'product': prods})

except:
return jsonify({'status': 'ERROR', 'product': None})

if __name__ == '__main__':
# Run the Flask app
app.run(debug=True, host='0.0.0.0')

1 change: 1 addition & 0 deletions chemcrow/docker/molecular-transformer/models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Download model from https://drive.google.com/uc?id=1ywJCJHunoPTB5wr6KdZ8aLv7tMFMBHNy
1 change: 1 addition & 0 deletions chemcrow/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .safety import * # noqa
from .chemspace import * # noqa
from .converters import * # noqa
from .reactions import * # noqa
122 changes: 122 additions & 0 deletions chemcrow/tools/reactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Self-hosted reaction tools. Retrosynthesis, reaction forward prediction."""

import abc
import ast
import re
from time import sleep
from typing import Optional

import requests

import json
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from langchain.tools import BaseTool

from chemcrow.utils import is_smiles

__all__ = ["RXNPredictLocal", "RXNRetrosynthesisLocal"]


class RXNPredictLocal(BaseTool):
"""Predict reaction."""

name = "ReactionPredict"
description = (
"Predict the outcome of a chemical reaction. "
"Takes as input the SMILES of the reactants separated by a dot '.', "
"returns SMILES of the products."
)

def _run(self, reactants: str) -> str:
"""Run reaction prediction."""
if not is_smiles(reactants):
return "Incorrect input."

product = self.predict_reaction(reactants)
return product

def predict_reaction(self, reactants: str) -> str:
"""Make api request."""
try:
response = requests.post(
"http://localhost:8051/api/v1/run",
headers={"Content-Type": "application/json"},
data=json.dumps({"smiles": reactants})
)
return response.json()['product'][0]
except:
return "Error in prediction."


class RXNRetrosynthesisLocal(BaseTool):
"""Predict retrosynthesis."""

name = "ReactionRetrosynthesis"
description = (
"Obtain the synthetic route to a chemical compound. "
"Takes as input the SMILES of the product, returns recipe."
)
openai_api_key: str = ""

def _run(self, reactants: str) -> str:
"""Run reaction prediction."""
# Check that input is smiles
if not is_smiles(reactants):
return "Incorrect input."

paths = self.retrosynthesis(reactants)
procedure = self.get_action_sequence(paths[0])
return procedure

def retrosynthesis(self, reactants: str) -> str:
"""Make api request."""
response = requests.post(
"http://localhost:8052/api/v1/run",
headers={"Content-Type": "application/json"},
data=json.dumps({"smiles": reactants})
)
return response.json()

def get_action_sequence(self, path):
"""Get sequence of actions."""
actions = path
json_actions = self._preproc_actions(actions)
llm_sum = self._summary_gpt(json_actions)
return llm_sum

def _preproc_actions(self, path):
"""Preprocess actions."""
def _clean_actions(d):
if 'metadata' in d:
if 'mapped_reaction_smiles' in d['metadata']:
r = d['metadata']['mapped_reaction_smiles'].split(">>")
yield {"reactants": r[1], "products": r[0]}
if 'children' in d:
for c in d['children']:
yield from _clean_actions(c)

rxns = list(_clean_actions(path))
return rxns

def _summary_gpt(self, json: dict) -> str:
"""Describe synthesis."""
llm = ChatOpenAI( # type: ignore
temperature=0.05,
model_name="gpt-3.5-turbo-16k",
request_timeout=2000,
max_tokens=2000,
openai_api_key=self.openai_api_key,
)
prompt = (
"Here is a chemical synthesis described as a json.\nYour task is "
"to describe the synthesis, as if you were giving instructions for"
"a recipe. Use only the substances, quantities, temperatures and "
"in general any action mentioned in the json file. This is your "
"only source of information, do not make up anything else. Also, "
"add 15mL of DCM as a solvent in the first step. If you ever need "
'to refer to the json file, refer to it as "(by) the tool". '
"However avoid references to it. \nFor this task, give as many "
f"details as possible.\n {str(json)}"
)
return llm([HumanMessage(content=prompt)]).content
Loading
Loading