Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(
node, ("lift", "shift", "reduce", "map_", "index")
node, ("lift", "shift", "neighbors", "reduce", "map_", "index")
) or cpm.is_applied_lift(node):
return False
return True
Expand Down Expand Up @@ -396,7 +396,7 @@ def extract_subexpression(
if not eligible_ids:
continue

expr_id = uid_generator.sequential_id()
expr_id = uid_generator.sequential_id(prefix="_cs")
extracted[itir.Sym(id=expr_id)] = expr
expr_ref = itir.SymRef(id=expr_id)
for id_ in eligible_ids:
Expand Down Expand Up @@ -435,9 +435,7 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):

# we use one UID generator per instance such that the generated ids are
# stable across multiple runs (required for caching to properly work)
uids: UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_cs")
)
uids: UIDGenerator = dataclasses.field(repr=False)

collect_all: bool = dataclasses.field(default=False)

Expand All @@ -447,6 +445,8 @@ def apply(
node: ProgramOrExpr,
within_stencil: bool | None = None,
offset_provider_type: common.OffsetProviderType | None = None,
*,
uids: UIDGenerator | None = None,
) -> ProgramOrExpr:
is_program = isinstance(node, itir.Program)
if is_program:
Expand All @@ -457,11 +457,14 @@ def apply(
"The expression's context must be specified using `within_stencil`."
)

if not uids:
uids = UIDGenerator()

offset_provider_type = offset_provider_type or {}
node = itir_type_inference.infer(
node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program
)
return cls().visit(node, within_stencil=within_stencil)
return cls(uids=uids).visit(node, within_stencil=within_stencil)

def generic_visit(self, node, **kwargs):
if cpm.is_call_to(node, "as_fieldop"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node:

def _post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node:
if self.REINFER_TYPES:
itir_type_inference.reinfer(new_node)
kwargs = {}
if hasattr(self, "offset_provider_type"):
kwargs["offset_provider_type"] = self.offset_provider_type
itir_type_inference.reinfer(new_node, **kwargs)
self._preserve_annex(node, new_node)
return new_node

Expand Down
13 changes: 10 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
misc as ir_misc,
)
from gt4py.next.iterator.transforms import (
cse,
fixed_point_transformation,
inline_center_deref_lift_vars,
inline_lambdas,
Expand Down Expand Up @@ -182,6 +183,9 @@ def fuse_as_fieldop(
new_stencil, opcount_preserving=True, force_inline_lift_args=True
)
new_stencil = inline_lifts.InlineLifts().visit(new_stencil)
new_stencil = cse.CommonSubexpressionElimination.apply(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try copying over the type of new_stencil

new_stencil, within_stencil=True, uids=uids
)

new_node = im.as_fieldop(new_stencil, domain)(*new_args.values())

Expand Down Expand Up @@ -278,6 +282,7 @@ def all(self) -> FuseAsFieldOp.Transformation:
enabled_transformations = Transformation.all()

uids: eve_utils.UIDGenerator
offset_provider_type: common.OffsetProviderType

@classmethod
def apply(
Expand All @@ -304,9 +309,11 @@ def apply(
if not uids:
uids = eve_utils.UIDGenerator()

new_node = cls(uids=uids, enabled_transformations=enabled_transformations).visit(
node, within_set_at_expr=within_set_at_expr
)
new_node = cls(
uids=uids,
enabled_transformations=enabled_transformations,
offset_provider_type=offset_provider_type,
).visit(node, within_set_at_expr=within_set_at_expr)
# The `FuseAsFieldOp` pass does not fully preserve the type information yet. In particular
# for the generated lifts this is tricky and error-prone. For simplicity, we just reinfer
# everything here ensuring later passes can use the information.
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def apply_common_transforms(
tmp_uids = eve_utils.UIDGenerator(prefix="__tmp")
mergeasfop_uids = eve_utils.UIDGenerator()
collapse_tuple_uids = eve_utils.UIDGenerator()
cse_uids = eve_utils.UIDGenerator()

ir = MergeLet().visit(ir)
ir = inline_fundefs.InlineFundefs().visit(ir)
Expand Down Expand Up @@ -128,7 +129,9 @@ def apply_common_transforms(

# breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined
if common_subexpression_elimination:
ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type)
ir = CommonSubexpressionElimination.apply(
ir, offset_provider_type=offset_provider_type, uids=cse_uids
)
ir = MergeLet().visit(ir)
ir = InlineLambdas.apply(ir, opcount_preserving=True)

Expand Down
Loading