Skip to content

Commit aa23aca

Browse files
Support selecting a subfield from struct (#294)
Co-authored-by: zhejiangxiaomai <zhenhui.zhao@intel.com>
1 parent 7e73041 commit aa23aca

File tree

4 files changed

+119
-40
lines changed

4 files changed

+119
-40
lines changed

velox/substrait/SubstraitToVeloxExpr.cpp

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,16 @@ VectorPtr constructFlatVectorForStruct(
189189
return vector;
190190
}
191191

192+
core::FieldAccessTypedExprPtr makeFieldAccessExpr(
193+
const std::string& name,
194+
const TypePtr& type,
195+
core::FieldAccessTypedExprPtr input) {
196+
if (input) {
197+
return std::make_shared<core::FieldAccessTypedExpr>(type, input, name);
198+
}
199+
200+
return std::make_shared<core::FieldAccessTypedExpr>(type, name);
201+
}
192202
} // namespace
193203

194204
using facebook::velox::core::variantArrayToVector;
@@ -202,44 +212,26 @@ SubstraitVeloxExprConverter::toVeloxExpr(
202212
switch (typeCase) {
203213
case ::substrait::Expression::FieldReference::ReferenceTypeCase::
204214
kDirectReference: {
205-
const auto& dRef = substraitField.direct_reference();
206-
VELOX_CHECK(dRef.has_struct_field(), "Struct field expected.");
207-
int32_t colIdx = subParser_->parseReferenceSegment(dRef);
208-
std::optional<int32_t> childIdx = std::nullopt;
209-
if (dRef.struct_field().has_child()) {
210-
childIdx =
211-
subParser_->parseReferenceSegment(dRef.struct_field().child());
212-
}
213-
214-
const auto& inputTypes = inputType->children();
215-
const auto& inputNames = inputType->names();
216-
const int64_t inputSize = inputNames.size();
217-
218-
if (colIdx >= inputSize) {
219-
VELOX_FAIL("Missing the column with id '{}' .", colIdx);
220-
}
221-
222-
if (!childIdx.has_value()) {
223-
return std::make_shared<core::FieldAccessTypedExpr>(
224-
inputTypes[colIdx],
225-
std::make_shared<core::InputTypedExpr>(inputTypes[colIdx]),
226-
inputNames[colIdx]);
227-
} else {
228-
// Select a subfield in a struct by name.
229-
if (auto inputColumnType = asRowType(inputTypes[colIdx])) {
230-
if (childIdx.value() >= inputColumnType->size()) {
231-
VELOX_FAIL("Missing the subfield with id '{}' .", childIdx.value());
232-
}
233-
return std::make_shared<core::FieldAccessTypedExpr>(
234-
inputColumnType->childAt(childIdx.value()),
235-
std::make_shared<core::FieldAccessTypedExpr>(
236-
inputTypes[colIdx], inputNames[colIdx]),
237-
inputColumnType->nameOf(childIdx.value()));
238-
} else {
239-
VELOX_FAIL("RowType expected.");
215+
const auto& directRef = substraitField.direct_reference();
216+
core::FieldAccessTypedExprPtr fieldAccess{nullptr};
217+
const auto* tmp = &directRef.struct_field();
218+
219+
auto inputColumnType = inputType;
220+
for (;;) {
221+
auto idx = tmp->field();
222+
fieldAccess = makeFieldAccessExpr(
223+
inputColumnType->nameOf(idx),
224+
inputColumnType->childAt(idx),
225+
fieldAccess);
226+
227+
if (!tmp->has_child()) {
228+
break;
240229
}
230+
231+
inputColumnType = asRowType(inputColumnType->childAt(idx));
232+
tmp = &tmp->child().struct_field();
241233
}
242-
break;
234+
return fieldAccess;
243235
}
244236
default:
245237
VELOX_NYI(

velox/substrait/VeloxToSubstraitExpr.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,26 @@ void convertVectorValue(
368368
}
369369
}
370370
}
371+
372+
uint32_t getFieldIdForIntermediateNode(
373+
const std::string& exprName,
374+
const ::substrait::Expression_ReferenceSegment_StructField& structField,
375+
const RowTypePtr& inputType) {
376+
auto inputColumnType = inputType;
377+
std::vector<int32_t> ids;
378+
const auto* tmp = &structField;
379+
for (;;) {
380+
ids.push_back(tmp->field());
381+
if (!tmp->has_child()) {
382+
break;
383+
}
384+
tmp = &tmp->child().struct_field();
385+
}
386+
for (int32_t i = ids.size() - 1; i >= 0; --i) {
387+
inputColumnType = asRowType(inputColumnType->childAt(ids[i]));
388+
}
389+
return inputColumnType->getChildIdx(exprName);
390+
}
371391
} // namespace
372392

373393
const ::substrait::Expression& VeloxToSubstraitExprConvertor::toSubstraitExpr(
@@ -448,7 +468,40 @@ VeloxToSubstraitExprConvertor::toSubstraitExpr(
448468
::substrait::Expression_ReferenceSegment_StructField* directStruct =
449469
substraitFieldExpr->mutable_direct_reference()->mutable_struct_field();
450470

451-
directStruct->set_field(inputType->getChildIdx(exprName));
471+
// FieldAccessTypedExpr represents one of two things: a leaf in an expression
472+
// or a dereference expression(fieldExpr->isInputColumn() == false)
473+
// for a leaf in an expression, find idx from child by exprName.
474+
// for a dereference expression, find idx from every child by exprName.
475+
if (fieldExpr->isInputColumn()) {
476+
uint32_t idx = inputType->getChildIdx(exprName);
477+
directStruct->set_field(idx);
478+
} else {
479+
auto tmp = toSubstraitExpr(arena, fieldExpr->inputs()[0], inputType)
480+
.selection()
481+
.direct_reference();
482+
if (!tmp.has_struct_field()) {
483+
uint32_t idx = inputType->getChildIdx(exprName);
484+
directStruct->set_field(idx);
485+
} else {
486+
::substrait::Expression_ReferenceSegment_StructField* childStruct =
487+
google::protobuf::Arena::CreateMessage<
488+
::substrait::Expression_ReferenceSegment_StructField>(&arena);
489+
::substrait::Expression_ReferenceSegment* refSegment =
490+
google::protobuf::Arena::CreateMessage<
491+
::substrait::Expression_ReferenceSegment>(&arena);
492+
directStruct->MergeFrom(tmp.struct_field());
493+
childStruct->set_field(getFieldIdForIntermediateNode(
494+
exprName, tmp.struct_field(), inputType));
495+
refSegment->set_allocated_struct_field(childStruct);
496+
::substrait::Expression_ReferenceSegment_StructField* innerChild =
497+
directStruct;
498+
while (innerChild->has_child()) {
499+
innerChild = innerChild->mutable_child()->mutable_struct_field();
500+
}
501+
innerChild->set_allocated_child(refSegment);
502+
}
503+
}
504+
452505
return *substraitFieldExpr;
453506
}
454507

velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ TEST_F(Substrait2VeloxPlanConversionTest, ifthenTest) {
309309
" -- TableScan[table: hive_table, range filters: "
310310
"[(hd_demo_sk, Filter(IsNotNull, deterministic, null not allowed)),"
311311
" (hd_vehicle_count, BigintRange: [1, 9223372036854775807] no nulls)], "
312-
"remaining filter: (and(or(equalto(ROW[\"hd_buy_potential\"],\">10000\"),"
313-
"equalto(ROW[\"hd_buy_potential\"],\"unknown\")),if(greaterthan(ROW[\"hd_vehicle_count\"],0),"
314-
"greaterthan(divide(cast ROW[\"hd_dep_count\"] as DOUBLE,cast ROW[\"hd_vehicle_count\"] as DOUBLE),1.2))))]"
312+
"remaining filter: (and(or(equalto(\"hd_buy_potential\",\">10000\"),"
313+
"equalto(\"hd_buy_potential\",\"unknown\")),if(greaterthan(\"hd_vehicle_count\",0),"
314+
"greaterthan(divide(cast \"hd_dep_count\" as DOUBLE,cast \"hd_vehicle_count\" as DOUBLE),1.2))))]"
315315
" -> n0_0:BIGINT, n0_1:VARCHAR, n0_2:BIGINT, n0_3:BIGINT\n",
316316
planNode->toString(true, true));
317317
}

velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
#include "velox/functions/sparksql/Register.h"
2525
#include "velox/substrait/SubstraitToVeloxPlan.h"
26+
#include "velox/substrait/VeloxToSubstraitPlan.h"
27+
#include "velox/vector/tests/utils/VectorTestBase.h"
28+
2629
#include "velox/substrait/VariantToVectorConverter.h"
2730
#include "velox/substrait/VeloxToSubstraitPlan.h"
2831

@@ -507,6 +510,37 @@ TEST_F(VeloxSubstraitRoundTripTest, dateType) {
507510
assertPlanConversion(plan, "SELECT * FROM tmp WHERE c > DATE '1992-01-01'");
508511
}
509512

513+
TEST_F(VeloxSubstraitRoundTripTest, subField) {
514+
RowVectorPtr data = makeRowVector(
515+
{"a", "b", "c"},
516+
{
517+
makeFlatVector<int64_t>({249, 235, 858}),
518+
makeFlatVector<int32_t>({581, -708, -133}),
519+
makeFlatVector<double>({0.905, 0.968, 0.632}),
520+
});
521+
createDuckDbTable({data});
522+
auto plan =
523+
PlanBuilder()
524+
.values({data})
525+
.project(
526+
{"cast(row_constructor(a, b) as row(a bigint, b bigint)) as ab",
527+
"a",
528+
"b",
529+
"c"})
530+
.project(
531+
{"cast(row_constructor(ab, c) as row(ab row(a bigint, b bigint), c bigint)) as abc",
532+
"a",
533+
"b"})
534+
.project(
535+
{"(cast(row_constructor(a, b) as row(a bigint, b bigint))).a",
536+
"(abc).ab.a",
537+
"(abc).ab.b",
538+
"abc.c"})
539+
.planNode();
540+
541+
assertPlanConversion(plan, "SELECT a, a, b, c FROM tmp");
542+
}
543+
510544
int main(int argc, char** argv) {
511545
facebook::velox::functions::sparksql::registerFunctions("");
512546
testing::InitGoogleTest(&argc, argv);

0 commit comments

Comments
 (0)