Skip to content

Commit

Permalink
Support for Together AI models (#130)
Browse files Browse the repository at this point in the history
Added support for Together AI models for inference
  • Loading branch information
MartBakler authored Feb 3, 2024
1 parent 21672ec commit 11d1b09
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 2 deletions.
54 changes: 54 additions & 0 deletions docs/together_ai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Together AI models

Tanuki now supports all models accessible by the Together AI API. Currently out of the box we support the following hosted models (more to be added soon)
* teknium/OpenHermes-2p5-Mistral-7B
* togethercomputer/llama-2-13b-chat
* openchat/openchat-3.5-1210
* NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
* zero-one-ai/Yi-34B-Chat
* mistralai/Mistral-7B-Instruct-v0.2
* mistralai/Mixtral-8x7B-Instruct-v0.1


To use Together AI models, firstly the Together AI extra package needs to be installed by `pip install tanuki.py[together_ai]`. When the package has been installed, a configuration flag for the teacher model needs to be sent to the `@tanuki.patch` decorator like shown below at the examples section.

**NB** Currently model distillation is turned off for Together AI models. Model alignment, inference and saving datapoints to local datasets are still being carried out as expected.

## Examples

### Using the mistralai/Mixtral-8x7B-Instruct-v0.1
```python
@tanuki.patch(teacher_models = ["Mixtral-8x7B"])
def example_function(input: TypedInput) -> TypedOutput:
"""(Optional) Include the description of how your function will be used."""

@tanuki.align
def test_example_function():

assert example_function(example_typed_input) == example_typed_output

```

To use the other pre-implemented models, the following configuration should be sent in to the teacher_models attribute at the `@tanuki.patch` decorator
* To use teknium/OpenHermes-2p5-Mistral-7B, teacher_models = ["OpenHermes-2p5-Mistral"]
* To use togethercomputer/llama-2-13b-chat, teacher_models = ["llama13b-togetherai"]
* To use openchat/openchat-3.5-1210, teacher_models = ["openchat-3.5"]
* To use NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, teacher_models = ["Mixtral-8x7B-DPO"]
* To use zero-one-ai/Yi-34B-Chat, teacher_models = ["Yi-34B-Chat"]
* To use mistralai/Mistral-7B-Instruct-v0.2, teacher_models = ["Mistral-7B-Instruct-v0.2"]

### Using another TogetherAI model that is not in the pre-implemented model list
```python
from tanuki.language_models.llm_configs import TogetherAIConfig
model_config = TogetherAIConfig(model_name = "Open-Orca/Mistral-7B-OpenOrca", context_length = 8192)

@tanuki.patch(teacher_models = [model_config])
def example_function(input: TypedInput) -> TypedOutput:
"""(Optional) Include the description of how your function will be used."""

@tanuki.align
def test_example_function():

assert example_function(example_typed_input) == example_typed_output

```
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
# Add any additional dependencies for the optional feature here
],
},
extras_require={
'together_ai': [
"together==0.2.11",
# Add any additional dependencies for the optional feature here
],
}
classifiers=[
'Development Status :: 3 - Alpha',
# Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
Expand Down
1 change: 1 addition & 0 deletions src/tanuki/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
OPENAI_PROVIDER = "openai"
LLAMA_BEDROCK_PROVIDER = "llama_bedrock"
TITAN_BEDROCK_PROVIDER = "aws_titan_bedrock"
TOGETHER_AI_PROVIDER = "together_ai"

# model type strings
TEACHER_MODEL = "teacher"
Expand Down
2 changes: 1 addition & 1 deletion src/tanuki/language_models/language_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def construct_prompt(self, f, args, kwargs, examples, model):
example_input = ""

instruction_prompt = model.instructions
content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\nInputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:{model.parsing_helper_tokens['start_token']}"
content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\nInputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:"
return content

def repair_generate(self, args, kwargs, f, failed_outputs_list, aligns, models, llm_parameters):
Expand Down
2 changes: 2 additions & 0 deletions src/tanuki/language_models/llama_bedrock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def generate(self, model: BaseModelConfig, system_message: str, prompt: str, **k
raise Exception("Chat prompt is not defined for this model"\
"Please define it in the model config")
final_prompt = chat_prompt.format(system_message=system_message, user_prompt=prompt)
if model.parsing_helper_tokens["start_token"]:
final_prompt += model.parsing_helper_tokens["start_token"]
body = json.dumps({
"prompt": final_prompt,
"max_gen_len": max_tokens_to_sample,
Expand Down
16 changes: 16 additions & 0 deletions src/tanuki/language_models/llm_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tanuki.language_models.llm_configs.claude_config import ClaudeConfig
from tanuki.language_models.llm_configs.llama_config import LlamaBedrockConfig
from tanuki.language_models.llm_configs.titan_config import TitanBedrockConfig
from tanuki.language_models.llm_configs.togetherai_config import TogetherAIConfig
DEFAULT_TEACHER_MODELS = {
"gpt-4-1106-preview": OpenAIConfig(model_name = "gpt-4-1106-preview", context_length = 128000),
"gpt-4": OpenAIConfig(model_name = "gpt-4", context_length = 8192),
Expand All @@ -17,6 +18,21 @@
"anthropic.claude-v2:1": ClaudeConfig(model_name = "anthropic.claude-v2:1", context_length = 200000),
"llama_70b_chat_aws": LlamaBedrockConfig(model_name = "meta.llama2-70b-chat-v1", context_length = 4096),
"llama_13b_chat_aws": LlamaBedrockConfig(model_name = "meta.llama2-13b-chat-v1", context_length = 4096),
"Mixtral-8x7B": TogetherAIConfig(model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1",
chat_template = "{user_prompt}", # for some reason this worked better than using their own supplied chat template
context_length = 32768),
"OpenHermes-2p5-Mistral": TogetherAIConfig(model_name = "teknium/OpenHermes-2p5-Mistral-7B",
context_length = 4096),
"llama13b-togetherai": TogetherAIConfig(model_name = "togethercomputer/llama-2-13b-chat",
context_length = 4096),
"openchat-3.5": TogetherAIConfig(model_name = "openchat/openchat-3.5-1210",
context_length = 8192),
"Mixtral-8x7B-DPO": TogetherAIConfig(model_name = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
context_length = 32768),
"Yi-34B-Chat": TogetherAIConfig(model_name = "zero-one-ai/Yi-34B-Chat",
context_length = 4096),
"Mistral-7B-Instruct-v0.2": TogetherAIConfig(model_name = "mistralai/Mistral-7B-Instruct-v0.2",
context_length = 32768),
}

DEFAULT_STUDENT_MODELS = {
Expand Down
9 changes: 9 additions & 0 deletions src/tanuki/language_models/llm_configs/togetherai_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from tanuki.language_models.llm_configs.abc_base_config import BaseModelConfig
from tanuki.constants import TOGETHER_AI_PROVIDER

class TogetherAIConfig(BaseModelConfig):
model_name: str
provider: str = TOGETHER_AI_PROVIDER
context_length: int
instructions : str = "You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format. The outputs will be between |START| and |END| tokens, the |START| token will be given in the prompt, use the |END| token to specify when the output ends. Only return the output to this input."
parsing_helper_tokens : dict = {"start_token": "|START|", "end_token": "|END|"}
2 changes: 2 additions & 0 deletions src/tanuki/language_models/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def generate(self, model, system_message, prompt, **kwargs):
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
}
if model.parsing_helper_tokens["start_token"]:
prompt += model.parsing_helper_tokens["start_token"]
messages = [
{
"role": "system",
Expand Down
118 changes: 118 additions & 0 deletions src/tanuki/language_models/togetherai_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import logging
import time
# import abstract base class
from tanuki.language_models.llm_api_abc import LLM_API
import os
import together

TOGETHER_AI_URL = "https://api.together.xyz/inference"
import requests
LLM_GENERATION_PARAMETERS = ["temperature", "top_p", "max_new_tokens", "frequency_penalty", "presence_penalty"]

class TogetherAI_API(LLM_API):
def __init__(self) -> None:
# initialise the abstract base class
super().__init__()

self.api_key = os.environ.get("TOGETHER_API_KEY")
self.model_configs = {}


def generate(self, model, system_message, prompt, **kwargs):
"""
The main generation function, given the args, kwargs, function_modeler, function description and model type, generate a response
Args
model (OpenAIConfig): The model to use for generation.
system_message (str): The system message to use for generation.
prompt (str): The prompt to use for generation.
kwargs (dict): Additional generation parameters.
"""

self.check_api_key()
if model.model_name not in self.model_configs:
self.model_configs[model.model_name] = together.Models.info(model.model_name)['config']
temperature = kwargs.get("temperature", 0.1)
top_p = kwargs.get("top_p", 1)
frequency_penalty = kwargs.get("frequency_penalty", 0)
presence_penalty = kwargs.get("presence_penalty", 0)
max_new_tokens = kwargs.get("max_new_tokens")
# check if there are any generation parameters that are not supported
unsupported_params = [param for param in kwargs.keys() if param not in LLM_GENERATION_PARAMETERS]
if len(unsupported_params) > 0:
# log warning
logging.warning(f"Unused generation parameters sent as input: {unsupported_params}."\
f"For OpenAI, only the following parameters are supported: {LLM_GENERATION_PARAMETERS}")
params = {
"model": model.model_name,
"temperature": temperature,
"max_tokens": max_new_tokens,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty
}
if "stop" in self.model_configs[model.model_name]:
params["stop"] = list(self.model_configs[model.model_name]["stop"])
if model.parsing_helper_tokens["end_token"]:
params["stop"] = model.parsing_helper_tokens["end_token"]
chat_prompt = model.chat_template
if chat_prompt is None:
try:
prompt_format = str(self.model_configs[model.model_name]['prompt_format'])
final_prompt = prompt_format.format(system_message=system_message, prompt=prompt)
except:
logging.warning("Chat prompt is not defined for this model. "\
"Please define it in the model config. Using default chat prompt")
chat_prompt = "[INST]{system_message}[/INST]\n{user_prompt}"
final_prompt = chat_prompt.format(system_message=system_message, user_prompt=prompt)
else:
final_prompt = chat_prompt.format(system_message=system_message, user_prompt=prompt)
if model.parsing_helper_tokens["start_token"]:
final_prompt += model.parsing_helper_tokens["start_token"]
params["prompt"] = final_prompt

counter = 0
choice = None
# initiate response so exception logic doesnt error out when checking for error in response
response = {}
while counter <= 5:
try:
openai_headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
response = requests.post(
TOGETHER_AI_URL, headers=openai_headers, json=params, timeout=50
)
response = response.json()
choice = response["output"]["choices"][0]["text"].strip("'")
break
except Exception as e:
if ("error" in response and
"code" in response["error"] and
response["error"]["code"] == 'invalid_api_key'):
raise Exception(f"The supplied Together AI API key {self.api_key} is invalid")
if counter == 5:
raise Exception(f"Together AI API failed to generate a response: {e}")
counter += 1
time.sleep(2 ** counter)
continue

if not choice:
raise Exception("TogetherAI API failed to generate a response")

if model.parsing_helper_tokens["end_token"]:
# remove the end token from the choice
choice = choice.split(model.parsing_helper_tokens["end_token"])[0]
# check if starting token is in choice
if model.parsing_helper_tokens["start_token"] in choice:
# remove the starting token from the choice
choice = choice.split(model.parsing_helper_tokens["start_token"])[-1]
return choice.strip()

def check_api_key(self):
# check if api key is not none
if not self.api_key:
# try to get the api key from the environment, maybe it has been set later
self.api_key = os.getenv("TOGETHER_API_KEY")
if not self.api_key:
raise ValueError("TogetherAI API key is not set")
9 changes: 8 additions & 1 deletion src/tanuki/models/api_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Any, Dict
from tanuki.constants import OPENAI_PROVIDER, LLAMA_BEDROCK_PROVIDER, TITAN_BEDROCK_PROVIDER
from tanuki.constants import OPENAI_PROVIDER, LLAMA_BEDROCK_PROVIDER, TITAN_BEDROCK_PROVIDER, TOGETHER_AI_PROVIDER


class APIManager(object):
Expand Down Expand Up @@ -50,6 +50,13 @@ def add_api_provider(self, provider):
except ImportError:
raise Exception(f"You need to install the Tanuki aws_bedrock package to use the titan_bedrock api provider."\
"Please install it as pip install tanuki.py[aws_bedrock]")
elif provider == TOGETHER_AI_PROVIDER:
try:
from tanuki.language_models.togetherai_api import TogetherAI_API
self.api_providers[provider] = TogetherAI_API()
except ImportError:
raise Exception(f"You need to install the Tanuki together_ai package to use the together ai api provider."\
"Please install it as pip install tanuki.py[together_ai]")
else:
raise Exception(f"Model provider {provider} is currently not supported."\
"If you have integrated a new provider, please add it to the api manager in the APIManager object "\
Expand Down
47 changes: 47 additions & 0 deletions tests/test_patch/test_classification_togetherai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, Literal
from dotenv import load_dotenv
import tanuki

load_dotenv()

@tanuki.patch(teacher_models=["openchat-3.5"], generation_params={"max_new_tokens": 10})
def classify_sentiment_2(input: str, input_2: str) -> Optional[Literal['Good', 'Bad']]:
"""
Determine if the inputs are positive or negative sentiment, or None
"""


@tanuki.patch(teacher_models=["openchat-3.5"])
def classify_sentiment(input: str) -> Optional[Literal['Good', 'Bad']]:
"""
Determine if the input is positive or negative sentiment
"""

@tanuki.align
def align_classify_sentiment():
"""We can test the function as normal using Pytest or Unittest"""

i_love_you = "I love you"
assert classify_sentiment_2(i_love_you, "I love woo") == 'Good'
assert classify_sentiment_2("I hate you", "You're discusting") == 'Bad'
assert classify_sentiment_2("Today is wednesday", "The dogs are running outside") == None


assert classify_sentiment("I love you") == 'Good'
assert classify_sentiment("I hate you") == 'Bad'
assert classify_sentiment("Wednesdays are in the middle of the week") == None

def test_classify_sentiment():
align_classify_sentiment()
bad_input = "I find you awful"
good_input = "I really really like you"
good_input_2 = "I adore you"
assert classify_sentiment("I like you") == 'Good'
assert classify_sentiment(bad_input) == 'Bad'
assert classify_sentiment("I am neutral") == None

assert classify_sentiment_2(good_input, good_input_2) == 'Good'
assert classify_sentiment_2("I do not like you you", bad_input) == 'Bad'
assert classify_sentiment_2("I am neutral", "I am neutral too") == None

test_classify_sentiment()

0 comments on commit 11d1b09

Please sign in to comment.