Skip to content

Commit

Permalink
feat: DIA-1716: Google AI Studio (Gemini) support in Prompts (#307)
Browse files Browse the repository at this point in the history
Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: hakan458 <[email protected]>
Co-authored-by: matt-bernstein <[email protected]>
Co-authored-by: Hakan Erol <[email protected]>
Co-authored-by: niklub <[email protected]>
Co-authored-by: Matt Bernstein <[email protected]>
Co-authored-by: fern-api <115122769+fern-api[bot]@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nikita Belonogov <[email protected]>
Co-authored-by: nik <[email protected]>
Co-authored-by: niklub <[email protected]>
Co-authored-by: nikitabelonogov <[email protected]>
Co-authored-by: pakelley <[email protected]>
Co-authored-by: Sergei Ivashchenko <[email protected]>
Co-authored-by: triklozoid <[email protected]>
  • Loading branch information
15 people authored Jan 29, 2025
1 parent 0c35b76 commit eb0febf
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
6 changes: 4 additions & 2 deletions adala/environments/code_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def execute_code(self, code: str, input_string) -> Dict:
stdin = io.StringIO(input_string)

try:
with redirect_stdin(stdin), redirect_stdout(stdout), redirect_stderr(
stderr
with (
redirect_stdin(stdin),
redirect_stdout(stdout),
redirect_stderr(stderr),
):
exec(code, {"__builtins__": __builtins__})
out["success"] = True
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ celery = {version = "^5.3.6", extras = ["redis"]}
kombu = ">=5.4.0rc2" # Pin version to fix https://github.com/celery/celery/issues/8030. TODO: remove when this fix will be included in celery
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/b3c9b27b1e2c0162f49bc3f6f6c18a543510bdcf.zip"}
label-studio-sdk = {url = "https://github.com/HumanSignal/label-studio-sdk/archive/3b41cd9b1965c65b8ebaa37db6916a5006eecc23.zip"}
kafka-python-ng = "^2.2.3"
requests = "^2.32.0"
# Using litellm from forked repo until vertex fix is released: https://github.com/BerriAI/litellm/issues/7904
Expand Down
16 changes: 9 additions & 7 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,21 +248,22 @@ async def validate_connection(request: ValidateConnectionRequest):
multi_model_provider_test_models = {
"openai": "gpt-4o-mini",
"vertexai": "vertex_ai/gemini-1.5-flash",
"gemini": "gemini/gemini-1.5-flash",
}
provider = request.provider.lower()
messages = [{"role": "user", "content": "Hey, how's it going?"}]

# For multi-model providers use a model that every account should have access to
if provider in multi_model_provider_test_models.keys():
model = multi_model_provider_test_models[provider]
if provider == "openai":
model_extra = {"api_key": request.api_key}
elif provider == "vertexai":
if provider == "vertexai":
model_extra = {"vertex_credentials": request.vertex_credentials}
if request.vertex_location:
model_extra["vertex_location"] = request.vertex_location
if request.vertex_project:
model_extra["vertex_project"] = request.vertex_project
else:
model_extra = {"api_key": request.api_key}
try:
response = litellm.completion(
messages=messages,
Expand Down Expand Up @@ -325,11 +326,12 @@ async def models_list(request: ModelsListRequest):
# https://docs.litellm.ai/docs/set_keys#get_valid_models
# https://github.com/BerriAI/litellm/blob/b9280528d368aced49cb4d287c57cd0b46168cb6/litellm/utils.py#L5705
# Ultimately just uses litellm.models_by_provider - setting API key is not needed
lse_provider_to_litellm_provider = {"openai": "openai", "vertexai": "vertex_ai"}
lse_provider_to_litellm_provider = {"vertexai": "vertex_ai"}
provider = request.provider.lower()
valid_models = litellm.models_by_provider[
lse_provider_to_litellm_provider[provider]
]
litellm_provider = lse_provider_to_litellm_provider.get(provider, provider)
valid_models = litellm.models_by_provider[litellm_provider]
# some providers include the prefix in this list and others don't
valid_models = [model.replace(f"{litellm_provider}/", "") for model in valid_models]

return Response[ModelsListResponse](
data=ModelsListResponse(models_list=valid_models)
Expand Down

0 comments on commit eb0febf

Please sign in to comment.