Skip to content

Commit 9f6a0d4

Browse files
committed
align validations
1 parent d0fa6eb commit 9f6a0d4

File tree

5 files changed

+70
-111
lines changed

5 files changed

+70
-111
lines changed

velox/substrait/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ add_library(velox_substrait_plan_converter ${SRCS})
4747
target_include_directories(velox_substrait_plan_converter
4848
PUBLIC ${PROTO_OUTPUT_DIR})
4949
target_link_libraries(velox_substrait_plan_converter velox_connector
50-
velox_dwio_dwrf_common)
50+
velox_dwio_dwrf_common velox_functions_spark)
5151

5252
if(${VELOX_BUILD_TESTING})
5353
add_subdirectory(tests)

velox/substrait/SubstraitToVeloxPlan.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,10 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
549549

550550
std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
551551
const ::substrait::Plan& sPlan) {
552-
// Construct the function map based on the Substrait representation.
552+
// Construct the function map based on the Substrait representation,
553+
// and initialize the expression converter with it.
553554
constructFuncMap(sPlan);
554555

555-
// Create the expression converter.
556-
exprConverter_ = std::make_shared<SubstraitVeloxExprConverter>(functionMap_);
557-
558556
// In fact, only one RelRoot or Rel is expected here.
559557
for (const auto& sRel : sPlan.relations()) {
560558
if (sRel.has_root()) {
@@ -579,6 +577,7 @@ void SubstraitVeloxPlanConverter::constructFuncMap(
579577
auto name = sFmap.name();
580578
functionMap_[id] = name;
581579
}
580+
exprConverter_ = std::make_shared<SubstraitVeloxExprConverter>(functionMap_);
582581
}
583582

584583
std::string SubstraitVeloxPlanConverter::nextPlanNodeId() {
@@ -674,6 +673,9 @@ int32_t SubstraitVeloxPlanConverter::streamIsInput(
674673
VELOX_FAIL(err.what());
675674
}
676675
}
676+
if (validationMode_) {
677+
return -1;
678+
}
677679
VELOX_FAIL("Local file is expected.");
678680
}
679681

@@ -1316,7 +1318,7 @@ SubstraitVeloxPlanConverter::connectWithAnd(
13161318
while (idx < remainingFunctions.size()) {
13171319
std::vector<std::shared_ptr<const core::ITypedExpr>> params;
13181320
params.reserve(2);
1319-
params.emplace_back(std::move(remainingFilter));
1321+
params.emplace_back(remainingFilter);
13201322
params.emplace_back(
13211323
exprConverter_->toVeloxExpr(remainingFunctions[idx], inputType));
13221324
remainingFilter = std::make_shared<const core::CallTypedExpr>(

velox/substrait/SubstraitToVeloxPlan.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ namespace facebook::velox::substrait {
2525
/// This class is used to convert the Substrait plan into Velox plan.
2626
class SubstraitVeloxPlanConverter {
2727
public:
28+
SubstraitVeloxPlanConverter(bool validationMode = false)
29+
: validationMode_(validationMode) {}
30+
2831
/// Used to convert Substrait JoinRel into Velox PlanNode.
2932
std::shared_ptr<const core::PlanNode> toVeloxPlan(
3033
const ::substrait::JoinRel& sJoin);
@@ -65,7 +68,8 @@ class SubstraitVeloxPlanConverter {
6568
const ::substrait::Plan& sPlan);
6669

6770
/// Used to construct the function map between the index
68-
/// and the Substrait function name.
71+
/// and the Substrait function name. Initialize the expression
72+
/// converter based on the constructed function map.
6973
void constructFuncMap(const ::substrait::Plan& sPlan);
7074

7175
/// Will return the function map used by this plan converter.
@@ -371,6 +375,9 @@ class SubstraitVeloxPlanConverter {
371375
/// The Expression converter used to convert Substrait representations into
372376
/// Velox expressions.
373377
std::shared_ptr<SubstraitVeloxExprConverter> exprConverter_;
378+
379+
/// A flag used to specify validation.
380+
bool validationMode_ = false;
374381
};
375382

376383
} // namespace facebook::velox::substrait

velox/substrait/SubstraitToVeloxPlanValidator.cpp

Lines changed: 53 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "velox/substrait/SubstraitToVeloxPlanValidator.h"
1818
#include "TypeUtils.h"
19+
#include "velox/functions/sparksql/Register.h"
1920

2021
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
2122

@@ -78,7 +79,8 @@ bool SubstraitToVeloxPlanValidator::validate(
7879
const auto& extension = sProject.advanced_extension();
7980
std::vector<TypePtr> types;
8081
if (!validateInputTypes(extension, types)) {
81-
std::cout << "Validation failed for input types in ProjectRel" << std::endl;
82+
std::cout << "Validation failed for input types in ProjectRel."
83+
<< std::endl;
8284
return false;
8385
}
8486

@@ -112,7 +114,45 @@ bool SubstraitToVeloxPlanValidator::validate(
112114

113115
bool SubstraitToVeloxPlanValidator::validate(
114116
const ::substrait::FilterRel& sFilter) {
115-
return false;
117+
if (sFilter.has_input() && !validate(sFilter.input())) {
118+
return false;
119+
}
120+
121+
// Get and validate the input types from extension.
122+
if (!sFilter.has_advanced_extension()) {
123+
std::cout << "Input types are expected in FilterRel." << std::endl;
124+
return false;
125+
}
126+
const auto& extension = sFilter.advanced_extension();
127+
std::vector<TypePtr> types;
128+
if (!validateInputTypes(extension, types)) {
129+
std::cout << "Validation failed for input types in FilterRel." << std::endl;
130+
return false;
131+
}
132+
133+
int32_t inputPlanNodeId = 0;
134+
// Create the fake input names to be used in row type.
135+
std::vector<std::string> names;
136+
names.reserve(types.size());
137+
for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) {
138+
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
139+
}
140+
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));
141+
142+
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions;
143+
expressions.reserve(1);
144+
try {
145+
expressions.emplace_back(
146+
exprConverter_->toVeloxExpr(sFilter.condition(), rowType));
147+
// Try to compile the expressions. If there is any unregistred funciton
148+
// or mismatched type, exception will be thrown.
149+
exec::ExprSet exprSet(std::move(expressions), &execCtx_);
150+
} catch (const VeloxException& err) {
151+
std::cout << "Validation failed for expression in ProjectRel due to:"
152+
<< err.message() << std::endl;
153+
return false;
154+
}
155+
return true;
116156
}
117157

118158
bool SubstraitToVeloxPlanValidator::validate(
@@ -228,6 +268,7 @@ bool SubstraitToVeloxPlanValidator::validate(
228268
auto typeCase = arg.rex_type_case();
229269
switch (typeCase) {
230270
case ::substrait::Expression::RexTypeCase::kSelection:
271+
case ::substrait::Expression::RexTypeCase::kLiteral:
231272
break;
232273
default:
233274
std::cout << "Only field is supported in aggregate functions."
@@ -256,110 +297,18 @@ bool SubstraitToVeloxPlanValidator::validate(
256297

257298
bool SubstraitToVeloxPlanValidator::validate(
258299
const ::substrait::ReadRel& sRead) {
259-
if (!sRead.has_base_schema()) {
260-
std::cout << "Validation failed due to schema was not found in ReadRel."
300+
try {
301+
u_int32_t index;
302+
std::vector<std::string> paths;
303+
std::vector<u_int64_t> starts;
304+
std::vector<u_int64_t> lengths;
305+
306+
planConverter_->toVeloxPlan(sRead, index, paths, starts, lengths);
307+
} catch (const VeloxException& err) {
308+
std::cout << "ReadRel validation failed due to:" << err.message()
261309
<< std::endl;
262310
return false;
263311
}
264-
const auto& sTypes = sRead.base_schema().struct_().types();
265-
for (const auto& sType : sTypes) {
266-
if (!validate(sType)) {
267-
std::cout << "Validation failed due to type was not supported in ReadRel."
268-
<< std::endl;
269-
return false;
270-
}
271-
}
272-
std::vector<::substrait::Expression_ScalarFunction> scalarFunctions;
273-
if (sRead.has_filter()) {
274-
try {
275-
planConverter_->flattenConditions(sRead.filter(), scalarFunctions);
276-
} catch (const VeloxException& err) {
277-
std::cout
278-
<< "Validation failed due to flattening conditions failed in ReadRel due to:"
279-
<< err.message() << std::endl;
280-
return false;
281-
}
282-
}
283-
// Get and validate the filter functions.
284-
std::vector<std::string> funcSpecs;
285-
funcSpecs.reserve(scalarFunctions.size());
286-
for (const auto& scalarFunction : scalarFunctions) {
287-
try {
288-
funcSpecs.emplace_back(
289-
planConverter_->findFuncSpec(scalarFunction.function_reference()));
290-
} catch (const VeloxException& err) {
291-
std::cout << "Validation failed in ReadRel due to:" << err.message()
292-
<< std::endl;
293-
return false;
294-
}
295-
296-
if (scalarFunction.args().size() == 1) {
297-
// Field is expected.
298-
for (const auto& param : scalarFunction.args()) {
299-
auto typeCase = param.rex_type_case();
300-
switch (typeCase) {
301-
case ::substrait::Expression::RexTypeCase::kSelection:
302-
break;
303-
default:
304-
std::cout << "Field is Expected." << std::endl;
305-
return false;
306-
}
307-
}
308-
} else if (scalarFunction.args().size() == 2) {
309-
// Expect there being two args. One is field and the other is literal.
310-
bool fieldExists = false;
311-
bool litExists = false;
312-
for (const auto& param : scalarFunction.args()) {
313-
auto typeCase = param.rex_type_case();
314-
switch (typeCase) {
315-
case ::substrait::Expression::RexTypeCase::kSelection: {
316-
fieldExists = true;
317-
break;
318-
}
319-
case ::substrait::Expression::RexTypeCase::kLiteral: {
320-
litExists = true;
321-
break;
322-
}
323-
default:
324-
std::cout << "Type case: " << typeCase
325-
<< " is not supported in ReadRel." << std::endl;
326-
return false;
327-
}
328-
}
329-
if (!fieldExists || !litExists) {
330-
std::cout << "Only the case of Field and Literal is supported."
331-
<< std::endl;
332-
return false;
333-
}
334-
} else {
335-
std::cout << "More than two args is not supported in ReadRel."
336-
<< std::endl;
337-
return false;
338-
}
339-
}
340-
std::unordered_set<std::string> supportedFilters = {
341-
"is_not_null", "gte", "gt", "lte", "lt"};
342-
std::unordered_set<std::string> supportedTypes = {"opt", "req", "fp64"};
343-
for (const auto& funcSpec : funcSpecs) {
344-
// Validate the functions.
345-
auto funcName = subParser_->getSubFunctionName(funcSpec);
346-
if (supportedFilters.find(funcName) == supportedFilters.end()) {
347-
std::cout << "Validation failed due to " << funcName
348-
<< " was not supported in ReadRel." << std::endl;
349-
return false;
350-
}
351-
352-
// Validate the types.
353-
std::vector<std::string> funcTypes;
354-
subParser_->getSubFunctionTypes(funcSpec, funcTypes);
355-
for (const auto& funcType : funcTypes) {
356-
if (supportedTypes.find(funcType) == supportedTypes.end()) {
357-
std::cout << "Validation failed due to " << funcType
358-
<< " was not supported in ReadRel." << std::endl;
359-
return false;
360-
}
361-
}
362-
}
363312
return true;
364313
}
365314

@@ -393,6 +342,7 @@ bool SubstraitToVeloxPlanValidator::validate(
393342

394343
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Plan& sPlan) {
395344
functions::prestosql::registerAllScalarFunctions();
345+
functions::sparksql::registerFunctions("");
396346
// Create plan converter and expression converter to help the validation.
397347
planConverter_->constructFuncMap(sPlan);
398348
exprConverter_ = std::make_shared<SubstraitVeloxExprConverter>(

velox/substrait/SubstraitToVeloxPlanValidator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class SubstraitToVeloxPlanValidator {
6464

6565
/// A converter used to convert Substrait plan into Velox's plan node.
6666
std::shared_ptr<SubstraitVeloxPlanConverter> planConverter_ =
67-
std::make_shared<SubstraitVeloxPlanConverter>();
67+
std::make_shared<SubstraitVeloxPlanConverter>(true);
6868

6969
/// A parser used to convert Substrait plan into recognizable representations.
7070
std::shared_ptr<SubstraitParser> subParser_ =

0 commit comments

Comments
 (0)