|
16 | 16 |
|
17 | 17 | #include "velox/substrait/SubstraitToVeloxPlanValidator.h" |
18 | 18 | #include "TypeUtils.h" |
| 19 | +#include "velox/functions/sparksql/Register.h" |
19 | 20 |
|
20 | 21 | #include "velox/functions/prestosql/registration/RegistrationFunctions.h" |
21 | 22 |
|
@@ -78,7 +79,8 @@ bool SubstraitToVeloxPlanValidator::validate( |
78 | 79 | const auto& extension = sProject.advanced_extension(); |
79 | 80 | std::vector<TypePtr> types; |
80 | 81 | if (!validateInputTypes(extension, types)) { |
81 | | - std::cout << "Validation failed for input types in ProjectRel" << std::endl; |
| 82 | + std::cout << "Validation failed for input types in ProjectRel." |
| 83 | + << std::endl; |
82 | 84 | return false; |
83 | 85 | } |
84 | 86 |
|
@@ -112,7 +114,45 @@ bool SubstraitToVeloxPlanValidator::validate( |
112 | 114 |
|
113 | 115 | bool SubstraitToVeloxPlanValidator::validate( |
114 | 116 | const ::substrait::FilterRel& sFilter) { |
115 | | - return false; |
| 117 | + if (sFilter.has_input() && !validate(sFilter.input())) { |
| 118 | + return false; |
| 119 | + } |
| 120 | + |
| 121 | + // Get and validate the input types from extension. |
| 122 | + if (!sFilter.has_advanced_extension()) { |
| 123 | + std::cout << "Input types are expected in FilterRel." << std::endl; |
| 124 | + return false; |
| 125 | + } |
| 126 | + const auto& extension = sFilter.advanced_extension(); |
| 127 | + std::vector<TypePtr> types; |
| 128 | + if (!validateInputTypes(extension, types)) { |
| 129 | + std::cout << "Validation failed for input types in FilterRel." << std::endl; |
| 130 | + return false; |
| 131 | + } |
| 132 | + |
| 133 | + int32_t inputPlanNodeId = 0; |
| 134 | + // Create the fake input names to be used in row type. |
| 135 | + std::vector<std::string> names; |
| 136 | + names.reserve(types.size()); |
| 137 | + for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) { |
| 138 | + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); |
| 139 | + } |
| 140 | + auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); |
| 141 | + |
| 142 | + std::vector<std::shared_ptr<const core::ITypedExpr>> expressions; |
| 143 | + expressions.reserve(1); |
| 144 | + try { |
| 145 | + expressions.emplace_back( |
| 146 | + exprConverter_->toVeloxExpr(sFilter.condition(), rowType)); |
| 147 | + // Try to compile the expressions. If there is any unregistred funciton |
| 148 | + // or mismatched type, exception will be thrown. |
| 149 | + exec::ExprSet exprSet(std::move(expressions), &execCtx_); |
| 150 | + } catch (const VeloxException& err) { |
| 151 | + std::cout << "Validation failed for expression in ProjectRel due to:" |
| 152 | + << err.message() << std::endl; |
| 153 | + return false; |
| 154 | + } |
| 155 | + return true; |
116 | 156 | } |
117 | 157 |
|
118 | 158 | bool SubstraitToVeloxPlanValidator::validate( |
@@ -228,6 +268,7 @@ bool SubstraitToVeloxPlanValidator::validate( |
228 | 268 | auto typeCase = arg.rex_type_case(); |
229 | 269 | switch (typeCase) { |
230 | 270 | case ::substrait::Expression::RexTypeCase::kSelection: |
| 271 | + case ::substrait::Expression::RexTypeCase::kLiteral: |
231 | 272 | break; |
232 | 273 | default: |
233 | 274 | std::cout << "Only field is supported in aggregate functions." |
@@ -256,110 +297,18 @@ bool SubstraitToVeloxPlanValidator::validate( |
256 | 297 |
|
257 | 298 | bool SubstraitToVeloxPlanValidator::validate( |
258 | 299 | const ::substrait::ReadRel& sRead) { |
259 | | - if (!sRead.has_base_schema()) { |
260 | | - std::cout << "Validation failed due to schema was not found in ReadRel." |
| 300 | + try { |
| 301 | + u_int32_t index; |
| 302 | + std::vector<std::string> paths; |
| 303 | + std::vector<u_int64_t> starts; |
| 304 | + std::vector<u_int64_t> lengths; |
| 305 | + |
| 306 | + planConverter_->toVeloxPlan(sRead, index, paths, starts, lengths); |
| 307 | + } catch (const VeloxException& err) { |
| 308 | + std::cout << "ReadRel validation failed due to:" << err.message() |
261 | 309 | << std::endl; |
262 | 310 | return false; |
263 | 311 | } |
264 | | - const auto& sTypes = sRead.base_schema().struct_().types(); |
265 | | - for (const auto& sType : sTypes) { |
266 | | - if (!validate(sType)) { |
267 | | - std::cout << "Validation failed due to type was not supported in ReadRel." |
268 | | - << std::endl; |
269 | | - return false; |
270 | | - } |
271 | | - } |
272 | | - std::vector<::substrait::Expression_ScalarFunction> scalarFunctions; |
273 | | - if (sRead.has_filter()) { |
274 | | - try { |
275 | | - planConverter_->flattenConditions(sRead.filter(), scalarFunctions); |
276 | | - } catch (const VeloxException& err) { |
277 | | - std::cout |
278 | | - << "Validation failed due to flattening conditions failed in ReadRel due to:" |
279 | | - << err.message() << std::endl; |
280 | | - return false; |
281 | | - } |
282 | | - } |
283 | | - // Get and validate the filter functions. |
284 | | - std::vector<std::string> funcSpecs; |
285 | | - funcSpecs.reserve(scalarFunctions.size()); |
286 | | - for (const auto& scalarFunction : scalarFunctions) { |
287 | | - try { |
288 | | - funcSpecs.emplace_back( |
289 | | - planConverter_->findFuncSpec(scalarFunction.function_reference())); |
290 | | - } catch (const VeloxException& err) { |
291 | | - std::cout << "Validation failed in ReadRel due to:" << err.message() |
292 | | - << std::endl; |
293 | | - return false; |
294 | | - } |
295 | | - |
296 | | - if (scalarFunction.args().size() == 1) { |
297 | | - // Field is expected. |
298 | | - for (const auto& param : scalarFunction.args()) { |
299 | | - auto typeCase = param.rex_type_case(); |
300 | | - switch (typeCase) { |
301 | | - case ::substrait::Expression::RexTypeCase::kSelection: |
302 | | - break; |
303 | | - default: |
304 | | - std::cout << "Field is Expected." << std::endl; |
305 | | - return false; |
306 | | - } |
307 | | - } |
308 | | - } else if (scalarFunction.args().size() == 2) { |
309 | | - // Expect there being two args. One is field and the other is literal. |
310 | | - bool fieldExists = false; |
311 | | - bool litExists = false; |
312 | | - for (const auto& param : scalarFunction.args()) { |
313 | | - auto typeCase = param.rex_type_case(); |
314 | | - switch (typeCase) { |
315 | | - case ::substrait::Expression::RexTypeCase::kSelection: { |
316 | | - fieldExists = true; |
317 | | - break; |
318 | | - } |
319 | | - case ::substrait::Expression::RexTypeCase::kLiteral: { |
320 | | - litExists = true; |
321 | | - break; |
322 | | - } |
323 | | - default: |
324 | | - std::cout << "Type case: " << typeCase |
325 | | - << " is not supported in ReadRel." << std::endl; |
326 | | - return false; |
327 | | - } |
328 | | - } |
329 | | - if (!fieldExists || !litExists) { |
330 | | - std::cout << "Only the case of Field and Literal is supported." |
331 | | - << std::endl; |
332 | | - return false; |
333 | | - } |
334 | | - } else { |
335 | | - std::cout << "More than two args is not supported in ReadRel." |
336 | | - << std::endl; |
337 | | - return false; |
338 | | - } |
339 | | - } |
340 | | - std::unordered_set<std::string> supportedFilters = { |
341 | | - "is_not_null", "gte", "gt", "lte", "lt"}; |
342 | | - std::unordered_set<std::string> supportedTypes = {"opt", "req", "fp64"}; |
343 | | - for (const auto& funcSpec : funcSpecs) { |
344 | | - // Validate the functions. |
345 | | - auto funcName = subParser_->getSubFunctionName(funcSpec); |
346 | | - if (supportedFilters.find(funcName) == supportedFilters.end()) { |
347 | | - std::cout << "Validation failed due to " << funcName |
348 | | - << " was not supported in ReadRel." << std::endl; |
349 | | - return false; |
350 | | - } |
351 | | - |
352 | | - // Validate the types. |
353 | | - std::vector<std::string> funcTypes; |
354 | | - subParser_->getSubFunctionTypes(funcSpec, funcTypes); |
355 | | - for (const auto& funcType : funcTypes) { |
356 | | - if (supportedTypes.find(funcType) == supportedTypes.end()) { |
357 | | - std::cout << "Validation failed due to " << funcType |
358 | | - << " was not supported in ReadRel." << std::endl; |
359 | | - return false; |
360 | | - } |
361 | | - } |
362 | | - } |
363 | 312 | return true; |
364 | 313 | } |
365 | 314 |
|
@@ -393,6 +342,7 @@ bool SubstraitToVeloxPlanValidator::validate( |
393 | 342 |
|
394 | 343 | bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Plan& sPlan) { |
395 | 344 | functions::prestosql::registerAllScalarFunctions(); |
| 345 | + functions::sparksql::registerFunctions(""); |
396 | 346 | // Create plan converter and expression converter to help the validation. |
397 | 347 | planConverter_->constructFuncMap(sPlan); |
398 | 348 | exprConverter_ = std::make_shared<SubstraitVeloxExprConverter>( |
|
0 commit comments