diff --git a/.gitignore b/.gitignore index 3aac549..3584f13 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ #scratch chatbot.ipynb +*.pkl +*.png # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 12d61df..ba7acc8 100644 --- a/README.md +++ b/README.md @@ -24,14 +24,32 @@ Install the chatbot package. ```bash pip install -e . ``` - To develop the code, run - ```bash pip install -e .[dev] ``` +Or simply, +```bash +pip install metadata-chatbot +``` + +## High Level Overview + +The project's main goal is to developing a chat bot that is able to ingest, analyze and query metadata. Metadata is accumulated in lieu with experiments and consists of information about the data description, subject, equipment and session. To maintain reproducibility standards, it is important for metadata to be documented well. + +## Model Overview + +The current chat bot model uses Anthropic's Claude Sonnet 3 hosted on AWS' Bedrock service. Since the primary goal is to use natural language to query the database, the user will provide prompts about the metadata specifically. The framework is hosted on Langchain. Claude's system prompt has been configured to understand the metadata schema format and craft MongoDB queries based on the prompt. Given a natural language query about the metadata, the model will produce a MongoDB query, thought reasoning and answer. This method of answering follows chain of thought reasoning, where a complex task is broken up into manageable chunks, allowing logical thinking through of a problem. + +## Data Retrieval + +### Vector Embeddings + +To improve retrieval accuracy and decrease hallucinations, we use vector embeddings to access relevant chunks of information found across the database. This process starts with accessing assets, and chunking each json file to chunks of 1000 tokens -- each chunk preserves the hierarchy found in json files. These chunks are converted to vector arrays of size 1024, through an embedding model (Amazon's Titan 2.0 Embedding). The user's query is converted to a vector and projected onto the latent space. The chunks that contain the most relevant information will be accessed through a cosine similarity search. + +### AIND-data-schema-access REST API -## Contributing +For queries that require accessing the entire database, like count based questions, information is accessed through an aggregation pipeline, provided by one of the constructed LLM agents, and the API connection. ### Linters and testing diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000..689a7a6 --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,24 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "from metadata_chatbot.agents" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/embeddings/embeddings.py b/embeddings/embeddings.py new file mode 100644 index 0000000..ffb041f --- /dev/null +++ b/embeddings/embeddings.py @@ -0,0 +1,191 @@ +from urllib.parse import quote_plus +import pymongo, os, boto3, re, pickle +from pymongo import MongoClient +from langchain_community.document_loaders.mongodb import MongodbLoader +from langchain_aws import BedrockEmbeddings +from sshtunnel import SSHTunnelForwarder +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_experimental.text_splitter import SemanticChunker +from tqdm import tqdm +from transformers import AutoTokenizer +from datetime import datetime +from tqdm.contrib.logging import logging_redirect_tqdm + +import logging + +logging.basicConfig(filename='embeddings.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + +# Constants +TOKEN_LIMIT = 8192 # TODO: update the value +BATCH_SIZE = 100 + + +# Establishing bedrock client and embedding model +bedrock = boto3.client( + service_name="bedrock-runtime", + region_name = 'us-west-2' +) + +embeddings_model = BedrockEmbeddings(model_id="amazon.titan-embed-text-v2:0",client=bedrock) + +logging.info("Embedding model instantiated") + +#Establishing tokenizer +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +logging.info("Tokenizer instantiated") + + +#TODO : CONVERT FUNCTION TO normal tokenizer +def count_tokens(id, text): + tokens = tokenizer.encode(text, truncation=False) + print(f"{id}:{len(tokens)} tokens") + if len(tokens) > TOKEN_LIMIT: + logging.info(f"{id} has {len(tokens)} tokens. Too large.") + return None + return tokens + +def create_ssh_tunnel(): + """Create an SSH tunnel to the Document Database.""" + try: + return SSHTunnelForwarder( + ssh_address_or_host=( + os.getenv("DOC_DB_SSH_HOST"), + 22, + ), + ssh_username=os.getenv("DOC_DB_SSH_USERNAME"), + ssh_password=os.getenv("DOC_DB_SSH_PASSWORD"), + remote_bind_address=(os.getenv("DOC_DB_HOST"), 27017), + local_bind_address=( + "localhost", + 27017, + ), + ) + except Exception as e: + logging.error(f"Error creating SSH tunnel: {e}") + + +def generate_embeddings_for_batch(client: MongoClient, batch: list) -> dict: + """Generates embeddings vectors for a batch of loaded documents + """ + logging.info("Embedding documents...") + db = client["metadata_vector_index"] + result_collection = db["data_assets_vectors"] + + skipped_ids = [] + failed_ids = [] + batch_vectors = dict() + with logging_redirect_tqdm(): + for doc in tqdm(batch, desc="Embeddings in progress", total = len(batch)): + + doc_text = doc.page_content + + pattern = r"'_id':\s*'([^']+)'" + match = re.search(pattern, doc_text) + if match: + id_value = match.group(1) + else: + # TODO: log warning + continue + + if result_collection.count_documents({"_id": id_value, "vector_embeddings": {"$exists": True}}): + skipped_ids.append(id_value) + continue + + tokens = count_tokens(id_value, doc_text) + if tokens is None: + failed_ids.append(id_value) + continue + vector = embeddings_model.embed_documents([doc_text])[0] # Embed a single document + + batch_vectors[id_value] = vector + logging.info("Embedding finished for batch") + logging.info(f"Succesfully embedded {len(batch_vectors)}/{len(batch)} documents.") + logging.warning(f"Failed for {len(failed_ids)} documents: {failed_ids}") + logging.info(f"Skipped {len(skipped_ids)} documents: {skipped_ids}") + return batch_vectors + +def write_embeddings_to_docdb_for_batch(client: MongoClient, batch_vectors: dict) -> None: + db = client["metadata_vector_index"] + result_collection = db["data_assets_vectors"] + + for id, vector in batch_vectors.items(): + logging.info(f"Adding vector embeddings for {id} to docdb") + filter={"_id": id} + update={ + "$set": { + "vector_embeddings": vector + } + } + result = result_collection.update_one(filter, update, upsert=False) + logging.info(result.raw_result) + return + +database_name = "metadata_vector_index" + + # Escape username and password to handle special characters +escaped_username = quote_plus(os.getenv("DOC_DB_USERNAME")) +escaped_password = quote_plus(os.getenv('DOC_DB_PASSWORD')) + +connection_string = f"mongodb://{escaped_username}:{escaped_password}@localhost:27017/?directConnection=true&authMechanism=SCRAM-SHA-1&retryWrites=false" + +try: + #print(f"Attempting to connect with: {connection_string}") + + ssh_server = create_ssh_tunnel() + ssh_server.start() + logging.info("SSH tunnel opened") + + client = MongoClient(connection_string) + + # Force a server check + server_info = client.server_info() + print(f"Server info: {server_info}") + + logging.info("Successfully connected to MongoDB") + + + #possibly filter criteria subject/data desc etc + loader = MongodbLoader( + connection_string = connection_string, + db_name = 'metadata_vector_index', + collection_name='data_assets' + ) + + logging.info("Loading collection...") + + documents = loader.load() + total_docs = len(documents) + logging.info(f"Loaded {total_docs} documents..") + + #text_splitter = RecursiveCharacterTextSplitter( + # chunk_size=1000, + # chunk_overlap=100 + # ) + + #text_splitter = SemanticChunker(embeddings, breakpoint_threshold_type="gradient") + #docs = text_splitter.split_documents(documents) + #docs_text = [doc.page_content for doc in docs] + + + for i in range(0, total_docs, BATCH_SIZE): + end = i+BATCH_SIZE if i+BATCH_SIZE bool: + + PHYSIO_modalities = ["behavior", "Other", "FIP", "phys", "HSFP"] + #SPIM_modalities = ["SPIM", "HCR"] + + PHYSIO_pattern = '(' + '|'.join(re.escape(word) for word in PHYSIO_modalities) + ')_' + regex = re.compile(PHYSIO_pattern) + + return bool(regex.search(record_name)) + + +def json_to_langchain_doc(json_doc: dict) -> tuple[list, list]: + + docs = [] + large_docs = [] + + PHYSIO_fields_to_embed = ["rig", "session"] + + SPIM_fields_To_embed = ["instrument", "acquisition"] + + general_fields_to_embed = ["data_description", "subject", "procedures"] + + if regex_modality_PHYSIO(json_doc["name"]): + fields_to_embed = [*PHYSIO_fields_to_embed, *general_fields_to_embed] + else: + fields_to_embed = [*SPIM_fields_To_embed, *general_fields_to_embed] + + #fields_to_metadata = ["_id", "created", "describedBy", "external_links", "last_modified", "location", "metadata_status", "name", "processing", "schema_version"] + + to_metadata = dict() + values_to_embed = dict() + + for item, value in json_doc.items(): + if item == "_id": + item = "original_id" + if item in fields_to_embed: + values_to_embed[item] = value + else: + to_metadata[item] = value + + subject = json_doc.get("subject") + + if subject is not None: + to_metadata["subject_id"] = subject.get("subject_id", None) # Default if subject_id is missing + else: + #print("Subject key is missing or None.") + to_metadata["subject_id"] = "null" + + data_description = json_doc.get("data_description") + + if data_description is not None: + to_metadata["modality"] = data_description.get("modality", None) # Default if subject_id is missing + else: + #print("Subject key is missing or None.") + to_metadata["modality"] = "null" + + json_chunks = JSON_SPLITTER.split_text(json_data=values_to_embed, convert_lists=True) + + for chunk in json_chunks: + newDoc = Document(page_content=chunk, metadata=to_metadata) + if len(chunk) < TOKEN_LIMIT: + docs.append(newDoc) + else: + large_docs.append(newDoc) + + return docs, large_docs + +#INDEX_NAME = 'ALL_curated_embeddings_index' +INDEX_NAME = 'TOKEN_LIMIT_curated_embeddings_index' +NAMESPACE = 'metadata_vector_index.curated_assets' +DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") + + +with ResourceManager() as RM: + + collection = RM.client[DB_NAME][COLLECTION_NAME] + langchain_collection = RM.client[DB_NAME]['bigger_LANGCHAIN_curated_chunks'] + LANGCHAIN_NAMESPACE = 'metadata_vector_index.bigger_LANGCHAIN_curated_chunks' + + logging.info(f"Finding assets that are already embedded...") + + if langchain_collection is not None: + existing_ids = set(doc['original_id'] for doc in langchain_collection.find({}, {'original_id': 1})) + + logging.info(f"Skipped {len(existing_ids)} assets, which are already in the new collection") + + docs_to_vectorize = collection.count_documents({'_id': {'$nin': list(existing_ids)}}) + + logging.info(f"{docs_to_vectorize} assets need to be vectorized") + + if docs_to_vectorize != 0: + + cursor = collection.find({'_id': {'$nin': list(existing_ids)}}) + + docs = [] + skipped_docs = [] + + logging.info("Chunking documents...") + + document_no = 0 + + for document in tqdm(cursor, desc="Chunking in progress"): + if document_no % 100 == 0: + logging.info(f"Currently on asset number {document_no}") + json_doc = json.loads(json_util.dumps(document)) + chunked_docs, large_docs = json_to_langchain_doc(json_doc) + docs.extend(chunked_docs) + skipped_docs.extend(large_docs) + document_no += 1 + + logging.info(f"Successfully chunked {document_no} documents") + + logging.info(f"Adding {len(docs)} chunked documents to collection") + logging.info(f"Skipping {len(skipped_docs)} due to token limitations") + + try: + vectorstore = DocumentDBVectorSearch( + embedding=BEDROCK_EMBEDDINGS, + collection=langchain_collection, + index_name=INDEX_NAME, + ) + + batch_size = 100 + for i in range(0, len(docs), batch_size): + batch = docs[i:i + batch_size] + vectorstore.add_documents(batch) + logging.info(f"Added batch {i // batch_size + 1} of documents") + + dimensions = 1024 + similarity_algorithm = DocumentDBSimilarityType.COS + + logging.info("Creating vector index with chunked documents") + vectorstore.create_index(dimensions, similarity_algorithm) + + except Exception as e: + logging.error(f"Error processing documents: {str(e)}") + + else: + logging.info("Vectorstore is up to date!") + diff --git a/embeddings/umap_visualization.py b/embeddings/umap_visualization.py new file mode 100644 index 0000000..19185c7 --- /dev/null +++ b/embeddings/umap_visualization.py @@ -0,0 +1,169 @@ +from langchain_community.vectorstores.documentdb import DocumentDBVectorSearch +from urllib.parse import quote_plus +import pymongo, os, boto3, sys, umap +from pymongo import MongoClient +from langchain_aws import BedrockEmbeddings +import logging +import matplotlib.pyplot as plt +import umap.umap_ as umap + + +import numpy as np +import pandas as pd +from sklearn.manifold import TSNE +import plotly.graph_objs as go +from plotly.subplots import make_subplots +import plotly.io as pio + +sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) +from metadata_chatbot.utils import create_ssh_tunnel, ALL_CURATED_VECTORSTORE, BEDROCK_EMBEDDINGS, CONNECTION_STRING + + +logging.basicConfig(filename='vector_visualization.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") + +client = MongoClient(CONNECTION_STRING) +langchain_collection = client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets'] + + + +ssh_server = create_ssh_tunnel() +ssh_server.start() +logging.info("SSH tunnel opened") + +logging.info("Successfully connected to MongoDB") +logging.info("Initializing connection vector store") +vectorstore = ALL_CURATED_VECTORSTORE + + # query = "subject" + # logging.info("Starting to vectorize query...") + # query_embedding = BEDROCK_EMBEDDINGS.embed_query(query) + + # total_documents = langchain_collection.count_documents({}) + # print(f"Total documents in collection: {total_documents}") + + # # Check indexes on the collection + # indexes = langchain_collection.index_information() + # print(f"Indexes on collection: {indexes}") + + # result = langchain_collection.aggregate([ + # { + # '$search': { + # 'vectorSearch': { + # 'vector': query_embedding, + # 'path': 'vectorContent', + # 'similarity': 'cosine', + # 'k': 22100 + # } + # } + # }, + # { + # '$limit': 22100 # Ensure the pipeline limits the results to 22100 + # } + # ]) + +logging.info("Finding vectors...") +documents = langchain_collection.find({}, {"vectorContent": 1, "_id": 0}) +logging.info("Extracting vectors...") \ + +embeddings_list = [] +for doc in documents: + embeddings_list.append(doc["vectorContent"]) +logging.info(f"Number of vectors retrieved: {len(embeddings_list)}") +embeddings_array = np.array(embeddings_list) +logging.info("Plotting...") +reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) +embedding = reducer.fit_transform(embeddings_array) +# Plot the results +plt.figure(figsize=(12, 10)) +plt.scatter(embedding[:, 0], embedding[:, 1], s=3, alpha=0.5) +plt.title(f'UMAP projection of {len(embeddings_list)} embeddings') +plt.xlabel('UMAP1') +plt.ylabel('UMAP2') +plt.colorbar() +plt.show() + + +''' + embeddings_list = [] + modalities_list = [] + + for document in result: + embeddings_list.append(document['vectorContent']) + modalities_list.append(document['modality']) + + #embeddings_list.insert(0,query_embedding) + + print(len(embeddings_list)) + print(len(modalities_list)) + + n_components = 3 #3D + embeddings_list = np.array(embeddings_list) #converting to numpy array + + print(np.shape(embeddings_list)) + + + + tsne = TSNE(n_components=n_components, random_state=42, perplexity=20) + reduced_vectors = tsne.fit_transform(embeddings_list) + print(len(reduced_vectors)) + #reduced_vectors[0:10] + + # Create a 3D scatter plot + scatter_plot = go.Scatter3d( + x=reduced_vectors[:, 0], + y=reduced_vectors[:, 1], + z=reduced_vectors[:, 2], + mode='markers', + marker=dict(size=5, color='grey', opacity=0.5, line=dict(color='lightgray', width=1)), + text=[f"Point {i}" for i in range(len(reduced_vectors))] + ) + + # Highlight the first point with a different color + highlighted_point = go.Scatter3d( + x=[reduced_vectors[0, 0]], + y=[reduced_vectors[0, 1]], + z=[reduced_vectors[0, 2]], + mode='markers', + marker=dict(size=8, color='red', opacity=0.8, line=dict(color='lightgray', width=1)), + text=["Question"] + + ) + + blue_points = go.Scatter3d( + x=reduced_vectors[1:4, 0], + y=reduced_vectors[1:4, 1], + z=reduced_vectors[1:4, 2], + z1=reduced_vectors[1:4, 3], + z2=reduced_vectors[1:4, 4], + z3=reduced_vectors[1:4, 5], + mode='markers', + marker=dict(size=8, color='blue', opacity=0.8, line=dict(color='black', width=1)), + text=["Top 1 Data Asset","Top 2 Data Asset","Top 3 Data Asset"] + ) + + # Create the layout for the plot + layout = go.Layout( + scene=dict( + xaxis=dict(title='X'), + yaxis=dict(title='Y'), + zaxis=dict(title='Z'), + ), + title=f'3D Representation after t-SNE (Perplexity=5)' + ) + + + fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter3d'}]]) + + # Add the scatter plots to the Figure + fig.add_trace(scatter_plot) + fig.add_trace(highlighted_point) + fig.add_trace(blue_points) + + fig.update_layout(layout) + + pio.write_html(fig, 'interactive_plot.html') + fig.show() +''' + +client.close() +ssh_server.stop() diff --git a/pyproject.toml b/pyproject.toml index ee0eb1f..7ba75e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,13 +22,16 @@ dependencies = [ 'langchain-aws', 'langchain-community', 'langchain_core', + 'langchain', + 'langgraph', 'motor', 'nest-asyncio', 'pymongo', - 'ragas', 'fastapi', 'uvicorn', - 'logging' + 'logging', + 'sshtunnel', + 'transformers' ] [project.optional-dependencies] diff --git a/queries..txt b/queries..txt new file mode 100644 index 0000000..14de420 --- /dev/null +++ b/queries..txt @@ -0,0 +1,10 @@ +- finding assets where injections were performed in the same region/virus constructs +- what channels are being imaged +- what was imaged in a channels +- reagent querying - what those are? +- injection site (nsp target/actual localization) +- target coordinates/ what that region is supposed to be +- evaluation of injection target +- whats the sequence of procedures, recreating a timeline/age of a mouse +- are my injections more on target at certain ages +- name of the rig \ No newline at end of file diff --git a/src/metadata_chatbot/agents/GAMER.py b/src/metadata_chatbot/agents/GAMER.py new file mode 100644 index 0000000..8ba0784 --- /dev/null +++ b/src/metadata_chatbot/agents/GAMER.py @@ -0,0 +1,120 @@ +from typing import Any, Dict, Iterator, List, Mapping, Optional + +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk + +from langgraph.graph import StateGraph + +import logging, asyncio + +from async_workflow import async_app +from workflow import app + + +class GAMER(LLM): + + def _call( + self, + query: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """ + Args: + query: Natural language query. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + If stop tokens are not supported consider raising NotImplementedError. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output as a string. + """ + inputs = {"query" : query} + answer = app.invoke(inputs) + return answer['generation'] + + async def _acall( + self, + query: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """ + Asynchronous call. + + Args: + query: Natural language query. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + If stop tokens are not supported consider raising NotImplementedError. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output as a string. + """ + inputs = {"query" : query} + async for output in async_app.astream(inputs, stream_mode="updates"): + for key, value in output.items(): + logging.info(f"Currently on node '{key}':") + return value['generation'] if value else None + + def _stream( + self, + query: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Stream the LLM on the given prompt. + + This method should be overridden by subclasses that support streaming. + + If not implemented, the default behavior of calls to stream will be to + fallback to the non-streaming version of the model and return + the output as a single chunk. + + Args: + query: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + An iterator of GenerationChunks. + """ + for char in query[: self.n]: + chunk = GenerationChunk(text=char) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + yield chunk + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Return a dictionary of identifying parameters.""" + return { + "model_name": "Anthropic Claude 3 Sonnet", + } + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model. Used for logging purposes only.""" + return "Claude 3 Sonnet" + + + +# async def main(): +# result = await llm.ainvoke("Can you give me a timeline of events for subject 675387?") +# print(result) + +# asyncio.run(main()) \ No newline at end of file diff --git a/src/metadata_chatbot/agents/GAMER_workbook.ipynb b/src/metadata_chatbot/agents/GAMER_workbook.ipynb new file mode 100644 index 0000000..dd61430 --- /dev/null +++ b/src/metadata_chatbot/agents/GAMER_workbook.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GAMER: Generative Analysis of Metadata Retrieval" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This model uses a multi agent framework on Langraph to retrieve and summarize metadata information based on a user's natural language query. \n", + "\n", + "This workflow consists of 6 agents, or nodes, where a decision is made and there is new context provided to either the model or the user. Here are some decisions incorporated into the framework:\n", + "1. To best answer the query, does the entire database need to be queried, or the vector index?\n", + "- Input: `x (query)`\n", + "- Decides best data to query against\n", + "- Output: `entire_database, vector_embeddings`\n", + "2. If querying against the vector embeddings, does the index need to be filtered further with metdata tags, to improve optimization of retrieval?\n", + "- Input: `x (query)`\n", + "- Decides whether database can be further filtered by applying a MongoDB query\n", + "- Output: `MongoDB query, None`\n", + "3. Are the documents retrieved during retrieval relevant to the question?\n", + "- Input: `x (query)`\n", + "- Decides whether document should be kept or tossed during summarization\n", + "- Output: `yes, no`\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![title](graph_workflow.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calling the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Synchronous calling" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The retrieved information provides details about an imaging experiment conducted on a SmartSPIM instrument for the subject with ID 675387. Here is a summary of the key details:\n", + "\n", + "Instrument Details:\n", + "- Instrument ID: SmartSPIM1-2\n", + "- Instrument Type: SmartSPIM\n", + "- Manufacturer: LifeCanvas\n", + "- Location: 615 Westlake\n", + "- Temperature control: True\n", + "- Optical table: VIS3648-PG4-325A (MKS Newport), 36 x 48 inch, vibration control\n", + "- Objective: TL2X-SAP (Thorlabs), 0.1 NA, 1.6x magnification, multi-immersion\n", + "\n", + "Acquisition Details:\n", + "- Subject ID: 675387 \n", + "- Session start time: 2023-05-23T23:05:56\n", + "- Session end time: 2023-05-24T04:10:10\n", + "- Experimenter: John Rohde\n", + "- Storage directory: D:/SmartSPIM_Data\n", + "- Imaging axes: X (left-right, μm), Y (posterior-anterior, μm), Z (superior-inferior, μm)\n", + "- Chamber immersion medium: Cargille Oil 1.5200 (RI 1.5207)\n", + "- Sample immersion medium: EasyIndex (RI 1.513)\n", + "\n", + "Tile Details:\n", + "- Tile 83: Ex 639 nm, Em 660 nm, file path Ex_639_Em_660/540340/540340_563980/, translation (54034, 56398, 4.2 μm), scale (1.8, 1.8, 2)\n", + "- Tile 53: Ex 639 nm, Em 660 nm, file path Ex_639_Em_660/475540/475540_512140/, translation (47554, 51214, 4.2 μm), scale (1.8, 1.8, 2)\n", + "- Tile 54: Ex 445 nm, Em 469 nm, 30 mW, file path Ex_445_Em_469/507940/507940_512140/, translation (50794, 51214, 4.2 μm), scale (1.8, 1.8, 2)\n", + "- Tile 23: Ex 639 nm, Em 660 nm, file path Ex_639_Em_660/540340/540340_434380/, translation (54034, 43438, 4.2 μm), scale (1.8, 1.8, 2)\n", + "- Tile 24: Ex 445 nm, Em 469 nm, 30 mW, file path Ex_445_Em_469/443140/443140_460300/, translation (44314, 46030, 4.2 μm), scale (1.8, 1.8, 2)\n", + "\n", + "The data was collected as part of the \"Thalamus in the middle\" project at the Allen Institute for Neural Dynamics, funded by NINDS grant NIH1U19NS123714-01.\n" + ] + } + ], + "source": [ + "from GAMER import GAMER\n", + "query = \"give me a summary of SmartSPIM_675387_2023-05-23_23-05-56\"\n", + "\n", + "model = GAMER()\n", + "result = model.invoke(query)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Asynchronous calling" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'The retrieved information provides details about an imaging experiment conducted on a SmartSPIM instrument for the subject with ID 675387. Here is a summary of the key points:\\n\\n1. The experiment was performed on the SmartSPIM1-2 instrument located at 615 Westlake, manufactured by LifeCanvas. The instrument has temperature control but no humidity control.\\n\\n2. The imaging session started on 2023-05-23 at 23:05:56 and ended on 2023-05-24 at 04:10:10. The experimenter was John Rohde.\\n\\n3. The data was acquired with a Thorlabs TL2X-SAP objective with a numerical aperture of 0.1 and 1.6x magnification.\\n\\n4. The imaging was performed with a 639nm excitation laser and 660nm emission filter.\\n\\n5. A single tile (tile 83) was imaged, with coordinates (54034, 56398, 4.2) and a scaling factor of (1.8, 1.8, 2).\\n\\n6. The specimen (675387) underwent several procedures, including active delipidation, 50% EasyIndex refractive index matching, and 100% EasyIndex refractive index matching, performed by experimenter DT.\\n\\n7. The specimen was obtained through a perfusion procedure following a surgery on 2023-04-28, performed by experimenter 30509.'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from GAMER import GAMER\n", + "llm = GAMER()\n", + "query = \"give me a summary of SmartSPIM_675387_2023-05-23_23-05-56\"\n", + "\n", + "await llm.ainvoke(query)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/metadata_chatbot/agents/__init__.py b/src/metadata_chatbot/agents/__init__.py new file mode 100644 index 0000000..a41897c --- /dev/null +++ b/src/metadata_chatbot/agents/__init__.py @@ -0,0 +1,3 @@ +"""Init package""" +__version__ = "0.0.12" + diff --git a/src/metadata_chatbot/agents/agentic_graph.py b/src/metadata_chatbot/agents/agentic_graph.py new file mode 100644 index 0000000..4afe30a --- /dev/null +++ b/src/metadata_chatbot/agents/agentic_graph.py @@ -0,0 +1,112 @@ +from pydantic import BaseModel, Field +from langchain_aws import ChatBedrock +from langchain import hub +import logging +from typing import Literal +from langchain_core.output_parsers import StrOutputParser +from langchain_core.tools import tool +from aind_data_access_api.document_db_ssh import DocumentDbSSHClient, DocumentDbSSHCredentials +from langchain.agents import AgentExecutor, create_tool_calling_agent + +logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") + +MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" +LLM = ChatBedrock( + model_id= MODEL_ID, + model_kwargs= { + "temperature": 0 + } +) + +#determining if entire database needs to be surveyed +class RouteQuery(BaseModel): + """Route a user query to the most relevant datasource.""" + + datasource: Literal["vectorstore", "direct_database"] = Field( + description="Given a user question choose to route it to the direct database or its vectorstore.", + ) + +structured_llm_router = LLM.with_structured_output(RouteQuery) +router_prompt = hub.pull("eden19/query_rerouter") +datasource_router = router_prompt | structured_llm_router + + +# Queries that require surveying the entire database (like count based questions) +credentials = DocumentDbSSHCredentials() +credentials.database = "metadata_vector_index" +credentials.collection = "curated_assets" +@tool +def aggregation_retrieval(agg_pipeline: list) -> list: + """Given a MongoDB query and list of projections, this function retrieves and returns the + relevant information in the documents. + Use a project stage as the first stage to minimize the size of the queries before proceeding with the remaining steps. + The input to $map must be an array not a string, avoid using it in the $project stage. + + Parameters + ---------- + agg_pipeline + MongoDB aggregation pipeline + + Returns + ------- + list + List of retrieved documents + """ + with DocumentDbSSHClient(credentials=credentials) as doc_db_client: + + result = list(doc_db_client.collection.aggregate( + pipeline=agg_pipeline + )) + return result + +tools = [aggregation_retrieval] +prompt = hub.pull("eden19/entire_db_retrieval") + +db_surveyor_agent = create_tool_calling_agent(LLM, tools, prompt) +db_surveyor = AgentExecutor(agent=db_surveyor_agent, tools=tools, verbose=False) + + +# Processing query +class ProcessQuery(BaseModel): + """Binary score to check whether query requires retrieval to be filtered with metadata information to achieve accurate results.""" + + binary_score: str = Field( + description="Query requires further filtering during retrieval process, 'yes' or 'no'" + ) + +query_grader = LLM.with_structured_output(ProcessQuery) +query_grade_prompt = hub.pull("eden19/processquery") +query_grader = query_grade_prompt | query_grader +# query_grade = query_grader.invoke({"query": question}).binary_score + +# Generating appropriate filter +class FilterGenerator(BaseModel): + """MongoDB filter to be applied before vector retrieval""" + + filter_query: dict = Field(description="MongoDB filter") + top_k: int = Field(description="Number of documents to retrieve from the database") + +filter_prompt = hub.pull("eden19/filtergeneration") +filter_generator_llm = LLM.with_structured_output(FilterGenerator) + +filter_generation_chain = filter_prompt | filter_generator_llm +# filter = filter_generation_chain.invoke({"query": question}).filter_query + +# Check if retrieved documents answer question +class RetrievalGrader(BaseModel): + """Binary score to check whether retrieved documents are relevant to the question""" + binary_score: str = Field( + description="Retrieved documents are relevant to the query, 'yes' or 'no'" + ) + +retrieval_grader = LLM.with_structured_output(RetrievalGrader) +retrieval_grade_prompt = hub.pull("eden19/retrievalgrader") +doc_grader = retrieval_grade_prompt | retrieval_grader +# doc_grade = doc_grader.invoke({"query": question, "document": doc}).binary_score +# logging.info(f"Retrieved document matched query: {doc_grade}") + +# Generating response to documents +answer_generation_prompt = hub.pull("eden19/answergeneration") +rag_chain = answer_generation_prompt | LLM | StrOutputParser() +# generation = rag_chain.invoke({"documents": doc, "query": question}) +# logging.info(f"Final answer: {generation}") \ No newline at end of file diff --git a/src/metadata_chatbot/agents/async_workflow.py b/src/metadata_chatbot/agents/async_workflow.py new file mode 100644 index 0000000..b534bfd --- /dev/null +++ b/src/metadata_chatbot/agents/async_workflow.py @@ -0,0 +1,207 @@ +import logging, sys, os, asyncio +from typing import List, Optional +from typing_extensions import TypedDict +from langchain_core.documents import Document +from langgraph.graph import END, StateGraph, START +from langgraph.checkpoint.memory import MemorySaver + + +sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) +from metadata_chatbot.utils import ResourceManager + +from docdb_retriever import DocDBRetriever +from agentic_graph import datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain + +logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") + +class GraphState(TypedDict): + """ + Represents the state of our graph. + + Attributes: + query: question asked by user + generation: LLM generation + documents: list of documents + """ + + query: str + generation: str + documents: List[str] + filter: Optional[dict] + top_k: Optional[int] + +async def route_question_async(state): + """ + Route question to database or vectorstore + Args: + state (dict): The current graph state + + Returns: + str: Next node to call + """ + query = state["query"] + + source = await datasource_router.ainvoke({"query": query}) + if source.datasource == "direct_database": + logging.info("Entire database needs to be queried.") + return "direct_database" + elif source.datasource == "vectorstore": + logging.info("Querying against vector embeddings...") + return "vectorstore" + +async def generate_for_whole_db_async(state): + """ + Filter database + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key may be added to state, generation, which contains the answer for query asked + """ + + query = state["query"] + chat_history = [] + + logging.info("Generating answer...") + + generation = await db_surveyor.ainvoke({'query': query, 'chat_history': chat_history}) + return {"query": query, "generation": generation} + +async def filter_generator_async(state): + """ + Filter database + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key may be added to state, filter, which contains the MongoDB query that will be applied before retrieval + """ + logging.info("Determining whether filter is required...") + + query = state["query"] + + result = await query_grader.ainvoke({"query": query}) + query_grade = result.binary_score + logging.info(f"Database needs to be further filtered: {query_grade}") + + if query_grade == "yes": + result = await filter_generation_chain.ainvoke({"query": query}) + filter = result.filter_query + logging.info(f"Database will be filtered using: {filter}") + return {"filter": filter, "query": query} + else: + return {"filter": None, "query": query} + +async def retrieve_async(state): + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + logging.info("Retrieving documents...") + query = state["query"] + filter = state["filter"] + + # Retrieval + with ResourceManager() as RM: + db = RM.async_client.get_database('metadata_vector_index') + collection = db.get_collection('bigger_LANGCHAIN_curated_chunks') + retriever = DocDBRetriever(collection = collection, k = 10) + documents = await retriever.aget_relevant_documents(query = query, query_filter = filter) + return {"documents": documents, "query": query} + +async def grade_doc_async(query, doc: Document): + score = await doc_grader.ainvoke({"query": query, "document": doc.page_content}) + grade = score.binary_score + logging.info(f"Retrieved document matched query: {grade}") + if grade == "yes": + logging.info("Document is relevant to the query") + return doc + else: + logging.info("Document is not relevant and will be removed") + return None + + +async def grade_documents_async(state): + """ + Determines whether the retrieved documents are relevant to the question. + + Args: + state (dict): The current graph state + + Returns: + state (dict): Updates documents key with only filtered relevant documents + """ + + logging.info("Checking relevance of documents to question asked...") + query = state["query"] + documents = state["documents"] + + filtered_docs = await asyncio.gather(*[grade_doc_async(query, doc) for doc in documents]) + filtered_docs = [doc for doc in filtered_docs if doc is not None] + return {"documents": filtered_docs, "query": query} + +async def generate_async(state): + """ + Generate answer + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, generation, that contains LLM generation + """ + logging.info("Generating answer...") + query = state["query"] + documents = state["documents"] + + doc_text = "\n\n".join(doc.page_content for doc in documents) + + # RAG generation + generation = await rag_chain.ainvoke({"documents": doc_text, "query": query}) + return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)} + +async_workflow = StateGraph(GraphState) +async_workflow.add_node("database_query", generate_for_whole_db_async) +async_workflow.add_node("filter_generation", filter_generator_async) +async_workflow.add_node("retrieve", retrieve_async) +async_workflow.add_node("document_grading", grade_documents_async) +async_workflow.add_node("generate", generate_async) + +async_workflow.add_conditional_edges( + START, + route_question_async, + { + "direct_database": "database_query", + "vectorstore": "filter_generation", + }, +) +async_workflow.add_edge("filter_generation", "retrieve") +async_workflow.add_edge("retrieve", "document_grading") +async_workflow.add_edge("document_grading","generate") +async_workflow.add_edge("generate", END) + + +async_app = async_workflow.compile() + +async def main(): + query = "What was the age of the subject when receiving injections in asset SmartSPIM_675388_2023-05-24_04-10-19_stitched_2023-05-28_18-07-46?" + inputs = {"query": query} + result = async_app.astream(inputs) + + value = None + async for output in result: + for key, value in output.items(): + logging.info(f"Currently on node '{key}':") + + if value: + print(value['generation']) + +# Run the async function +asyncio.run(main()) diff --git a/src/metadata_chatbot/agents/docdb_retriever.py b/src/metadata_chatbot/agents/docdb_retriever.py new file mode 100644 index 0000000..c329386 --- /dev/null +++ b/src/metadata_chatbot/agents/docdb_retriever.py @@ -0,0 +1,122 @@ +import sys, os, json +from typing import List, Optional, Any, Union, Annotated +from pymongo.collection import Collection +from motor.motor_asyncio import AsyncIOMotorCollection +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever +from bson import json_util +from langsmith import trace as langsmith_trace +from pydantic import Field + +sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) +from metadata_chatbot.utils import BEDROCK_EMBEDDINGS + + + +class DocDBRetriever(BaseRetriever): + """A retriever that contains the top k documents, retrieved from the DocDB index, aligned with the user's query.""" + collection: Any = Field(description="DocDB collection to retrieve from") + k: int = Field(default=10, description="Number of documents to retrieve") + + def _get_relevant_documents( + self, + query: str, + query_filter: Optional[dict] = None, + run_manager: Optional[CallbackManagerForRetrieverRun] = None, + **kwargs: Any, + ) -> List[Document]: + + #Embed query + embedded_query = BEDROCK_EMBEDDINGS.embed_query(query) + + #Construct aggregation pipeline + vector_search = { + "$search": { + "vectorSearch": { + "vector": embedded_query, + "path": 'vectorContent', + "similarity": 'euclidean', + "k": self.k + } + } + } + + pipeline = [vector_search] + if query_filter: + pipeline.insert(0, query_filter) + + cursor = self.collection.aggregate(pipeline) + + page_content_field = 'textContent' + + results = [] + + #Transform retrieved docs to langchain Documents + for document in cursor: + values_to_metadata = dict() + + json_doc = json.loads(json_util.dumps(document)) + + for key, value in json_doc.items(): + if key == page_content_field: + page_content = value + else: + values_to_metadata[key] = value + + new_doc = Document(page_content=page_content, metadata=values_to_metadata) + results.append(new_doc) + + return results + + async def _aget_relevant_documents( + self, + query: str, + query_filter: Optional[dict] = None, + run_manager: Optional[CallbackManagerForRetrieverRun] = None, + **kwargs: Any, + ) -> List[Document]: + + #Embed query + embedded_query = BEDROCK_EMBEDDINGS.embed_query(query) + + #Construct aggregation pipeline + vector_search = { + "$search": { + "vectorSearch": { + "vector": embedded_query, + "path": 'vectorContent', + "similarity": 'euclidean', + "k": self.k, + "efSearch": 40 + } + } + } + + pipeline = [vector_search] + if query_filter: + pipeline.insert(0, query_filter) + + cursor = self.collection.aggregate(pipeline, allowDiskUse=True) + #results = await cursor.to_list(length=1000) + + page_content_field = 'textContent' + + results = [] + + #Transform retrieved docs to langchain Documents + async for document in cursor: + values_to_metadata = dict() + + json_doc = json.loads(json_util.dumps(document)) + + for key, value in json_doc.items(): + if key == page_content_field: + page_content = value + else: + values_to_metadata[key] = value + + new_doc = Document(page_content=page_content, metadata=values_to_metadata) + results.append(new_doc) + + return results \ No newline at end of file diff --git a/src/metadata_chatbot/agents/workflow.py b/src/metadata_chatbot/agents/workflow.py new file mode 100644 index 0000000..a6ec7ac --- /dev/null +++ b/src/metadata_chatbot/agents/workflow.py @@ -0,0 +1,193 @@ +import logging, sys, os +from typing import List, Optional +from typing_extensions import TypedDict +from langgraph.graph import END, StateGraph, START +from langgraph.checkpoint.memory import MemorySaver + + +sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) +from metadata_chatbot.utils import ResourceManager + +from docdb_retriever import DocDBRetriever +from agentic_graph import datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain + +logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") + +class GraphState(TypedDict): + """ + Represents the state of our graph. + + Attributes: + query: question asked by user + generation: LLM generation + documents: list of documents + """ + + query: str + generation: str + documents: List[str] + filter: Optional[dict] + top_k: Optional[int] + +def route_question(state): + """ + Route question to database or vectorstore + Args: + state (dict): The current graph state + + Returns: + str: Next node to call + """ + query = state["query"] + + source = datasource_router.invoke({"query": query}) + if source.datasource == "direct_database": + logging.info("Entire database needs to be queried.") + return "direct_database" + elif source.datasource == "vectorstore": + logging.info("Querying against vector embeddings...") + return "vectorstore" + +def generate_for_whole_db(state): + """ + Filter database + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key may be added to state, generation, which contains the answer for query asked + """ + + query = state["query"] + chat_history = [] + + logging.info("Generating answer...") + + generation = db_surveyor.invoke({'query': query, 'chat_history': chat_history}) + return {"query": query, "generation": generation} + +def filter_generator(state): + """ + Filter database + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key may be added to states, filter and top_k, which contains the MongoDB query that will be applied before retrieval + """ + logging.info("Determining whether filter is required...") + + query = state["query"] + + query_grade = query_grader.invoke({"query": query}).binary_score + logging.info(f"Database needs to be further filtered: {query_grade}") + + if query_grade == "yes": + filter = filter_generation_chain.invoke({"query": query}).filter_query + top_k = filter_generation_chain.invoke({"query": query}).top_k + logging.info(f"Database will be filtered using: {filter}") + return {"filter": filter, "top_k": top_k, "query": query} + else: + return {"filter": None, "top_k": None, "query": query} + +def retrieve(state): + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + logging.info("Retrieving documents...") + query = state["query"] + filter = state["filter"] + top_k = state["top_k"] + + # Retrieval + with ResourceManager() as RM: + collection = RM.client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets'] + retriever = DocDBRetriever(collection = collection, k = top_k) + documents = retriever.get_relevant_documents(query = query, query_filter = filter) + return {"documents": documents, "query": query} + +def grade_documents(state): + """ + Determines whether the retrieved documents are relevant to the question. + + Args: + state (dict): The current graph state + + Returns: + state (dict): Updates documents key with only filtered relevant documents + """ + + logging.info("Checking relevance of documents to question asked...") + query = state["query"] + documents = state["documents"] + + # Score each doc + filtered_docs = [] + for doc in documents: + score = doc_grader.invoke({"query": query, "document": doc.page_content}) + grade = score.binary_score + logging.info(f"Retrieved document matched query: {grade}") + if grade == "yes": + logging.info("Document is relevant to the query") + filtered_docs.append(doc) + else: + logging.info("Document is not relevant and will be removed") + continue + return {"documents": filtered_docs, "query": query} + +def generate(state): + """ + Generate answer + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, generation, that contains LLM generation + """ + logging.info("Generating answer...") + query = state["query"] + documents = state["documents"] + + doc_text = "\n\n".join(doc.page_content for doc in documents) + + # RAG generation + generation = rag_chain.invoke({"documents": doc_text, "query": query}) + return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)} + +workflow = StateGraph(GraphState) +workflow.add_node("database_query", generate_for_whole_db) +workflow.add_node("filter_generation", filter_generator) +workflow.add_node("retrieve", retrieve) +workflow.add_node("document_grading", grade_documents) +workflow.add_node("generate", generate) + +workflow.add_conditional_edges( + START, + route_question, + { + "direct_database": "database_query", + "vectorstore": "filter_generation", + }, +) +workflow.add_edge("filter_generation", "retrieve") +workflow.add_edge("retrieve", "document_grading") +workflow.add_edge("document_grading","generate") +workflow.add_edge("generate", END) + +memory = MemorySaver() +app = workflow.compile(checkpointer=memory) + +# query = "How old was the subject in SmartSPIM_675388_2023-05-24_04-10-19_stitched_2023-05-28_18-07-46" + +# inputs = {"query" : query} +# answer = app.invoke(inputs) +# print(answer['generation']) \ No newline at end of file diff --git a/src/metadata_chatbot/Metamorph.py b/src/metadata_chatbot/bedrock_model/Metamorph.py similarity index 93% rename from src/metadata_chatbot/Metamorph.py rename to src/metadata_chatbot/bedrock_model/Metamorph.py index e148325..d0786f7 100644 --- a/src/metadata_chatbot/Metamorph.py +++ b/src/metadata_chatbot/bedrock_model/Metamorph.py @@ -39,5 +39,5 @@ def _llm_type(self) -> str: if __name__ == '__main__': llm = Metamorph() - prompt = "Give me the count of genotypes in each modality in the database?" + prompt = "Give me the count of genotypes in the ecephys modality in the database." llm.invoke(prompt) diff --git a/src/metadata_chatbot/bedrock_model/__init__.py b/src/metadata_chatbot/bedrock_model/__init__.py new file mode 100644 index 0000000..a41897c --- /dev/null +++ b/src/metadata_chatbot/bedrock_model/__init__.py @@ -0,0 +1,3 @@ +"""Init package""" +__version__ = "0.0.12" + diff --git a/src/metadata_chatbot/chat.py b/src/metadata_chatbot/bedrock_model/chat.py similarity index 99% rename from src/metadata_chatbot/chat.py rename to src/metadata_chatbot/bedrock_model/chat.py index 7ca1f88..7113996 100644 --- a/src/metadata_chatbot/chat.py +++ b/src/metadata_chatbot/bedrock_model/chat.py @@ -267,6 +267,6 @@ def simple_chat(bedrock_client = bedrock, system_prompt = system_prompt): if __name__ == '__main__': #simple_chat(bedrock) - prompt = "How many experiments of each unique modality exists in the database?" + prompt = "What is the experimental history for subject 664956" response = get_completion(prompt, bedrock) print(response) \ No newline at end of file diff --git a/src/metadata_chatbot/config.py b/src/metadata_chatbot/bedrock_model/config.py similarity index 100% rename from src/metadata_chatbot/config.py rename to src/metadata_chatbot/bedrock_model/config.py diff --git a/src/metadata_chatbot/ref/acquisition_schema.json b/src/metadata_chatbot/bedrock_model/ref/acquisition_schema.json similarity index 100% rename from src/metadata_chatbot/ref/acquisition_schema.json rename to src/metadata_chatbot/bedrock_model/ref/acquisition_schema.json diff --git a/src/metadata_chatbot/ref/data_description_schema.json b/src/metadata_chatbot/bedrock_model/ref/data_description_schema.json similarity index 100% rename from src/metadata_chatbot/ref/data_description_schema.json rename to src/metadata_chatbot/bedrock_model/ref/data_description_schema.json diff --git a/src/metadata_chatbot/ref/instrument_schema.json b/src/metadata_chatbot/bedrock_model/ref/instrument_schema.json similarity index 100% rename from src/metadata_chatbot/ref/instrument_schema.json rename to src/metadata_chatbot/bedrock_model/ref/instrument_schema.json diff --git a/src/metadata_chatbot/ref/metadata.json b/src/metadata_chatbot/bedrock_model/ref/metadata.json similarity index 100% rename from src/metadata_chatbot/ref/metadata.json rename to src/metadata_chatbot/bedrock_model/ref/metadata.json diff --git a/src/metadata_chatbot/ref/procedures_schema.json b/src/metadata_chatbot/bedrock_model/ref/procedures_schema.json similarity index 100% rename from src/metadata_chatbot/ref/procedures_schema.json rename to src/metadata_chatbot/bedrock_model/ref/procedures_schema.json diff --git a/src/metadata_chatbot/ref/processing_schema.json b/src/metadata_chatbot/bedrock_model/ref/processing_schema.json similarity index 100% rename from src/metadata_chatbot/ref/processing_schema.json rename to src/metadata_chatbot/bedrock_model/ref/processing_schema.json diff --git a/src/metadata_chatbot/ref/rig_schema.json b/src/metadata_chatbot/bedrock_model/ref/rig_schema.json similarity index 100% rename from src/metadata_chatbot/ref/rig_schema.json rename to src/metadata_chatbot/bedrock_model/ref/rig_schema.json diff --git a/src/metadata_chatbot/ref/session_schema.json b/src/metadata_chatbot/bedrock_model/ref/session_schema.json similarity index 100% rename from src/metadata_chatbot/ref/session_schema.json rename to src/metadata_chatbot/bedrock_model/ref/session_schema.json diff --git a/src/metadata_chatbot/ref/subject_609281_metadata.json b/src/metadata_chatbot/bedrock_model/ref/subject_609281_metadata.json similarity index 100% rename from src/metadata_chatbot/ref/subject_609281_metadata.json rename to src/metadata_chatbot/bedrock_model/ref/subject_609281_metadata.json diff --git a/src/metadata_chatbot/ref/subject_schema.json b/src/metadata_chatbot/bedrock_model/ref/subject_schema.json similarity index 100% rename from src/metadata_chatbot/ref/subject_schema.json rename to src/metadata_chatbot/bedrock_model/ref/subject_schema.json diff --git a/src/metadata_chatbot/system_prompt.py b/src/metadata_chatbot/bedrock_model/system_prompt.py similarity index 91% rename from src/metadata_chatbot/system_prompt.py rename to src/metadata_chatbot/bedrock_model/system_prompt.py index e991cd6..2c16cd4 100644 --- a/src/metadata_chatbot/system_prompt.py +++ b/src/metadata_chatbot/bedrock_model/system_prompt.py @@ -1,12 +1,12 @@ -import os, json +import os, json, re, logging from pathlib import Path -import re cwd = os.path.dirname(os.path.realpath(__file__)) folder = Path(f"{cwd}\\ref") schema_types = [] + for name in os.listdir(folder): #loading in schema files f = open(f'{folder}\\{name}') @@ -19,10 +19,10 @@ metadata_schema = file system_prompt = f""" -You are a neuroscientist with extensive knowledge about processes involves in neuroscience research. -You are also an expert in crafting queries for MongoDB. +You are a neuroscientist with extensive knowledge about processes involving in neuroscience research. +You are also an expert in crafting queries and projections in MongoDB. -I will provide you with a list of schemas that contains information about the accepted inputs of variable names in a JSON file. +Here is a list of schemas that contains information about the structure of a JSON file. Each schema is provided in a specified format and each file corresponds to a different section of an experiment. List of schemas: {schema_types} @@ -34,9 +34,7 @@ You can use it as a guide to better structure your queries. Sample metadata: {sample_metadata} -Your task is to read the user's question, which will adhere to certain guidelines or formats. -You maybe prompted to determine missing information in the sample metadata. -You maybe prompted to retrieve information from an external database, the information will be stored in json files. +Your task is to read the user's question, which will adhere to certain guidelines or formats and create a MongoDB query and projection, to Here are some examples: Input: Give me the query to find subject's whose breeding group is Chat-IRES-Cre_Jax006410 @@ -74,7 +72,11 @@ For example, do not end your answer with: The query first projects to include only the `data_description.modality` field, then unwinds the modality array to get individual modality objects. It groups the documents by the modality name and counts them using the `$sum` accumulator. Finally, it projects to include only the modality name and count fields. The results show the count of each modality present in the database. -I want to see the actual summary of results retrieved, for example: +I want to see the actual summary of results retrieved and be straightforward in your answer. Each sentence produced should directly answer the question asked. +When asked about each modality or each type of something, provide examples for ALL modalities, do NOT say "...and so on for the other modalities present in the database" or any version of this phrase. +Provide a summary of the retrieved input, including numerical values. +When asked a question like how many experiments of each modality are there, I want to see an answer like this. +For example: Optical Physiology: 40, Frame-projected independent-fiber photometry: 383, Behavior videos: 4213, Hyperspectral fiber photometry: 105, Extracellular electrophysiology: 2618, Electrophysiology: 12, Multiplane optical physiology: 13, Fiber photometry: 1761, Selective plane illumination microscopy: 3485, Planar optical physiology: 1330, Trained behavior: 32, None: 1481, Dual inverted selective plane illumination microscopy: 6, Behavior: 11016 @@ -84,7 +86,7 @@ Do not hallucinate. """ - +print(system_prompt) summary_system_prompt = f""" You are a neuroscientist with extensive knowledge about processes involves in neuroscience research. You are also an expert in crafting queries for MongoDB. diff --git a/src/metadata_chatbot/tools.py b/src/metadata_chatbot/bedrock_model/tools.py similarity index 100% rename from src/metadata_chatbot/tools.py rename to src/metadata_chatbot/bedrock_model/tools.py diff --git a/src/metadata_chatbot/embedding files/embeddings.py b/src/metadata_chatbot/embedding files/embeddings.py deleted file mode 100644 index 631ae4c..0000000 --- a/src/metadata_chatbot/embedding files/embeddings.py +++ /dev/null @@ -1,53 +0,0 @@ -import boto3, os -from langchain_community.document_loaders.mongodb import MongodbLoader -from aind_data_access_api.document_db_ssh import DocumentDbSSHClient, DocumentDbSSHCredentials -from pymongo import MongoClient -from urllib.parse import quote_plus - - -#establishing embedding model -model_id = "amazon.titan-embed-text-v2:0" - -bedrock = boto3.client( - service_name="bedrock-runtime", - region_name = 'us-west-2' -) - -print("hi") -from pymongo.errors import ConnectionFailure - - -connection_string = f"mongodb://{escaped_username}:{escaped_password}@localhost:27017/metadata_vector_index" - -print(connection_string) -def test_connection(connection_string): - try: - print("connecting") - client = MongoClient(connection_string, serverSelectionTimeoutMS=5000) - client.admin.command('ismaster') - print("MongoDB connection successful!") - return True - except ConnectionFailure: - print("MongoDB connection failed!") - return False - -test_connection(connection_string) - -# connecting to MongoDB - -db = 'metadata_vector_index' -collection = 'data_assets_dev' - -client = MongoClient(os.environ['CONNECTION_STRING']) -db = client['metadata_vector_index'] -collection = db['data_assets_dev'] - -loader = MongodbLoader( - connection_string = os.environ['CONNECTION_STRING'], - db_name = 'metadata_vector_index', - collection_name='data_assets_dev' -) - -docs = loader.load() - -len(docs) \ No newline at end of file diff --git a/src/metadata_chatbot/embedding files/trial.py b/src/metadata_chatbot/embedding files/trial.py deleted file mode 100644 index 910e1af..0000000 --- a/src/metadata_chatbot/embedding files/trial.py +++ /dev/null @@ -1,34 +0,0 @@ -from urllib.parse import quote_plus -import pymongo -from pymongo import MongoClient -from langchain_community.document_loaders.mongodb import MongodbLoader - -username = "sreya.kumar" -password = "Nimalja3min!23" -database_name = "metadata_vector_index" - - # Escape username and password to handle special characters -escaped_username = quote_plus(username) -escaped_password = quote_plus(password) - -connection_string = 'mongodb://localhost:27018/' - -try: - print(f"Attempting to connect with: {connection_string}") - client = MongoClient('mongodb://localhost:27018/', serverSelectionTimeoutMS=5000) - print("Initial connection successful") - - # Force a server check - server_info = client.server_info() - print(f"Server info: {server_info}") - - print("Connected successfully!") - -except pymongo.errors.ServerSelectionTimeoutError as e: - print(f"Server selection timeout error: {e}") - print(f"Current topology description: {client.topology_description}") -except Exception as e: - print(f"An error occurred: {e}") -finally: - if 'client' in locals(): - client.close() \ No newline at end of file diff --git a/src/metadata_chatbot/langchain_app.ipynb b/src/metadata_chatbot/langchain_app.ipynb deleted file mode 100644 index 14edcfa..0000000 --- a/src/metadata_chatbot/langchain_app.ipynb +++ /dev/null @@ -1,71 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n" - ] - } - ], - "source": [ - "\n", - "import os\n", - "\n", - "print(os.getenv(\"LANGCHAIN_API_KEY\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from Metamorph import Metamorph\n", - "model = Metamorph()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.messages import HumanMessage, SystemMessage\n", - "\n", - "messages = [\n", - " SystemMessage(content=\"Translate the following from English into Italian\"),\n", - " HumanMessage(content=\"hi!\"),\n", - "]\n", - "\n", - "model.invoke(messages)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/src/metadata_chatbot/main.py b/src/metadata_chatbot/main.py index e8d3ef4..b25eaf9 100644 --- a/src/metadata_chatbot/main.py +++ b/src/metadata_chatbot/main.py @@ -1,12 +1,13 @@ from fastapi import FastAPI import uvicorn -from metadata_chatbot.chat import get_summary +from metadata_chatbot.bedrock_model.chat import get_summary app = FastAPI() @app.get("/summary/{_id}") def REST_summary(_id: str): - return get_summary(_id) + result = get_summary(_id) + return result if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8000) \ No newline at end of file diff --git a/src/metadata_chatbot/utils.py b/src/metadata_chatbot/utils.py new file mode 100644 index 0000000..15a2975 --- /dev/null +++ b/src/metadata_chatbot/utils.py @@ -0,0 +1,93 @@ +from sshtunnel import SSHTunnelForwarder +import logging, os, boto3 +from urllib.parse import quote_plus +from langchain_community.vectorstores.documentdb import DocumentDBVectorSearch +from langchain_aws import BedrockEmbeddings +from pymongo import MongoClient +from motor.motor_asyncio import AsyncIOMotorClient + + +BEDROCK_CLIENT = boto3.client( + service_name="bedrock-runtime", + region_name = 'us-west-2' +) + +BEDROCK_EMBEDDINGS = BedrockEmbeddings(model_id="amazon.titan-embed-text-v2:0",client=BEDROCK_CLIENT) + +escaped_username = quote_plus(os.getenv("DOC_DB_USERNAME")) +escaped_password = quote_plus(os.getenv('DOC_DB_PASSWORD')) + +CONNECTION_STRING = f"mongodb://{escaped_username}:{escaped_password}@localhost:27017/?directConnection=true&authMechanism=SCRAM-SHA-1&retryWrites=false" + +CURATED_NAMESPACE = 'metadata_vector_index.LANGCHAIN_curated_assets' +CURATED_INDEX_NAME = "curated_embeddings_index" + +CURATED_VECTORSTORE = DocumentDBVectorSearch.from_connection_string( + connection_string=CONNECTION_STRING, + namespace=CURATED_NAMESPACE, + embedding=BEDROCK_EMBEDDINGS, + index_name=CURATED_INDEX_NAME + ) + +client = MongoClient(CONNECTION_STRING) +LANGCHAIN_COLLECTION = client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets'] + +ALL_curated_namespace = 'metadata_vector_index.LANGCHAIN_ALL_curated_assets' +ALL_curated_index_name = "ALL_curated_embeddings_index" +ALL_CURATED_VECTORSTORE = DocumentDBVectorSearch.from_connection_string( + connection_string=CONNECTION_STRING, + namespace=ALL_curated_namespace, + embedding=BEDROCK_EMBEDDINGS, + index_name=ALL_curated_index_name + ) + + +def create_ssh_tunnel(): + """Create an SSH tunnel to the Document Database.""" + try: + return SSHTunnelForwarder( + ssh_address_or_host=( + os.getenv("DOC_DB_SSH_HOST"), + 22, + ), + ssh_username=os.getenv("DOC_DB_SSH_USERNAME"), + ssh_password=os.getenv("DOC_DB_SSH_PASSWORD"), + remote_bind_address=(os.getenv("DOC_DB_HOST"), 27017), + local_bind_address=( + "localhost", + 27017, + ), + ) + except Exception as e: + logging.error(f"Error creating SSH tunnel: {e}") + +class ResourceManager: + def __init__(self): + self.ssh_server = None + self.client = None + self.async_client = None + + def __enter__(self): + try: + self.ssh_server = create_ssh_tunnel() + self.ssh_server.start() + logging.info("SSH tunnel opened") + + self.client = MongoClient(CONNECTION_STRING) + self.async_client = AsyncIOMotorClient(CONNECTION_STRING) + logging.info("Successfully connected to MongoDB") + + return self + except Exception as e: + logging.exception(e) + self.__exit__(None, None, None) + raise + + def __exit__(self, exc_type, exc_value, traceback): + if self.client: + self.client.close() + if self.async_client: + self.async_client.close() + if self.ssh_server: + self.ssh_server.stop() + logging.info("Resources cleaned up") \ No newline at end of file