@@ -949,11 +949,8 @@ def get_ptr(self):
949949 def as_barrier_memref (self ) -> ir .Value :
950950 num_barriers = self .barrier_ref .num_barriers
951951 shape = () if num_barriers == 1 else (num_barriers ,)
952- return ptr_as_memref (
953- self .get_ptr (),
954- ir .MemRefType .get (shape , ir .Type .parse ("!mosaic_gpu.barrier" )),
955- ptr_memory_space = WORKGROUP_NVPTX_ADDRESS_SPACE ,
956- )
952+ memref_type = ir .MemRefType .get (shape , ir .Type .parse ("!mosaic_gpu.barrier" ))
953+ return builtin .unrealized_conversion_cast ([memref_type ], [self .get_ptr ()])
957954
958955 @classmethod
959956 def from_barrier_memref (cls , barrier : ir .Value ):
@@ -967,17 +964,18 @@ def from_barrier_memref(cls, barrier: ir.Value):
967964 f"!mosaic_gpu.barrier, but got { barrier .type } "
968965 )
969966
967+ ptr_type = ir .Type .parse (f"!llvm.ptr<{ WORKGROUP_NVPTX_ADDRESS_SPACE } >" )
968+ addr = builtin .unrealized_conversion_cast ([ptr_type ], [barrier ])
970969 return cls (
971970 barrier_ref = BarrierRef (
972- base_address = memref_ptr (
973- barrier , memory_space = WORKGROUP_NVPTX_ADDRESS_SPACE
974- ),
971+ base_address = addr ,
975972 offset = c (0 , ir .IntegerType .get_signless (64 )),
976973 phases = None ,
977974 num_barriers = (1 if memref_type .rank == 0 else memref_type .shape [0 ]),
978975 )
979976 )
980977
978+
981979@dataclasses .dataclass (frozen = True )
982980class CollectiveBarrierRef :
983981 barrier : BarrierRef
0 commit comments