Skip to content

Commit

Permalink
Fix return same column with different alias (#4233)
Browse files Browse the repository at this point in the history
* Fix return same column with different alias

* Run clang-format

---------

Co-authored-by: CI Bot <[email protected]>
  • Loading branch information
andyfengHKU and andyfengHKU authored Sep 11, 2024
1 parent 60cc1af commit 796f70d
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 62 deletions.
86 changes: 59 additions & 27 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,55 @@ static std::pair<expression_vector, std::vector<std::string>> rewriteProjectionI
return {newExprs, newAliases};
}

void validateColumnNamesAreUnique(const std::vector<std::string>& columnNames) {
auto existColumnNames = std::unordered_set<std::string>();
for (auto& name : columnNames) {
if (existColumnNames.contains(name)) {
throw BinderException(
"Multiple result column with the same name " + name + " are not supported.");
}
existColumnNames.insert(name);
}
}

std::vector<std::string> getColumnNames(const expression_vector& exprs,
const std::vector<std::string>& aliases) {
std::vector<std::string> columnNames;
for (auto i = 0u; i < exprs.size(); ++i) {
if (aliases[i].empty()) {
columnNames.push_back(exprs[i]->toString());
} else {
columnNames.push_back(aliases[i]);
}
}
return columnNames;
}

BoundWithClause Binder::bindWithClause(const WithClause& withClause) {
auto projectionBody = withClause.getProjectionBody();
auto boundProjectionBody = bindProjectionBody(*projectionBody, true /* isWithClause */);
auto [projectionExprs, aliases] = bindProjectionList(*projectionBody);
// Check all expressions are aliased
for (auto& alias : aliases) {
if (alias.empty()) {
throw BinderException("Expression in WITH must be aliased (use AS).");
}
}
auto columnNames = getColumnNames(projectionExprs, aliases);
validateColumnNamesAreUnique(columnNames);
// Rewrite projection list
auto originalProjectionExprs = projectionExprs;
auto originalAliases = aliases;
auto [newExprs, newAliases] = rewriteProjectionInWithClause(projectionExprs, aliases);
projectionExprs = newExprs;
aliases = newAliases;

auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExprs, aliases);
validateOrderByFollowedBySkipOrLimitInWithClause(boundProjectionBody);
// Update scope
scope.clear();
for (auto i = 0u; i < originalProjectionExprs.size(); ++i) {
addToScope(originalAliases[i], originalProjectionExprs[i]);
}
auto boundWithClause = BoundWithClause(std::move(boundProjectionBody));
if (withClause.hasWhereExpression()) {
boundWithClause.setWhereExpression(bindWhereExpression(*withClause.getWhereExpression()));
Expand All @@ -87,10 +132,14 @@ BoundWithClause Binder::bindWithClause(const WithClause& withClause) {

BoundReturnClause Binder::bindReturnClause(const ReturnClause& returnClause) {
auto projectionBody = returnClause.getProjectionBody();
auto boundProjectionBody = bindProjectionBody(*projectionBody, false /* isWithClause */);
auto [projectionExprs, aliases] = bindProjectionList(*projectionBody);
auto columnNames = getColumnNames(projectionExprs, aliases);
validateColumnNamesAreUnique(columnNames);
auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExprs, aliases);
auto statementResult = BoundStatementResult();
for (auto& expression : boundProjectionBody.getProjectionExpressions()) {
statementResult.addColumn(expression);
KU_ASSERT(columnNames.size() == projectionExprs.size());
for (auto i = 0u; i < columnNames.size(); ++i) {
statementResult.addColumn(columnNames[i], projectionExprs[i]);
}
return BoundReturnClause(std::move(boundProjectionBody), std::move(statementResult));
}
Expand All @@ -113,8 +162,8 @@ static expression_vector getAggregateExpressions(const std::shared_ptr<Expressio
return result;
}

BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& projectionBody,
bool isWithClause) {
std::pair<expression_vector, std::vector<std::string>> Binder::bindProjectionList(
const ProjectionBody& projectionBody) {
expression_vector projectionExprs;
std::vector<std::string> aliases;
for (auto& parsedExpr : projectionBody.getProjectionExpressions()) {
Expand Down Expand Up @@ -147,19 +196,11 @@ BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& pro
aliases.push_back(parsedExpr->hasAlias() ? parsedExpr->getAlias() : expr->getAlias());
}
}
auto originProjectionExprs = projectionExprs;
auto originAliases = aliases;
return {projectionExprs, aliases};
}

if (isWithClause) {
for (auto& alias : aliases) {
if (alias.empty()) {
throw BinderException("Expression in WITH must be aliased (use AS).");
}
}
auto [a, b] = rewriteProjectionInWithClause(projectionExprs, aliases);
projectionExprs = a;
aliases = b;
}
BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& projectionBody,
const expression_vector& projectionExprs, const std::vector<std::string>& aliases) {

expression_vector groupByExprs;
expression_vector aggregateExprs;
Expand All @@ -177,10 +218,6 @@ BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& pro
expr->setAlias(aliases[i]);
}

