Skip to content

[Mosaic GPU] Resolve different tile transforms using the largest common divisor. #29006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions jax/experimental/mosaic/gpu/transform_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections.abc import Callable
from functools import partial
import itertools
import math
from typing import cast

from jax._src.lib import mosaic_gpu_dialect as mgpu
Expand Down Expand Up @@ -84,12 +85,32 @@ def _resolve_transforms(
if other_transforms is None:
return transforms

if transforms != other_transforms:
if len(transforms) != len(other_transforms):
raise NotImplementedError(
f"Conflicting transforms {transforms} != {other_transforms}."
)

return transforms
new_transforms = []
for a, b in zip(transforms, other_transforms, strict=True):
if mgpu.TileTransformAttr.isinstance(a) and mgpu.TileTransformAttr.isinstance(b):
a = mgpu.TileTransformAttr(a)
b = mgpu.TileTransformAttr(b)
if (len(a.tiling) != len(b.tiling)):
raise ValueError(f"Conflicting tile transforms {a} != {b}.")
new_tiling = []
for tile_a, tile_b in zip(a.tiling, b.tiling):
new_tiling.append(math.gcd(tile_a, tile_b))
new_transforms.append(mgpu.TileTransformAttr.get(new_tiling))
elif mgpu.SwizzleTransformAttr.isinstance(a) and mgpu.SwizzleTransformAttr.isinstance(b):
a = mgpu.SwizzleTransformAttr(a)
b = mgpu.SwizzleTransformAttr(b)
if a.swizzle != b.swizzle:
raise ValueError(f"Swizzle transforms must match, got {a} and {b}.")
new_transforms.append(a)
else:
raise NotImplementedError(f"Unsupported transforms {a} and {b}")

return ir.ArrayAttr.get(new_transforms)


def _transforms_from_uses(op: ir.OpView) -> ir.Attribute | None:
Expand Down Expand Up @@ -302,7 +323,7 @@ def _infer_memref_subview_transforms(
# - We only propagate transforms if they consist of a single tile transform
# and a single swizzle transform.
# TODO(bchetioui): implement more complex propagation rules.
tile_transform, _ = _get_tile_and_swizzle_transforms(transforms)
tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms(transforms)

# Check swizzle transform propagation.
strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type)
Expand All @@ -314,17 +335,17 @@ def _infer_memref_subview_transforms(
)

# Check tile transform propagation.
num_tiled_axes = len(mgpu.TileTransformAttr(tile_transform).tiling)
old_tiling = mgpu.TileTransformAttr(tile_transform).tiling
num_tiled_axes = len(old_tiling)
last_n_dims = op.source.type.shape[-num_tiled_axes:]
last_n_sizes = list(op.static_sizes)[-num_tiled_axes:]
for slice_size, dim_size in safe_zip(last_n_sizes, last_n_dims):
if slice_size != dim_size:
raise NotImplementedError(
"Tile transforms are only propagated if the tiled axes are not "
"sliced."
)
new_tiling = []

return [transforms], [transforms]
for slice_size, dim_size, old_tile in safe_zip(last_n_sizes, last_n_dims, old_tiling):
new_tiling.append(math.gcd(slice_size, dim_size, old_tile))
new_transforms = ir.ArrayAttr.get([mgpu.TileTransformAttr.get(new_tiling), swizzle_transform])

return [new_transforms], [new_transforms]


@partial(_add_transform_inference_rule, memref.TransposeOp)
Expand Down Expand Up @@ -411,14 +432,17 @@ def inference_step(op: ir.Operation):

_set_transform_attributes(op, *maybe_transforms)

# It's enough to do a single backwards propagation (starting from vector
# users), and then a single forward propagation (to feed into the async loads
# and stores).
for op in module.body:
inference_utils.traverse_op(
op, inference_step, inference_utils.TraversalOrder.BACKWARDS
)
for op in module.body:
inference_utils.traverse_op(
op, inference_step, inference_utils.TraversalOrder.FORWARD
)
# We alternate a few backwards propagation (starting from vector users), and
# forward propagation (to feed into the async loads and stores) passes in
# order to enable more complex inference situations.
#
# TODO(bchetioui): Replace this with a more generic inference.
inference_passes = [
inference_utils.TraversalOrder.BACKWARDS,
inference_utils.TraversalOrder.FORWARD,
inference_utils.TraversalOrder.BACKWARDS,
inference_utils.TraversalOrder.FORWARD,
]
for traversal_order in inference_passes:
for op in module.body:
inference_utils.traverse_op(op, inference_step, traversal_order)
105 changes: 105 additions & 0 deletions tests/mosaic/gpu_transform_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,111 @@ def body(in_ref):
with self.assertRaises(NotImplementedError):
mgpu.infer_transforms(self.module)

def test_infer_transforms_for_sibling_subviews_and_distant_op(self):
# This test uses the following op tree extracted from a ragged dot kernel:
#
# subview_op0 (slice = 64, 64)
# - subview_op1 (slice = 1, 64)
# - subview_op2 (slice = 2, 64)
# - subview_op3 (slice = 8, 64)
# - user_op0 (in_transforms = [tile(64, 64), swizzle(32)])
#
# First the in_transforms of user_op0 have to be propagated up to
# subview_op0. Then they have to be propagated down and resolved. Finally
# all ops, including user_op0, need to have the same transforms.
subview_op0, subview_op1, subview_op2, subview_op3 = None, None, None, None
user_op0 = None

source_shape = (64, 64)
elt_ty = ir.BF16Type.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
source_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=smem)

