Skip to content

Commit

Permalink
[pallas:mosaic_gpu] emit_pipeline no longer ignores transforms
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702726201
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 4, 2024
1 parent 2ac2692 commit 12b45b3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
19 changes: 19 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
def to_gpu_transform(self) -> mgpu.MemRefTransform:
pass

def batch(self, leading_rank: int):
"""Returns a transform that accepts a ref with the extra `leading_rank` dims.
The returned transform should leave the leading dimensions unchanged and
only apply to the suffix of the shape.
"""
raise NotImplementedError

def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
return aval.update(
shape=self.to_gpu_transform().transform_shape(aval.shape)
Expand All @@ -161,6 +169,9 @@ def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
ref, transforms=(*ref.transforms, UntileRef(self.tiling))
)

def batch(self, leading_rank: int):
return self

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)

Expand Down Expand Up @@ -228,6 +239,11 @@ def __post_init__(self):
if set(self.permutation) != set(range(len(self.permutation))):
raise ValueError(f"Permutation {self.permutation} is not a permutation.")

def batch(self, leading_rank: int):
return TransposeTransform(
(*range(leading_rank), *(d + leading_rank for d in self.permutation))
)

def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
return dataclasses.replace(
ref,
Expand Down Expand Up @@ -304,6 +320,9 @@ def __post_init__(self):
" accepted."
)

def batch(self, leading_rank: int):
return self

def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
return dataclasses.replace(
ref, transforms=(*ref.transforms, UnswizzleRef(self.swizzle))
Expand Down
8 changes: 7 additions & 1 deletion jax/_src/pallas/mosaic_gpu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
in_smem_refs, out_smem_refs = util.split_list(
[
gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore
gpu_core.SMEM(
(max_concurrent_steps, *spec.block_shape), # type: ignore
ref.dtype,
transforms=tuple(
t.batch(1) for t in getattr(spec, "transforms", ())
),
)
if _in_smem(spec)
else None
for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs)
Expand Down
25 changes: 19 additions & 6 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,6 @@ def kernel(x_ref, o_ref):
x = jnp.arange(256)
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256]))


@parameterized.parameters(jnp.float16, jnp.float32)
def test_wgmma(self, dtype):
self.skip_unless_sm90a()
Expand Down Expand Up @@ -1233,23 +1232,37 @@ def body(step, _):
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)

def test_emit(self):
@parameterized.parameters(
((),),
((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),),
)
def test_emit(self, transforms):
num_steps = 4

def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
in_specs=[
plgpu.GPUBlockSpec(
(64, 64), lambda i: (0, i), transforms=transforms
)
],
out_specs=[
plgpu.GPUBlockSpec(
(64, 64), lambda i: (0, i), transforms=transforms
)
],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)

def kernel_body(x_smem, o_smem):
# +1 for the indexing done by ``emit_pipeline`.
self.assertLen(x_smem.transforms, len(transforms) + 1)
o_smem[...] = x_smem[...] + 1.0

x = jnp.arange(32 * num_steps * 16)
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
x = jnp.arange(64 * num_steps * 64)
x = x.reshape(-1, num_steps * 64).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
Expand Down

0 comments on commit 12b45b3

Please sign in to comment.