From 00417edb5a6162435d2f027617f86b8855cb31cc Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 2 Feb 2024 19:21:19 +0200 Subject: [PATCH] Custom Model Client support (#1345) * add client interface, response protocol, and move code into openai client class * add ability to register custom client * tidy up code * adding checks and errors, and more unit tests * remove code * fix error msg * add use_docer False in notebook * better error message * add another example to custom model notebook * rename and have register_client take model name too * make Client protocol and remove inheritance * renames * update notebook * add link * rename and more error checking for registered agents * adding message retrieval to client protocol for more flexible response * fix failing openai test * api_type req made model_client_cls requirement * notebook cleanup and added blog * remove raise error if client list is empty - client list will never be empty it will have placeholders * rename Client -> ModelClient * add forgotten file * fix test by fetching internal client * Update autogen/oai/client.py Co-authored-by: Eric Zhu * don't add retrieval function to cache * added placeholder cllient class during initial client init, and rewrote registration * fix spelling * Update autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang * type hints, small fixes, docstr comment * fix api type checking --------- Co-authored-by: Eric Zhu Co-authored-by: Chi Wang --- .../agentchat/contrib/gpt_assistant_agent.py | 2 +- autogen/agentchat/conversable_agent.py | 11 +- autogen/oai/__init__.py | 3 +- autogen/oai/client.py | 571 +++++++----- notebook/agentchat_custom_model.ipynb | 848 ++++++++++++++++++ test/agentchat/contrib/test_gpt_assistant.py | 6 +- test/oai/test_custom_client.py | 166 ++++ .../blog/2024-01-26-Custom-Models/index.mdx | 170 ++++ 8 files changed, 1554 insertions(+), 223 deletions(-) create mode 100644 notebook/agentchat_custom_model.ipynb create mode 100644 test/oai/test_custom_client.py create mode 100644 website/blog/2024-01-26-Custom-Models/index.mdx diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index a8419f0d2ad3..b588b2b59f5a 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -56,7 +56,7 @@ def __init__( oai_wrapper = OpenAIWrapper(**llm_config) if len(oai_wrapper._clients) > 1: logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.") - self._openai_client = oai_wrapper._clients[0] + self._openai_client = oai_wrapper._clients[0]._oai_client openai_assistant_id = llm_config.get("assistant_id", None) if openai_assistant_id is None: # try to find assistant by name first diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index ecb1c996599c..e3fb3d8be4a5 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -9,7 +9,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union import warnings -from .. import OpenAIWrapper +from .. import OpenAIWrapper, ModelClient from ..cache.cache import Cache from ..code_utils import ( DEFAULT_MODEL, @@ -1946,6 +1946,15 @@ def _decorator(func: F) -> F: return _decorator + def register_model_client(self, model_client_cls: ModelClient, **kwargs): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the Client interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + self.client.register_model_client(model_client_cls, **kwargs) + def register_hook(self, hookable_method: Callable, hook: Callable): """ Registers a hook to be called by a hookable method, in order to add a capability to the agent. diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py index 92791fd5c0ec..1cf57f04456a 100644 --- a/autogen/oai/__init__.py +++ b/autogen/oai/__init__.py @@ -1,4 +1,4 @@ -from autogen.oai.client import OpenAIWrapper +from autogen.oai.client import OpenAIWrapper, ModelClient from autogen.oai.completion import Completion, ChatCompletion from autogen.oai.openai_utils import ( get_config_list, @@ -12,6 +12,7 @@ __all__ = [ "OpenAIWrapper", + "ModelClient", "Completion", "ChatCompletion", "get_config_list", diff --git a/autogen/oai/client.py b/autogen/oai/client.py index b9b07d569ce5..56167f978c66 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -8,6 +8,7 @@ from flaml.automl.logger import logger_formatter from pydantic import BaseModel +from typing import Protocol from autogen.cache.cache import Cache from autogen.oai import completion @@ -52,6 +53,246 @@ LEGACY_CACHE_DIR = ".cache" +class ModelClient(Protocol): + """ + A client class must implement the following methods: + - create must return a response object that implements the ModelClientResponseProtocol + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model + + This class is used to create a client that can be used by OpenAIWrapper. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ModelClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: Optional[str] + + choices: List[Choice] + model: str + + def create(self, **params: Any) -> ModelClientResponseProtocol: + ... # pragma: no cover + + def message_retrieval( + self, response: ModelClientResponseProtocol + ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... # pragma: no cover + + def cost(self, response: ModelClientResponseProtocol) -> float: + ... # pragma: no cover + + @staticmethod + def get_usage(response: ModelClientResponseProtocol) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... # pragma: no cover + + +class PlaceHolderClient: + def __init__(self, config): + self.config = config + + +class OpenAIClient: + """Follows the Client protocol and wraps the OpenAI client.""" + + def __init__(self, client: Union[OpenAI, AzureOpenAI]): + self._oai_client = client + + def message_retrieval( + self, response: Union[ChatCompletion, Completion] + ) -> Union[List[str], List[ChatCompletionMessage]]: + """Retrieve the messages from the response.""" + choices = response.choices + if isinstance(response, Completion): + return [choice.text for choice in choices] # type: ignore [union-attr] + + if TOOL_ENABLED: + return [ # type: ignore [return-value] + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content # type: ignore [union-attr] + for choice in choices + ] + else: + return [ # type: ignore [return-value] + choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] + for choice in choices + ] + + def create(self, params: Dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config using openai's client. + + Args: + client: The openai client. + params: The params for the completion. + + Returns: + The completion. + """ + completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined] + # If streaming is enabled and has messages, then iterate over the chunks of the response. + if params.get("stream", False) and "messages" in params: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + # Set the terminal text color to green + print("\033[32m", end="") + + # Prepare for potential function call + full_function_call: Optional[Dict[str, Any]] = None + full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None + + # Send the chat completion request to OpenAI's API and process the response in chunks + for chunk in completions.create(**params): + if chunk.choices: + for choice in chunk.choices: + content = choice.delta.content + tool_calls_chunks = choice.delta.tool_calls + finish_reasons[choice.index] = choice.finish_reason + + # todo: remove this after function calls are removed from the API + # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail + # begin block + function_call_chunk = ( + choice.delta.function_call if hasattr(choice.delta, "function_call") else None + ) + # Handle function call + if function_call_chunk: + # Handle function call + if function_call_chunk: + full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + function_call_chunk, full_function_call, completion_tokens + ) + if not content: + continue + # end block + + # Handle tool calls + if tool_calls_chunks: + for tool_calls_chunk in tool_calls_chunks: + # the current tool call to be reconstructed + ix = tool_calls_chunk.index + if full_tool_calls is None: + full_tool_calls = [] + if ix >= len(full_tool_calls): + # in case ix is not sequential + full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) + + full_tool_calls[ix], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk( + tool_calls_chunk, full_tool_calls[ix], completion_tokens + ) + if not content: + continue + + # End handle tool calls + + # If content is present, print it to the terminal and update response variables + if content is not None: + print(content, end="", flush=True) + response_contents[choice.index] += content + completion_tokens += 1 + else: + # print() + pass + + # Reset the terminal text color + print("\033[0m\n") + + # Prepare the final ChatCompletion object based on the accumulated data + model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API + prompt_tokens = count_token(params["messages"], model) + response = ChatCompletion( + id=chunk.id, + model=chunk.model, + created=chunk.created, + object="chat.completion", + choices=[], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + for i in range(len(response_contents)): + if OPENAIVERSION >= "1.5": # pragma: no cover + # OpenAI versions 1.5.0 and above + choice = Choice( + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, + ), + logprobs=None, + ) + else: + # OpenAI versions below 1.5.0 + choice = Choice( # type: ignore [call-arg] + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, + ), + ) + + response.choices.append(choice) + else: + # If streaming is not enabled, send a regular chat completion request + params = params.copy() + params["stream"] = False + response = completions.create(**params) + + return response + + def cost(self, response: Union[ChatCompletion, Completion]) -> float: + """Calculate the cost of the response.""" + model = response.model + if model not in OAI_PRICE1K: + # TODO: add logging to warn that the model is not found + logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) + return 0 + + n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] + tmp_price1K = OAI_PRICE1K[model] + # First value is input token rate, second value is output token rate + if isinstance(tmp_price1K, tuple): + return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] + return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] + + @staticmethod + def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + class OpenAIWrapper: """A wrapper class for openai client.""" @@ -106,17 +347,19 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base openai_config, extra_kwargs = self._separate_openai_config(base_config) if type(config_list) is list and len(config_list) == 0: logger.warning("openai client was provided with an empty config_list, which may not be intended.") + + self._clients: List[ModelClient] = [] + self._config_list: List[Dict[str, Any]] = [] + if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying - self._clients: List[OpenAI] = [ - self._client(config, openai_config) for config in config_list - ] # could modify the config - self._config_list = [ - {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} - for config in config_list - ] + for config in config_list: + self._register_default_client(config, openai_config) # could modify the config + self._config_list.append( + {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} + ) else: - self._clients = [self._client(extra_kwargs, openai_config)] + self._register_default_client(extra_kwargs, openai_config) self._config_list = [extra_kwargs] def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -131,7 +374,13 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI: + def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: + openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) + if openai_config["azure_deployment"] is not None: + openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") + openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) + + def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -142,15 +391,49 @@ def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> Open """ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} api_type = config.get("api_type") - if api_type is not None and api_type.startswith("azure"): - openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) - if openai_config["azure_deployment"] is not None: - openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") - openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) - client = AzureOpenAI(**openai_config) + model_client_cls_name = config.get("model_client_cls") + if model_client_cls_name is not None: + # a config for a custom client is set + # adding placeholder until the register_model_client is called with the appropriate class + self._clients.append(PlaceHolderClient(config)) + logger.info( + f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called." + ) else: - client = OpenAI(**openai_config) - return client + if api_type is not None and api_type.startswith("azure"): + self._configure_azure_openai(config, openai_config) + self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) + else: + self._clients.append(OpenAIClient(OpenAI(**openai_config))) + + def register_model_client(self, model_client_cls: ModelClient, **kwargs): + """Register a model client. + + Args: + model_client_cls: A custom client class that follows the ModelClient interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + existing_client_class = False + for i, client in enumerate(self._clients): + if isinstance(client, PlaceHolderClient): + placeholder_config = client.config + + if placeholder_config in self._config_list: + if placeholder_config.get("model_client_cls") == model_client_cls.__name__: + self._clients[i] = model_client_cls(placeholder_config, **kwargs) + return + elif isinstance(client, model_client_cls): + existing_client_class = True + + if existing_client_class: + logger.warn( + f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients." + ) + else: + raise ValueError( + f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. ' + f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"' + ) @classmethod def instantiate( @@ -196,9 +479,9 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params - def create(self, **config: Any) -> ChatCompletion: - """Make a completion for a given config using openai's clients. - Besides the kwargs allowed in openai's client, we allow the following additional kwargs. + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: + """Make a completion for a given config using available clients. + Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config. Args: @@ -228,10 +511,21 @@ def yes_or_no_filter(context, response): - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. - api_version (str | None): The api version. Default to None. E.g., "2023-08-01-preview". + Raises: + - RuntimeError: If all declared custom model clients are not registered + - APIError: If any model client create call raises an APIError """ if ERROR: raise ERROR last = len(self._clients) - 1 + # Check if all configs in config list are activated + non_activated = [ + client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient) + ] + if non_activated: + raise RuntimeError( + f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list." + ) for i, client in enumerate(self._clients): # merge the input config with the i-th config in the config list full_config = {**config, **self._config_list[i]} @@ -248,6 +542,9 @@ def yes_or_no_filter(context, response): filter_func = extra_kwargs.get("filter_func") context = extra_kwargs.get("context") + total_usage = None + actual_usage = None + cache_client = None if cache is not None: # Use the cache object if provided. @@ -260,26 +557,28 @@ def yes_or_no_filter(context, response): with cache_client as cache: # Try to get the response from cache key = get_key(params) - response: ChatCompletion = cache.get(key, None) + response: ModelClient.ModelClientResponseProtocol = cache.get(key, None) if response is not None: + response.message_retrieval_function = client.message_retrieval try: response.cost # type: ignore [attr-defined] except AttributeError: # update attribute if cost is not calculated - response.cost = self.cost(response) + response.cost = client.cost(response) cache.set(key, response) - self._update_usage_summary(response, use_cache=True) + total_usage = client.get_usage(response) # check the filter pass_filter = filter_func is None or filter_func(context=context, response=response) if pass_filter or i == last: # Return the response if it passes the filter or it is the last client response.config_id = i response.pass_filter = pass_filter + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) return response continue # filter is not passed; try the next config try: - response = self._completions_create(client, params) + response = client.create(params) except APIError as err: error_code = getattr(err, "code", None) if error_code == "content_filter": @@ -290,13 +589,16 @@ def yes_or_no_filter(context, response): raise else: # add cost calculation before caching no matter filter is passed or not - response.cost = self.cost(response) - self._update_usage_summary(response, use_cache=False) + response.cost = client.cost(response) + actual_usage = client.get_usage(response) + total_usage = actual_usage.copy() if actual_usage is not None else total_usage + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) if cache_client is not None: # Cache the response with cache_client as cache: cache.set(key, response) + response.message_retrieval_function = client.message_retrieval # check the filter pass_filter = filter_func is None or filter_func(context=context, response=response) if pass_filter or i == last: @@ -417,170 +719,36 @@ def _update_tool_calls_from_chunk( else: raise RuntimeError("Tool call is not found, this should not happen.") - def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion: - """Create a completion for a given config using openai's client. - - Args: - client: The openai client. - params: The params for the completion. - - Returns: - The completion. - """ - completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined] - # If streaming is enabled and has messages, then iterate over the chunks of the response. - if params.get("stream", False) and "messages" in params: - response_contents = [""] * params.get("n", 1) - finish_reasons = [""] * params.get("n", 1) - completion_tokens = 0 - - # Set the terminal text color to green - print("\033[32m", end="") + def _update_usage(self, actual_usage, total_usage): + def update_usage(usage_summary, response_usage): + # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary + for key in ModelClient.RESPONSE_USAGE_KEYS: + if key not in response_usage: + return usage_summary - # Prepare for potential function call - full_function_call: Optional[Dict[str, Any]] = None - full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None - - # Send the chat completion request to OpenAI's API and process the response in chunks - for chunk in completions.create(**params): - if chunk.choices: - for choice in chunk.choices: - content = choice.delta.content - tool_calls_chunks = choice.delta.tool_calls - finish_reasons[choice.index] = choice.finish_reason + model = response_usage["model"] + cost = response_usage["cost"] + prompt_tokens = response_usage["prompt_tokens"] + completion_tokens = response_usage["completion_tokens"] + total_tokens = response_usage["total_tokens"] - # todo: remove this after function calls are removed from the API - # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail - # begin block - function_call_chunk = ( - choice.delta.function_call if hasattr(choice.delta, "function_call") else None - ) - # Handle function call - if function_call_chunk: - # Handle function call - if function_call_chunk: - full_function_call, completion_tokens = self._update_function_call_from_chunk( - function_call_chunk, full_function_call, completion_tokens - ) - if not content: - continue - # end block - - # Handle tool calls - if tool_calls_chunks: - for tool_calls_chunk in tool_calls_chunks: - # the current tool call to be reconstructed - ix = tool_calls_chunk.index - if full_tool_calls is None: - full_tool_calls = [] - if ix >= len(full_tool_calls): - # in case ix is not sequential - full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) - - full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk( - tool_calls_chunk, full_tool_calls[ix], completion_tokens - ) - if not content: - continue - - # End handle tool calls - - # If content is present, print it to the terminal and update response variables - if content is not None: - print(content, end="", flush=True) - response_contents[choice.index] += content - completion_tokens += 1 - else: - # print() - pass - - # Reset the terminal text color - print("\033[0m\n") - - # Prepare the final ChatCompletion object based on the accumulated data - model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API - prompt_tokens = count_token(params["messages"], model) - response = ChatCompletion( - id=chunk.id, - model=chunk.model, - created=chunk.created, - object="chat.completion", - choices=[], - usage=CompletionUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - for i in range(len(response_contents)): - if OPENAIVERSION >= "1.5": # pragma: no cover - # OpenAI versions 1.5.0 and above - choice = Choice( - index=i, - finish_reason=finish_reasons[i], - message=ChatCompletionMessage( - role="assistant", - content=response_contents[i], - function_call=full_function_call, - tool_calls=full_tool_calls, - ), - logprobs=None, - ) - else: - # OpenAI versions below 1.5.0 - choice = Choice( # type: ignore [call-arg] - index=i, - finish_reason=finish_reasons[i], - message=ChatCompletionMessage( - role="assistant", - content=response_contents[i], - function_call=full_function_call, - tool_calls=full_tool_calls, - ), - ) - - response.choices.append(choice) - else: - # If streaming is not enabled, send a regular chat completion request - params = params.copy() - params["stream"] = False - response = completions.create(**params) - - return response - - def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None: - """Update the usage summary. - - Usage is calculated no matter filter is passed or not. - """ - try: - usage = response.usage - assert usage is not None - usage.prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens - usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens - usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens - except (AttributeError, AssertionError): - logger.debug("Usage attribute is not found in the response.", exc_info=True) - return - - def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]: if usage_summary is None: - usage_summary = {"total_cost": response.cost} # type: ignore [union-attr] + usage_summary = {"total_cost": cost} else: - usage_summary["total_cost"] += response.cost # type: ignore [union-attr] - - usage_summary[response.model] = { - "cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr] - "prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens, - "completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0) - + usage.completion_tokens, - "total_tokens": usage_summary.get(response.model, {}).get("total_tokens", 0) + usage.total_tokens, + usage_summary["total_cost"] += cost + + usage_summary[model] = { + "cost": usage_summary.get(model, {}).get("cost", 0) + cost, + "prompt_tokens": usage_summary.get(model, {}).get("prompt_tokens", 0) + prompt_tokens, + "completion_tokens": usage_summary.get(model, {}).get("completion_tokens", 0) + completion_tokens, + "total_tokens": usage_summary.get(model, {}).get("total_tokens", 0) + total_tokens, } return usage_summary - self.total_usage_summary = update_usage(self.total_usage_summary) - if not use_cache: - self.actual_usage_summary = update_usage(self.actual_usage_summary) + if total_usage is not None: + self.total_usage_summary = update_usage(self.total_usage_summary, total_usage) + if actual_usage is not None: + self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage) def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: """Print the usage summary.""" @@ -639,26 +807,10 @@ def clear_usage_summary(self) -> None: self.total_usage_summary = None self.actual_usage_summary = None - def cost(self, response: Union[ChatCompletion, Completion]) -> float: - """Calculate the cost of the response.""" - model = response.model - if model not in OAI_PRICE1K: - # TODO: add logging to warn that the model is not found - logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) - return 0 - - n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] - n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] - tmp_price1K = OAI_PRICE1K[model] - # First value is input token rate, second value is output token rate - if isinstance(tmp_price1K, tuple): - return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] - return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] - @classmethod def extract_text_or_completion_object( - cls, response: Union[ChatCompletion, Completion] - ) -> Union[List[str], List[ChatCompletionMessage]]: + cls, response: ModelClient.ModelClientResponseProtocol + ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: """Extract the text or ChatCompletion objects from a completion or chat response. Args: @@ -667,22 +819,7 @@ def extract_text_or_completion_object( Returns: A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. """ - choices = response.choices - if isinstance(response, Completion): - return [choice.text for choice in choices] # type: ignore [union-attr] - - if TOOL_ENABLED: - return [ # type: ignore [return-value] - choice.message # type: ignore [union-attr] - if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] - else choice.message.content # type: ignore [union-attr] - for choice in choices - ] - else: - return [ # type: ignore [return-value] - choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] - for choice in choices - ] + return response.message_retrieval_function(response) # TODO: logging diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb new file mode 100644 index 000000000000..b58b5d93a055 --- /dev/null +++ b/notebook/agentchat_custom_model.ipynb @@ -0,0 +1,848 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Agent Chat with custom model loading\n", + "\n", + "In this notebook, we demonstrate how a custom model can be defined and loaded, and what protocol it needs to comply to.\n", + "\n", + "**NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's**\n", + "\n", + "## Requirements\n", + "\n", + "AutoGen requires `Python>=3.8`. To run this notebook example, please install:\n", + "```bash\n", + "pip install pyautogen torch transformers sentencepiece\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-02-13T23:40:52.317406Z", + "iopub.status.busy": "2023-02-13T23:40:52.316561Z", + "iopub.status.idle": "2023-02-13T23:40:52.321193Z", + "shell.execute_reply": "2023-02-13T23:40:52.320628Z" + } + }, + "outputs": [], + "source": [ + "# %pip install pyautogen~=0.2.0b4 torch git+https://github.com/huggingface/transformers sentencepiece" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import autogen\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from transformers import AutoTokenizer, GenerationConfig, AutoModelForCausalLM\n", + "from types import SimpleNamespace" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and configure the custom model\n", + "\n", + "A custom model class can be created in many ways, but needs to adhere to the `ModelClient` protocol and response structure which is defined in client.py and shown below.\n", + "\n", + "The response protocol has some minimum requirements, but can be extended to include any additional information that is needed.\n", + "Message retrieval therefore can be customized, but needs to return a list of strings or a list of `ModelClientResponseProtocol.Choice.Message` objects.\n", + "\n", + "\n", + "```python\n", + "class ModelClient(Protocol):\n", + " \"\"\"\n", + " A client class must implement the following methods:\n", + " - create must return a response object that implements the ModelClientResponseProtocol\n", + " - cost must return the cost of the response\n", + " - get_usage must return a dict with the following keys:\n", + " - prompt_tokens\n", + " - completion_tokens\n", + " - total_tokens\n", + " - cost\n", + " - model\n", + "\n", + " This class is used to create a client that can be used by OpenAIWrapper.\n", + " The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.\n", + " The message_retrieval method must be implemented to return a list of str or a list of messages from the response.\n", + " \"\"\"\n", + "\n", + " RESPONSE_USAGE_KEYS = [\"prompt_tokens\", \"completion_tokens\", \"total_tokens\", \"cost\", \"model\"]\n", + "\n", + " class ModelClientResponseProtocol(Protocol):\n", + " class Choice(Protocol):\n", + " class Message(Protocol):\n", + " content: str | None\n", + "\n", + " choices: List[Choice]\n", + " model: str\n", + "\n", + " def create(self, params) -> ModelClientResponseProtocol:\n", + " ...\n", + "\n", + " def message_retrieval(\n", + " self, response: ModelClientResponseProtocol\n", + " ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:\n", + " \"\"\"\n", + " Retrieve and return a list of strings or a list of Choice.Message from the response.\n", + "\n", + " NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,\n", + " since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def cost(self, response: ModelClientResponseProtocol) -> float:\n", + " ...\n", + "\n", + " @staticmethod\n", + " def get_usage(response: ModelClientResponseProtocol) -> Dict:\n", + " \"\"\"Return usage summary of the response using RESPONSE_USAGE_KEYS.\"\"\"\n", + " ...\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example of simple custom client\n", + "\n", + "Following the huggingface example for using [Mistral's Open-Orca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca)\n", + "\n", + "For the response object, python's `SimpleNamespace` is used to create a simple object that can be used to store the response data, but any object that follows the `ClientResponseProtocol` can be used.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# custom client with custom model loader\n", + "\n", + "\n", + "class CustomModelClient:\n", + " def __init__(self, config, **kwargs):\n", + " print(f\"CustomModelClient config: {config}\")\n", + " self.device = config.get(\"device\", \"cpu\")\n", + " self.model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(self.device)\n", + " self.model_name = config[\"model\"]\n", + " self.tokenizer = AutoTokenizer.from_pretrained(config[\"model\"], use_fast=False)\n", + " self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n", + "\n", + " # params are set by the user and consumed by the user since they are providing a custom model\n", + " # so anything can be done here\n", + " gen_config_params = config.get(\"params\", {})\n", + " self.max_length = gen_config_params.get(\"max_length\", 256)\n", + "\n", + " print(f\"Loaded model {config['model']} to {self.device}\")\n", + "\n", + " def create(self, params):\n", + " if params.get(\"stream\", False) and \"messages\" in params:\n", + " raise NotImplementedError(\"Local models do not support streaming.\")\n", + " else:\n", + " num_of_responses = params.get(\"n\", 1)\n", + "\n", + " # can create my own data response class\n", + " # here using SimpleNamespace for simplicity\n", + " # as long as it adheres to the ClientResponseProtocol\n", + "\n", + " response = SimpleNamespace()\n", + "\n", + " inputs = self.tokenizer.apply_chat_template(\n", + " params[\"messages\"], return_tensors=\"pt\", add_generation_prompt=True\n", + " ).to(self.device)\n", + " inputs_length = inputs.shape[-1]\n", + "\n", + " # add inputs_length to max_length\n", + " max_length = self.max_length + inputs_length\n", + " generation_config = GenerationConfig(\n", + " max_length=max_length,\n", + " eos_token_id=self.tokenizer.eos_token_id,\n", + " pad_token_id=self.tokenizer.eos_token_id,\n", + " )\n", + "\n", + " response.choices = []\n", + " response.model = self.model_name\n", + "\n", + " for _ in range(num_of_responses):\n", + " outputs = self.model.generate(inputs, generation_config=generation_config)\n", + " # Decode only the newly generated text, excluding the prompt\n", + " text = self.tokenizer.decode(outputs[0, inputs_length:])\n", + " choice = SimpleNamespace()\n", + " choice.message = SimpleNamespace()\n", + " choice.message.content = text\n", + " choice.message.function_call = None\n", + " response.choices.append(choice)\n", + "\n", + " return response\n", + "\n", + " def message_retrieval(self, response):\n", + " \"\"\"Retrieve the messages from the response.\"\"\"\n", + " choices = response.choices\n", + " return [choice.message.content for choice in choices]\n", + "\n", + " def cost(self, response) -> float:\n", + " \"\"\"Calculate the cost of the response.\"\"\"\n", + " response.cost = 0\n", + " return 0\n", + "\n", + " @staticmethod\n", + " def get_usage(response):\n", + " # returns a dict of prompt_tokens, completion_tokens, total_tokens, cost, model\n", + " # if usage needs to be tracked, else None\n", + " return {}" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set your API Endpoint\n", + "\n", + "The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.\n", + "\n", + "It first looks for an environment variable of a specified name (\"OAI_CONFIG_LIST\" in this example), which needs to be a valid json string. If that variable is not found, it looks for a json file with the same name. It filters the configs by models (you can filter by other keys as well).\n", + "\n", + "The json looks like the following:\n", + "```json\n", + "[\n", + " {\n", + " \"model\": \"gpt-4\",\n", + " \"api_key\": \"\"\n", + " },\n", + " {\n", + " \"model\": \"gpt-4\",\n", + " \"api_key\": \"\",\n", + " \"base_url\": \"\",\n", + " \"api_type\": \"azure\",\n", + " \"api_version\": \"2023-08-01-preview\"\n", + " },\n", + " {\n", + " \"model\": \"gpt-4-32k\",\n", + " \"api_key\": \"\",\n", + " \"base_url\": \"\",\n", + " \"api_type\": \"azure\",\n", + " \"api_version\": \"2023-08-01-preview\"\n", + " }\n", + "]\n", + "```\n", + "\n", + "You can set the value of config_list in any way you prefer. Please refer to this [notebook](https://github.com/microsoft/autogen/blob/main/notebook/oai_openai_utils.ipynb) for full code examples of the different methods." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set the config for the custom model\n", + "\n", + "You can add any paramteres that are needed for the custom model loading in the same configuration list.\n", + "\n", + "It is important to add the `model_client_cls` field and set it to a string that corresponds to the class name: `\"CustomModelClient\"`.\n", + "\n", + "```json\n", + "{\n", + " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", + " \"model_client_cls\": \"CustomModelClient\",\n", + " \"device\": \"cuda\",\n", + " \"n\": 1,\n", + " \"params\": {\n", + " \"max_length\": 1000,\n", + " }\n", + "},\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_list_custom = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\"model_client_cls\": [\"CustomModelClient\"]},\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Construct Agents\n", + "\n", + "Consturct a simple conversation between a User proxy and an Assistent agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n", + "user_proxy = UserProxyAgent(\n", + " \"user_proxy\",\n", + " code_execution_config={\n", + " \"work_dir\": \"coding\",\n", + " \"use_docker\": False, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly.\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register the custom client class to the assistant agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant.register_model_client(model_client_cls=CustomModelClient)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user_proxy.initiate_chat(assistant, message=\"Write python code to print Hello World!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register a custom client class with a pre-loaded model\n", + "\n", + "If you want to have more control over when the model gets loaded, you can load the model yourself and pass it as an argument to the CustomClient during registration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# custom client with custom model loader\n", + "\n", + "\n", + "class CustomModelClientWithArguments(CustomModelClient):\n", + " def __init__(self, config, loaded_model, tokenizer, **kwargs):\n", + " print(f\"CustomModelClientWithArguments config: {config}\")\n", + "\n", + " self.model_name = config[\"model\"]\n", + " self.model = loaded_model\n", + " self.tokenizer = tokenizer\n", + "\n", + " self.device = config.get(\"device\", \"cpu\")\n", + "\n", + " gen_config_params = config.get(\"params\", {})\n", + " self.max_length = gen_config_params.get(\"max_length\", 256)\n", + " print(f\"Loaded model {config['model']} to {self.device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load model here\n", + "\n", + "config = config_list_custom[0]\n", + "device = config.get(\"device\", \"cpu\")\n", + "loaded_model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(config[\"model\"], use_fast=False)\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add the config of the new custom model\n", + "\n", + "```json\n", + "{\n", + " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", + " \"model_client_cls\": \"CustomModelClientWithArguments\",\n", + " \"device\": \"cuda\",\n", + " \"n\": 1,\n", + " \"params\": {\n", + " \"max_length\": 1000,\n", + " }\n", + "},\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_list_custom = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\"model_client_cls\": [\"CustomModelClientWithArguments\"]},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant.register_model_client(\n", + " model_client_cls=CustomModelClientWithArguments,\n", + " loaded_model=loaded_model,\n", + " tokenizer=tokenizer,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user_proxy.initiate_chat(assistant, message=\"Write python code to print Hello World!\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + }, + "vscode": { + "interpreter": { + "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "2d910cfd2d2a4fc49fc30fbbdc5576a7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "454146d0f7224f038689031002906e6f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e4ae2b6f5a974fd4bafb6abb9d12ff26", + "IPY_MODEL_577e1e3cc4db4942b0883577b3b52755", + "IPY_MODEL_b40bdfb1ac1d4cffb7cefcb870c64d45" + ], + "layout": "IPY_MODEL_dc83c7bff2f241309537a8119dfc7555", + "tabbable": null, + "tooltip": null + } + }, + "577e1e3cc4db4942b0883577b3b52755": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_2d910cfd2d2a4fc49fc30fbbdc5576a7", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_74a6ba0c3cbc4051be0a83e152fe1e62", + "tabbable": null, + "tooltip": null, + "value": 1 + } + }, + "6086462a12d54bafa59d3c4566f06cb2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "74a6ba0c3cbc4051be0a83e152fe1e62": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "7d3f3d9e15894d05a4d188ff4f466554": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "b40bdfb1ac1d4cffb7cefcb870c64d45": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_f1355871cc6f4dd4b50d9df5af20e5c8", + "placeholder": "​", + "style": "IPY_MODEL_ca245376fd9f4354af6b2befe4af4466", + "tabbable": null, + "tooltip": null, + "value": " 1/1 [00:00<00:00, 44.69it/s]" + } + }, + "ca245376fd9f4354af6b2befe4af4466": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "dc83c7bff2f241309537a8119dfc7555": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4ae2b6f5a974fd4bafb6abb9d12ff26": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_6086462a12d54bafa59d3c4566f06cb2", + "placeholder": "​", + "style": "IPY_MODEL_7d3f3d9e15894d05a4d188ff4f466554", + "tabbable": null, + "tooltip": null, + "value": "100%" + } + }, + "f1355871cc6f4dd4b50d9df5af20e5c8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/agentchat/contrib/test_gpt_assistant.py b/test/agentchat/contrib/test_gpt_assistant.py index a9fb901edc81..92e12558afc7 100644 --- a/test/agentchat/contrib/test_gpt_assistant.py +++ b/test/agentchat/contrib/test_gpt_assistant.py @@ -225,7 +225,7 @@ def test_get_assistant_files() -> None: and assert that the retrieved instructions match the set instructions. """ current_file_path = os.path.abspath(__file__) - openai_client = OpenAIWrapper(config_list=config_list)._clients[0] + openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") name = f"For test_get_assistant_files {uuid.uuid4()}" @@ -274,7 +274,7 @@ def test_assistant_retrieval() -> None: "description": "This is a test function 2", } - openai_client = OpenAIWrapper(config_list=config_list)._clients[0] + openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client current_file_path = os.path.abspath(__file__) file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") @@ -350,7 +350,7 @@ def test_assistant_mismatch_retrieval() -> None: "description": "This is a test function 3", } - openai_client = OpenAIWrapper(config_list=config_list)._clients[0] + openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client current_file_path = os.path.abspath(__file__) file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py new file mode 100644 index 000000000000..8e536921795f --- /dev/null +++ b/test/oai/test_custom_client.py @@ -0,0 +1,166 @@ +import pytest +from autogen import OpenAIWrapper +from autogen.oai import ModelClient +from typing import Dict + +try: + from openai import OpenAI +except ImportError: + skip = True +else: + skip = False + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_custom_model_client(): + TEST_COST = 20000000 + TEST_CUSTOM_RESPONSE = "This is a custom response." + TEST_DEVICE = "cpu" + TEST_LOCAL_MODEL_NAME = "local_model_name" + TEST_OTHER_PARAMS_VAL = "other_params" + TEST_MAX_LENGTH = 1000 + + class CustomModel: + def __init__(self, config: Dict, test_hook): + self.test_hook = test_hook + self.device = config["device"] + self.model = config["model"] + self.other_params = config["params"]["other_params"] + self.max_length = config["params"]["max_length"] + self.test_hook["called"] = True + # set all params to test hook + self.test_hook["device"] = self.device + self.test_hook["model"] = self.model + self.test_hook["other_params"] = self.other_params + self.test_hook["max_length"] = self.max_length + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = TEST_COST + return TEST_COST + + @staticmethod + def get_usage(response) -> Dict: + return {} + + config_list = [ + { + "model": TEST_LOCAL_MODEL_NAME, + "model_client_cls": "CustomModel", + "device": TEST_DEVICE, + "params": { + "max_length": TEST_MAX_LENGTH, + "other_params": TEST_OTHER_PARAMS_VAL, + }, + }, + ] + + test_hook = {"called": False} + + client = OpenAIWrapper(config_list=config_list) + client.register_model_client(model_client_cls=CustomModel, test_hook=test_hook) + + response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE + assert response.cost == TEST_COST + + assert test_hook["called"] + assert test_hook["device"] == TEST_DEVICE + assert test_hook["model"] == TEST_LOCAL_MODEL_NAME + assert test_hook["other_params"] == TEST_OTHER_PARAMS_VAL + assert test_hook["max_length"] == TEST_MAX_LENGTH + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_registering_with_wrong_class_name_raises_error(): + class CustomModel: + def __init__(self, config: Dict): + pass + + def create(self, params): + return None + + def message_retrieval(self, response): + return [] + + def cost(self, response) -> float: + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + config_list = [ + { + "model": "local_model_name", + "model_client_cls": "CustomModelWrongName", + }, + ] + client = OpenAIWrapper(config_list=config_list) + + with pytest.raises(ValueError): + client.register_model_client(model_client_cls=CustomModel) + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_not_all_clients_registered_raises_error(): + class CustomModel: + def __init__(self, config: Dict): + pass + + def create(self, params): + return None + + def message_retrieval(self, response): + return [] + + def cost(self, response) -> float: + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + config_list = [ + { + "model": "local_model_name", + "model_client_cls": "CustomModel", + "device": "cpu", + "params": { + "max_length": 1000, + "other_params": "other_params", + }, + }, + { + "model": "local_model_name_2", + "model_client_cls": "CustomModel", + "device": "cpu", + "params": { + "max_length": 1000, + "other_params": "other_params", + }, + }, + ] + + client = OpenAIWrapper(config_list=config_list) + + client.register_model_client(model_client_cls=CustomModel) + + with pytest.raises(RuntimeError): + client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx new file mode 100644 index 000000000000..796b0c00d20e --- /dev/null +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -0,0 +1,170 @@ +--- +title: "AutoGen with Custom Models: Empowering Users to Use Their Own Inference Mechanism" +authors: + - olgavrou +tags: [AutoGen] +--- + +## TL;DR + +AutoGen now supports custom models! This feature empowers users to define and load their own models, allowing for a more flexible and personalized inference mechanism. By adhering to a specific protocol, you can integrate your custom model for use with AutoGen and respond to prompts any way needed by using any model/API call/hardcoded response you want. + +**NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's** + +## Quickstart + +An interactive and easy way to get started is by following the notebook [here](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_custom_model.ipynb) which loads a local model from HuggingFace into AutoGen and uses it for inference, and making changes to the class provided. + +### Step 1: Create the custom model client class + +To get started with using custom models in AutoGen, you need to create a model client class that adheres to the `ModelClient` protocol defined in `client.py`. The new model client class should implement these methods: + +- `create()`: Returns a response object that implements the `ModelClientResponseProtocol` (more details in the Protocol section). +- `message_retrieval()`: Processes the response object and returns a list of strings or a list of message objects (more details in the Protocol section). +- `cost()`: Returns the cost of the response. +- `get_usage()`: Returns a dictionary with keys from `RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]`. + +E.g. of a bare bones dummy custom class: + +```python +class CustomModelClient: + def __init__(self, config, **kwargs): + print(f"CustomModelClient config: {config}") + + def create(self, params): + num_of_responses = params.get("n", 1) + + # can create my own data response class + # here using SimpleNamespace for simplicity + # as long as it adheres to the ModelClientResponseProtocol + + response = SimpleNamespace() + response.choices = [] + response.model = "model_name" # should match the OAI_CONFIG_LIST registration + + for _ in range(num_of_responses): + text = "this is a dummy text response" + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = text + choice.message.function_call = None + response.choices.append(choice) + return response + + def message_retrieval(self, response): + choices = response.choices + return [choice.message.content for choice in choices] + + def cost(self, response) -> float: + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response): + return {} +``` + +### Step 2: Add the configuration to the OAI_CONFIG_LIST + +The field that is necessary is setting `model_client_cls` to the name of the new class (as a string) `"model_client_cls":"CustomModelClient"`. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.: + +```json +{ + "model": "Open-Orca/Mistral-7B-OpenOrca", + "model_client_cls": "CustomModelClient", + "device": "cuda", + "n": 1, + "params": { + "max_length": 1000, + } +} +``` + +### Step 3: Register the new custom model to the agent that will use it + +If a configuration with the field `"model_client_cls":""` has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized: + +```python +my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor]) +``` + +`model_client_cls=CustomModelClient` arg matches the one specified in the `OAI_CONFIG_LIST` and `CustomModelClient` is the class that adheres to the `ModelClient` protocol (more details on the protocol below). + +If the new model client is in the config list but not registered by the time the chat is initialized, then an error will be raised. + +## Protocol details + +A custom model class can be created in many ways, but needs to adhere to the `ModelClient` protocol and response structure which is defined in `client.py` and shown below. + +The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier. + +```python + +class ModelClient(Protocol): + """ + A client class must implement the following methods: + - create must return a response object that implements the ModelClientResponseProtocol + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model + + This class is used to create a client that can be used by OpenAIWrapper. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ModelClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: str | None + + choices: List[Choice] + model: str + + def create(self, params) -> ModelClientResponseProtocol: + ... + + def message_retrieval( + self, response: ModelClientResponseProtocol + ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... + + def cost(self, response: ModelClientResponseProtocol) -> float: + ... + + @staticmethod + def get_usage(response: ModelClientResponseProtocol) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... + +``` + +## Troubleshooting steps + +If something doesn't work then run through the checklist: + +- Make sure you have followed the client protocol and client response protocol when creating the custom model class + - `create()` method: `ModelClientResponseProtocol` must be followed when returning an inference response during `create` call. + - `message_retrieval()` method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + - `cost()`method: returns an integer, and if you don't care about cost tracking you can just return `0`. + - `get_usage()`: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary `{}`. +- Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"model_client_cls":""` field. +- Make sure you have registered the client using the corresponding config entry and your new class `agent.register_model_client(model_client_cls=, [other optional args])` +- Make sure that all of the custom models defined in the `OAI_CONFIG_LIST` have been registered. +- Any other troubleshooting might need to be done in the custom code itself. + +## Conclusion + +With the ability to use custom models, AutoGen now offers even more flexibility and power for your AI applications. Whether you've trained your own model or want to use a specific pre-trained model, AutoGen can accommodate your needs. Happy coding!