Skip to content

Commit

Permalink
Use Glob in PUT file paths and Allow for Multi File Upload (#13)
Browse files Browse the repository at this point in the history
colin99d authored Mar 16, 2024
1 parent 56a71d7 commit b908ee0
Showing 6 changed files with 215 additions and 136 deletions.
12 changes: 9 additions & 3 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@ base64 = "0.21"
bytes = "1"
futures = "0.3"
log = "0.4"
object_store = { version = "0.9", features = ["aws"] }
regex = "1"
reqwest = { version = "0.11", default-features = false, features = [
"gzip",
@@ -36,17 +35,24 @@ reqwest-middleware = "0.2"
reqwest-retry = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
snowflake-jwt = { version = "0.3.0", optional = true }
thiserror = "1"
url = "2"
uuid = { version = "1", features = ["v4"] }
snowflake-jwt = { version = "0.3.0", optional = true }

# polars-support
polars-io = { version = ">=0.32", features = ["json", "ipc_streaming"], optional = true}
polars-core = { version = ">=0.32", optional = true}

# put request support
object_store = { version = "0.9", features = ["aws"] }
glob = { version = "0.3"}
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }


[dev-dependencies]
anyhow = "1"
arrow = { version = "50", features = ["prettyprint"] }
clap = { version = "4", features = ["derive"] }
pretty_env_logger = "0.5"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] }
3 changes: 2 additions & 1 deletion snowflake-api/README.md
Original file line number Diff line number Diff line change
@@ -20,7 +20,8 @@ Since it does a lot of I/O the library is async-only, and currently has hard dep
- [x] AWS integration
- [ ] GCloud integration
- [ ] Azure integration
- [ ] Parallel uploading of small files
- [x] Parallel uploading of small files
- [x] Glob support for PUT (eg `*.csv`)
- [x] Polars support [example](./examples/polars/src/main.rs)
- [x] Tracing / custom reqwest middlware [example](./examples/tracing/src/main.rs)

63 changes: 3 additions & 60 deletions snowflake-api/examples/filetransfer.rs
Original file line number Diff line number Diff line change
@@ -2,81 +2,24 @@ use anyhow::Result;
use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use snowflake_api::{QueryResult, SnowflakeApi};
use std::fs;

extern crate snowflake_api;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Path to RSA PEM private key
#[arg(long)]
private_key: Option<String>,

/// Password if certificate is not present
#[arg(long)]
password: Option<String>,

/// <account_identifier> in Snowflake format, uppercase
#[arg(short, long)]
account_identifier: String,

/// Database name
#[arg(short, long)]
database: String,

/// Schema name
#[arg(long)]
schema: String,

/// Warehouse
#[arg(short, long)]
warehouse: String,

/// username to whom the private key belongs to
#[arg(short, long)]
username: String,

/// role which user will assume
#[arg(short, long)]
role: String,

#[arg(long)]
csv_path: String,
}

