From ebf7c5d981fe31ef6194fd456162759273a27469 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 9 Jul 2024 16:52:57 -0700 Subject: [PATCH] fix: add support for list fields with names other than 'item' (#2580) In the future we can maybe add support for a few more things as well: * User's should be able to read back into whatever schema they want (this is present in the rust already but missing from python because projection is missing from python, so this is mostly just tests) * Perhaps add some kind of read option to "normalize" a schema so we always read back the field as "item". Right now, if no schema is provided at read time, we mirror exactly the write time schema, this will cause non-item field names to propagate which is maybe not the best choice. This could also be a write time option to normalize the schema on write. --- python/python/tests/test_file.py | 33 +++++++++++++++++++ rust/lance-encoding/src/decoder.rs | 2 ++ .../src/encodings/logical/list.rs | 13 +++++++- 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/python/tests/test_file.py b/python/python/tests/test_file.py index e5efa09b1d..fa8241f3a9 100644 --- a/python/python/tests/test_file.py +++ b/python/python/tests/test_file.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors import pyarrow as pa +import pyarrow.parquet as pq import pytest from lance.file import LanceFileReader, LanceFileWriter @@ -156,3 +157,35 @@ def test_metadata(tmp_path): assert page.buffers[0].size == 24 assert len(page.encoding) > 0 + + +def test_round_trip_parquet(tmp_path): + pq_path = tmp_path / "foo.parquet" + table = pa.table({"int": [1, 2], "list_str": [["x", "yz", "abc"], ["foo", "bar"]]}) + pq.write_table(table, str(pq_path)) + table = pq.read_table(str(pq_path)) + + lance_path = tmp_path / "foo.lance" + with LanceFileWriter(str(lance_path)) as writer: + writer.write_batch(table) + + reader = LanceFileReader(str(lance_path)) + round_tripped = reader.read_all().to_table() + assert round_tripped == table + + +def test_list_field_name(tmp_path): + weird_field = pa.field("why does this name even exist", pa.string()) + weird_string_type = pa.list_(weird_field) + schema = pa.schema([pa.field("list_str", weird_string_type)]) + table = pa.table({"list_str": [["x", "yz", "abc"], ["foo", "bar"]]}, schema=schema) + + path = tmp_path / "foo.lance" + with LanceFileWriter(str(path)) as writer: + writer.write_batch(table) + + reader = LanceFileReader(str(path)) + round_tripped = reader.read_all().to_table() + + assert round_tripped == table + assert round_tripped.schema.field("list_str").type == weird_string_type diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index e110680e1f..8469951ede 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -637,6 +637,7 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy { file_buffers: buffers, positions_and_sizes: &offsets_column.buffer_offsets_and_sizes, }; + let item_field_name = items_field.name().clone(); let (chain, items_scheduler) = chain.new_child( /*child_idx=*/ 0, &field.children[0], @@ -688,6 +689,7 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy { Ok(Arc::new(ListFieldScheduler::new( inner, items_scheduler, + item_field_name.clone(), items_type, offset_type, null_offset_adjustments, diff --git a/rust/lance-encoding/src/encodings/logical/list.rs b/rust/lance-encoding/src/encodings/logical/list.rs index fd7d290ef8..8229226ab5 100644 --- a/rust/lance-encoding/src/encodings/logical/list.rs +++ b/rust/lance-encoding/src/encodings/logical/list.rs @@ -455,6 +455,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> { item_decoder: None, rows_drained: 0, lists_available: 0, + item_field_name: self.scheduler.item_field_name.clone(), num_rows, unloaded: Some(indirect_fut), items_type: self.scheduler.items_type.clone(), @@ -491,6 +492,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> { pub struct ListFieldScheduler { offsets_scheduler: Arc, items_scheduler: Arc, + item_field_name: String, items_type: DataType, offset_type: DataType, list_type: DataType, @@ -512,6 +514,7 @@ impl ListFieldScheduler { pub fn new( offsets_scheduler: Arc, items_scheduler: Arc, + item_field_name: String, items_type: DataType, // Should be int32 or int64 offset_type: DataType, @@ -529,6 +532,7 @@ impl ListFieldScheduler { Self { offsets_scheduler, items_scheduler, + item_field_name, items_type, offset_type, offset_page_info, @@ -573,6 +577,7 @@ struct ListPageDecoder { lists_available: u64, num_rows: u64, rows_drained: u64, + item_field_name: String, items_type: DataType, offset_type: DataType, data_type: DataType, @@ -583,6 +588,7 @@ struct ListDecodeTask { validity: BooleanBuffer, // Will be None if there are no items (all empty / null lists) items: Option>, + item_field_name: String, items_type: DataType, offset_type: DataType, } @@ -601,7 +607,11 @@ impl DecodeArrayTask for ListDecodeTask { // TODO: we default to nullable true here, should probably use the nullability given to // us from the input schema - let item_field = Arc::new(Field::new("item", self.items_type.clone(), true)); + let item_field = Arc::new(Field::new( + self.item_field_name, + self.items_type.clone(), + true, + )); // The offsets are already decoded but they need to be shifted back to 0 and cast // to the appropriate type @@ -756,6 +766,7 @@ impl LogicalPageDecoder for ListPageDecoder { task: Box::new(ListDecodeTask { offsets, validity, + item_field_name: self.item_field_name.clone(), items: item_decode, items_type: self.items_type.clone(), offset_type: self.offset_type.clone(),