From 0ba66adc47c5bec0e98bdf98d1b72e280f0ad4b1 Mon Sep 17 00:00:00 2001 From: vinid Date: Sun, 14 Jul 2024 18:08:20 -0400 Subject: [PATCH 1/2] trying to make engine implementation independent from cache --- textgrad/engine/base.py | 70 +++++++++++++++++++++++++++++++ textgrad/engine/openai.py | 88 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/textgrad/engine/base.py b/textgrad/engine/base.py index b57a719..bcd7e92 100644 --- a/textgrad/engine/base.py +++ b/textgrad/engine/base.py @@ -1,6 +1,8 @@ import hashlib import diskcache as dc from abc import ABC, abstractmethod +from typing import Union, List +import json class EngineLM(ABC): system_prompt: str = "You are a helpful, creative, and smart assistant." @@ -41,3 +43,71 @@ def __setstate__(self, state): # Restore the cache after unpickling self.__dict__.update(state) self.cache = dc.Cache(self.cache_path) + +import platformdirs +import os + +class CachedLLM(CachedEngine, EngineLM): + def __init__(self, model_string, is_multimodal=False, do_cache=False): + root = platformdirs.user_cache_dir("textgrad") + cache_path = os.path.join(root, f"cache_openai_{model_string}.db") + + super().__init__(cache_path=cache_path) + self.model_string = model_string + self.is_multimodal = is_multimodal + self.do_cache = do_cache + + def __call__(self, prompt, **kwargs): + return self.generate(prompt, **kwargs) + + @abstractmethod + def _generate_from_single_prompt(self, prompt: str, system_prompt: str=None, **kwargs): + pass + + @abstractmethod + def _generate_from_multiple_input(self, content: List[Union[str, bytes]], system_prompt: str=None, **kwargs): + pass + + def single_prompt_generate(self, prompt: str, system_prompt: str=None, **kwargs): + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + + if self.do_cache: + cache_or_none = self._check_cache(sys_prompt_arg + prompt) + if cache_or_none is not None: + return cache_or_none + + response = self._generate_from_single_prompt(prompt, system_prompt=sys_prompt_arg, **kwargs) + + if self.do_cache: + self._save_cache(sys_prompt_arg + prompt, response) + return response + + def multimodal_generate(self, content: List[Union[str, bytes]], system_prompt: str = None, **kwargs): + + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + if self.do_cache: + key = "".join([str(k) for k in content]) + + cache_key = sys_prompt_arg + key + cache_or_none = self._check_cache(cache_key) + if cache_or_none is not None: + return cache_or_none + + response = self._generate_from_multiple_input(content, system_prompt=sys_prompt_arg, **kwargs) + + if self.do_cache: + self._save_cache(cache_key, response) + + return response + + def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str = None, **kwargs): + if isinstance(content, str): + return self.single_prompt_generate(content, system_prompt=system_prompt, **kwargs) + + elif isinstance(content, list): + has_multimodal_input = any(isinstance(item, bytes) for item in content) + if has_multimodal_input and not self.is_multimodal: + raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") + + return self.multimodal_generate(content, system_prompt=system_prompt, **kwargs) + diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index 723f04a..fcfc2a8 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -14,7 +14,8 @@ ) from typing import List, Union -from .base import EngineLM, CachedEngine + +from .base import EngineLM, CachedEngine, CachedLLM from .engine_utils import get_image_type_from_bytes # Default base URL for OLLAMA @@ -158,6 +159,91 @@ def _generate_from_multiple_input( self._save_cache(cache_key, response_text) return response_text + +class OpenAIWithCachedLLM(CachedLLM): + DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." + + def __init__(self, model_string, is_multimodal=False, system_prompt: str = DEFAULT_SYSTEM_PROMPT, do_cache=False): + super().__init__(model_string=model_string, is_multimodal=is_multimodal, do_cache=do_cache) + """ + :param model_string: + :param system_prompt: + :param base_url: Used to support Ollama + """ + + self.system_prompt = system_prompt + + if os.getenv("OPENAI_API_KEY") is None: + raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") + + self.client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY") + ) + + def _generate_from_single_prompt( + self, prompt: str, system_prompt: str= None, temperature=0, max_tokens=2000, top_p=0.99 + ): + + response = self.client.chat.completions.create( + model=self.model_string, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + frequency_penalty=0, + presence_penalty=0, + stop=None, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + + response = response.choices[0].message.content + return response + + def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: + """Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API. + """ + formatted_content = [] + for item in content: + if isinstance(item, bytes): + # For now, bytes are assumed to be images + image_type = get_image_type_from_bytes(item) + base64_image = base64.b64encode(item).decode('utf-8') + formatted_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/{image_type};base64,{base64_image}" + } + }) + elif isinstance(item, str): + formatted_content.append({ + "type": "text", + "text": item + }) + else: + raise ValueError(f"Unsupported input type: {type(item)}") + return formatted_content + + def _generate_from_multiple_input( + self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 + ): + formatted_content = self._format_content(content) + + response = self.client.chat.completions.create( + model=self.model_string, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": formatted_content}, + ], + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + + response_text = response.choices[0].message.content + return response_text + class AzureChatOpenAI(ChatOpenAI): def __init__( self, From 569fca65b1b60d94c97062cb1df1f0f9bc0aa84c Mon Sep 17 00:00:00 2001 From: vinid Date: Sun, 14 Jul 2024 18:10:38 -0400 Subject: [PATCH 2/2] just redecleare the key --- textgrad/engine/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/textgrad/engine/base.py b/textgrad/engine/base.py index bcd7e92..a25962c 100644 --- a/textgrad/engine/base.py +++ b/textgrad/engine/base.py @@ -85,9 +85,8 @@ def single_prompt_generate(self, prompt: str, system_prompt: str=None, **kwargs) def multimodal_generate(self, content: List[Union[str, bytes]], system_prompt: str = None, **kwargs): sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + key = "".join([str(k) for k in content]) if self.do_cache: - key = "".join([str(k) for k in content]) - cache_key = sys_prompt_arg + key cache_or_none = self._check_cache(cache_key) if cache_or_none is not None: @@ -96,6 +95,7 @@ def multimodal_generate(self, content: List[Union[str, bytes]], system_prompt: s response = self._generate_from_multiple_input(content, system_prompt=sys_prompt_arg, **kwargs) if self.do_cache: + cache_key = sys_prompt_arg + key self._save_cache(cache_key, response) return response