Skip to content

Commit

Permalink
Merge pull request #717 from splitgraph/align-coalesced-batch-nullabi…
Browse files Browse the repository at this point in the history
…lity

Align batch schema nullability during squashing
  • Loading branch information
gruuya authored Oct 22, 2024
2 parents b66a48a + 9f03eeb commit eea12e1
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 15 deletions.
34 changes: 30 additions & 4 deletions src/sync/schema/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::sync::SyncError;
use arrow_schema::{DataType, FieldRef, SchemaRef};
use crate::sync::{SyncError, SyncResult};
use arrow_schema::{DataType, FieldRef, Fields, SchemaRef};
use clade::sync::{ColumnDescriptor, ColumnRole};
use std::collections::{HashMap, HashSet};
use std::fmt::{Display, Formatter};
Expand All @@ -20,15 +20,15 @@ impl SyncSchema {
column_descriptors: Vec<ColumnDescriptor>,
schema: SchemaRef,
validate_pks: bool,
) -> Result<Self, SyncError> {
) -> SyncResult<Self> {
if column_descriptors.len() != schema.flattened_fields().len() {
return Err(SyncError::SchemaError {
reason: "Column descriptors do not match the schema".to_string(),
});
}

// Validate field role's are parsable, we have the correct number of old/new PKs,
// and Changed role is non-nullable boolean type which points to an existing column
// and Changed role is a boolean type which points to an existing column
// TODO: Validate a column can not be a PK and Value at the same time
let mut old_pk_types = HashSet::new();
let mut new_pk_types = HashSet::new();
Expand Down Expand Up @@ -122,6 +122,32 @@ impl SyncSchema {
}
}

// Replace the existing field references for sync columns with new ones.
pub fn with_fields(&mut self, fields: &Fields) -> SyncResult<()> {
if fields.len() != self.columns.len() {
return Err(SyncError::SchemaError {
reason: "New and old field counts are different".to_string(),
});
}

for (col, field) in self.columns.iter_mut().zip(fields.iter()) {
let col_data_type = col.field.data_type();
let field_data_type = field.data_type();

if col_data_type != field_data_type {
return Err(SyncError::SchemaError {
reason: format!(
"Schema mismatch for column {col:?}: expected data type {col_data_type}, got {field_data_type}"
),
});
}

col.field = field.clone();
}

Ok(())
}

pub fn column(&self, name: &str, role: ColumnRole) -> Option<&SyncColumn> {
self.indices
.get(&role)
Expand Down
74 changes: 66 additions & 8 deletions src/sync/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use arrow::array::{new_null_array, Array, ArrayRef, RecordBatch, Scalar, UInt64A
use arrow::compute::kernels::cmp::{gt_eq, lt_eq};
use arrow::compute::{and_kleene, bool_or, concat_batches, filter, is_not_null, take};
use arrow_row::{Row, RowConverter, SortField};
use arrow_schema::{Field, Fields, Schema};
use clade::sync::ColumnRole;
use datafusion::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
use datafusion::physical_optimizer::pruning::PruningStatistics;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{col, lit, Accumulator, Expr};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use tracing::log::{debug, warn};

// Returns the total number of bytes and rows in the slice of batches
Expand All @@ -29,12 +31,30 @@ pub(super) fn get_size_and_rows(batches: &[RecordBatch]) -> (usize, usize) {
// output batch (meaning the last NewPk and Value role columns and the last Value column where the
// accompanying Changed field was `true`).
pub(super) fn squash_batches(
sync_schema: &SyncSchema,
sync_schema: &mut SyncSchema,
data: &[RecordBatch],
) -> Result<RecordBatch> {
// Concatenate all the record batches into a single one
let schema = data.first().unwrap().schema();
) -> SyncResult<RecordBatch> {
debug!("Concatenating {} batch(es)", data.len());
// Concatenate all the record batches into a single one, after making all fields nullable first.
let nullable_fields = Fields::from(
data.first()
.unwrap()
.schema()
.fields()
.iter()
.map(|field| {
if field.is_nullable() {
field.clone()
} else {
// Clone the field but set it to nullable
Arc::new(Field::new(field.name(), field.data_type().clone(), true))
}
})
.collect::<Vec<_>>(),
);

sync_schema.with_fields(&nullable_fields)?;
let schema = Arc::new(Schema::new(nullable_fields));
let batch = concat_batches(&schema, data)?;

// Get columns, sort fields and null arrays for a particular role
Expand Down Expand Up @@ -469,7 +489,7 @@ mod tests {
Field::new("value_c3", DataType::Utf8, true),
]));

let sync_schema = arrow_to_sync_schema(schema.clone())?;
let mut sync_schema = arrow_to_sync_schema(schema.clone())?;

// Test a batch with several edge cases with:
// - multiple changes to the same row
Expand Down Expand Up @@ -511,7 +531,7 @@ mod tests {
],
)?;

let squashed = squash_batches(&sync_schema, &[batch.clone()])?;
let squashed = squash_batches(&mut sync_schema, &[batch.clone()])?;

let expected = [
"+-----------+-----------+----------+------------+----------+",
Expand Down Expand Up @@ -540,7 +560,7 @@ mod tests {
Field::new("value_c4", DataType::Utf8, true),
]));

