Skip to content

Commit 155b45a

Browse files
1. select from struct.
2. add round-trip test.(nested-struct)
1 parent 6c3ee92 commit 155b45a

File tree

3 files changed

+101
-19
lines changed

3 files changed

+101
-19
lines changed

velox/substrait/SubstraitToVeloxExpr.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,37 @@ SubstraitVeloxExprConverter::toVeloxExpr(
2929
kDirectReference: {
3030
const auto& directRef = substraitField.direct_reference();
3131
int32_t colIdx = substraitParser_.parseReferenceSegment(directRef);
32+
33+
const auto& inputTypes = inputType->children();
3234
const auto& inputNames = inputType->names();
3335
const int64_t inputSize = inputNames.size();
34-
if (colIdx <= inputSize) {
35-
const auto& inputTypes = inputType->children();
36-
// Convert type to row.
36+
VELOX_CHECK_LT(
37+
colIdx, inputSize, "Missing the column with id '{}' .", colIdx);
38+
std::optional<int32_t> childIdx;
39+
if (directRef.struct_field().has_child()) {
40+
childIdx = substraitParser_.parseReferenceSegment(
41+
directRef.struct_field().child());
42+
}
43+
44+
if (!childIdx.has_value()) {
3745
return std::make_shared<core::FieldAccessTypedExpr>(
3846
inputTypes[colIdx],
3947
std::make_shared<core::InputTypedExpr>(inputTypes[colIdx]),
4048
inputNames[colIdx]);
41-
} else {
42-
VELOX_FAIL("Missing the column with id '{}' .", colIdx);
4349
}
50+
VELOX_CHECK_EQ(inputTypes[colIdx]->kind(), TypeKind::ROW);
51+
auto inputColumnType = asRowType(inputTypes[colIdx]);
52+
VELOX_CHECK_LT(
53+
childIdx.value(),
54+
inputColumnType->size(),
55+
"Missing the subfield with id '{}' .",
56+
childIdx.value());
57+
// Select a subfield in a struct by name.
58+
return std::make_shared<core::FieldAccessTypedExpr>(
59+
inputColumnType->childAt(childIdx.value()),
60+
std::make_shared<core::FieldAccessTypedExpr>(
61+
inputTypes[colIdx], inputNames[colIdx]),
62+
inputColumnType->nameOf(childIdx.value()));
4463
}
4564
default:
4665
VELOX_NYI(

velox/substrait/VeloxToSubstraitExpr.cpp

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,51 @@ VeloxToSubstraitExprConvertor::toSubstraitExpr(
114114
::substrait::Expression_FieldReference* substraitFieldExpr =
115115
google::protobuf::Arena::CreateMessage<
116116
::substrait::Expression_FieldReference>(&arena);
117-
118-
std::string exprName = fieldExpr->name();
119-
120117
::substrait::Expression_ReferenceSegment_StructField* directStruct =
121118
substraitFieldExpr->mutable_direct_reference()->mutable_struct_field();
119+
std::string exprName = fieldExpr->name();
120+
// FieldAccessTypedExpr represents one of two things: a leaf in an expression
121+
// or a dereference expression(fieldExpr->isInputColumn() == false)
122+
// for a leaf in an expression, find idx from child by exprName.
123+
// for a dereference expression, find idx from every child by exprName.
124+
if (fieldExpr->isInputColumn()) {
125+
auto idx = inputType->getChildIdxIfExists(exprName);
126+
if (idx.has_value()) {
127+
directStruct->set_field(idx.value());
128+
} else {
129+
VELOX_USER_FAIL("idx has_value return false.");
130+
}
131+
} else {
132+
int matchCount = 0;
133+
uint32_t idxOfExprName = -1;
134+
for (auto child : inputType->children()) {
135+
auto rowChild = asRowType(child);
136+
if (!rowChild) {
137+
continue;
138+
}
139+
auto idxInChild = rowChild->getChildIdxIfExists(exprName);
140+
if (idxInChild.has_value()) {
141+
matchCount++;
142+
idxOfExprName = idxInChild.value();
143+
}
144+
}
145+
if (matchCount == 0) {
146+
VELOX_USER_FAIL("exprName :{} no name matched!", exprName);
147+
}
148+
if (matchCount > 1) {
149+
VELOX_USER_FAIL("exprName :{} multiple names matched!", exprName);
150+
}
151+
152+
::substrait::Expression_ReferenceSegment* refSegment =
153+
google::protobuf::Arena::CreateMessage<
154+
::substrait::Expression_ReferenceSegment>(&arena);
155+
::substrait::Expression_ReferenceSegment_StructField* childStruct =
156+
refSegment->mutable_struct_field();
157+
childStruct->set_field(idxOfExprName);
158+
refSegment->set_allocated_struct_field(childStruct);
159+
directStruct->set_allocated_child(refSegment);
160+
}
122161

123-
directStruct->set_field(inputType->getChildIdx(exprName));
124162
return *substraitFieldExpr;
125163
}
126164

@@ -235,20 +273,20 @@ VeloxToSubstraitExprConvertor::toSubstraitNotNullLiteral(
235273
google::protobuf::Arena::CreateMessage<::substrait::Expression_Literal>(
236274
&arena);
237275
switch (variantValue.kind()) {
238-
case velox::TypeKind::DOUBLE: {
239-
literalExpr->set_fp64(variantValue.value<TypeKind::DOUBLE>());
240-
break;
241-
}
242-
case velox::TypeKind::BIGINT: {
243-
literalExpr->set_i64(variantValue.value<TypeKind::BIGINT>());
276+
case velox::TypeKind::BOOLEAN: {
277+
literalExpr->set_boolean(variantValue.value<TypeKind::BOOLEAN>());
244278
break;
245279
}
246280
case velox::TypeKind::INTEGER: {
247281
literalExpr->set_i32(variantValue.value<TypeKind::INTEGER>());
248282
break;
249283
}
250-
case velox::TypeKind::BOOLEAN: {
251-
literalExpr->set_boolean(variantValue.value<TypeKind::BOOLEAN>());
284+
case velox::TypeKind::BIGINT: {
285+
literalExpr->set_i64(variantValue.value<TypeKind::BIGINT>());
286+
break;
287+
}
288+
case velox::TypeKind::DOUBLE: {
289+
literalExpr->set_fp64(variantValue.value<TypeKind::DOUBLE>());
252290
break;
253291
}
254292
default:

velox/substrait/tests/VeloxSubstraitRoundTripPlanConverterTest.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818

1919
#include "velox/exec/tests/utils/OperatorTestBase.h"
2020
#include "velox/exec/tests/utils/PlanBuilder.h"
21-
#include "velox/vector/tests/utils/VectorMaker.h"
22-
2321
#include "velox/substrait/SubstraitToVeloxPlan.h"
2422
#include "velox/substrait/VeloxToSubstraitPlan.h"
23+
#include "velox/vector/tests/utils/VectorTestBase.h"
2524

2625
using namespace facebook::velox;
2726
using namespace facebook::velox::test;
@@ -286,6 +285,32 @@ TEST_F(VeloxSubstraitRoundTripPlanConverterTest, ifThen) {
286285
assertPlanConversion(plan, "SELECT if (c0=1, c0 + 1, c1 + 2) as x FROM tmp");
287286
}
288287

288+
TEST_F(VeloxSubstraitRoundTripPlanConverterTest, subField) {
289+
RowVectorPtr data = makeRowVector(
290+
{"a", "b", "c"},
291+
{
292+
makeFlatVector<int64_t>(
293+
{2499109626526694126, 2342493223442167775, 4077358421272316858}),
294+
makeFlatVector<int32_t>({581869302, -708632711, -133711905}),
295+
makeFlatVector<double>(
296+
{0.90579193414549275, 0.96886777112423139, 0.63235925003444637}),
297+
});
298+
createDuckDbTable({data});
299+
auto plan =
300+
PlanBuilder()
301+
.values({data})
302+
.project(
303+
{"cast(row_constructor(a, b) as row(a bigint, b bigint)) as ab",
304+
"c"})
305+
.project(
306+
{"cast(row_constructor(ab, c) as row(ab row(a bigint, b bigint), c bigint)) as abc"})
307+
.project({"(abc).ab", "abc.c"})
308+
.project({"(ab).a", "(ab).b", "c"})
309+
.planNode();
310+
311+
assertPlanConversion(plan, "SELECT a, b, c FROM tmp");
312+
}
313+
289314
int main(int argc, char** argv) {
290315
testing::InitGoogleTest(&argc, argv);
291316
folly::init(&argc, &argv, false);

0 commit comments

Comments
 (0)