Commit 66fed33
authored
[mlir][vector] Update
Updates `castAwayContractionLeadingOneDim` to check for leading unit
dimensions before inserting `vector.transpose` ops.
Currently `castAwayContractionLeadingOneDim` removes all leading unit
dims based on the accumulator and transpose any subsequent operands to
match the accumulator indexing. This does not take into account if the
transpose is strictly necessary, for instance when given this
vector-matrix contract:
```mlir
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
```
Passing this through `castAwayContractionLeadingOneDim` pattern produces
the following:
```mlir
%0 = vector.transpose %arg0, [1, 0, 2] : vector<1x1x8xi32> to vector<1x1x8xi32>
%1 = vector.extract %0[0] : vector<1x8xi32> from vector<1x1x8xi32>
%2 = vector.extract %arg2[0] : vector<8xi32> from vector<1x8xi32>
%3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %arg1, %2 : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
%4 = vector.broadcast %3 : vector<8xi32> to vector<1x8xi32>
```
The `vector.transpose` introduced does not affect the underlying data
layout (effectively a no op), but it cannot be folded automatically.
This change avoids inserting transposes when only leading unit
dimensions are involved.
Fixes #85691castAwayContractionLeadingOneDim to omit transposes solely on leading unit dims. (#85694)1 parent c511c90 commit 66fed33
File tree
2 files changed
+30
-3
lines changed- mlir
- lib/Dialect/Vector/Transforms
- test/Dialect/Vector
2 files changed
+30
-3
lines changedLines changed: 19 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
398 | 398 | | |
399 | 399 | | |
400 | 400 | | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
401 | 416 | | |
402 | 417 | | |
403 | 418 | | |
404 | 419 | | |
405 | 420 | | |
406 | | - | |
407 | | - | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
408 | 425 | | |
409 | 426 | | |
410 | 427 | | |
| |||
Lines changed: 11 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
238 | 238 | | |
239 | 239 | | |
240 | 240 | | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
241 | 252 | | |
242 | 253 | | |
243 | 254 | | |
| |||
663 | 674 | | |
664 | 675 | | |
665 | 676 | | |
666 | | - | |
| |||
0 commit comments