// SNOWFLAKE_ACCOUNT=... SNOWFLAKE_USER=... SNOWFLAKE_PASSWORD=... SNOWFLAKE_DATABASE=... SNOWFLAKE_SCHEMA=... \
// cargo run --example filetransfer -- --csv-path $(pwd)/examples/oscar_age_male.csv
#[tokio::main]
async fn main() -> Result<()> {
pretty_env_logger::init();

let args = Args::parse();

let mut api = match (&args.private_key, &args.password) {
(Some(pkey), None) => {
let pem = fs::read_to_string(pkey)?;
SnowflakeApi::with_certificate_auth(
&args.account_identifier,
Some(&args.warehouse),
Some(&args.database),
Some(&args.schema),
&args.username,
Some(&args.role),
&pem,
)?
}
(None, Some(pwd)) => SnowflakeApi::with_password_auth(
&args.account_identifier,
Some(&args.warehouse),
Some(&args.database),
Some(&args.schema),
&args.username,
Some(&args.role),
pwd,
)?,
_ => {
panic!("Either private key path or password must be set")
}
};
let mut api = SnowflakeApi::from_env()?;

log::info!("Creating table");
api.exec(
100 changes: 30 additions & 70 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@ clippy::missing_panics_doc

use std::fmt::{Display, Formatter};
use std::io;
use std::path::Path;
use std::sync::Arc;

use arrow::error::ArrowError;
@@ -24,27 +23,23 @@ use arrow::record_batch::RecordBatch;
use base64::Engine;
use bytes::{Buf, Bytes};
use futures::future::try_join_all;
use object_store::aws::AmazonS3Builder;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use crate::connection::{Connection, ConnectionError};
use responses::ExecResponse;
use session::{AuthError, Session};

use crate::connection::QueryType;
use crate::connection::{Connection, ConnectionError};
use crate::requests::ExecRequest;
use crate::responses::{
AwsPutGetStageInfo, ExecResponseRowType, PutGetExecResponse, PutGetStageInfo, SnowflakeType,
};
use crate::responses::{ExecResponseRowType, SnowflakeType};
use crate::session::AuthError::MissingEnvArgument;

pub mod connection;
#[cfg(feature = "polars")]
mod polars;
mod put;
mod requests;
mod responses;
mod session;
@@ -78,6 +73,9 @@ pub enum SnowflakeApiError {
#[error(transparent)]
ObjectStorePathError(#[from] object_store::path::Error),

#[error(transparent)]
TokioTaskJoinError(#[from] tokio::task::JoinError),

#[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
ApiError(String, String),

@@ -92,6 +90,12 @@ pub enum SnowflakeApiError {

#[error("Unexpected API response")]
UnexpectedResponse,

#[error(transparent)]
GlobPatternError(#[from] glob::PatternError),

#[error(transparent)]
GlobError(#[from] glob::GlobError),
}

/// Even if Arrow is specified as a return type non-select queries
@@ -274,11 +278,11 @@ impl SnowflakeApiBuilder {

let account_identifier = self.auth.account_identifier.to_uppercase();

Ok(SnowflakeApi {
connection: Arc::clone(&connection),
Ok(SnowflakeApi::new(
Arc::clone(&connection),
session,
account_identifier,
})
))
}
}

@@ -290,6 +294,14 @@ pub struct SnowflakeApi {
}

impl SnowflakeApi {
/// Create a new `SnowflakeApi` object with an existing connection and session.
pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
Self {
connection,
session,
account_identifier,
}
}
/// Initialize object with password auth. Authentication happens on the first request.
pub fn with_password_auth(
account_identifier: &str,
@@ -314,11 +326,11 @@ impl SnowflakeApi {
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self {
connection: Arc::clone(&connection),
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
})
))
}

/// Initialize object with private certificate auth. Authentication happens on the first request.
@@ -345,11 +357,11 @@ impl SnowflakeApi {
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self {
connection: Arc::clone(&connection),
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
})
))
}

pub fn from_env() -> Result<Self, SnowflakeApiError> {
@@ -381,7 +393,6 @@ impl SnowflakeApi {
// put commands go through a different flow and result is side-effect
if put_re.is_match(sql) {
log::info!("Detected PUT query");

self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
} else {
self.exec_arrow_raw(sql).await
@@ -396,65 +407,14 @@ impl SnowflakeApi {

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.put(pg).await,
ExecResponse::PutGet(pg) => put::put(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}
}

async fn put(&self, resp: PutGetExecResponse) -> Result<(), SnowflakeApiError> {
match resp.data.stage_info {
PutGetStageInfo::Aws(info) => self.put_to_s3(&resp.data.src_locations, info).await,
PutGetStageInfo::Azure(_) => Err(SnowflakeApiError::Unimplemented(
"PUT local file requests for Azure".to_string(),
)),
PutGetStageInfo::Gcs(_) => Err(SnowflakeApiError::Unimplemented(
"PUT local file requests for GCS".to_string(),
)),
}
}

async fn put_to_s3(
&self,
src_locations: &[String],
info: AwsPutGetStageInfo,
) -> Result<(), SnowflakeApiError> {
let (bucket_name, bucket_path) = info
.location
.split_once('/')
.ok_or(SnowflakeApiError::InvalidBucketPath(info.location.clone()))?;

let s3 = AmazonS3Builder::new()
.with_region(info.region)
.with_bucket_name(bucket_name)
.with_access_key_id(info.creds.aws_key_id)
.with_secret_access_key(info.creds.aws_secret_key)
.with_token(info.creds.aws_token)
.build()?;

// todo: security vulnerability, external system tells you which local files to upload
for src_path in src_locations {
let path = Path::new(src_path);
let filename = path
.file_name()
.ok_or(SnowflakeApiError::InvalidLocalPath(src_path.clone()))?;

// fixme: unwrap
let dest_path = format!("{}{}", bucket_path, filename.to_str().unwrap());
let dest_path = object_store::path::Path::parse(dest_path)?;

let src_path = object_store::path::Path::parse(src_path)?;

let fs = LocalFileSystem::new().get(&src_path).await?;

s3.put(&dest_path, fs.bytes().await?).await?;
}

Ok(())
}

/// Useful for debugging to get the straight query response
#[cfg(debug_assertions)]
pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
170 changes: 170 additions & 0 deletions snowflake-api/src/put.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use std::fs::Metadata;
use std::path::Path;
use std::sync::Arc;

use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use object_store::aws::AmazonS3Builder;
use object_store::limit::LimitStore;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
use tokio::task;

use crate::responses::{AwsPutGetStageInfo, PutGetExecResponse, PutGetStageInfo};
use crate::SnowflakeApiError;

pub async fn put(resp: PutGetExecResponse) -> Result<(), SnowflakeApiError> {
match resp.data.stage_info {
PutGetStageInfo::Aws(info) => {
put_to_s3(
resp.data.src_locations,
info,
resp.data.parallel,
resp.data.threshold,
)
.await
}
PutGetStageInfo::Azure(_) => Err(SnowflakeApiError::Unimplemented(
"PUT local file requests for Azure".to_string(),
)),
PutGetStageInfo::Gcs(_) => Err(SnowflakeApiError::Unimplemented(
"PUT local file requests for GCS".to_string(),
)),
}
}

async fn put_to_s3(
src_locations: Vec<String>,
info: AwsPutGetStageInfo,
max_parallel_uploads: usize,
max_file_size_threshold: i64,
) -> Result<(), SnowflakeApiError> {
// These constants are based on the snowflake website
let (bucket_name, bucket_path) = info
.location
.split_once('/')
.ok_or(SnowflakeApiError::InvalidBucketPath(info.location.clone()))?;

let s3 = AmazonS3Builder::new()
.with_region(info.region)
.with_bucket_name(bucket_name)
.with_access_key_id(info.creds.aws_key_id)
.with_secret_access_key(info.creds.aws_secret_key)
.with_token(info.creds.aws_token)
.build()?;

let files = list_files(src_locations, max_file_size_threshold).await?;

for src_path in files.large_files {
put_file(&s3, &src_path, bucket_path).await?;
}

let limit_store = LimitStore::new(s3, max_parallel_uploads);
put_files_par(files.small_files, bucket_path, limit_store).await?;

Ok(())
}

/// Sorts upload files by whether they are larger or smaller than the threshold
struct SizedFiles {
small_files: Vec<String>,
large_files: Vec<String>,
}

// todo: security vulnerability, external system tells you which local files to upload
async fn list_files(
src_locations: Vec<String>,
threshold: i64,
) -> Result<SizedFiles, SnowflakeApiError> {
let paths = task::spawn_blocking(move || traverse_globs(src_locations)).await??;
let paths_meta = fetch_metadata(paths).await?;

let threshold = u64::try_from(threshold).unwrap_or(0);
let mut small_files = vec![];
let mut large_files = vec![];
for pm in paths_meta {
if pm.meta.len() > threshold {
large_files.push(pm.path);
} else {
small_files.push(pm.path);
}
}

Ok(SizedFiles {
small_files,
large_files,
})
}

fn traverse_globs(globs: Vec<String>) -> Result<Vec<String>, SnowflakeApiError> {
let mut res = vec![];
for g in globs {
for path in glob::glob(&g)? {
if let Some(p) = path?.to_str() {
res.push(p.to_owned());
}
}
}

Ok(res)
}

struct PathMeta {
path: String,
meta: Metadata,
}

async fn fetch_metadata(paths: Vec<String>) -> Result<Vec<PathMeta>, SnowflakeApiError> {
let metadata = FuturesUnordered::new();
for path in paths {
let task = async move {
let meta = tokio::fs::metadata(&path).await?;
Ok(PathMeta { path, meta })
};
metadata.push(task);
}

metadata.try_collect().await
}

async fn put_file<T: ObjectStore>(
store: &T,
src_path: &str,
bucket_path: &str,
) -> Result<(), SnowflakeApiError> {
let filename = Path::new(&src_path)
.file_name()
.and_then(|f| f.to_str())
.ok_or(SnowflakeApiError::InvalidLocalPath(src_path.to_owned()))?;

let dest_path = format!("{bucket_path}{filename}");
let dest_path = object_store::path::Path::parse(dest_path)?;
let src_path = object_store::path::Path::parse(src_path)?;
let fs = LocalFileSystem::new().get(&src_path).await?;

store.put(&dest_path, fs.bytes().await?).await?;

Ok::<(), SnowflakeApiError>(())
}

/// This function uploads files in parallel, useful for files below the threshold
/// One potential issue is that file size could be changed between when the file is
/// checked and when it is uploaded
async fn put_files_par<T: ObjectStore>(
files: Vec<String>,
bucket_path: &str,
limit_store: LimitStore<T>,
) -> Result<(), SnowflakeApiError> {
let limit_store = Arc::new(limit_store);
let mut tasks = task::JoinSet::new();
for src_path in files {
let bucket_path = bucket_path.to_owned();
let limit_store = Arc::clone(&limit_store);
tasks.spawn(async move { put_file(limit_store.as_ref(), &src_path, &bucket_path).await });
}
while let Some(result) = tasks.join_next().await {
result??;
}

Ok(())
}
3 changes: 1 addition & 2 deletions snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
@@ -201,9 +201,8 @@ pub struct PutGetResponseData {
// inconsistent case naming
#[serde(rename = "src_locations", default)]
pub src_locations: Vec<String>,
// todo: support upload parallelism
// file upload parallelism
pub parallel: i32,
pub parallel: usize, // fixme: originally i32, handle this in parsing somehow?
// file size threshold, small ones are should be uploaded with given parallelism
pub threshold: i64,
// doesn't need compression if source is already compressed

0 comments on commit b908ee0

Please sign in to comment.