diff --git a/snowflake-api/Cargo.toml b/snowflake-api/Cargo.toml index 43ae88e..17db81d 100644 --- a/snowflake-api/Cargo.toml +++ b/snowflake-api/Cargo.toml @@ -19,7 +19,7 @@ cert-auth = ["dep:snowflake-jwt"] polars = ["dep:polars-core", "dep:polars-io"] [dependencies] -arrow = "50" +arrow = "51" async-trait = "0.1" base64 = "0.21" bytes = "1" @@ -52,7 +52,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] } [dev-dependencies] anyhow = "1" -arrow = { version = "50", features = ["prettyprint"] } +arrow = { version = "51", features = ["prettyprint"] } clap = { version = "4", features = ["derive"] } pretty_env_logger = "0.5" tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] } diff --git a/snowflake-api/examples/tracing/Cargo.toml b/snowflake-api/examples/tracing/Cargo.toml index 01a07ec..6e009c3 100644 --- a/snowflake-api/examples/tracing/Cargo.toml +++ b/snowflake-api/examples/tracing/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] anyhow = "1.0.79" -arrow = { version = "50.0.0", features = ["prettyprint"] } +arrow = { version = "51", features = ["prettyprint"] } dotenv = "0.15.0" snowflake-api = { path = "../../../snowflake-api" } diff --git a/snowflake-api/examples/tracing/src/main.rs b/snowflake-api/examples/tracing/src/main.rs index 2c1c002..dd8e834 100644 --- a/snowflake-api/examples/tracing/src/main.rs +++ b/snowflake-api/examples/tracing/src/main.rs @@ -57,7 +57,7 @@ async fn run_in_span(api: &snowflake_api::SnowflakeApi) -> anyhow::Result<()> { match res { QueryResult::Arrow(a) => { - println!("{}", pretty_format_batches(&a).unwrap()); + println!("{}", pretty_format_batches(&a[..]).unwrap()); } QueryResult::Json(j) => { println!("{}", j); diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index e7087e1..6d03889 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -163,6 +163,9 @@ impl Connection { } // todo: persist client to use connection polling + // println!("{:?}", url); + // println!("{:?}", headers); + // println!("{:?}", json_str); let resp = self .client .post(url) @@ -197,3 +200,23 @@ impl Connection { Ok(bytes) } } + +#[cfg(test)] +mod test { + use super::*; + use serde::{Serialize, Deserialize}; + use crate::responses::AuthResponse; + + #[derive(Serialize, Deserialize, Debug)] + struct Test { + data: String + } + + #[tokio::test] + async fn test_request_oauth() { + let connection = Connection::new().unwrap(); + let v = Test { data: "Hello".to_string() }; + let result: AuthResponse = connection.request(QueryType::LoginRequest, "", &[], Some("Bearer TOKEN"), v).await.unwrap(); + println!("{:?}", result); + } +} diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 86904e5..bc5ff66 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -14,7 +14,7 @@ clippy::missing_panics_doc )] use std::fmt::{Display, Formatter}; -use std::io; +use std::{io, fs}; use std::sync::Arc; use arrow::error::ArrowError; @@ -30,6 +30,7 @@ use thiserror::Error; use responses::ExecResponse; use session::{AuthError, Session}; +use crate::session::AuthParts; use crate::connection::QueryType; use crate::connection::{Connection, ConnectionError}; use crate::requests::ExecRequest; @@ -412,6 +413,10 @@ impl SnowflakeApi { e.data.error_code, e.message.unwrap_or_default(), )), + ExecResponse::Other(e) => Err(SnowflakeApiError::ApiError( + 9999.to_string(), + e.to_string() + )) } } @@ -443,6 +448,10 @@ impl SnowflakeApi { e.data.error_code, e.message.unwrap_or_default(), )), + ExecResponse::Other(e) => Err(SnowflakeApiError::ApiError( + 9999.to_string(), + e.to_string() + )) }?; // if response was empty, base64 data is empty string @@ -488,7 +497,19 @@ impl SnowflakeApi { ) -> Result { log::debug!("Executing: {}", sql_text); - let parts = self.session.get_token().await?; + // let mut parts = self.session.get_token().await?; + let parts: AuthParts; + let path = "/snowflake/session/token"; + if let Ok(contents) = fs::read_to_string(path) { + println!("Overiding with token: {:?}", contents); + parts = AuthParts { + session_token_auth_header: contents, + sequence_id: 1 + }; + } else { + panic!("Failed to read the env var, using one provided by login") + } + println!("{:?}", parts); let body = ExecRequest { sql_text: sql_text.to_string(), diff --git a/snowflake-api/src/responses.rs b/snowflake-api/src/responses.rs index b8a3e68..3ef0dde 100644 --- a/snowflake-api/src/responses.rs +++ b/snowflake-api/src/responses.rs @@ -9,9 +9,9 @@ pub enum ExecResponse { Query(QueryExecResponse), PutGet(PutGetExecResponse), Error(ExecErrorResponse), + Other(serde_json::Value), } -// todo: add close session response, which should be just empty? #[allow(clippy::large_enum_variant)] #[derive(Deserialize, Debug)] #[serde(untagged)] @@ -21,6 +21,8 @@ pub enum AuthResponse { Renew(RenewSessionResponse), Close(CloseSessionResponse), Error(AuthErrorResponse), + ExecError(ExecErrorResponse), + Other(serde_json::Value), } #[derive(Deserialize, Debug)] @@ -53,9 +55,8 @@ pub struct ExecErrorResponseData { pub line: Option, pub pos: Option, - // fixme: only valid for exec query response error? present in any exec query response? - pub query_id: String, - pub sql_state: String, + pub query_id: Option, + pub sql_state: Option, } #[derive(Deserialize, Debug)] diff --git a/snowflake-api/src/session.rs b/snowflake-api/src/session.rs index 58fcf7f..e8b754a 100644 --- a/snowflake-api/src/session.rs +++ b/snowflake-api/src/session.rs @@ -35,7 +35,7 @@ pub enum AuthError { MissingCertificate, #[error("Unexpected API response")] - UnexpectedResponse, + UnexpectedResponse(Option), // todo: add code mapping to meaningful message and/or refer to docs // eg https://docs.snowflake.com/en/user-guide/key-pair-auth-troubleshooting @@ -269,7 +269,12 @@ impl Session { e.code.unwrap_or_default(), e.message.unwrap_or_default(), )), - _ => Err(AuthError::UnexpectedResponse), + AuthResponse::ExecError(e) => Err(AuthError::AuthFailed( + e.code.unwrap_or_default(), + e.message.unwrap_or_default(), + )), + AuthResponse::Other(value) => Err(AuthError::UnexpectedResponse(Some(value))), + _ => Err(AuthError::UnexpectedResponse(None)), } } else { Ok(()) @@ -356,7 +361,12 @@ impl Session { e.code.unwrap_or_default(), e.message.unwrap_or_default(), )), - _ => Err(AuthError::UnexpectedResponse), + AuthResponse::ExecError(e) => Err(AuthError::AuthFailed( + e.code.unwrap_or_default(), + e.message.unwrap_or_default(), + )), + AuthResponse::Other(value) => Err(AuthError::UnexpectedResponse(Some(value))), + _ => Err(AuthError::UnexpectedResponse(None)), } } @@ -417,7 +427,12 @@ impl Session { e.code.unwrap_or_default(), e.message.unwrap_or_default(), )), - _ => Err(AuthError::UnexpectedResponse), + AuthResponse::ExecError(e) => Err(AuthError::AuthFailed( + e.code.unwrap_or_default(), + e.message.unwrap_or_default(), + )), + AuthResponse::Other(value) => Err(AuthError::UnexpectedResponse(Some(value))), + _ => Err(AuthError::UnexpectedResponse(None)), } } else { Err(AuthError::OutOfOrderRenew)