Skip to content

Commit

Permalink
feat: Postgres Provider
Browse files Browse the repository at this point in the history
  • Loading branch information
holicc committed Aug 15, 2024
1 parent 858fb68 commit cb0c5bb
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 62 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ tokio = { version = "1.37.0", features = ["full"] }
async-trait = "0.1.80"
tokio-stream = "0.1.15"
log = "0.4.21"
dashmap = "6.0.1"
connectorx = { git = "https://github.com/holicc/connector-x.git" }


[workspace.lints.rust]
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ services:
POSTGRES_USER: root
POSTGRES_PASSWORD: root
volumes:
- ./tests/testdata/db/pg/migration.sql:/docker-entrypoint-initdb.d/init.sql
- ./tests/db/pg:/docker-entrypoint-initdb.d
ports:
- 5433:5432
15 changes: 6 additions & 9 deletions qurious/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@ version = "0.1.0"
edition = "2021"




[dependencies]
sqlparser = { workspace = true }
parquet = { workspace = true }
arrow = { workspace = true }
url = { workspace = true }
tokio = { workspace = true }
async-trait = { workspace = true }
tokio-stream = { workspace = true }
log = { workspace = true }
dashmap = "6.0.1"

dashmap = { workspace = true }
connectorx = { workspace = true, features = ["src_postgres", "dst_arrow"] }
postgres = "0.19.8"
itertools = "0.13.0"
rayon = "1.10.0"


[dev-dependencies]
arrow = { version = "52.0.0", features = ["prettyprint", "test_utils"] }
arrow = { version = "52.0.0", features = ["prettyprint", "test_utils"] }
3 changes: 3 additions & 0 deletions qurious/src/common/table_relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ impl TableRelation {
}
}
}



}

impl Display for TableRelation {
Expand Down
14 changes: 14 additions & 0 deletions qurious/src/datasource/connectorx/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::error::{Error, Result};
use arrow::record_batch::RecordBatch;
use connectorx::prelude::{get_arrow, ArrowDestination, CXQuery, SourceConn};

pub mod postgres;

fn query_batchs(source: &SourceConn, sql: &str) -> Result<Vec<RecordBatch>> {
query(source, sql).and_then(|dst| dst.arrow().map_err(|e| Error::InternalError(e.to_string())))
}

fn query(source: &SourceConn, sql: &str) -> Result<ArrowDestination> {
let queries = &[CXQuery::from(sql)];
get_arrow(source, None, queries).map_err(|e| Error::InternalError(e.to_string()))
}
255 changes: 255 additions & 0 deletions qurious/src/datasource/connectorx/postgres.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
use std::fmt::Debug;
use std::sync::Arc;

use arrow::array::{as_string_array, RecordBatch};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use connectorx::prelude::{SourceConn, SourceType};
use dashmap::DashMap;
use itertools::multizip;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

use crate::common::table_relation::TableRelation;
use crate::datatypes::scalar::ScalarValue;
use crate::error::{Error, Result};
use crate::logical::expr::LogicalExpr;
use crate::provider::catalog::CatalogProvider;
use crate::provider::schema::SchemaProvider;
use crate::provider::table::TableProvider;

use super::query_batchs;

#[derive(Debug)]
pub struct PostgresCatalogProvider {
schemas: dashmap::DashMap<String, Arc<dyn SchemaProvider>>,
}

impl PostgresCatalogProvider {
pub fn try_new(url: &str) -> Result<Self> {
let source = SourceConn::try_from(url).map_err(|e| Error::InternalError(e.to_string()))?;
let config: postgres::Config = url
.parse()
.map_err(|e: postgres::Error| Error::InternalError(e.to_string()))?;
match source.ty {
SourceType::Postgres => {}
_ => return Err(Error::InternalError("Invalid source type".to_string())),
}
let db_name = config
.get_dbname()
.ok_or(Error::InternalError("No database name".to_string()))?;
let schemas: DashMap<String, Arc<dyn SchemaProvider>> = DashMap::new();

for batch in query_batchs(&source, "select schema_name from information_schema.schemata")? {
let array = as_string_array(batch.column(0));
let values = array
.iter()
.filter_map(|x| x.map(|x| x.to_string()))
.collect::<Vec<_>>();

values
.into_par_iter()
.map(|schema| PostgresSchemaProvider::try_new(source.clone(), db_name.to_owned(), schema).map(Arc::new))
.collect::<Result<Vec<_>>>()?
.into_iter()
.for_each(|provider| {
schemas.insert(provider.schema.clone(), provider);
});
}

Ok(Self { schemas })
}
}

impl CatalogProvider for PostgresCatalogProvider {
fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
self.schemas.get(name).map(|x| x.value().clone())
}
}

#[derive(Debug)]
pub struct PostgresSchemaProvider {
schema: String,
tables: DashMap<String, Arc<dyn TableProvider>>,
}

