Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled #24354

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3020,14 +3020,27 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
FAILUREOR_ASSIGN_OR_RETURN(
Tiling memref_tiling,
getMemRefTiling(load_op.getBase(), ctx.target_shape));
if (memref_tiling != layout_out.tiling() &&
!(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 &&
memref_tiling[1] % layout_out.tiling()[1] == 0)) {
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
// TODO(b/295393167): need to support strided load for bitwidth < 32.
if (layout_out.bitwidth() != 32 ||
layout_out.tiling() != std::array<int64_t, 2>{1, ctx.target_shape[1]}) {
return op.emitOpError("Not implemented");
if (memref_tiling != layout_out.tiling()) {
if (memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 &&
memref_tiling[1] % layout_out.tiling()[1] == 0) {
// In this case, it is valid to use output tiling (1, 128 * packing) when
// loading from a 1D memref.
} else if (layout_out.bitwidth() == 32 &&
layout_out.tiling() ==
std::array<int64_t, 2>{1, ctx.target_shape[1]}) {
// In this case, it is valid to use output tiling (1, TARGET_SHAPE.lanes)
// because we strided-load one row from each tile of the memref. This can
// save us a bunch of loads!
// TODO(b/295393167): need to support strided load for bitwidth < 32.
} else if (layout_out.bitwidth() == 32 &&
canReinterpretToUntiledMemref(memref_ty, ctx.target_shape)) {
// In this case, if the memref can be reinterpreted to untiled, it is
// valid to use any tiling for output. But using native tiling can save us
// a bunch of loads!
} else {
return op.emitOpError(
"Not implemented: dismatch in memref tiling and vector tiling in "
"load");
}
}
// TODO(apaszke): Check that loads are from vmem!
Expand Down Expand Up @@ -4204,14 +4217,27 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
FAILUREOR_ASSIGN_OR_RETURN(
const Tiling memref_tiling,
getMemRefTiling(store_op.getBase(), ctx.target_shape));
if (memref_tiling != to_store_layout.tiling() &&
!(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 &&
memref_tiling[1] % to_store_layout.tiling()[1] == 0)) {
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
// TODO(b/295393167): need to support strided store for bitwidth < 32.
if (to_store_layout.bitwidth() != 32 ||
to_store_layout.tiling() != Tiling{1, ctx.target_shape[1]}) {
return op.emitOpError("Not implemented");
if (memref_tiling != to_store_layout.tiling()) {
if (memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 &&
memref_tiling[1] % to_store_layout.tiling()[1] == 0) {
// In this case, it is valid to have to_store tiling (1, 128 * packing)
// when storing to a 1D memref.
} else if (to_store_layout.bitwidth() == 32 &&
to_store_layout.tiling() ==
std::array<int64_t, 2>{1, ctx.target_shape[1]}) {
// In this case, it is valid to have to_store tiling (1,
// TARGET_SHAPE.lanes) because we strided-store one row to each tile of
// the memref. This can save us a bunch of stores!
// TODO(b/295393167): need to support strided store for bitwidth < 32.
} else if (to_store_layout.bitwidth() == 32 &&
canReinterpretToUntiledMemref(memref_ty, ctx.target_shape)) {
// In this case, if the memref can be reinterpreted to untiled, it is
// valid to use any tiling for to_store. But using native tiling can save
// us a bunch of stores!
} else {
return op.emitOpError(
"Not implemented: dismatch in memref tiling and vector tiling in "
"store");
}
}

Expand Down
112 changes: 91 additions & 21 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,18 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
int64_t leading_tile_rows = 0) {
if (auto tiled_layout_attr =
dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
return tiled_layout_attr;
// If expected leading_tile_rows does not match with the sublane tiling in
// the layout, we will override the layout.
if (!tiled_layout_attr.getTiles().empty() &&
tiled_layout_attr.getTiles().front().dimensions().size() == 2 &&
tiled_layout_attr.getTiles().front().dimensions()[0] ==
leading_tile_rows) {
return tiled_layout_attr;
} else {
memref_ty =
MemRefType::get(memref_ty.getShape(), memref_ty.getElementType(),
/*layout=*/nullptr, memref_ty.getMemorySpace());
}
}
if (auto affine_map_attr = dyn_cast<AffineMapAttr>(memref_ty.getLayout())) {
if (memref_ty.getRank() == 0) {
Expand Down Expand Up @@ -226,19 +237,38 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation,
target_shape, tpu_tiling_flags));
// If the memref can be reinterpreted to untiled, force to use tiling
// {packing, target.lane_count}.
int64_t leading_tile_rows = 0;
if (canReinterpretToUntiledMemref(memref_ty, target_shape)) {
leading_tile_rows = 32 / memref_ty.getElementTypeBitWidth();
}
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, target_shape,
tpu_tiling_flags, leading_tile_rows));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
builder.setInsertionPointAfter(alloca_op);
auto erase_op = builder.create<tpu::EraseLayoutOp>(
arg.getLoc(),
MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(),
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
arg);
arg.replaceAllUsesExcept(erase_op.getResult(), erase_op);
// We never forbid users from creating their own EraseLayoutOp. So we do
// not need to create EraseLayoutOp if there is already one.
bool has_erase_layout_op = false;
for (auto user : arg.getUsers()) {
if (auto erase_op = dyn_cast<tpu::EraseLayoutOp>(user)) {
has_erase_layout_op = true;
break;
}
}
if (!has_erase_layout_op) {
auto erase_op = builder.create<tpu::EraseLayoutOp>(
arg.getLoc(),
MemRefType::get(new_memref_ty.getShape(),
memref_ty.getElementType(),
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
arg);
arg.replaceAllUsesExcept(erase_op.getResult(), erase_op);
}
}
} else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
Expand Down Expand Up @@ -296,22 +326,62 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
}

FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, target_shape,
tpu_tiling_flags, leading_tile_rows));
arg.setType(new_memref_ty);
new_arg_types.push_back(arg.getType());
if (memref_ty != new_memref_ty) {
// Some standard MLIR ops have static checks that seems unreasonable,
// and we know they hold in the way they are used in Mosaic. Still,
// verification with layouts likes to fail, because it can't statically
// prove the properties.
auto erase_op = builder.create<tpu::EraseLayoutOp>(
arg.getLoc(),
MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(),
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
arg);
arg.replaceAllUsesExcept(erase_op.getResult(), erase_op);
Value val = arg;
Operation * arg_use_op = nullptr;
// If the arg memref can be reinterpreted to untiled, we can insert
// ReinterpretCastOp to use tiling {packing, target.lane_count} before
// EraseLayoutOp for only the arg memrefs and expect the rest memref
// layout inference is based on the casted layout automatically. This
// would help lift many restrictions in alignment check when consuming
// this memref.
if (canReinterpretToUntiledMemref(new_memref_ty, target_shape)) {
auto tiled_layout =
cast<tpu::TiledLayoutAttr>(new_memref_ty.getLayout());
SmallVector<xla::Tile> tiles(tiled_layout.getTiles());
tiles[0] = ::xla::Tile(
{32 / memref_ty.getElementTypeBitWidth(), target_shape[1]});
auto new_tile_strides =
ComputeTileStrides(new_memref_ty, tiles[0].dimensions());
new_memref_ty = MemRefType::get(
new_memref_ty.getShape(), new_memref_ty.getElementType(),
TiledLayoutAttr::get(new_memref_ty.getContext(), tiles,
new_tile_strides),
new_memref_ty.getMemorySpace());
arg_use_op = builder.create<tpu::ReinterpretCastOp>(val.getLoc(),
new_memref_ty, val);
val = arg_use_op->getResult(0);
}
// We never forbid users from creating their own EraseLayoutOp. So we do
// not need to create EraseLayoutOp if there is already one.
bool has_erase_layout_op = false;
for (auto user : arg.getUsers()) {
if (auto erase_op = dyn_cast<tpu::EraseLayoutOp>(user)) {
has_erase_layout_op = true;
break;
}
}
if (!has_erase_layout_op) {
// Some standard MLIR ops have static checks that seems unreasonable,
// and we know they hold in the way they are used in Mosaic. Still,
// verification with layouts likes to fail, because it can't statically
// prove the properties.
auto erase_op = builder.create<tpu::EraseLayoutOp>(
val.getLoc(),
MemRefType::get(new_memref_ty.getShape(),
memref_ty.getElementType(),
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
val);
if (!arg_use_op) {
arg_use_op = erase_op;
}
arg.replaceAllUsesExcept(erase_op.getResult(), arg_use_op);
}
}
}
f.setFunctionType(
Expand Down
29 changes: 23 additions & 6 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ limitations under the License.
#include "mlir/include/mlir/Pass/Pass.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/layout.h"

namespace mlir::tpu {
Expand Down Expand Up @@ -1187,7 +1188,7 @@ class VectorLayoutInferer {
}

LogicalResult infer(vector::LoadOp op) {
auto src_ty = op.getMemRefType();
auto src_ty = getMemRefType(op.getBase());
auto res_ty = op.getVectorType();
TPU_CHECK_OP(src_ty.getRank() == res_ty.getRank(),
"memref and vector rank mismatch");
Expand Down Expand Up @@ -1280,6 +1281,15 @@ class VectorLayoutInferer {
setLayout(op, in_layout,
VectorLayout(bitwidth, {std::nullopt, offsets[1]},
layout_tiling, ImplicitDim::kNone));
} else if (bitwidth == 32 &&
canReinterpretToUntiledMemref(src_ty, target_shape_) &&
*(src_ty.getShape().end() - 2) > 1) {
// Since it is untiled, we can load from any arbitrary address which
// means the sublane offset is always 0.
// Note: if the src_shape[-2] == 1, we can just use the tiling from ref.
setLayout(op, in_layout,
VectorLayout(bitwidth, {0, offsets[1].value_or(0)},
nativeTiling(bitwidth), ImplicitDim::kNone));
} else {
setLayout(
op, in_layout,
Expand Down Expand Up @@ -1515,7 +1525,7 @@ class VectorLayoutInferer {
}

LogicalResult infer(vector::StoreOp op) {
auto ref_ty = op.getMemRefType();
auto ref_ty = getMemRefType(op.getBase());
auto store_ty = op.getValueToStore().getType();
TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(),
"memref and vector rank mismatch");
Expand Down Expand Up @@ -1596,11 +1606,18 @@ class VectorLayoutInferer {
// We can strided store sublanes if we're storing a single sublane for
// multiple times. Enabling this helps store one entire row to memref
// more efficiently.
store_layout = VectorLayout(store_ty.getElementTypeBitWidth(), offsets,
{1, tiling[1]}, ImplicitDim::kNone);
store_layout =
VectorLayout(bitwidth, offsets, {1, tiling[1]}, ImplicitDim::kNone);
} else if (bitwidth == 32 && offsets[0].value_or(0) == 0 &&
offsets[1].value_or(0) == 0 &&
canReinterpretToUntiledMemref(ref_ty, target_shape_)) {
// Since it is untiled, we can store to any arbitrary address which
// means the sublane offset is 0.
store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
} else {
store_layout = VectorLayout(store_ty.getElementTypeBitWidth(), offsets,
{tiling[0], tiling[1]}, ImplicitDim::kNone);
store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]},
ImplicitDim::kNone);
}
}
SmallVector<Layout, 5> in_layout{store_layout};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ LogicalResult specializeMemorySpace(TypedValue<MemRefType> value,
to_update.pop_back();
// Here we only have to handle the operations allowed on refs with
// unspecified memory space.
if (auto op = dyn_cast<tpu::ReinterpretCastOp>(some_op)) {
updateResultFrom(op, op.getInput().getType());
continue;
}
if (auto op = dyn_cast<tpu::MemRefSliceOp>(some_op)) {
updateResultFrom(op, op.getMemRef().getType());
continue;
Expand Down
24 changes: 24 additions & 0 deletions jaxlib/mosaic/dialect/tpu/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ limitations under the License.

#include "jaxlib/mosaic/dialect/tpu/util.h"

#include <array>
#include <cstdint>

#include "llvm/Support/MathExtras.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"

namespace mlir::tpu {
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
Expand All @@ -39,4 +41,26 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
}
return tile_strides;
}

bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
const std::array<int64_t, 2>& target_shape) {
auto tiled_layout =
dyn_cast<tpu::TiledLayoutAttr>(tiled_memref_ty.getLayout());
if (!tiled_layout) {
// We expect the tiled memref to have a tiled layout.
return false;
}
if (tiled_layout.getTiles().empty() ||
tiled_layout.getTiles().front().dimensions().size() != 2 ||
tiled_memref_ty.getRank() < 2) {
// TODO(jevinjiang): Currently we only support >= 2D memref, we might
// need to handle 1D memref if we find a use case.
return false;
}
auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth();
return (*(tiled_memref_ty.getShape().end() - 1) <= target_shape[1] &&
*(tiled_memref_ty.getShape().end() - 2) % packing == 0 &&
*(tiled_layout.getTileStrides().end() - 1) == 1 &&
*(tiled_layout.getTileStrides().end() - 2) == 1);
}
} // namespace mlir::tpu
6 changes: 6 additions & 0 deletions jaxlib/mosaic/dialect/tpu/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ std::string shapeToString(const T &shape) {

SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
absl::Span<const int64_t> tiling);

