Skip to content

Commit

Permalink
Add support for Tableau (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
Brent Gardner authored Oct 28, 2022
1 parent 1d24f1d commit 2c1bc7b
Showing 1 changed file with 196 additions and 43 deletions.
239 changes: 196 additions & 43 deletions ballista/scheduler/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ use log::{debug, error, warn};
use std::convert::TryFrom;
use std::pin::Pin;
use std::str::FromStr;
use std::string::ToString;
use std::sync::Arc;
use std::time::Duration;
use tonic::{Request, Response, Status, Streaming};

use crate::scheduler_server::SchedulerServer;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::sql::ProstMessageExt;
use arrow_flight::utils::flight_data_from_arrow_batch;
use arrow_flight::SchemaAsIpc;
use ballista_core::config::BallistaConfig;
use ballista_core::serde::protobuf;
Expand All @@ -52,14 +54,21 @@ use ballista_core::serde::protobuf::SuccessfulJob;
use ballista_core::utils::create_grpc_client_connection;
use dashmap::DashMap;
use datafusion::arrow;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::array::{ArrayRef, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::error::ArrowError;
use datafusion::arrow::ipc::writer::{IpcDataGenerator, IpcWriteOptions};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::DFSchemaRef;
use datafusion::logical_expr::LogicalPlan;
use datafusion::physical_plan::common::batch_byte_size;
use datafusion::prelude::SessionContext;
use datafusion_proto::protobuf::LogicalPlanNode;
use itertools::Itertools;
use prost::Message;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::time::sleep;
use tokio_stream::wrappers::ReceiverStream;
use tonic::codegen::futures_core::Stream;
use tonic::metadata::MetadataValue;
use uuid::Uuid;
Expand All @@ -70,6 +79,8 @@ pub struct FlightSqlServiceImpl {
contexts: Arc<DashMap<Uuid, Arc<SessionContext>>>,
}

const TABLE_TYPES: [&str; 2] = ["TABLE", "VIEW"];

impl FlightSqlServiceImpl {
pub fn new(server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode>) -> Self {
Self {
Expand All @@ -79,6 +90,43 @@ impl FlightSqlServiceImpl {
}
}

fn tables(&self, ctx: Arc<SessionContext>) -> Result<RecordBatch, ArrowError> {
let schema = Arc::new(Schema::new(vec![
Field::new("catalog_name", DataType::Utf8, true),
Field::new("db_schema_name", DataType::Utf8, true),
Field::new("table_name", DataType::Utf8, false),
Field::new("table_type", DataType::Utf8, false),
]));
let tables = ctx.tables()?;
let names: Vec<_> = tables.iter().map(|it| Some(it.as_str())).collect();
let types: Vec<_> = names.iter().map(|_| Some("TABLE")).collect();
let cats: Vec<_> = names.iter().map(|_| None).collect();
let schemas: Vec<_> = names.iter().map(|_| None).collect();
let rb = RecordBatch::try_new(
schema,
[cats, schemas, names, types]
.into_iter()
.map(|i| Arc::new(StringArray::from(i.clone())) as ArrayRef)
.collect::<Vec<_>>(),
)?;
Ok(rb)
}

fn table_types() -> Result<RecordBatch, ArrowError> {
let schema = Arc::new(Schema::new(vec![Field::new(
"table_type",
DataType::Utf8,
false,
)]));
RecordBatch::try_new(
schema,
[TABLE_TYPES]
.into_iter()
.map(|i| Arc::new(StringArray::from(i.to_vec())) as ArrayRef)
.collect::<Vec<_>>(),
)
}

async fn create_ctx(&self) -> Result<Uuid, Status> {
let config_builder = BallistaConfig::builder();
let config = config_builder
Expand Down Expand Up @@ -233,6 +281,34 @@ impl FlightSqlServiceImpl {
Ok(fieps)
}

fn make_local_fieps(&self, job_id: &str) -> Result<Vec<FlightEndpoint>, Status> {
let (host, port) = ("127.0.0.1".to_string(), 50050); // TODO: use advertise host
let fetch = protobuf::FetchPartition {
job_id: job_id.to_string(),
stage_id: 0,
partition_id: 0,
path: job_id.to_string(),
host: host.clone(),
port,
};
let fetch = protobuf::Action {
action_type: Some(FetchPartition(fetch)),
settings: vec![],
};
let authority = format!("{}:{}", &host, &port); // TODO: use advertise host
let loc = Location {
uri: format!("grpc+tcp://{}", authority),
};
let buf = fetch.as_any().encode_to_vec();
let ticket = Ticket { ticket: buf };
let fiep = FlightEndpoint {
ticket: Some(ticket),
location: vec![loc],
};
let fieps = vec![fiep];
Ok(fieps)
}

fn cache_plan(&self, plan: LogicalPlan) -> Result<Uuid, Status> {
let handle = Uuid::new_v4();
self.statements.insert(handle, plan);
Expand All @@ -257,6 +333,11 @@ impl FlightSqlServiceImpl {

fn df_schema_to_arrow(&self, schema: &DFSchemaRef) -> Result<Vec<u8>, Status> {
let arrow_schema: Schema = (&**schema).into();
let schema_bytes = self.schema_to_arrow(Arc::new(arrow_schema))?;
Ok(schema_bytes)
}

fn schema_to_arrow(&self, arrow_schema: SchemaRef) -> Result<Vec<u8>, Status> {
let options = IpcWriteOptions::default();
let pair = SchemaAsIpc::new(&arrow_schema, &options);
let data_gen = IpcDataGenerator::default();
Expand Down Expand Up @@ -335,6 +416,48 @@ impl FlightSqlServiceImpl {
let resp = Self::create_resp(schema_bytes, fieps, num_rows, num_bytes);
Ok(resp)
}

async fn record_batch_to_resp(
rb: &RecordBatch,
) -> Result<
Response<Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send>>>,
Status,
> {
let (tx, rx): (
Sender<Result<FlightData, Status>>,
Receiver<Result<FlightData, Status>>,
) = channel(2);
let options = IpcWriteOptions::default();
let schema = SchemaAsIpc::new(rb.schema().as_ref(), &options).into();
tx.send(Ok(schema))
.await
.map_err(|e| Status::internal("Error sending schema".to_string()))?;
let (dict, flight) = flight_data_from_arrow_batch(&rb, &options);
let flights = dict.into_iter().chain(std::iter::once(flight));
for flight in flights.into_iter() {
tx.send(Ok(flight))
.await
.map_err(|e| Status::internal("Error sending flight".to_string()))?;
}
let resp = Response::new(Box::pin(ReceiverStream::new(rx))
as Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>);
Ok(resp)
}

fn batch_to_schema_resp(
&self,
data: &RecordBatch,
name: &str,
) -> Result<Response<FlightInfo>, Status> {
let num_bytes = batch_byte_size(&data) as i64;
let schema = data.schema();
let num_rows = data.num_rows() as i64;

let fieps = self.make_local_fieps(name)?;
let schema_bytes = self.schema_to_arrow(schema)?;
let resp = Self::create_resp(schema_bytes, fieps, num_rows, num_bytes);
Ok(resp)
}
}

#[tonic::async_trait]
Expand Down Expand Up @@ -402,49 +525,70 @@ impl FlightSqlService for FlightSqlServiceImpl {

async fn do_get_fallback(
&self,
_request: Request<Ticket>,
request: Request<Ticket>,
message: prost_types::Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
println!("type_url: {}", message.type_url);
if message.is::<protobuf::Action>() {
println!("got action!");
let action: protobuf::Action = message
.unpack()
.map_err(|e| Status::internal(format!("{:?}", e)))?
.ok_or(Status::internal("Expected an Action but got None!"))?;
println!("action={:?}", action);
let (host, port) = match &action.action_type {
Some(FetchPartition(fp)) => (fp.host.clone(), fp.port),
None => Err(Status::internal("Expected an ActionType but got None!"))?,
};
debug!("do_get_fallback type_url: {}", message.type_url);
let ctx = self.get_ctx(&request)?;
if !message.is::<protobuf::Action>() {
Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {}",
message.type_url
)))?
}

let action: protobuf::Action = message
.unpack()
.map_err(|e| Status::internal(format!("{:?}", e)))?
.ok_or(Status::internal("Expected an Action but got None!"))?;
let fp = match &action.action_type {
Some(FetchPartition(fp)) => fp.clone(),
None => Err(Status::internal("Expected an ActionType but got None!"))?,
};

let addr = format!("http://{}:{}", host, port);
println!("BallistaClient connecting to {}", addr);
let connection =
create_grpc_client_connection(addr.clone())
.await
.map_err(|e| {
Status::internal(format!(
// Well-known job ID: respond with the data
match fp.job_id.as_str() {
"get_flight_info_table_types" => {
debug!("Responding with table types");
let rb = FlightSqlServiceImpl::table_types().map_err(|e| {
Status::internal("Error getting table types".to_string())
})?;
let resp = Self::record_batch_to_resp(&rb).await?;
return Ok(resp);
}
"get_flight_info_tables" => {
debug!("Responding with tables");
let rb = self
.tables(ctx)
.map_err(|e| Status::internal("Error getting tables".to_string()))?;
let resp = Self::record_batch_to_resp(&rb).await?;
return Ok(resp);
}
_ => {}
}

// Proxy the flight
let addr = format!("http://{}:{}", fp.host, fp.port);
debug!("Scheduler proxying flight for to {}", addr);
let connection =
create_grpc_client_connection(addr.clone())
.await
.map_err(|e| {
Status::internal(format!(
"Error connecting to Ballista scheduler or executor at {}: {:?}",
addr, e
))
})?;
let mut flight_client = FlightServiceClient::new(connection);
let buf = action.encode_to_vec();
let request = Request::new(Ticket { ticket: buf });
})?;
let mut flight_client = FlightServiceClient::new(connection);
let buf = action.encode_to_vec();
let request = Request::new(Ticket { ticket: buf });

let stream = flight_client
.do_get(request)
.await
.map_err(|e| Status::internal(format!("{:?}", e)))?
.into_inner();
return Ok(Response::new(Box::pin(stream)));
}

Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {}",
message.type_url
)))
let stream = flight_client
.do_get(request)
.await
.map_err(|e| Status::internal(format!("{:?}", e)))?
.into_inner();
Ok(Response::new(Box::pin(stream)))
}

async fn get_flight_info_statement(
Expand Down Expand Up @@ -494,24 +638,33 @@ impl FlightSqlService for FlightSqlServiceImpl {
debug!("get_flight_info_schemas");
Err(Status::unimplemented("Implement get_flight_info_schemas"))
}

async fn get_flight_info_tables(
&self,
_query: CommandGetTables,
_request: Request<FlightDescriptor>,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
debug!("get_flight_info_tables");
Err(Status::unimplemented("Implement get_flight_info_tables"))
let ctx = self.get_ctx(&request)?;
let data = self
.tables(ctx)
.map_err(|e| Status::internal(format!("Error getting tables: {}", e)))?;
let resp = self.batch_to_schema_resp(&data, "get_flight_info_tables")?;
Ok(resp)
}

async fn get_flight_info_table_types(
&self,
_query: CommandGetTableTypes,
_request: Request<FlightDescriptor>,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
debug!("get_flight_info_table_types");
Err(Status::unimplemented(
"Implement get_flight_info_table_types",
))
let data = FlightSqlServiceImpl::table_types()
.map_err(|e| Status::internal(format!("Error getting table types: {}", e)))?;
let resp = self.batch_to_schema_resp(&data, "get_flight_info_table_types")?;
Ok(resp)
}

async fn get_flight_info_sql_info(
&self,
_query: CommandGetSqlInfo,
Expand Down

0 comments on commit 2c1bc7b

Please sign in to comment.