impl PostgresSchemaProvider {
pub fn try_new(source: SourceConn, db_name: String, schema: String) -> Result<Self> {
let sql = format!(
"select table_name from information_schema.tables where table_schema = '{}'",
schema
);
let tables: DashMap<String, Arc<dyn TableProvider>> = DashMap::new();
for batch in query_batchs(&source, &sql)? {
let array = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.expect("Failed to downcast");
let values = array
.iter()
.filter_map(|x| x.map(|x| x.to_string()))
.collect::<Vec<_>>();

values
.into_par_iter()
.map(|table| {
PostgresTableProvider::try_new(source.clone(), format!("{}.{}.{}", db_name, schema, table).into())
.map(Arc::new)
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.for_each(|table_provider| {
tables.insert(table_provider.table.table().to_owned(), table_provider);
});
}

Ok(Self { schema, tables })
}
}

impl SchemaProvider for PostgresSchemaProvider {
fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.tables.get(name).map(|x| x.value().clone())
}
}

#[derive(Debug)]
pub struct PostgresTableProvider {
schema: SchemaRef,
table: TableRelation,
source: SourceConn,
default_values: DashMap<String, ScalarValue>,
}

impl PostgresTableProvider {
pub fn try_new(source: SourceConn, table: TableRelation) -> Result<Self> {
let sql = format!(
r#"
select
column_name,
column_default,
is_nullable,
data_type
from
information_schema.columns
where table_catalog = '{}' and table_schema = '{}' and table_name = '{}'"#,
table.catalog().ok_or(Error::InternalError("No catalog".to_string()))?,
table.schema().ok_or(Error::InternalError("No schema".to_string()))?,
table.table()
);
let default_values: DashMap<String, ScalarValue> = DashMap::new();
let mut fields = vec![];
for batch in query_batchs(&source, &sql)? {
let column_names = as_string_array(batch.column(0)).into_iter();
let column_defaults = as_string_array(batch.column(1)).into_iter();
let column_nullables = as_string_array(batch.column(2)).into_iter();
let column_types = as_string_array(batch.column(3)).into_iter();
// zip the columns
let rows = multizip((column_names, column_defaults, column_nullables, column_types)).collect::<Vec<_>>();
for (col_name, default_val, nullable, col_type) in rows {
let col_name = col_name.expect("Failed to get column name").to_owned();

if let Some(default_val) = default_val {
default_values.insert(
col_name.clone(),
to_default_value(
&to_arrow_type(col_type.expect("Failed to get column type")),
&default_val,
)?,
);
}

fields.push(Field::new(
col_name,
to_arrow_type(col_type.expect("Failed to get column type")),
nullable.map(|v| v == "YES").unwrap_or_default(),
));
}
}

Ok(Self {
default_values,
schema: Arc::new(Schema::new(fields)),
source,
table,
})
}
}

impl TableProvider for PostgresTableProvider {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn scan(&self, projection: Option<Vec<String>>, filters: &[LogicalExpr]) -> Result<Vec<RecordBatch>> {
let projection = projection.unwrap_or_else(|| self.schema.fields().iter().map(|x| x.name().clone()).collect());
let mut sql = format!("select {} from {}", projection.join(","), self.table);

if !filters.is_empty() {
let filters = filters.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(" and ");
sql = format!("{} where {}", sql, filters);
}

query_batchs(&self.source, &sql)
}

fn get_column_default(&self, _column: &str) -> Option<ScalarValue> {
self.default_values.get(_column).map(|x| x.value().clone())
}
}

fn to_arrow_type(col_type: &str) -> DataType {
match col_type {
"bigint"|"integer" => DataType::Int64,
"smallint" => DataType::Int16,
"character varying" => DataType::Utf8,
"character" => DataType::Utf8,
"text" => DataType::Utf8,
"timestamp without time zone" => DataType::Timestamp(TimeUnit::Second, None),
"timestamp with time zone" => DataType::Timestamp(TimeUnit::Second, None),
"date" => DataType::Date32,
"boolean" => DataType::Boolean,
"real" => DataType::Float32,
"double precision" => DataType::Float64,
"numeric" => DataType::Float64,
_ => DataType::Utf8,
}
}

fn to_default_value(data_type: &DataType, default_value: &str) -> Result<ScalarValue> {
match data_type {
DataType::Int64 => Ok(ScalarValue::Int64(Some(default_value.parse()?))),
DataType::Int32 => Ok(ScalarValue::Int32(Some(default_value.parse()?))),
DataType::Int16 => Ok(ScalarValue::Int16(Some(default_value.parse()?))),
DataType::Utf8 => Ok(ScalarValue::Utf8(Some(default_value.to_string()))),
DataType::Boolean => {
// Parse the default value as a boolean
let boolean = default_value.parse()?;
Ok(ScalarValue::Boolean(Some(boolean)))
}
DataType::Float32 => Ok(ScalarValue::Float32(Some(default_value.parse()?))),
DataType::Float64 => Ok(ScalarValue::Float64(Some(default_value.parse()?))),
_ => Err(Error::InternalError("Unsupported data type".to_string())),
}
}

#[cfg(test)]
mod tests {

use arrow::util::pretty::print_batches;

use super::*;

#[test]
fn test_postgres_catalog_provider() {
let url = "postgresql://root:root@localhost:5433/qurious";
let catalog = PostgresCatalogProvider::try_new(url).unwrap();

let schema = catalog.schema("public").unwrap();

let table = schema.table("schools").unwrap();

let batch = table.scan(None, &vec![]).unwrap();

print_batches(&batch).unwrap();
}
}
4 changes: 2 additions & 2 deletions qurious/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl TableProvider for MemoryTable {
}
}

fn get_column_default(&self, column: &str) -> Option<&ScalarValue> {
self.column_defaults.get(column)
fn get_column_default(&self, column: &str) -> Option<ScalarValue> {
self.column_defaults.get(column).map(|v| v.clone())
}
}
2 changes: 1 addition & 1 deletion qurious/src/datasource/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod file;
pub mod memory;
pub mod postgres;
pub mod connectorx;


28 changes: 0 additions & 28 deletions qurious/src/datasource/postgres.rs

This file was deleted.

Loading

0 comments on commit cb0c5bb

Please sign in to comment.