let sync_schema = arrow_to_sync_schema(schema.clone())?;
let mut sync_schema = arrow_to_sync_schema(schema.clone())?;

let mut rng = rand::thread_rng();
let row_count = rng.gen_range(1..=1000); // With more than 1000 rows the test becomes slow
Expand Down Expand Up @@ -660,7 +680,7 @@ mod tests {
],
)?;

let squashed = squash_batches(&sync_schema, &[batch.clone()])?;
let squashed = squash_batches(&mut sync_schema, &[batch.clone()])?;
println!(
"Squashed PKs from {row_count} to {} rows",
squashed.num_rows()
Expand All @@ -674,6 +694,44 @@ mod tests {
Ok(())
}

#[test]
fn test_squash_align_nullability() -> Result<(), Box<dyn std::error::Error>> {
// Old PK nullable
let schema = Arc::new(Schema::new(vec![
Field::new("old_pk_c1", DataType::Int32, true),
Field::new("new_pk_c1", DataType::Int32, false),
]));

let mut sync_schema = arrow_to_sync_schema(schema.clone())?;

let batch_1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![None])),
Arc::new(Int32Array::from(vec![1])),
],
)?;

// New PK nullable
let schema = Arc::new(Schema::new(vec![
Field::new("old_pk_c1", DataType::Int32, false),
Field::new("new_pk_c1", DataType::Int32, true),
]));
let batch_2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1])),
Arc::new(Int32Array::from(vec![None])),
],
)?;

let squashed = squash_batches(&mut sync_schema, &[batch_1, batch_2])?;

assert_eq!(squashed.num_rows(), 0);

Ok(())
}

#[test]
fn test_sync_filter() -> Result<(), Box<dyn std::error::Error>> {
let schema = Arc::new(Schema::new(vec![
Expand Down
7 changes: 4 additions & 3 deletions src/sync/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ impl SeafowlDataSyncWriter {
Ok(())
}

// Criteria for return the cached entry ready to be persisted to storage.
// Criteria for flushing a cached entry to object storage.
//
// First flush any records that are explicitly beyond the configured max
// lag, followed by further entries if we're still above max cache size.
fn flush_ready(&mut self) -> SyncResult<Option<String>> {
Expand Down Expand Up @@ -372,7 +373,7 @@ impl SeafowlDataSyncWriter {
Ok(None)
}

// Flush the table containing the oldest sync in memory
// Flush the table with the provided url
async fn flush_syncs(&mut self, url: String) -> SyncResult<()> {
self.physical_squashing(&url)?;
let entry = match self.syncs.get(&url) {
Expand Down Expand Up @@ -512,7 +513,7 @@ impl SeafowlDataSyncWriter {
let (old_size, old_rows) = get_size_and_rows(&item.data);

let start = Instant::now();
let batch = squash_batches(&item.sync_schema, &item.data)?;
let batch = squash_batches(&mut item.sync_schema, &item.data)?;
let duration = start.elapsed().as_millis();

// Get new size and row count
Expand Down

0 comments on commit eea12e1

Please sign in to comment.