Skip to content

Commit

Permalink
Add cohere model provider
Browse files Browse the repository at this point in the history
  • Loading branch information
arkadyark-cohere committed Mar 25, 2024
1 parent b74c060 commit 39ef809
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
3 changes: 2 additions & 1 deletion needlehaystack/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .anthropic import Anthropic
from .cohere import Cohere
from .model import ModelProvider
from .openai import OpenAI
from .openai import OpenAI
57 changes: 57 additions & 0 deletions needlehaystack/providers/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import pkg_resources

from operator import itemgetter
from typing import Optional

from cohere import Client, AsyncClient

from .model import ModelProvider

class Cohere(ModelProvider):
DEFAULT_MODEL_KWARGS: dict = dict(max_tokens = 50,
temperature = 0.3)

def __init__(self,
model_name: str = "command-r",
model_kwargs: dict = DEFAULT_MODEL_KWARGS):
"""
:param model_name: The name of the model. Default is 'command-r'.
:param model_kwargs: Model configuration. Default is {max_tokens_to_sample: 300, temperature: 0}
"""

api_key = os.getenv('NIAH_MODEL_API_KEY')
if (not api_key):
raise ValueError("NIAH_MODEL_API_KEY must be in env.")

self.model_name = model_name
self.model_kwargs = model_kwargs
self.api_key = api_key

self.client = AsyncClient(api_key=self.api_key)

async def evaluate_model(self, prompt: str) -> str:
response = await self.client.chat(message=prompt[-1]["message"], chat_history=prompt[:-1], model=self.model_name, **self.model_kwargs)
return response.text

def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]:
return [{
"role": "System",
"message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
},
{
"role": "User",
"message": context
},
{
"role": "User",
"message": f"{retrieval_question} Don't give information outside the document or repeat your findings"
}]

def encode_text_to_tokens(self, text: str) -> list[int]:
if not text: return []
return Client().tokenize(text=text, model=self.model_name).tokens

def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
# Assuming you have a different decoder for Anthropic
return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text
6 changes: 4 additions & 2 deletions needlehaystack/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
from .providers import Anthropic, ModelProvider, OpenAI
from .providers import Anthropic, ModelProvider, OpenAI, Cohere

load_dotenv()

Expand Down Expand Up @@ -63,6 +63,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider:
return OpenAI(model_name=args.model_name)
case "anthropic":
return Anthropic(model_name=args.model_name)
case "cohere":
return Cohere(model_name=args.model_name)
case _:
raise ValueError(f"Invalid provider: {args.provider}")

Expand Down Expand Up @@ -109,4 +111,4 @@ def main():
tester.start_test()

if __name__ == "__main__":
main()
main()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ anyio==3.7.1
attrs==23.1.0
certifi==2023.11.17
charset-normalizer==3.3.2
cohere>=5.0.0
dataclasses-json==0.6.3
distro==1.8.0
filelock==3.13.1
Expand Down Expand Up @@ -46,4 +47,4 @@ tqdm==4.66.1
typing-inspect==0.9.0
typing_extensions==4.8.0
urllib3==2.1.0
yarl==1.9.3
yarl==1.9.3

0 comments on commit 39ef809

Please sign in to comment.