slice1_shape = (1, 64)
slice2_shape = (2, 64)
slice3_shape = (8, 64)

slice0_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=smem)
slice1_ref_ty = ir.MemRefType.get(slice1_shape, elt_ty, memory_space=smem)
slice2_ref_ty = ir.MemRefType.get(slice2_shape, elt_ty, memory_space=smem)
slice3_ref_ty = ir.MemRefType.get(slice3_shape, elt_ty, memory_space=smem)

def body(source_ref):
nonlocal subview_op0, subview_op1, subview_op2, subview_op3, user_op0

subview_op0 = memref.SubViewOp(
slice0_ref_ty,
source_ref,
[], # dynamic offsets
[], # dynamic sizes
[], # dynamic strides
static_offsets=[0, 0],
static_sizes=source_shape,
static_strides=[1, 1],
)
user_op0 = builtin.UnrealizedConversionCastOp(
[slice0_ref_ty], [subview_op0.result]
)
transforms_0 = ir.ArrayAttr.get([
mgpu.dialect.TileTransformAttr.get((64, 64)),
mgpu.dialect.SwizzleTransformAttr.get(32),
])
user_op0.attributes["in_transforms"] = ir.ArrayAttr.get([transforms_0])

subview_op1 = memref.SubViewOp(
slice1_ref_ty,
subview_op0,
[], # dynamic offsets
[], # dynamic sizes
[], # dynamic strides
static_offsets=[0, 0],
static_sizes=slice1_shape,
static_strides=[1, 1],
)

subview_op2 = memref.SubViewOp(
slice2_ref_ty,
subview_op0,
[], # dynamic offsets
[], # dynamic sizes
[], # dynamic strides
static_offsets=[15, 0],
static_sizes=slice2_shape,
static_strides=[1, 1],
)

subview_op3 = memref.SubViewOp(
slice3_ref_ty,
subview_op0,
[], # dynamic offsets
[], # dynamic sizes
[], # dynamic strides
static_offsets=[30, 0],
static_sizes=slice3_shape,
static_strides=[1, 1],
)

with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(source_ref_ty)(body)

mgpu.infer_transforms(self.module)

want = ir.ArrayAttr.get([
mgpu.dialect.TileTransformAttr.get((1, 64)),
mgpu.dialect.SwizzleTransformAttr.get(32),
])

self.assertSequenceEqual(inference_utils.in_transforms(subview_op0), [want])
self.assertSequenceEqual(inference_utils.out_transforms(subview_op0), [want])
self.assertSequenceEqual(inference_utils.in_transforms(subview_op1), [want])
self.assertSequenceEqual(inference_utils.out_transforms(subview_op1), [want])
self.assertSequenceEqual(inference_utils.in_transforms(subview_op2), [want])
self.assertSequenceEqual(inference_utils.out_transforms(subview_op2), [want])
self.assertSequenceEqual(inference_utils.in_transforms(subview_op3), [want])
self.assertSequenceEqual(inference_utils.out_transforms(subview_op3), [want])
self.assertSequenceEqual(inference_utils.in_transforms(user_op0), [want])
self.assertSequenceEqual(inference_utils.out_transforms(user_op0), [want])


if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
Loading