Skip to content

Commit

Permalink
[Substrait] Support root case in PlanRel message.
Browse files Browse the repository at this point in the history
This essentially consist in adding optional field names to the
`PlanRelOp`. Since verifying if those field names match the yielded type
of the relation consists of the same logic as the one used in the
verifier of the `NamedTableOp`, this commit also factors out that
verification log. While touching that code, the commit also slightly
extends the error messages emitted on verification failure.
  • Loading branch information
ingomueller-net committed Mar 27, 2024
1 parent a71c4de commit 789a0ae
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 41 deletions.
18 changes: 14 additions & 4 deletions include/structured/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,22 @@ def Substrait_PlanRelOp : Substrait_Op<"relation", [
let summary = "Represents a query tree in a Substrait plan";
let description = [{
Represents a `PlanRel` message, which is used in the `relations` field of
the `Plan` message. The body of this op contains various `RelOpInterface`
ops (corresponding to the `Rel` message type) producing SSA values and the
one being yielded reprents the root of the query tree that this op contains.
the `Plan` message. The same op can represent either the `Rel`, in which
case the `fieldNames` attribute is not set, or the `RootRel` case, in which
case the `fieldNames` attribute corresponds to the `RelRoot.names` field.
The body of this op contains various `RelOpInterface` ops (corresponding to
the `Rel` message type) producing SSA values and the one being yielded
reprents the root of the query tree that this op contains.
}];
let arguments = (ins OptionalAttr<StringArrayAttr>:$fieldNames);
let regions = (region RegionOf<RelationBodyOp>:$body);
let assemblyFormat = "attr-dict-with-keyword $body";
let assemblyFormat = "(`as` $fieldNames^)? attr-dict-with-keyword $body";
let hasRegionVerifier = 1;
let builders = [
OpBuilder<(ins ), [{
build($_builder, $_state, ArrayAttr());
}]>
];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
::llvm::StringRef $cppClass::getDefaultDialect() {
Expand Down
40 changes: 28 additions & 12 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
/// own). Furthermore, the names on each nesting level need to be unique. For
/// details, see
/// https://substrait.io/tutorial/sql_to_substrait/#types-and-schemas.
FailureOr<int> verifyNamedStruct(Location loc,
llvm::ArrayRef<Attribute> fieldNames,
TypeRange fieldTypes) {
FailureOr<int> verifyNamedStructHelper(Location loc,
llvm::ArrayRef<Attribute> fieldNames,
TypeRange fieldTypes) {
int numConsumedNames = 0;
llvm::SmallSet<llvm::StringRef, 8> currentLevelNames;
for (Type type : fieldTypes) {
Expand All @@ -200,7 +200,7 @@ FailureOr<int> verifyNamedStruct(Location loc,
llvm::ArrayRef<Attribute> remainingNames =
fieldNames.drop_front(numConsumedNames);
FailureOr<int> res =
verifyNamedStruct(loc, remainingNames, nestedFieldTypes);
verifyNamedStructHelper(loc, remainingNames, nestedFieldTypes);
if (failed(res))
return failure();
numConsumedNames += res.value();
Expand All @@ -209,31 +209,31 @@ FailureOr<int> verifyNamedStruct(Location loc,
return numConsumedNames;
}

LogicalResult NamedTableOp::verify() {
Location loc = getLoc();
llvm::ArrayRef<Attribute> fieldNames = getFieldNames().getValue();
auto tupleType = llvm::cast<TupleType>(getResult().getType());
LogicalResult verifyNamedStruct(Operation *op,
llvm::ArrayRef<Attribute> fieldNames,
TupleType tupleType) {
Location loc = op->getLoc();
TypeRange fieldTypes = tupleType.getTypes();

// Emits error message with context on failure.
auto emitErrorMessage = [&]() {
InFlightDiagnostic error = ::emitError(loc)
<< "mismatching 'field_names' ([";
InFlightDiagnostic error = op->emitOpError()
<< "has mismatching 'field_names' ([";
llvm::interleaveComma(fieldNames, error);
error << "]) and result type (" << tupleType << ")";
return error;
};

// Call recursive verification function.
FailureOr<int> numConsumedNames =
verifyNamedStruct(loc, fieldNames, fieldTypes);
verifyNamedStructHelper(loc, fieldNames, fieldTypes);

// Relay any failure.
if (failed(numConsumedNames))
return emitErrorMessage();

// If we haven't consumed all names, we got too many of them, so report.
if (numConsumedNames.value() != static_cast<int>(getFieldNames().size())) {
if (numConsumedNames.value() != static_cast<int>(fieldNames.size())) {
InFlightDiagnostic error = emitErrorMessage();
error.attachNote(loc) << "too many field names provided";
return error;
Expand All @@ -242,6 +242,22 @@ LogicalResult NamedTableOp::verify() {
return success();
}

LogicalResult NamedTableOp::verify() {
llvm::ArrayRef<Attribute> fieldNames = getFieldNames().getValue();
auto tupleType = llvm::cast<TupleType>(getResult().getType());
return verifyNamedStruct(getOperation(), fieldNames, tupleType);
}

LogicalResult PlanRelOp::verifyRegions() {
if (!getFieldNames().has_value())
return success();

llvm::ArrayRef<Attribute> fieldNames = getFieldNames()->getValue();
auto yieldOp = llvm::cast<YieldOp>(getBody().front().getTerminator());
auto tupleType = llvm::cast<TupleType>(yieldOp.getValue().getType());
return verifyNamedStruct(getOperation(), fieldNames, tupleType);
}

} // namespace substrait
} // namespace mlir

Expand Down
15 changes: 14 additions & 1 deletion lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,21 @@ FailureOr<std::unique_ptr<Plan>> exportOperation(PlanOp op) {
if (failed(rel))
return failure();

// Handle `Rel`/`RelRoot` cases depending on whether `names` is set.
PlanRel *planRel = plan->add_relations();
planRel->set_allocated_rel(rel.value().release());
if (std::optional<Attribute> names = relOp.getFieldNames()) {
auto root = std::make_unique<RelRoot>();
root->set_allocated_input(rel->release());

auto namesArray = names->cast<ArrayAttr>().getAsRange<StringAttr>();
for (StringAttr name : namesArray) {
root->add_names(name.getValue().str());
}

planRel->set_allocated_root(root.release());
} else {
planRel->set_allocated_rel(rel->release());
}
}

return std::move(plan);
Expand Down
53 changes: 33 additions & 20 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,44 @@ static FailureOr<PlanRelOp> importPlanRel(ImplicitLocOpBuilder builder,
MLIRContext *context = builder.getContext();
Location loc = UnknownLoc::get(context);

PlanRel::RelTypeCase relType = message.rel_type_case();
switch (relType) {
case PlanRel::RelTypeCase::kRel: {
auto planRelOp = builder.create<PlanRelOp>();
planRelOp.getBody().push_back(new Block());
Block *block = &planRelOp.getBody().front();

OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToEnd(block);
const Rel &rel = message.rel();
mlir::FailureOr<Operation *> rootRel = importRel(builder, rel);
if (failed(rootRel))
return failure();

builder.setInsertionPointToEnd(block);
builder.create<YieldOp>(rootRel.value()->getResult(0));

return planRelOp;
}
default: {
if (!message.has_rel() && !message.has_root()) {
PlanRel::RelTypeCase relType = message.rel_type_case();
const pb::FieldDescriptor *desc =
PlanRel::GetDescriptor()->FindFieldByNumber(relType);
return emitError(loc) << Twine("unsupported PlanRel type: ") + desc->name();
}

// Create new `PlanRelOp`.
auto planRelOp = builder.create<PlanRelOp>();
planRelOp.getBody().push_back(new Block());
Block *block = &planRelOp.getBody().front();

// Handle `Rel` and `RelRoot` separately.
const Rel *rel;
if (message.has_rel())
rel = &message.rel();
else {
const RelRoot &root = message.root();
rel = &root.input();

// Extract names.
SmallVector<std::string> names(root.names().begin(), root.names().end());
SmallVector<llvm::StringRef> nameAttrs(names.begin(), names.end());
ArrayAttr namesAttr = builder.getStrArrayAttr(nameAttrs);
planRelOp.setFieldNamesAttr(namesAttr);
}

// Import body of `PlanRelOp`.
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToEnd(block);
mlir::FailureOr<Operation *> rootRel = importRel(builder, *rel);
if (failed(rootRel))
return failure();

builder.setInsertionPointToEnd(block);
builder.create<YieldOp>(rootRel.value()->getResult(0));

return planRelOp;
}

static mlir::FailureOr<RelOpInterface>
Expand Down
6 changes: 3 additions & 3 deletions test/Dialect/Substrait/named-table-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Test error if providing too many names (1 name for 0 fields).
substrait.plan version 0 : 42 : 1 {
relation {
// expected-error@+2 {{mismatching 'field_names' (["a"]) and result type ('tuple<>')}}
// expected-error@+2 {{'substrait.named_table' op has mismatching 'field_names' (["a"]) and result type ('tuple<>')}}
// expected-note@+1 {{too many field names provided}}
%0 = named_table @t1 as ["a"] : tuple<>
yield %0 : tuple<>
Expand All @@ -15,7 +15,7 @@ substrait.plan version 0 : 42 : 1 {
// Test error if providing too few names (0 names for 1 field).
substrait.plan version 0 : 42 : 1 {
relation {
// expected-error@+2 {{mismatching 'field_names' ([]) and result type ('tuple<si32>')}}
// expected-error@+2 {{'substrait.named_table' op has mismatching 'field_names' ([]) and result type ('tuple<si32>')}}
// expected-error@+1 {{not enough field names provided}}
%0 = named_table @t1 as [] : tuple<si32>
yield %0 : tuple<si32>
Expand All @@ -28,7 +28,7 @@ substrait.plan version 0 : 42 : 1 {
// Test error if providing duplicate field names in the same nesting level.
substrait.plan version 0 : 42 : 1 {
relation {
// expected-error@+2 {{mismatching 'field_names' (["a", "a"]) and result type ('tuple<si32, si32>')}}
// expected-error@+2 {{'substrait.named_table' op has mismatching 'field_names' (["a", "a"]) and result type ('tuple<si32, si32>')}}
// expected-error@+1 {{duplicate field name: 'a'}}
%0 = named_table @t1 as ["a", "a"] : tuple<si32, si32>
yield %0 : tuple<si32, si32>
Expand Down
36 changes: 36 additions & 0 deletions test/Dialect/Substrait/plan-relation-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: structured-opt -verify-diagnostics -split-input-file %s

// Test error if providing too many names (1 name for 0 fields).
substrait.plan version 0 : 42 : 1 {
// expected-error@+2 {{'substrait.relation' op has mismatching 'field_names' (["x", "y"]) and result type ('tuple<si32>')}}
// expected-note@+1 {{too many field names provided}}
relation as ["x", "y"] {
%0 = named_table @t1 as ["a"] : tuple<si32>
yield %0 : tuple<si32>
}
}

// -----

// Test error if providing too few names (0 names for 1 field).
substrait.plan version 0 : 42 : 1 {
// expected-error@+2 {{'substrait.relation' op has mismatching 'field_names' (["x"]) and result type ('tuple<si32, si32>')}}
// expected-error@+1 {{not enough field names provided}}
relation as ["x"] {
%0 = named_table @t1 as ["a", "b"] : tuple<si32, si32>
yield %0 : tuple<si32, si32>
}
}


// -----

// Test error if providing duplicate field names in the same nesting level.
substrait.plan version 0 : 42 : 1 {
// expected-error@+2 {{'substrait.relation' op has mismatching 'field_names' (["x", "x"]) and result type ('tuple<si32, si32>')}}
// expected-error@+1 {{duplicate field name: 'x'}}
relation as ["x", "x"] {
%0 = named_table @t1 as ["a", "b"] : tuple<si32, si32>
yield %0 : tuple<si32, si32>
}
}
16 changes: 16 additions & 0 deletions test/Dialect/Substrait/plan-version.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,19 @@ substrait.plan version 0 : 42 : 1 {
yield %0 : tuple<si32, si32>
}
}

// -----

// CHECK: substrait.plan
// CHECK-NEXT: relation as ["x", "y", "z"] {
// CHECK-NEXT: named_table
// CHECK-NEXT: yield
// CHECK-NEXT: }
// CHECK-NEXT: }

substrait.plan version 0 : 42 : 1 {
relation as ["x", "y", "z"] {
%0 = named_table @t as ["a", "b", "c"] : tuple<si32, tuple<si32>>
yield %0 : tuple<si32, tuple<si32>>
}
}
30 changes: 29 additions & 1 deletion test/Target/SubstraitPB/Export/plan-version.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// RUN: structured-translate -substrait-to-protobuf %s \
// RUN: structured-translate -substrait-to-protobuf --split-input-file %s \
// RUN: | FileCheck %s

// RUN: structured-translate -substrait-to-protobuf %s \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | structured-translate -protobuf-to-substrait \
// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \
// RUN: | structured-translate -substrait-to-protobuf \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | FileCheck %s

// CHECK-LABEL: version {
Expand All @@ -17,3 +20,28 @@ substrait.plan
git_hash "hash"
producer "producer"
{}

// -----

// CHECK: relations {
// CHECK-NEXT: root {
// CHECK-NEXT: input {
// CHECK-NEXT: read {
// CHECK: named_table {
// CHECK-NEXT: names
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: names: "x"
// CHECK-NEXT: names: "y"
// CHECK-NEXT: names: "z"
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: version

substrait.plan version 0 : 42 : 1 {
relation as ["x", "y", "z"] {
%0 = named_table @t as ["a", "b", "c"] : tuple<si32, tuple<si32>>
yield %0 : tuple<si32, tuple<si32>>
}
}

0 comments on commit 789a0ae

Please sign in to comment.