Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand PrestoQueryRunner to test CAST expressions with a VALUES node #10523

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion velox/connectors/hive/tests/HiveConnectorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) {
// Change these once HUGEINT filter merge is fixed.
ASSERT_TRUE(remaining);
ASSERT_EQ(
remaining->toString(), "not(lt(ROW[\"c2\"],cast 0 as DECIMAL(20, 0)))");
remaining->toString(), "not(lt(ROW[\"c2\"],cast (0 as DECIMAL(20, 0))))");
}

TEST_F(HiveConnectorTest, prestoTableSampling) {
Expand Down
4 changes: 2 additions & 2 deletions velox/core/Expressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,10 @@ class CastTypedExpr : public ITypedExpr {
std::string toString() const override {
if (nullOnFailure_) {
return fmt::format(
"try_cast {} as {}", inputs()[0]->toString(), type()->toString());
"try_cast ({} as {})", inputs()[0]->toString(), type()->toString());
} else {
return fmt::format(
"cast {} as {}", inputs()[0]->toString(), type()->toString());
"cast ({} as {})", inputs()[0]->toString(), type()->toString());
}
}

Expand Down
24 changes: 23 additions & 1 deletion velox/exec/fuzzer/PrestoQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ std::optional<std::string> PrestoQueryRunner::toSql(
return toSql(joinNode);
}

if (auto valuesNode =
std::dynamic_pointer_cast<const core::ValuesNode>(plan)) {
return toSql(valuesNode);
}

VELOX_NYI();
}

Expand Down Expand Up @@ -240,6 +245,12 @@ std::string toWindowCallSql(
return sql.str();
}

std::string toCastSql(const core::CastTypedExprPtr& cast) {
std::stringstream sql;
sql << cast->toString();
return sql.str();
}

