Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support numerical arrays in bin.reinterpret #20456

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,49 @@ impl DataType {
},
}
}

pub fn byte_size(&self) -> Option<usize> {
match self {
DataType::Boolean => Some(1),
DataType::UInt8 => Some(1),
DataType::UInt16 => Some(2),
DataType::UInt32 => Some(4),
DataType::UInt64 => Some(8),
DataType::Int8 => Some(1),
DataType::Int16 => Some(2),
DataType::Int32 => Some(4),
DataType::Int64 => Some(8),
DataType::Int128 => Some(16),
DataType::Float32 => Some(4),
DataType::Float64 => Some(8),
DataType::Decimal(_, _) => Some(16),
DataType::String => None,
DataType::Binary => None,
DataType::BinaryOffset => None,
DataType::Date => Some(4),
DataType::Datetime(_, _) => Some(8),
DataType::Duration(_) => Some(8),
DataType::Time => Some(8),
DataType::Array(data_type, size) => data_type.byte_size().map(|v| v * size),
DataType::List(_) => None,
DataType::Object(_, _) => None,
DataType::Null => Some(0),
DataType::Categorical(_, _) => None,
DataType::Enum(_, _) => None,
DataType::Struct(vec) => {
let mut total_size = 0usize;
for field in vec.iter() {
if let Some(byte_size) = field.dtype.byte_size() {
total_size += byte_size;
} else {
return None;
}
}
Some(total_size)
},
DataType::Unknown(_) => None,
}
}
}

impl Display for DataType {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use arrow::array::{Array, BinaryViewArray, PrimitiveArray};
use arrow::array::{
Array, BinaryViewArray, FixedSizeListArray, MutableArray, MutableFixedSizeListArray,
MutablePrimitiveArray, PrimitiveArray, TryPush,
};
use arrow::datatypes::ArrowDataType;
use arrow::types::NativeType;
use polars_error::PolarsResult;
Expand Down Expand Up @@ -78,3 +81,65 @@ where
is_little_endian,
)))
}

/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any un-castable value a Null.
pub(super) fn try_cast_binview_to_array_primitive<T>(
from: &BinaryViewArray,
to: &ArrowDataType,
is_little_endian: bool,
element_size: usize,
) -> PolarsResult<FixedSizeListArray>
where
T: Cast + NativeType,
{
let size = if let ArrowDataType::FixedSizeList(_, size) = to {
*size
} else {
todo!("Hello")
};
let mut result = MutableFixedSizeListArray::new(MutablePrimitiveArray::<T>::new(), size);

from.iter().try_for_each(|x| {
if let Some(x) = x {
if x.len() % element_size != 0 {
todo!("Return error here.")
}

result.try_push(Some(
x.chunks_exact(element_size)
.map(|val| {
if is_little_endian {
T::cast_le(val)
} else {
T::cast_be(val)
}
})
.collect::<Vec<_>>(),
))
} else {
result.push_null();
Ok(())
}
})?;

Ok(result.into())
}

pub(super) fn cast_binview_to_array_primitive_dyn<T>(
from: &dyn Array,
to: &ArrowDataType,
is_little_endian: bool,
element_size: usize,
) -> PolarsResult<Box<dyn Array>>
where
T: Cast + NativeType,
{
let from = from.as_any().downcast_ref().unwrap();

Ok(Box::new(try_cast_binview_to_array_primitive::<T>(
from,
to,
is_little_endian,
element_size,
)?))
}
71 changes: 57 additions & 14 deletions crates/polars-ops/src/chunked_array/binary/namespace.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#[cfg(feature = "binary_encoding")]
use std::borrow::Cow;

