Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions frontends/PyRTG/test/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ class Singleton(Config):


# MLIR-LABEL: rtg.target @Tgt0 : !rtg.dict<entry0: !rtg.set<index>>
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
# MLIR-NEXT: [[C1:%.+]] = index.constant 1
# MLIR-NEXT: [[SET:%.+]] = rtg.set_create [[C0:%.+]], [[C1:%.+]] : index
# MLIR-NEXT: [[SET:%.+]] = rtg.constant #rtg.set<0 : index, 1 : index> : !rtg.set<index>
# MLIR-NEXT: rtg.yield [[SET]] : !rtg.set<index>
# MLIR-NEXT: }

Expand Down Expand Up @@ -142,6 +140,7 @@ def test1_args(config):
# MLIR-DAG: [[STR:%.+]] = rtg.constant "l1" : !rtg.string
# MLIR-DAG: [[LBL5:%.+]] = rtg.constant #rtg.isa.label<"L_5">
# MLIR-DAG: [[LBL3:%.+]] = rtg.constant #rtg.isa.label<"L_3">
# MLIR-DAG: [[EMPTY_SET:%.+]] = rtg.constant #rtg.set<> : !rtg.set<!rtg.isa.label>
# MLIR-NEXT: [[L1:%.+]] = rtg.label_unique_decl [[STR]]
# MLIR-NEXT: [[L2:%.+]] = rtg.label_unique_decl [[STR]]
# MLIR-NEXT: rtg.label global [[L0]]
Expand All @@ -150,7 +149,6 @@ def test1_args(config):

# MLIR-NEXT: [[SET0:%.+]] = rtg.set_create [[L0]], [[L1]] : !rtg.isa.label
# MLIR-NEXT: [[SET1:%.+]] = rtg.set_create [[L2]] : !rtg.isa.label
# MLIR-NEXT: [[EMPTY_SET:%.+]] = rtg.set_create : !rtg.isa.label
# MLIR-NEXT: [[SET2_1:%.+]] = rtg.set_union [[SET0]], [[SET1]] : !rtg.set<!rtg.isa.label>
# MLIR-NEXT: [[SET2:%.+]] = rtg.set_union [[SET2_1]], [[EMPTY_SET]] : !rtg.set<!rtg.isa.label>
# MLIR-NEXT: [[RL0:%.+]] = rtg.set_select_random [[SET2]] : !rtg.set<!rtg.isa.label>
Expand All @@ -164,7 +162,7 @@ def test1_args(config):

# MLIR-NEXT: [[BAG0:%.+]] = rtg.bag_create (%idx2 x [[L0:%.+]], %idx1 x [[L1:%.+]]) : !rtg.isa.label
# MLIR-NEXT: [[BAG1:%.+]] = rtg.bag_create (%idx1 x [[L2:%.+]]) : !rtg.isa.label
# MLIR-NEXT: [[EMPTY_BAG:%.+]] = rtg.bag_create : !rtg.isa.label
# MLIR-NEXT: [[EMPTY_BAG:%.+]] = rtg.bag_create : !rtg.isa.label
# MLIR-NEXT: [[BAG2_1:%.+]] = rtg.bag_union [[BAG0]], [[BAG1]] : !rtg.bag<!rtg.isa.label>
# MLIR-NEXT: [[BAG2:%.+]] = rtg.bag_union [[BAG2_1]], [[EMPTY_BAG]] : !rtg.bag<!rtg.isa.label>
# MLIR-NEXT: [[RL2:%.+]] = rtg.bag_select_random [[BAG2]] : !rtg.bag<!rtg.isa.label>
Expand Down
9 changes: 9 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def SetCreateOp : RTGOp<"set_create", [Pure, SameTypeOperands]> {

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let hasFolder = 1;
}

