Skip to content

Commit 7f4a3a9

Browse files
authored
[MLIR][XeGPU][TransformOps] Add convert_layout op (#167342)
Adds `transform.xegpu.convert_layout` transform op that inserts an `xegpu.convert_layout` op for a given `Value`.
1 parent ec4207b commit 7f4a3a9

File tree

5 files changed

+464
-6
lines changed

5 files changed

+464
-6
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,90 @@ def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
244244
}];
245245
}
246246

247+
def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
248+
AttrSizedOperandSegments,
249+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
250+
TransformOpInterface
251+
]> {
252+
253+
let summary = "Convert xegpu.layout attribute for a value.";
254+
let description = [{
255+
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
256+
of a value. The input and target layouts are defined by the `*sg_layout`,
257+
`*sg_data` and optional `*inst_data` attributes. Returns a handle to the
258+
emitted `xegpu.convert_layout` op.
259+
}];
260+
261+
let arguments = (ins TransformValueHandleTypeInterface:$target,
262+
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_layout,
263+
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_data,
264+
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_inst_data,
265+
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_layout,
266+
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_data,
267+
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_inst_data,
268+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_layout,
269+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_data,
270+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_inst_data,
271+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_layout,
272+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_data,
273+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data
274+
);
275+
276+
let results = (outs TransformHandleTypeInterface:$newConvertOp);
277+
let builders = [
278+
OpBuilder<(ins "Value":$target,
279+
"ArrayRef<OpFoldResult>":$mixedInputSgLayout,
280+
"ArrayRef<OpFoldResult>":$mixedInputSgData,
281+
"ArrayRef<OpFoldResult>":$mixedInputInstData,
282+
"ArrayRef<OpFoldResult>":$mixedTargetSgLayout,
283+
"ArrayRef<OpFoldResult>":$mixedTargetSgData,
284+
"ArrayRef<OpFoldResult>":$mixedTargetInstData
285+
)>,
286+
];
287+
288+
let assemblyFormat = [{
289+
$target
290+
`input_sg_layout` `=` custom<DynamicIndexList>($input_sg_layout, $static_input_sg_layout)
291+
`input_sg_data` `=` custom<DynamicIndexList>($input_sg_data, $static_input_sg_data)
292+
(`input_inst_data` `=` custom<DynamicIndexList>($input_inst_data, $static_input_inst_data)^)?
293+
`target_sg_layout` `=` custom<DynamicIndexList>($target_sg_layout, $static_target_sg_layout)
294+
`target_sg_data` `=` custom<DynamicIndexList>($target_sg_data, $static_target_sg_data)
295+
(`target_inst_data` `=` custom<DynamicIndexList>($target_inst_data, $static_target_inst_data)^)?
296+
attr-dict `:` functional-type(operands, results)
297+
}];
298+
299+
let extraClassDeclaration = [{
300+
::mlir::DiagnosedSilenceableFailure apply(
301+
::mlir::transform::TransformRewriter &rewriter,
302+
::mlir::transform::TransformResults &transformResults,
303+
::mlir::transform::TransformState &state);
304+
305+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgLayout() {
306+
Builder b(getContext());
307+
return getMixedValues(getStaticInputSgLayout(), getInputSgLayout(), b);
308+
}
309+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgData() {
310+
Builder b(getContext());
311+
return getMixedValues(getStaticInputSgData(), getInputSgData(), b);
312+
}
313+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputInstData() {
314+
Builder b(getContext());
315+
return getMixedValues(getStaticInputInstData(), getInputInstData(), b);
316+
}
317+
318+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgLayout() {
319+
Builder b(getContext());
320+
return getMixedValues(getStaticTargetSgLayout(), getTargetSgLayout(), b);
321+
}
322+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgData() {
323+
Builder b(getContext());
324+
return getMixedValues(getStaticTargetSgData(), getTargetSgData(), b);
325+
}
326+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetInstData() {
327+
Builder b(getContext());
328+
return getMixedValues(getStaticTargetInstData(), getTargetInstData(), b);
329+
}
330+
}];
331+
}
332+
247333
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,110 @@ void transform::InsertPrefetchOp::getEffects(
537537
modifiesPayload(effects);
538538
}
539539

