From c4a95e391225a728237d45a3a173378f903cb485 Mon Sep 17 00:00:00 2001 From: SjoerdGn <38759766+SjoerdGn@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:28:11 +0100 Subject: [PATCH] fix: OPENAI_API_KEY not needed for Azure OpenAI --- tests/test_basics.py | 15 +++++++++++++-- textgrad/engine/openai.py | 40 ++++++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/tests/test_basics.py b/tests/test_basics.py index 0f06b8c..bbac019 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -6,7 +6,7 @@ from textgrad import Variable, TextualGradientDescent, BlackboxLLM, sum from textgrad.engine.base import EngineLM -from textgrad.engine.openai import ChatOpenAI +from textgrad.engine.openai import AzureChatOpenAI, ChatOpenAI from textgrad.autograd import LLMCall, FormattedLLMCall logging.disable(logging.CRITICAL) @@ -247,4 +247,15 @@ def test_multimodal_from_url(): image_variable_2 = Variable(image_data, role_description="image to answer a question about", requires_grad=False) - assert image_variable_2.value == image_variable.value \ No newline at end of file + assert image_variable_2.value == image_variable.value + +def test_azure_openai_engine(): + if os.environ.get("OPENAI_API_KEY"): + os.environ.pop("OPENAI_API_KEY") + + with pytest.raises(ValueError): + engine = AzureChatOpenAI() + + os.environ['AZURE_OPENAI_API_KEY'] = "fake_key" + os.environ['AZURE_OPENAI_API_BASE'] = "fake_base" + engine = AzureChatOpenAI() diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index 723f04a..3247881 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -33,11 +33,13 @@ def __init__( system_prompt: str=DEFAULT_SYSTEM_PROMPT, is_multimodal: bool=False, base_url: str=None, + azure_openai: bool=False, **kwargs): """ :param model_string: :param system_prompt: :param base_url: Used to support Ollama + :param azure_openai: Set to True if you use Azure OpenAI. """ root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_openai_{model_string}.db") @@ -47,20 +49,21 @@ def __init__( self.system_prompt = system_prompt self.base_url = base_url - if not base_url: - if os.getenv("OPENAI_API_KEY") is None: - raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") - - self.client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY") - ) - elif base_url and base_url == OLLAMA_BASE_URL: - self.client = OpenAI( - base_url=base_url, - api_key="ollama" - ) - else: - raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.") + if not azure_openai: + if not base_url: + if os.getenv("OPENAI_API_KEY") is None: + raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") + + self.client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY") + ) + elif base_url and base_url == OLLAMA_BASE_URL: + self.client = OpenAI( + base_url=base_url, + api_key="ollama" + ) + else: + raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.") self.model_string = model_string self.is_multimodal = is_multimodal @@ -184,11 +187,14 @@ def __init__( root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache - super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs) + super().__init__(cache_path=cache_path, + system_prompt=system_prompt, + azure_openai=True, + **kwargs) self.system_prompt = system_prompt api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview") - if os.getenv("AZURE_OPENAI_API_KEY") is None: + if (os.getenv("AZURE_OPENAI_API_KEY") is None) or (os.getenv("AZURE_OPENAI_API_BASE") is None): raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.") self.client = AzureOpenAI( @@ -197,4 +203,4 @@ def __init__( azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"), azure_deployment=model_string, ) - self.model_string = model_string + self.model_string = model_string \ No newline at end of file