Skip to content

Commit

Permalink
feat: Implement join on struct dtype (pola-rs#21093)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Feb 7, 2025
1 parent 63cdab5 commit 1aab51c
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
25 changes: 25 additions & 0 deletions crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow::array::PrimitiveArray;
use polars_core::chunked_array::ops::row_encode::encode_rows_unordered;
use polars_core::series::BitRepr;
use polars_core::utils::split;
use polars_core::with_match_physical_float_polars_type;
Expand Down Expand Up @@ -48,6 +49,12 @@ pub trait SeriesJoin: SeriesSealed + Sized {
let build_null_count = other.null_count();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count)
},
#[cfg(feature = "dtype-struct")]
T::Struct(_) => {
let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series();
let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series();
lhs.hash_join_left(rhs, validate, join_nulls)
},
x if x.is_float() => {
with_match_physical_float_polars_type!(lhs.dtype(), |$T| {
let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
Expand Down Expand Up @@ -126,6 +133,12 @@ pub trait SeriesJoin: SeriesSealed + Sized {
hash_join_tuples_left_semi(lhs, rhs, join_nulls)
}
},
#[cfg(feature = "dtype-struct")]
T::Struct(_) => {
let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series();
let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series();
lhs.hash_join_semi_anti(rhs, anti, join_nulls)?
},
x if x.is_float() => {
with_match_physical_float_polars_type!(lhs.dtype(), |$T| {
let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
Expand Down Expand Up @@ -227,6 +240,12 @@ pub trait SeriesJoin: SeriesSealed + Sized {
!swapped,
))
},
#[cfg(feature = "dtype-struct")]
T::Struct(_) => {
let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series();
let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series();
lhs.hash_join_inner(rhs, validate, join_nulls)
},
x if x.is_float() => {
with_match_physical_float_polars_type!(lhs.dtype(), |$T| {
let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
Expand Down Expand Up @@ -297,6 +316,12 @@ pub trait SeriesJoin: SeriesSealed + Sized {
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls)
},
#[cfg(feature = "dtype-struct")]
T::Struct(_) => {
let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series();
let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series();
lhs.hash_join_outer(rhs, validate, join_nulls)
},
x if x.is_float() => {
with_match_physical_float_polars_type!(lhs.dtype(), |$T| {
let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
Expand Down
74 changes: 74 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,3 +1626,77 @@ def test_select_after_join_where_20831() -> None:
)

assert q.select(pl.len()).collect().item() == 6


def test_join_on_struct() -> None:
lhs = pl.DataFrame(
{
"a": [{"x": 1}, {"x": 2}, {"x": 3}],
"b": [1, 2, 3],
}
)
rhs = pl.DataFrame(
{
"a": [{"x": 4}, {"x": 2}],
"c": [4, 2],
}
)

assert_frame_equal(
lhs.join(rhs, on="a", how="left", maintain_order="left"),
pl.select(
a=pl.Series([{"x": 1}, {"x": 2}, {"x": 3}]),
b=pl.Series([1, 2, 3]),
c=pl.Series([None, 2, None]),
),
)
assert_frame_equal(
lhs.join(rhs, on="a", how="right", maintain_order="right"),
pl.select(
b=pl.Series([None, 2]),
a=pl.Series([{"x": 4}, {"x": 2}]),
c=pl.Series([4, 2]),
),
)
assert_frame_equal(
lhs.join(rhs, on="a", how="inner"),
pl.select(
a=pl.Series([{"x": 2}]),
b=pl.Series([2]),
c=pl.Series([2]),
),
)
assert_frame_equal(
lhs.join(rhs, on="a", how="full", maintain_order="left_right"),
pl.select(
a=pl.Series([{"x": 1}, {"x": 2}, {"x": 3}, None]),
b=pl.Series([1, 2, 3, None]),
a_right=pl.Series([None, {"x": 2}, None, {"x": 4}]),
c=pl.Series([None, 2, None, 4]),
),
)
assert_frame_equal(
lhs.join(rhs, on="a", how="semi"),
pl.select(
a=pl.Series([{"x": 2}]),
b=pl.Series([2]),
),
)
assert_frame_equal(
lhs.join(rhs, on="a", how="anti", maintain_order="left"),
pl.select(
a=pl.Series([{"x": 1}, {"x": 3}]),
b=pl.Series([1, 3]),
),
)
assert_frame_equal(
lhs.join(rhs, how="cross", maintain_order="left_right"),
pl.select(
a=pl.Series([{"x": 1}, {"x": 1}, {"x": 2}, {"x": 2}, {"x": 3}, {"x": 3}]),
b=pl.Series([1, 1, 2, 2, 3, 3]),
a_right=pl.Series(
[{"x": 4}, {"x": 2}, {"x": 4}, {"x": 2}, {"x": 4}, {"x": 2}]
),
c=pl.Series([4, 2, 4, 2, 4, 2]),
),
)

0 comments on commit 1aab51c

Please sign in to comment.