Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving path registry #66

Merged
merged 16 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import csv
import re
from typing import Optional

import matplotlib.pyplot as plt
from langchain.tools import BaseTool

from mdagent.utils import PathRegistry


def process_csv(file_name):
with open(file_name, "r") as f:
Expand Down Expand Up @@ -64,13 +67,15 @@ def plot_data(data, headers, matched_headers):
class SimulationOutputFigures(BaseTool):
name = "PostSimulationFigures"
description = """This tool will take
a csv file output from an openmm
a csv file id output from an openmm
simulation and create figures for
all physical parameters
versus timestep of the simulation.
Give this tool the path to the
csv file output from the simulation."""

path_registry: Optional[PathRegistry]

def _run(self, file_path: str) -> str:
"""use the tool."""
try:
Expand Down
72 changes: 47 additions & 25 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, root_validator

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


class CleaningTools:
Expand Down Expand Up @@ -226,7 +226,7 @@ async def _arun(self, query: str) -> str:
class CleaningToolFunctionInput(BaseModel):
"""Input model for CleaningToolFunction"""

pdb_path: str = Field(..., description="Path to PDB or CIF file")
pdb_id: str = Field(..., description="ID of the pdb/cif file in the path registry")
output_path: Optional[str] = Field(..., description="Path to the output file")
replace_nonstandard_residues: bool = Field(
True, description="Whether to replace nonstandard residues with standard ones. "
Expand Down Expand Up @@ -277,10 +277,10 @@ def _run(self, **input_args) -> str:
input_args = input_args["input_args"]
else:
input_args = input_args
pdbfile_path = input_args.get("pdb_path", None)
if pdbfile_path is None:
return """No file path provided.
The input has to be a dictionary with the key 'pdb_path'"""
pdbfile_id = input_args.get("pdb_id", None)
if pdbfile_id is None:
return """No file was provided.
The input has to be a dictionary with the key 'pdb_id'"""
remove_heterogens = input_args.get("remove_heterogens", True)
remove_water = input_args.get("remove_water", True)
add_hydrogens = input_args.get("add_hydrogens", True)
Expand All @@ -289,17 +289,23 @@ def _run(self, **input_args) -> str:
"replace_nonstandard_residues", True
)
add_missing_atoms = input_args.get("add_missing_atoms", True)
output_path = input_args.get("output_path", None)
input_args.get("output_path", None)

if self.path_registry is None:
return "Path registry not initialized"
file_description = "Cleaned File: "
clean_tools = CleaningTools()
pdbfile = clean_tools._extract_path(pdbfile_path, self.path_registry)
name = pdbfile.split(".")[0]
end = pdbfile.split(".")[1]
CleaningTools()
try:
pdbfile = self.path_registry.get_mapped_path(pdbfile_id)
if "/" in pdbfile:
pdbfile_name = pdbfile.split("/")[-1]
name = pdbfile_name.split("_")[0]
end = pdbfile_name.split(".")[1]
print(f"pdbfile: {pdbfile}", f"name: {name}", f"end: {end}")
except Exception as e:
print(f"error retrieving from path_registry, trying to read file {e}")
return "File not found in path registry. "
fixer = PDBFixer(filename=pdbfile)

try:
fixer.findMissingResidues()
except Exception:
Expand All @@ -321,6 +327,7 @@ def _run(self, **input_args) -> str:
try:
if replace_nonstandard_residues:
fixer.replaceNonstandardResidues()
file_description += " Replaced Nonstandard Residues. "
except Exception:
print("error at replaceNonstandardResidues")
try:
Expand All @@ -343,26 +350,41 @@ def _run(self, **input_args) -> str:
"Missing Atoms Added and replaces nonstandard residues. "
)
file_mode = "w" if add_hydrogens else "a"
if output_path:
file_name = output_path
else:
version = 1
while os.path.exists(f"tidy_{name}v{version}.{end}"):
version += 1

file_name = f"tidy_{name}v{version}.{end}"

file_name = self.path_registry.write_file_name(
type=FileType.PROTEIN,
protein_name=name,
description="Clean",
file_format=end,
)
file_id = self.path_registry.get_fileid(file_name, FileType.PROTEIN)
# if output_path:
SamCox822 marked this conversation as resolved.
Show resolved Hide resolved
# file_name = output_path
# else:
# version = 1
# while os.path.exists(f"tidy_{name}v{version}.{end}"):
# version += 1
#
# file_name = f"tidy_{name}v{version}.{end}"
directory = "files/pdb"
if not os.path.exists(directory):
os.makedirs(directory)
if end == "pdb":
PDBFile.writeFile(
fixer.topology, fixer.positions, open(file_name, file_mode)
fixer.topology,
fixer.positions,
open(f"{directory}/{file_name}", file_mode),
)
elif end == "cif":
PDBxFile.writeFile(
fixer.topology, fixer.positions, open(file_name, file_mode)
fixer.topology,
fixer.positions,
open(f"{directory}/{file_name}", file_mode),
)