bool isSupportedDwrfType(const TypePtr& type) {
if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) {
return false;
Expand Down Expand Up @@ -341,6 +352,10 @@ std::optional<std::string> PrestoQueryRunner::toSql(
auto call =
std::dynamic_pointer_cast<const core::CallTypedExpr>(projection)) {
sql << toCallSql(call);
} else if (
auto call =
std::dynamic_pointer_cast<const core::CastTypedExpr>(projection)) {
sql << toCastSql(call);
} else {
VELOX_NYI();
}
Expand All @@ -352,6 +367,12 @@ std::optional<std::string> PrestoQueryRunner::toSql(
return sql.str();
}

std::optional<std::string> PrestoQueryRunner::toSql(
const std::shared_ptr<const velox::core::ValuesNode>& valuesNode) {
// Map VALUES to table that is created when data is present.
return "tmp";
}

namespace {

void appendWindowFrame(
Expand Down Expand Up @@ -826,7 +847,8 @@ std::string PrestoQueryRunner::startQuery(
{"X-Presto-Catalog", "hive"},
{"X-Presto-Schema", "tpch"},
{"Content-Type", "text/plain"},
{"X-Presto-Session", sessionProperty}});
{"X-Presto-Session", sessionProperty},
{"X-Presto-Time-Zone", "GMT+3"}});
cpr::Timeout timeout{timeout_};
cpr::Response response = cpr::Post(url, body, header, timeout);
VELOX_CHECK_EQ(
Expand Down
3 changes: 3 additions & 0 deletions velox/exec/fuzzer/PrestoQueryRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner {
std::optional<std::string> toSql(
const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode);

std::optional<std::string> toSql(
const std::shared_ptr<const velox::core::ValuesNode>& valuesNode);

std::string startQuery(
const std::string& sql,
const std::string& sessionProperty = "");
Expand Down
22 changes: 11 additions & 11 deletions velox/exec/tests/PlanNodeToStringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ TEST_F(PlanNodeToStringTest, recursive) {

TEST_F(PlanNodeToStringTest, detailed) {
ASSERT_EQ(
"-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n",
"-- Project[4][expressions: (out3:BIGINT, plus(cast (ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n",
plan_->toString(true, false));
}

TEST_F(PlanNodeToStringTest, recursiveAndDetailed) {
ASSERT_EQ(
"-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n"
" -- Filter[3][expression: lt(mod(cast ROW[\"out1\"] as BIGINT,10),8)] -> out1:SMALLINT, out2:BIGINT\n"
" -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast ROW[\"c0\"] as BIGINT,100),mod(cast ROW[\"c1\"] as BIGINT,50)))] -> out1:SMALLINT, out2:BIGINT\n"
" -- Filter[1][expression: lt(mod(cast ROW[\"c0\"] as BIGINT,10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n"
"-- Project[4][expressions: (out3:BIGINT, plus(cast (ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n"
" -- Filter[3][expression: lt(mod(cast (ROW[\"out1\"] as BIGINT),10),8)] -> out1:SMALLINT, out2:BIGINT\n"
" -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast (ROW[\"c0\"] as BIGINT),100),mod(cast (ROW[\"c1\"] as BIGINT),50)))] -> out1:SMALLINT, out2:BIGINT\n"
" -- Filter[1][expression: lt(mod(cast (ROW[\"c0\"] as BIGINT),10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n"
" -- Values[0][5 rows in 1 vectors] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n",
plan_->toString(true, true));
}
Expand All @@ -101,7 +101,7 @@ TEST_F(PlanNodeToStringTest, withContext) {
plan_->toString(false, false, addContext));

ASSERT_EQ(
"-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n"
"-- Project[4][expressions: (out3:BIGINT, plus(cast (ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n"
" Context for 4\n",
plan_->toString(true, false, addContext));

Expand All @@ -119,13 +119,13 @@ TEST_F(PlanNodeToStringTest, withContext) {
plan_->toString(false, true, addContext));

ASSERT_EQ(
"-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n"
"-- Project[4][expressions: (out3:BIGINT, plus(cast (ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n"
" Context for 4\n"
" -- Filter[3][expression: lt(mod(cast ROW[\"out1\"] as BIGINT,10),8)] -> out1:SMALLINT, out2:BIGINT\n"
" -- Filter[3][expression: lt(mod(cast (ROW[\"out1\"] as BIGINT),10),8)] -> out1:SMALLINT, out2:BIGINT\n"
" Context for 3\n"
" -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast ROW[\"c0\"] as BIGINT,100),mod(cast ROW[\"c1\"] as BIGINT,50)))] -> out1:SMALLINT, out2:BIGINT\n"
" -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast (ROW[\"c0\"] as BIGINT),100),mod(cast (ROW[\"c1\"] as BIGINT),50)))] -> out1:SMALLINT, out2:BIGINT\n"
" Context for 2\n"
" -- Filter[1][expression: lt(mod(cast ROW[\"c0\"] as BIGINT,10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n"
" -- Filter[1][expression: lt(mod(cast (ROW[\"c0\"] as BIGINT),10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n"
" Context for 1\n"
" -- Values[0][5 rows in 1 vectors] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n"
" Context for 0\n",
Expand All @@ -147,7 +147,7 @@ TEST_F(PlanNodeToStringTest, withMultiLineContext) {
plan_->toString(false, false, addContext));

ASSERT_EQ(
"-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n"
"-- Project[4][expressions: (out3:BIGINT, plus(cast (ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n"
" Context for 4: line 1\n"
" Context for 4: line 2\n",
plan_->toString(true, false, addContext));
Expand Down
8 changes: 8 additions & 0 deletions velox/exec/tests/utils/PlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,14 @@ PlanBuilder& PlanBuilder::projectExpressions(
return *this;
}

PlanBuilder& PlanBuilder::projectTypedExpressions(
const std::vector<std::string>& projectNames,
const std::vector<core::TypedExprPtr>& projections) {
planNode_ = std::make_shared<core::ProjectNode>(
nextPlanNodeId(), projectNames, projections, planNode_);
return *this;
}

PlanBuilder& PlanBuilder::project(const std::vector<std::string>& projections) {
VELOX_CHECK_NOT_NULL(planNode_, "Project cannot be the source node");
std::vector<std::shared_ptr<const core::IExpr>> expressions;
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/tests/utils/PlanBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ class PlanBuilder {
PlanBuilder& projectExpressions(
const std::vector<std::shared_ptr<const core::IExpr>>& projections);

/// Variation of project that takes typed expressions. This bypasses the
/// duckdb parser.
PlanBuilder& projectTypedExpressions(
const std::vector<std::string>& projectNames,
const std::vector<core::TypedExprPtr>& projections);

/// Similar to project() except 'optionalProjections' could be empty and the
/// function will skip creating a ProjectNode in that case.
PlanBuilder& optionalProject(
Expand Down
9 changes: 9 additions & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,12 @@ target_link_libraries(
gflags::gflags
GTest::gmock
GTest::gmock_main)

add_executable(velox_presto_functions_test PrestoStringCastTest.cpp)
add_test(velox_presto_functions_test velox_presto_functions_test)
target_link_libraries(
velox_presto_functions_test
velox_functions_test_lib
velox_fuzzer_util
GTest::gtest
GTest::gtest_main)
190 changes: 190 additions & 0 deletions velox/functions/prestosql/tests/PrestoStringCastTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/exec/fuzzer/PrestoQueryRunner.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/prestosql/tests/CastBaseTest.h"

using namespace facebook::velox::exec::test;
using namespace facebook::velox;

namespace facebook::velox::functions::test {

class PrestoStringCastTest : public functions::test::CastBaseTest {
public:
void SetUp() override {
velox::functions::prestosql::registerAllScalarFunctions();
queryRunner_ = std::make_unique<PrestoQueryRunner>(
pool(),
"http://127.0.0.1:8080",
"hive",
static_cast<std::chrono::milliseconds>(5000));
}

void evalCastTypedExpression(
const VectorPtr& data,
const TypePtr& outputType) {
const auto kOutputColName = "p0";

auto rows = makeRowVector({data});

auto inputType = rows->childAt(0)->type();

auto typedExpr = buildCastExpr(inputType, outputType, false);
auto plan = velox::exec::test::PlanBuilder()
.values({rows})
.projectTypedExpressions({kOutputColName}, {typedExpr})
.planNode();

auto sql = queryRunner_->toSql(plan);
ASSERT_TRUE(sql.has_value());
SCOPED_TRACE(fmt::format("SQL: {}", sql.value()));
ASSERT_EQ(
sql.value(),
fmt::format(
"SELECT cast (\"c0\" as {}) as {} FROM (tmp)",
outputType->toString(),
kOutputColName));

auto outputRowType = ROW({kOutputColName}, {outputType});
auto prestoResults =
queryRunner_->execute(sql.value(), {rows}, outputRowType);

auto veloxResults =
velox::exec::test::AssertQueryBuilder(plan).copyResults(pool());
velox::exec::test::assertEqualResults(
prestoResults, plan->outputType(), {veloxResults});
}

const TypePtr kTargetType_ = VARCHAR();

private:
std::unique_ptr<PrestoQueryRunner> queryRunner_;
};

TEST_F(PrestoStringCastTest, DISABLED_varchar) {
auto data = makeNullableFlatVector<std::string>(
{std::nullopt,
"ABCDEFSFFFFFFF",
"ABCDEFSDDDDDDDDDDD"
"ABCDEFSEEEEEEEEEEEEEEE"});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_boolean) {
auto data = makeNullableFlatVector<bool>({std::nullopt, true, false});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_smallint) {
auto data = makeNullableFlatVector<int16_t>(
{std::nullopt,
12345,
-12345,
std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max()});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_integer) {
auto data = makeNullableFlatVector<int32_t>(
{std::nullopt,
12345,
-12345,
12345678,
-12345678,
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_bigint) {
auto data = makeNullableFlatVector<int64_t>(
{std::nullopt,
12345,
-12345,
12345678,
-12345678,
12345678901234,
-12345678901234,
std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max()});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_real) {
// TODO precision issue with float::max
// 1 extra rows, 1 missing rows
// 1 of extra rows:
// "1.1754944E-38"
// 1 of missing rows:
// "1.17549435E-38"
auto data = makeNullableFlatVector<float>(
{std::nullopt,
12345.0,
-12345.0,
12345678,
-12345678,
std::numeric_limits<float>::min(),
std::numeric_limits<float>::max()});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_double) {
// TODO precision issue in number of digits printed
// for double::max
// 2 extra rows, 2 missing rows
// 2 of extra rows:
// "-9.2233720368547758E18"
// "9.2233720368547758E18"
// 2 of missing rows:
// "-9.223372036854776E18"
// "9.223372036854776E18"
auto data = makeNullableFlatVector<double>(
{std::nullopt,
12345678,
-12345678,
12345678901234,
-12345678901234,
std::numeric_limits<double>::min(),
std::numeric_limits<double>::max()});

evalCastTypedExpression(data, kTargetType_);
}

TEST_F(PrestoStringCastTest, DISABLED_timestamp) {
// TODO the DRWF data written to file is as if it was written
// in GMT+3 session timezone instead of UTC.
// As a result the X-Presto-Time-Zone needs to be set accordingly to be able
// to match the values.
auto data = makeNullableFlatVector<Timestamp>(
{std::nullopt,
Timestamp(0, 0),
Timestamp(-1'000'000, 0),
Timestamp(9'000'000, 500)});

evalCastTypedExpression(data, kTargetType_);
}

} // namespace facebook::velox::functions::test
Loading