4
4
5
5
try :
6
6
from ..ir import *
7
- from ..dialects import pdl , transform
7
+ from ..dialects import transform
8
8
except ImportError as e :
9
9
raise RuntimeError ("Error loading imports from extension module" ) from e
10
10
@@ -203,7 +203,8 @@ class DecomposeOp:
203
203
"""Specialization for DecomposeOp class."""
204
204
205
205
def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
206
- super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
206
+ transformed_type = transform .AnyOpType .get ()
207
+ super ().__init__ (transformed_type , target , loc = loc , ip = ip )
207
208
208
209
209
210
class FuseIntoContainingOp :
@@ -274,7 +275,8 @@ class GeneralizeOp:
274
275
"""Specialization for GeneralizeOp class."""
275
276
276
277
def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
277
- super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
278
+ transformed_type = transform .AnyOpType .get ()
279
+ super ().__init__ (transformed_type , target , loc = loc , ip = ip )
278
280
279
281
280
282
class InterchangeOp :
@@ -288,9 +290,9 @@ def __init__(
288
290
loc = None ,
289
291
ip = None ,
290
292
):
291
- pdl_operation_type = pdl . OperationType .get ()
293
+ transformed_type = transform . AnyOpType .get ()
292
294
super ().__init__ (
293
- pdl_operation_type ,
295
+ transformed_type ,
294
296
target ,
295
297
iterator_interchange = iterator_interchange ,
296
298
loc = loc ,
@@ -503,11 +505,11 @@ def __init__(
503
505
):
504
506
transpose_paddings = _get_int_array_array_attr (transpose_paddings )
505
507
506
- pdl_operation_type = pdl . OperationType .get ()
508
+ any_op_type = transform . AnyOpType .get ()
507
509
super ().__init__ (
508
- pdl_operation_type ,
509
- pdl_operation_type ,
510
- pdl_operation_type ,
510
+ any_op_type ,
511
+ any_op_type ,
512
+ any_op_type ,
511
513
target ,
512
514
padding_values = padding_values ,
513
515
padding_dimensions = padding_dimensions ,
@@ -524,8 +526,8 @@ class ScalarizeOp:
524
526
"""Specialization for ScalarizeOp class."""
525
527
526
528
def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
527
- pdl_operation_type = pdl . OperationType .get ()
528
- super ().__init__ (pdl_operation_type , target , loc = loc , ip = ip )
529
+ result_type = transform . AnyOpType .get ()
530
+ super ().__init__ (result_type , target , loc = loc , ip = ip )
529
531
530
532
531
533
class SplitOp :
@@ -736,9 +738,9 @@ def __init__(
736
738
loc = None ,
737
739
ip = None ,
738
740
):
739
- pdl_operation_type = pdl . OperationType .get ()
741
+ transformed_type = transform . AnyOpType .get ()
740
742
super ().__init__ (
741
- pdl_operation_type ,
743
+ transformed_type ,
742
744
target ,
743
745
disable_multi_reduction_to_contract_patterns = disable_multi_reduction_to_contract_patterns ,
744
746
disable_transfer_permutation_map_lowering_patterns = disable_transfer_permutation_map_lowering_patterns ,
0 commit comments