Skip to content

Commit

Permalink
fix: don't eagerly materialize fields that the user hasn't asked for (#…
Browse files Browse the repository at this point in the history
…3442)

We added logic a while back to eagerly materialize fields if they are
narrow and there is a filter. However, we forgot to ensure that those
fields are actually part of the final projection. The result is that we
end up loading many columns the user doesn't want and then throwing them
away.

This fix changes the set of fields we load to only be those that are
asked for.
  • Loading branch information
westonpace authored Feb 11, 2025
1 parent 2e2bf1a commit c70d1d2
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.4.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
9 changes: 8 additions & 1 deletion java/core/src/test/java/com/lancedb/lance/FilterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -102,7 +103,13 @@ void testFilters() throws Exception {
}

private void testFilter(String filter, int expectedCount) throws Exception {
try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().filter(filter).build())) {
try (LanceScanner scanner =
dataset.newScan(
new ScanOptions.Builder()
.columns(Arrays.asList())
.withRowId(true)
.filter(filter)
.build())) {
assertEquals(expectedCount, scanner.countRows());
}
}
Expand Down
8 changes: 6 additions & 2 deletions java/core/src/test/java/com/lancedb/lance/ScannerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ void testDatasetScannerCountRows() throws Exception {
// write id with value from 0 to 39
try (Dataset dataset = testDataset.write(1, 40)) {
try (LanceScanner scanner =
dataset.newScan(new ScanOptions.Builder().filter("id < 20").build())) {
dataset.newScan(
new ScanOptions.Builder()
.columns(Arrays.asList())
.withRowId(true)
.filter("id < 20")
.build())) {
assertEquals(20, scanner.countRows());
}
}
Expand Down Expand Up @@ -387,7 +392,6 @@ void testDatasetScannerBatchReadahead() throws Exception {
// This test is more about ensuring that the batchReadahead parameter is accepted
// and doesn't cause errors. The actual effect of batchReadahead might not be
// directly observable in this test.
assertEquals(totalRows, scanner.countRows());
try (ArrowReader reader = scanner.scanBatches()) {
int rowCount = 0;
while (reader.loadNextBatch()) {
Expand Down
4 changes: 3 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,9 @@ def count_rows(
"""
if isinstance(filter, pa.compute.Expression):
# TODO: consolidate all to use scanner
return self.scanner(filter=filter).count_rows()
return self.scanner(
columns=[], with_row_id=True, filter=filter
).count_rows()

return self._ds.count_rows(filter)

Expand Down
16 changes: 10 additions & 6 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,10 @@ def fragment_id(self):
def count_rows(
self, filter: Optional[Union[pa.compute.Expression, str]] = None
) -> int:
if filter is not None:
return self.scanner(filter=filter).count_rows()
if isinstance(filter, pa.compute.Expression):
return self.scanner(
with_row_id=True, columns=[], filter=filter
).count_rows()
return self._fragment.count_rows(filter)

@property
Expand Down Expand Up @@ -540,10 +542,12 @@ def merge(

def merge_columns(
self,
value_func: Dict[str, str]
| BatchUDF
| ReaderLike
| Callable[[pa.RecordBatch], pa.RecordBatch],
value_func: (
Dict[str, str]
| BatchUDF
| ReaderLike
| Callable[[pa.RecordBatch], pa.RecordBatch]
),
columns: Optional[list[str]] = None,
batch_size: Optional[int] = None,
reader_schema: Optional[pa.Schema] = None,
Expand Down
12 changes: 11 additions & 1 deletion python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,16 @@ def test_count_rows(tmp_path: Path):
assert dataset.count_rows(filter="a < 50") == 50


def test_select_none(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
ds = lance.write_dataset(table, base_dir)

assert "projection=[a]" in ds.scanner(
columns=[], filter="a < 50", with_row_id=True
).explain_plan(True)


def test_get_fragments(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
Expand Down Expand Up @@ -2200,7 +2210,7 @@ def test_scan_count_rows(tmp_path: Path):
df = pd.DataFrame({"a": range(42), "b": range(42)})
dataset = lance.write_dataset(df, base_dir)

assert dataset.scanner().count_rows() == 42
assert dataset.scanner(columns=[], with_row_id=True).count_rows() == 42
assert dataset.count_rows(filter="a < 10") == 10
assert dataset.count_rows(filter=pa_ds.field("a") < 20) == 20

Expand Down
3 changes: 3 additions & 0 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,9 @@ impl FileFragment {
match filter {
Some(expr) => self
.scan()
.project(&Vec::<String>::default())
.unwrap()
.with_row_id()
.filter(&expr)?
.count_rows()
.await
Expand Down
156 changes: 126 additions & 30 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl MaterializationStyle {
}

/// Filter for filtering rows
#[derive(Debug)]
pub enum LanceFilter {
/// The filter is an SQL string
Sql(String),
Expand Down Expand Up @@ -1027,11 +1028,22 @@ impl Scanner {
Ok(concat_batches(&schema, &batches)?)
}

/// Scan and return the number of matching rows
#[instrument(skip_all)]
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
fn create_count_plan(&self) -> BoxFuture<Result<Arc<dyn ExecutionPlan>>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
if !self.projection_plan.physical_schema.fields.is_empty() {
return Err(Error::invalid_input(
"count_rows should not be called on a plan selecting columns".to_string(),
location!(),
));
}

if self.limit.is_some() || self.offset.is_some() {
log::warn!(
"count_rows called with limit or offset which could have surprising results"
);
}

let plan = self.create_plan().await?;
// Datafusion interprets COUNT(*) as COUNT(1)
let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1))));
Expand All @@ -1046,14 +1058,27 @@ impl Scanner {
let count_expr = builder.build()?;

let plan_schema = plan.schema();
let count_plan = Arc::new(AggregateExec::try_new(
Ok(Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(Vec::new()),
vec![Arc::new(count_expr)],
vec![None],
plan,
plan_schema,
)?);
)?) as Arc<dyn ExecutionPlan>)
}
.boxed()
}

/// Scan and return the number of matching rows
///
/// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method
/// especially if there is no filter.
#[instrument(skip_all)]
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
let count_plan = self.create_count_plan().await?;
let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?;

// A count plan will always return a single batch with a single row.
Expand Down Expand Up @@ -1127,15 +1152,25 @@ impl Scanner {
}
}

fn calc_eager_columns(&self, filter_plan: &FilterPlan) -> Result<Arc<Schema>> {
let columns = filter_plan.refine_columns();
// If we are going to filter on `filter_plan`, then which columns are so small it is
// cheaper to read the entire column and filter in memory.
//
// Note: only add columns that we actually need to read
fn calc_eager_columns(
&self,
filter_plan: &FilterPlan,
desired_schema: &Schema,
) -> Result<Arc<Schema>> {
let filter_columns = filter_plan.refine_columns();
let early_schema = self
.dataset
.empty_projection()
// We need the filter columns
.union_columns(columns, OnMissing::Error)?
// And also any columns that are eager
.union_predicate(|f| self.is_early_field(f))
// Start with the desired schema
.union_schema(desired_schema)
// Subtract columns that are expensive
.subtract_predicate(|f| !self.is_early_field(f))
// Add back columns that we need for filtering
.union_columns(filter_columns, OnMissing::Error)?
.into_schema_ref();

if early_schema.fields.iter().any(|f| !f.is_default_storage()) {
Expand Down Expand Up @@ -1340,7 +1375,10 @@ impl Scanner {
(Some(index_query), Some(_)) => {
// If there is a filter then just load the eager columns and
// "take" the other columns later.
let eager_schema = self.calc_eager_columns(&filter_plan)?;
let eager_schema = self.calc_eager_columns(
&filter_plan,
self.projection_plan.physical_schema.as_ref(),
)?;
self.scalar_indexed_scan(&eager_schema, index_query).await?
}
(None, Some(_)) if use_stats && self.batch_size.is_none() => {
Expand All @@ -1352,7 +1390,10 @@ impl Scanner {
let eager_schema = if filter_plan.has_refine() {
// If there is a filter then only load the filter columns in the
// initial scan. We will `take` the remaining columns later
self.calc_eager_columns(&filter_plan)?
self.calc_eager_columns(
&filter_plan,
self.projection_plan.physical_schema.as_ref(),
)?
} else {
// If there is no filter we eagerly load everything
self.projection_plan.physical_schema.clone()
Expand Down Expand Up @@ -3913,14 +3954,11 @@ mod test {
.unwrap();

let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
assert_eq!(32, dataset.count_rows(None).await.unwrap());
assert_eq!(
16,
dataset
.scan()
.filter("`Filter_me` > 15")
.unwrap()
.count_rows()
.count_rows(Some("`Filter_me` > 15".to_string()))
.await
.unwrap()
);
Expand Down Expand Up @@ -3948,7 +3986,7 @@ mod test {
.unwrap();

let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);

let mut scanner = dataset.scan();

Expand Down Expand Up @@ -3996,7 +4034,7 @@ mod test {
.unwrap();

let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);

let mut scanner = dataset.scan();

Expand Down Expand Up @@ -4519,20 +4557,13 @@ mod test {
}
}

/// Assert that the plan when formatted matches the expected string.
///
/// Within expected, you can use `...` to match any number of characters.
async fn assert_plan_equals(
dataset: &Dataset,
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
async fn assert_plan_node_equals(
plan_node: Arc<dyn ExecutionPlan>,
expected: &str,
) -> Result<()> {
let mut scan = dataset.scan();
plan(&mut scan)?;
let exec_plan = scan.create_plan().await?;
let plan_desc = format!(
"{}",
datafusion::physical_plan::displayable(exec_plan.as_ref()).indent(true)
datafusion::physical_plan::displayable(plan_node.as_ref()).indent(true)
);

let to_match = expected.split("...").collect::<Vec<_>>();
Expand All @@ -4559,6 +4590,71 @@ mod test {
Ok(())
}

/// Assert that the plan when formatted matches the expected string.
///
/// Within expected, you can use `...` to match any number of characters.
async fn assert_plan_equals(
dataset: &Dataset,
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
expected: &str,
) -> Result<()> {
let mut scan = dataset.scan();
plan(&mut scan)?;
let exec_plan = scan.create_plan().await?;
assert_plan_node_equals(exec_plan, expected).await
}

#[tokio::test]
async fn test_count_plan() {
// A count rows operation should load the minimal amount of data
let dim = 256;
let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim)
.await
.unwrap();

// By default, all columns are returned, this is bad for a count_rows op
let err = fixture
.dataset
.scan()
.create_count_plan()
.await
.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));

let mut scan = fixture.dataset.scan();
scan.project(&Vec::<String>::default()).unwrap();

// with_row_id needs to be specified
let err = scan.create_count_plan().await.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));

scan.with_row_id();

let plan = scan.create_count_plan().await.unwrap();

assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
LanceScan: uri=..., projection=[], row_id=true, row_addr=false, ordered=true",
)
.await
.unwrap();

scan.filter("s == ''").unwrap();

let plan = scan.create_count_plan().await.unwrap();

assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
ProjectionExec: expr=[_rowid@1 as _rowid]
FilterExec: s@0 =
LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=true",
)
.await
.unwrap();
}

#[rstest]
#[tokio::test]
async fn test_late_materialization(
Expand Down
5 changes: 1 addition & 4 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1816,10 +1816,7 @@ mod tests {

// Check that the data is as expected
let updated = ds
.scan()
.filter("value = 9999999")
.unwrap()
.count_rows()
.count_rows(Some("value = 9999999".to_string()))
.await
.unwrap();
assert_eq!(updated, 2048);
Expand Down

0 comments on commit c70d1d2

Please sign in to comment.