Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code and unit test for salt bridge tool #121

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e6f086d
added some rough code draft for salt bridge
brittyscience Mar 27, 2024
f1ee5ce
Merge branch 'main' into saltbridge_code
brittyscience Mar 31, 2024
0ff52c6
Prompting change in CNT
Jgmedina95 Apr 1, 2024
8ed1734
merging CNT_rdf to saltbridge_code
brittyscience Apr 3, 2024
3d90dd3
Merge branch 'main' into saltbridge_code
brittyscience Apr 10, 2024
9228c2a
pushing latest updates. Unit test and codes WIP
brittyscience Apr 12, 2024
f67e42b
Another changes updates to my GT and Unit test. WIP
brittyscience Apr 15, 2024
1103b11
latest changes. pytest is still not working.
brittyscience Apr 15, 2024
679f61a
updating my saltbridge tool code
brittyscience Apr 22, 2024
562d164
Just another update of my changes from two days ago.
brittyscience Apr 25, 2024
92ad2aa
more updates. agent should have my tools now
brittyscience Apr 25, 2024
4be782f
fixed init bug
brittyscience May 7, 2024
0df1242
added arg schema, making code more flexible, path registery
brittyscience May 11, 2024
5655b4d
unit test now works and salt bridge tool revision update
brittyscience May 13, 2024
86e333d
Merge branch 'main' of https://github.com/ur-whitelab/md-agent into s…
brittyscience May 13, 2024
e6d792c
pre commit hiccup fix
brittyscience Jun 9, 2024
2daed31
updating with changes from main
brittyscience Jun 9, 2024
8833cf7
resolved merging conflict
brittyscience Jun 12, 2024
f154a69
moved my saltbridge unit test to test_analysis folder
brittyscience Jun 13, 2024
5ddbba2
added ID to my description in line 67 and 70
brittyscience Jun 20, 2024
012c3b4
Merge branch 'main' of https://github.com/ur-whitelab/md-agent into s…
brittyscience Jun 26, 2024
e94f81a
M)erge branch 'main' of https://github.com/ur-whitelab/md-agent into …
brittyscience Jul 10, 2024
694de51
Update mdagent/tools/base_tools/analysis_tools/salt_bridge_tool.py
brittyscience Jul 10, 2024
e058c77
Update tests/test_analysis/test_saltbridge_tools.py
brittyscience Jul 10, 2024
0d6fe0d
Update tests/test_analysis/test_saltbridge_tools.py
brittyscience Jul 10, 2024
46580ff
adjustments based on feedback
brittyscience Jul 10, 2024
6841574
merging changes from remote branch
brittyscience Jan 23, 2025
ed756af
salt bridge code update
brittyscience Jan 23, 2025
41c75ed
merged from main to saltbridge_code
qcampbel Jan 24, 2025
2be80a8
added salt bridge counts and new unit tests
qcampbel Jan 29, 2025
2c36751
added neutral ph warning & refactored a bit
qcampbel Feb 5, 2025
670aca1
fixed code error (line 68)
qcampbel Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
RadiusofGyrationPlot,
)
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.salt_bridge_tool import SaltBridgeTool
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.vis_tools import VisFunctions, VisualizeProtein
from .preprocess_tools.clean_tools import CleaningToolFunction
Expand Down Expand Up @@ -44,7 +45,7 @@
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RDFTool",
"RMSDCalculator",
"SaltBridgeTool",
"Scholar2ResultLLM",
"SerpGitTool",
"SetUpandRunFunction",
Expand Down
3 changes: 2 additions & 1 deletion mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .salt_bridge_tool import SaltBridgeTool
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein

Expand All @@ -20,7 +21,7 @@
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RMSDCalculator",
"SaltBridgeTool",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
"VisFunctions",
Expand Down
106 changes: 106 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/salt_bridge_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Optional

import mdtraj as md
from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from mdagent.utils import PathRegistry


class SaltBridgeFunction: # this class defines a method called find_salt_bridge
# using MD traj and top files and threshold distance default, residue pair list
brittyscience marked this conversation as resolved.
Show resolved Hide resolved
# used to account for salt bridge analysis
brittyscience marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, path_registry):
brittyscience marked this conversation as resolved.
Show resolved Hide resolved
self.path_registry = path_registry
self.salt_bridges = [] # stores paired salt bridges
self.traj = None

def find_salt_bridges(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function might be easier to test if you break it up into some sub-functions

self, traj_file, top_file=None, threshold_distance=0.4, residue_pairs=None
brittyscience marked this conversation as resolved.
Show resolved Hide resolved
):
# add two files here in similar format as line 14 above
traj_file_path = self.path_registry.get_mapped_path(traj_file)
ending = traj_file_path.split(".")[-1]
if ending in ["dcd", "xtc", "xyz"] and top_file is not None:
top_file_path = self.path_registry.get_mapped_path(top_file)
self.traj = md.load(traj_file_path, top=top_file_path)
else:
self.traj = md.load(traj_file_path)
qcampbel marked this conversation as resolved.
Show resolved Hide resolved