self.path_registry.map_path(file_name, file_name, file_description)
return f"{file_description} written to {file_name}"
self.path_registry.map_path(
file_id, f"{directory}/{file_name}", file_description
)
return f"{file_id} written to {directory}/{file_name}"
except FileNotFoundError:
return "Check your file path. File not found."
except Exception as e:
Expand Down
34 changes: 24 additions & 10 deletions mdagent/tools/base_tools/preprocess_tools/pdb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pdbfixer import PDBFixer
from pydantic import BaseModel, Field, ValidationError, root_validator

from mdagent.utils import PathRegistry
from mdagent.utils import FileType, PathRegistry


def get_pdb(query_string, path_registry=None):
Expand Down Expand Up @@ -41,13 +41,22 @@ def get_pdb(query_string, path_registry=None):
print(f"PDB file found with this ID: {pdbid}")
url = f"https://files.rcsb.org/download/{pdbid}.{filetype}"
pdb = requests.get(url)
filename = f"{pdbid}.{filetype}"
with open(filename, "w") as file:
filename = path_registry.write_file_name(
FileType.PROTEIN,
protein_name=pdbid,
description="raw",
file_format=filetype,
)
file_id = path_registry.get_fileid(filename, FileType.PROTEIN)
directory = "files/pdb"
# Create the directory if it does not exist
if not os.path.exists(directory):
os.makedirs(directory)

with open(f"{directory}/{filename}", "w") as file:
file.write(pdb.text)
print(f"{filename} is created.")
file_description = f"PDB file downloaded from RSCB, PDB ID: {pdbid}"
path_registry.map_path(filename, filename, file_description)
return filename

return filename, file_id
return None


Expand All @@ -73,11 +82,16 @@ def _run(self, query: str) -> str:
try:
if self.path_registry is None: # this should not happen
return "Path registry not initialized"
pdb = get_pdb(query, self.path_registry)
if pdb is None:
filename, pdbfile_id = get_pdb(query, self.path_registry)
if pdbfile_id is None:
return "Name2PDB tool failed to find and download PDB file."
else:
return f"Name2PDB tool successfully downloaded the PDB file: {pdb}"
self.path_registry.map_path(
pdbfile_id,
f"files/pdb/{filename}",
f"PDB file downloaded from RSCB, PDBFile ID: {pdbfile_id}",
)
return f"Name2PDB tool successful. downloaded the PDB file:{pdbfile_id}"
except Exception as e:
return f"Something went wrong. {e}"

Expand Down
35 changes: 27 additions & 8 deletions mdagent/tools/base_tools/simulation_tools/create_simulation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import textwrap
from typing import Optional

Expand Down Expand Up @@ -131,8 +132,13 @@ def remove_leading_spaces(self, text):


class ModifyScriptInput(BaseModel):
query: str = Field(..., description="Simmulation required by the user")
script: str = Field(..., description=" path to the base script file")
query: str = Field(
...,
description="""Simmulation required by the user.You MUST
specify the objective, requirements of the simulation as well
as on what protein you are working.""",
)
script: str = Field(..., description=" simulation ID of the base script file")


class ModifyBaseSimulationScriptTool(BaseTool):
Expand All @@ -150,10 +156,17 @@ def __init__(self, path_registry: Optional[PathRegistry], llm: BaseLanguageModel
self.llm = llm

def _run(self, **input):
base_script_path = input.get("script")
if not base_script_path:
return """No script provided. The keys for the input are:
base_script_id = input.get("script")
if not base_script_id:
return """No id provided. The keys for the input are:
'query' and 'script'"""
try:
base_script_path = self.path_registry.get_mapped_path(base_script_id)
parts = base_script_path.split("/")
if len(parts) > 1:
parts[-1]
except Exception as e:
return f"Error getting path from file id: {e}"
with open(base_script_path, "r") as file:
base_script = file.read()
base_script = "".join(base_script)
Expand All @@ -172,11 +185,17 @@ def _run(self, **input):
script_content = script_content.replace("```", "#")
script_content = textwrap.dedent(script_content).strip()
# Write to file
filename = "modified_simul.py"
with open(filename, "w") as file:
filename = self.path_registry.write_file_name(
type="SIMULATION", Sim_id=base_script_id, modified=True
)
file_id = self.path_registry.get_fileid(filename, type="SIMULATION")
directory = "files/simulations"
if not os.path.exists(directory):
os.makedirs(directory)
with open(f"{directory}/{filename}", "w") as file:
file.write(script_content)

self.path_registry.map_path(filename, filename, description)
self.path_registry.map_path(file_id, filename, description)
return "Script modified successfully"

async def _arun(self, query) -> str:
Expand Down
Loading
Loading