Skip to content

[MLIR][DLTI] Enable types as keys in DLTI-query utils #105995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace mlir {
namespace dlti {
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
/// query interface-implementing attrs, starting from attr obtained from `op`.
FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
bool emitError = false);
} // namespace dlti
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [

A lookup is performed for the given `keys` at `target` op - or its closest
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
returns an attribute for a key. If more than one key is provided, the lookup
continues recursively, now on the returned attributes, with the condition
that these implement the above interface. For example if the payload IR is
returns an attribute for a key. Each key should be either a (quoted) string
or a type. If more than one key is provided, the lookup continues
recursively, now on the returned attributes, with the condition that these
implement the above interface. For example if the payload IR is

```
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
Expand All @@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
}];

let arguments = (ins TransformHandleTypeInterface:$target,
StrArrayAttr:$keys);
ArrayAttr:$keys);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"$keys `at` $target attr-dict `:` functional-type(operands, results)";
Expand Down
32 changes: 28 additions & 4 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,16 @@ getClosestQueryable(Operation *op) {
return std::pair(queryable, op);
}

FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
bool emitError) {
FailureOr<Attribute>
dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
if (keys.empty()) {
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(op->getLoc()) << "no keys provided to attempt query with";
}
return failure();
}

auto [queryable, queryOp] = getClosestQueryable(op);
Operation *reportOp = (queryOp ? queryOp : op);

Expand All @@ -438,6 +446,15 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
return failure();
}

auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
std::string buf;
llvm::TypeSwitch<DataLayoutEntryKey>(key)
.Case<StringAttr, Type>( // The only two kinds of key we know of.
[&](auto key) { llvm::raw_string_ostream(buf) << key; })
.Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
return buf;
};

Attribute currentAttr = queryable;
for (auto &&[idx, key] : llvm::enumerate(keys)) {
if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
Expand All @@ -446,17 +463,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "key " << key << " has no DLTI-mapping per attr: " << map;
<< "key " << keyToStr(key)
<< " has no DLTI-mapping per attr: " << map;
}
return failure();
}
currentAttr = *maybeAttr;
} else {
if (emitError) {
std::string commaSeparatedKeys;
llvm::interleave(
keys.take_front(idx), // All prior keys.
[&](auto key) { commaSeparatedKeys += keyToStr(key); },
[&]() { commaSeparatedKeys += ","; });

auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "got non-DLTI-queryable attribute upon looking up keys ["
<< keys.take_front(idx) << "] at op";
<< commaSeparatedKeys << "] at op";
}
return failure();
}
Expand Down
11 changes: 10 additions & 1 deletion mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
SmallVector<DataLayoutEntryKey> keys;
for (Attribute key : getKeys()) {
if (auto strKey = dyn_cast<StringAttr>(key))
keys.push_back(strKey);
else if (auto typeKey = dyn_cast<TypeAttr>(key))
keys.push_back(typeKey.getValue());
else
return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
"only StringAttr and TypeAttr are allowed");
}

FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);

Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/DLTI/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@

// -----

// expected-error@below {{repeated layout entry key: 'i32'}}
"test.unknown_op"() { test.unknown_attr = #dlti.map<
#dlti.dl_entry<i32, 42>,
#dlti.dl_entry<i32, 42>
>} : () -> ()

// -----

// expected-error@below {{repeated layout entry key: 'i32'}}
"test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
#dlti.dl_entry<i32, 42>,
Expand Down
120 changes: 120 additions & 0 deletions mlir/test/Dialect/DLTI/query.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {

// -----

// expected-remark @below {{i32 present in set : unit}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, unit>>} {
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
%param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "i32 present in set :" at %module : !transform.any_param, !transform.any_op
transform.yield
}
}

// -----

// expected-remark @below {{associated attr 32 : i32}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
%param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
transform.yield
}
}

// -----

// expected-remark @below {{width in bits of i32 = 32 : i64}}
// expected-remark @below {{width in bits of f64 = 64 : i64}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
%i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
%f64bits = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
transform.yield
}
}

// -----

// expected-remark @below {{associated attr 42 : i32}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
Expand Down Expand Up @@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {

// -----

// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
// expected-error @below {{target op of failed DLTI query}}
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' op failed to apply}}
%param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}

// -----

module {
// expected-error @below {{target op of failed DLTI query}}
// expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
Expand All @@ -353,6 +424,55 @@ module attributes {transform.with_named_sequence} {

// -----

// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
// expected-error @below {{target op of failed DLTI query}}
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' op failed to apply}}
%param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}

// -----

module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' keys of wrong type: only StringAttr and TypeAttr are allowed}}
%param = transform.dlti.query [1] at %funcs : (!transform.any_op) -> !transform.param<i64>
transform.yield
}
}

// -----

module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"test.id", 42 : i32>>} {
// expected-error @below {{target op of failed DLTI query}}
// expected-note @below {{no keys provided to attempt query with}}
func.func private @f()
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' op failed to apply}}
%param = transform.dlti.query [] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}

// -----

module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
}
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/DLTI/valid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,18 @@ module attributes {
"GPU": #dlti.target_device_spec<
#dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
>} {}


// -----

// CHECK: "test.op_with_dlti_map"() ({
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
"test.op_with_dlti_map"() ({
}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()

// -----

// CHECK: "test.op_with_dlti_map"() ({
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
"test.op_with_dlti_map"() ({
}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()
Loading