Skip to content

Commit d0fa6eb

Browse files
committed
support in(double), fix not(eq) and refine
1 parent f1b9e77 commit d0fa6eb

13 files changed

+1210
-810
lines changed

velox/substrait/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ add_custom_target(substrait_proto ALL DEPENDS ${PROTO_OUTPUT_FILES})
4242
add_dependencies(substrait_proto protobuf::libprotobuf)
4343

4444
set(SRCS ${PROTO_SRCS} SubstraitUtils.cpp SubstraitToVeloxPlanValidator.cpp
45-
SubstraitToVeloxExpr.cpp SubstraitToVeloxPlan.cpp TypeUtils.cpp)
45+
SubstraitToVeloxExpr.cpp SubstraitToVeloxPlan.cpp TypeUtils.cpp VectorCreater.cpp)
4646
add_library(velox_substrait_plan_converter ${SRCS})
4747
target_include_directories(velox_substrait_plan_converter
4848
PUBLIC ${PROTO_OUTPUT_DIR})

velox/substrait/SubstraitToVeloxExpr.cpp

Lines changed: 46 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -16,85 +16,10 @@
1616

1717
#include "velox/substrait/SubstraitToVeloxExpr.h"
1818
#include "velox/substrait/TypeUtils.h"
19+
#include "velox/substrait/VectorCreater.h"
1920