#[cfg(feature = "binary_encoding")]
use arrow::array::Array;
#[cfg(feature = "binary_encoding")]
use arrow::datatypes::PhysicalType;
use arrow::with_match_primitive_type;
#[cfg(feature = "binary_encoding")]
use base64::engine::general_purpose;
#[cfg(feature = "binary_encoding")]
use base64::Engine as _;
use cast_binary_to_numerical::cast_binview_to_array_primitive_dyn;
use memchr::memmem::find;
use polars_compute::size::binary_size_bytes;
use polars_core::prelude::arity::{broadcast_binary_elementwise_values, unary_elementwise_values};
Expand Down Expand Up @@ -133,24 +138,62 @@ pub trait BinaryNameSpaceImpl: AsBinary {
#[cfg(feature = "binary_encoding")]
#[allow(clippy::wrong_self_convention)]
fn from_buffer(&self, dtype: &DataType, is_little_endian: bool) -> PolarsResult<Series> {
unsafe {
Ok(Series::from_chunks_and_dtype_unchecked(
self.as_binary().name().clone(),
self._from_buffer_inner(dtype, is_little_endian)?,
dtype,
))
}
}

fn _from_buffer_inner(
&self,
dtype: &DataType,
is_little_endian: bool,
) -> PolarsResult<Vec<Box<dyn Array>>> {
let arrow_data_type = dtype.to_arrow(CompatLevel::newest());
let ca = self.as_binary();
let arrow_type = dtype.to_arrow(CompatLevel::newest());

match arrow_type.to_physical_type() {
arrow::datatypes::PhysicalType::Primitive(ty) => {
match arrow_data_type.to_physical_type() {
PhysicalType::Primitive(ty) => {
with_match_primitive_type!(ty, |$T| {
unsafe {
Ok(Series::from_chunks_and_dtype_unchecked(
ca.name().clone(),
ca.chunks().iter().map(|chunk| {
cast_binview_to_primitive_dyn::<$T>(
&**chunk,
&arrow_type,
is_little_endian,
)
}).collect::<PolarsResult<Vec<_>>>()?,
dtype
))
ca.chunks().iter().map(|chunk| {
cast_binview_to_primitive_dyn::<$T>(
&**chunk,
&arrow_data_type,
is_little_endian,
)
}).collect()
}
})
},
PhysicalType::FixedSizeList => {
let leaf_dtype = dtype.leaf_dtype();
let leaf_physical_type = leaf_dtype
.to_arrow(CompatLevel::newest())
.to_physical_type();
let primitive_type = if let PhysicalType::Primitive(x) = leaf_physical_type {
x
} else {
return Err(
polars_err!(InvalidOperation:"unsupported data type in from_buffer. Only numerical types are allowed in arrays."),
);
};
// Since we know it's a physical size, we
let element_size = leaf_dtype.byte_size().unwrap();

with_match_primitive_type!(primitive_type, |$T| {
unsafe {
ca.chunks().iter().map(|chunk| {
cast_binview_to_array_primitive_dyn::<$T>(
&**chunk,
&arrow_data_type,
is_little_endian,
element_size
)
}).collect()
}
})
},
Expand Down
31 changes: 28 additions & 3 deletions crates/polars-plan/src/dsl/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,35 @@ impl BinaryNameSpace {

#[cfg(feature = "binary_encoding")]
pub fn from_buffer(self, to_type: DataType, is_little_endian: bool) -> Expr {
self.0
let leaf_type = to_type.leaf_dtype();
let shape = to_type.get_shape();

let call_to_type = if let Some(ref shape) = shape {
DataType::Array(
Box::new(leaf_type.clone()),
shape.iter().product(),
)
} else {
to_type
};

let result = self
.0
.map_private(FunctionExpr::BinaryExpr(BinaryFunction::FromBuffer(
to_type,
call_to_type,
is_little_endian,
)))
)));

if let Some(shape) = shape {
let mut dimensions: Vec<ReshapeDimension> = shape
.iter()
.map(|&v| ReshapeDimension::new(v as i64))
.collect();
dimensions.insert(0, ReshapeDimension::Infer);

result.apply_private(FunctionExpr::Reshape(dimensions))
} else {
result
}
}
}
59 changes: 59 additions & 0 deletions py-polars/tests/unit/operations/namespaces/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import struct
from typing import TYPE_CHECKING

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -210,6 +211,64 @@ def test_reinterpret(
assert_frame_equal(result, expected_df)


@pytest.mark.parametrize(
("dtype", "inner_type_size", "struct_type"),
[
(pl.Array(pl.Int8, 3), 1, "b"),
(pl.Array(pl.UInt8, 3), 1, "B"),
(pl.Array(pl.UInt8, (3, 4, 5)), 1, "B"),
(pl.Array(pl.Int16, 3), 2, "h"),
(pl.Array(pl.UInt16, 3), 2, "H"),
(pl.Array(pl.Int32, 3), 4, "i"),
(pl.Array(pl.UInt32, 3), 4, "I"),
(pl.Array(pl.Int64, 3), 8, "q"),
(pl.Array(pl.UInt64, 3), 8, "Q"),
(pl.Array(pl.Float32, 3), 4, "f"),
(pl.Array(pl.Float64, 3), 8, "d"),
],
)
def test_reinterpret_list(
dtype: pl.Array,
inner_type_size: int,
struct_type: str,
) -> None:
# Make test reproducible
random.seed(42)

type_size = inner_type_size
shape = dtype.shape
if isinstance(shape, int):
shape = (shape,)
for dim_size in dtype.shape:
type_size *= dim_size

byte_arr = [random.randbytes(type_size) for _ in range(3)]
df = pl.DataFrame({"x": byte_arr}, orient="row")

for endianness in ["little", "big"]:
result = df.select(
pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]
)

# So that mypy doesn't complain
struct_endianness = "<" if endianness == "little" else ">"
expected = []
for elem_bytes in byte_arr:
vals = [
struct.unpack_from(
f"{struct_endianness}{struct_type}",
elem_bytes[idx : idx + inner_type_size],
)[0]
for idx in range(0, type_size, inner_type_size)
]
if len(shape) > 1:
vals = np.reshape(vals, shape).tolist()
expected.append(vals)
expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})

assert_frame_equal(result, expected_df)


@pytest.mark.parametrize(
("dtype", "type_size"),
[
Expand Down
Loading