diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 096d60d1..00000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,91 +0,0 @@ -# Contributor's Guide - -Welcome to the md-agent repository! We're thrilled that you're interested in contributing. This guide will help you understand how you can contribute to the project effectively. - -## Setup Instructions -To get started with contributing to md-agent, follow these steps: - -### Setting up the Repository and Environment -``` -git clone https://github.com/ur-whitelab/md-agent.git -cd md-agent -conda env create -n mdagent -f environment.yaml -conda activate mdagent -``` - -### Installing the Package and Dependencies and Configuring Pre-commit Hooks -``` -pip install -e . -pip install -r dev-requirements.txt -pre-commit install -``` - -## Code Guidelines - -- Use meaningful variable and function names. -- Maintain consistency with the existing codebase. -- Write clear and concise comments to explain your code. -- Run pre-commit before committing your changes to ensure code quality and pass the automated checks. - -## Feature Development Guidelines - -When developing new features for md-agent, please follow these guidelines: - -- If your feature uses a new package, ensure that you add the package to the project's setup and also include it in the ignore list in the `mypy.ini` file to avoid type checking errors. - -- If your feature requires the use of API keys or other sensitive information, follow the appropriate steps: - - Add a placeholder or example entry for the required information in the `.env.example` file. - - Open an issue to discuss and coordinate the addition of the actual keys or secrets to the project's secure environment. - -- New features should include the following components: - - Implement the feature functionality in the codebase. - - Write unit test functions to ensure proper functionality. - - If applicable, create a notebook demonstration or example showcasing the usage and benefits of the new feature. - -These guidelines help maintain consistency, security, and thoroughness in the development of new features for the project. - - -## Pull Request Process - -1. Fork the repository to your own GitHub account. -2. Create a new branch: `git checkout -b my-feature-branch` -3. Make your changes, following the code guidelines mentioned above. -4. Test your changes thoroughly. Pytest workflows must pass in order for PR to be approved. -5. Commit your changes: `git commit -am 'Add new feature'` -6. Push your branch to your forked repository: `git push origin my-feature-branch` -7. Submit a pull request, providing a detailed description of your changes and their purpose. -8. Request a review from the project maintainers or assigned reviewers. -9. Address any feedback or comments received during the review process. -10. Once your changes are approved, they will be merged into the main branch. - - -## Issue Guidelines - -If you encounter any bugs or have feature requests, please follow these guidelines when submitting an issue: - -1. Before creating a new issue, search the existing issues to avoid duplicates. -2. Use a clear and descriptive title. -3. Provide steps to reproduce the issue (if applicable). -4. Include any relevant error messages or logs. -5. Explain the expected behavior and the actual behavior you encountered. - -In addition, if you have any questions, need help, or want to discuss ideas, please submit an issue. - - -## Code Review Etiquette - -When participating in code reviews, please follow these guidelines: - - -1. Be respectful and constructive in your feedback. -2. Provide specific and actionable feedback. -3. Explain the reasoning behind your suggestions. -4. Be open to receiving feedback and engage in discussions. -5. Be responsive and timely in addressing comments. - -By adhering to these guidelines, we can foster a positive and collaborative environment for code reviews and contribute to the project's success. - - -## Acknowledgment - -We value and appreciate all contributions to md-agent. Your efforts are highly valued and have a positive impact on the project's development. diff --git a/README.md b/README.md index 088a57be..ba6c9eae 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ -MDAgent is a LLM-agent based toolset for Molecular Dynamics. +MDCrow is a LLM-agent based toolset for Molecular Dynamics. It's built using Langchain and uses a collection of tools to set up and execute molecular dynamics simulations, particularly in OpenMM. ## Environment Setup To use the OpenMM features in the agent, please set up a conda environment, following these steps. ``` -conda env create -n mdagent -f environment.yaml -conda activate mdagent +conda env create -n mdcrow -f environment.yaml +conda activate mdcrow ``` If you already have a conda environment, you can install dependencies before you activate it with the following step. @@ -16,7 +16,7 @@ If you already have a conda environment, you can install dependencies before you ## Installation ``` -pip install git+https://github.com/ur-whitelab/md-agent.git +pip install git+https://github.com/ur-whitelab/MDCrow.git ``` ## Usage @@ -25,14 +25,14 @@ We recommend setting up api keys in a .env file. You can use the provided .env.e 1. Copy the `.env.example` file and rename it to `.env`: `cp .env.example .env` 2. Replace the placeholder values in `.env` with your actual keys -You can ask MDAgent to conduct molecular dynamics tasks using OpenAI's GPT model +You can ask MDCrow to conduct molecular dynamics tasks using OpenAI's GPT model ``` -from mdagent import MDAgent +from mdcrow import MDCrow -agent = MDAgent(model="gpt-3.5-turbo") +agent = MDCrow(model="gpt-3.5-turbo") agent.run("Simulate protein 1ZNI at 300 K for 0.1 ps and calculate the RMSD over time.") ``` -Note: to distinguish Together models from the rest, you'll need to add "together\" prefix in model flag, such as `agent = MDAgent(model="together/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo")` +Note: to distinguish Together models from the rest, you'll need to add "together\" prefix in model flag, such as `agent = MDCrow(model="together/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo")` ## LLM Providers By default, we support LLMs through OpenAI API. However, feel free to use other LLM providers. Make sure to install the necessary package for it. Here's list of packages required for alternative LLM providers we support: @@ -43,6 +43,6 @@ By default, we support LLMs through OpenAI API. However, feel free to use other ## Contributing -We welcome contributions to MDAgent! If you're interested in contributing to the project, please check out our [Contributor's Guide](CONTRIBUTING.md) for detailed instructions on getting started, feature development, and the pull request process. +We welcome contributions to MDCrow! If you're interested in contributing to the project, please check out our [Contributor's Guide](CONTRIBUTING.md) for detailed instructions on getting started, feature development, and the pull request process. -We value and appreciate all contributions to MDAgent. +We value and appreciate all contributions to MDCrow. diff --git a/mdagent/__init__.py b/mdagent/__init__.py deleted file mode 100644 index cc336143..00000000 --- a/mdagent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .agent import Evaluator, MDAgent - -__all__ = ["MDAgent", "Evaluator"] diff --git a/mdagent/agent/README_evaluate.md b/mdagent/agent/README_evaluate.md deleted file mode 100644 index 0e3ecd34..00000000 --- a/mdagent/agent/README_evaluate.md +++ /dev/null @@ -1,109 +0,0 @@ -# README for Evaluate.py - -## Overview -`evaluate.py` is a Python script designed to facilitate the evaluation of MDAgent. It supports multiple use cases, including automated evaluation across different agent settings, loading previous evaluation results, and creating structured output tables. A list of generated or loaded evaluations can be accessed via `evaluator.evaluations`. Use `evaluator.reset()` to clear the evaluations list. - -## Getting Started -To use `evaluate.py`, ensure that `mdagent` package is installed in your Python environment (see main README for installation instructions). - -## Usage Examples - -### Example 1: Evaluate Prompts with default MD-Agent Parameters -Evaluate specific prompts with default settings: -```python -from mdagent import Evaluator - -evaluator = Evaluator() -prompts = [ - 'Download and clean fibronectin.', - 'Simulate 1A3N in water for 100 ns.', -] -df = evaluator.automate(prompts) -df # this displays DataFrame table in Jupyter notebook -``` -This will run MDAgent and evaluate the prompts using the default settings. The results will be -saved to json file in the "evaluation_results" directory and used to create pandas DataFrame. - -### Example 2: Evaluate Prompts with specific MD-Agent Parameters -Evaluate specific prompts using single agent instance with specified parameters: -```python -from mdagent import Evaluator - -evaluator = Evaluator() -agent_params = { - "agent_type": "Structured", - 'model': 'gpt-3.5-turbo', - 'tools_model': 'gpt-3.5-turbo', -} -prompts = [ - 'Download and clean fibronectin.', - 'Simulate 1A3N in water for 100 ns.', -] -df = evaluator.automate(prompts, agent_params=agent_params) -df_full = evaluator.create_table(simple=False) # to get a table with all details -``` - -### Example 3: Evaluate Prompts with Multiple Agent Parameters -Evaluate specific prompts using multiple agent settings with `automate_all` method: -```python -from mdagent import Evaluator - -evaluator = Evaluator() -prompts = [ - 'Download and clean fibronectin.', - 'Simulate 1A3N in water for 100 ns.', -] -agent_params_list = [ - { - "agent_type": "OpenAIFunctionsAgent", - "model": "gpt-4-1106-preview", - "ckpt_dir": "ckpt_openaifxn_gpt4", - }, - { - "agent_type": "Structured", - "model": "gpt-3.5-turbo", - "ckpt_dir": "ckpt_structured_gpt3.5", - }, -] -df = evaluator.automate_all(prompts, agent_params_list=agent_params_list) -``` - -### Example 4: Load Previous Evaluation Results and Create a Table -Load previous evaluation results from a JSON file: -```python -from mdagent import Evaluator - -evaluator = Evaluator() -evaluator.load('evaluation_results/mega_eval_20240422-181241.json') -df = evaluator.create_table() -df.to_latex('evaluation_results/eval_table.tex') # Optional: save table to a LaTeX file -``` -You can load multiple evaluation files by calling `evaluator.load()` multiple times. The results will be appended to the `evaluator.evaluations` list. - -### Example 5: Make Multi-Prompt Query with Agent Memory -Use agent memory to link multiple prompts: -```python -from mdagent import Evaluator - -evaluator = Evaluator() -agent_params1 = { - "use_memory": True -} -prompt_set1 = ['Simulate 1A3N in water for 100 ns.'] -df1 = evaluator.automate(prompt_set1, agent_params1) -df1 # display the table containing run ID - -agent_params2 = { - "use_memory": True, - "run_id": "U4831GA3", # <---- insert run_id from prompt 1 table -} -prompt_set2 = ['Calculate RMSD for 1A3N simulation over time.'] -df2 = evaluator.automate(prompt_set2, agent_params2) -df2 # display results from both prompts -``` -Another way to get run_id: you can access `evaluator.evaluations` list and pull the key -`run_id` from the dictionary that contains the results of the first prompt. - -## Additional Information -- For whatever reason, instead of `evaluator.automate()`, you can manually call `evaluator.run_and_evaluate(prompts, agent_params=params)` once or several times, `evaluator.save()` to save all evaluations to a json file, then use `evaluator.create_table()` to get DataFrame object. -- `evaluate.py` is designed to be used in a Jupyter notebook environment. diff --git a/mdagent/agent/__init__.py b/mdagent/agent/__init__.py deleted file mode 100644 index 00afc395..00000000 --- a/mdagent/agent/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .agent import MDAgent -from .evaluate import Evaluator - -__all__ = ["MDAgent", "Evaluator"] diff --git a/mdagent/agent/evaluate.py b/mdagent/agent/evaluate.py deleted file mode 100644 index fc0c4639..00000000 --- a/mdagent/agent/evaluate.py +++ /dev/null @@ -1,346 +0,0 @@ -import json -import os -import time - -import pandas as pd - -from .agent import MDAgent - - -# TODO: turn off verbose for MD-Agent -- verbose option doesn't work -# TODO: later, add write_to_notebooks option -class Evaluator: - def __init__(self, eval_dir="evaluation_results"): - # find root directory - eval_path = eval_dir - current_dir = os.getcwd() - while current_dir != "/": - if "setup.py" in os.listdir(current_dir): - root_dir = os.path.abspath(current_dir) - eval_path = os.path.join(root_dir, eval_dir) - break - else: - current_dir = os.path.dirname(current_dir) - self.base_dir = eval_path - os.makedirs(self.base_dir, exist_ok=True) - self.evaluations = [] - - def create_agent(self, agent_params={}): - """ - initializes MDAgent with given parameters - - Parameters: - - agent_params (dict): dictionary of parameters to initialize MDAgent - - Returns: - - initialized MDAgent object - """ - # initialize MDAgent with given parameters. - if agent_params is None: # this shouldn't happen though - agent_params = {} - return MDAgent(**agent_params) - - def reset(self): - """ - empties the evaluations list - - Parameters: - - None - - Returns: - - None - """ - self.evaluations = [] - - def save(self, filename="mega_eval", add_timestamp=True): - """ - save all evaluations to a json file - - Parameters: - - filename (str): name of the file to save evaluations to - - add_timestamp (bool): whether to add a timestamp to the filename - - Returns: - - None - """ - if filename.endswith(".json"): - filename = filename[:-5] - if add_timestamp: - timestamp = time.strftime("%Y%m%d-%H%M%S") - filename = f"{filename}_{timestamp}" - full_path = os.path.join(self.base_dir, f"{filename}.json") - with open(full_path, "w") as f: - json.dump(self.evaluations, f, indent=4) - print(f"All evaluations saved to {full_path}.") - - def load(self, filename): - """ - load past evaluations from a json file. Appends to the current evaluations list. - - Parameters: - - filename (str): name of the file to load evaluations from - - Returns: - - None - """ - if not os.path.exists(filename): - print(f"File {filename} not found. Please provide a valid file path.") - return - with open(filename, "r") as f: - data = json.load(f) - self.evaluations.extend(data) - - def _flatten_dict(self, d, sep="_"): - """ - flattens evaluations dictionary up to 3 levels deep. - Used in create_table method. - - Parameters: - - d (dict): dictionary to flatten - - sep (str): separator to use when flattening - - Returns: - - flattened dictionary - """ - flat_dict = {} - for k1, v1 in d.items(): - if isinstance(v1, dict): - for k2, v2 in v1.items(): - if isinstance(v2, dict): - for k3, v3 in v2.items(): - flat_key = f"{k1}{sep}{k2}{sep}{k3}" - flat_dict[flat_key] = v3 - else: - flat_key = f"{k1}{sep}{k2}" - flat_dict[flat_key] = v2 - else: - flat_key = k1 - flat_dict[flat_key] = v1 - return flat_dict - - def _evaluate_all_steps(self, agent, user_prompt): - """ - core function that evaluates while iterating every step of - MDAgent's response to a user prompt. Evaluation details are - saved to json file. - Used in run_and_evaluate method. - - NOTE: This is not meant to be used directly. Use run_and_evaluate - instead, since that method can capture exceptions and save to evaluations, - which can be used to create a table. - - Parameters: - - agent (MDAgent): MDAgent object - - user_prompt (str): user prompt to evaluate - - Returns: - - evaluation report (dict) containing details of the evaluation - """ - num_steps = 0 - tools_used = {} - tools_details = {} - failed_steps = 0 - status_complete = "Unclear" - step_start_time = start_time = time.time() - for step in agent.iter(user_prompt): - step_output = step.get("intermediate_step") - if step_output: - num_steps += 1 - action, observation = step_output[0] - current_time = time.time() - step_elapsed_time = current_time - step_start_time - step_start_time = current_time - tools_used[action.tool] = tools_used.get(action.tool, 0) + 1 - - # determine success or failure from the first sentence of the output - first_sentence = observation.split(".")[0] - if "Failed" in first_sentence or "Error" in first_sentence: - status_complete = False - failed_steps += 1 - elif "Succeeded" in first_sentence: - status_complete = True - else: - status_complete = "Unclear" - - tools_details[f"Step {num_steps}"] = { - "tool": action.tool, - "tool_input": action.tool_input, - "observation": observation, - "status_complete": status_complete, - "step_elapsed_time (sec)": f"{step_elapsed_time:.3f}", - "timestamp_from_start (sec)": f"{current_time - start_time:.3f}", - } - final_output = step.get("output", "") - if "Succeeded" in final_output.split(".")[0]: - prompt_passed = True - elif "Failed" in final_output.split(".")[0]: - prompt_passed = False - else: - # If the last step output doesn't explicitly state "Succeeded" or "Failed", - # determine the success of the prompt based on the previous step' status. - prompt_passed = status_complete - - run_id = agent.run_id - total_seconds = time.time() - start_time - total_mins = total_seconds / 60 - agent_settings = { - "llm": agent.llm.model_name, - "agent_type": agent.agent_type, - "tools_llm": agent.tools_llm.model_name, - "use_memory": agent.use_memory, - } - print("\n----- Evaluation Summary -----") - print("Run ID: ", run_id) - print("Prompt success: ", prompt_passed) - print(f"Total Steps: {num_steps+1}") - print(f"Total Time: {total_seconds:.2f} seconds ({total_mins:.2f} minutes)") - - eval_report = { - "agent_settings": agent_settings, - "user_prompt": user_prompt, - "prompt_success": prompt_passed, - "total_steps": num_steps, - "failed_steps": failed_steps, - "total_time_seconds": f"{total_seconds:.3f}", - "total_time_minutes": f"{total_mins:.3f}", - "final_answer": final_output, - "tools_used": tools_used, - "tools_details": tools_details, - "run_id": run_id, - } - timestamp = time.strftime("%Y%m%d-%H%M%S") - os.makedirs(f"{agent.ckpt_dir}/evals", exist_ok=True) - filename = f"{agent.ckpt_dir}/evals/individual_eval_{timestamp}.json" - with open(filename, "w") as f: - json.dump(eval_report, f, indent=4) - return eval_report - - def run_and_evaluate(self, prompts, agent_params={}): - """ - run and evaluate the agent with given parameters across multiple - prompts. Appends to the evaluations list. - - Parameters: - - prompts (list): list of prompts to evaluate - - agent_params (dict): dictionary of parameters to initialize MDAgent - - Returns: - - None - """ - agent = self.create_agent(agent_params) - for prompt in prompts: - print(f"Evaluating prompt: {prompt}") - try: - eval_report = self._evaluate_all_steps(agent, prompt) - eval_report["execution_success"] = True - self.evaluations.append(eval_report) - except Exception as e: - agent_settings = { - "llm": agent.llm.model_name, - "agent_type": agent.agent_type, - "memory": agent.use_memory, - } - self.evaluations.append( - { - "agent_settings": agent_settings, - "prompt": prompt, - "execution_success": False, - "error_msg": f"{type(e).__name__}: {e}", - } - ) - print(f"Error occurred while running MDAgent. {type(e).__name__}: {e}") - - def create_table(self, simple=True): - """ - creates DataFrame table from evaluations list. Note that evaluations - have to be loaded or generated first. - - Parameters: - - simple (bool): whether to return a simplified table with fewer columns - - Returns: - - DataFrame table of evaluations - """ - evals = [self._flatten_dict(eval) for eval in self.evaluations] - if not simple: - return pd.DataFrame(evals) - data = [] - for eval in evals: - data.append( - { - "LLM": eval.get("agent_settings_llm"), - "Agent Type": eval.get("agent_settings_agent_type"), - "User Prompt": eval.get("prompt"), - "Prompt Success": eval.get("prompt_success"), - "Execution Success": eval.get("execution_success"), - "Error Message": eval.get("error_msg"), - "Total Steps": eval.get("total_steps"), - "Failed Steps": eval.get("failed_steps"), - "Time (s)": eval.get("total_time_seconds"), - "Time (min)": eval.get("total_time_minutes"), - "Run ID": eval.get("run_id"), - } - ) - return pd.DataFrame(data) - - def automate(self, prompts, agent_params={}): - """ - this automates the entire evaluation process for a given agent - and prompts. It runs and evaluates, save the evaluations to a - json file, and creates a table. - - Parameters: - - prompts (list): list of prompts to evaluate - - agent_params (dict): dictionary of parameters to initialize MDAgent - - Returns: - - DataFrame table of evaluations - """ - self.run_and_evaluate(prompts, agent_params) - self.save() - dataframe = self.create_table() - return dataframe - - def automate_all(self, prompts, agent_params_list=None): - """ - it automates the entire evaluation process for a list of agents. - After evaluating all prompts with each agent, it saves the evaluations - to a json file and creates a table containing all evaluations. - - Parameters: - - prompts (list): list of prompts to evaluate - - agent_params_list (list): list of dictionaries containing parameters - to initialize MDAgent. If None, it will evaluate with default agents. - - Returns: - - DataFrame table of evaluations - """ - if agent_params_list is None: - agent_params_list = [ - { - "agent_type": "OpenAIFunctionsAgent", - "model": "gpt-4-1106-preview", - "ckpt_dir": "ckpt_openaifxn_gpt4", - }, - { - "agent_type": "Structured", - "model": "gpt-4-1106-preview", - "ckpt_dir": "ckpt_structured_gpt4", - }, - { - "agent_type": "OpenAIFunctionsAgent", - "model": "gpt-3.5-turbo", - "ckpt_dir": "ckpt_openaifxn_gpt3.5", - }, - { - "agent_type": "Structured", - "model": "gpt-3.5-turbo", - "ckpt_dir": "ckpt_structured_gpt3.5", - }, - ] - - for agent in agent_params_list: - self.run_and_evaluate(prompts, agent) - self.save() - dataframe = self.create_table() - return dataframe diff --git a/mdcrow/__init__.py b/mdcrow/__init__.py new file mode 100644 index 00000000..533d5466 --- /dev/null +++ b/mdcrow/__init__.py @@ -0,0 +1,3 @@ +from .agent import MDCrow + +__all__ = ["MDCrow"] diff --git a/mdcrow/agent/__init__.py b/mdcrow/agent/__init__.py new file mode 100644 index 00000000..533d5466 --- /dev/null +++ b/mdcrow/agent/__init__.py @@ -0,0 +1,3 @@ +from .agent import MDCrow + +__all__ = ["MDCrow"] diff --git a/mdagent/agent/agent.py b/mdcrow/agent/agent.py similarity index 96% rename from mdagent/agent/agent.py rename to mdcrow/agent/agent.py index 2cea0c61..db1428de 100644 --- a/mdagent/agent/agent.py +++ b/mdcrow/agent/agent.py @@ -1,6 +1,6 @@ import os from datetime import datetime -from time import time + from dotenv import load_dotenv from langchain.agents import AgentExecutor, OpenAIFunctionsAgent from langchain.agents.structured_chat.base import StructuredChatAgent @@ -31,7 +31,7 @@ def get_agent(cls, model_name: str = "OpenAIFunctionsAgent"): ) -class MDAgent: +class MDCrow: def __init__( self, tools=None, @@ -63,7 +63,7 @@ def __init__( self.uploaded_files = uploaded_files # for file in uploaded_files: # todo -> allow users to add descriptions? - # self.path_registry.map_path(file, file, description="User uploaded file") + # self.path_registry.map_path(file, file, description="User uploaded file") self.agent = None self.agent_type = agent_type @@ -75,15 +75,16 @@ def __init__( if self.uploaded_files: self.add_file(self.uploaded_files) self.safe_mode = safe_mode + def _add_single_file(self, file_path, description=None): now = datetime.now() # Format the date and time as "YYYYMMDD_HHMMSS" timestamp = now.strftime("%Y%m%d_%H%M%S") i = 0 - ID = "UPL_"+str(i) + timestamp - while ID in self.path_registry.list_path_names(): # check if ID already exists + ID = "UPL_" + str(i) + timestamp + while ID in self.path_registry.list_path_names(): # check if ID already exists i += 1 - ID = "UPL_"+str(i) + timestamp + ID = "UPL_" + str(i) + timestamp if not description: # asks for user input to add description for file file_path # wait for 20 seconds or set up a default description diff --git a/mdagent/agent/memory.py b/mdcrow/agent/memory.py similarity index 99% rename from mdagent/agent/memory.py rename to mdcrow/agent/memory.py index f450e47c..37ced874 100644 --- a/mdagent/agent/memory.py +++ b/mdcrow/agent/memory.py @@ -6,7 +6,7 @@ from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry agent_summary_template = PromptTemplate( input_variables=["agent_trace"], diff --git a/mdagent/agent/prompt.py b/mdcrow/agent/prompt.py similarity index 100% rename from mdagent/agent/prompt.py rename to mdcrow/agent/prompt.py diff --git a/mdagent/tools/__init__.py b/mdcrow/tools/__init__.py similarity index 100% rename from mdagent/tools/__init__.py rename to mdcrow/tools/__init__.py diff --git a/mdagent/tools/base_tools/__init__.py b/mdcrow/tools/base_tools/__init__.py similarity index 97% rename from mdagent/tools/base_tools/__init__.py rename to mdcrow/tools/base_tools/__init__.py index 04c4fa35..89b03d3f 100644 --- a/mdagent/tools/base_tools/__init__.py +++ b/mdcrow/tools/base_tools/__init__.py @@ -1,4 +1,5 @@ from .analysis_tools.distance_tools import ContactsTool, DistanceMatrixTool +from .analysis_tools.hydrogen_bonding_tools import HydrogenBondTool from .analysis_tools.inertia import MomentOfInertia from .analysis_tools.pca_tools import PCATool from .analysis_tools.plot_tools import SimulationOutputFigures @@ -67,6 +68,7 @@ "ComputeRMSF", "ContactsTool", "DistanceMatrixTool", + "HydrogenBondTool", "ListRegistryPaths", "MapPath2Name", "MapProteinRepresentation", diff --git a/mdagent/tools/base_tools/analysis_tools/__init__.py b/mdcrow/tools/base_tools/analysis_tools/__init__.py similarity index 90% rename from mdagent/tools/base_tools/analysis_tools/__init__.py rename to mdcrow/tools/base_tools/analysis_tools/__init__.py index 26d58784..810cf503 100644 --- a/mdagent/tools/base_tools/analysis_tools/__init__.py +++ b/mdcrow/tools/base_tools/analysis_tools/__init__.py @@ -1,4 +1,5 @@ from .distance_tools import ContactsTool, DistanceMatrixTool +from .hydrogen_bonding_tools import HydrogenBondTool from .inertia import MomentOfInertia from .pca_tools import PCATool from .plot_tools import SimulationOutputFigures @@ -14,6 +15,7 @@ "ComputeRMSF", "ContactsTool", "DistanceMatrixTool", + "HydrogenBondTool", "MomentOfInertia", "PCATool", "PPIDistance", diff --git a/mdagent/tools/base_tools/analysis_tools/distance_tools.py b/mdcrow/tools/base_tools/analysis_tools/distance_tools.py similarity index 99% rename from mdagent/tools/base_tools/analysis_tools/distance_tools.py rename to mdcrow/tools/base_tools/analysis_tools/distance_tools.py index 3cdf349b..a4c14c8f 100644 --- a/mdagent/tools/base_tools/analysis_tools/distance_tools.py +++ b/mdcrow/tools/base_tools/analysis_tools/distance_tools.py @@ -10,7 +10,7 @@ from matplotlib.animation import FuncAnimation from pydantic import BaseModel, Field -from mdagent.utils import FileType, PathRegistry, load_single_traj +from mdcrow.utils import FileType, PathRegistry, load_single_traj class DistanceToolsUtils: diff --git a/mdcrow/tools/base_tools/analysis_tools/hydrogen_bonding_tools.py b/mdcrow/tools/base_tools/analysis_tools/hydrogen_bonding_tools.py new file mode 100644 index 00000000..42c90fa1 --- /dev/null +++ b/mdcrow/tools/base_tools/analysis_tools/hydrogen_bonding_tools.py @@ -0,0 +1,128 @@ +import matplotlib.pyplot as plt +import mdtraj as md +from langchain.tools import BaseTool + +from mdcrow.utils import FileType, PathRegistry, load_single_traj + + +class HydrogenBondTool(BaseTool): + """Note that this tool only usees the Baker-Hubbard method for identifying hydrogen bonds. + Other methods (kabsch-sander, wernet-nilsson) can be implemented later, if desired. + """ + + name = "hydrogen_bond_tool" + description = ( + "Identifies hydrogen bonds and plots the results from the" + "provided trajectory data." + "Input the File ID for the trajectory file and optionally the topology file. " + "The tool will output the file ID of the results and plot." + ) + + path_registry: PathRegistry | None = None + freq: float = 0.3 + + def __init__(self, path_registry, freq=0.1): + super().__init__() + self.path_registry = path_registry + self.freq = freq + + def compute_hbonds_traj(self, traj): + hbond_counts = [] + for frame in range(traj.n_frames): + hbonds = md.baker_hubbard(traj[frame], freq=self.freq) + hbond_counts.append(len(hbonds)) + return hbond_counts + + def write_hbond_counts_to_file(self, hbond_counts, traj_id): + output_file = f"{traj_id}_hbond_counts" + + file_name = self.path_registry.write_file_name( + type=FileType.RECORD, fig_analysis=output_file, file_format="csv" + ) + file_id = self.path_registry.get_fileid( + file_name=file_name, type=FileType.FIGURE + ) + + file_path = f"{self.path_registry.ckpt_records}/{file_name}" + file_path = file_path if file_path.endswith(".csv") else file_path + ".csv" + + with open(file_path, "w") as f: + f.write("Frame,Hydrogen Bonds\n") + for frame, count in enumerate(hbond_counts): + f.write(f"{frame},{count}\n") + self.path_registry.map_path( + file_id, + file_path, + description=f"Hydrogen bond counts for {traj_id}", + ) + return f"Data saved to: {file_id}, full path: {file_path}" + + def plot_hbonds_over_time(self, hbond_counts, traj, traj_id): + fig_analysis = f"hbonds_over_time_{traj_id}" + plot_name = self.path_registry.write_file_name( + type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png" + ) + plot_id = self.path_registry.get_fileid( + file_name=plot_name, type=FileType.FIGURE + ) + plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}" + plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png" + plt.plot(range(traj.n_frames), hbond_counts, marker="o") + plt.xlabel("Frame") + plt.ylabel("Number of Hydrogen Bonds") + plt.title(f"Hydrogen Bonds Over Time for traj {traj_id}") + plt.grid(True) + plt.savefig(f"{plot_path}") + + self.path_registry.map_path( + plot_id, + plot_path, + description=f"Plot of hydrogen bonds over time for {traj_id}", + ) + plt.close() + plt.clf() + return f"plot saved to: {plot_id}, full path: {plot_path}" + + def _run( + self, + top_file: str, + traj_file: str | None = None, + ) -> str: + try: + traj_file = ( + traj_file + if (traj_file is not None) and (traj_file != top_file) + else None + ) + traj = load_single_traj( + path_registry=self.path_registry, + top_fileid=top_file, + traj_fileid=traj_file, + traj_required=False, + ) + if not traj: + raise ValueError("Trajectory could not be loaded.") + except Exception as e: + return f"Error loading traj: {e}" + + try: + hbond_counts = self.compute_hbonds_traj(traj) + rtrn_msg = "" + if all(count == 0 for count in hbond_counts): + rtrn_msg += ( + "No hydrogen bonds found in the trajectory. " + "Did you forget to add missing hydrogens? " + ) + traj_file = top_file if not traj_file else traj_file + plot_id = self.plot_hbonds_over_time(hbond_counts, traj, traj_file) + data_id = self.write_hbond_counts_to_file(hbond_counts, traj_file) + return f"Hydrogen bond analysis completed. {data_id}, {plot_id} {rtrn_msg}." + except Exception as e: + return f"Error during hydrogen bond analysis: {e}" + + async def _arun( + self, + top_file: str, + traj_file: str | None = None, + ) -> str: + raise NotImplementedError diff --git a/mdagent/tools/base_tools/analysis_tools/inertia.py b/mdcrow/tools/base_tools/analysis_tools/inertia.py similarity index 98% rename from mdagent/tools/base_tools/analysis_tools/inertia.py rename to mdcrow/tools/base_tools/analysis_tools/inertia.py index 62d0124e..64a3f5c2 100644 --- a/mdagent/tools/base_tools/analysis_tools/inertia.py +++ b/mdcrow/tools/base_tools/analysis_tools/inertia.py @@ -5,7 +5,7 @@ import numpy as np from langchain.tools import BaseTool -from mdagent.utils import FileType, PathRegistry, load_single_traj, save_to_csv +from mdcrow.utils import FileType, PathRegistry, load_single_traj, save_to_csv class MOIFunctions: diff --git a/mdagent/tools/base_tools/analysis_tools/pca_tools.py b/mdcrow/tools/base_tools/analysis_tools/pca_tools.py similarity index 99% rename from mdagent/tools/base_tools/analysis_tools/pca_tools.py rename to mdcrow/tools/base_tools/analysis_tools/pca_tools.py index 90880f56..fa86a108 100644 --- a/mdagent/tools/base_tools/analysis_tools/pca_tools.py +++ b/mdcrow/tools/base_tools/analysis_tools/pca_tools.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from sklearn.decomposition import PCA -from mdagent.utils import FileType, PathRegistry, load_single_traj +from mdcrow.utils import FileType, PathRegistry, load_single_traj class PCA_analysis: diff --git a/mdagent/tools/base_tools/analysis_tools/plot_tools.py b/mdcrow/tools/base_tools/analysis_tools/plot_tools.py similarity index 98% rename from mdagent/tools/base_tools/analysis_tools/plot_tools.py rename to mdcrow/tools/base_tools/analysis_tools/plot_tools.py index dfd5d693..767fd3fe 100644 --- a/mdagent/tools/base_tools/analysis_tools/plot_tools.py +++ b/mdcrow/tools/base_tools/analysis_tools/plot_tools.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt from langchain.tools import BaseTool -from mdagent.utils import FileType, PathRegistry +from mdcrow.utils import FileType, PathRegistry class PlottingTools: diff --git a/mdagent/tools/base_tools/analysis_tools/ppi_tools.py b/mdcrow/tools/base_tools/analysis_tools/ppi_tools.py similarity index 98% rename from mdagent/tools/base_tools/analysis_tools/ppi_tools.py rename to mdcrow/tools/base_tools/analysis_tools/ppi_tools.py index 7e44e042..ea282051 100644 --- a/mdagent/tools/base_tools/analysis_tools/ppi_tools.py +++ b/mdcrow/tools/base_tools/analysis_tools/ppi_tools.py @@ -6,7 +6,7 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry def ppi_distance(file_path, binding_site="protein"): diff --git a/mdagent/tools/base_tools/analysis_tools/rdf_tool.py b/mdcrow/tools/base_tools/analysis_tools/rdf_tool.py similarity index 99% rename from mdagent/tools/base_tools/analysis_tools/rdf_tool.py rename to mdcrow/tools/base_tools/analysis_tools/rdf_tool.py index e6fa24cb..a0a168cd 100644 --- a/mdagent/tools/base_tools/analysis_tools/rdf_tool.py +++ b/mdcrow/tools/base_tools/analysis_tools/rdf_tool.py @@ -5,7 +5,7 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field -from mdagent.utils import FileType, PathRegistry +from mdcrow.utils import FileType, PathRegistry class RDFToolInput(BaseModel): diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdcrow/tools/base_tools/analysis_tools/rgy.py similarity index 98% rename from mdagent/tools/base_tools/analysis_tools/rgy.py rename to mdcrow/tools/base_tools/analysis_tools/rgy.py index 5976d9f5..a648cdf7 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdcrow/tools/base_tools/analysis_tools/rgy.py @@ -5,7 +5,7 @@ import numpy as np from langchain.tools import BaseTool -from mdagent.utils import FileType, PathRegistry, load_single_traj +from mdcrow.utils import FileType, PathRegistry, load_single_traj class RadiusofGyration: diff --git a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py b/mdcrow/tools/base_tools/analysis_tools/rmsd_tools.py similarity index 99% rename from mdagent/tools/base_tools/analysis_tools/rmsd_tools.py rename to mdcrow/tools/base_tools/analysis_tools/rmsd_tools.py index b4f960c0..a7f2377d 100644 --- a/mdagent/tools/base_tools/analysis_tools/rmsd_tools.py +++ b/mdcrow/tools/base_tools/analysis_tools/rmsd_tools.py @@ -5,7 +5,7 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field -from mdagent.utils import PathRegistry, load_traj_with_ref, save_plot, save_to_csv +from mdcrow.utils import PathRegistry, load_traj_with_ref, save_plot, save_to_csv def rmsd(path_registry, traj, ref_traj, mol_name, select="protein"): diff --git a/mdagent/tools/base_tools/analysis_tools/sasa.py b/mdcrow/tools/base_tools/analysis_tools/sasa.py similarity index 98% rename from mdagent/tools/base_tools/analysis_tools/sasa.py rename to mdcrow/tools/base_tools/analysis_tools/sasa.py index f51eee6f..ece2499b 100644 --- a/mdagent/tools/base_tools/analysis_tools/sasa.py +++ b/mdcrow/tools/base_tools/analysis_tools/sasa.py @@ -5,7 +5,7 @@ import numpy as np from langchain.tools import BaseTool -from mdagent.utils import FileType, PathRegistry, load_single_traj, save_to_csv +from mdcrow.utils import FileType, PathRegistry, load_single_traj, save_to_csv class SASAFunctions: diff --git a/mdagent/tools/base_tools/analysis_tools/secondary_structure.py b/mdcrow/tools/base_tools/analysis_tools/secondary_structure.py similarity index 99% rename from mdagent/tools/base_tools/analysis_tools/secondary_structure.py rename to mdcrow/tools/base_tools/analysis_tools/secondary_structure.py index f00decb4..d3a062cb 100644 --- a/mdagent/tools/base_tools/analysis_tools/secondary_structure.py +++ b/mdcrow/tools/base_tools/analysis_tools/secondary_structure.py @@ -5,7 +5,7 @@ import numpy as np from langchain.tools import BaseTool -from mdagent.utils import FileType, PathRegistry, load_single_traj +from mdcrow.utils import FileType, PathRegistry, load_single_traj def write_raw_x( diff --git a/mdagent/tools/base_tools/analysis_tools/vis_tools.py b/mdcrow/tools/base_tools/analysis_tools/vis_tools.py similarity index 99% rename from mdagent/tools/base_tools/analysis_tools/vis_tools.py rename to mdcrow/tools/base_tools/analysis_tools/vis_tools.py index 3223c1ac..cda47ca5 100644 --- a/mdagent/tools/base_tools/analysis_tools/vis_tools.py +++ b/mdcrow/tools/base_tools/analysis_tools/vis_tools.py @@ -6,7 +6,7 @@ import nbformat as nbf from langchain.tools import BaseTool -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry class VisFunctions: diff --git a/mdagent/tools/base_tools/preprocess_tools/__init__.py b/mdcrow/tools/base_tools/preprocess_tools/__init__.py similarity index 100% rename from mdagent/tools/base_tools/preprocess_tools/__init__.py rename to mdcrow/tools/base_tools/preprocess_tools/__init__.py diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdcrow/tools/base_tools/preprocess_tools/clean_tools.py similarity index 99% rename from mdagent/tools/base_tools/preprocess_tools/clean_tools.py rename to mdcrow/tools/base_tools/preprocess_tools/clean_tools.py index 1be094f3..ed03786b 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdcrow/tools/base_tools/preprocess_tools/clean_tools.py @@ -5,7 +5,7 @@ from pdbfixer import PDBFixer from pydantic import BaseModel, Field -from mdagent.utils import FileType, PathRegistry +from mdcrow.utils import FileType, PathRegistry class CleaningToolFunctionInput(BaseModel): diff --git a/mdagent/tools/base_tools/preprocess_tools/elements.py b/mdcrow/tools/base_tools/preprocess_tools/elements.py similarity index 100% rename from mdagent/tools/base_tools/preprocess_tools/elements.py rename to mdcrow/tools/base_tools/preprocess_tools/elements.py diff --git a/mdagent/tools/base_tools/preprocess_tools/packing.py b/mdcrow/tools/base_tools/preprocess_tools/packing.py similarity index 99% rename from mdagent/tools/base_tools/preprocess_tools/packing.py rename to mdcrow/tools/base_tools/preprocess_tools/packing.py index ccc38b5d..e07d8cc9 100644 --- a/mdagent/tools/base_tools/preprocess_tools/packing.py +++ b/mdcrow/tools/base_tools/preprocess_tools/packing.py @@ -7,7 +7,7 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Field, ValidationError -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry from .pdb_fix import Validate_Fix_PDB from .pdb_get import MolPDB diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py b/mdcrow/tools/base_tools/preprocess_tools/pdb_fix.py similarity index 99% rename from mdagent/tools/base_tools/preprocess_tools/pdb_fix.py rename to mdcrow/tools/base_tools/preprocess_tools/pdb_fix.py index 8c4c25f9..2b722ee6 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_fix.py +++ b/mdcrow/tools/base_tools/preprocess_tools/pdb_fix.py @@ -8,7 +8,7 @@ from pdbfixer import PDBFixer from pydantic import BaseModel, Field, ValidationError -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry from .elements import list_of_elements diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_get.py b/mdcrow/tools/base_tools/preprocess_tools/pdb_get.py similarity index 99% rename from mdagent/tools/base_tools/preprocess_tools/pdb_get.py rename to mdcrow/tools/base_tools/preprocess_tools/pdb_get.py index a44f5075..c6606f1d 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_get.py +++ b/mdcrow/tools/base_tools/preprocess_tools/pdb_get.py @@ -5,7 +5,7 @@ from rdkit import Chem from rdkit.Chem import AllChem -from mdagent.utils import FileType, PathRegistry +from mdcrow.utils import FileType, PathRegistry def get_pdb(query_string: str, path_registry: PathRegistry): diff --git a/mdagent/tools/base_tools/preprocess_tools/uniprot.py b/mdcrow/tools/base_tools/preprocess_tools/uniprot.py similarity index 100% rename from mdagent/tools/base_tools/preprocess_tools/uniprot.py rename to mdcrow/tools/base_tools/preprocess_tools/uniprot.py diff --git a/mdagent/tools/base_tools/simulation_tools/__init__.py b/mdcrow/tools/base_tools/simulation_tools/__init__.py similarity index 100% rename from mdagent/tools/base_tools/simulation_tools/__init__.py rename to mdcrow/tools/base_tools/simulation_tools/__init__.py diff --git a/mdagent/tools/base_tools/simulation_tools/create_simulation.py b/mdcrow/tools/base_tools/simulation_tools/create_simulation.py similarity index 94% rename from mdagent/tools/base_tools/simulation_tools/create_simulation.py rename to mdcrow/tools/base_tools/simulation_tools/create_simulation.py index cc3a63d4..163a7af5 100644 --- a/mdagent/tools/base_tools/simulation_tools/create_simulation.py +++ b/mdcrow/tools/base_tools/simulation_tools/create_simulation.py @@ -8,7 +8,7 @@ from langchain_core.output_parsers import StrOutputParser from pydantic import BaseModel, Field -from mdagent.utils import FileType, PathRegistry +from mdcrow.utils import FileType, PathRegistry class ModifyScriptUtils: @@ -84,9 +84,9 @@ class ModifyBaseSimulationScriptTool(BaseTool): args_schema = ModifyScriptInput llm: Optional[BaseLanguageModel] path_registry: Optional[PathRegistry] - safe_mode: Optional[bool] + safe_mode: Optional[bool] - def __init__(self, path_registry, llm, safe_mode = False): + def __init__(self, path_registry, llm, safe_mode=False): super().__init__() self.path_registry = path_registry self.llm = llm @@ -152,18 +152,20 @@ def _run(self, script_id: str, query: str) -> str: file.write(script_content) self.path_registry.map_path(file_id, f"{directory}/{filename}", description) - #if safe mode is on, return the file id + # if safe mode is on, return the file id if self.safe_mode: return f"Succeeded. Script modified successfully. Modified Script ID: {file_id}" - #if safe mode is off, try to run the script + # if safe mode is off, try to run the script try: exec(script_content) return f"Succeeded. Script modified and ran \ successfully. Modified Script ID: {file_id}" except Exception as e: - return (f"Failed. Error running the script: {e}." + return ( + f"Failed. Error running the script: {e}." "Modified Script ID: {file_id}. If you want to try to correct the " - "script, use the file id of the modified to correct the script.") + "script, use the file id of the modified to correct the script." + ) async def _arun(self, query) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdcrow/tools/base_tools/simulation_tools/setup_and_run.py similarity index 99% rename from mdagent/tools/base_tools/simulation_tools/setup_and_run.py rename to mdcrow/tools/base_tools/simulation_tools/setup_and_run.py index 07631243..0f83af4c 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdcrow/tools/base_tools/simulation_tools/setup_and_run.py @@ -43,7 +43,7 @@ from rdkit import Chem # Local Library/Application Imports -from mdagent.utils import FileType, PathRegistry +from mdcrow.utils import FileType, PathRegistry # TODO delete files created from the simulation if not needed. @@ -608,7 +608,7 @@ def _construct_script_content( integrator_type, ): script_content = f""" - # This script was generated by MDagent-Setup. + # This script was generated by MDCrow-Setup. from openmm import * from openmm.app import * diff --git a/mdagent/tools/base_tools/util_tools/__init__.py b/mdcrow/tools/base_tools/util_tools/__init__.py similarity index 100% rename from mdagent/tools/base_tools/util_tools/__init__.py rename to mdcrow/tools/base_tools/util_tools/__init__.py diff --git a/mdagent/tools/base_tools/util_tools/registry_tools.py b/mdcrow/tools/base_tools/util_tools/registry_tools.py similarity index 98% rename from mdagent/tools/base_tools/util_tools/registry_tools.py rename to mdcrow/tools/base_tools/util_tools/registry_tools.py index 6e8bdf2c..5bff7c2c 100644 --- a/mdagent/tools/base_tools/util_tools/registry_tools.py +++ b/mdcrow/tools/base_tools/util_tools/registry_tools.py @@ -3,7 +3,7 @@ from langchain.tools import BaseTool -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry class MapPath2Name(BaseTool): diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdcrow/tools/base_tools/util_tools/search_tools.py similarity index 98% rename from mdagent/tools/base_tools/util_tools/search_tools.py rename to mdcrow/tools/base_tools/util_tools/search_tools.py index 20424e07..dc204bf0 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdcrow/tools/base_tools/util_tools/search_tools.py @@ -5,7 +5,7 @@ from langchain.base_language import BaseLanguageModel from langchain.tools import BaseTool -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry def scholar2result_llm(llm, query, path_registry): diff --git a/mdagent/tools/maketools.py b/mdcrow/tools/maketools.py similarity index 95% rename from mdagent/tools/maketools.py rename to mdcrow/tools/maketools.py index 77596333..86b5ec6c 100644 --- a/mdagent/tools/maketools.py +++ b/mdcrow/tools/maketools.py @@ -8,7 +8,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry from .base_tools import ( CleaningToolFunction, @@ -38,6 +38,7 @@ GetSubunitStructure, GetTurnsBetaSheetsHelices, GetUniprotID, + HydrogenBondTool, ListRegistryPaths, MapProteinRepresentation, ModifyBaseSimulationScriptTool, @@ -71,10 +72,9 @@ def make_all_tools( all_tools += agents.load_tools(["llm-math"], llm) # all_tools += [PythonREPLTool()] all_tools += [ - ModifyBaseSimulationScriptTool(path_registry=path_instance, - llm=llm, - safe_mode=safe_mode - ), + ModifyBaseSimulationScriptTool( + path_registry=path_instance, llm=llm, safe_mode=safe_mode + ), ] if path_instance.ckpt_papers: all_tools += [Scholar2ResultLLM(llm=llm, path_registry=path_instance)] @@ -95,6 +95,7 @@ def make_all_tools( ComputeRMSF(path_registry=path_instance), ContactsTool(path_registry=path_instance), DistanceMatrixTool(path_registry=path_instance), + HydrogenBondTool(path_registry=path_instance), ListRegistryPaths(path_registry=path_instance), MomentOfInertia(path_registry=path_instance), PackMolTool(path_registry=path_instance), diff --git a/mdagent/utils/__init__.py b/mdcrow/utils/__init__.py similarity index 100% rename from mdagent/utils/__init__.py rename to mdcrow/utils/__init__.py diff --git a/mdagent/utils/data_handling.py b/mdcrow/utils/data_handling.py similarity index 99% rename from mdagent/utils/data_handling.py rename to mdcrow/utils/data_handling.py index 22647c8b..4a9dc4f7 100644 --- a/mdagent/utils/data_handling.py +++ b/mdcrow/utils/data_handling.py @@ -53,6 +53,7 @@ def load_single_traj( ), UserWarning, ) + return md.load(top_path) else: raise ValueError("Trajectory File ID is required, and it's not provided.") @@ -88,6 +89,7 @@ def load_traj_with_ref( ref_traj = load_single_traj( path_registry, ref_top_id, ref_traj_id, traj_required, ignore_warnings ) + return traj, ref_traj diff --git a/mdagent/utils/makellm.py b/mdcrow/utils/makellm.py similarity index 100% rename from mdagent/utils/makellm.py rename to mdcrow/utils/makellm.py diff --git a/mdagent/utils/path_registry.py b/mdcrow/utils/path_registry.py similarity index 99% rename from mdagent/utils/path_registry.py rename to mdcrow/utils/path_registry.py index c7ebbcc9..2f0714d8 100644 --- a/mdagent/utils/path_registry.py +++ b/mdcrow/utils/path_registry.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional -from mdagent.utils.set_ckpt import SetCheckpoint +from mdcrow.utils.set_ckpt import SetCheckpoint ##TODO: add method to get description from simulation inputs diff --git a/mdagent/utils/set_ckpt.py b/mdcrow/utils/set_ckpt.py similarity index 100% rename from mdagent/utils/set_ckpt.py rename to mdcrow/utils/set_ckpt.py diff --git a/mdagent/version.py b/mdcrow/version.py similarity index 100% rename from mdagent/version.py rename to mdcrow/version.py diff --git a/notebooks/lit_search.ipynb b/notebooks/lit_search.ipynb new file mode 100644 index 00000000..5ed3ccb3 --- /dev/null +++ b/notebooks/lit_search.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/samcox/anaconda3/envs/mda_feb21/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from mdcrow import MDCrow" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "#until we update to new version\n", + "import nest_asyncio\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "mda = MDCrow()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"Are there any studies that show that the use of a mask can reduce the spread of COVID-19?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"Masks COVID-19 transmission reduction studies\"\n", + "Search: \"Masks COVID-19 transmission reduction studies\"\n", + "\n", + "Found 14 papers but couldn't load 0\n", + "Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies." + ] + } + ], + "source": [ + "answer = mda.run(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "answer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mdcrow", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mdagent/simulated_annealing.ipynb b/notebooks/mdcrow_extrapolation/simulated_annealing.ipynb similarity index 99% rename from mdagent/simulated_annealing.ipynb rename to notebooks/mdcrow_extrapolation/simulated_annealing.ipynb index 71557c83..5cf6dc6b 100644 --- a/mdagent/simulated_annealing.ipynb +++ b/notebooks/mdcrow_extrapolation/simulated_annealing.ipynb @@ -6,8 +6,6 @@ "metadata": {}, "outputs": [], "source": [ - "import datetime\n", - "import os\n", "from mdagent import MDAgent\n" ] }, @@ -21,7 +19,6 @@ "import matplotlib.pyplot as plt\n", "import matplotlib.font_manager as font_manager\n", "import urllib.request\n", - "import numpy as np\n", "\n", "urllib.request.urlretrieve(\n", " \"https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf\",\n", @@ -42,9 +39,7 @@ " \"ytick.left\": True,\n", " \"xtick.bottom\": True,\n", " }\n", - ")\n", - "\n", - "import random" + ")\n" ] }, { diff --git a/notebooks/memory_demo.ipynb b/notebooks/memory_demo.ipynb index fe6c6746..fdea9df6 100644 --- a/notebooks/memory_demo.ipynb +++ b/notebooks/memory_demo.ipynb @@ -15,7 +15,7 @@ } ], "source": [ - "from mdagent import MDAgent" + "from mdcrow import MDCrow" ] }, { @@ -24,7 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "agent = MDAgent(use_memory=True)" + "agent = MDCrow(use_memory=True)" ] }, { @@ -99,7 +99,7 @@ } ], "source": [ - "#agent = MDAgent(use_memory=True, run_id=run_id_run1)\n", + "#agent = MDCrow(use_memory=True, run_id=run_id_run1)\n", "\n", "prompt_with_mem = \"Now do the same for hemoglobin.\"\n", "\n", diff --git a/setup.py b/setup.py index d23addcc..b878aa2d 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,18 @@ # fake to satisfy mypy __version__ = "0.0.0" -exec(open("mdagent/version.py").read()) +exec(open("mdcrow/version.py").read()) with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setup( - name="md-agent", + name="MDCrow", version=__version__, description="Collection of MD tools for use with language models", author="Andrew White", author_email="andrew.white@rochester.edu", - url="https://github.com/ur-whitelab/md-agent", + url="https://github.com/ur-whitelab/MDCrow", license="MIT", packages=find_packages(), install_requires=[ diff --git a/tests/conftest.py b/tests/conftest.py index 38a35ffa..1515c5ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from mdagent.utils import PathRegistry +from mdcrow.utils import PathRegistry def safe_remove(file_path): diff --git a/tests/test_analysis/test_distance_tools.py b/tests/test_analysis/test_distance_tools.py index b264b950..d095fc0e 100644 --- a/tests/test_analysis/test_distance_tools.py +++ b/tests/test_analysis/test_distance_tools.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from mdagent.tools.base_tools.analysis_tools.distance_tools import DistanceToolsUtils +from mdcrow.tools.base_tools.analysis_tools.distance_tools import DistanceToolsUtils @pytest.fixture(scope="module") diff --git a/tests/test_analysis/test_hydrogen_bonding.py b/tests/test_analysis/test_hydrogen_bonding.py new file mode 100644 index 00000000..6e3bf7ec --- /dev/null +++ b/tests/test_analysis/test_hydrogen_bonding.py @@ -0,0 +1,55 @@ +import mdtraj as md +import numpy as np +import pytest + +from mdcrow.tools.base_tools.analysis_tools.hydrogen_bonding_tools import ( + HydrogenBondTool, +) + + +@pytest.fixture +def hydrogen_bond_tool(get_registry): + path_registry = get_registry("raw", True) + return HydrogenBondTool(path_registry) + + +@pytest.fixture +def dummy_traj(): + topology = md.Topology() + chain = topology.add_chain() + residue = topology.add_residue("ALA", chain) + atom1 = topology.add_atom("N", element=md.element.nitrogen, residue=residue) + atom2 = topology.add_atom("H", element=md.element.hydrogen, residue=residue) + atom3 = topology.add_atom("O", element=md.element.oxygen, residue=residue) + topology.add_bond(atom1, atom2) + topology.add_bond(atom1, atom3) + + n_atoms = topology.n_atoms + n_frames = 3 + coordinates = np.zeros((n_frames, n_atoms, 3)) + + coordinates[0, :, :] = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] + coordinates[1, :, :] = [[0, 0, 0], [1.1, 0, 0], [0, 1.1, 0]] + coordinates[2, :, :] = [[0, 0, 0], [1.2, 0, 0], [0, 1.2, 0]] + + traj = md.Trajectory(coordinates, topology) + return traj + + +def test_compute_hbonds_traj(hydrogen_bond_tool, dummy_traj): + hbond_counts = hydrogen_bond_tool.compute_hbonds_traj(dummy_traj) + assert hbond_counts == [0, 0, 0] + + +def test_plot_hbonds_over_time(hydrogen_bond_tool, dummy_traj): + hbond_counts = hydrogen_bond_tool.compute_hbonds_traj(dummy_traj) + result = hydrogen_bond_tool.plot_hbonds_over_time(hbond_counts, dummy_traj, "dummy") + assert "plot saved to" in result + assert ".png" in result + + +def test_write_hbond_counts_to_file(hydrogen_bond_tool, dummy_traj): + hbond_counts = hydrogen_bond_tool.compute_hbonds_traj(dummy_traj) + result = hydrogen_bond_tool.write_hbond_counts_to_file(hbond_counts, "dummy") + assert "Data saved to" in result + assert ".csv" in result diff --git a/tests/test_analysis/test_inertia.py b/tests/test_analysis/test_inertia.py index 67a7c91a..adb29863 100644 --- a/tests/test_analysis/test_inertia.py +++ b/tests/test_analysis/test_inertia.py @@ -3,10 +3,7 @@ import numpy as np import pytest -from mdagent.tools.base_tools.analysis_tools.inertia import ( - MOIFunctions, - MomentOfInertia, -) +from mdcrow.tools.base_tools.analysis_tools.inertia import MOIFunctions, MomentOfInertia @pytest.fixture @@ -46,8 +43,8 @@ def test_plot_moi_one_frame(moi_functions): assert "Only one frame in trajectory, no plot generated." in result -@patch("mdagent.tools.base_tools.analysis_tools.inertia.plt.savefig") -@patch("mdagent.tools.base_tools.analysis_tools.inertia.plt.close") +@patch("mdcrow.tools.base_tools.analysis_tools.inertia.plt.savefig") +@patch("mdcrow.tools.base_tools.analysis_tools.inertia.plt.close") def test_plot_moi_multiple_frames(mock_close, mock_savefig, moi_functions): # Simulate multiple frames of inertia tensor data moi_functions.moments_of_inertia = np.array([[1.0, 2.0, 3.0], [1.1, 2.1, 3.1]]) diff --git a/tests/test_analysis/test_misc.py b/tests/test_analysis/test_misc.py index f2a6656f..3ce340a9 100644 --- a/tests/test_analysis/test_misc.py +++ b/tests/test_analysis/test_misc.py @@ -3,8 +3,8 @@ import pytest -from mdagent.tools.base_tools import VisFunctions -from mdagent.tools.base_tools.analysis_tools.plot_tools import PlottingTools +from mdcrow.tools.base_tools import VisFunctions +from mdcrow.tools.base_tools.analysis_tools.plot_tools import PlottingTools @pytest.fixture diff --git a/tests/test_analysis/test_pca_tool.py b/tests/test_analysis/test_pca_tool.py index 59d33e45..3f0ae6e4 100644 --- a/tests/test_analysis/test_pca_tool.py +++ b/tests/test_analysis/test_pca_tool.py @@ -3,8 +3,8 @@ import mdtraj as md import numpy as np -from mdagent.tools.base_tools import PCATool -from mdagent.tools.base_tools.analysis_tools.pca_tools import PCA_analysis +from mdcrow.tools.base_tools import PCATool +from mdcrow.tools.base_tools.analysis_tools.pca_tools import PCA_analysis def test_pca_tool_bad_inputs(get_registry): diff --git a/tests/test_analysis/test_rdftool.py b/tests/test_analysis/test_rdftool.py index a15bac4b..02287c0b 100644 --- a/tests/test_analysis/test_rdftool.py +++ b/tests/test_analysis/test_rdftool.py @@ -2,7 +2,7 @@ import pytest -from mdagent.tools.base_tools.analysis_tools.rdf_tool import RDFTool +from mdcrow.tools.base_tools.analysis_tools.rdf_tool import RDFTool # TODO add dcd files in testing file for testing diff --git a/tests/test_analysis/test_rgy_tool.py b/tests/test_analysis/test_rgy_tool.py index c131d5be..21859334 100644 --- a/tests/test_analysis/test_rgy_tool.py +++ b/tests/test_analysis/test_rgy_tool.py @@ -1,6 +1,6 @@ import pytest -from mdagent.tools.base_tools.analysis_tools.rgy import ( +from mdcrow.tools.base_tools.analysis_tools.rgy import ( RadiusofGyration, RadiusofGyrationTool, ) diff --git a/tests/test_analysis/test_rmsd_ppi.py b/tests/test_analysis/test_rmsd_ppi.py index eb449fbe..ba4f7032 100644 --- a/tests/test_analysis/test_rmsd_ppi.py +++ b/tests/test_analysis/test_rmsd_ppi.py @@ -1,8 +1,8 @@ import pytest -from mdagent.tools.base_tools.analysis_tools.ppi_tools import ppi_distance -from mdagent.tools.base_tools.analysis_tools.rmsd_tools import lprmsd, rmsd, rmsf -from mdagent.utils import load_traj_with_ref +from mdcrow.tools.base_tools.analysis_tools.ppi_tools import ppi_distance +from mdcrow.tools.base_tools.analysis_tools.rmsd_tools import lprmsd, rmsd, rmsf +from mdcrow.utils import load_traj_with_ref # pdb with two chains pdb_string = """ diff --git a/tests/test_analysis/test_sasa.py b/tests/test_analysis/test_sasa.py index 82abb593..e3eed9ff 100644 --- a/tests/test_analysis/test_sasa.py +++ b/tests/test_analysis/test_sasa.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from mdagent.tools.base_tools.analysis_tools.sasa import ( +from mdcrow.tools.base_tools.analysis_tools.sasa import ( SASAFunctions, SolventAccessibleSurfaceArea, ) @@ -55,7 +55,7 @@ def test_sasa_tool_init(get_registry): @patch("os.path.exists", return_value=False) -@patch("mdagent.tools.base_tools.analysis_tools.sasa.np.savetxt") +@patch("mdcrow.tools.base_tools.analysis_tools.sasa.np.savetxt") def test_calculate_sasa(mock_savetxt, mock_exists, get_sasa_functions_with_files): analysis = get_sasa_functions_with_files result = analysis.calculate_sasa() @@ -65,8 +65,8 @@ def test_calculate_sasa(mock_savetxt, mock_exists, get_sasa_functions_with_files assert analysis.total_sasa is not None -@patch("mdagent.tools.base_tools.analysis_tools.sasa.plt.savefig") -@patch("mdagent.tools.base_tools.analysis_tools.sasa.plt.close") +@patch("mdcrow.tools.base_tools.analysis_tools.sasa.plt.savefig") +@patch("mdcrow.tools.base_tools.analysis_tools.sasa.plt.close") def test_plot_sasa(mock_close, mock_savefig, get_sasa_functions_with_files): analysis = get_sasa_functions_with_files analysis.residue_sasa = np.array([[1, 2], [3, 4]]) diff --git a/tests/test_analysis/test_secondary_structure.py b/tests/test_analysis/test_secondary_structure.py index 699e5aef..67516a51 100644 --- a/tests/test_analysis/test_secondary_structure.py +++ b/tests/test_analysis/test_secondary_structure.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from mdagent.tools.base_tools.analysis_tools.secondary_structure import ( +from mdcrow.tools.base_tools.analysis_tools.secondary_structure import ( ComputeAcylindricity, ComputeAsphericity, ComputeDSSP, diff --git a/tests/test_general_tools/test_search_tools.py b/tests/test_general_tools/test_search_tools.py index e457ee6e..84a19737 100644 --- a/tests/test_general_tools/test_search_tools.py +++ b/tests/test_general_tools/test_search_tools.py @@ -1,7 +1,7 @@ import pytest from langchain_openai import ChatOpenAI -from mdagent.tools.base_tools import Scholar2ResultLLM +from mdcrow.tools.base_tools import Scholar2ResultLLM @pytest.fixture diff --git a/tests/test_general_tools/test_util_tools.py b/tests/test_general_tools/test_util_tools.py index 1e840979..52fe0fcb 100644 --- a/tests/test_general_tools/test_util_tools.py +++ b/tests/test_general_tools/test_util_tools.py @@ -7,8 +7,8 @@ import pytest -from mdagent.agent.agent import MDAgent -from mdagent.utils import FileType, PathRegistry, SetCheckpoint +from mdcrow.agent.agent import MDCrow +from mdcrow.utils import FileType, PathRegistry, SetCheckpoint @pytest.fixture @@ -231,10 +231,10 @@ def test_path_registry_ckpt(get_registry): assert os.path.isdir(ckpt) -def test_mdagent_w_ckpt(): +def test_mdcrow_w_ckpt(): dummy_test_dir = "ckpt_test" - mdagent = MDAgent(ckpt_dir=dummy_test_dir) - dummy_test_path = mdagent.path_registry.ckpt_dir + mdcrow = MDCrow(ckpt_dir=dummy_test_dir) + dummy_test_path = mdcrow.path_registry.ckpt_dir assert os.path.exists(dummy_test_path) assert dummy_test_dir in dummy_test_path diff --git a/tests/test_preprocess/test_cleaning.py b/tests/test_preprocess/test_cleaning.py index 890e12ee..c329016c 100644 --- a/tests/test_preprocess/test_cleaning.py +++ b/tests/test_preprocess/test_cleaning.py @@ -1,4 +1,4 @@ -from mdagent.tools.base_tools import CleaningToolFunction +from mdcrow.tools.base_tools import CleaningToolFunction def test_cleaning_function(get_registry): diff --git a/tests/test_preprocess/test_packing.py b/tests/test_preprocess/test_packing.py index c0f108d6..6aebcd50 100644 --- a/tests/test_preprocess/test_packing.py +++ b/tests/test_preprocess/test_packing.py @@ -2,7 +2,7 @@ import pytest -from mdagent.tools.base_tools.preprocess_tools.packing import ( +from mdcrow.tools.base_tools.preprocess_tools.packing import ( Molecule, PackmolBox, PackMolTool, diff --git a/tests/test_preprocess/test_pdb_tools.py b/tests/test_preprocess/test_pdb_tools.py index 7652197d..49ff198b 100644 --- a/tests/test_preprocess/test_pdb_tools.py +++ b/tests/test_preprocess/test_pdb_tools.py @@ -4,9 +4,9 @@ import pytest -from mdagent.tools.base_tools import get_pdb -from mdagent.tools.base_tools.preprocess_tools.packing import PackMolTool -from mdagent.tools.base_tools.preprocess_tools.pdb_get import MolPDB +from mdcrow.tools.base_tools import get_pdb +from mdcrow.tools.base_tools.preprocess_tools.packing import PackMolTool +from mdcrow.tools.base_tools.preprocess_tools.pdb_get import MolPDB @pytest.fixture @@ -93,7 +93,7 @@ def test_packmol_sm_download_called(packmol): "1A3N_144150", f"{packmol.path_registry.ckpt_pdb}/1A3N_144150.pdb", "pdb" ) with patch( - "mdagent.tools.base_tools.preprocess_tools.packing.PackMolTool._get_sm_pdbs", + "mdcrow.tools.base_tools.preprocess_tools.packing.PackMolTool._get_sm_pdbs", new=MagicMock(), ) as mock_get_sm_pdbs: test_values = { diff --git a/tests/test_preprocess/test_uniprot.py b/tests/test_preprocess/test_uniprot.py index 64db4887..119cf95d 100644 --- a/tests/test_preprocess/test_uniprot.py +++ b/tests/test_preprocess/test_uniprot.py @@ -1,6 +1,6 @@ import pytest -from mdagent.tools.base_tools.preprocess_tools.uniprot import ( +from mdcrow.tools.base_tools.preprocess_tools.uniprot import ( GetAllKnownSites, QueryUniprot, ) diff --git a/tests/test_sim/test_setup.py b/tests/test_sim/test_setup.py index f4f58220..edc08256 100644 --- a/tests/test_sim/test_setup.py +++ b/tests/test_sim/test_setup.py @@ -2,7 +2,7 @@ from openmm import unit from openmm.app import PME, HBonds -from mdagent.tools.base_tools.simulation_tools import SetUpandRunFunction +from mdcrow.tools.base_tools.simulation_tools import SetUpandRunFunction @pytest.fixture diff --git a/tests/test_sim/test_setupandrun.py b/tests/test_sim/test_setupandrun.py index b3924742..3414ecb3 100644 --- a/tests/test_sim/test_setupandrun.py +++ b/tests/test_sim/test_setupandrun.py @@ -5,7 +5,7 @@ from openmm import unit from openmm.app import PME, HBonds -from mdagent.tools.base_tools.simulation_tools.setup_and_run import ( +from mdcrow.tools.base_tools.simulation_tools.setup_and_run import ( OpenMMSimulation, SetUpandRunFunction, ) diff --git a/tests/test_sim/test_write_script.py b/tests/test_sim/test_write_script.py index 09a2453f..4c12f4c3 100644 --- a/tests/test_sim/test_write_script.py +++ b/tests/test_sim/test_write_script.py @@ -2,7 +2,7 @@ from openmm import unit from openmm.app import PME, NoCutoff -from mdagent.tools.base_tools.simulation_tools.setup_and_run import OpenMMSimulation +from mdcrow.tools.base_tools.simulation_tools.setup_and_run import OpenMMSimulation @pytest.fixture diff --git a/tests/test_utils/test_datahandling.py b/tests/test_utils/test_datahandling.py index 54a37058..9f454f02 100644 --- a/tests/test_utils/test_datahandling.py +++ b/tests/test_utils/test_datahandling.py @@ -4,12 +4,12 @@ import numpy as np import pytest -from mdagent.utils import load_single_traj, load_traj_with_ref, save_plot, save_to_csv +from mdcrow.utils import load_single_traj, load_traj_with_ref, save_plot, save_to_csv @pytest.fixture def load_single_traj_mock(): - with patch("mdagent.utils.data_handling.load_single_traj", return_value="MockTraj"): + with patch("mdcrow.utils.data_handling.load_single_traj", return_value="MockTraj"): yield @@ -72,7 +72,7 @@ def test_save_plot_success(get_registry): path_registry = get_registry("raw", False) fig, ax = plt.subplots() ax.plot([1, 2, 3], [1, 2, 3]) # Create a plot - with patch("mdagent.utils.data_handling.plt.savefig"): + with patch("mdcrow.utils.data_handling.plt.savefig"): fig_id = save_plot(path_registry, "test_data", "Test plot") assert "fig0_" in fig_id diff --git a/tests/test_utils/test_eval.py b/tests/test_utils/test_eval.py deleted file mode 100644 index 149b30b6..00000000 --- a/tests/test_utils/test_eval.py +++ /dev/null @@ -1,149 +0,0 @@ -from unittest.mock import MagicMock, mock_open, patch - -import pytest - -from mdagent.agent.evaluate import Evaluator - - -@pytest.fixture -def evaluator(): - with patch("os.makedirs"): - yield Evaluator() - - -@pytest.fixture -def mock_os_makedirs(): - with patch("os.makedirs", MagicMock()) as mock: - yield mock - - -@pytest.fixture -def mock_open_json(): - with patch("builtins.open", mock_open(read_data='[{"key": "value"}]')) as mock: - yield mock - - -@pytest.fixture -def mock_json_load(): - with patch("json.load", return_value=[{"key": "value"}]) as mock: - yield mock - - -@pytest.fixture -def mock_json_dump(): - with patch("json.dump", MagicMock()) as mock: - yield mock - - -@pytest.fixture -def mock_os_path_exists(): - with patch("os.path.exists", return_value=True) as mock: - yield mock - - -@pytest.fixture -def mock_agent(tmp_path): - mock_action = MagicMock() - mock_action.tool = "some_tool" - mock_action.tool_input = "some_input" - agent = MagicMock() - agent.iter.return_value = iter( - [ - {"intermediate_step": [(mock_action, "Succeeded. some obervation.")]}, - {"output": "Succeeded. Some final answer."}, - ] - ) - agent.ckpt_dir = tmp_path / "fake_dir" - agent.llm.model_name = "test_model" - agent.tools_llm.model_name = "some_tool_model" - agent.agent_type = "test_agent_type" - agent.use_memory = False - agent.run_id = "test_run_id" - return agent - - -@patch("mdagent.agent.evaluate.MDAgent") -def test_create_agent(mock_mdagent, evaluator): - agent_params = {"model_name": "test_model"} - evaluator.create_agent(agent_params) - mock_mdagent.assert_called_once_with(**agent_params) - - -def test_reset(evaluator): - evaluator.evaluations = ["dummy"] - evaluator.reset() - assert evaluator.evaluations == [] - - -def test_save(evaluator, mock_open_json, mock_json_dump): - evaluator.evaluations = [{"test_key": "test_value"}] - evaluator.save("test_file") - mock_open_json.assert_called() - mock_json_dump.assert_called() - - -def test_load(evaluator, mock_open_json, mock_json_load, mock_os_path_exists): - filename = "dummy_data.json" - evaluator.load(filename) - mock_os_path_exists.assert_called_once_with(filename) - mock_open_json.assert_called_once_with(filename, "r") - mock_json_load.assert_called_once() - assert evaluator.evaluations == [{"key": "value"}] - - -def test_evaluate_all_steps(evaluator, mock_agent, mock_os_makedirs, mock_open_json): - user_prompt = "Test prompt" - result = evaluator._evaluate_all_steps(mock_agent, user_prompt) - assert result["prompt_success"] is True, "The prompt should be marked as succeeded." - - -def test_evaluate_all_steps_contents( - evaluator, mock_agent, mock_os_makedirs, mock_open_json, mock_json_dump -): - user_prompt = "Test some prompt" - evaluator._evaluate_all_steps(mock_agent, user_prompt) - assert mock_json_dump.call_count == 1 - args, kwargs = mock_json_dump.call_args - data_to_dump = args[0] - assert data_to_dump["prompt_success"] is True - assert "Step 1" in data_to_dump["tools_details"] - assert data_to_dump["tools_details"]["Step 1"]["status_complete"] is True - assert data_to_dump["total_steps"] == 1 - - -def test_run_and_evaluate(evaluator, mock_os_makedirs, mock_open_json): - with patch( - "mdagent.agent.evaluate.Evaluator._evaluate_all_steps" - ) as mock_evaluate_all_steps: - mock_evaluate_all_steps.side_effect = [ - {"prompt_success": True}, - Exception("Test error"), - ] - prompts = ["Prompt 1", "Prompt 2"] - evaluator.run_and_evaluate(prompts) - assert len(evaluator.evaluations) == 2 - assert evaluator.evaluations[0]["execution_success"] is True - assert evaluator.evaluations[1]["execution_success"] is False - assert "Test error" in evaluator.evaluations[1]["error_msg"] - - -@patch("pandas.DataFrame.to_json", MagicMock()) -def test_create_table(evaluator): - evaluator.evaluations = [ - { - "execution_success": True, - "total_steps": 1, - "failed_steps": 0, - "prompt_success": True, - "total_time_seconds": "10.0", - }, - { - "execution_success": True, - "total_steps": 2, - "failed_steps": 1, - "prompt_success": False, - "total_time_seconds": "20.0", - }, - ] - df = evaluator.create_table() - assert len(df) == 2 diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py index 1f233000..39249c4a 100644 --- a/tests/test_utils/test_memory.py +++ b/tests/test_utils/test_memory.py @@ -4,8 +4,8 @@ import pytest from langchain_openai import ChatOpenAI -from mdagent.agent.agent import MDAgent -from mdagent.agent.memory import MemoryManager +from mdcrow.agent.agent import MDCrow +from mdcrow.agent.memory import MemoryManager @pytest.fixture @@ -14,17 +14,17 @@ def memory_manager(get_registry): return MemoryManager(get_registry("raw", False), llm) -def test_mdagent_memory(): - mdagent_memory = MDAgent(use_memory=True) - mdagent_no_memory = MDAgent(use_memory=False) - assert mdagent_memory.use_memory is True - assert mdagent_no_memory.use_memory is False +def test_mdcrow_memory(): + mdcrow_memory = MDCrow(use_memory=True) + mdcrow_no_memory = MDCrow(use_memory=False) + assert mdcrow_memory.use_memory is True + assert mdcrow_no_memory.use_memory is False - mdagent_memory = MDAgent(use_memory=True, run_id="TESTRUNN") - assert mdagent_memory.run_id == "TESTRUNN" + mdcrow_memory = MDCrow(use_memory=True, run_id="TESTRUNN") + assert mdcrow_memory.run_id == "TESTRUNN" - mdagent_memory = MDAgent(use_memory=True, run_id="") - assert mdagent_memory.run_id + mdcrow_memory = MDCrow(use_memory=True, run_id="") + assert mdcrow_memory.run_id def test_memory_init(memory_manager, get_registry): @@ -41,27 +41,27 @@ def test_memory_init(memory_manager, get_registry): def test_force_clear_mem(monkeypatch): dummy_test_dir = "ckpt_test" - mdagent = MDAgent(ckpt_dir=dummy_test_dir) + mdcrow = MDCrow(ckpt_dir=dummy_test_dir) monkeypatch.setattr("builtins.input", lambda _: "yes") - print(mdagent.path_registry.ckpt_dir) - print(mdagent.path_registry.json_file_path) - print(os.path.dirname(mdagent.path_registry.ckpt_dir)) - mdagent.force_clear_mem(all=False) - assert not os.path.exists(mdagent.path_registry.ckpt_dir) - assert not os.path.exists(mdagent.path_registry.json_file_path) + print(mdcrow.path_registry.ckpt_dir) + print(mdcrow.path_registry.json_file_path) + print(os.path.dirname(mdcrow.path_registry.ckpt_dir)) + mdcrow.force_clear_mem(all=False) + assert not os.path.exists(mdcrow.path_registry.ckpt_dir) + assert not os.path.exists(mdcrow.path_registry.json_file_path) assert os.path.exists( - os.path.basename(os.path.dirname(mdagent.path_registry.ckpt_dir)) + os.path.basename(os.path.dirname(mdcrow.path_registry.ckpt_dir)) ) - mdagent = MDAgent(ckpt_dir=dummy_test_dir) + mdcrow = MDCrow(ckpt_dir=dummy_test_dir) monkeypatch.setattr("builtins.input", lambda _: "yes") - mdagent.force_clear_mem(all=True) - print(mdagent.path_registry.ckpt_dir) - print(mdagent.path_registry.json_file_path) - print(os.path.dirname(mdagent.path_registry.ckpt_dir)) - assert not os.path.exists(mdagent.path_registry.ckpt_dir) - assert not os.path.exists(mdagent.path_registry.json_file_path) - assert not os.path.exists(os.path.dirname(mdagent.path_registry.ckpt_dir)) + mdcrow.force_clear_mem(all=True) + print(mdcrow.path_registry.ckpt_dir) + print(mdcrow.path_registry.json_file_path) + print(os.path.dirname(mdcrow.path_registry.ckpt_dir)) + assert not os.path.exists(mdcrow.path_registry.ckpt_dir) + assert not os.path.exists(mdcrow.path_registry.json_file_path) + assert not os.path.exists(os.path.dirname(mdcrow.path_registry.ckpt_dir)) def test_write_to_json_new_file(tmp_path, memory_manager): diff --git a/tests/test_utils/test_top_k_tools.py b/tests/test_utils/test_top_k_tools.py index 2f3dda3b..6a5a323b 100644 --- a/tests/test_utils/test_top_k_tools.py +++ b/tests/test_utils/test_top_k_tools.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from mdagent.tools.maketools import get_relevant_tools +from mdcrow.tools.maketools import get_relevant_tools @pytest.fixture @@ -21,8 +21,8 @@ def mock_tools(): return [tool1, tool2, tool3] -@patch("mdagent.tools.maketools.make_all_tools") -@patch("mdagent.tools.maketools.OpenAIEmbeddings") +@patch("mdcrow.tools.maketools.make_all_tools") +@patch("mdcrow.tools.maketools.OpenAIEmbeddings") def test_get_relevant_tools_with_openai_embeddings( mock_openai_embeddings, mock_make_all_tools, mock_llm, mock_tools ): @@ -41,8 +41,8 @@ def test_get_relevant_tools_with_openai_embeddings( assert relevant_tools[1] in mock_tools -@patch("mdagent.tools.maketools.make_all_tools") -@patch("mdagent.tools.maketools.TfidfVectorizer") +@patch("mdcrow.tools.maketools.make_all_tools") +@patch("mdcrow.tools.maketools.TfidfVectorizer") def test_get_relevant_tools_with_tfidf( mock_tfidf_vectorizer, mock_make_all_tools, mock_llm, mock_tools ): @@ -58,7 +58,7 @@ def test_get_relevant_tools_with_tfidf( assert relevant_tools[1] in mock_tools -@patch("mdagent.tools.maketools.make_all_tools") +@patch("mdcrow.tools.maketools.make_all_tools") def test_get_relevant_tools_with_no_tools(mock_make_all_tools, mock_llm): mock_make_all_tools.return_value = [] @@ -67,8 +67,8 @@ def test_get_relevant_tools_with_no_tools(mock_make_all_tools, mock_llm): assert relevant_tools is None -@patch("mdagent.tools.maketools.make_all_tools") -@patch("mdagent.tools.maketools.OpenAIEmbeddings") +@patch("mdcrow.tools.maketools.make_all_tools") +@patch("mdcrow.tools.maketools.OpenAIEmbeddings") def test_get_relevant_tools_with_openai_exception( mock_openai_embeddings, mock_make_all_tools, mock_llm, mock_tools ): @@ -83,7 +83,7 @@ def test_get_relevant_tools_with_openai_exception( assert relevant_tools is None -@patch("mdagent.tools.maketools.make_all_tools") +@patch("mdcrow.tools.maketools.make_all_tools") def test_get_relevant_tools_top_k(mock_make_all_tools, mock_llm, mock_tools): mock_make_all_tools.return_value = mock_tools