diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index a590958b0ad7..752486677267 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -902,6 +902,49 @@ impl DataType { }, } } + + pub fn byte_size(&self) -> Option { + match self { + DataType::Boolean => Some(1), + DataType::UInt8 => Some(1), + DataType::UInt16 => Some(2), + DataType::UInt32 => Some(4), + DataType::UInt64 => Some(8), + DataType::Int8 => Some(1), + DataType::Int16 => Some(2), + DataType::Int32 => Some(4), + DataType::Int64 => Some(8), + DataType::Int128 => Some(16), + DataType::Float32 => Some(4), + DataType::Float64 => Some(8), + DataType::Decimal(_, _) => Some(16), + DataType::String => None, + DataType::Binary => None, + DataType::BinaryOffset => None, + DataType::Date => Some(4), + DataType::Datetime(_, _) => Some(8), + DataType::Duration(_) => Some(8), + DataType::Time => Some(8), + DataType::Array(data_type, size) => data_type.byte_size().map(|v| v * size), + DataType::List(_) => None, + DataType::Object(_, _) => None, + DataType::Null => Some(0), + DataType::Categorical(_, _) => None, + DataType::Enum(_, _) => None, + DataType::Struct(vec) => { + let mut total_size = 0usize; + for field in vec.iter() { + if let Some(byte_size) = field.dtype.byte_size() { + total_size += byte_size; + } else { + return None; + } + } + Some(total_size) + }, + DataType::Unknown(_) => None, + } + } } impl Display for DataType { diff --git a/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs b/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs index d3f76f6b8263..6fde8d3b96c1 100644 --- a/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs +++ b/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs @@ -1,4 +1,7 @@ -use arrow::array::{Array, BinaryViewArray, PrimitiveArray}; +use arrow::array::{ + Array, BinaryViewArray, FixedSizeListArray, MutableArray, MutableFixedSizeListArray, + MutablePrimitiveArray, PrimitiveArray, TryPush, +}; use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; use polars_error::PolarsResult; @@ -78,3 +81,65 @@ where is_little_endian, ))) } + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any un-castable value a Null. +pub(super) fn try_cast_binview_to_array_primitive( + from: &BinaryViewArray, + to: &ArrowDataType, + is_little_endian: bool, + element_size: usize, +) -> PolarsResult +where + T: Cast + NativeType, +{ + let size = if let ArrowDataType::FixedSizeList(_, size) = to { + *size + } else { + todo!("Hello") + }; + let mut result = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), size); + + from.iter().try_for_each(|x| { + if let Some(x) = x { + if x.len() % element_size != 0 { + todo!("Return error here.") + } + + result.try_push(Some( + x.chunks_exact(element_size) + .map(|val| { + if is_little_endian { + T::cast_le(val) + } else { + T::cast_be(val) + } + }) + .collect::>(), + )) + } else { + result.push_null(); + Ok(()) + } + })?; + + Ok(result.into()) +} + +pub(super) fn cast_binview_to_array_primitive_dyn( + from: &dyn Array, + to: &ArrowDataType, + is_little_endian: bool, + element_size: usize, +) -> PolarsResult> +where + T: Cast + NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + + Ok(Box::new(try_cast_binview_to_array_primitive::( + from, + to, + is_little_endian, + element_size, + )?)) +} diff --git a/crates/polars-ops/src/chunked_array/binary/namespace.rs b/crates/polars-ops/src/chunked_array/binary/namespace.rs index 3cd299892fa5..f74fddc5dffb 100644 --- a/crates/polars-ops/src/chunked_array/binary/namespace.rs +++ b/crates/polars-ops/src/chunked_array/binary/namespace.rs @@ -1,11 +1,16 @@ #[cfg(feature = "binary_encoding")] use std::borrow::Cow; +#[cfg(feature = "binary_encoding")] +use arrow::array::Array; +#[cfg(feature = "binary_encoding")] +use arrow::datatypes::PhysicalType; use arrow::with_match_primitive_type; #[cfg(feature = "binary_encoding")] use base64::engine::general_purpose; #[cfg(feature = "binary_encoding")] use base64::Engine as _; +use cast_binary_to_numerical::cast_binview_to_array_primitive_dyn; use memchr::memmem::find; use polars_compute::size::binary_size_bytes; use polars_core::prelude::arity::{broadcast_binary_elementwise_values, unary_elementwise_values}; @@ -133,24 +138,62 @@ pub trait BinaryNameSpaceImpl: AsBinary { #[cfg(feature = "binary_encoding")] #[allow(clippy::wrong_self_convention)] fn from_buffer(&self, dtype: &DataType, is_little_endian: bool) -> PolarsResult { + unsafe { + Ok(Series::from_chunks_and_dtype_unchecked( + self.as_binary().name().clone(), + self._from_buffer_inner(dtype, is_little_endian)?, + dtype, + )) + } + } + + fn _from_buffer_inner( + &self, + dtype: &DataType, + is_little_endian: bool, + ) -> PolarsResult>> { + let arrow_data_type = dtype.to_arrow(CompatLevel::newest()); let ca = self.as_binary(); - let arrow_type = dtype.to_arrow(CompatLevel::newest()); - match arrow_type.to_physical_type() { - arrow::datatypes::PhysicalType::Primitive(ty) => { + match arrow_data_type.to_physical_type() { + PhysicalType::Primitive(ty) => { with_match_primitive_type!(ty, |$T| { unsafe { - Ok(Series::from_chunks_and_dtype_unchecked( - ca.name().clone(), - ca.chunks().iter().map(|chunk| { - cast_binview_to_primitive_dyn::<$T>( - &**chunk, - &arrow_type, - is_little_endian, - ) - }).collect::>>()?, - dtype - )) + ca.chunks().iter().map(|chunk| { + cast_binview_to_primitive_dyn::<$T>( + &**chunk, + &arrow_data_type, + is_little_endian, + ) + }).collect() + } + }) + }, + PhysicalType::FixedSizeList => { + let leaf_dtype = dtype.leaf_dtype(); + let leaf_physical_type = leaf_dtype + .to_arrow(CompatLevel::newest()) + .to_physical_type(); + let primitive_type = if let PhysicalType::Primitive(x) = leaf_physical_type { + x + } else { + return Err( + polars_err!(InvalidOperation:"unsupported data type in from_buffer. Only numerical types are allowed in arrays."), + ); + }; + // Since we know it's a physical size, we + let element_size = leaf_dtype.byte_size().unwrap(); + + with_match_primitive_type!(primitive_type, |$T| { + unsafe { + ca.chunks().iter().map(|chunk| { + cast_binview_to_array_primitive_dyn::<$T>( + &**chunk, + &arrow_data_type, + is_little_endian, + element_size + ) + }).collect() } }) }, diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index 659d498b4388..00e5962149aa 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -67,10 +67,35 @@ impl BinaryNameSpace { #[cfg(feature = "binary_encoding")] pub fn from_buffer(self, to_type: DataType, is_little_endian: bool) -> Expr { - self.0 + let leaf_type = to_type.leaf_dtype(); + let shape = to_type.get_shape(); + + let call_to_type = if let Some(ref shape) = shape { + DataType::Array( + Box::new(leaf_type.clone()), + shape.iter().product(), + ) + } else { + to_type + }; + + let result = self + .0 .map_private(FunctionExpr::BinaryExpr(BinaryFunction::FromBuffer( - to_type, + call_to_type, is_little_endian, - ))) + ))); + + if let Some(shape) = shape { + let mut dimensions: Vec = shape + .iter() + .map(|&v| ReshapeDimension::new(v as i64)) + .collect(); + dimensions.insert(0, ReshapeDimension::Infer); + + result.apply_private(FunctionExpr::Reshape(dimensions)) + } else { + result + } } } diff --git a/py-polars/tests/unit/operations/namespaces/test_binary.py b/py-polars/tests/unit/operations/namespaces/test_binary.py index ab86b9b51c15..e5d1130b7110 100644 --- a/py-polars/tests/unit/operations/namespaces/test_binary.py +++ b/py-polars/tests/unit/operations/namespaces/test_binary.py @@ -4,6 +4,7 @@ import struct from typing import TYPE_CHECKING +import numpy as np import pytest import polars as pl @@ -210,6 +211,64 @@ def test_reinterpret( assert_frame_equal(result, expected_df) +@pytest.mark.parametrize( + ("dtype", "inner_type_size", "struct_type"), + [ + (pl.Array(pl.Int8, 3), 1, "b"), + (pl.Array(pl.UInt8, 3), 1, "B"), + (pl.Array(pl.UInt8, (3, 4, 5)), 1, "B"), + (pl.Array(pl.Int16, 3), 2, "h"), + (pl.Array(pl.UInt16, 3), 2, "H"), + (pl.Array(pl.Int32, 3), 4, "i"), + (pl.Array(pl.UInt32, 3), 4, "I"), + (pl.Array(pl.Int64, 3), 8, "q"), + (pl.Array(pl.UInt64, 3), 8, "Q"), + (pl.Array(pl.Float32, 3), 4, "f"), + (pl.Array(pl.Float64, 3), 8, "d"), + ], +) +def test_reinterpret_list( + dtype: pl.Array, + inner_type_size: int, + struct_type: str, +) -> None: + # Make test reproducible + random.seed(42) + + type_size = inner_type_size + shape = dtype.shape + if isinstance(shape, int): + shape = (shape,) + for dim_size in dtype.shape: + type_size *= dim_size + + byte_arr = [random.randbytes(type_size) for _ in range(3)] + df = pl.DataFrame({"x": byte_arr}, orient="row") + + for endianness in ["little", "big"]: + result = df.select( + pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type] + ) + + # So that mypy doesn't complain + struct_endianness = "<" if endianness == "little" else ">" + expected = [] + for elem_bytes in byte_arr: + vals = [ + struct.unpack_from( + f"{struct_endianness}{struct_type}", + elem_bytes[idx : idx + inner_type_size], + )[0] + for idx in range(0, type_size, inner_type_size) + ] + if len(shape) > 1: + vals = np.reshape(vals, shape).tolist() + expected.append(vals) + expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype}) + + assert_frame_equal(result, expected_df) + + @pytest.mark.parametrize( ("dtype", "type_size"), [