// Returns true if a >=2D memref has layout and can be reinterpreted to an
// untiled memref which indicates data is contiguous and allows paddings in
// the minormost dimension.
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
const std::array<int64_t, 2> &target_shape);
} // namespace mlir::tpu

#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_
33 changes: 33 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,39 @@ def kernel(index, x, y, sem):
del y


def test_dynamic_dma_on_2nd_minor(self):
def kernel(array, data, index, size, _, sem):
pltpu.async_copy(
data.at[pl.ds(0, size[0])], array.at[pl.ds(index[0], size[0])], sem
).wait()

def run(array, data, index, size):
return pl.pallas_call(
kernel,
out_shape=array,
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.VMEM),
pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.SMEM),
],
scratch_shapes=[
pltpu.SemaphoreType.DMA,
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
input_output_aliases={0: 0},
)(array, data, index, size)

array = jnp.zeros((1024, 128), jnp.int32)
data = jnp.ones((8, 128), jnp.int32)
index = jnp.array([0], jnp.int32)
size = jnp.array([3], jnp.int32)

result = run(array, data, index, size)
assert jnp.all(result[:3] == data[:3])
assert jnp.all(result[3:] == array[3:])


class PallasCallDMAInterpretTest(PallasCallDMATest):
INTERPRET = True

Expand Down
Loading