for (auto i = 0u; i < originProjectionExprs.size(); ++i) {
originProjectionExprs[i]->setAlias(originAliases[i]);
}
validateProjectionColumnNamesAreUnique(originProjectionExprs);
auto boundProjectionBody = BoundProjectionBody(projectionBody.getIsDistinct());
boundProjectionBody.setProjectionExpressions(projectionExprs);

Expand Down Expand Up @@ -235,11 +272,6 @@ BoundProjectionBody Binder::bindProjectionBody(const parser::ProjectionBody& pro
boundProjectionBody.setLimitNumber(
bindSkipLimitExpression(*projectionBody.getLimitExpression()));
}
// Update scope.
if (isWithClause) {
scope.clear();
addExpressionsToScope(originProjectionExprs);
}
return boundProjectionBody;
}

Expand Down
12 changes: 0 additions & 12 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,6 @@ std::shared_ptr<Expression> Binder::createVariable(const std::string& name,
return expression;
}

void Binder::validateProjectionColumnNamesAreUnique(const expression_vector& expressions) {
auto existColumnNames = std::unordered_set<std::string>();
for (auto& expression : expressions) {
auto columnName = expression->hasAlias() ? expression->getAlias() : expression->toString();
if (existColumnNames.contains(columnName)) {
throw BinderException(
"Multiple result column with the same name " + columnName + " are not supported.");
}
existColumnNames.insert(columnName);
}
}

void Binder::validateOrderByFollowedBySkipOrLimitInWithClause(
const BoundProjectionBody& boundProjectionBody) {
auto hasSkipOrLimit = boundProjectionBody.hasSkip() || boundProjectionBody.hasLimit();
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bound_statement_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ BoundStatementResult BoundStatementResult::createSingleStringColumnResult(
auto result = BoundStatementResult();
auto value = Value(LogicalType::STRING(), columnName);
auto stringColumn = std::make_shared<LiteralExpression>(std::move(value), columnName);
result.addColumn(stringColumn);
result.addColumn(columnName, stringColumn);
return result;
}

Expand Down
7 changes: 3 additions & 4 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,10 @@ class Binder {
BoundWithClause bindWithClause(const parser::WithClause& withClause);
BoundReturnClause bindReturnClause(const parser::ReturnClause& returnClause);

std::pair<expression_vector, std::vector<std::string>> bindProjectionList(
const parser::ProjectionBody& projectionBody);
BoundProjectionBody bindProjectionBody(const parser::ProjectionBody& projectionBody,
bool isWithClause);
const expression_vector& projectionExprs, const std::vector<std::string>& aliases);

expression_vector bindOrderByExpressions(
const std::vector<std::unique_ptr<parser::ParsedExpression>>& orderByExpressions);
Expand Down Expand Up @@ -287,9 +289,6 @@ class Binder {
const common::table_id_vector_t& tableIDs);

/*** validations ***/
// E.g. ... RETURN a, b AS a
static void validateProjectionColumnNamesAreUnique(const expression_vector& expressions);

static void validateOrderByFollowedBySkipOrLimitInWithClause(
const BoundProjectionBody& boundProjectionBody);

Expand Down
24 changes: 21 additions & 3 deletions src/include/binder/bound_statement_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,45 @@ namespace binder {
class BoundStatementResult {
public:
BoundStatementResult() = default;
explicit BoundStatementResult(expression_vector columns) : columns{std::move(columns)} {}
explicit BoundStatementResult(expression_vector columns, std::vector<std::string> columnNames)
: columns{std::move(columns)}, columnNames{std::move(columnNames)} {}
EXPLICIT_COPY_DEFAULT_MOVE(BoundStatementResult);

static BoundStatementResult createEmptyResult() { return BoundStatementResult(); }

static BoundStatementResult createSingleStringColumnResult(
const std::string& columnName = "result");

void addColumn(std::shared_ptr<Expression> column) { columns.push_back(std::move(column)); }
void addColumn(const std::string& columnName, std::shared_ptr<Expression> column) {
columns.push_back(std::move(column));
columnNames.push_back(columnName);
}
expression_vector getColumns() const { return columns; }
std::vector<std::string> getColumnNames() const { return columnNames; }
std::vector<common::LogicalType> getColumnTypes() const {
std::vector<common::LogicalType> columnTypes;
for (auto& column : columns) {
columnTypes.push_back(column->getDataType().copy());
}
return columnTypes;
}

std::shared_ptr<Expression> getSingleColumnExpr() const {
KU_ASSERT(columns.size() == 1);
return columns[0];
}

private:
BoundStatementResult(const BoundStatementResult& other) : columns{other.columns} {}
BoundStatementResult(const BoundStatementResult& other)
: columns{other.columns}, columnNames{other.columnNames} {}

private:
expression_vector columns;
// ColumnNames might be different from column.toString() because the same column might have
// different aliases, e.g. RETURN id AS a, id AS b
// For both columns we currently refer to the same id expr object so we cannot resolve column
// name properly from expression object.
std::vector<std::string> columnNames;
};

} // namespace binder
Expand Down
5 changes: 3 additions & 2 deletions src/include/main/query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ class QueryResult {
KUZU_API std::unique_ptr<ArrowArray> getNextArrowChunk(int64_t chunkSize);

private:
void initResultTableAndIterator(std::shared_ptr<processor::FactorizedTable> factorizedTable_,
const std::vector<std::shared_ptr<binder::Expression>>& columns);
void setColumnHeader(std::vector<std::string> columnNames,
std::vector<common::LogicalType> columnTypes);
void initResultTableAndIterator(std::shared_ptr<processor::FactorizedTable> factorizedTable_);
void validateQuerySucceed() const;

private:
Expand Down
5 changes: 3 additions & 2 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,9 @@ std::unique_ptr<QueryResult> ClientContext::executeNoLock(PreparedStatement* pre
}
executingTimer.stop();
queryResult->querySummary->executionTime = executingTimer.getElapsedTimeMS();
queryResult->initResultTableAndIterator(std::move(resultFT),
preparedStatement->statementResult->getColumns());
auto sResult = preparedStatement->statementResult.get();
queryResult->setColumnHeader(sResult->getColumnNames(), sResult->getColumnTypes());
queryResult->initResultTableAndIterator(std::move(resultFT));
return queryResult;
}

Expand Down
20 changes: 9 additions & 11 deletions src/main/query_result.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "main/query_result.h"

#include "binder/expression/expression.h"
#include "common/arrow/arrow_converter.h"
#include "common/exception/runtime.h"
#include "common/types/value/node.h"
Expand Down Expand Up @@ -57,20 +56,19 @@ void QueryResult::resetIterator() {
iterator->resetState();
}

void QueryResult::setColumnHeader(std::vector<std::string> columnNames_,
std::vector<LogicalType> columnTypes_) {
columnNames = std::move(columnNames_);
columnDataTypes = std::move(columnTypes_);
}

void QueryResult::initResultTableAndIterator(
std::shared_ptr<processor::FactorizedTable> factorizedTable_,
const binder::expression_vector& columns) {
std::shared_ptr<processor::FactorizedTable> factorizedTable_) {
factorizedTable = std::move(factorizedTable_);
tuple = std::make_shared<FlatTuple>();
std::vector<Value*> valuesToCollect;
for (auto i = 0u; i < columns.size(); ++i) {
auto column = columns[i].get();
const auto& columnType = column->getDataType();
auto columnName = column->hasAlias() ? column->getAlias() : column->toString();
columnDataTypes.push_back(columnType.copy());
columnNames.push_back(columnName);
std::unique_ptr<Value> value =
std::make_unique<Value>(Value::createDefaultValue(columnType.copy()));
for (auto& type : columnDataTypes) {
auto value = std::make_unique<Value>(Value::createDefaultValue(type.copy()));
valuesToCollect.push_back(value.get());
tuple->addValue(std::move(value));
}
Expand Down
4 changes: 4 additions & 0 deletions test/test_files/load_from/load_from.test
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,7 @@ Copy exception: Error in file ${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv
---- 2
1|Some sample text.
2|Some more sample "text".
-STATEMENT LOAD FROM '${KUZU_ROOT_DIRECTORY}/dataset/load-from-test/escape-char/double_quote_escape_char.csv' (header=true) RETURN id AS a, id AS b
---- 2
1|1
2|2

0 comments on commit 796f70d

Please sign in to comment.