Skip to content

Commit

Permalink
Use Glob in PUT file paths and Allow for Multi File Upload (#13)
Browse files Browse the repository at this point in the history
colin99d authored Mar 16, 2024
1 parent 56a71d7 commit b908ee0
Showing 6 changed files with 215 additions and 136 deletions.
12 changes: 9 additions & 3 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@ base64 = "0.21"
bytes = "1"
futures = "0.3"
log = "0.4"
object_store = { version = "0.9", features = ["aws"] }
regex = "1"
reqwest = { version = "0.11", default-features = false, features = [
"gzip",
@@ -36,17 +35,24 @@ reqwest-middleware = "0.2"
reqwest-retry = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
snowflake-jwt = { version = "0.3.0", optional = true }
thiserror = "1"
url = "2"
uuid = { version = "1", features = ["v4"] }
snowflake-jwt = { version = "0.3.0", optional = true }

# polars-support
polars-io = { version = ">=0.32", features = ["json", "ipc_streaming"], optional = true}
polars-core = { version = ">=0.32", optional = true}

# put request support
object_store = { version = "0.9", features = ["aws"] }
glob = { version = "0.3"}
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }


[dev-dependencies]
anyhow = "1"
arrow = { version = "50", features = ["prettyprint"] }
clap = { version = "4", features = ["derive"] }
pretty_env_logger = "0.5"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] }
3 changes: 2 additions & 1 deletion snowflake-api/README.md
Original file line number Diff line number Diff line change
@@ -20,7 +20,8 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep
- [x] AWS integration
- [ ] GCloud integration
- [ ] Azure integration
- [ ] Parallel uploading of small files
- [x] Parallel uploading of small files
- [x] Glob support for PUT (eg `*.csv`)
- [x] Polars support [example](./examples/polars/src/main.rs)
- [x] Tracing / custom reqwest middlware [example](./examples/tracing/src/main.rs)

63 changes: 3 additions & 60 deletions snowflake-api/examples/filetransfer.rs
Original file line number Diff line number Diff line change
@@ -2,81 +2,24 @@ use anyhow::Result;
use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use snowflake_api::{QueryResult, SnowflakeApi};
use std::fs;

extern crate snowflake_api;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Path to RSA PEM private key
#[arg(long)]
private_key: Option<String>,

/// Password if certificate is not present
#[arg(long)]
password: Option<String>,

/// <account_identifier> in Snowflake format, uppercase
#[arg(short, long)]
account_identifier: String,

/// Database name
#[arg(short, long)]
database: String,

/// Schema name
#[arg(long)]
schema: String,

/// Warehouse
#[arg(short, long)]
warehouse: String,

/// username to whom the private key belongs to
#[arg(short, long)]
username: String,

/// role which user will assume
#[arg(short, long)]
role: String,

#[arg(long)]
csv_path: String,
}

// SNOWFLAKE_ACCOUNT=... SNOWFLAKE_USER=... SNOWFLAKE_PASSWORD=... SNOWFLAKE_DATABASE=... SNOWFLAKE_SCHEMA=... \
// cargo run --example filetransfer -- --csv-path $(pwd)/examples/oscar_age_male.csv
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();

let args = Args::parse();

let mut api = match (&args.private_key, &args.password) {
(Some(pkey), None) => {
let pem = fs::read_to_string(pkey)?;
SnowflakeApi::with_certificate_auth(
&args.account_identifier,
Some(&args.warehouse),
Some(&args.database),
Some(&args.schema),
&args.username,
Some(&args.role),
&pem,
)?
}
(None, Some(pwd)) => SnowflakeApi::with_password_auth(
&args.account_identifier,
Some(&args.warehouse),
Some(&args.database),
Some(&args.schema),
&args.username,
Some(&args.role),
pwd,
)?,
_ => {
panic!("Either private key path or password must be set")
}
};
let mut api = SnowflakeApi::from_env()?;

log::info!("Creating table");
api.exec(
100 changes: 30 additions & 70 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@ clippy::missing_panics_doc

use std::fmt::{Display, Formatter};
use std::io;
use std::path::Path;
use std::sync::Arc;

use arrow::error::ArrowError;
@@ -24,27 +23,23 @@ use arrow::record_batch::RecordBatch;
use base64::Engine;
use bytes::{Buf, Bytes};
use futures::future::try_join_all;
use object_store::aws::AmazonS3Builder;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use crate::connection::{Connection, ConnectionError};
use responses::ExecResponse;
use session::{AuthError, Session};

use crate::connection::QueryType;
use crate::connection::{Connection, ConnectionError};
use crate::requests::ExecRequest;
use crate::responses::{
AwsPutGetStageInfo, ExecResponseRowType, PutGetExecResponse, PutGetStageInfo, SnowflakeType,
};
use crate::responses::{ExecResponseRowType, SnowflakeType};
use crate::session::AuthError::MissingEnvArgument;

pub mod connection;
#[cfg(feature = "polars")]
mod polars;
mod put;
mod requests;
mod responses;
mod session;
@@ -78,6 +73,9 @@ pub enum SnowflakeApiError {
#[error(transparent)]
ObjectStorePathError(#[from] object_store::path::Error),

#[error(transparent)]
TokioTaskJoinError(#[from] tokio::task::JoinError),

#[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
ApiError(String, String),

@@ -92,6 +90,12 @@ pub enum SnowflakeApiError {

#[error("Unexpected API response")]
UnexpectedResponse,

#[error(transparent)]
GlobPatternError(#[from] glob::PatternError),

#[error(transparent)]
GlobError(#[from] glob::GlobError),
}

/// Even if Arrow is specified as a return type non-select queries
@@ -274,11 +278,11 @@ impl SnowflakeApiBuilder {

let account_identifier = self.auth.account_identifier.to_uppercase();

Ok(SnowflakeApi {
connection: Arc::clone(&connection),
Ok(SnowflakeApi::new(
Arc::clone(&connection),
session,
account_identifier,
})
))
}
}

@@ -290,6 +294,14 @@ pub struct SnowflakeApi {
}

impl SnowflakeApi {
/// Create a new `SnowflakeApi` object with an existing connection and session.
pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
Self {
connection,
session,
account_identifier,
}
}
/// Initialize object with password auth. Authentication happens on the first request.
pub fn with_password_auth(
account_identifier: &str,
@@ -314,11 +326,11 @@ impl SnowflakeApi {
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self {
connection: Arc::clone(&connection),
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
})
))
}

/// Initialize object with private certificate auth. Authentication happens on the first request.
@@ -345,11 +357,11 @@ impl SnowflakeApi {
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self {
connection: Arc::clone(&connection),
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
})
))
}

pub fn from_env() -> Result<Self, SnowflakeApiError> {
@@ -381,7 +393,6 @@ impl SnowflakeApi {
// put commands go through a different flow and result is side-effect
if put_re.is_match(sql) {
log::info!("Detected PUT query");

self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
} else {
self.exec_arrow_raw(sql).await
@@ -396,65 +407,14 @@ impl SnowflakeApi {

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.put(pg).await,
ExecResponse::PutGet(pg) => put::put(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}
}

async fn put(&self, resp: PutGetExecResponse) -> Result<(), SnowflakeApiError> {
match resp.data.stage_info {
PutGetStageInfo::Aws(info) => self.put_to_s3(&resp.data.src_locations, info).await,
PutGetStageInfo::Azure(_) => Err(SnowflakeApiError::Unimplemented(
"PUT local file requests for Azure".to_string(),
)),
PutGetStageInfo::Gcs(_) => Err(SnowflakeApiError::Unimplemented(
"PUT local file requests for GCS".to_string(),
)),
}
}

async fn put_to_s3(
&self,
src_locations: &[String],
info: AwsPutGetStageInfo,
) -> Result<(), SnowflakeApiError> {
let (bucket_name, bucket_path) = info
.location
.split_once('/')
.ok_or(SnowflakeApiError::InvalidBucketPath(info.location.clone()))?;

let s3 = AmazonS3Builder::new()
.with_region(info.region)
.with_bucket_name(bucket_name)
.with_access_key_id(info.creds.aws_key_id)
.with_secret_access_key(info.creds.aws_secret_key)
.with_token(info.creds.aws_token)
.build()?;

// todo: security vulnerability, external system tells you which local files to upload
for src_path in src_locations {
let path = Path::new(src_path);
let filename = path
.file_name()
.ok_or(SnowflakeApiError::InvalidLocalPath(src_path.clone()))?;

// fixme: unwrap
let dest_path = format!("{}{}", bucket_path, filename.to_str().unwrap());
let dest_path = object_store::path::Path::parse(dest_path)?;

let src_path = object_store::path::Path::parse(src_path)?;

let fs = LocalFileSystem::new().get(&src_path).await?;

s3.put(&dest_path, fs.bytes().await?).await?;
}

Ok(())
}

/// Useful for debugging to get the straight query response
#[cfg(debug_assertions)]
pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
Loading

0 comments on commit b908ee0

Please sign in to comment.