Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow casting struct to bigger nullable struct #12

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Implementation of casting to (or between) list types

#include <limits>
#include <set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -340,6 +341,8 @@ struct CastFixedList {

struct CastStruct {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
static constexpr int kFillNullSentinel = -2;

const CastOptions& options = CastState::Get(ctx);
const auto& in_type = checked_cast<const StructType&>(*batch[0].type());
const auto& out_type = checked_cast<const StructType&>(*out->type());
Expand All @@ -348,25 +351,46 @@ struct CastStruct {

std::vector<int> fields_to_select(out_field_count, -1);

int out_field_index = 0;
for (int in_field_index = 0;
in_field_index < in_field_count && out_field_index < out_field_count;
++in_field_index) {
const auto& in_field = in_type.field(in_field_index);
std::set<std::string> all_in_field_names;
for (int in_field_index = 0; in_field_index < in_field_count; ++in_field_index) {
all_in_field_names.insert(in_type.field(in_field_index)->name());
}

for (int in_field_index = 0, out_field_index = 0;
out_field_index < out_field_count;) {
const auto& out_field = out_type.field(out_field_index);
if (in_field->name() == out_field->name()) {
if (in_field->nullable() && !out_field->nullable()) {
return Status::TypeError("cannot cast nullable field to non-nullable field: ",
in_type.ToString(), " ", out_type.ToString());
if (in_field_index < in_field_count) {
const auto& in_field = in_type.field(in_field_index);
// If there are more in_fields check if they match the out_field.
if (in_field->name() == out_field->name()) {
if (in_field->nullable() && !out_field->nullable()) {
return Status::TypeError("cannot cast nullable field to non-nullable field: ",
in_type.ToString(), " ", out_type.ToString());
}
// Found matching in_field and out_field.
fields_to_select[out_field_index++] = in_field_index;
// Using the same in_field for multiple out_fields is not allowed.
in_field_index++;
continue;
}
fields_to_select[out_field_index++] = in_field_index;
}
}

if (out_field_index < out_field_count) {
return Status::TypeError(
"struct fields don't match or are in the wrong order: Input fields: ",
in_type.ToString(), " output fields: ", out_type.ToString());
if (all_in_field_names.count(out_field->name()) == 0 && out_field->nullable()) {
// Didn't match current in_field, but we can fill with null.
// Filling with null is only acceptable on nullable fields when there
// is definitely no in_field with matching name.

fields_to_select[out_field_index++] = kFillNullSentinel;
} else if (in_field_index < in_field_count) {
// Didn't match current in_field, and the we cannot fill with null, so
// try next in_field.
in_field_index++;
} else {
// Didn't match current in_field, we cannot fill with null, and there
// are no more in_fields to try, so fail.
return Status::TypeError(
"struct fields don't match or are in the wrong order: Input fields: ",
in_type.ToString(), " output fields: ", out_type.ToString());
}
}

const ArraySpan& in_array = batch[0].array;
Expand All @@ -378,17 +402,21 @@ struct CastStruct {
in_array.offset, in_array.length));
}

out_field_index = 0;
int out_field_index = 0;
for (int field_index : fields_to_select) {
const auto& values = (in_array.child_data[field_index].ToArrayData()->Slice(
in_array.offset, in_array.length));
const auto& target_type = out->type()->field(out_field_index++)->type();

ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, target_type, options, ctx->exec_context()));

DCHECK(cast_values.is_array());
out_array->child_data.push_back(cast_values.array());
if (field_index == kFillNullSentinel) {
ARROW_ASSIGN_OR_RAISE(auto nulls,
MakeArrayOfNull(target_type->GetSharedPtr(), batch.length));
out_array->child_data.push_back(nulls->data());
} else {
const auto& values = (in_array.child_data[field_index].ToArrayData()->Slice(
in_array.offset, in_array.length));
ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, target_type, options, ctx->exec_context()));
DCHECK(cast_values.is_array());
out_array->child_data.push_back(cast_values.array());
}
}

