diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 66ca82207e..8949ee6524 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1339,6 +1339,26 @@ def test_update_dataset_all_types(tmp_path: Path): assert dataset.to_table() == expected +def test_update_with_binary_field(tmp_path: Path): + # Create a lance dataset with binary fields + table = pa.Table.from_pydict( + { + "a": [f"str-{i}" for i in range(100)], + "b": [b"bin-{i}" for i in range(100)], + "c": list(range(100)), + } + ) + dataset = lance.write_dataset(table, tmp_path) + + # Update binary field + dataset.update({"b": "X'616263'"}, where="c < 2") + + ds = lance.dataset(tmp_path) + assert ds.scanner(filter="c < 2").to_table().column( + "b" + ).combine_chunks() == pa.array([b"abc", b"abc"]) + + def test_create_update_empty_dataset(tmp_path: Path, provide_pandas: bool): base_dir = tmp_path / "dataset" diff --git a/rust/lance/src/io/exec/planner.rs b/rust/lance/src/io/exec/planner.rs index 74202b920d..9446d56afb 100644 --- a/rust/lance/src/io/exec/planner.rs +++ b/rust/lance/src/io/exec/planner.rs @@ -357,7 +357,9 @@ impl Planner { Value::DollarQuotedString(_) => todo!(), Value::EscapedStringLiteral(_) => todo!(), Value::NationalStringLiteral(_) => todo!(), - Value::HexStringLiteral(_) => todo!(), + Value::HexStringLiteral(hsl) => { + Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl))) + } Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))), Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))), Value::Null => Expr::Literal(ScalarValue::Null), @@ -673,6 +675,42 @@ impl Planner { Ok(resolved) } + /// Try to decode bytes from hex literal string. + /// + /// Copied from datafusion because this is not public. + /// + /// TODO: use SqlToRel from Datafusion directly? + fn try_decode_hex_literal(s: &str) -> Option> { + let hex_bytes = s.as_bytes(); + let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2); + + let start_idx = hex_bytes.len() % 2; + if start_idx > 0 { + // The first byte is formed of only one char. + decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?); + } + + for i in (start_idx..hex_bytes.len()).step_by(2) { + let high = Self::try_decode_hex_char(hex_bytes[i])?; + let low = Self::try_decode_hex_char(hex_bytes[i + 1])?; + decoded_bytes.push(high << 4 | low); + } + + Some(decoded_bytes) + } + + /// Try to decode a byte from a hex char. + /// + /// None will be returned if the input char is hex-invalid. + const fn try_decode_hex_char(c: u8) -> Option { + match c { + b'A'..=b'F' => Some(c - b'A' + 10), + b'a'..=b'f' => Some(c - b'a' + 10), + b'0'..=b'9' => Some(c - b'0'), + _ => None, + } + } + /// Optimize the filter expression and coerce data types. pub fn optimize_expr(&self, expr: Expr) -> Result { let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?); @@ -1384,4 +1422,21 @@ mod tests { let columns = Planner::column_names_in_expr(&expr); assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]); } + + #[test] + fn test_parse_binary_expr() { + let bin_str = "x'616263'"; + + let schema = Arc::new(Schema::new(vec![Field::new( + "binary", + DataType::Binary, + true, + )])); + let planner = Planner::new(schema); + let expr = planner.parse_expr(bin_str).unwrap(); + assert_eq!( + expr, + Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c']))) + ); + } }