Skip to content

Commit f2fcf95

Browse files
committed
[GPU] Add 2d load/store ops validation
1 parent d3f403d commit f2fcf95

File tree

2 files changed

+426
-3
lines changed

2 files changed

+426
-3
lines changed

lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 255 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1111
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1213
#include "mlir/IR/DialectImplementation.h"
1314
#include "llvm/ADT/TypeSwitch.h"
1415

@@ -18,9 +19,260 @@ using namespace xevm;
1819
#include "gc/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
1920
#include "gc/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
2021

21-
// TODO
22-
LogicalResult BlockLoad2dOp::verify() { return success(); }
23-
LogicalResult BlockStore2dOp::verify() { return success(); }
22+
namespace {
23+
constexpr uint32_t subgroupSize = 16;
24+
25+
template <typename Op> LogicalResult verifyMatrixInput(Op op) {
26+
static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp>::value,
27+
"Unexpected template parameter");
28+
29+
std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
30+
std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
31+
if (pitch && width && *pitch < *width)
32+
return op->emitOpError(
33+
"4th operand (base pitch) should be >= 2nd operand (base width)");
34+
35+
if (op.getElemSizeInBits() != 8 && op.getElemSizeInBits() != 16 &&
36+
op.getElemSizeInBits() != 32)
37+
return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
38+
39+
uint32_t tileHeight = op.getTileHeight();
40+
if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 &&
41+
tileHeight != 8 && tileHeight != 16 && tileHeight != 32)
42+
return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
43+
44+
uint32_t vBlocks = op.getVBlocks();
45+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8)
46+
return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
47+
48+
return success();
49+
}
50+
51+
LogicalResult verify2DBlockLoadHWRestriction(BlockLoad2dOp op) {
52+
VectorType resTy = op.getRes().getType();
53+
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
54+
unsigned resSize = resTy.getNumElements() * resElemTySize;
55+
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
56+
op.getTileWidth() * op.getVBlocks() / subgroupSize;
57+
if (resSize != expectedSize)
58+
return op.emitOpError() << "result size of " << resSize
59+
<< " bits does not match the expected size of "
60+
<< expectedSize << " bits";
61+
62+
if (op.getTranspose() && op.getVnniTransform())
63+
return op.emitOpError(
64+
"transpose and vnni_transform are mutually exclusive");
65+
66+
if (!op.getTranspose() && !op.getVnniTransform()) {
67+
uint32_t tileHeight = op.getTileHeight();
68+
if (tileHeight < 1 || tileHeight > 32)
69+
return op.emitOpError("expecting tile_height to be between 1 and 32");
70+
71+
uint32_t tileWidth = op.getTileWidth();
72+
uint32_t vBlocks = op.getVBlocks();
73+
switch (op.getElemSizeInBits()) {
74+
case 8:
75+
if (tileWidth < 4 || tileWidth > 64)
76+
return op.emitOpError("expecting tile_width to be between 4 and 64");
77+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
78+
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
79+
if (tileWidth * vBlocks > 64)
80+
return op.emitOpError(
81+
"tile_width * v_blocks should be less than or equal "
82+
"to 64 for 8 bit elements");
83+
break;
84+
case 16:
85+
if (tileWidth < 2 || tileWidth > 32)
86+
return op.emitOpError("expecting tile_width to be between 2 and 32");
87+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
88+
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
89+
if (tileWidth * vBlocks > 32)
90+
return op.emitOpError(
91+
"tile_width * v_blocks should be less than or equal "
92+
"to 32 for 16 bit elements");
93+
break;
94+
case 32:
95+
if (tileWidth < 1 || tileWidth > 16)
96+
return op.emitOpError("expecting tile_width to be between 1 and 16");
97+
if (vBlocks != 1 && vBlocks != 2)
98+
return op.emitOpError("expecting v_blocks to be 1 or 2");
99+
if (tileWidth * vBlocks > 16)
100+
return op.emitOpError(
101+
"tile_width * v_blocks should be less than or equal "
102+
"to 16 for 32 bit elements");
103+
break;
104+
case 64:
105+
if (tileWidth < 1 || tileWidth > 8)
106+
return op.emitOpError("expecting tile_width to be between 1 and 8");
107+
if (vBlocks != 1)
108+
return op.emitOpError("expecting v_blocks to be 1");
109+
break;
110+
default:
111+
return op.emitOpError(
112+
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
113+
}
114+
115+
return success();
116+
}
117+
118+
if (op.getTranspose()) {
119+
assert(!op.getVnniTransform() &&
120+
"Expecting vnni_transform should be false");
121+
122+
uint32_t vBlocks = op.getVBlocks();
123+
if (vBlocks != 1)
124+
return op.emitOpError("expecting v_blocks to be 1");
125+
126+
uint32_t tileHeight = op.getTileHeight();
127+
uint32_t tileWidth = op.getTileWidth();
128+
switch (op.getElemSizeInBits()) {
129+
case 32:
130+
if (tileHeight < 1 || tileHeight > 32)
131+
return op.emitOpError("expecting tile_height to be between 1 and 32");
132+
if (tileWidth < 1 || tileWidth > 8)
133+
return op.emitOpError("expecting tile_width to be between 1 and 8");
134+
break;
135+
case 64:
136+
if (tileHeight != 8)
137+
return op.emitOpError(
138+
"expecting tile_height to be 8 for 64 bit elements");
139+
if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
140+
return op.emitOpError("expecting tile_width to be 1, 2, or 4");
141+
break;
142+
default:
143+
return op.emitOpError("transpose is only supported for 32 and 64 bit "
144+
"elements");
145+
}
146+
147+
return success();
148+
}
149+
150+
assert(op.getVnniTransform() && !op.getTranspose() &&
151+
"Expecting vnni_transform should be true and transpose should be "
152+
"false");
153+
154+
uint32_t vBlocks = op.getVBlocks();
155+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
156+
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
157+
158+
uint32_t tileHeight = op.getTileHeight();
159+
uint32_t tileWidth = op.getTileWidth();
160+
switch (op.getElemSizeInBits()) {
161+
case 8:
162+
if (tileHeight < 4 || tileHeight > 32)
163+
return op.emitOpError("expecting tile_height to be between 4 and 32");
164+
if (tileWidth < 4 || tileWidth > 16)
165+
return op.emitOpError("expecting tile_width to be between 4 and 16");
166+
break;
167+
case 16:
168+
if (tileHeight < 2 || tileHeight > 32)
169+
return op.emitOpError("expecting tile_height to be between 2 and 32");
170+
if (tileWidth < 2 || tileWidth > 16)
171+
return op.emitOpError("expecting tile_width to be between 2 and 16");
172+
if (tileWidth * vBlocks > 32)
173+
return op.emitOpError(
174+
"tile_width * v_blocks should be less than or equal "
175+
"to 32 for 16 bit elements");
176+
break;
177+
default:
178+
return op.emitOpError("vnni_transform is only supported for 8 and 16 bit "
179+
"elements");
180+
}
181+
182+
return success();
183+
}
184+
185+
static LogicalResult verify2DBlockStoreHWRestriction(BlockStore2dOp op) {
186+
uint32_t tileHeight = op.getTileHeight();
187+
if (tileHeight < 1 || tileHeight > 8)
188+
return op.emitOpError("expecting tile_height to be between 1 and 8");
189+
190+
uint32_t tileWidth = op.getTileWidth();
191+
switch (op.getElemSizeInBits()) {
192+
case 8:
193+
if (tileWidth < 4 || tileWidth > 64)
194+
return op.emitOpError("expecting tile_width to be between 4 and 64");
195+
break;
196+
case 16:
197+
if (tileWidth < 2 || tileWidth > 32)
198+
return op.emitOpError("expecting tile_width to be between 2 and 32");
199+
break;
200+
case 32:
201+
if (tileWidth < 1 || tileWidth > 16)
202+
return op.emitOpError("expecting tile_width to be between 1 and 16");
203+
break;
204+
case 64:
205+
if (tileWidth < 1 || tileWidth > 8)
206+
return op.emitOpError("expecting tile_width to be between 1 and 8");
207+
break;
208+
default:
209+
return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
210+
}
211+
212+
uint32_t vBlocks = op.getVBlocks();
213+
if (vBlocks != 1)
214+
return op.emitOpError("expecting v_blocks to be 1");
215+
return success();
216+
}
217+
218+
} // namespace
219+
220+
LogicalResult BlockLoad2dOp::verify() {
221+
if (verify2DBlockLoadHWRestriction(*this).failed())
222+
return failure();
223+
224+
if (verifyMatrixInput(*this).failed())
225+
return failure();
226+
227+
VectorType resTy = getRes().getType();
228+
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
229+
if (getElemSizeInBits() == 32 || getVnniTransform()) {
230+
if (resElemTySize != 32)
231+
return emitOpError() << "expecting result element type to be 32 bits";
232+
}
233+
234+
uint32_t tileWidth = getTileWidth();
235+
if (getVnniTransform()) {
236+
if (tileWidth != 16)
237+
return emitOpError(
238+
"tile_width when vnni_transform is true should be equal "
239+
"to subgroup size (16 elements)");
240+
return success();
241+
}
242+
243+
return success();
244+
}
245+
246+
LogicalResult BlockStore2dOp::verify() {
247+
if (verify2DBlockStoreHWRestriction(*this).failed())
248+
return failure();
249+
250+
if (verifyMatrixInput(*this).failed())
251+
return failure();
252+
253+
uint32_t tileWidth = getTileWidth();
254+
switch (getElemSizeInBits()) {
255+
case 8:
256+
if (tileWidth != 16 && tileWidth != 32)
257+
return emitOpError("tile_width for 8 bit elements should be equal to "
258+
"16 or 32");
259+
break;
260+
case 16:
261+
if (tileWidth != 16)
262+
return emitOpError("tile_width for 16 bit elements should be equal "
263+
"to 16");
264+
break;
265+
case 32:
266+
if (tileWidth != 16)
267+
return emitOpError("tile_width for 32 bit elements should be equal "
268+
"to 16");
269+
break;
270+
default:
271+
llvm_unreachable("unexpected element size");
272+
}
273+
274+
return success();
275+
}
24276

25277
void XeVMDialect::initialize() {
26278
// NOLINTBEGIN

0 commit comments

Comments
 (0)