-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for Together AI models (#130)
Added support for Together AI models for inference
- Loading branch information
1 parent
21672ec
commit 11d1b09
Showing
11 changed files
with
264 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Together AI models | ||
|
||
Tanuki now supports all models accessible by the Together AI API. Currently out of the box we support the following hosted models (more to be added soon) | ||
* teknium/OpenHermes-2p5-Mistral-7B | ||
* togethercomputer/llama-2-13b-chat | ||
* openchat/openchat-3.5-1210 | ||
* NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO | ||
* zero-one-ai/Yi-34B-Chat | ||
* mistralai/Mistral-7B-Instruct-v0.2 | ||
* mistralai/Mixtral-8x7B-Instruct-v0.1 | ||
|
||
|
||
To use Together AI models, firstly the Together AI extra package needs to be installed by `pip install tanuki.py[together_ai]`. When the package has been installed, a configuration flag for the teacher model needs to be sent to the `@tanuki.patch` decorator like shown below at the examples section. | ||
|
||
**NB** Currently model distillation is turned off for Together AI models. Model alignment, inference and saving datapoints to local datasets are still being carried out as expected. | ||
|
||
## Examples | ||
|
||
### Using the mistralai/Mixtral-8x7B-Instruct-v0.1 | ||
```python | ||
@tanuki.patch(teacher_models = ["Mixtral-8x7B"]) | ||
def example_function(input: TypedInput) -> TypedOutput: | ||
"""(Optional) Include the description of how your function will be used.""" | ||
|
||
@tanuki.align | ||
def test_example_function(): | ||
|
||
assert example_function(example_typed_input) == example_typed_output | ||
|
||
``` | ||
|
||
To use the other pre-implemented models, the following configuration should be sent in to the teacher_models attribute at the `@tanuki.patch` decorator | ||
* To use teknium/OpenHermes-2p5-Mistral-7B, teacher_models = ["OpenHermes-2p5-Mistral"] | ||
* To use togethercomputer/llama-2-13b-chat, teacher_models = ["llama13b-togetherai"] | ||
* To use openchat/openchat-3.5-1210, teacher_models = ["openchat-3.5"] | ||
* To use NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, teacher_models = ["Mixtral-8x7B-DPO"] | ||
* To use zero-one-ai/Yi-34B-Chat, teacher_models = ["Yi-34B-Chat"] | ||
* To use mistralai/Mistral-7B-Instruct-v0.2, teacher_models = ["Mistral-7B-Instruct-v0.2"] | ||
|
||
### Using another TogetherAI model that is not in the pre-implemented model list | ||
```python | ||
from tanuki.language_models.llm_configs import TogetherAIConfig | ||
model_config = TogetherAIConfig(model_name = "Open-Orca/Mistral-7B-OpenOrca", context_length = 8192) | ||
|
||
@tanuki.patch(teacher_models = [model_config]) | ||
def example_function(input: TypedInput) -> TypedOutput: | ||
"""(Optional) Include the description of how your function will be used.""" | ||
|
||
@tanuki.align | ||
def test_example_function(): | ||
|
||
assert example_function(example_typed_input) == example_typed_output | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from tanuki.language_models.llm_configs.abc_base_config import BaseModelConfig | ||
from tanuki.constants import TOGETHER_AI_PROVIDER | ||
|
||
class TogetherAIConfig(BaseModelConfig): | ||
model_name: str | ||
provider: str = TOGETHER_AI_PROVIDER | ||
context_length: int | ||
instructions : str = "You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format. The outputs will be between |START| and |END| tokens, the |START| token will be given in the prompt, use the |END| token to specify when the output ends. Only return the output to this input." | ||
parsing_helper_tokens : dict = {"start_token": "|START|", "end_token": "|END|"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import logging | ||
import time | ||
# import abstract base class | ||
from tanuki.language_models.llm_api_abc import LLM_API | ||
import os | ||
import together | ||
|
||
TOGETHER_AI_URL = "https://api.together.xyz/inference" | ||
import requests | ||
LLM_GENERATION_PARAMETERS = ["temperature", "top_p", "max_new_tokens", "frequency_penalty", "presence_penalty"] | ||
|
||
class TogetherAI_API(LLM_API): | ||
def __init__(self) -> None: | ||
# initialise the abstract base class | ||
super().__init__() | ||
|
||
self.api_key = os.environ.get("TOGETHER_API_KEY") | ||
self.model_configs = {} | ||
|
||
|
||
def generate(self, model, system_message, prompt, **kwargs): | ||
""" | ||
The main generation function, given the args, kwargs, function_modeler, function description and model type, generate a response | ||
Args | ||
model (OpenAIConfig): The model to use for generation. | ||
system_message (str): The system message to use for generation. | ||
prompt (str): The prompt to use for generation. | ||
kwargs (dict): Additional generation parameters. | ||
""" | ||
|
||
self.check_api_key() | ||
if model.model_name not in self.model_configs: | ||
self.model_configs[model.model_name] = together.Models.info(model.model_name)['config'] | ||
temperature = kwargs.get("temperature", 0.1) | ||
top_p = kwargs.get("top_p", 1) | ||
frequency_penalty = kwargs.get("frequency_penalty", 0) | ||
presence_penalty = kwargs.get("presence_penalty", 0) | ||
max_new_tokens = kwargs.get("max_new_tokens") | ||
# check if there are any generation parameters that are not supported | ||
unsupported_params = [param for param in kwargs.keys() if param not in LLM_GENERATION_PARAMETERS] | ||
if len(unsupported_params) > 0: | ||
# log warning | ||
logging.warning(f"Unused generation parameters sent as input: {unsupported_params}."\ | ||
f"For OpenAI, only the following parameters are supported: {LLM_GENERATION_PARAMETERS}") | ||
params = { | ||
"model": model.model_name, | ||
"temperature": temperature, | ||
"max_tokens": max_new_tokens, | ||
"top_p": top_p, | ||
"frequency_penalty": frequency_penalty, | ||
"presence_penalty": presence_penalty | ||
} | ||
if "stop" in self.model_configs[model.model_name]: | ||
params["stop"] = list(self.model_configs[model.model_name]["stop"]) | ||
if model.parsing_helper_tokens["end_token"]: | ||
params["stop"] = model.parsing_helper_tokens["end_token"] | ||
chat_prompt = model.chat_template | ||
if chat_prompt is None: | ||
try: | ||
prompt_format = str(self.model_configs[model.model_name]['prompt_format']) | ||
final_prompt = prompt_format.format(system_message=system_message, prompt=prompt) | ||
except: | ||
logging.warning("Chat prompt is not defined for this model. "\ | ||
"Please define it in the model config. Using default chat prompt") | ||
chat_prompt = "[INST]{system_message}[/INST]\n{user_prompt}" | ||
final_prompt = chat_prompt.format(system_message=system_message, user_prompt=prompt) | ||
else: | ||
final_prompt = chat_prompt.format(system_message=system_message, user_prompt=prompt) | ||
if model.parsing_helper_tokens["start_token"]: | ||
final_prompt += model.parsing_helper_tokens["start_token"] | ||
params["prompt"] = final_prompt | ||
|
||
counter = 0 | ||
choice = None | ||
# initiate response so exception logic doesnt error out when checking for error in response | ||
response = {} | ||
while counter <= 5: | ||
try: | ||
openai_headers = { | ||
"Authorization": f"Bearer {self.api_key}", | ||
"Content-Type": "application/json", | ||
} | ||
response = requests.post( | ||
TOGETHER_AI_URL, headers=openai_headers, json=params, timeout=50 | ||
) | ||
response = response.json() | ||
choice = response["output"]["choices"][0]["text"].strip("'") | ||
break | ||
except Exception as e: | ||
if ("error" in response and | ||
"code" in response["error"] and | ||
response["error"]["code"] == 'invalid_api_key'): | ||
raise Exception(f"The supplied Together AI API key {self.api_key} is invalid") | ||
if counter == 5: | ||
raise Exception(f"Together AI API failed to generate a response: {e}") | ||
counter += 1 | ||
time.sleep(2 ** counter) | ||
continue | ||
|
||
if not choice: | ||
raise Exception("TogetherAI API failed to generate a response") | ||
|
||
if model.parsing_helper_tokens["end_token"]: | ||
# remove the end token from the choice | ||
choice = choice.split(model.parsing_helper_tokens["end_token"])[0] | ||
# check if starting token is in choice | ||
if model.parsing_helper_tokens["start_token"] in choice: | ||
# remove the starting token from the choice | ||
choice = choice.split(model.parsing_helper_tokens["start_token"])[-1] | ||
return choice.strip() | ||
|
||
def check_api_key(self): | ||
# check if api key is not none | ||
if not self.api_key: | ||
# try to get the api key from the environment, maybe it has been set later | ||
self.api_key = os.getenv("TOGETHER_API_KEY") | ||
if not self.api_key: | ||
raise ValueError("TogetherAI API key is not set") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Optional, Literal | ||
from dotenv import load_dotenv | ||
import tanuki | ||
|
||
load_dotenv() | ||
|
||
@tanuki.patch(teacher_models=["openchat-3.5"], generation_params={"max_new_tokens": 10}) | ||
def classify_sentiment_2(input: str, input_2: str) -> Optional[Literal['Good', 'Bad']]: | ||
""" | ||
Determine if the inputs are positive or negative sentiment, or None | ||
""" | ||
|
||
|
||
@tanuki.patch(teacher_models=["openchat-3.5"]) | ||
def classify_sentiment(input: str) -> Optional[Literal['Good', 'Bad']]: | ||
""" | ||
Determine if the input is positive or negative sentiment | ||
""" | ||
|
||
@tanuki.align | ||
def align_classify_sentiment(): | ||
"""We can test the function as normal using Pytest or Unittest""" | ||
|
||
i_love_you = "I love you" | ||
assert classify_sentiment_2(i_love_you, "I love woo") == 'Good' | ||
assert classify_sentiment_2("I hate you", "You're discusting") == 'Bad' | ||
assert classify_sentiment_2("Today is wednesday", "The dogs are running outside") == None | ||
|
||
|
||
assert classify_sentiment("I love you") == 'Good' | ||
assert classify_sentiment("I hate you") == 'Bad' | ||
assert classify_sentiment("Wednesdays are in the middle of the week") == None | ||
|
||
def test_classify_sentiment(): | ||
align_classify_sentiment() | ||
bad_input = "I find you awful" | ||
good_input = "I really really like you" | ||
good_input_2 = "I adore you" | ||
assert classify_sentiment("I like you") == 'Good' | ||
assert classify_sentiment(bad_input) == 'Bad' | ||
assert classify_sentiment("I am neutral") == None | ||
|
||
assert classify_sentiment_2(good_input, good_input_2) == 'Good' | ||
assert classify_sentiment_2("I do not like you you", bad_input) == 'Bad' | ||
assert classify_sentiment_2("I am neutral", "I am neutral too") == None | ||
|
||
test_classify_sentiment() |