From 7ff08431093aa2de5fae6d258c620b5855716f5e Mon Sep 17 00:00:00 2001 From: Janito Vaqueiro Ferreira Filho Date: Thu, 6 Feb 2025 14:06:57 -0300 Subject: [PATCH] Implement service queries for chat completions (#1) * Add a `chat_log` field to the state Prepare to keep a chat transcript on chain. * Add an operation to log a chat interaction Update the on-chain chat transcript with the new interaction. * Test logging some chat interactions Ensure that they are stored in the state on chain. * Change the service to handle GraphQL requests For now no requests are actually handled. * Add a `runtime` field to the `ApplicationService` Prepare to perform HTTP queries. * Add a mutation to perform a chat interaction Creates an operation to log a chat interaction, but for now the response is empty. * Implement querying of chat completion API Send an HTTP request to the Atoma network to retrieve a chat completion. * Create a `ChatInteractionResponse` type A helper type to parse the `ChatCompletionResponse` and later build a `ChatInteraction`. * Fetch chat completion to log to chat transcript Include a real response from the Atoma network in the operation to log a chat interaction. * Test if `chat` mutation performs HTTP request Ensure that a `chat` mutation leads to the service sending an HTTP request to the Atoma proxy, and returning an operation to log the chat interaction. * Downgrade Rust version to workaround Wasm issue Use a version known to be able to run Wasm contracts. * Add an integration test for the API query Ensure that the API queries from the service work correctly. --- Cargo.lock | 149 ++++++++++++++++++++++++++---- Cargo.toml | 13 +++ rust-toolchain.toml | 2 +- src/contract.rs | 18 +++- src/contract_unit_tests.rs | 36 ++++++++ src/lib.rs | 24 ++++- src/service.rs | 184 +++++++++++++++++++++++++++++++++++-- src/service_unit_tests.rs | 84 +++++++++++++++++ src/state.rs | 6 +- tests/chat_transcript.rs | 63 +++++++++++++ 10 files changed, 544 insertions(+), 35 deletions(-) create mode 100644 src/contract_unit_tests.rs create mode 100644 src/service_unit_tests.rs create mode 100644 tests/chat_transcript.rs diff --git a/Cargo.lock b/Cargo.lock index aaa8561..d8e7775 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -263,7 +263,14 @@ version = "0.1.0" dependencies = [ "async-graphql", "async-graphql-derive", + "atoma-demo", "linera-sdk", + "proptest", + "serde", + "serde_json", + "test-log", + "test-strategy 0.4.0", + "tokio", ] [[package]] @@ -368,6 +375,21 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitflags" version = "1.3.2" @@ -830,7 +852,7 @@ dependencies = [ "cranelift-codegen 0.112.3", "cranelift-entity 0.112.3", "cranelift-frontend 0.112.3", - "itertools 0.12.1", + "itertools", "log", "smallvec", "wasmparser 0.217.0", @@ -1182,6 +1204,27 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", +] + +[[package]] +name = "env_logger" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -2000,15 +2043,6 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.12.1" @@ -2107,7 +2141,7 @@ dependencies = [ "serde_bytes", "serde_json", "sha3", - "test-strategy", + "test-strategy 0.3.1", "thiserror 1.0.69", "tokio", "tokio-stream", @@ -2168,7 +2202,7 @@ dependencies = [ "serde", "serde_json", "test-log", - "test-strategy", + "test-strategy 0.3.1", "thiserror 1.0.69", "tokio", "tokio-stream", @@ -3030,11 +3064,17 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" dependencies = [ + "bit-set", + "bit-vec", "bitflags 2.8.0", + "lazy_static", "num-traits", "rand", "rand_chacha", "rand_xorshift", + "regex-syntax 0.8.5", + "rusty-fork", + "tempfile", "unarray", ] @@ -3054,8 +3094,8 @@ version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b" dependencies = [ - "heck 0.4.1", - "itertools 0.10.5", + "heck 0.5.0", + "itertools", "log", "multimap", "once_cell", @@ -3075,7 +3115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools", "proc-macro2", "quote", "syn 2.0.96", @@ -3125,6 +3165,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.38" @@ -3468,6 +3514,18 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ruzstd" version = "0.7.3" @@ -3767,7 +3825,19 @@ checksum = "78ad9e09554f0456d67a69c1584c9798ba733a5b50349a6c0d0948710523922d" dependencies = [ "proc-macro2", "quote", - "structmeta-derive", + "structmeta-derive 0.2.0", + "syn 2.0.96", +] + +[[package]] +name = "structmeta" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e1575d8d40908d70f6fd05537266b90ae71b15dbbe7a8b7dffa2b759306d329" +dependencies = [ + "proc-macro2", + "quote", + "structmeta-derive 0.3.0", "syn 2.0.96", ] @@ -3782,6 +3852,17 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "structmeta-derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "strum" version = "0.25.0" @@ -3928,6 +4009,7 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" dependencies = [ + "env_logger", "test-log-macros", "tracing-subscriber", ] @@ -3951,7 +4033,19 @@ checksum = "b8361c808554228ad09bfed70f5c823caf8a3450b6881cc3a38eb57e8c08c1d9" dependencies = [ "proc-macro2", "quote", - "structmeta", + "structmeta 0.2.0", + "syn 2.0.96", +] + +[[package]] +name = "test-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf41af45e3f54cc184831d629d41d5b2bda8297e29c81add7ae4f362ed5e01b" +dependencies = [ + "proc-macro2", + "quote", + "structmeta 0.3.0", "syn 2.0.96", ] @@ -4261,6 +4355,17 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + [[package]] name = "tracing-serde" version = "0.2.0" @@ -4287,6 +4392,7 @@ dependencies = [ "thread_local", "tracing", "tracing-core", + "tracing-log", "tracing-serde", ] @@ -4406,6 +4512,15 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "want" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index db107b5..d52ebee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,14 +3,27 @@ name = "atoma-demo" version = "0.1.0" edition = "2021" +[features] +test = ["proptest", "test-strategy"] + [dependencies] async-graphql = { version = "=7.0.2", default-features = false } async-graphql-derive = { version = "=7.0.2", default-features = false } linera-sdk = { git = "https://github.com/jvff/linera-protocol", rev = "26a5299" } +proptest = { version = "1.6.0", optional = true } +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.137" +test-strategy = { version = "0.4.0", optional = true } [dev-dependencies] +atoma-demo = { path = ".", features = ["test"] } linera-sdk = { git = "https://github.com/jvff/linera-protocol", rev = "26a5299", features = ["test"] } +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +linera-sdk = { git = "https://github.com/jvff/linera-protocol", rev = "26a5299", features = ["test", "wasmer", "unstable-oracles"] } +tokio = "1.39.3" +test-log = "*" + [[bin]] name = "atoma_demo_contract" path = "src/contract.rs" diff --git a/rust-toolchain.toml b/rust-toolchain.toml index e1fe6da..2ddd8cb 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.84.0" +channel = "1.81.0" components = [ "clippy", "rustfmt", "rust-src" ] targets = [ "wasm32-unknown-unknown" ] profile = "minimal" diff --git a/src/contract.rs b/src/contract.rs index be5fa3a..4235275 100644 --- a/src/contract.rs +++ b/src/contract.rs @@ -4,7 +4,11 @@ #![cfg_attr(target_arch = "wasm32", no_main)] mod state; +#[cfg(test)] +#[path = "./contract_unit_tests.rs"] +mod tests; +use atoma_demo::{ChatInteraction, Operation}; use linera_sdk::{ base::WithContractAbi, views::{RootView, View}, @@ -38,7 +42,11 @@ impl Contract for ApplicationContract { async fn instantiate(&mut self, _argument: Self::InstantiationArgument) {} - async fn execute_operation(&mut self, _operation: Self::Operation) -> Self::Response {} + async fn execute_operation(&mut self, operation: Self::Operation) -> Self::Response { + let Operation::LogChatInteraction { interaction } = operation; + + self.log_chat_interaction(interaction); + } async fn execute_message(&mut self, _message: Self::Message) {} @@ -46,3 +54,11 @@ impl Contract for ApplicationContract { self.state.save().await.expect("Failed to save state"); } } + +impl ApplicationContract { + /// Handles an [`Operation::LogChatInteraction`] by adding a [`ChatInteraction`] to the chat + /// log. + fn log_chat_interaction(&mut self, interaction: ChatInteraction) { + self.state.chat_log.push(interaction); + } +} diff --git a/src/contract_unit_tests.rs b/src/contract_unit_tests.rs new file mode 100644 index 0000000..f8605a0 --- /dev/null +++ b/src/contract_unit_tests.rs @@ -0,0 +1,36 @@ +// Copyright (c) Zefchain Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use atoma_demo::{ChatInteraction, Operation}; +use linera_sdk::{util::BlockingWait, Contract, ContractRuntime}; +use test_strategy::proptest; + +use super::ApplicationContract; + +/// Tests if chat interactions are logged on chain. +#[proptest] +fn chat_interactions_are_logged_on_chain(interactions: Vec) { + let mut contract = setup_contract(); + + for interaction in interactions.clone() { + contract + .execute_operation(Operation::LogChatInteraction { interaction }) + .blocking_wait(); + } + + let logged_interactions = contract + .state + .chat_log + .read(..) + .blocking_wait() + .expect("Failed to read logged chat interactions from the state"); + + assert_eq!(logged_interactions, interactions); +} + +/// Creates a [`ApplicationContract`] instance to be tested. +fn setup_contract() -> ApplicationContract { + let runtime = ContractRuntime::new(); + + ApplicationContract::load(runtime).blocking_wait() +} diff --git a/src/lib.rs b/src/lib.rs index dd8c9da..4a1d52d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,15 +2,33 @@ // SPDX-License-Identifier: Apache-2.0 use linera_sdk::base::{ContractAbi, ServiceAbi}; +use serde::{Deserialize, Serialize}; pub struct ApplicationAbi; impl ContractAbi for ApplicationAbi { - type Operation = (); + type Operation = Operation; type Response = (); } impl ServiceAbi for ApplicationAbi { - type Query = (); - type QueryResponse = (); + type Query = async_graphql::Request; + type QueryResponse = async_graphql::Response; +} + +/// Operations that the contract can execute. +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub enum Operation { + /// Log an interaction with the AI. + LogChatInteraction { interaction: ChatInteraction }, +} + +/// A single interaction with the AI chat. +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize, async_graphql::SimpleObject)] +#[cfg_attr(feature = "test", derive(test_strategy::Arbitrary))] +pub struct ChatInteraction { + #[cfg_attr(feature = "test", strategy("[A-Za-z0-9., ]*"))] + pub prompt: String, + #[cfg_attr(feature = "test", strategy("[A-Za-z0-9., ]*"))] + pub response: String, } diff --git a/src/service.rs b/src/service.rs index e0c7fec..ea779ff 100644 --- a/src/service.rs +++ b/src/service.rs @@ -4,13 +4,20 @@ #![cfg_attr(target_arch = "wasm32", no_main)] mod state; +#[cfg(test)] +#[path = "./service_unit_tests.rs"] +mod tests; -use self::state::Application; -use linera_sdk::{base::WithServiceAbi, views::View, Service, ServiceRuntime}; +use std::sync::{Arc, Mutex}; +use async_graphql::{connection::EmptyFields, EmptySubscription, Schema}; +use atoma_demo::{ChatInteraction, Operation}; +use linera_sdk::{base::WithServiceAbi, bcs, ensure, http, Service, ServiceRuntime}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone)] pub struct ApplicationService { - state: Application, - runtime: ServiceRuntime, + runtime: Arc>>, } linera_sdk::service!(ApplicationService); @@ -23,13 +30,170 @@ impl Service for ApplicationService { type Parameters = (); async fn new(runtime: ServiceRuntime) -> Self { - let state = Application::load(runtime.root_view_storage_context()) - .await - .expect("Failed to load state"); - ApplicationService { state, runtime } + ApplicationService { + runtime: Arc::new(Mutex::new(runtime)), + } + } + + async fn handle_query(&self, query: Self::Query) -> Self::QueryResponse { + Schema::build( + EmptyFields, + Mutation { + runtime: self.runtime.clone(), + }, + EmptySubscription, + ) + .finish() + .execute(query) + .await + } +} + +/// Root type that defines all the GraphQL mutations available from the service. +pub struct Mutation { + runtime: Arc>>, +} + +#[async_graphql::Object] +impl Mutation { + /// Executes a chat completion using the Atoma Network. + async fn chat( + &self, + api_token: String, + message: ChatMessage, + model: Option, + max_tokens: Option, + atoma_proxy_url: Option, + ) -> async_graphql::Result> { + let request = ChatCompletionRequest { + stream: false, + messages: &[&message], + model: model.unwrap_or_else(|| "meta-llama/Llama-3.3-70B-Instruct".to_owned()), + max_tokens: max_tokens.unwrap_or(128), + }; + + let response = self.query_chat_completion( + atoma_proxy_url.as_deref().unwrap_or(ATOMA_CLOUD_URL), + &api_token, + &request, + )?; + + let interaction = ChatInteractionResponse::parse_from_completion_response(response)? + .with_prompt(message.content); + + Ok( + bcs::to_bytes(&Operation::LogChatInteraction { interaction }) + .expect("`LogChatInteraction` should be serializable"), + ) + } +} + +/// A message to be sent to the AI chat. +#[derive(Clone, Debug, Deserialize, Serialize, async_graphql::InputObject)] +pub struct ChatMessage { + content: String, + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, +} + +impl Mutation { + /// Queries the Atoma network for a chat completion. + fn query_chat_completion( + &self, + base_url: &str, + api_token: &str, + request: &ChatCompletionRequest, + ) -> async_graphql::Result { + let mut runtime = self + .runtime + .lock() + .expect("Locking should never fail because service runs in a single thread"); + + let body = serde_json::to_vec(request)?; + + let response = runtime.http_request( + http::Request::post(format!("{base_url}/v1/chat/completions"), body) + .with_header("Content-Type", b"application/json") + .with_header("Authorization", format!("Bearer {api_token}").as_bytes()), + ); + + ensure!( + response.status == 200, + async_graphql::Error::new(format!( + "Failed to perform chat completion API query. Status code: {}", + response.status + )) + ); + + serde_json::from_slice::(&response.body).map_err(|error| { + async_graphql::Error::new(format!( + "Failed to deserialize chat completion response: {error}\n{:?}", + String::from_utf8_lossy(&response.body), + )) + }) } +} - async fn handle_query(&self, _query: Self::Query) -> Self::QueryResponse { - panic!("Queries not supported by application"); +/// The POST body to be sent to the chat completion API. +#[derive(Clone, Debug, Serialize)] +pub struct ChatCompletionRequest<'message> { + stream: bool, + messages: &'message [&'message ChatMessage], + model: String, + max_tokens: usize, +} + +/// The response received from the chat completion API. +#[derive(Clone, Debug, Deserialize)] +pub struct ChatCompletionResponse { + choices: Vec, +} + +/// A choice received in the response from a chat completion API. +#[derive(Clone, Debug, Deserialize)] +pub struct ChatCompletionChoice { + message: ChatMessage, +} + +/// Only the response for a [`ChatInteraction`]. +#[derive(Clone, Debug)] +pub struct ChatInteractionResponse { + response: String, +} + +impl ChatInteractionResponse { + /// Parses the first choice from a [`ChatCompletionResponse`] to extract the + /// [`ChatInteractionResponse`]. + pub fn parse_from_completion_response( + response: ChatCompletionResponse, + ) -> async_graphql::Result { + ensure!( + !response.choices.is_empty(), + async_graphql::Error::new( + "Chat completion response has an empty `choices` list".to_owned() + ) + ); + + let first_choice = response + .choices + .into_iter() + .next() + .expect("Response should have at least one choice element"); + + Ok(ChatInteractionResponse { + response: first_choice.message.content, + }) + } + + /// Builds a [`ChatInteraction`] using this response and the provided `prompt`. + pub fn with_prompt(self, prompt: String) -> ChatInteraction { + ChatInteraction { + prompt, + response: self.response, + } } } + +/// The base URL to access the Atoma Cloud proxy. +const ATOMA_CLOUD_URL: &str = "https://api.atoma.network"; diff --git a/src/service_unit_tests.rs b/src/service_unit_tests.rs new file mode 100644 index 0000000..26f2e66 --- /dev/null +++ b/src/service_unit_tests.rs @@ -0,0 +1,84 @@ +// Copyright (c) Zefchain Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use atoma_demo::{ChatInteraction, Operation}; +use linera_sdk::{bcs, http, util::BlockingWait, Service, ServiceRuntime}; +use serde_json::json; +use test_strategy::proptest; + +use super::{ApplicationService, ATOMA_CLOUD_URL}; + +/// Tests if `chat` mutations perform an HTTP request to the Atoma proxy, and generates the +/// operation to log a chat interaction. +#[proptest] +fn performs_http_query( + #[strategy("[A-Za-z0-9%=]*")] api_token: String, + interaction: ChatInteraction, +) { + let service = setup_service(); + + let prompt = &interaction.prompt; + let request = async_graphql::Request::new(format!( + "mutation {{ \ + chat(\ + apiToken: \"{api_token}\", \ + message: {{ \ + content: {prompt:?}, \ + role: \"user\" + }}\ + ) \ + }}" + )); + + let expected_body = format!( + "{{\ + \"stream\":false,\ + \"messages\":[\ + {{\"content\":{prompt:?},\"role\":\"user\"}}\ + ],\ + \"model\":\"meta-llama/Llama-3.3-70B-Instruct\",\ + \"max_tokens\":128\ + }}" + ); + let mock_response = format!( + "{{ \ + \"choices\": [\ + {{ + \"message\": {{\ + \"content\": {:?}, + \"role\": \"\" + }}\ + }}\ + ] \ + }}", + interaction.response + ); + + service.runtime.lock().unwrap().add_expected_http_request( + http::Request::post( + format!("{ATOMA_CLOUD_URL}/v1/chat/completions"), + expected_body, + ) + .with_header("Content-Type", b"application/json") + .with_header("Authorization", format!("Bearer {api_token}").as_bytes()), + http::Response::ok(mock_response), + ); + + let response = service.handle_query(request).blocking_wait(); + + let expected_operation = Operation::LogChatInteraction { interaction }; + let expected_bytes = + bcs::to_bytes(&expected_operation).expect("`Operation` should be serializable"); + let expected_response = async_graphql::Response::new( + async_graphql::Value::from_json(json!({"chat": expected_bytes})).unwrap(), + ); + + assert_eq!(response, expected_response); +} + +/// Creates a [`ApplicationService`] instance to be tested. +fn setup_service() -> ApplicationService { + let runtime = ServiceRuntime::new(); + + ApplicationService::new(runtime).blocking_wait() +} diff --git a/src/state.rs b/src/state.rs index 902dc26..4aae513 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,11 +1,11 @@ // Copyright (c) Zefchain Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use linera_sdk::views::{linera_views, RegisterView, RootView, ViewStorageContext}; +use atoma_demo::ChatInteraction; +use linera_sdk::views::{linera_views, LogView, RootView, ViewStorageContext}; #[derive(RootView, async_graphql::SimpleObject)] #[view(context = "ViewStorageContext")] pub struct Application { - pub value: RegisterView, - // Add fields here. + pub chat_log: LogView, } diff --git a/tests/chat_transcript.rs b/tests/chat_transcript.rs new file mode 100644 index 0000000..95425be --- /dev/null +++ b/tests/chat_transcript.rs @@ -0,0 +1,63 @@ +// Copyright (c) Zefchain Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(not(target_arch = "wasm32"))] + +use std::env; + +use atoma_demo::{ApplicationAbi, ChatInteraction, Operation}; +use linera_sdk::{bcs, test::TestValidator}; + +/// Tests if the service queries the Atoma network when handling a `chat` mutation. +#[test_log::test(tokio::test)] +async fn service_queries_atoma() { + let (_validator, application_id, chain) = + TestValidator::with_current_application::((), ()).await; + + let api_token = env::var("ATOMA_API_TOKEN") + .expect("Missing ATOMA_API_TOKEN environment variable to run integration test"); + + let query = format!( + "mutation {{ \ + chat(\ + apiToken: \"{api_token}\", \ + message: {{ + content: \"What was the capital of Brazil in 1940\", + role: \"user\" + }}\ + ) \ + }}" + ); + + let response = chain.graphql_query(application_id, query).await; + + let response_object = response + .as_object() + .expect("Unexpected response from service"); + + let operation_list = response_object["chat"] + .as_array() + .expect("Unexpected operation representation returned from service"); + + let operation_bytes = operation_list + .iter() + .map(|value| { + let byte_integer = value + .as_u64() + .expect("Invalid byte type in serialized operation"); + + byte_integer + .try_into() + .expect("Invalid byte value in serialized operation") + }) + .collect::>(); + + let operation = + bcs::from_bytes::(&operation_bytes).expect("Failed to deserialize operation"); + + let Operation::LogChatInteraction { + interaction: ChatInteraction { response, .. }, + } = operation; + + assert!(response.contains("Rio de Janeiro")); +}