def SetSelectRandomOp : RTGOp<"set_select_random", [
Expand Down Expand Up @@ -393,6 +394,8 @@ def SetDifferenceOp : RTGOp<"set_difference", [
let assemblyFormat = [{
$original `,` $diff `:` qualified(type($output)) attr-dict
}];

let hasFolder = 1;
}

def SetUnionOp : RTGOp<"set_union", [
Expand All @@ -410,6 +413,8 @@ def SetUnionOp : RTGOp<"set_union", [
let assemblyFormat = [{
$sets `:` qualified(type($result)) attr-dict
}];

let hasFolder = 1;
}

def SetSizeOp : RTGOp<"set_size", [Pure]> {
Expand All @@ -421,6 +426,8 @@ def SetSizeOp : RTGOp<"set_size", [Pure]> {
let assemblyFormat = [{
$set `:` qualified(type($set)) attr-dict
}];

let hasFolder = 1;
}

def SetCartesianProductOp : RTGOp<"set_cartesian_product", [
Expand Down Expand Up @@ -452,6 +459,8 @@ def SetCartesianProductOp : RTGOp<"set_cartesian_product", [
let results = (outs SetType:$result);

let assemblyFormat = "$inputs `:` qualified(type($inputs)) attr-dict";

let hasFolder = 1;
}

def SetConvertToBagOp : RTGOp<"set_convert_to_bag", [
Expand Down
113 changes: 113 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,119 @@ OpFoldResult StringToLabelOp::fold(FoldAdaptor adaptor) {
return {};
}

//===----------------------------------------------------------------------===//
// SetCreateOp
//===----------------------------------------------------------------------===//

OpFoldResult SetCreateOp::fold(FoldAdaptor adaptor) {
DenseSet<TypedAttr> elements;
for (auto attr : adaptor.getElements()) {
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (!typedAttr)
return {};

elements.insert(typedAttr);
}

return SetAttr::get(getType(), &elements);
}

//===----------------------------------------------------------------------===//
// SetSizeOp
//===----------------------------------------------------------------------===//

OpFoldResult SetSizeOp::fold(FoldAdaptor adaptor) {
auto setAttr = dyn_cast_or_null<SetAttr>(adaptor.getSet());
if (!setAttr)
return {};

return IntegerAttr::get(IndexType::get(getContext()),
setAttr.getElements()->size());
}

//===----------------------------------------------------------------------===//
// SetUnionOp
//===----------------------------------------------------------------------===//

OpFoldResult SetUnionOp::fold(FoldAdaptor adaptor) {
// Fast track to make sure we're not computing the union of all sets but the
// last of the variadic operands is NULL.
if (llvm::any_of(adaptor.getSets(), [&](Attribute attr) { return !attr; }))
return {};

DenseSet<TypedAttr> res;
for (auto set : adaptor.getSets()) {
auto setAttr = dyn_cast<SetAttr>(set);
if (!set)
return {};

for (auto element : *setAttr.getElements())
res.insert(element);
}

return SetAttr::get(getType(), &res);
}

//===----------------------------------------------------------------------===//
// SetDifferenceOp
//===----------------------------------------------------------------------===//

OpFoldResult SetDifferenceOp::fold(FoldAdaptor adaptor) {
auto original = dyn_cast_or_null<SetAttr>(adaptor.getOriginal());
auto diff = dyn_cast_or_null<SetAttr>(adaptor.getDiff());
if (!original || !diff)
return {};

DenseSet<TypedAttr> res(*original.getElements());
for (auto element : *diff.getElements())
res.erase(element);

return SetAttr::get(getType(), &res);
}

//===----------------------------------------------------------------------===//
// SetCartesianProductOp
//===----------------------------------------------------------------------===//

OpFoldResult SetCartesianProductOp::fold(FoldAdaptor adaptor) {
// Fast track to make sure we're not computing the product of all sets but the
// last of the variadic operands is NULL.
if (llvm::any_of(adaptor.getInputs(), [&](Attribute attr) { return !attr; }))
return {};

DenseSet<TypedAttr> res;
SmallVector<SmallVector<TypedAttr>> tuples;
tuples.push_back({});

for (auto input : adaptor.getInputs()) {
auto setAttr = dyn_cast<SetAttr>(input);
if (!setAttr)
return {};

DenseSet<TypedAttr> set(*setAttr.getElements());
if (set.empty()) {
DenseSet<TypedAttr> empty;
return SetAttr::get(getType(), &empty);
}

for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
for (auto [k, el] : llvm::enumerate(set)) {
if (k == set.size() - 1) {
tuples[i].push_back(el);
continue;
}
tuples.push_back(tuples[i]);
tuples.back().push_back(el);
}
}
}

for (auto &tup : tuples)
res.insert(TupleAttr::get(getContext(), tup));

return SetAttr::get(getType(), &res);
}

//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//
Expand Down
59 changes: 59 additions & 0 deletions test/Dialect/RTG/IR/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
func.func @dummy(%arg0: !rtg.isa.label) -> () {return}
func.func @dummy1(%arg0: !rtg.string) -> () {return}
func.func @dummy2(%arg0: !rtg.array<index>) -> () {return}
func.func @dummy3(%arg0: !rtg.set<!rtg.tuple<index>>) -> () {return}
func.func @dummy4(%arg0: !rtg.set<index>) -> () {return}
func.func @dummy5(%arg0: !rtg.set<!rtg.tuple<index, i32, i64>>) -> () {return}
func.func @dummy6(%arg0: index) -> () {return}

// CHECK-LABEL: @interleaveSequences
rtg.test @interleaveSequences(seq0 = %seq0: !rtg.randomized_sequence) {
Expand Down Expand Up @@ -93,3 +97,58 @@ rtg.test @arrays() {
%4 = rtg.array_append %3, %idx1 : !rtg.array<index>
func.call @dummy2(%4) : (!rtg.array<index>) -> ()
}

// CHECK-LABEL: @sets
rtg.test @sets() {
%idx0 = index.constant 0
%idx1 = index.constant 1
%set0 = rtg.constant #rtg.set<1 : index, 0 : index> : !rtg.set<index>
%set1 = rtg.constant #rtg.set<1 : index, 2 : index> : !rtg.set<index>
%set2 = rtg.constant #rtg.set<2 : index, 3 : index> : !rtg.set<index>
%set3 = rtg.constant #rtg.set<> : !rtg.set<i64>
%set4 = rtg.constant #rtg.set<4 : i32, 5 : i32> : !rtg.set<i32>
%set5 = rtg.constant #rtg.set<6 : i64, 7 : i64> : !rtg.set<i64>

// CHECK: [[SET:%.+]] = rtg.constant #rtg.set<0 : index, 1 : index> : !rtg.set<index>
%0 = rtg.set_create %idx1, %idx0 : index

// CHECK: [[SIZE:%.+]] = rtg.constant 2 : index
%size = rtg.set_size %0 : !rtg.set<index>

// CHECK: [[UNION:%.+]] = rtg.constant #rtg.set<0 : index, 1 : index, 2 : index, 3 : index> : !rtg.set<index>
%union = rtg.set_union %set0, %set1, %set2 : !rtg.set<index>

// CHECK: [[DIFF:%.+]] = rtg.constant #rtg.set<0 : index> : !rtg.set<index>
%diff = rtg.set_difference %set0, %set1 : !rtg.set<index>

// CHECK: [[PROD:%.+]] = rtg.constant #rtg.set<
// CHECK-SAME: #rtg.tuple<0 : index, 4 : i32, 6 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<0 : index, 4 : i32, 7 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<0 : index, 5 : i32, 6 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<0 : index, 5 : i32, 7 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<1 : index, 4 : i32, 6 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<1 : index, 4 : i32, 7 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<1 : index, 5 : i32, 6 : i64> : !rtg.tuple<index, i32, i64>
// CHECK-SAME: #rtg.tuple<1 : index, 5 : i32, 7 : i64> : !rtg.tuple<index, i32, i64>>
// CHECK-SAME: !rtg.set<!rtg.tuple<index, i32, i64>>
%prod0 = rtg.set_cartesian_product %set0, %set4, %set5 : !rtg.set<index>, !rtg.set<i32>, !rtg.set<i64>
// CHECK: [[EMPTY:%.+]] = rtg.constant #rtg.set<> : !rtg.set<!rtg.tuple<index, i32, i64>>
%prod1 = rtg.set_cartesian_product %set0, %set4, %set3 : !rtg.set<index>, !rtg.set<i32>, !rtg.set<i64>
// CHECK: [[SET2:%.+]] = rtg.constant #rtg.set<#rtg.tuple<0 : index> : !rtg.tuple<index>, #rtg.tuple<1 : index> : !rtg.tuple<index>> : !rtg.set<!rtg.tuple<index>>
%prod2 = rtg.set_cartesian_product %set0 : !rtg.set<index>

// CHECK: func.call @dummy4([[SET:%.+]])
func.call @dummy4(%0) : (!rtg.set<index>) -> ()
// CHECK: func.call @dummy6([[SIZE:%.+]])
func.call @dummy6(%size) : (index) -> ()
// CHECK: func.call @dummy4([[UNION]])
func.call @dummy4(%union) : (!rtg.set<index>) -> ()
// CHECK: func.call @dummy4([[DIFF]])
func.call @dummy4(%diff) : (!rtg.set<index>) -> ()
// CHECK: func.call @dummy5([[PROD]])
func.call @dummy5(%prod0) : (!rtg.set<!rtg.tuple<index, i32, i64>>) -> ()
// CHECK: func.call @dummy5([[EMPTY]])
func.call @dummy5(%prod1) : (!rtg.set<!rtg.tuple<index, i32, i64>>) -> ()
// CHECK: func.call @dummy3([[SET2]]) : (!rtg.set<!rtg.tuple<index>>) -> ()
func.call @dummy3(%prod2) : (!rtg.set<!rtg.tuple<index>>) -> ()
}