Skip to content

Commit

Permalink
merge from main to experiments (#164)
Browse files Browse the repository at this point in the history
* Update uniprot.py (#157)

* various fixes (#161)

* more various fixes (rgy) (#162)

---------

Co-authored-by: Sam Cox <[email protected]>
  • Loading branch information
qcampbel and SamCox822 authored Jan 15, 2025
1 parent 58251d8 commit 7a65c57
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 147 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ repos:
- id: mixed-line-ending
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.270"
rev: "v0.7.1"
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
- repo: https://github.com/psf/black
rev: "23.3.0"
rev: "24.10.0"
hooks:
- id: black
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.3.0"
rev: "v1.13.0"
hooks:
- id: mypy
args: [--pretty, --ignore-missing-imports]
additional_dependencies: [types-requests]
- repo: https://github.com/PyCQA/isort
rev: "5.12.0"
rev: "5.13.2"
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/Yelp/detect-secrets
rev: v1.0.3
rev: v1.5.0
hooks:
- id: detect-secrets
args: [--exclude-files, ".github/workflows/"]
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rdf_tool import RDFTool
from .analysis_tools.rgy import (
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
)
from .analysis_tools.rgy import RadiusofGyrationTool
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.secondary_structure import (
Expand Down Expand Up @@ -80,9 +76,7 @@
"PCATool",
"PPIDistance",
"ProteinName2PDBTool",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RDFTool",
"RMSDCalculator",
"Scholar2ResultLLM",
Expand Down
6 changes: 2 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .pca_tools import PCATool
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rgy import RadiusofGyrationTool
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein
Expand All @@ -17,9 +17,7 @@
"MomentOfInertia",
"PCATool",
"PPIDistance",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RMSDCalculator",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _run(self, file_id: str) -> str:
plotting_tools._find_file(file_id)
plotting_tools.process_csv()
plot_result = plotting_tools.plot_data()
if type(plot_result) == str:
if isinstance(plot_result, str):
return "Succeeded. IDs of figures created: " + plot_result
else:
return "Failed. No figures created."
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def validate_input(self, input):
)

if stride:
if type(stride) != int:
if not isinstance(stride, int):
try:
stride = int(stride)
if stride <= 0:
Expand Down
145 changes: 44 additions & 101 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, path_registry):
self.top_file = ""
self.traj_file = ""
self.traj = None
self.rgy_file = ""

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
Expand All @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str):
traj_required=True,
)

def rgy_per_frame(self, force_recompute: bool = False) -> str:
def rgy_per_frame(self) -> str:
rg_per_frame = md.compute_rg(self.traj)
self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)
rgy_id = f"rgy_{self.traj_file}"
if rgy_id in self.path_registry.list_path_names() and force_recompute is False:
print("RGY already computed, skipping re-compute")
# todo -> maybe allow re-compute & save under different id/path
else:
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}."

def rgy_average(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
Expand All @@ -66,9 +65,8 @@ def plot_rgy(self) -> str:
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if plot_name.endswith(".png"):
plot_name = plot_name.split(".png")[0]
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png"
print("plot_path", plot_path)
plt.plot(rg_per_frame)
plt.xlabel("Frame")
Expand All @@ -78,106 +76,51 @@ def plot_rgy(self) -> str:
plt.savefig(f"{plot_path}")
self.path_registry.map_path(
plot_id,
plot_path + ".png",
plot_path,
description=f"Plot of radii of gyration over time for {self.traj_file}",
)
plt.close()
plt.clf()
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""
return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}"

path_registry: Optional[PathRegistry]
def compute_plot_return_avg(self) -> str:
rgy_per_frame = self.rgy_per_frame()
avg_rgy = self.rgy_average()
plot_rgy = self.plot_rgy()
return rgy_per_frame + plot_rgy + avg_rgy

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory.
class RadiusofGyrationTool(BaseTool):
name = "RadiusofGyrationTool"
description = """This tool calculates and plots
the radius of gyration
at each frame of a given trajectory and retuns the average.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""
topology file ID."""

path_registry: Optional[PathRegistry]
rgy: Optional[RadiusofGyration]
load_traj: bool = True

def __init__(self, path_registry):
def __init__(self, path_registry, load_traj=True):
super().__init__()
self.path_registry = path_registry
self.rgy = RadiusofGyration(path_registry)
self.load_traj = load_traj # only for testing

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
assert self.rgy is not None, "RadiusofGyration instance is not initialized"

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
if self.load_traj:
try:
self.rgy._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
return "Succeeded. " + self.rgy.compute_plot_return_avg()
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return f"Failed Computing RGY: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
Expand Down
10 changes: 6 additions & 4 deletions mdagent/tools/base_tools/preprocess_tools/uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def get_sequence_info(self, query: str, primary_accession: str) -> dict:
- 'crc64': The CRC64 hash of the protein sequence (probably not useful)
- 'md5': The MD5 hash of the protein sequence (probably not useful)
"""
seq_info = self.data = self.get_data(query, desired_field="sequence")
seq_info = self.get_data(query, desired_field="sequence")
if not seq_info:
return {}
seq_info_specific = self._match_primary_accession(seq_info, primary_accession)[
Expand Down Expand Up @@ -693,9 +693,11 @@ def get_ids(
if include_uniprotkbids:
all_ids + [entry["uniProtkbId"] for entry in ids_] if ids_ else []
accession = self.get_data(query, desired_field="accession")
all_ids + [
entry["primaryAccession"] for entry in accession
] if accession else []
(
all_ids + [entry["primaryAccession"] for entry in accession]
if accession
else []
)
if single_id:
return [all_ids[0]] if all_ids else []
return list(set(all_ids))
Expand Down
Loading

0 comments on commit 7a65c57

Please sign in to comment.