Skip to content

Commit 5d3489e

Browse files
[mlir][transform][lingalg][python] Replace pdl.operation => transform.any_op. (#66392)
For some reason, the mix-ins of the Python bindings of this dialect used the PDL type for "any op". However, PDL isn't involved here, so it makes more sense to use the corresponding type of the transform dialect. This PR changes that.
1 parent 6d73cca commit 5d3489e

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
try:
66
from ..ir import *
7-
from ..dialects import pdl, transform
7+
from ..dialects import transform
88
except ImportError as e:
99
raise RuntimeError("Error loading imports from extension module") from e
1010

@@ -203,7 +203,8 @@ class DecomposeOp:
203203
"""Specialization for DecomposeOp class."""
204204

205205
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
206-
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
206+
transformed_type = transform.AnyOpType.get()
207+
super().__init__(transformed_type, target, loc=loc, ip=ip)
207208

208209

209210
class FuseIntoContainingOp:
@@ -274,7 +275,8 @@ class GeneralizeOp:
274275
"""Specialization for GeneralizeOp class."""
275276

276277
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
277-
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
278+
transformed_type = transform.AnyOpType.get()
279+
super().__init__(transformed_type, target, loc=loc, ip=ip)
278280

279281

280282
class InterchangeOp:
@@ -288,9 +290,9 @@ def __init__(
288290
loc=None,
289291
ip=None,
290292
):
291-
pdl_operation_type = pdl.OperationType.get()
293+
transformed_type = transform.AnyOpType.get()
292294
super().__init__(
293-
pdl_operation_type,
295+
transformed_type,
294296
target,
295297
iterator_interchange=iterator_interchange,
296298
loc=loc,
@@ -503,11 +505,11 @@ def __init__(
503505
):
504506
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
505507

506-
pdl_operation_type = pdl.OperationType.get()
508+
any_op_type = transform.AnyOpType.get()
507509
super().__init__(
508-
pdl_operation_type,
509-
pdl_operation_type,
510-
pdl_operation_type,
510+
any_op_type,
511+
any_op_type,
512+
any_op_type,
511513
target,
512514
padding_values=padding_values,
513515
padding_dimensions=padding_dimensions,
@@ -524,8 +526,8 @@ class ScalarizeOp:
524526
"""Specialization for ScalarizeOp class."""
525527

526528
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
527-
pdl_operation_type = pdl.OperationType.get()
528-
super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
529+
result_type = transform.AnyOpType.get()
530+
super().__init__(result_type, target, loc=loc, ip=ip)
529531

530532

531533
class SplitOp:
@@ -736,9 +738,9 @@ def __init__(
736738
loc=None,
737739
ip=None,
738740
):
739-
pdl_operation_type = pdl.OperationType.get()
741+
transformed_type = transform.AnyOpType.get()
740742
super().__init__(
741-
pdl_operation_type,
743+
transformed_type,
742744
target,
743745
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
744746
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,

0 commit comments

Comments
 (0)