Skip to content

Commit

Permalink
[Mosaic:TPU] Allow null parts for tpu.pack_subelements, meaning "don'…
Browse files Browse the repository at this point in the history
…t care"

PiperOrigin-RevId: 703707282
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 9, 2024
1 parent 6f69774 commit 70871b5
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 63 deletions.
11 changes: 11 additions & 0 deletions jaxlib/mosaic/dialect/tpu/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ class VectorLayout {

int8_t bitwidth() const { return bitwidth_; }
const LayoutOffsets &offsets() const { return offsets_; }
const LayoutOffsets getCanonicalOffsets(
const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
// For (1, n) tiling with a single row, replication does not change
// anything about the layout - it is equivalent to an offset of 0. We
// choose a replicated offset as "canonical".
const std::array<int64_t, 2> tiled_ishape = getImplicitTiledDims(shape, 1);
return {
tiling_[0] == 1 && tiled_ishape[0] == 1 ? std::nullopt : offsets_[0],
offsets_[1]};
}
const std::array<int64_t, 2> &tiling() const { return tiling_; }
ImplicitDim implicit_dim() const { return implicit_dim_; }
int packing() const { return 32 / bitwidth_; }
Expand Down
12 changes: 10 additions & 2 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,21 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
}

// Integer packs are always signed at the moment.
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> {
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> {
let arguments = (ins
Variadic<AnyVectorOfNonZeroRank>:$sources,
Variadic<TPU_Vreg>:$sources,
DenseI32ArrayAttr:$positions,
TPU_PackFormatEnum:$pack_format
);
let results = (outs AnyVectorOfNonZeroRank:$output);
let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
let builders = [
OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>,
];
let extraClassDeclaration = [{
static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef<int32_t> positions, int packing_factor);
}];
let hasVerifier = 1;
}

def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
Expand Down
49 changes: 49 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstddef>
#include <cstdint>
#include <optional>
#include <string_view>
Expand Down Expand Up @@ -1087,6 +1088,54 @@ LogicalResult LogOp::verify() {
stringifyCoreType(logging_core_type_maybe->value())));
}

void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
const VectorType output_type,
const ArrayRef<Value> padded_sources,
const PackFormat pack_format) {
SmallVector<Value> sources;
SmallVector<int32_t> positions;
for (size_t i = 0; i < padded_sources.size(); ++i) {
if (padded_sources[i] != nullptr) {
sources.push_back(padded_sources[i]);
positions.push_back(i);
}
}
build(builder, state, output_type, sources, positions, pack_format);
}

/*static*/ SmallVector<Value> PackSubelementsOp::getPaddedSources(
ValueRange sources, const ArrayRef<int32_t> positions,
const int packing_factor) {
SmallVector<Value> padded_sources(packing_factor);
for (const auto [source, position] : llvm::zip(sources, positions)) {
padded_sources[position] = source;
}
return padded_sources;
}

LogicalResult PackSubelementsOp::verify() {
if (getSources().empty()) {
return emitOpError("At least one source is required");
}
if (getPositions().size() != getSources().size()) {
return emitOpError("Size of sources and positions must match");
}
const int packing_factor = cast<VectorType>(getSources().front().getType())
.getElementTypeBitWidth() /
getType().getElementTypeBitWidth();
SmallVector<bool> seen_positions(packing_factor);
for (const int32_t position : getPositions()) {
if (position < 0 || packing_factor <= position) {
return emitOpError("Positions must be between 0 and the packing factor");
}
if (seen_positions[position]) {
return emitOpError("Positions must be unique");
}
seen_positions[position] = true;
}
return success();
}

} // namespace tpu
} // namespace mlir

Expand Down
165 changes: 104 additions & 61 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1034,11 +1034,16 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
if (!layout_out.offsets()[1].has_value()) {
parts.append(packing, input_vregs(idxs_local));
} else {
for (int64_t i = 0; i < packing; ++i) {
if (idxs_local.back() < input_vregs.dimensions().back()) {
parts.push_back(input_vregs(idxs_local));
++idxs_local.back();
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
Expand All @@ -1051,16 +1056,17 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= packing;
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
while (parts.size() < packing) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
} else {
// Once we run out of tiles, we can pick any one we like.
parts.push_back(parts.back());
*(idxs_local.end() - 2) *= packing;
if (!layout_out.offsets()[0].has_value()) {
parts.append(packing, input_vregs(idxs_local));
} else {
for (int64_t i = 0; i < packing; ++i) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
++*(idxs_local.end() - 2);
} else {
parts.push_back(nullptr);
}
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
Expand Down Expand Up @@ -6159,45 +6165,68 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
const auto &target_shape = ctx.target_shape;
const std::array<int64_t, 2> src_tiling = src.tiling();
const LayoutOffsets src_offsets =
src.getCanonicalOffsets(vty.getShape(), ctx.target_shape);
if (src_tiling == dst_tiling) {
return std::pair(src, std::move(vregs));
}
const int packing = src.packing();
const int8_t bitwidth = src.bitwidth();
// Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating
// sublanes.
if (try_replicate_rows && packing == 1 &&
*(vregs.dimensions().end() - 2) == 1 &&
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
dst_tiling == ctx.target_shape) {
DCHECK_EQ(src.offsets()[0].value_or(0), 0);
const LayoutOffset dst_minor_offset =
src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1])
: std::nullopt;
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);

// Fully replicated offsets are handled efficiently elsewhere (in relayout)
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());