if residue_pairs is None:
residue_pairs = [
("ARG", "ASP"),
("ARG", "GLU"),
("LYS", "ASP"),
("LYS", "GLU"),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if residue_pairs is None:
residue_pairs = [
("ARG", "ASP"),
("ARG", "GLU"),
("LYS", "ASP"),
("LYS", "GLU"),
]
residue_pairs = residue_pairs if residue_pairs else [
("ARG", "ASP"),
("ARG", "GLU"),
("LYS", "ASP"),
("LYS", "GLU"),
]


for pair in residue_pairs:
donor_residues = self.traj.topology.select(f'residue_name == "{pair[0]}"')
acceptor_residues = self.traj.topology.select(
f'residue_name == "{pair[1]}"'
)

for donor_idx in donor_residues:
for acceptor_idx in acceptor_residues:
distances = md.compute_distances(
self.traj, [[donor_idx, acceptor_idx]]
)
if any(d <= threshold_distance for d in distances):
self.salt_bridges.append((donor_idx, acceptor_idx))
return self.salt_bridges

def get_results_string(self):
msg = "Salt bridges found: "
for bridge in self.salt_bridges:
msg += (
f"Residue {self.traj.topology.atom(bridge[0]).residue.index + 1} "
f"({self.traj.topology.atom(bridge[0]).residue.name}) - "
f"Residue {self.traj.topology.atom(bridge[1]).residue.index + 1} "
f"({self.traj.topology.atom(bridge[1]).residue.name})"
)
return msg
qcampbel marked this conversation as resolved.
Show resolved Hide resolved


class SaltBridgeToolInput(BaseModel):
trajectory_fileid: str = Field(
None, description="Trajectory file. Either dcd, hdf5, xtc, or xyz"
brittyscience marked this conversation as resolved.
Show resolved Hide resolved
)

topology_fileid: Optional[str] = Field(None, description="Topology file")

threshold_distance: Optional[float] = Field(
0.4,
description=(
"maximum distance between residues for salt bridge formation in angstrom"
),
)

residue_pairs: Optional[dict] = Field(
None, description=("Identifies the amino acid residues for salt bridge")
)


class SaltBridgeTool(BaseTool):
name = "SaltBridgeTool"
description = "A tool to find salt bridge in a protein trajectory"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
description = "A tool to find salt bridge in a protein trajectory"
description = "A tool to find salt bridge in a protein trajectory. Input a trajectory file ID, a threshold distance (@brittyscience add a range here), and optionally a topology file ID and a list of residue pairs. If no residue pairs are provided...(finish this)"

you should be more descriptive about the inputs, outputs

args_schema = SaltBridgeToolInput
path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, path_registry):
def __init__(self, path_registry:PathRegistry):

super().__init__()
self.path_registry = path_registry

def _run(
self, traj_file, top_file=None, threshold_distance=0.4, residue_pairs=None
):
try:
# calls the salt bridge function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# calls the salt bridge function

salt_bridge_function = SaltBridgeFunction(self.path_registry)
salt_bridge_function.find_salt_bridges(
traj_file, top_file, threshold_distance, residue_pairs
)
message = salt_bridge_function.get_results_string()
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return "Succeeded. " + message
2 changes: 2 additions & 0 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
RDFTool,
SaltBridgeTool,
Scholar2ResultLLM,
SetUpandRunFunction,
SimulationOutputFigures,
Expand Down Expand Up @@ -69,6 +70,7 @@ def make_all_tools(
RadiusofGyrationPerFrame(path_registry=path_instance),
RadiusofGyrationPlot(path_registry=path_instance),
RDFTool(path_registry=path_instance),
SaltBridgeTool(path_registry=path_instance),
SetUpandRunFunction(path_registry=path_instance),
SimulationOutputFigures(path_registry=path_instance),
SmallMolPDB(path_registry=path_instance),
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This is an empty __init__.py file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this file, we don't need tests to be a module

50 changes: 50 additions & 0 deletions tests/test_analysis/test_saltbridge_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import MagicMock, patch

import pytest

from mdagent.tools.base_tools.analysis_tools.salt_bridge_tool import (
SaltBridgeFunction,
SaltBridgeTool,
)


@pytest.fixture
def fake_path_registry():
# Mock PathRegistry to return a specific file path when asked
mock_registry = MagicMock()
mock_registry.get_mapped_path.side_effect = lambda x: f"/fake/path/{x}"
return mock_registry


@pytest.fixture
def mock_md_load():
with patch("mdtraj.load", autospec=True) as mock:
yield mock


def test_saltbridge_tool_init(get_registry):
registry = get_registry("raw", False)
tool = SaltBridgeTool(path_registry=registry)
qcampbel marked this conversation as resolved.
Show resolved Hide resolved
assert tool.name == "SaltBridgeTool"
assert tool.path_registry == registry


def test_salt_bridge_function_init(get_registry):
path_registry = get_registry("raw", False)
sbf = SaltBridgeFunction(path_registry)
assert sbf.path_registry == path_registry

brittyscience marked this conversation as resolved.
Show resolved Hide resolved

def test_find_salt_bridges(fake_path_registry, mock_md_load):
sbf = SaltBridgeFunction(fake_path_registry)
sbf.find_salt_bridges("traj_file.dcd", "top_file.top")
mock_md_load.assert_called_once_with(
"/fake/path/traj_file.dcd", top="/fake/path/top_file.top"
)
# mock_md_compute_distances.assert_called()
brittyscience marked this conversation as resolved.
Show resolved Hide resolved


def test_salt_bridge_function_without_top(fake_path_registry, mock_md_load):
sbf = SaltBridgeFunction(fake_path_registry)
sbf.find_salt_bridges("traj_file.hdf5")
mock_md_load.assert_called_once_with("/fake/path/traj_file.hdf5")
Loading