2021
namespace facebook::velox::substrait {
2122

22-
template <TypeKind KIND>
23-
VectorPtr SubstraitVeloxExprConverter::setVectorFromVariantsByKind(
24-
const std::vector<velox::variant>& value,
25-
memory::MemoryPool* pool) {
26-
using T = typename TypeTraits<KIND>::NativeType;
27-
28-
auto flatVector = std::dynamic_pointer_cast<FlatVector<T>>(
29-
BaseVector::create(CppToType<T>::create(), value.size(), pool));
30-
31-
for (vector_size_t i = 0; i < value.size(); i++) {
32-
if (value[i].isNull()) {
33-
flatVector->setNull(i, true);
34-
} else {
35-
flatVector->set(i, value[i].value<T>());
36-
}
37-
}
38-
return flatVector;
39-
}
40-
41-
template <>
42-
VectorPtr
43-
SubstraitVeloxExprConverter::setVectorFromVariantsByKind<TypeKind::VARBINARY>(
44-
const std::vector<velox::variant>& value,
45-
memory::MemoryPool* pool) {
46-
throw std::invalid_argument("Return of VARBINARY data is not supported");
47-
}
48-
49-
template <>
50-
VectorPtr
51-
SubstraitVeloxExprConverter::setVectorFromVariantsByKind<TypeKind::VARCHAR>(
52-
const std::vector<velox::variant>& value,
53-
memory::MemoryPool* pool) {
54-
auto flatVector = std::dynamic_pointer_cast<FlatVector<StringView>>(
55-
BaseVector::create(VARCHAR(), value.size(), pool));
56-
57-
for (vector_size_t i = 0; i < value.size(); i++) {
58-
if (value[i].isNull()) {
59-
flatVector->setNull(i, true);
60-
} else {
61-
flatVector->set(i, StringView(value[i].value<Varchar>()));
62-
}
63-
}
64-
return flatVector;
65-
}
66-
67-
VectorPtr SubstraitVeloxExprConverter::setVectorFromVariants(
68-
const TypePtr& type,
69-
const std::vector<velox::variant>& value,
70-
velox::memory::MemoryPool* pool) {
71-
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
72-
setVectorFromVariantsByKind, type->kind(), value, pool);
73-
}
74-
75-
ArrayVectorPtr SubstraitVeloxExprConverter::toArrayVector(
76-
TypePtr type,
77-
VectorPtr vector,
78-
memory::MemoryPool* pool) {
79-
vector_size_t size = 1;
80-
BufferPtr offsets = AlignedBuffer::allocate<vector_size_t>(size, pool);
81-
BufferPtr sizes = AlignedBuffer::allocate<vector_size_t>(size, pool);
82-
BufferPtr nulls = AlignedBuffer::allocate<uint64_t>(size, pool);
83-
84-
auto rawOffsets = offsets->asMutable<vector_size_t>();
85-
auto rawSizes = sizes->asMutable<vector_size_t>();
86-
auto rawNulls = nulls->asMutable<uint64_t>();
87-
88-
bits::fillBits(rawNulls, 0, size, pool);
89-
vector_size_t nullCount = 0;
90-
91-
*rawSizes++ = vector->size();
92-
*rawOffsets++ = 0;
93-
94-
return std::make_shared<ArrayVector>(
95-
pool, ARRAY(type), nulls, size, offsets, sizes, vector, nullCount);
96-
}
97-
9823
std::shared_ptr<const core::FieldAccessTypedExpr>
9924
SubstraitVeloxExprConverter::toVeloxExpr(
10025
const ::substrait::Expression::FieldReference& sField,
@@ -130,6 +55,7 @@ std::shared_ptr<const core::ITypedExpr>
13055
SubstraitVeloxExprConverter::toAliasExpr(
13156
const std::vector<std::shared_ptr<const core::ITypedExpr>>& params) {
13257
VELOX_CHECK(params.size() == 1, "Alias expects one parameter.");
58+
// Alias is omitted due to name change is not needed.
13359
return params[0];
13460
}
13561

@@ -161,7 +87,6 @@ SubstraitVeloxExprConverter::toVeloxExpr(
16187
const auto& veloxType =
16288
toVeloxType(subParser_->parseType(sFunc.output_type())->type);
16389

164-
// Omit alias because because name change is not needed.
16590
if (veloxFunction == "alias") {
16691
return toAliasExpr(params);
16792
}
@@ -172,44 +97,6 @@ SubstraitVeloxExprConverter::toVeloxExpr(
17297
veloxType, std::move(params), veloxFunction);
17398
}
17499

175-
TypePtr SubstraitVeloxExprConverter::literalToType(
176-
const ::substrait::Expression::Literal& literal) {
177-
auto typeCase = literal.literal_type_case();
178-
switch (typeCase) {
179-
case ::substrait::Expression_Literal::LiteralTypeCase::kBoolean:
180-
return BOOLEAN();
181-
case ::substrait::Expression_Literal::LiteralTypeCase::kI32:
182-
return INTEGER();
183-
case ::substrait::Expression_Literal::LiteralTypeCase::kI64:
184-
return BIGINT();
185-
case ::substrait::Expression_Literal::LiteralTypeCase::kFp64:
186-
return DOUBLE();
187-
case ::substrait::Expression_Literal::LiteralTypeCase::kString:
188-
return VARCHAR();
189-
default:
190-
VELOX_NYI("LiteralToType not supported for type case '{}'", typeCase);
191-
}
192-
}
193-
194-
variant SubstraitVeloxExprConverter::toVariant(
195-
const ::substrait::Expression::Literal& literal) {
196-
auto typeCase = literal.literal_type_case();
197-
switch (typeCase) {
198-
case ::substrait::Expression_Literal::LiteralTypeCase::kBoolean:
199-
return variant(literal.boolean());
200-
case ::substrait::Expression_Literal::LiteralTypeCase::kI32:
201-
return variant(literal.i32());
202-
case ::substrait::Expression_Literal::LiteralTypeCase::kI64:
203-
return variant(literal.i64());
204-
case ::substrait::Expression_Literal::LiteralTypeCase::kFp64:
205-
return variant(literal.fp64());
206-
case ::substrait::Expression_Literal::LiteralTypeCase::kString:
207-
return variant(literal.string());
208-
default:
209-
VELOX_NYI("ToVariant not supported for type case '{}'", typeCase);
210-
}
211-
}
212-
213100
std::shared_ptr<const core::ConstantTypedExpr>
214101
SubstraitVeloxExprConverter::toVeloxExpr(
215102
const ::substrait::Expression::Literal& sLit) {
@@ -220,24 +107,32 @@ SubstraitVeloxExprConverter::toVeloxExpr(
220107
case ::substrait::Expression_Literal::LiteralTypeCase::kI64:
221108
case ::substrait::Expression_Literal::LiteralTypeCase::kFp64:
222109
case ::substrait::Expression_Literal::LiteralTypeCase::kString:
223-
return std::make_shared<core::ConstantTypedExpr>(toVariant(sLit));
110+
return std::make_shared<core::ConstantTypedExpr>(
111+
toTypedVariant(sLit)->veloxVariant);
224112
case ::substrait::Expression_Literal::LiteralTypeCase::kList: {
113+
// List is used in 'in' expression. Will wrap a constant
114+
// vector with an array vector inside to create the constant expression.
225115
std::vector<variant> variants;
226116
variants.reserve(sLit.list().values().size());
227117
VELOX_CHECK(
228118
sLit.list().values().size() > 0,
229119
"List should have at least one item.");
230120
std::optional<TypePtr> literalType = std::nullopt;
231121
for (const auto& literal : sLit.list().values()) {
122+
auto typedVariant = toTypedVariant(literal);
232123
if (!literalType.has_value()) {
233-
literalType = literalToType(literal);
124+
literalType = typedVariant->variantType;
234125
}
235-
variants.emplace_back(toVariant(literal));
126+
variants.emplace_back(typedVariant->veloxVariant);
236127
}
237128
VELOX_CHECK(literalType.has_value(), "Type expected.");
238-
auto type = literalType.value();
239-
VectorPtr vector = setVectorFromVariants(type, variants, pool_.get());
240-
ArrayVectorPtr arrayVector = toArrayVector(type, vector, pool_.get());
129+
// Create flat vector from the variants.
130+
VectorPtr vector =
131+
setVectorFromVariants(literalType.value(), variants, pool_.get());
132+
// Create array vector from the flat vector.
133+
ArrayVectorPtr arrayVector =
134+
toArrayVector(literalType.value(), vector, pool_.get());
135+
// Wrap the array vector into constant vector.
241136
auto constantVector = BaseVector::wrapInConstant(1, 0, arrayVector);
242137
auto constantExpr =
243138
std::make_shared<core::ConstantTypedExpr>(constantVector);
@@ -285,4 +180,34 @@ SubstraitVeloxExprConverter::toVeloxExpr(
285180
}
286181
}
287182

183+
std::shared_ptr<SubstraitVeloxExprConverter::TypedVariant>
184+
SubstraitVeloxExprConverter::toTypedVariant(
185+
const ::substrait::Expression::Literal& literal) {
186+
auto typeCase = literal.literal_type_case();
187+
switch (typeCase) {
188+
case ::substrait::Expression_Literal::LiteralTypeCase::kBoolean: {
189+
TypedVariant typedVariant = {variant(literal.boolean()), BOOLEAN()};
190+
return std::make_shared<TypedVariant>(typedVariant);
191+
}
192+
case ::substrait::Expression_Literal::LiteralTypeCase::kI32: {
193+
TypedVariant typedVariant = {variant(literal.i32()), INTEGER()};
194+
return std::make_shared<TypedVariant>(typedVariant);
195+
}
196+
case ::substrait::Expression_Literal::LiteralTypeCase::kI64: {
197+
TypedVariant typedVariant = {variant(literal.i64()), BIGINT()};
198+
return std::make_shared<TypedVariant>(typedVariant);
199+
}
200+
case ::substrait::Expression_Literal::LiteralTypeCase::kFp64: {
201+
TypedVariant typedVariant = {variant(literal.fp64()), DOUBLE()};
202+
return std::make_shared<TypedVariant>(typedVariant);
203+
}
204+
case ::substrait::Expression_Literal::LiteralTypeCase::kString: {
205+
TypedVariant typedVariant = {variant(literal.string()), VARCHAR()};
206+
return std::make_shared<TypedVariant>(typedVariant);
207+
}
208+
default:
209+
VELOX_NYI("ToVariant not supported for type case '{}'", typeCase);
210+
}
211+
}
212+
288213
} // namespace facebook::velox::substrait

velox/substrait/SubstraitToVeloxExpr.h

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class SubstraitVeloxExprConverter {
3737
const std::unordered_map<uint64_t, std::string>& functionMap)
3838
: functionMap_(functionMap) {}
3939

40+
/// Stores the variant and its type.
41+
struct TypedVariant {
42+
variant veloxVariant;
43+
TypePtr variantType;
44+
};
45+
4046
/// Used to convert Substrait Field into Velox Field Expression.
4147
std::shared_ptr<const core::FieldAccessTypedExpr> toVeloxExpr(
4248
const ::substrait::Expression::FieldReference& sField,
@@ -52,17 +58,15 @@ class SubstraitVeloxExprConverter {
5258
const ::substrait::Expression::Cast& castExpr,
5359
const RowTypePtr& inputType);
5460

61+
/// Create expression for alias.
5562
std::shared_ptr<const core::ITypedExpr> toAliasExpr(
5663
const std::vector<std::shared_ptr<const core::ITypedExpr>>& params);
5764

65+
/// Create expression for is_not_null.
5866
std::shared_ptr<const core::ITypedExpr> toIsNotNullExpr(
5967
const std::vector<std::shared_ptr<const core::ITypedExpr>>& params,
6068
const TypePtr& outputType);
6169

62-
TypePtr literalToType(const ::substrait::Expression::Literal& literal);
63-
64-
variant toVariant(const ::substrait::Expression::Literal& literal);
65-
6670
/// Used to convert Substrait Literal into Velox Expression.
6771
std::shared_ptr<const core::ConstantTypedExpr> toVeloxExpr(
6872
const ::substrait::Expression::Literal& sLit);
@@ -72,21 +76,12 @@ class SubstraitVeloxExprConverter {
7276
const ::substrait::Expression& sExpr,
7377
const RowTypePtr& inputType);
7478

75-
private:
76-
template <TypeKind KIND>
77-
VectorPtr setVectorFromVariantsByKind(
78-
const std::vector<velox::variant>& value,
79-
memory::MemoryPool* pool);
79+
/// Get variant and its type from Substrait Literal.
80+
std::shared_ptr<TypedVariant> toTypedVariant(
81+
const ::substrait::Expression::Literal& literal);
8082

81-
VectorPtr setVectorFromVariants(
82-
const TypePtr& type,
83-
const std::vector<velox::variant>& value,
84-
velox::memory::MemoryPool* pool);
85-
86-
ArrayVectorPtr
87-
toArrayVector(TypePtr type, VectorPtr vector, memory::MemoryPool* pool);
88-
89-
// A tmp used memory pool. Needs to be removed.
83+
private:
84+
// Memory pool.
9085
std::unique_ptr<memory::MemoryPool> pool_{
9186
memory::getDefaultScopedMemoryPool()};
9287

0 commit comments

Comments
 (0)