Skip to content

Commit d4609e0

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Avoid unnecessary conversion to and from ptr for lowering of dialect barriers.
PiperOrigin-RevId: 788907078
1 parent 89b91ca commit d4609e0

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

jax/experimental/mosaic/gpu/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -949,11 +949,8 @@ def get_ptr(self):
949949
def as_barrier_memref(self) -> ir.Value:
950950
num_barriers = self.barrier_ref.num_barriers
951951
shape = () if num_barriers == 1 else (num_barriers,)
952-
return ptr_as_memref(
953-
self.get_ptr(),
954-
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
955-
ptr_memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE,
956-
)
952+
memref_type = ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier"))
953+
return builtin.unrealized_conversion_cast([memref_type], [self.get_ptr()])
957954

958955
@classmethod
959956
def from_barrier_memref(cls, barrier: ir.Value):
@@ -967,17 +964,18 @@ def from_barrier_memref(cls, barrier: ir.Value):
967964
f"!mosaic_gpu.barrier, but got {barrier.type}"
968965
)
969966

967+
ptr_type = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>")
968+
addr = builtin.unrealized_conversion_cast([ptr_type], [barrier])
970969
return cls(
971970
barrier_ref=BarrierRef(
972-
base_address=memref_ptr(
973-
barrier, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE
974-
),
971+
base_address=addr,
975972
offset=c(0, ir.IntegerType.get_signless(64)),
976973
phases=None,
977974
num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]),
978975
)
979976
)
980977

978+
981979
@dataclasses.dataclass(frozen=True)
982980
class CollectiveBarrierRef:
983981
barrier: BarrierRef

0 commit comments

Comments
 (0)