Skip to content

Commit

Permalink
Merge pull request #3 from xmakro/main
Browse files Browse the repository at this point in the history
Support fetching chunks
tbaums authored Oct 9, 2023
2 parents a95100f + 516c4c4 commit c128f19
Showing 4 changed files with 102 additions and 60 deletions.
14 changes: 8 additions & 6 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "snowflake-api"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
description = "Snowflake API bindings"
authors = ["Andrew Korzhuev <[email protected]>"]
@@ -14,23 +14,25 @@ license = "Apache-2.0"
[dependencies]
thiserror = "1"
snowflake-jwt = "0.3.0"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
reqwest = { version = "0.11", default-features = false, features = ["gzip", "rustls-tls", "json"] }
reqwest-middleware = "0.2"
reqwest-retry = "0.2"
reqwest-retry = "0.3"
log = "0.4"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
url = "2"
uuid = { version = "1.4", features = ["v4"] }
arrow = "42"
arrow = "47"
base64 = "0.21"
regex = "1"
object_store = { version = "0.6", features = ["aws"] }
object_store = { version = "0.7", features = ["aws"] }
async-trait = "0.1"
bytes = "1"
futures = "0.3"

[dev-dependencies]
anyhow = "1"
pretty_env_logger = "0.5.0"
clap = { version = "4", features = ["derive"] }
arrow = { version = "42", features = ["prettyprint"] }
arrow = { version = "47", features = ["prettyprint"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
33 changes: 28 additions & 5 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use reqwest::header;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use url::Url;
@@ -82,6 +82,7 @@ impl Connection {
// use builder to fail safely, unlike client new
let client = reqwest::ClientBuilder::new()
.user_agent("Rust/0.0.1")
.gzip(true)
.referer(false);

#[cfg(debug_assertions)]
@@ -114,9 +115,8 @@ impl Connection {
let client_start_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let client_start_time = client_start_time.to_string();

.as_secs()
.to_string();
// fixme: update uuid's on the retry
let request_id = request_id.to_string();
let request_guid = request_guid.to_string();
@@ -157,4 +157,27 @@ impl Connection {

Ok(resp.json::<R>().await?)
}

pub async fn get_chunk(
&self,
url: &str,
headers: &HashMap<String, String>,
) -> Result<bytes::Bytes, ConnectionError> {
let mut header_map = HeaderMap::new();
for (k, v) in headers {
header_map.insert(
HeaderName::from_bytes(k.as_bytes()).unwrap(),
HeaderValue::from_bytes(v.as_bytes()).unwrap(),
);
}
let bytes = self
.client
.get(url)
.headers(header_map)
.send()
.await?
.bytes()
.await?;
Ok(bytes)
}
}
56 changes: 32 additions & 24 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ use arrow::datatypes::ToByteSlice;
use arrow::ipc::reader::StreamReader;
use arrow::record_batch::RecordBatch;
use base64::Engine;
use futures::future::try_join_all;
use object_store::aws::AmazonS3Builder;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
@@ -90,7 +91,6 @@ pub struct SnowflakeApi {
connection: Arc<Connection>,
session: Session,
account_identifier: String,
sequence_id: u64,
}

impl SnowflakeApi {
@@ -122,7 +122,6 @@ impl SnowflakeApi {
connection: Arc::clone(&connection),
session,
account_identifier,
sequence_id: 0,
})
}

@@ -154,7 +153,6 @@ impl SnowflakeApi {
connection: Arc::clone(&connection),
session,
account_identifier,
sequence_id: 0,
})
}

@@ -168,7 +166,7 @@ impl SnowflakeApi {

/// Execute a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&mut self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();

// put commands go through a different flow and result is side-effect
@@ -181,7 +179,7 @@ impl SnowflakeApi {
}
}

