Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code and unit test for salt bridge tool #121

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e6f086d
added some rough code draft for salt bridge
brittyscience Mar 27, 2024
f1ee5ce
Merge branch 'main' into saltbridge_code
brittyscience Mar 31, 2024
0ff52c6
Prompting change in CNT
Jgmedina95 Apr 1, 2024
8ed1734
merging CNT_rdf to saltbridge_code
brittyscience Apr 3, 2024
3d90dd3
Merge branch 'main' into saltbridge_code
brittyscience Apr 10, 2024
9228c2a
pushing latest updates. Unit test and codes WIP
brittyscience Apr 12, 2024
f67e42b
Another changes updates to my GT and Unit test. WIP
brittyscience Apr 15, 2024
1103b11
latest changes. pytest is still not working.
brittyscience Apr 15, 2024
679f61a
updating my saltbridge tool code
brittyscience Apr 22, 2024
562d164
Just another update of my changes from two days ago.
brittyscience Apr 25, 2024
92ad2aa
more updates. agent should have my tools now
brittyscience Apr 25, 2024
4be782f
fixed init bug
brittyscience May 7, 2024
0df1242
added arg schema, making code more flexible, path registery
brittyscience May 11, 2024
5655b4d
unit test now works and salt bridge tool revision update
brittyscience May 13, 2024
86e333d
Merge branch 'main' of https://github.com/ur-whitelab/md-agent into s…
brittyscience May 13, 2024
e6d792c
pre commit hiccup fix
brittyscience Jun 9, 2024
2daed31
updating with changes from main
brittyscience Jun 9, 2024
8833cf7
resolved merging conflict
brittyscience Jun 12, 2024
f154a69
moved my saltbridge unit test to test_analysis folder
brittyscience Jun 13, 2024
5ddbba2
added ID to my description in line 67 and 70
brittyscience Jun 20, 2024
012c3b4
Merge branch 'main' of https://github.com/ur-whitelab/md-agent into s…
brittyscience Jun 26, 2024
e94f81a
M)erge branch 'main' of https://github.com/ur-whitelab/md-agent into …
brittyscience Jul 10, 2024
694de51
Update mdagent/tools/base_tools/analysis_tools/salt_bridge_tool.py
brittyscience Jul 10, 2024
e058c77
Update tests/test_analysis/test_saltbridge_tools.py
brittyscience Jul 10, 2024
0d6fe0d
Update tests/test_analysis/test_saltbridge_tools.py
brittyscience Jul 10, 2024
46580ff
adjustments based on feedback
brittyscience Jul 10, 2024
6841574
merging changes from remote branch
brittyscience Jan 23, 2025
ed756af
salt bridge code update
brittyscience Jan 23, 2025
41c75ed
merged from main to saltbridge_code
qcampbel Jan 24, 2025
2be80a8
added salt bridge counts and new unit tests
qcampbel Jan 29, 2025
2c36751
added neutral ph warning & refactored a bit
qcampbel Feb 5, 2025
670aca1
fixed code error (line 68)
qcampbel Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .analysis_tools.rdf_tool import RDFTool
from .analysis_tools.rgy import RadiusofGyrationTool
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.salt_bridge_tool import SaltBridgeTool
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.secondary_structure import (
ComputeAcylindricity,
Expand Down Expand Up @@ -80,7 +81,7 @@
"ProteinName2PDBTool",
"RadiusofGyrationTool",
"RDFTool",
"RMSDCalculator",
"SaltBridgeTool",
"Scholar2ResultLLM",
"SetUpandRunFunction",
"SimulationOutputFigures",
Expand Down
2 changes: 2 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationTool
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .salt_bridge_tool import SaltBridgeTool
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein

Expand All @@ -21,6 +22,7 @@
"PPIDistance",
"RadiusofGyrationTool",
"RMSDCalculator",
"SaltBridgeTool",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
"VisFunctions",
Expand Down
236 changes: 236 additions & 0 deletions mdagent/tools/base_tools/analysis_tools/salt_bridge_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import warnings

import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
import pandas as pd
from langchain.tools import BaseTool

from mdagent.utils import FileType, PathRegistry, load_single_traj, save_plot


class SaltBridgeFunction:
def __init__(self, path_registry):
brittyscience marked this conversation as resolved.
Show resolved Hide resolved
self.path_registry = path_registry
self.salt_bridge_pairs = [] # stores paired salt bridges
self.salt_bridge_counts = []
self.traj = None
self.traj_file = ""

def _load_traj(self, traj_file, top_file):
self.traj = load_single_traj(
self.path_registry, top_fileid=top_file, traj_fileid=traj_file
)
self.traj_file = traj_file if traj_file else top_file

def find_salt_bridges(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function might be easier to test if you break it up into some sub-functions

self,
threshold_distance: float = 0.4,
residue_pairs=[],
):
"""
Find Salt Bridge in molecular dynamics simulation trajectory, using
threshold distance (in nm) between N and O atoms for salt bridge formation,
based on Barlow and Thornton's original definition of salt bridges
(https://doi.org/10.1016/S0022-2836(83)80079-5)


threshold_distance: maximum distance (in nm) between N and O atoms
residue_pairs (optional): list of tuples (donor_residue, acceptor_residue)
"""
if self.traj is None:
raise Exception("MDTrajectory hasn't been loaded")

if not residue_pairs:
residue_pairs = [
# (postive-charged, negative-charged)
# pairs from https://doi.org/10.1002/prot.22927
("ARG", "ASP"),
("ARG", "GLU"),
("LYS", "ASP"),
("LYS", "GLU"),
("HIS", "ASP"),
("HIS", "GLU"),
]
warnings.warn(
"No residue pairs provided. Default charged residues "
"are being used, assuming physiological pH. "
f"Default pairs: {residue_pairs}",
UserWarning,
)

donor_acceptor_pairs = []
for pair in residue_pairs:
print(f"Looking for salt bridges between {pair[0]} and {pair[1]} pairs...")
donor_atoms = self.traj.topology.select(f'resname == "{pair[0]}"')
acceptor_atoms = self.traj.topology.select(f'resname == "{pair[1]}"')

if donor_atoms.size == 0 or acceptor_atoms.size == 0:
continue

donor_nitrogens = [ # N atoms in the donor residues (e.g. Arg, Lys, His)
atom.index
for atom in self.traj.topology.atoms
if atom.index in donor_atoms and atom.element.symbol == "N"
]
acceptor_oxygens = [ # O atoms in the acceptor residues (e.g. Asp, Glu)
atom.index
for atom in self.traj.topology.atoms
if atom.index in acceptor_atoms and atom.element.symbol == "O"
]

# generate all possible donor-acceptor pairs
pairs = np.array(np.meshgrid(donor_nitrogens, acceptor_oxygens)).T.reshape(
-1, 2
)
donor_acceptor_pairs.append(pairs)

if not donor_acceptor_pairs:
return None

donor_acceptor_pairs = np.vstack(donor_acceptor_pairs) # combine into one list
all_distances = md.compute_distances(self.traj, donor_acceptor_pairs)

salt_bridge_counts = []
salt_bridge_pairs = []
for frame_idx in range(self.traj.n_frames):
frame_distances = all_distances[frame_idx]
within_threshold = frame_distances <= threshold_distance
salt_bridge_counts.append(np.sum(within_threshold))

filtered_pairs = donor_acceptor_pairs[within_threshold]
if filtered_pairs.size > 0:
salt_bridge_pairs.append((frame_idx, filtered_pairs))
self.salt_bridge_counts = salt_bridge_counts
self.salt_bridge_pairs = salt_bridge_pairs

def plot_salt_bridge_counts(self):
if not self.salt_bridge_pairs or self.traj.n_frames == 1:
return None

plt.figure(figsize=(10, 6))
plt.plot(
range(self.traj.n_frames),
self.salt_bridge_counts,
marker="o",
linestyle="-",
color="b",
)
plt.title(f"Salt Bridge Count Over Time - {self.traj_file}")
plt.xlabel("Frame")
plt.ylabel("Total Salt Bridge Count")
plt.grid(True)
fig_id = save_plot(
self.path_registry,
"salt_bridge",
f"figure of salt bridge counts for {self.traj_file}",
)
plt.close()
return fig_id

def save_results_to_file(self):
if self.traj is None:
raise Exception("Trajectory is None")
if not self.salt_bridge_pairs:
return None

if self.traj.n_frames == 1:
num_sb = self.salt_bridge_counts[0]
print(f"We found {num_sb} salt bridges for {self.traj_file}.")
print(
(
"Since the trajectory has only one frame, we saved a "
"list of salt bridges instead of plotting."
)
)

salt_bridge_data = []
frame_idx, bridges = self.salt_bridge_pairs[0]
for bridge in bridges:
donor_residue = self.traj.topology.atom(bridge[0]).residue
acceptor_residue = self.traj.topology.atom(bridge[1]).residue
salt_bridge_data.append(
{
"Donor": f"{donor_residue.name} ({donor_residue.index + 1})",
"Acceptor": f"{acceptor_residue.name} ({acceptor_residue.index + 1})",
}
)
df = pd.DataFrame(salt_bridge_data)

else:
df = pd.DataFrame(
{
"Frame": range(self.traj.n_frames),
"Salt Bridge Count": self.salt_bridge_counts,
}
)

# save to file, add to path registry
file_name = self.path_registry.write_file_name(
FileType.RECORD,
record_type="salt_bridges",
file_format="csv",
)
file_id = self.path_registry.get_fileid(file_name, FileType.RECORD)
file_path = f"{self.path_registry.ckpt_records}/{file_name}"
df.to_csv(file_path, index=False)
self.path_registry.map_path(
file_id, file_path, description=f"salt bridge analysis for {self.traj_file}"
)
return file_id

def compute_salt_bridges(
self,
traj_file,
top_file,
threshold_distance,
residue_pairs,
):
self._load_traj(traj_file, top_file)
self.find_salt_bridges(threshold_distance, residue_pairs)
file_id = self.save_results_to_file()
fig_id = self.plot_salt_bridge_counts()
return file_id, fig_id


class SaltBridgeTool(BaseTool):
name = "SaltBridgeTool"
description = (
"A tool to find and count salt bridges in a protein trajectory. "
"You need to provide either PDB file or trajectory and topology files. "
"Optional: provide threshold distance (default:0.4) and a custom list "
"of residue pairs as tuples of positive-charged and negative-charged. "
)
path_registry: PathRegistry | None = None

def __init__(self, path_registry):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, path_registry):
def __init__(self, path_registry:PathRegistry):

