Skip to content

Commit

Permalink
Added OpenAI embedding agent for relevance scoring only. (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuhlmueller authored Aug 12, 2023
2 parents 36188ad + b821d8a commit 62f2635
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ice/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ice.agents.human import HumanAgent
from ice.agents.openai import OpenAIAgent
from ice.agents.openai import OpenAIChatCompletionAgent
from ice.agents.openai import OpenAIEmbeddingAgent
from ice.agents.openai_reasoning import OpenAIReasoningAgent
from ice.agents.ought_inference import OughtInferenceAgent
from ice.agents.squad import SquadAgent
Expand All @@ -27,6 +28,7 @@ def __init__(self, *args, **kwargs):
MACHINE_AGENTS = {
"chatgpt": lambda: OpenAIChatCompletionAgent(model="gpt-3.5-turbo"),
"gpt-4": lambda: OpenAIChatCompletionAgent(model="gpt-4"),
"embedding-ada": lambda: OpenAIEmbeddingAgent(model="text-embedding-ada-002"),
"instruct": lambda: OpenAIAgent(),
"instruct-reasoning": lambda: OpenAIReasoningAgent(),
"instruct-reasoning-crowd": lambda: OpenAIReasoningAgent(num_workers=8),
Expand Down
51 changes: 51 additions & 0 deletions ice/agents/openai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import math
from typing import Any
from typing import Optional
from typing import Union

import numpy as np
from structlog.stdlib import get_logger

from ice.agents.base import Agent
from ice.agents.base import Stop
from ice.apis.openai import openai_chatcomplete
from ice.apis.openai import openai_complete
from ice.apis.openai import openai_embedding
from ice.environment import env
from ice.utils import longest_common_prefix

Expand Down Expand Up @@ -227,3 +230,51 @@ def _extract_completion(self, response: dict) -> str:
def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)


class OpenAIEmbeddingAgent(Agent):
"""An agent that uses the OpenAI API to generate a relevance score by cosine similarity between two text embeddings."""

def __init__(
self,
model: str = "text-embedding-ada-002",
):
self.model = model

async def relevance(
self,
*,
context: str,
question: str,
verbose: bool = False,
default: Optional[float] = None,
) -> float:
"""Generate a relevance score (cosine similarity) from a context and a question."""
if verbose:
self._print_markdown(context)
self._print_markdown(question)
context_embedding_response = await openai_embedding(context, model=self.model)
question_embedding_response = await openai_embedding(question, model=self.model)

context_embedding = self._extract_embedding(context_embedding_response)
question_embedding = self._extract_embedding(question_embedding_response)

relevance = self._cosine_similarity(context_embedding, question_embedding)

if verbose:
self._print_markdown(relevance)
return relevance

def _extract_embedding(self, response: dict) -> list:
"""Extract the embedding from the response."""
return response["data"][0]["embedding"]

def _cosine_similarity(
self, a: Union[list, np.ndarray], b: Union[list, np.ndarray]
) -> float:
"""Compute the cosine similarity between two vectors."""
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)
18 changes: 18 additions & 0 deletions ice/apis/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,21 @@ async def openai_chatcomplete(
raise response
add_fields(total_tokens=extract_total_tokens(response))
return response


@trace
async def openai_embedding(
input: Union[str, list[str]],
model: str = "text-embedding-ada-002",
cache_id: int = 0, # for repeated non-deterministic sampling using caching
) -> dict:
"""Send an embedding request to the OpenAI API and return the JSON response."""
params = {
"input": input,
"model": model,
}
response = await _post("embeddings", json=params, cache_id=cache_id)
if isinstance(response, TooLongRequestError):
raise response
add_fields(total_tokens=extract_total_tokens(response))
return response

0 comments on commit 62f2635

Please sign in to comment.