Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't eagerly materialize fields that the user hasn't asked for #3442

Merged
merged 9 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.4.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
8 changes: 6 additions & 2 deletions java/core/src/test/java/com/lancedb/lance/ScannerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ void testDatasetScannerCountRows() throws Exception {
// write id with value from 0 to 39
try (Dataset dataset = testDataset.write(1, 40)) {
try (LanceScanner scanner =
dataset.newScan(new ScanOptions.Builder().filter("id < 20").build())) {
dataset.newScan(
new ScanOptions.Builder()
.columns(Arrays.asList())
.withRowId(true)
.filter("id < 20")
.build())) {
assertEquals(20, scanner.countRows());
}
}
Expand Down Expand Up @@ -387,7 +392,6 @@ void testDatasetScannerBatchReadahead() throws Exception {
// This test is more about ensuring that the batchReadahead parameter is accepted
// and doesn't cause errors. The actual effect of batchReadahead might not be
// directly observable in this test.
assertEquals(totalRows, scanner.countRows());
try (ArrowReader reader = scanner.scanBatches()) {
int rowCount = 0;
while (reader.loadNextBatch()) {
Expand Down
4 changes: 3 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,9 @@ def count_rows(
"""
if isinstance(filter, pa.compute.Expression):
# TODO: consolidate all to use scanner
return self.scanner(filter=filter).count_rows()
return self.scanner(
columns=[], with_row_id=True, filter=filter
).count_rows()

return self._ds.count_rows(filter)

Expand Down
16 changes: 10 additions & 6 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,10 @@ def fragment_id(self):
def count_rows(
self, filter: Optional[Union[pa.compute.Expression, str]] = None
) -> int:
if filter is not None:
return self.scanner(filter=filter).count_rows()
if isinstance(filter, pa.compute.Expression):
return self.scanner(
with_row_id=True, columns=[], filter=filter
).count_rows()
return self._fragment.count_rows(filter)

@property
Expand Down Expand Up @@ -540,10 +542,12 @@ def merge(

def merge_columns(
self,
value_func: Dict[str, str]
| BatchUDF
| ReaderLike
| Callable[[pa.RecordBatch], pa.RecordBatch],
value_func: (
Dict[str, str]
| BatchUDF
| ReaderLike
| Callable[[pa.RecordBatch], pa.RecordBatch]
),
columns: Optional[list[str]] = None,
batch_size: Optional[int] = None,
reader_schema: Optional[pa.Schema] = None,
Expand Down
12 changes: 11 additions & 1 deletion python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,16 @@ def test_count_rows(tmp_path: Path):
assert dataset.count_rows(filter="a < 50") == 50


def test_select_none(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
ds = lance.write_dataset(table, base_dir)

assert "projection=[a]" in ds.scanner(
columns=[], filter="a < 50", with_row_id=True
).explain_plan(True)


def test_get_fragments(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
Expand Down Expand Up @@ -2200,7 +2210,7 @@ def test_scan_count_rows(tmp_path: Path):
df = pd.DataFrame({"a": range(42), "b": range(42)})
dataset = lance.write_dataset(df, base_dir)

assert dataset.scanner().count_rows() == 42
assert dataset.scanner(columns=[], with_row_id=True).count_rows() == 42
assert dataset.count_rows(filter="a < 10") == 10
assert dataset.count_rows(filter=pa_ds.field("a") < 20) == 20

Expand Down
3 changes: 3 additions & 0 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,9 @@ impl FileFragment {
match filter {
Some(expr) => self
.scan()
.project(&Vec::<String>::default())
.unwrap()
.with_row_id()
.filter(&expr)?
.count_rows()
.await
Expand Down
156 changes: 126 additions & 30 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl MaterializationStyle {
}

/// Filter for filtering rows
#[derive(Debug)]
pub enum LanceFilter {
/// The filter is an SQL string
Sql(String),
Expand Down Expand Up @@ -1027,11 +1028,22 @@ impl Scanner {
Ok(concat_batches(&schema, &batches)?)
}

/// Scan and return the number of matching rows
#[instrument(skip_all)]
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
fn create_count_plan(&self) -> BoxFuture<Result<Arc<dyn ExecutionPlan>>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
if !self.projection_plan.physical_schema.fields.is_empty() {
return Err(Error::invalid_input(
"count_rows should not be called on a plan selecting columns".to_string(),
location!(),
));
}
Comment on lines +1034 to +1039
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little torn on this error. Ideally, we would just silently blank out the projection plan and then create the count plan. However, to do that we either have to clone the scan, which is a pretty big thing to be cloning, or modify the scanner, which would maybe not be what users would expect from count_rows.

For now, I want to get something out soon, so I'm just raising an error, with the assumption that Scanner::count_rows is a mostly internal method anyways (users should use Dataset::count_rows).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right that is pretty internal. Plus easy to work around.


if self.limit.is_some() || self.offset.is_some() {
log::warn!(
"count_rows called with limit or offset which could have surprising results"
);
}

let plan = self.create_plan().await?;
// Datafusion interprets COUNT(*) as COUNT(1)
let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1))));
Expand All @@ -1046,14 +1058,27 @@ impl Scanner {
let count_expr = builder.build()?;

let plan_schema = plan.schema();
let count_plan = Arc::new(AggregateExec::try_new(
Ok(Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(Vec::new()),
vec![Arc::new(count_expr)],
vec![None],
plan,
plan_schema,
)?);
)?) as Arc<dyn ExecutionPlan>)
}
.boxed()
}

/// Scan and return the number of matching rows
///
/// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method
/// especially if there is no filter.
#[instrument(skip_all)]
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
let count_plan = self.create_count_plan().await?;
let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?;

// A count plan will always return a single batch with a single row.
Expand Down Expand Up @@ -1127,15 +1152,25 @@ impl Scanner {
}
}

fn calc_eager_columns(&self, filter_plan: &FilterPlan) -> Result<Arc<Schema>> {
let columns = filter_plan.refine_columns();
// If we are going to filter on `filter_plan`, then which columns are so small it is
// cheaper to read the entire column and filter in memory.
//
// Note: only add columns that we actually need to read
fn calc_eager_columns(
&self,
filter_plan: &FilterPlan,
desired_schema: &Schema,
) -> Result<Arc<Schema>> {
let filter_columns = filter_plan.refine_columns();
let early_schema = self
.dataset
.empty_projection()
// We need the filter columns
.union_columns(columns, OnMissing::Error)?
// And also any columns that are eager
.union_predicate(|f| self.is_early_field(f))
// Start with the desired schema
.union_schema(desired_schema)
// Subtract columns that are expensive
.subtract_predicate(|f| !self.is_early_field(f))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do "early" and "eager" mean the same thing in the vocabulary?

// Add back columns that we need for filtering
.union_columns(filter_columns, OnMissing::Error)?
.into_schema_ref();

if early_schema.fields.iter().any(|f| !f.is_default_storage()) {
Expand Down Expand Up @@ -1340,7 +1375,10 @@ impl Scanner {
(Some(index_query), Some(_)) => {
// If there is a filter then just load the eager columns and
// "take" the other columns later.
let eager_schema = self.calc_eager_columns(&filter_plan)?;
let eager_schema = self.calc_eager_columns(
&filter_plan,
self.projection_plan.physical_schema.as_ref(),
)?;
self.scalar_indexed_scan(&eager_schema, index_query).await?
}
(None, Some(_)) if use_stats && self.batch_size.is_none() => {
Expand All @@ -1352,7 +1390,10 @@ impl Scanner {
let eager_schema = if filter_plan.has_refine() {
// If there is a filter then only load the filter columns in the
// initial scan. We will `take` the remaining columns later
self.calc_eager_columns(&filter_plan)?
self.calc_eager_columns(
&filter_plan,
self.projection_plan.physical_schema.as_ref(),
)?
} else {
// If there is no filter we eagerly load everything
self.projection_plan.physical_schema.clone()
Expand Down Expand Up @@ -3913,14 +3954,11 @@ mod test {
.unwrap();

let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
assert_eq!(32, dataset.count_rows(None).await.unwrap());
assert_eq!(
16,
dataset
.scan()
.filter("`Filter_me` > 15")
.unwrap()
.count_rows()
.count_rows(Some("`Filter_me` > 15".to_string()))
.await
.unwrap()
);
Expand Down Expand Up @@ -3948,7 +3986,7 @@ mod test {
.unwrap();

let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);

let mut scanner = dataset.scan();

Expand Down Expand Up @@ -3996,7 +4034,7 @@ mod test {
.unwrap();

let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.scan().count_rows().await.unwrap());
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);

let mut scanner = dataset.scan();

Expand Down Expand Up @@ -4519,20 +4557,13 @@ mod test {
}
}

/// Assert that the plan when formatted matches the expected string.
///
/// Within expected, you can use `...` to match any number of characters.
async fn assert_plan_equals(
dataset: &Dataset,
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
async fn assert_plan_node_equals(
plan_node: Arc<dyn ExecutionPlan>,
expected: &str,
) -> Result<()> {
let mut scan = dataset.scan();
plan(&mut scan)?;
let exec_plan = scan.create_plan().await?;
let plan_desc = format!(
"{}",
datafusion::physical_plan::displayable(exec_plan.as_ref()).indent(true)
datafusion::physical_plan::displayable(plan_node.as_ref()).indent(true)
);

let to_match = expected.split("...").collect::<Vec<_>>();
Expand All @@ -4559,6 +4590,71 @@ mod test {
Ok(())
}

/// Assert that the plan when formatted matches the expected string.
///
/// Within expected, you can use `...` to match any number of characters.
async fn assert_plan_equals(
dataset: &Dataset,
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
expected: &str,
) -> Result<()> {
let mut scan = dataset.scan();
plan(&mut scan)?;
let exec_plan = scan.create_plan().await?;
assert_plan_node_equals(exec_plan, expected).await
}

#[tokio::test]
async fn test_count_plan() {
// A count rows operation should load the minimal amount of data
let dim = 256;
let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim)
.await
.unwrap();

// By default, all columns are returned, this is bad for a count_rows op
let err = fixture
.dataset
.scan()
.create_count_plan()
.await
.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));

let mut scan = fixture.dataset.scan();
scan.project(&Vec::<String>::default()).unwrap();

// with_row_id needs to be specified
let err = scan.create_count_plan().await.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));

scan.with_row_id();

let plan = scan.create_count_plan().await.unwrap();

assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
LanceScan: uri=..., projection=[], row_id=true, row_addr=false, ordered=true",
)
.await
.unwrap();

scan.filter("s == ''").unwrap();

let plan = scan.create_count_plan().await.unwrap();

assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
ProjectionExec: expr=[_rowid@1 as _rowid]
FilterExec: s@0 =
LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=true",
)
.await
.unwrap();
}

#[rstest]
#[tokio::test]
async fn test_late_materialization(
Expand Down
5 changes: 1 addition & 4 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1816,10 +1816,7 @@ mod tests {

// Check that the data is as expected
let updated = ds
.scan()
.filter("value = 9999999")
.unwrap()
.count_rows()
.count_rows(Some("value = 9999999".to_string()))
.await
.unwrap();
assert_eq!(updated, 2048);
Expand Down
Loading