return Status::OK();
Expand Down
111 changes: 78 additions & 33 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2540,6 +2540,8 @@ static void CheckStructToStructSubset(
d2 = ArrayFromJSON(dest_value_type, "[6, 51, 49]");
e2 = ArrayFromJSON(dest_value_type, "[19, 17, 74]");

auto nulls = ArrayFromJSON(dest_value_type, "[null, null, null]");

ASSERT_OK_AND_ASSIGN(auto src,
StructArray::Make({a1, b1, c1, d1, e1}, field_names));
ASSERT_OK_AND_ASSIGN(auto dest1,
Expand All @@ -2565,34 +2567,38 @@ static void CheckStructToStructSubset(
CheckCast(src, dest5);

// field does not exist
const auto dest6 = arrow::struct_({std::make_shared<Field>("a", int8()),
std::make_shared<Field>("d", int16()),
std::make_shared<Field>("f", int64())});
const auto options6 = CastOptions::Safe(dest6);
ASSERT_OK_AND_ASSIGN(auto dest6,
StructArray::Make({a1, d1, nulls}, {"a", "d", "f"}));
CheckCast(src, dest6);

const auto dest7 = arrow::struct_(
{std::make_shared<Field>("a", int8()), std::make_shared<Field>("d", int16()),
std::make_shared<Field>("f", int64(), /*nullable=*/false)});
const auto options7 = CastOptions::Safe(dest7);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options6));
Cast(src, options7));

// fields in wrong order
const auto dest7 = arrow::struct_({std::make_shared<Field>("a", int8()),
const auto dest8 = arrow::struct_({std::make_shared<Field>("a", int8()),
std::make_shared<Field>("c", int16()),
std::make_shared<Field>("b", int64())});
const auto options7 = CastOptions::Safe(dest7);
const auto options8 = CastOptions::Safe(dest8);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options7));
Cast(src, options8));

// duplicate missing field names
const auto dest8 = arrow::struct_(
const auto dest9 = arrow::struct_(
{std::make_shared<Field>("a", int8()), std::make_shared<Field>("c", int16()),
std::make_shared<Field>("d", int32()), std::make_shared<Field>("a", int64())});
const auto options8 = CastOptions::Safe(dest8);
const auto options9 = CastOptions::Safe(dest9);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options8));
Cast(src, options9));

// duplicate present field names
ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -2639,6 +2645,8 @@ static void CheckStructToStructSubsetWithNulls(
d2 = ArrayFromJSON(dest_value_type, "[6, 51, null]");
e2 = ArrayFromJSON(dest_value_type, "[null, 17, 74]");

auto nulls = ArrayFromJSON(dest_value_type, "[null, null, null]");

std::shared_ptr<Buffer> null_bitmap;
BitmapFromVector<int>({0, 1, 0}, &null_bitmap);

Expand Down Expand Up @@ -2674,34 +2682,39 @@ static void CheckStructToStructSubsetWithNulls(
CheckCast(src_null, dest5_null);

// field does not exist
const auto dest6_null = arrow::struct_({std::make_shared<Field>("a", int8()),
std::make_shared<Field>("d", int16()),
std::make_shared<Field>("f", int64())});
const auto options6_null = CastOptions::Safe(dest6_null);
ASSERT_OK_AND_ASSIGN(
auto dest6_null,
StructArray::Make({a1, d1, nulls}, {"a", "d", "f"}, null_bitmap));
CheckCast(src_null, dest6_null);

const auto dest7_null = arrow::struct_(
{std::make_shared<Field>("a", int8()), std::make_shared<Field>("d", int16()),
std::make_shared<Field>("f", int64(), /*nullable=*/false)});
const auto options7_null = CastOptions::Safe(dest7_null);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src_null, options6_null));
Cast(src_null, options7_null));

// fields in wrong order
const auto dest7_null = arrow::struct_({std::make_shared<Field>("a", int8()),
const auto dest8_null = arrow::struct_({std::make_shared<Field>("a", int8()),
std::make_shared<Field>("c", int16()),
std::make_shared<Field>("b", int64())});
const auto options7_null = CastOptions::Safe(dest7_null);
const auto options8_null = CastOptions::Safe(dest8_null);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src_null, options7_null));
Cast(src_null, options8_null));

// duplicate missing field names
const auto dest8_null = arrow::struct_(
const auto dest9_null = arrow::struct_(
{std::make_shared<Field>("a", int8()), std::make_shared<Field>("c", int16()),
std::make_shared<Field>("d", int32()), std::make_shared<Field>("a", int64())});
const auto options8_null = CastOptions::Safe(dest8_null);
const auto options9_null = CastOptions::Safe(dest9_null);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src_null, options8_null));
Cast(src_null, options9_null));

// duplicate present field values
ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -2737,20 +2750,26 @@ TEST(Cast, StructToStructSubsetWithNulls) {
}

TEST(Cast, StructToSameSizedButDifferentNamedStruct) {
std::vector<std::string> field_names = {"a", "b"};
std::vector<std::string> src_field_names = {"a", "b"};
std::shared_ptr<Array> a, b;
a = ArrayFromJSON(int8(), "[1, 2]");
b = ArrayFromJSON(int8(), "[3, 4]");
ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names));
auto nulls = ArrayFromJSON(int8(), "[null, null]");
ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, src_field_names));

