diff --git a/velox/functions/sparksql/ConcatWs.cpp b/velox/functions/sparksql/ConcatWs.cpp index 8b745391cfdd0..b477086165d2c 100644 --- a/velox/functions/sparksql/ConcatWs.cpp +++ b/velox/functions/sparksql/ConcatWs.cpp @@ -39,16 +39,12 @@ class ConcatWs : public exec::VectorFunction { const exec::LocalDecodedVector& decodedSeparator) const { auto arrayArgNum = decodedArrays.size(); std::vector arrayVectors; - std::vector elementsDecodedVectors; + std::vector elementsDecodedVectors; for (auto i = 0; i < arrayArgNum; ++i) { auto arrayVector = decodedArrays[i].get()->base()->as(); arrayVectors.push_back(arrayVector); - auto elements = arrayVector->elements(); - exec::LocalSelectivityVector nestedRows(context, elements->size()); - nestedRows.get()->setAll(); - exec::LocalDecodedVector elementsHolder( - context, *elements, *nestedRows.get()); - elementsDecodedVectors.push_back(elementsHolder.get()); + SelectivityVector nestedRows(arrayVector->elements()->size()); + elementsDecodedVectors.emplace_back(*arrayVector->elements(), nestedRows); } size_t totalResultBytes = 0; @@ -60,17 +56,18 @@ class ConcatWs : public exec::VectorFunction { int32_t allElements = 0; // Calculate size for array columns data. for (int i = 0; i < arrayArgNum; i++) { + if (decodedArrays[i]->isNullAt(row)) { + continue; + } auto arrayVector = arrayVectors[i]; - auto rawSizes = arrayVector->rawSizes(); - auto rawOffsets = arrayVector->rawOffsets(); auto indices = decodedArrays[i].get()->indices(); - auto elementsDecoded = elementsDecodedVectors[i]; + auto size = arrayVector->sizeAt(indices[row]); + auto offset = arrayVector->offsetAt(indices[row]); - auto size = rawSizes[indices[row]]; - auto offset = rawOffsets[indices[row]]; for (int j = 0; j < size; ++j) { - if (!elementsDecoded->isNullAt(offset + j)) { - auto element = elementsDecoded->valueAt(offset + j); + if (!elementsDecodedVectors[i].isNullAt(offset + j)) { + auto element = + elementsDecodedVectors[i].valueAt(offset + j); // No matter empty string or not. ++allElements; totalResultBytes += element.size(); @@ -209,16 +206,12 @@ class ConcatWs : public exec::VectorFunction { decodedSeparator); std::vector arrayVectors; - std::vector elementsDecodedVectors; + std::vector elementsDecodedVectors; for (auto i = 0; i < decodedArrays.size(); ++i) { auto arrayVector = decodedArrays[i].get()->base()->as(); arrayVectors.push_back(arrayVector); - auto elements = arrayVector->elements(); - exec::LocalSelectivityVector nestedRows(context, elements->size()); - nestedRows.get()->setAll(); - exec::LocalDecodedVector elementsHolder( - context, *elements, *nestedRows.get()); - elementsDecodedVectors.push_back(elementsHolder.get()); + SelectivityVector nestedRows(arrayVector->elements()->size()); + elementsDecodedVectors.emplace_back(*arrayVector->elements(), nestedRows); } // Allocate a string buffer. auto rawBuffer = @@ -255,17 +248,19 @@ class ConcatWs : public exec::VectorFunction { for (auto itArgs = args.begin() + 1; itArgs != args.end(); ++itArgs) { if ((*itArgs)->typeKind() == TypeKind::ARRAY) { + if ((*itArgs)->isNullAt(row)) { + ++i; + continue; + } auto arrayVector = arrayVectors[i]; - auto rawSizes = arrayVector->rawSizes(); - auto rawOffsets = arrayVector->rawOffsets(); auto indices = decodedArrays[i].get()->indices(); - auto elementsDecoded = elementsDecodedVectors[i]; + auto size = arrayVector->sizeAt(indices[row]); + auto offset = arrayVector->offsetAt(indices[row]); - auto size = rawSizes[indices[row]]; - auto offset = rawOffsets[indices[row]]; for (int k = 0; k < size; ++k) { - if (!elementsDecoded->isNullAt(offset + k)) { - auto element = elementsDecoded->valueAt(offset + k); + if (!elementsDecodedVectors[i].isNullAt(offset + k)) { + auto element = + elementsDecodedVectors[i].valueAt(offset + k); copyToBuffer( element, isConstantSeparator() diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp index 3e51838b448bc..4cf6703c68f2b 100644 --- a/velox/functions/sparksql/tests/ConcatWsTest.cpp +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -127,7 +127,7 @@ TEST_F(ConcatWsTest, stringArgsWithNulls) { auto result = evaluate>( "concat_ws('~','',c0,'x',NULL::VARCHAR)", makeRowVector({input})); - auto expected = makeNullableFlatVector({ + auto expected = makeFlatVector({ "~~x", "~x", "~a~x", @@ -219,6 +219,13 @@ TEST_F(ConcatWsTest, arrayArgs) { "red--purple--green--red--purple--green", }); velox::test::assertEqualVectors(expected, result); + + // Constant arrays. + auto dummyInput = makeRowVector(makeRowType({VARCHAR()}), 1); + result = evaluate>( + "concat_ws('--', array['a','b','c'], array['d'])", dummyInput); + expected = makeFlatVector({"a--b--c--d"}); + velox::test::assertEqualVectors(expected, result); } TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {