Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: streamline OpenAI_Chat initialization and deprecate old parameters #734

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 38 additions & 88 deletions src/vanna/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,120 +9,70 @@ class OpenAI_Chat(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)

# default parameters - can be overrided using config
self.temperature = 0.7
# Ensure config is a dictionary
config = config or {}

if "temperature" in config:
self.temperature = config["temperature"]
# Default parameters - can be overridden using config
self.temperature = config.get("temperature", 0.7)

if "api_type" in config:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)

if "api_base" in config:
raise Exception(
"Passing api_base is now deprecated. Please pass an OpenAI client instead."
)

if "api_version" in config:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)
# Raise exceptions for deprecated parameters
for deprecated_param in ["api_type", "api_base", "api_version"]:
if deprecated_param in config:
raise ValueError(
f"Passing {deprecated_param} is now deprecated. Please pass an OpenAI client instead."
)

if client is not None:
self.client = client
return

if config is None and client is None:
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return

if "api_key" in config:
self.client = OpenAI(api_key=config["api_key"])
# Initialize the OpenAI client with optional overrides from config
self.client = OpenAI(
api_key=config.get("api_key"),
base_url=config.get("base_url")
)

def system_message(self, message: str) -> any:
def system_message(self, message: str) -> dict:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
def user_message(self, message: str) -> dict:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
def assistant_message(self, message: str) -> dict:
return {"role": "assistant", "content": message}

def generate_response(self, prompt, num_tokens):
model = self.config.get("model", "gpt-4o-mini")
print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)
return response

def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
raise ValueError("Prompt is None")

if len(prompt) == 0:
raise Exception("Prompt is empty")
raise ValueError("Prompt is empty")

# Count the number of tokens in the message log
# Use 4 as an approximation for the number of characters per token
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4

if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
stop=None,
temperature=self.temperature,
)
else:
if num_tokens > 3500:
model = "gpt-3.5-turbo-16k"
else:
model = "gpt-3.5-turbo"

print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)

# Find the first response from the chatbot that has text in it (some responses may not have text)
# Use the generate_response method to get the response
response = self.generate_response(prompt, num_tokens)

# Find the first response from the chatbot that has text in it
# (some responses may not have text)
for choice in response.choices:
if "text" in choice:
return choice.text

# If no response with text is found, return the first response's content (which may be empty)
return response.choices[0].message.content
# If no response with text is found, return the first response's content
return response.choices[0].message.content