std::vector<std::string> dest1_field_names = {"c", "d"};
ASSERT_OK_AND_ASSIGN(auto dest1, StructArray::Make({nulls, nulls}, dest1_field_names));
CheckCast(src, dest1);

const auto dest = arrow::struct_(
{std::make_shared<Field>("c", int8()), std::make_shared<Field>("d", int8())});
const auto options = CastOptions::Safe(dest);
const auto dest2 =
arrow::struct_({std::make_shared<Field>("c", int8(), /*nullable=*/false),
std::make_shared<Field>("d", int8())});
const auto options2 = CastOptions::Safe(dest2);

EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options));
Cast(src, options2));
}

TEST(Cast, StructToBiggerStruct) {
Expand All @@ -2760,15 +2779,41 @@ TEST(Cast, StructToBiggerStruct) {
b = ArrayFromJSON(int8(), "[3, 4]");
ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names));

const auto dest = arrow::struct_({std::make_shared<Field>("a", int8()),
std::make_shared<Field>("b", int8()),
std::make_shared<Field>("c", int8())});
const auto options = CastOptions::Safe(dest);
const auto dest1 = arrow::struct_(
{std::make_shared<Field>("a", int8()), std::make_shared<Field>("b", int8()),
std::make_shared<Field>("c", int8(), /*nullable=*/false)});
const auto options1 = CastOptions::Safe(dest1);

EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options1));

const auto dest2 =
arrow::struct_({std::make_shared<Field>("a", int8()),
std::make_shared<Field>("c", int8(), /*nullable=*/false),
std::make_shared<Field>("b", int8())});
const auto options2 = CastOptions::Safe(dest2);

EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options));
Cast(src, options2));
}

TEST(Cast, StructToBiggerNullableStruct) {
std::vector<std::string> field_names = {"a", "b"};
std::shared_ptr<Array> a, b, c;
a = ArrayFromJSON(int8(), "[1, 2]");
b = ArrayFromJSON(int8(), "[3, 4]");
ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a, b}, field_names));

c = ArrayFromJSON(int8(), "[null, null]");
ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({a, b, c}, {"a", "b", "c"}));
CheckCast(src, dest);

ASSERT_OK_AND_ASSIGN(auto dest2, StructArray::Make({a, c, b}, {"a", "c", "b"}));
CheckCast(src, dest2);
}

TEST(Cast, StructToDifferentNullabilityStruct) {
Expand Down
25 changes: 17 additions & 8 deletions cpp/src/arrow/dataset/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,7 @@ DatasetAndBatches MakeNestedDataset() {
field("b", boolean()),
field("c", struct_({
field("d", int64()),
field("e", float64()),
field("e", int64()),
})),
});
const auto physical_schema = ::arrow::schema({
Expand Down Expand Up @@ -2531,26 +2531,35 @@ TEST(ScanNode, MaterializationOfVirtualColumn) {
TEST(ScanNode, MaterializationOfNestedVirtualColumn) {
TestPlan plan;

auto basic = MakeNestedDataset();
auto nested = MakeNestedDataset();

auto options = std::make_shared<ScanOptions>();
options->projection = Materialize({"a", "b", "c"}, /*include_aug_fields=*/true);

ASSERT_OK(acero::Declaration::Sequence(
{
{"scan", ScanNodeOptions{basic.dataset, options}},
{"scan", ScanNodeOptions{nested.dataset, options}},
{"augmented_project",
acero::ProjectNodeOptions{
{field_ref("a"), field_ref("b"), field_ref("c")}}},
{"sink", acero::SinkNodeOptions{&plan.sink_gen}},
})
.AddToPlan(plan.get()));

// TODO(ARROW-1888): allow scanner to "patch up" structs with casts
EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("struct fields don't match or are in the wrong order"),
plan.Run());
auto expected = nested.batches;

for (auto& batch : expected) {
// Scan will fill in "c.d" with nulls.
ASSERT_OK_AND_ASSIGN(auto nulls,
MakeArrayOfNull(int64()->GetSharedPtr(), batch.length));
auto c_data = batch.values[2].array()->Copy();
c_data->child_data.insert(c_data->child_data.begin(), nulls->data());
c_data->type = nested.dataset->schema()->field(2)->type();
auto c_array = std::make_shared<StructArray>(c_data);
batch.values[2] = c_array;
}

ASSERT_THAT(plan.Run(), Finishes(ResultWith(UnorderedElementsAreArray(expected))));
}

TEST(ScanNode, MinimalEndToEnd) {
Expand Down
Loading