async fn exec_put(&mut self, sql: &str) -> Result<(), SnowflakeApiError> {
async fn exec_put(&self, sql: &str) -> Result<(), SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
.await?;
@@ -262,7 +260,7 @@ impl SnowflakeApi {
.await
}

async fn exec_arrow(&mut self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
async fn exec_arrow(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
@@ -283,22 +281,35 @@ impl SnowflakeApi {
// if response was empty, base64 data is empty string
// todo: still return empty arrow batch with proper schema? (schema always included)
if resp.data.returned == 0 {
log::info!("Got response with 0 rows");

log::debug!("Got response with 0 rows");
Ok(QueryResult::Empty)
} else if let Some(json) = resp.data.rowset {
log::info!("Got JSON response");

log::debug!("Got JSON response");
// NOTE: json response could be chunked too. however, go clients should receive arrow by-default,
// unless user sets session variable to return json. This case was added for debugging and status
// information being passed through that fields.
Ok(QueryResult::Json(json))
} else if let Some(base64) = resp.data.rowset_base64 {
log::info!("Got base64 encoded response");
let bytes = base64::engine::general_purpose::STANDARD.decode(base64)?;
let fr = StreamReader::try_new_unbuffered(bytes.to_byte_slice(), None)?;

// fixme: loads everything into memory
let mut res = Vec::new();
for batch in fr {
res.push(batch?);
let mut res = vec![];
if !base64.is_empty() {
log::debug!("Got base64 encoded response");
let bytes = base64::engine::general_purpose::STANDARD.decode(base64)?;
let fr = StreamReader::try_new_unbuffered(bytes.to_byte_slice(), None)?;
for batch in fr {
res.push(batch?);
}
}
let chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;
for bytes in chunks {
let fr = StreamReader::try_new_unbuffered(&*bytes, None)?;
for batch in fr {
res.push(batch?);
}
}

Ok(QueryResult::Arrow(res))
@@ -308,21 +319,18 @@ impl SnowflakeApi {
}

async fn run_sql<R: serde::de::DeserializeOwned>(
&mut self,
&self,
sql_text: &str,
query_type: QueryType,
) -> Result<R, SnowflakeApiError> {
log::debug!("Executing: {}", sql_text);

let tokens = self.session.get_token().await?;
// expected by snowflake api for all requests within session to follow sequence id
// fixme: possible race condition if multiple requests run in parallel, shouldn't be a big problem however
self.sequence_id += 1;
let parts = self.session.get_token().await?;

let body = ExecRequest {
sql_text: sql_text.to_string(),
async_exec: false,
sequence_id: self.sequence_id,
sequence_id: parts.sequence_id,
is_internal: false,
};

@@ -332,7 +340,7 @@ impl SnowflakeApi {
query_type,
&self.account_identifier,
&[],
Some(&tokens.session_token.auth_header()),
Some(&parts.session_token_auth_header),
body,
)
.await?;
59 changes: 34 additions & 25 deletions snowflake-api/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;
use std::time::{Duration, Instant};

use futures::lock::Mutex;
use snowflake_jwt::generate_jwt_token;
use thiserror::Error;

@@ -40,18 +41,26 @@ pub enum AuthError {
}

#[derive(Debug)]
pub struct AuthTokens {
pub session_token: AuthToken,
pub master_token: AuthToken,
struct AuthTokens {
session_token: AuthToken,
master_token: AuthToken,
/// expected by snowflake api for all requests within session to follow sequence id
sequence_id: u64,
}

#[derive(Debug, Clone)]
pub struct AuthToken {
pub token: String,
struct AuthToken {
token: String,
valid_for: Duration,
issued_on: Instant,
}

#[derive(Debug, Clone)]
pub struct AuthParts {
pub session_token_auth_header: String,
pub sequence_id: u64,
}

impl AuthToken {
pub fn new(token: &str, validity_in_seconds: i64) -> Self {
let token = token.to_string();
@@ -92,8 +101,7 @@ enum AuthType {
pub struct Session {
connection: Arc<Connection>,

// wrap them into mutex?
auth_tokens: Option<AuthTokens>,
auth_tokens: Mutex<Option<AuthTokens>>,
auth_type: AuthType,
account_identifier: String,

@@ -135,7 +143,7 @@ impl Session {

Session {
connection,
auth_tokens: None,
auth_tokens: Mutex::new(None),
auth_type: AuthType::Certificate,
private_key_pem,
account_identifier,
@@ -173,7 +181,7 @@ impl Session {

Session {
connection,
auth_tokens: None,
auth_tokens: Mutex::new(None),
auth_type: AuthType::Password,
account_identifier,
warehouse,
@@ -187,16 +195,15 @@ impl Session {
}

/// Get cached token or request a new one if old one has expired.
pub async fn get_token(&mut self) -> Result<&AuthTokens, AuthError> {
if self.auth_tokens.is_none()
|| self
.auth_tokens
pub async fn get_token(&self) -> Result<AuthParts, AuthError> {
let mut auth_tokens = self.auth_tokens.lock().await;
if auth_tokens.is_none()
|| auth_tokens
.as_ref()
.map(|at| at.master_token.is_expired())
.unwrap_or(false)
{
// Create new session if tokens are absent or can not be exchange

let tokens = match self.auth_type {
AuthType::Certificate => {
log::info!("Starting session with certificate authentication");
@@ -207,24 +214,26 @@ impl Session {
self.create(self.passwd_request_body()?).await
}
}?;
self.auth_tokens = Some(tokens);
} else if self
.auth_tokens
*auth_tokens = Some(tokens);
} else if auth_tokens
.as_ref()
.map(|at| at.session_token.is_expired())
.unwrap_or(false)
{
// Renew old session token

let tokens = self.renew().await?;
self.auth_tokens = Some(tokens);
*auth_tokens = Some(tokens);
}

self.auth_tokens.as_ref().ok_or(AuthError::TokenFetchFailed)
auth_tokens.as_mut().unwrap().sequence_id += 1;
let session_token_auth_header = auth_tokens.as_ref().unwrap().session_token.auth_header();
Ok(AuthParts {
session_token_auth_header,
sequence_id: auth_tokens.as_ref().unwrap().sequence_id,
})
}

pub async fn close(&mut self) -> Result<(), AuthError> {
if let Some(tokens) = self.auth_tokens.as_ref() {
if let Some(tokens) = self.auth_tokens.lock().await.take() {
log::debug!("Closing sessions");

let resp = self
@@ -238,8 +247,6 @@ impl Session {
)
.await?;

self.auth_tokens = None;

match resp {
AuthResponse::Close(_) => Ok(()),
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
@@ -322,6 +329,7 @@ impl Session {
Ok(AuthTokens {
session_token,
master_token,
sequence_id: 0,
})
}
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
@@ -353,7 +361,7 @@ impl Session {
}

async fn renew(&self) -> Result<AuthTokens, AuthError> {
if let Some(token) = &self.auth_tokens {
if let Some(token) = self.auth_tokens.lock().await.take() {
log::debug!("Renewing the token");
let auth = token.master_token.auth_header();
let body = RenewSessionRequest {
@@ -382,6 +390,7 @@ impl Session {
Ok(AuthTokens {
session_token,
master_token,
sequence_id: token.sequence_id,
})
}
AuthResponse::Error(e) => Err(AuthError::AuthFailed(

0 comments on commit c128f19

Please sign in to comment.