Skip to content

Commit

Permalink
Merge pull request #12 from tylerjthomas9/dev-tyler
Browse files Browse the repository at this point in the history
https_kwargs -> http_kwargs, fix bug with http_kwargs
  • Loading branch information
tylerjthomas9 authored Mar 26, 2024
2 parents 4291897 + df65cb6 commit 0749448
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 35 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Downgrade
on:
pull_request:
branches:
- main
paths-ignore:
- 'docs/**'
push:
branches:
- main
paths-ignore:
- 'docs/**'
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
version: ['1']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
- uses: cjdoris/julia-downgrade-compat-action@v1
with:
skip: Pkg,TOML
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GoogleGenAI"
uuid = "903d41d1-eaca-47dd-943b-fee3930375ab"
authors = ["Tyler Thomas <[email protected]>"]
version = "0.2.0"
version = "0.3.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand All @@ -10,12 +10,12 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

[compat]
Aqua = "0.8"
Base64 = "1"
Dates = "1"
Aqua = "0.8.4"
Base64 = "<0.0.1, 1"
Dates = "<0.0.1, 1"
HTTP = "1"
JSON3 = "1"
Test = "1"
Test = "<0.0.1, 1"
julia = "1"

[extras]
Expand Down
56 changes: 28 additions & 28 deletions src/GoogleGenAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ end

#TODO: Should we use different function names?
"""
generate_content(provider::AbstractGoogleProvider, model_name::String, prompt::String, image_path::String; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple
generate_content(api_key::String, model_name::String, prompt::String, image_path::String; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple
generate_content(provider::AbstractGoogleProvider, model_name::String, prompt::String, image_path::String; api_kwargs=NamedTuple(), http_kwargs=NamedTuple()) -> NamedTuple
generate_content(api_key::String, model_name::String, prompt::String, image_path::String; api_kwargs=NamedTuple(), http_kwargs=NamedTuple()) -> NamedTuple
generate_content(provider::AbstractGoogleProvider, model_name::String, conversation::Vector{Dict{Symbol,Any}}; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple
generate_content(api_key::String, model_name::String, conversation::Vector{Dict{Symbol,Any}}; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple
generate_content(provider::AbstractGoogleProvider, model_name::String, conversation::Vector{Dict{Symbol,Any}}; api_kwargs=NamedTuple(), http_kwargs=NamedTuple()) -> NamedTuple
generate_content(api_key::String, model_name::String, conversation::Vector{Dict{Symbol,Any}}; api_kwargs=NamedTuple(), http_kwargs=NamedTuple()) -> NamedTuple
Generate content based on a combination of text prompt and an image (optional).
Expand Down Expand Up @@ -122,7 +122,7 @@ function generate_content(
model_name::String,
prompt::String;
api_kwargs=NamedTuple(),
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
endpoint = "models/$model_name:generateContent"

Expand All @@ -141,18 +141,18 @@ function generate_content(
"safetySettings" => safety_settings,
)

response = _request(provider, endpoint, :POST, body; https_kwargs...)
response = _request(provider, endpoint, :POST, body; http_kwargs...)
return _parse_response(response)
end
function generate_content(
api_key::String,
model_name::String,
prompt::String;
api_kwargs=NamedTuple(),
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
return generate_content(
GoogleProvider(; api_key), model_name, prompt; api_kwargs, https_kwargs
GoogleProvider(; api_key), model_name, prompt; api_kwargs, http_kwargs
)
end

Expand All @@ -162,7 +162,7 @@ function generate_content(
prompt::String,
image_path::String;
api_kwargs=NamedTuple(),
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
image_data = open(base64encode, image_path)
safety_settings = get(api_kwargs, :safety_settings, nothing)
Expand All @@ -189,7 +189,7 @@ function generate_content(
)

response = _request(
provider, "models/$model_name:generateContent", :POST, body; https_kwargs...
provider, "models/$model_name:generateContent", :POST, body; http_kwargs...
)
return _parse_response(response)
end
Expand All @@ -199,10 +199,10 @@ function generate_content(
prompt::String,
image_path::String;
api_kwargs=NamedTuple(),
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
return generate_content(
GoogleProvider(; api_key), model_name, prompt, image_path; api_kwargs, https_kwargs
GoogleProvider(; api_key), model_name, prompt, image_path; api_kwargs, http_kwargs
)
end

Expand All @@ -211,7 +211,7 @@ function generate_content(
model_name::String,
conversation::Vector{Dict{Symbol,Any}};
api_kwargs=NamedTuple(),
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
endpoint = "models/$model_name:generateContent"

Expand All @@ -236,18 +236,18 @@ function generate_content(
"safetySettings" => safety_settings,
)

response = _request(provider, endpoint, :POST, body; https_kwargs)
response = _request(provider, endpoint, :POST, body; http_kwargs...)
return _parse_response(response)
end
function generate_content(
api_key::String,
model_name::String,
conversation::Vector{Dict{Symbol,Any}};
api_kwargs=NamedTuple(),
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
return generate_content(
GoogleProvider(; api_key), model_name, conversation; api_kwargs, https_kwargs
GoogleProvider(; api_key), model_name, conversation; api_kwargs, http_kwargs
)
end

Expand Down Expand Up @@ -278,10 +278,10 @@ function count_tokens(api_key::String, model_name::String, prompt::String)
end

"""
embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String https_kwargs=NamedTuple()) -> NamedTuple
embed_content(api_key::String, model_name::String, prompt::String https_kwargs=NamedTuple()) -> NamedTuple
embed_content(provider::AbstractGoogleProvider, model_name::String, prompts::Vector{String} https_kwargs=NamedTuple()) -> Vector{NamedTuple}
embed_content(api_key::String, model_name::String, prompts::Vector{String}, https_kwargs=NamedTuple()) -> Vector{NamedTuple}
embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String http_kwargs=NamedTuple()) -> NamedTuple
embed_content(api_key::String, model_name::String, prompt::String http_kwargs=NamedTuple()) -> NamedTuple
embed_content(provider::AbstractGoogleProvider, model_name::String, prompts::Vector{String} http_kwargs=NamedTuple()) -> Vector{NamedTuple}
embed_content(api_key::String, model_name::String, prompts::Vector{String}, http_kwargs=NamedTuple()) -> Vector{NamedTuple}
Generate an embedding for the given prompt text using the specified model.
Expand All @@ -303,30 +303,30 @@ function embed_content(
provider::AbstractGoogleProvider,
model_name::String,
prompt::String;
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
endpoint = "models/$model_name:embedContent"
body = Dict(
"model" => "models/$model_name",
"content" => Dict("parts" => [Dict("text" => prompt)]),
)
response = _request(provider, endpoint, :POST, body; https_kwargs...)
response = _request(provider, endpoint, :POST, body; http_kwargs...)
embedding_values = get(
get(JSON3.read(response.body), "embedding", Dict()), "values", Vector{Float64}()
)
return (values=embedding_values, response_status=response.status)
end
function embed_content(
api_key::String, model_name::String, prompt::String, https_kwargs=NamedTuple()
api_key::String, model_name::String, prompt::String, http_kwargs=NamedTuple()
)
return embed_content(GoogleProvider(; api_key), model_name, prompt; https_kwargs...)
return embed_content(GoogleProvider(; api_key), model_name, prompt; http_kwargs...)
end

function embed_content(
provider::AbstractGoogleProvider,
model_name::String,
prompts::Vector{String},
https_kwargs=NamedTuple(),
http_kwargs=NamedTuple(),
)
endpoint = "models/$model_name:batchEmbedContents"
body = Dict(
Expand All @@ -337,17 +337,17 @@ function embed_content(
) for prompt in prompts
],
)
response = _request(provider, endpoint, :POST, body; https_kwargs...)
response = _request(provider, endpoint, :POST, body; http_kwargs...)
embedding_values = [
get(embedding, "values", Vector{Float64}()) for
embedding in JSON3.read(response.body)["embeddings"]
]
return (values=embedding_values, response_status=response.status)
end
function embed_content(
api_key::String, model_name::String, prompts::Vector{String}, https_kwargs=NamedTuple()
api_key::String, model_name::String, prompts::Vector{String}, http_kwargs=NamedTuple()
)
return embed_content(GoogleProvider(; api_key), model_name, prompts; https_kwargs...)
return embed_content(GoogleProvider(; api_key), model_name, prompts; http_kwargs...)
end