// Handle small-to-large retiling with replicated 2nd minor.
// This retiling is one-to-many vregs.
// TODO(tlongeri): Large-to-small retiling with replicated minor is analogous
// to this.
if (src_tiling[1] == ctx.target_shape[1] &&
dst_tiling[1] == ctx.target_shape[1] &&
dst_tiling[0] % src_tiling[0] == 0 && !src_offsets[0].has_value() &&
// For (packing, 128) tiling, prefer scratch retiling for older
// generations, where gather/broadcast is more expensive:
(try_replicate_rows || src_tiling[0] > packing ||
ctx.hardware_generation >= 5)) {
DCHECK(src_offsets[1].has_value()); // Otherwise handled elsewhere
const int64_t dst_minor_offset = *src.offsets()[1] % dst_vreg_slice[1];
const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset},
dst_tiling, src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
*(src_idx.end() - 2) *= target_shape[0];
if (!src.offsets()[1].has_value()) {
// With (1, 128) tiling each vreg holds values from a single row. This
// means that if the columns are replicated, then the whole vreg is
// already replicated.
*(src_idx.end() - 1) = 0;
*tile = vregs(src_idx);
} else {
// The column (in units of sublanes) of the sublane we want:
const int64_t sublane_column =
*(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1];
*(src_idx.end() - 1) = sublane_column / target_shape[0];
const int64_t src_sl_idx = sublane_column % target_shape[0];
*tile =
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
const SmallVector<int64_t> dst_vreg_array_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
const int64_t dst_tiles_per_vreg = dst.tilesPerVreg(ctx.target_shape);
const int64_t src_sublanes_per_tile = src.sublanesPerTile(ctx.target_shape);
const int64_t dst_sublanes_per_tile = dst.sublanesPerTile(ctx.target_shape);
const int64_t tiling_ratio = dst_tiling[0] / src_tiling[0];
xla::Array<Value> retiled(dst_vreg_array_shape);
SmallVector<int64_t> idxs;
retiled.Each([&](absl::Span<const int64_t> dst_idx, Value *vreg) {
const int64_t dst_col_idx = *(dst_idx.end() - 1);
const int64_t dst_col_idx_with_offsets =
dst_col_idx + *src.offsets()[1] / dst_vreg_slice[1];
const int64_t src_col_idx = dst_col_idx_with_offsets / tiling_ratio;
const int64_t src_part_idx = dst_col_idx_with_offsets % tiling_ratio;
SmallVector<int32_t, 8> gather_pattern;
// Iterate over the sublanes in the dst vreg:
for (int32_t sublane = 0; sublane < ctx.target_shape[0]; ++sublane) {
const int64_t src_tile_idx =
sublane / dst_sublanes_per_tile + src_part_idx * dst_tiles_per_vreg;
// Although replication may give us several sublanes to choose from,
// we always gather from the first sublane in the source tile. This
// degenerates to a broadcast when dst_tiling is native, which can
// be cheaper than an arbitrary gather (for some hardware gens).
const int64_t src_sublane = src_tile_idx * src_sublanes_per_tile;
gather_pattern.push_back(src_sublane);
}
idxs.assign(dst_idx.begin(), dst_idx.end());
*(idxs.end() - 2) = 0;
*(idxs.end() - 1) = src_col_idx;
Value src_vreg = vregs(idxs);
*vreg = builder.create<tpu::GatherOp>(loc, src_vreg.getType(), src_vreg,
gather_pattern,
/*dimension=*/0);
});
// We have successfully replicated sublanes
return std::pair(dst, std::move(retiled));
}
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
Expand All @@ -6211,6 +6240,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * dst.packing(),
ctx.target_shape[1]}) {
DCHECK(src_offsets[0].has_value()); // Otherwise handled above
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
xla::Array<Value> retiled(dst_tiles_shape);
Expand All @@ -6227,15 +6257,17 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
src_idx[src_idx.size() - 2] *= vty_packing;
src_idx[src_idx.size() - 1] /= vty_packing;
for (int i = 0; i < vty_packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
if (src_idx[src_idx.size() - 2] <
vregs.dim(vregs.num_dimensions() - 2) - 1) {
++src_idx[src_idx.size() - 2];
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
++*(src_idx.end() - 2);
} else {
parts.push_back(nullptr);
}
}
*tile = builder.create<tpu::PackSubelementsOp>(
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed);
loc, cast<VectorType>(vregs.begin()->getType()), parts,
tpu::PackFormat::kCompressed);
});
return std::pair(dst, std::move(retiled));
}
Expand Down Expand Up @@ -6306,16 +6338,24 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
*(src_idx.end() - 2) *= packing;
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
*(src_idx.end() - 1) /= packing;
for (int i = 0; i < packing; ++i) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
++*(src_idx.end() - 2);
} // The rest is padding, so just pick any of the input parts (but not
// an arbitrary vreg so we don't add an extra dependency).
if (try_replicate_rows) {
DCHECK(!src_offsets[0].has_value());
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
} else {
for (int i = 0; i < packing; ++i) {
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_x32, vregs(src_idx), vreg_part));
++*(src_idx.end() - 2);
} else {
parts.push_back(nullptr);
}
}
}
*tile = builder.create<tpu::PackSubelementsOp>(
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kInterleaved);
loc, cast<VectorType>(vregs.begin()->getType()), parts,
tpu::PackFormat::kInterleaved);
});
return std::pair(dst, std::move(retiled));
}
Expand Down Expand Up @@ -6574,8 +6614,11 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
return assemble_with_mask_check(src_tiles,
/*use_implicit_shape=*/true);
}
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
!src.offsets()[1].has_value()) {

if (const LayoutOffsets src_offsets =
src.getCanonicalOffsets(vty.getShape(), ctx.target_shape);
src.layout_rank() >= dst.layout_rank() && !src_offsets[0].has_value() &&
!src_offsets[1].has_value()) {
// A fully replicated value is always easy to relayout
xla::Array<Value> dst_tiles(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
Expand Down

0 comments on commit 70871b5

Please sign in to comment.