Skip to content

Commit 05a09e6

Browse files
authored
[MLIR][Affine] Extend/generalize MDG to properly add edges between non-affine ops (#125451)
Drop arbitrary checks and hacks from affine fusion MDG construction and handle all ops using memory read/write effects. This has been a long pending change and it now makes affine fusion more powerful in the presence of non-affine ops and does not limit fusion in parts of the block where it is feasible simply because of non-affine ops elsewhere or intervening non-affine users. Populate memref read and write ops in non-affine region holding ops and non-affine ops at the top level of the Block properly; add the appropriate edges to MDG. Use memory read-write effects and drop assumptions and special handling of ops due to historic reasons. Update MDG to drop unnecessary "unhandled region" hack. This hack is no longer needed with the update to fully and properly construct the MDG. MDG edges now capture dependences between nodes completely. Drop non-affine users check. With the MDG generalization to properly include edges between non-affine nodes/operations, the non-affine users on path check in fusion is no longer needed. Add more test cases to exercise MDG generalization. Drop unnecessary failure when encountering side-effect-free affine.if ops. Improve documentation on MDG.
1 parent 048f533 commit 05a09e6

File tree

6 files changed

+417
-142
lines changed

6 files changed

+417
-142
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,28 @@ struct MemRefAccess;
3737
// was encountered in the loop nest.
3838
struct LoopNestStateCollector {
3939
SmallVector<AffineForOp, 4> forOps;
40+
// Affine loads.
4041
SmallVector<Operation *, 4> loadOpInsts;
42+
// Affine stores.
4143
SmallVector<Operation *, 4> storeOpInsts;
42-
bool hasNonAffineRegionOp = false;
44+
// Non-affine loads.
45+
SmallVector<Operation *, 4> memrefLoads;
46+
// Non-affine stores.
47+
SmallVector<Operation *, 4> memrefStores;
48+
// Free operations.
49+
SmallVector<Operation *, 4> memrefFrees;
4350

4451
// Collects load and store operations, and whether or not a region holding op
4552
// other than ForOp and IfOp was encountered in the loop nest.
4653
void collect(Operation *opToWalk);
4754
};
4855

4956
// MemRefDependenceGraph is a graph data structure where graph nodes are
50-
// top-level operations in a `Block` which contain load/store ops, and edges
51-
// are memref dependences between the nodes.
52-
// TODO: Add a more flexible dependence graph representation.
57+
// top-level operations in a `Block` and edges are memref dependences or SSA
58+
// dependences (on memrefs) between the nodes. Nodes are created for all
59+
// top-level operations except in certain cases (see `init` method). Edges are
60+
// created between nodes with a dependence (see `Edge` documentation). Edges
61+
// aren't created from/to nodes that have no memory effects.
5362
struct MemRefDependenceGraph {
5463
public:
5564
// Node represents a node in the graph. A Node is either an entire loop nest
@@ -60,10 +69,18 @@ struct MemRefDependenceGraph {
6069
unsigned id;
6170
// The top-level statement which is (or contains) a load/store.
6271
Operation *op;
63-
// List of load operations.
72+
// List of affine loads.
6473
SmallVector<Operation *, 4> loads;
65-
// List of store op insts.
74+
// List of non-affine loads.
75+
SmallVector<Operation *, 4> memrefLoads;
76+
// List of affine store ops.
6677
SmallVector<Operation *, 4> stores;
78+
// List of non-affine stores.
79+
SmallVector<Operation *, 4> memrefStores;
80+
// List of free operations.
81+
SmallVector<Operation *, 4> memrefFrees;
82+
// Set of private memrefs used in this node.
83+
DenseSet<Value> privateMemrefs;
6784

6885
Node(unsigned id, Operation *op) : id(id), op(op) {}
6986

@@ -73,6 +90,13 @@ struct MemRefDependenceGraph {
7390
// Returns the store op count for 'memref'.
7491
unsigned getStoreOpCount(Value memref) const;
7592

93+
/// Returns true if there exists an operation with a write memory effect to
94+
/// `memref` in this node.
95+
unsigned hasStore(Value memref) const;
96+
97+
// Returns true if the node has a free op on `memref`.
98+
unsigned hasFree(Value memref) const;
99+
76100
// Returns all store ops in 'storeOps' which access 'memref'.
77101
void getStoreOpsForMemref(Value memref,
78102
SmallVectorImpl<Operation *> *storeOps) const;
@@ -86,7 +110,16 @@ struct MemRefDependenceGraph {
86110
void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const;
87111
};
88112

89-
// Edge represents a data dependence between nodes in the graph.
113+
// Edge represents a data dependence between nodes in the graph. It can either
114+
// be a memory dependence or an SSA dependence. In the former case, it
115+
// corresponds to a pair of memory accesses to the same memref or aliasing
116+
// memrefs where at least one of them has a write or free memory effect. The
117+
// memory accesses need not be affine load/store operations. Operations are
118+
// checked for read/write effects and edges may be added conservatively. Edges
119+
// are not created to/from nodes that have no memory effect. An exception to
120+
// this are SSA dependences between operations that define memrefs (like
121+
// alloc's, view-like ops) and their memory-effecting users that are enclosed
122+
// in loops.
90123
struct Edge {
91124
// The id of the node at the other end of the edge.
92125
// If this edge is stored in Edge = Node.inEdges[i], then
@@ -182,9 +215,12 @@ struct MemRefDependenceGraph {
182215
// of sibling node 'sibId' into node 'dstId'.
183216
void updateEdges(unsigned sibId, unsigned dstId);
184217

185-
// Adds ops in 'loads' and 'stores' to node at 'id'.
186-
void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
187-
const SmallVectorImpl<Operation *> &stores);
218+
// Adds the specified ops to lists of node at 'id'.
219+
void addToNode(unsigned id, ArrayRef<Operation *> loads,
220+
ArrayRef<Operation *> stores,
221+
ArrayRef<Operation *> memrefLoads,
222+
ArrayRef<Operation *> memrefStores,
223+
ArrayRef<Operation *> memrefFrees);
188224

189225
void clearNodeLoadAndStores(unsigned id);
190226

0 commit comments

Comments
 (0)