Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for temporary IAM credentials in chat_bedrock() #266

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
disambiguate it from `chat_snowflake()` (which *also* uses "Cortex") (#275,
@atheriel).

* `chat_bedrock()` now handles temporary IAM credentials better (#261,
@atheriel).

# ellmer 0.1.0

* New `chat_vllm()` to chat with models served by vLLM (#140).
Expand Down
55 changes: 41 additions & 14 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ chat_bedrock <- function(system_prompt = NULL,
echo = NULL) {

check_installed("paws.common", "AWS authentication")
credentials <- paws_credentials(profile)
cache <- aws_creds_cache(profile)
credentials <- paws_credentials(profile, cache = cache)

turns <- normalize_turns(turns, system_prompt)
model <- set_default(model, "anthropic.claude-3-5-sonnet-20240620-v1:0")
Expand All @@ -45,7 +46,8 @@ chat_bedrock <- function(system_prompt = NULL,
base_url = "",
model = model,
profile = profile,
credentials = credentials
region = credentials$region,
cache = cache
)

Chat$new(provider = provider, turns = turns, echo = echo)
Expand All @@ -57,7 +59,8 @@ ProviderBedrock <- new_class(
properties = list(
model = prop_string(),
profile = prop_string(allow_null = TRUE),
credentials = class_list
region = prop_string(),
cache = class_list
)
)

Expand All @@ -69,19 +72,20 @@ method(chat_request, ProviderBedrock) <- function(provider,
extra_args = list()) {

req <- request(paste0(
"https://bedrock-runtime.", provider@credentials$region, ".amazonaws.com"
"https://bedrock-runtime.", provider@region, ".amazonaws.com"
))
req <- req_url_path_append(
req,
"model",
provider@model,
if (stream) "converse-stream" else "converse"
)
creds <- paws_credentials(provider@profile, provider@cache)
req <- req_auth_aws_v4(
req,
aws_access_key_id = provider@credentials$access_key_id,
aws_secret_access_key = provider@credentials$secret_access_key,
aws_session_token = provider@credentials$session_token
aws_access_key_id = creds$access_key_id,
aws_secret_access_key = creds$secret_access_key,
aws_session_token = creds$session_token
)

req <- req_error(req, body = function(resp) {
Expand Down Expand Up @@ -297,15 +301,38 @@ method(as_json, list(ProviderBedrock, ToolDef)) <- function(provider, x) {

# Helpers ----------------------------------------------------------------

paws_credentials <- function(profile) {
if (is_testing()) {
tryCatch(
paws.common::locate_credentials(profile),
paws_credentials <- function(profile, cache = aws_creds_cache(profile),
reauth = FALSE) {
creds <- cache$get()
if (reauth || is.null(creds) || creds$expiration < Sys.time()) {
cache$clear()
try_fetch(
creds <- locate_aws_credentials(profile),
error = function(cnd) {
testthat::skip("Failed to locate AWS credentails")
if (is_testing()) {
testthat::skip("Failed to locate AWS credentails")
}
cli::cli_abort("No IAM credentials found.", parent = cnd)
}
)
} else {
paws.common::locate_credentials(profile)
cache$set(creds)
}
creds
}

# Wrapper for paws.common::locate_credentials() so we can mock it in tests.
locate_aws_credentials <- function(profile) {
paws.common::locate_credentials(profile)
}

# In-memory cache for AWS credentials. Analogous to httr2:::cache_mem().
aws_creds_cache <- function(profile) {
key <- hash(profile)
list(
get = function() env_get(the$aws_credentials_cache, key, default = NULL),
set = function(creds) env_poke(the$aws_credentials_cache, key, creds),
clear = function() env_unbind(the$aws_credentials_cache, key)
)
}

the$aws_credentials_cache <- new_environment()
45 changes: 45 additions & 0 deletions tests/testthat/test-provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,48 @@ test_that("can use images", {
test_images_inline(chat_fun)
test_images_remote_error(chat_fun)
})

# Auth --------------------------------------------------------------------

test_that("AWS credential caching works as expected", {
# Mock AWS credentials for different profiles.
local_mocked_bindings(
locate_aws_credentials = function(profile) {
if (!is.null(profile) && profile == "test") {
list(
access_key = "key1",
secret_key = "secret1",
expiration = Sys.time() + 3600
)
} else {
list(
access_key = "key2",
secret_key = "secret2",
expiration = Sys.time() + 3600
)
}
}
)

creds1 <- paws_credentials(profile = "test")
creds2 <- paws_credentials(profile = NULL)

# Verify different credentials were returned.
expect_false(identical(creds1, creds2))
expect_equal(creds1$access_key, "key1")
expect_equal(creds2$access_key, "key2")

# Verify cached credentials match original ones.
expect_identical(creds1, paws_credentials(profile = "test"))
expect_identical(creds2, paws_credentials(profile = NULL))

# Simulate a cache entry that has expired.
creds_modified <- creds1
creds_modified$expiration <- Sys.time() - 5
aws_creds_cache(profile = "test")$set(creds_modified)

# Ensure the new credentials have been updated.
expect_false(identical(creds_modified, paws_credentials(profile = "test")))
expect_false(identical(creds1, paws_credentials(profile = "test")))
expect_false(identical(creds2, paws_credentials(profile = "test")))
})
Loading