"""
Expand Down
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ if haskey(ENV, "GOOGLE_API_KEY")

@testset "GoogleGenAI.jl" begin
api_kwargs = (max_output_tokens=50,)
http_kwargs = (retries=2,)
# Generate text from text
response = generate_content(secret_key, "gemini-pro", "Hello"; api_kwargs)
response = generate_content(
secret_key, "gemini-pro", "Hello"; api_kwargs, http_kwargs
)

# Generate text from text+image
response = generate_content(
Expand All @@ -17,11 +20,14 @@ if haskey(ENV, "GOOGLE_API_KEY")
"What is this picture?",
"example.jpg";
api_kwargs,
http_kwargs,
)

# Multi-turn conversation
conversation = [Dict(:role => "user", :parts => [Dict(:text => "Hello")])]
response = generate_content(secret_key, "gemini-pro", conversation; api_kwargs)
response = generate_content(
secret_key, "gemini-pro", conversation; api_kwargs, http_kwargs
)

n_tokens = count_tokens(secret_key, "gemini-pro", "Hello")
@test n_tokens == 1
Expand Down

2 comments on commit 0749448

@tylerjthomas9
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Rename https_kwargs to http_kwargs
  • Fix bug with passing http_kwargs to _request in generate_content

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103681

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 07494484e5b693bb0f242032bcdf682652476049
git push origin v0.3.0

Please sign in to comment.