super().__init__()
self.path_registry = path_registry

def _run(
self,
traj_file: str,
top_file: str | None = None,
threshold_distance=0.4,
residue_pairs=[],
):
try:
if self.path_registry is None:
return "Path registry is not set"

salt_bridge_function = SaltBridgeFunction(self.path_registry)
results_file_id, fig_id = salt_bridge_function.compute_salt_bridges(
traj_file, top_file, threshold_distance, residue_pairs
)
if not results_file_id:
return (
"Succeeded. No salt bridges are found in "
f"{salt_bridge_function.traj_file}."
)

message = f"Saved results with file id: {results_file_id} "
if fig_id:
message += f"and figure with fig id {fig_id}."
return "Succeeded. " + message
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
2 changes: 2 additions & 0 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
ProteinName2PDBTool,
RadiusofGyrationTool,
RDFTool,
SaltBridgeTool,
Scholar2ResultLLM,
SetUpandRunFunction,
SimulationOutputFigures,
Expand Down Expand Up @@ -101,6 +102,7 @@ def make_all_tools(
ProteinName2PDBTool(path_registry=path_instance),
RadiusofGyrationTool(path_registry=path_instance),
RDFTool(path_registry=path_instance),
SaltBridgeTool(path_registry=path_instance),
SetUpandRunFunction(path_registry=path_instance),
SimulationOutputFigures(path_registry=path_instance),
SmallMolPDB(path_registry=path_instance),
Expand Down
Empty file added tests/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions tests/test_analysis/test_saltbridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import mdtraj as md
import pytest

from mdagent.tools.base_tools.analysis_tools.salt_bridge_tool import SaltBridgeFunction

# pdb with salt bridge residues (ARG, ASP, LYS, GLU)
pdb_data = """
HEADER MOCK SALT BRIDGE EXAMPLE
ATOM 1 N ARG A 1 0.000 0.000 0.000
ATOM 2 CA ARG A 1 1.000 0.000 0.000
ATOM 3 C ARG A 1 1.500 1.000 0.000
ATOM 4 O ASP A 2 2.000 1.000 0.000
ATOM 5 N LYS A 3 0.000 1.000 1.000
ATOM 6 CA LYS A 3 1.000 1.000 1.000
ATOM 7 C LYS A 3 1.500 2.000 1.000
ATOM 8 O GLU A 4 2.000 2.000 1.000
ATOM 9 N ASP A 2 3.000 1.000 0.000
ATOM 10 O GLU A 4 4.000 2.000 1.000
ATOM 11 N GLU A 4 2.000 2.000 0.000
ATOM 12 O GLU A 4 4.000 2.000 1.000
ATOM 13 N LYS A 3 0.0 3.0 0.000
ATOM 14 O LYS A 3 0.0 4.0 0.000
END
"""


@pytest.fixture
def get_salt_bridge_function(get_registry):
# Create the SaltBridgeFunction object using the PDB file path
reg = get_registry("raw", True)
pdb_path = f"{reg.ckpt_dir}/sb_residues.pdb"
with open(pdb_path, "w") as file:
file.write(pdb_data)
fxn = SaltBridgeFunction(reg)
fxn.traj = md.load(pdb_path)
fxn.traj_file = "sb_residues"
# fxn._load_traj(pdb_path, pdb_path) # Using pdb_path as both traj and top file for simplicity
return fxn


@pytest.fixture
def get_salt_bridge_function_with_butane(get_registry):
registry = get_registry("raw", True)
traj_fileid = "rec0_butane_123456"
top_fileid = "top_sim0_butane_123456"
fxn = SaltBridgeFunction(registry)
fxn._load_traj(traj_fileid, top_fileid)
return fxn


def test_find_salt_bridges_with_salt_bridges(get_salt_bridge_function):
salt_bridge_function = get_salt_bridge_function
salt_bridge_function.find_salt_bridges()
assert len(salt_bridge_function.salt_bridge_counts) == 1
assert len(salt_bridge_function.salt_bridge_pairs) == 1 # Only 1 frame
assert len(salt_bridge_function.salt_bridge_pairs[0][1]) == 6
assert salt_bridge_function.salt_bridge_counts == [6]


def test_salt_bridge_files_single_frame(get_salt_bridge_function):
salt_bridge_function = get_salt_bridge_function
salt_bridge_function.find_salt_bridges()
file_id = salt_bridge_function.save_results_to_file()
fig_id = salt_bridge_function.plot_salt_bridge_counts()
assert file_id is not None
assert fig_id is None


def test_salt_bridge_files_multiple_frames(get_salt_bridge_function):
salt_bridge_function = get_salt_bridge_function
n_frames = 5
multi_frame_traj = md.join([salt_bridge_function.traj] * n_frames)
salt_bridge_function.traj = multi_frame_traj
salt_bridge_function.find_salt_bridges()
file_id = salt_bridge_function.save_results_to_file()
fig_id = salt_bridge_function.plot_salt_bridge_counts()
assert file_id is not None
assert fig_id is not None


def test_no_salt_bridges(get_salt_bridge_function_with_butane):
salt_bridge_function = get_salt_bridge_function_with_butane
salt_bridge_function.find_salt_bridges()
file_id = salt_bridge_function.save_results_to_file()
fig_id = salt_bridge_function.plot_salt_bridge_counts()
assert file_id is None
assert fig_id is None
assert len(salt_bridge_function.salt_bridge_counts) == 0
assert len(salt_bridge_function.salt_bridge_pairs) == 0
assert salt_bridge_function.salt_bridge_pairs == []
assert file_id is None
assert fig_id is None


def test_invalid_trajectory(get_salt_bridge_function):
salt_bridge_function = get_salt_bridge_function
salt_bridge_function.traj = None
with pytest.raises(Exception, match="MDTrajectory hasn't been loaded"):
salt_bridge_function.find_salt_bridges()
Loading