Skip to content

Commit

Permalink
Add DocQA Tool (#330)
Browse files Browse the repository at this point in the history
* added embeddings for stella

* add cached embeddings

* add embedding to module

* import fixes

* fix imports

* fix category for query for tools

* updated doc qa tool docs

* added doc qa category, fixed prompts for tools

* fixed mypy flake8

* updated example

* updated docs
  • Loading branch information
dillonalaird authored Dec 19, 2024
1 parent 4c4d4e2 commit 8dce01a
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 79 deletions.
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,15 @@ To get started with the python library, you can install it using pip:
pip install vision-agent
```

Ensure you have both an Anthropic key and an OpenAI API key and set in your environment
variables (if you are using Azure OpenAI please see the Azure setup section):

```bash
export ANTHROPIC_API_KEY="your-api-key"
export OPENAI_API_KEY="your-api-key"
```

---
**NOTE**
You must have both Anthropic and OpenAI API keys set in your environment variables to
use VisionAgent. If you don't have an Anthropic key you can use Ollama as a backend.
You must have the Anthropic API key set in your environment variables to use
VisionAgent. If you don't have an Anthropic key you can use another provider like
OpenAI or Ollama.
---

#### Chatting with VisionAgent
Expand Down Expand Up @@ -116,8 +113,7 @@ Anthropic/OpenAI models.
### Chatting and Message Formats
`VisionAgent` is an agent that can chat with you and call other tools or agents to
write vision code for you. You can interact with it like you would ChatGPT or any other
chatbot. The agent uses Clause-3.5 for it's LMM and OpenAI for embeddings for searching
for tools.
chatbot. The agent uses Clause-3.5 for it's LMM.

The message format is:
```json
Expand Down
25 changes: 4 additions & 21 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
<div align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/landing-ai/vision-agent/blob/main/assets/logo_light.svg?raw=true">
<source media="(prefers-color-scheme: light)" srcset="https://github.com/landing-ai/vision-agent/blob/main/assets/logo_dark.svg?raw=true">
<img alt="VisionAgent" height="200px" src="https://github.com/landing-ai/vision-agent/blob/main/assets/logo_light.svg?raw=true">
</picture>

[![](https://dcbadge.vercel.app/api/server/wPdN8RCYew?compact=true&style=flat)](https://discord.gg/wPdN8RCYew)
![ci_status](https://github.com/landing-ai/vision-agent/actions/workflows/ci_cd.yml/badge.svg)
[![PyPI version](https://badge.fury.io/py/vision-agent.svg)](https://badge.fury.io/py/vision-agent)
![version](https://img.shields.io/pypi/pyversions/vision-agent)
</div>

VisionAgent is a library that helps you utilize agent frameworks to generate code to
solve your vision task. Check out our discord for updates and roadmaps!

Expand Down Expand Up @@ -44,18 +31,15 @@ To get started with the python library, you can install it using pip:
pip install vision-agent
```

Ensure you have both an Anthropic key and an OpenAI API key and set in your environment
variables (if you are using Azure OpenAI please see the Azure setup section):

```bash
export ANTHROPIC_API_KEY="your-api-key"
export OPENAI_API_KEY="your-api-key"
```

---
**NOTE**
You must have both Anthropic and OpenAI API keys set in your environment variables to
use VisionAgent. If you don't have an Anthropic key you can use Ollama as a backend.
You must have the Anthropic API key set in your environment variables to use
VisionAgent. If you don't have an Anthropic key you can use another provider like
OpenAI or Ollama.
---

#### Chatting with VisionAgent
Expand Down Expand Up @@ -116,8 +100,7 @@ Anthropic/OpenAI models.
### Chatting and Message Formats
`VisionAgent` is an agent that can chat with you and call other tools or agents to
write vision code for you. You can interact with it like you would ChatGPT or any other
chatbot. The agent uses Clause-3.5 for it's LMM and OpenAI for embeddings for searching
for tools.
chatbot. The agent uses Clause-3.5 for it's LMM.

The message format is:
```json
Expand Down
16 changes: 5 additions & 11 deletions examples/notebooks/counting_cans.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,7 @@
"source": [
"## Prerequisite\n",
"\n",
"In order to run below example, you will need to provide below API keys from your OpenAI and Anthropic account.\n",
"\n",
"1. OPENAI_API_KEY\n",
"2. ANTHROPIC_API_KEY\n",
"\n",
"Supply your keys in below cell. We will set them as environment variables so VisionAgent can load them later."
"In order to run below example, you will need to provide an Anthropic API key from your Anthropic account. Supply your key in below cell. We will set it as an environment variable so VisionAgent can load it later."
]
},
{
Expand All @@ -148,8 +143,7 @@
"source": [
"import os\n",
"\n",
"# TODO: fill below with your API keys\n",
"os.environ[\"OPENAI_API_KEY\"] = \"YOUR_OPENAI_API_KEY\"\n",
"# TODO: fill below with your API key\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = \"YOUR_ANTHROPIC_API_KEY\""
]
},
Expand Down Expand Up @@ -2679,7 +2673,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -2693,9 +2687,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
24 changes: 21 additions & 3 deletions vision_agent/.sim_tools/df.csv
Original file line number Diff line number Diff line change
Expand Up @@ -460,19 +460,37 @@ desc,doc,name
-------
>>> document_analysis(image)
{'pages':
[{'bbox': [0, 0, 1700, 2200],
'chunks': [{'bbox': [1371, 75, 1503, 112],
[{'bbox': [0, 0, 1.0, 1.0],
'chunks': [{'bbox': [0.8, 0.1, 1.0, 0.2],
'label': 'page_header',
'order': 75
'caption': 'Annual Report 2024',
'summary': 'This annual report summarizes ...' },
{'bbox': [201, 1119, 1497, 1647],
{'bbox': [0.2, 0.9, 0.9, 1.0],
'label': table',
'order': 1119,
'caption': [{'Column 1': 'Value 1', 'Column 2': 'Value 2'},
'summary': 'This table illustrates a trend of ...'},
],
",document_extraction
"'document_qa' is a tool that can answer any questions about arbitrary documents, presentations, or tables. It's very useful for document QA tasks, you can ask it a specific question or ask it to return a JSON object answering multiple questions about the document.","document_qa(prompt: str, image: numpy.ndarray) -> str:
'document_qa' is a tool that can answer any questions about arbitrary documents,
presentations, or tables. It's very useful for document QA tasks, you can ask it a
specific question or ask it to return a JSON object answering multiple questions
about the document.

Parameters:
prompt (str): The question to be answered about the document image.
image (np.ndarray): The document image to analyze.

Returns:
str: The answer to the question based on the document's context.

Example
-------
>>> document_qa(image, question)
'The answer to the question ...'
",document_qa
'video_temporal_localization' will run qwen2vl on each chunk_length_frames value selected for the video. It can detect multiple objects independently per chunk_length_frames given a text prompt such as a referring expression but does not track objects across frames. It returns a list of floats with a value of 1.0 if the objects are found in a given chunk_length_frames of the video.,"video_temporal_localization(prompt: str, frames: List[numpy.ndarray], model: str = 'qwen2vl', chunk_length_frames: Optional[int] = 2) -> List[float]:
'video_temporal_localization' will run qwen2vl on each chunk_length_frames
value selected for the video. It can detect multiple objects independently per
Expand Down
Binary file modified vision_agent/.sim_tools/embs.npy
Binary file not shown.
6 changes: 3 additions & 3 deletions vision_agent/agent/vision_agent_coder_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rich.console import Console
from rich.markup import escape

import vision_agent.tools as T
import vision_agent.tools.tools as T
from vision_agent.agent import AgentCoder, AgentPlanner
from vision_agent.agent.agent_utils import (
DefaultImports,
Expand Down Expand Up @@ -34,7 +34,7 @@
CodeInterpreterFactory,
Execution,
)
from vision_agent.utils.sim import Sim
from vision_agent.utils.sim import Sim, get_tool_recommender

_CONSOLE = Console()

Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(
elif isinstance(tool_recommender, Sim):
self.tool_recommender = tool_recommender
else:
self.tool_recommender = T.get_tool_recommender()
self.tool_recommender = get_tool_recommender()

self.verbose = verbose
self.code_sandbox_runtime = code_sandbox_runtime
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/agent/vision_agent_planner_prompts_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,17 @@ def merge_bounding_box_list(bboxes):
"""

CATEGORIZE_TOOL_REQUEST = """
You are given a task: {task} from the user. Your task is to extract the type of category this task belongs to, it can be one or more of the following:
You are given a task: "{task}" from the user. You must extract the type of category this task belongs to, it can be one or more of the following:
- "object detection and counting" - detecting objects or counting objects from a text prompt in an image or video.
- "classification" - classifying objects in an image given a text prompt.
- "segmentation" - segmenting objects in an image or video given a text prompt.
- "OCR" - extracting text from an image.
- "VQA" - answering questions about an image or video, can also be used for text extraction.
- "DocQA" - answering questions about a document or extracting information from a document.
- "video object tracking" - tracking objects in a video.
- "depth and pose estimation" - estimating the depth or pose of objects in an image.
Return the category or categories (comma separated) inside tags <category># your categories here</category>.
Return the category or categories (comma separated) inside tags <category># your categories here</category>. If you are unsure about a task, it is better to include more categories than less.
"""

TEST_TOOLS = """
Expand All @@ -473,7 +474,7 @@ def merge_bounding_box_list(bboxes):
{examples}
**Instructions**:
1. List all the tools under **Tools** and the user request. Write a program to load the media and call every tool in parallel and print it's output along with other relevant information.
1. List all the tools under **Tools** and the user request. Write a program to load the media and call the most relevant tools in parallel and print it's output along with other relevant information.
2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary.
3. Your test case MUST run only on the given images which are {media}
4. Print this final dictionary.
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
flux_image_inpainting,
generate_pose_image,
get_tool_documentation,
get_tool_recommender,
gpt4o_image_vqa,
gpt4o_video_vqa,
load_image,
Expand All @@ -63,6 +62,7 @@
save_json,
save_video,
siglip_classification,
stella_embeddings,
template_match,
video_temporal_localization,
vit_image_classification,
Expand Down
9 changes: 4 additions & 5 deletions vision_agent/tools/planner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MimeType,
)
from vision_agent.utils.image_utils import convert_to_b64
from vision_agent.utils.sim import get_tool_recommender

TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}

Expand Down Expand Up @@ -116,13 +117,11 @@ def run_tool_testing(
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
category = extract_tag(query, "category") # type: ignore
if category is None:
category = task
query = task
else:
category = (
f"I need models from the {category.strip()} category of tools. {task}"
)
query = f"{category.strip()}. {task}"

tool_docs = T.get_tool_recommender().top_k(category, k=10, thresh=0.2)
tool_docs = get_tool_recommender().top_k(query, k=5, thresh=0.3)
if exclude_tools is not None and len(exclude_tools) > 0:
cleaned_tool_docs = []
for tool_doc in tool_docs:
Expand Down
45 changes: 28 additions & 17 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from base64 import b64encode
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from functools import lru_cache
from importlib import resources
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -49,7 +48,6 @@
rle_decode,
rle_decode_array,
)
from vision_agent.utils.sim import Sim, load_cached_sim
from vision_agent.utils.video import (
extract_frames_from_video,
frames_to_bytes,
Expand Down Expand Up @@ -85,11 +83,6 @@
_LOGGER = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def get_tool_recommender() -> Sim:
return load_cached_sim(TOOLS_DF)


def _display_tool_trace(
function_name: str,
request: Dict[str, Any],
Expand Down Expand Up @@ -2178,13 +2171,14 @@ def document_qa(
prompt: str,
image: np.ndarray,
) -> str:
"""'document_qa' is a tool that can answer any questions about arbitrary
images of documents or presentations. It answers by analyzing the contextual document data
and then using a model to answer specific questions. It returns text as an answer to the question.
"""'document_qa' is a tool that can answer any questions about arbitrary documents,
presentations, or tables. It's very useful for document QA tasks, you can ask it a
specific question or ask it to return a JSON object answering multiple questions
about the document.
Parameters:
prompt (str): The question to be answered about the document image
image (np.ndarray): The document image to analyze
prompt (str): The question to be answered about the document image.
image (np.ndarray): The document image to analyze.
Returns:
str: The answer to the question based on the document's context.
Expand All @@ -2203,7 +2197,7 @@ def document_qa(
"model": "document-analysis",
}

data: dict[str, Any] = send_inference_request(
data: Dict[str, Any] = send_inference_request(
payload=payload,
endpoint_name="document-analysis",
files=files,
Expand All @@ -2225,10 +2219,10 @@ def normalize(data: Any) -> Dict[str, Any]:
data = normalize(data)

prompt = f"""
Document Context:
{data}\n
Question: {prompt}\n
Please provide a clear, concise answer using only the information from the document. If the answer is not definitively contained in the document, say "I cannot find the answer in the provided document."
Document Context:
{data}\n
Question: {prompt}\n
Answer the question directly using only the information from the document, do not answer with any additional text besides the answer. If the answer is not definitively contained in the document, say "I cannot find the answer in the provided document."
"""

lmm = AnthropicLMM()
Expand All @@ -2245,6 +2239,22 @@ def normalize(data: Any) -> Dict[str, Any]:
return llm_output


def stella_embeddings(prompts: List[str]) -> List[np.ndarray]:
payload = {
"input": prompts,
"model": "stella1.5b",
}

data: Dict[str, Any] = send_inference_request(
payload=payload,
endpoint_name="embeddings",
v2=True,
metadata_payload={"function_name": "get_embeddings"},
is_form=True,
)
return [d["embedding"] for d in data] # type: ignore


# Utility and visualization functions


Expand Down Expand Up @@ -2781,6 +2791,7 @@ def _plot_counting(
qwen2_vl_images_vqa,
qwen2_vl_video_vqa,
document_extraction,
document_qa,
video_temporal_localization,
flux_image_inpainting,
siglip_classification,
Expand Down
1 change: 0 additions & 1 deletion vision_agent/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@
Result,
)
from .sim import AzureSim, OllamaSim, Sim, load_sim, merge_sim
from .video import extract_frames_from_video, video_writer
4 changes: 2 additions & 2 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from nbclient.exceptions import CellTimeoutError, DeadKernelError
from nbclient.util import run_sync
from nbformat.v4 import new_code_cell
from opentelemetry.context import get_current
from opentelemetry.trace import SpanKind, Status, StatusCode, get_tracer
from pydantic import BaseModel, field_serializer
from typing_extensions import Self
from opentelemetry.trace import get_tracer, Status, StatusCode, SpanKind
from opentelemetry.context import get_current

from vision_agent.utils.exceptions import (
RemoteSandboxCreationError,
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.video import extract_frames_from_video

COLORS = [
(158, 218, 229),
Expand Down
Loading

0 comments on commit 8dce01a

Please sign in to comment.