Skip to content

Commit 71f9f86

Browse files
Applied Enriques primarly fixes.
1 parent d88752a commit 71f9f86

File tree

4 files changed

+30
-35
lines changed

4 files changed

+30
-35
lines changed

src/jace/translator/jaxpr_translator_builder.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,24 +179,6 @@ def append_new_state(
179179
self._ctx.terminal_state = new_state
180180
return new_state
181181

182-
def add_orphan_state(
183-
self,
184-
label: str,
185-
) -> dace.SDFGState:
186-
"""
187-
Add a new orphan state to the SDFG.
188-
189-
The state is not connected to any other state, nor it is the new start state.
190-
Except you know what you are doing you should not use this function and
191-
instead use `self.append_new_state()`.
192-
193-
Args:
194-
label: The name of the state.
195-
"""
196-
if not self.is_allocated():
197-
raise RuntimeError("Builder is not allocated.")
198-
return self._ctx.sdfg.add_state(label=label, is_start_block=False)
199-
200182
@property
201183
def arrays(self) -> Mapping[str, dace_data.Data]:
202184
"""
@@ -520,7 +502,8 @@ def _allocate_translation_ctx(
520502
@property
521503
def _ctx(self) -> TranslationContext:
522504
"""Returns the currently active translation context."""
523-
assert len(self._ctx_stack) != 0, "No context is active."
505+
if not self.is_allocated():
506+
raise RuntimeError("The context is not allocated.")
524507
return self._ctx_stack[-1]
525508

526509
def _clear_translation_ctx(self) -> TranslationContext | None:

src/jace/translator/primitive_translators/arithmetic_logical_translators.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase):
7979
8080
Args:
8181
prim_name: The name of the primitive that should be handled.
82-
int_tmpl: The template used for the integer case.
83-
bool_tmpl: The template used for the bool case.
82+
bitwise_tmpl: The template used for the bitwise case.
83+
logical_tmpl: The template used for the logical case.
8484
8585
Note:
8686
Since it does not make sense to single out `not` and keep the other
8787
logical operations in `ArithmeticOperationTranslator` all of them are
8888
handled by this class.
8989
"""
9090

91-
def __init__(self, prim_name: str, int_tmpl: str, bool_tmpl: str) -> None:
91+
def __init__(self, prim_name: str, bitwise_tmpl: str, logical_tmpl: str) -> None:
9292
super().__init__(primitive_name=prim_name)
93-
self._int_tmpl = int_tmpl
94-
self._bool_tmpl = bool_tmpl
93+
self._bitwise_tmpl = bitwise_tmpl
94+
self._logical_tmpl = logical_tmpl
9595

9696
@override
9797
def write_tasklet_code(
@@ -101,8 +101,8 @@ def write_tasklet_code(
101101
eqn: jax_core.JaxprEqn,
102102
) -> str:
103103
if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars):
104-
return self._bool_tmpl
105-
return self._int_tmpl
104+
return self._logical_tmpl
105+
return self._bitwise_tmpl
106106

107107

108108
# Maps the name of an arithmetic JAX primitive to the code template that is used to
@@ -176,17 +176,29 @@ def write_tasklet_code(
176176
# Maps the name of a logical primitive to the two code templates, first the integer
177177
# case and second the boolean case, that are used to create the body of the mapped
178178
# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators.
179-
_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = {
180-
"or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"),
181-
"not": ("__out = ~(__in0)", "__out = not (__in0)"),
182-
"and": ("__out = (__in0) & (__in1)", "__out = (__in0) and (__in1)"),
183-
"xor": ("__out = (__in0) ^ (__in1)", "__out = (__in0) != (__in1)"),
179+
_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, dict[str, str]]] = {
180+
"or": {
181+
"bitwise_tmpl": "__out = (__in0) | (__in1)",
182+
"logical_tmpl": "__out = (__in0) or (__in1)",
183+
},
184+
"not": {
185+
"bitwise_tmpl": "__out = ~(__in0)",
186+
"logical_tmpl": "__out = not (__in0)",
187+
},
188+
"and": {
189+
"bitwise_tmpl": "__out = (__in0) & (__in1)",
190+
"logical_tmpl": "__out = (__in0) and (__in1)",
191+
},
192+
"xor": {
193+
"bitwise_tmpl": "__out = (__in0) ^ (__in1)",
194+
"logical_tmpl": "__out = (__in0) != (__in1)",
195+
},
184196
}
185197
# fmt: on
186198

187199

188200
# Instantiate the arithmetic and logical translators from the templates.
189201
for pname, ptmpl in _ARITMETIC_OPERATION_TEMPLATES.items():
190202
translator.register_primitive_translator(ArithmeticOperationTranslator(pname, ptmpl))
191-
for pname, (itmpl, btmpl) in _LOGICAL_OPERATION_TEMPLATES.items():
192-
translator.register_primitive_translator(LogicalOperationTranslator(pname, itmpl, btmpl))
203+
for pname, ptmpl in _LOGICAL_OPERATION_TEMPLATES.items(): # type: ignore[assignment] # Type confusion
204+
translator.register_primitive_translator(LogicalOperationTranslator(pname, **ptmpl)) # type: ignore[arg-type] # Type confusion

src/jace/translator/primitive_translators/conditions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def condition_translator(
116116
branch_states.append(branch_state)
117117

118118
# Connect all branch states to the join state
119-
join_state = builder.add_orphan_state(f"{name_pattern}__join_state")
119+
join_state = builder._ctx.sdfg.add_state(label=f"{name_pattern}__join_state")
120120
for branch_state in branch_states:
121121
builder.sdfg.add_edge(
122122
branch_state,

src/jace/translator/primitive_translators/slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def make_input_memlets(
5757
eqn: jax_core.JaxprEqn,
5858
) -> dict[str, dace.Memlet]:
5959
strides: Sequence[int] = (
60-
((1,) * len(tskl_ranges)) if eqn.params["strides"] is None else eqn.params["strides"]
60+
eqn.params["strides"] if eqn.params["strides"] else ((1,) * len(tskl_ranges))
6161
)
6262
start_indices: Sequence[int] = eqn.params["start_indices"] # Fist index to slice
6363
return {

0 commit comments

Comments
 (0)