540+
void transform::ConvertLayoutOp::build(
541+
OpBuilder &builder, OperationState &ostate, Value target,
542+
ArrayRef<OpFoldResult> mixedInputSgLayout,
543+
ArrayRef<OpFoldResult> mixedInputSgData,
544+
ArrayRef<OpFoldResult> mixedInputInstData,
545+
ArrayRef<OpFoldResult> mixedTargetSgLayout,
546+
ArrayRef<OpFoldResult> mixedTargetSgData,
547+
ArrayRef<OpFoldResult> mixedTargetInstData) {
548+
SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
549+
staticInputInstData;
550+
SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
551+
dynamicInputInstData;
552+
dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
553+
staticInputSgLayout);
554+
dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
555+
staticInputSgData);
556+
dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
557+
staticInputInstData);
558+
SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
559+
staticTargetInstData;
560+
SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
561+
dynamicTargetInstData;
562+
dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
563+
staticTargetSgLayout);
564+
dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
565+
staticTargetSgData);
566+
dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
567+
staticTargetInstData);
568+
build(builder, ostate, target.getType(),
569+
/*target=*/target,
570+
/*input_sg_layout=*/dynamicInputSgLayout,
571+
/*input_sg_data=*/dynamicInputSgData,
572+
/*input_inst_data=*/dynamicInputInstData,
573+
/*target_sg_layout=*/dynamicTargetSgLayout,
574+
/*target_sg_data=*/dynamicTargetSgData,
575+
/*target_inst_data=*/dynamicTargetInstData,
576+
/*static_input_sg_layout=*/staticInputSgLayout,
577+
/*static_input_sg_data=*/staticInputSgData,
578+
/*static_input_inst_data=*/staticInputInstData,
579+
/*static_target_sg_layout=*/staticTargetSgLayout,
580+
/*static_target_sg_data=*/staticTargetSgData,
581+
/*static_target_inst_data=*/staticTargetInstData);
582+
}
583+
584+
DiagnosedSilenceableFailure
585+
transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
586+
transform::TransformResults &results,
587+
transform::TransformState &state) {
588+
auto targetValues = state.getPayloadValues(getTarget());
589+
if (!llvm::hasSingleElement(targetValues))
590+
return emitDefiniteFailure()
591+
<< "requires exactly one target value handle (got "
592+
<< llvm::range_size(targetValues) << ")";
593+
auto value = *targetValues.begin();
594+
595+
// Construct layout attributes.
596+
xegpu::LayoutAttr inputLayoutAttr = nullptr;
597+
auto status = getLayoutAttrFromOperands(
598+
getContext(), state, (*this), getMixedInputSgLayout(),
599+
getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
600+
if (!status.succeeded())
601+
return status;
602+
603+
xegpu::LayoutAttr targetLayoutAttr = nullptr;
604+
status = getLayoutAttrFromOperands(
605+
getContext(), state, (*this), getMixedTargetSgLayout(),
606+
getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
607+
if (!status.succeeded())
608+
return status;
609+
610+
// Find first user op to define insertion point for layout conversion.
611+
if (value.use_empty())
612+
return emitSilenceableFailure(getLoc())
613+
<< "Value has no users to insert layout conversion.";
614+
Operation *userOp = *value.getUsers().begin();
615+
616+
// Emit convert_layout op.
617+
rewriter.setInsertionPoint(userOp);
618+
auto convLayoutOp =
619+
xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
620+
value, inputLayoutAttr, targetLayoutAttr);
621+
// Replace load op result with the converted layout.
622+
rewriter.replaceUsesWithIf(
623+
value, convLayoutOp.getResult(), [&](OpOperand &use) {
624+
return use.getOwner() != convLayoutOp.getOperation();
625+
});
626+
627+
results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
628+
return DiagnosedSilenceableFailure::success();
629+
}
630+
631+
void transform::ConvertLayoutOp::getEffects(
632+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
633+
onlyReadsHandle(getTargetMutable(), effects);
634+
onlyReadsHandle(getInputSgLayoutMutable(), effects);
635+
onlyReadsHandle(getInputSgDataMutable(), effects);
636+
onlyReadsHandle(getInputInstDataMutable(), effects);
637+
onlyReadsHandle(getTargetSgLayoutMutable(), effects);
638+
onlyReadsHandle(getTargetSgDataMutable(), effects);
639+
onlyReadsHandle(getTargetInstDataMutable(), effects);
640+
producesHandle(getOperation()->getOpResults(), effects);
641+
modifiesPayload(effects);
642+
}
643+
540644
namespace {
541645
class XeGPUTransformDialectExtension
542646
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def __init__(
4242
)
4343

4444

45+
def get_desc_op(
46+
target: Value,
47+
*,
48+
loc=None,
49+
ip=None,
50+
) -> OpResult:
51+
return GetDescOp(target, loc=loc, ip=ip).result
52+
53+
4554
@_ods_cext.register_operation(_Dialect, replace=True)
4655
class SetDescLayoutOp(SetDescLayoutOp):
4756
"""Specialization for SetDescLayoutOp class."""
@@ -88,6 +97,25 @@ def __init__(
8897
)
8998

9099

100+
def set_desc_layout(
101+
target: Union[Operation, Value],
102+
sg_layout: MixedValues,
103+
sg_data: MixedValues,
104+
*,
105+
inst_data: Optional[MixedValues] = None,
106+
loc=None,
107+
ip=None,
108+
) -> OpResult:
109+
return SetDescLayoutOp(
110+
target,
111+
sg_layout,
112+
sg_data,
113+
inst_data=inst_data,
114+
loc=loc,
115+
ip=ip,
116+
).result
117+
118+
91119
@_ods_cext.register_operation(_Dialect, replace=True)
92120
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
93121
"""Specialization for SetOpLayoutAttrOp class."""
@@ -135,6 +163,29 @@ def __init__(
135163
)
136164

137165

166+
def set_op_layout_attr(
167+
target: Union[Operation, Value],
168+
sg_layout: MixedValues,
169+
sg_data: MixedValues,
170+
*,
171+
inst_data: Optional[MixedValues] = None,
172+
index: Optional[Union[int, Attribute]] = None,
173+
result: Optional[Union[bool, Attribute]] = None,
174+
loc=None,
175+
ip=None,
176+
) -> SetOpLayoutAttrOp:
177+
return SetOpLayoutAttrOp(
178+
target,
179+
sg_layout,
180+
sg_data,
181+
inst_data=inst_data,
182+
index=index,
183+
result=result,
184+
loc=loc,
185+
ip=ip,
186+
)
187+
188+
138189
@_ods_cext.register_operation(_Dialect, replace=True)
139190
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
140191
"""Specialization for SetGPULaunchThreadsOp class."""
@@ -210,4 +261,98 @@ def insert_prefetch(
210261
loc=None,
211262
ip=None,
212263
) -> OpResult:
213-
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
264+
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
265+
266+
267+
@_ods_cext.register_operation(_Dialect, replace=True)
268+
class ConvertLayoutOp(ConvertLayoutOp):
269+
"""Specialization for ConvertLayoutOp class."""
270+
271+
def __init__(
272+
self,
273+
target: Value,
274+
input_sg_layout: MixedValues,
275+
input_sg_data: MixedValues,
276+
target_sg_layout: MixedValues,
277+
target_sg_data: MixedValues,
278+
*,
279+
input_inst_data: Optional[MixedValues] = None,
280+
target_inst_data: Optional[MixedValues] = None,
281+
loc=None,
282+
ip=None,
283+
):
284+
input_inst_data = [] if input_inst_data is None else input_inst_data
285+
target_inst_data = [] if target_inst_data is None else target_inst_data
286+
(
287+
dynamic_input_sg_layout,
288+
static_input_sg_layout,
289+
_,
290+
) = _dispatch_dynamic_index_list(input_sg_layout)
291+
(
292+
dynamic_input_sg_data,
293+
static_input_sg_data,
294+
_,
295+
) = _dispatch_dynamic_index_list(input_sg_data)
296+
(
297+
dynamic_input_inst_data,
298+
static_input_inst_data,
299+
_,
300+
) = _dispatch_dynamic_index_list(input_inst_data)
301+
(
302+
dynamic_target_sg_layout,
303+
static_target_sg_layout,
304+
_,
305+
) = _dispatch_dynamic_index_list(target_sg_layout)
306+
(
307+
dynamic_target_sg_data,
308+
static_target_sg_data,
309+
_,
310+
) = _dispatch_dynamic_index_list(target_sg_data)
311+
(
312+
dynamic_target_inst_data,
313+
static_target_inst_data,
314+
_,
315+
) = _dispatch_dynamic_index_list(target_inst_data)
316+
super().__init__(
317+
transform.AnyOpType.get(),
318+
target,
319+
dynamic_input_sg_layout,
320+
dynamic_input_sg_data,
321+
dynamic_input_inst_data,
322+
dynamic_target_sg_layout,
323+
dynamic_target_sg_data,
324+
dynamic_target_inst_data,
325+
static_input_sg_layout=static_input_sg_layout,
326+
static_input_sg_data=static_input_sg_data,
327+
static_input_inst_data=static_input_inst_data,
328+
static_target_sg_layout=static_target_sg_layout,
329+
static_target_sg_data=static_target_sg_data,
330+
static_target_inst_data=static_target_inst_data,
331+
loc=loc,
332+
ip=ip,
333+
)
334+
335+
336+
def convert_layout(
337+
target: Value,
338+
input_sg_layout: MixedValues,
339+
input_sg_data: MixedValues,
340+
target_sg_layout: MixedValues,
341+
target_sg_data: MixedValues,
342+
*,
343+
input_inst_data: Optional[MixedValues] = None,
344+
target_inst_data: Optional[MixedValues] = None,
345+
loc=None,
346+
ip=None,
347+
) -> ConvertLayoutOp:
348+
return ConvertLayoutOp(
349+
target,
350+
input_sg_layout,
351+
input_sg_data,
352+
target_sg_layout,
353+
target_sg_data,
354+
input_inst_data=input_inst_data,
355+
target_inst_data=target_inst_data,
356+
loc=loc,
357+
ip=ip,
358+
).result

0 commit comments

Comments
 (0)