@@ -79,19 +79,19 @@ class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase):
79
79
80
80
Args:
81
81
prim_name: The name of the primitive that should be handled.
82
- int_tmpl : The template used for the integer case.
83
- bool_tmpl : The template used for the bool case.
82
+ bitwise_tmpl : The template used for the bitwise case.
83
+ logical_tmpl : The template used for the logical case.
84
84
85
85
Note:
86
86
Since it does not make sense to single out `not` and keep the other
87
87
logical operations in `ArithmeticOperationTranslator` all of them are
88
88
handled by this class.
89
89
"""
90
90
91
- def __init__ (self , prim_name : str , int_tmpl : str , bool_tmpl : str ) -> None :
91
+ def __init__ (self , prim_name : str , bitwise_tmpl : str , logical_tmpl : str ) -> None :
92
92
super ().__init__ (primitive_name = prim_name )
93
- self ._int_tmpl = int_tmpl
94
- self ._bool_tmpl = bool_tmpl
93
+ self ._bitwise_tmpl = bitwise_tmpl
94
+ self ._logical_tmpl = logical_tmpl
95
95
96
96
@override
97
97
def write_tasklet_code (
@@ -101,8 +101,8 @@ def write_tasklet_code(
101
101
eqn : jax_core .JaxprEqn ,
102
102
) -> str :
103
103
if all (util .get_jax_var_dtype (invar ) is dace .bool_ for invar in eqn .invars ):
104
- return self ._bool_tmpl
105
- return self ._int_tmpl
104
+ return self ._logical_tmpl
105
+ return self ._bitwise_tmpl
106
106
107
107
108
108
# Maps the name of an arithmetic JAX primitive to the code template that is used to
@@ -176,17 +176,29 @@ def write_tasklet_code(
176
176
# Maps the name of a logical primitive to the two code templates, first the integer
177
177
# case and second the boolean case, that are used to create the body of the mapped
178
178
# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators.
179
- _LOGICAL_OPERATION_TEMPLATES : Final [dict [str , tuple [str , str ]]] = {
180
- "or" : ("__out = (__in0) | (__in1)" , "__out = (__in0) or (__in1)" ),
181
- "not" : ("__out = ~(__in0)" , "__out = not (__in0)" ),
182
- "and" : ("__out = (__in0) & (__in1)" , "__out = (__in0) and (__in1)" ),
183
- "xor" : ("__out = (__in0) ^ (__in1)" , "__out = (__in0) != (__in1)" ),
179
+ _LOGICAL_OPERATION_TEMPLATES : Final [dict [str , dict [str , str ]]] = {
180
+ "or" : {
181
+ "bitwise_tmpl" : "__out = (__in0) | (__in1)" ,
182
+ "logical_tmpl" : "__out = (__in0) or (__in1)" ,
183
+ },
184
+ "not" : {
185
+ "bitwise_tmpl" : "__out = ~(__in0)" ,
186
+ "logical_tmpl" : "__out = not (__in0)" ,
187
+ },
188
+ "and" : {
189
+ "bitwise_tmpl" : "__out = (__in0) & (__in1)" ,
190
+ "logical_tmpl" : "__out = (__in0) and (__in1)" ,
191
+ },
192
+ "xor" : {
193
+ "bitwise_tmpl" : "__out = (__in0) ^ (__in1)" ,
194
+ "logical_tmpl" : "__out = (__in0) != (__in1)" ,
195
+ },
184
196
}
185
197
# fmt: on
186
198
187
199
188
200
# Instantiate the arithmetic and logical translators from the templates.
189
201
for pname , ptmpl in _ARITMETIC_OPERATION_TEMPLATES .items ():
190
202
translator .register_primitive_translator (ArithmeticOperationTranslator (pname , ptmpl ))
191
- for pname , ( itmpl , btmpl ) in _LOGICAL_OPERATION_TEMPLATES .items ():
192
- translator .register_primitive_translator (LogicalOperationTranslator (pname , itmpl , btmpl ))
203
+ for pname , ptmpl in _LOGICAL_OPERATION_TEMPLATES .items (): # type: ignore[assignment] # Type confusion
204
+ translator .register_primitive_translator (LogicalOperationTranslator (pname , ** ptmpl )) # type: ignore[arg-type] # Type confusion
0 commit comments