Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxlib/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ cc_library(
":tpu_inc_gen",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down
38 changes: 38 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,44 @@ def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]>
let hasVerifier = 1;
}

def TPU_PackElementwiseOp : TPU_Op<"pack_elementwise", [Pure, SameTypeOperands, ElementwiseMappable]> {
let description = [{
Packs multiple `sources` elementwise into a single vector of a narrower `target_type`.

The number of `sources` must equal the packing factor, which is the ratio of
the element bitwidth of the `sources` to the element bitwidth of the
`target_type`. Elements from the `sources` are interleaved and packed into
each word of the `output`, ordered from lowest to highest bits,
corresponding to their order in the `sources`.
}];
let arguments = (ins
Variadic<VectorOfNonZeroRankOf<[F32, I32]>>:$sources,
TypeAttr:$target_type
);
let results = (outs VectorOfNonZeroRankOf<[I32]>:$output);
let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
let hasVerifier = 1;
}

def TPU_UnpackElementwiseOp : TPU_Op<"unpack_elementwise", [Pure, ElementwiseMappable]> {
let description = [{
Unpacks a single vector from `source`, which contains multiple `source_type`
vectors packed elementwise.

The `index` selects which packed value to extract from each word of `source`.
An `index` of 0 corresponds to the lowest bits. The extracted values are
cast to the output element type.
}];
let arguments = (ins
VectorOfNonZeroRankOf<[I32]>:$source,
TypeAttr:$source_type,
I32Attr:$index
);
let results = (outs VectorOfNonZeroRankOf<[F32, I32]>:$output);
let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
let hasVerifier = 1;
}

def TPU_RelayoutOp : TPU_Op<"relayout", [Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyVectorOfAnyRank:$input);
let results = (outs AnyVectorOfAnyRank:$output);
Expand Down
53 changes: 52 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1817,7 +1817,7 @@ LogicalResult ReciprocalOp::verify() {
LogicalResult UnpackSubelementsOp::verify() {
const int packing_factor = getType().getElementTypeBitWidth() /
getSource().getType().getElementTypeBitWidth();
if (auto index = getIndex(); index < 0 || index >= packing_factor) {
if (auto index = getIndex(); index >= packing_factor) {
return emitOpError("Index must be between 0 and the packing factor (")
<< packing_factor << "), got " << index;
}
Expand Down Expand Up @@ -1912,6 +1912,57 @@ LogicalResult PackSubelementsOp::verify() {
return success();
}

namespace {
LogicalResult verifyElementwisePacking(Operation *op, Type unpacked_ty,
Type packed_ty) {
if (unpacked_ty.isF32() && !packed_ty.isBF16()) {
return op->emitOpError(
"Only packing/unpacking between f32 and bf16 is supported for floats");
}
if (unpacked_ty.isSignlessInteger(32) &&
!packed_ty.isSignlessInteger(16) &&
!packed_ty.isSignlessInteger(8) &&
!packed_ty.isSignlessInteger(4)) {
return op->emitOpError(
"Only packing/unpacking between i32 and i16/i8/i4 is supported for "
"integers");
}
return success();
}
} // namespace

LogicalResult PackElementwiseOp::verify() {
if (getSources().empty()) {
return emitOpError("At least one source is required");
}
const auto src_vty = cast<VectorType>(getSources().front().getType());
if (failed(verifyElementwisePacking(*this, src_vty.getElementType(),
getTargetType()))) {
return failure();
}
const int packing_factor =
src_vty.getElementTypeBitWidth() /
getTargetType().getIntOrFloatBitWidth();
if (packing_factor != getSources().size()) {
return emitOpError("The number of sources must match the packing factor (")
<< packing_factor << "), got " << getSources().size();
}
return success();
}

LogicalResult UnpackElementwiseOp::verify() {
if (failed(verifyElementwisePacking(*this, getType(), getSourceType()))) {
return failure();
}
const int packing_factor = getType().getElementTypeBitWidth() /
getSourceType().getIntOrFloatBitWidth();
if (auto index = getIndex(); index >= packing_factor) {
return emitOpError("Index must be between 0 and the packing factor (")
<< packing_factor << "), got " << index;
}
return success();
}

LogicalResult DynamicGatherOp::verify() {
const int64_t rank = getSource().getType().getRank();
SmallVector<bool> seen(rank, false);
Expand Down
121 changes: 98 additions & 23 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "llvm/ADT/APInt.h"
Expand Down Expand Up @@ -692,50 +693,46 @@ FailureOr<xla::Array<Value>> insertImplicitMinorDimension(
return new_vregs;
}

LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_OP(OpTrait::hasElementwiseMappableTraits(&op));
if (op.getNumResults() != 1) {
return op.emitError("Not implemented: Only ops with one result supported");
}
// A generic rule for elementwise operations that applies a given function to
// each vreg of the operands.
LogicalResult elementwise_op_rule_impl(
RewriteContext &ctx,
Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
absl::AnyInvocable<Value(ImplicitLocOpBuilder &, ArrayRef<Value>)>
vreg_op_creator) {
TPU_ASSERT_EQ_OP(layouts_in.size(), op.getNumOperands());
TPU_ASSERT_GT_OP(layouts_in.size(), 0);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
OpBuilder builder(&op);
if (!(layouts_out.front().has_value() &&
llvm::all_of(layouts_in,
[&](const Layout &l) { return l.has_value(); }))) {
return op.emitOpError(
"Not implemented: Null layout / non-vector operand in elementwise "
"operation");
}
const auto vty = cast<VectorType>(op.getResult(0).getType());
const VectorLayout &layout = *layouts_out.front();
if (!llvm::all_of(layouts_in,
[&](const Layout &l) { return layout == *l; })) {
return op.emitOpError(
"Not implemented: Different layouts in elementwise operation");
}

ImplicitLocOpBuilder builder(op.getLoc(), &op);
const unsigned num_operands = op.getNumOperands();
SmallVector<xla::Array<Value>> in_vreg_arrays;
in_vreg_arrays.reserve(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tile_array,
const xla::Array<Value> tile_array,
disassemble(builder, *layouts_in[i],
cast<TypedValue<VectorType>>(op.getOperand(i)),
ctx.target_shape));
in_vreg_arrays.emplace_back(std::move(tile_array));
}

const VectorType out_vreg_ty = getNativeVregOrVmaskType(
vty.getElementType(), layout.bitwidth(), ctx.target_shape);

NamedAttrList attributes(op.getAttrDictionary());
attributes.erase("in_layout");
attributes.erase("out_layout");

const auto vty = cast<VectorType>(op.getResult(0).getType());
// TODO(tlongeri): Can we avoid initializing the array before filling values?
xla::Array<Value> out_vreg_array(
layout.tileArrayShape(vty.getShape(), ctx.target_shape));
Expand All @@ -745,19 +742,47 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
for (unsigned i = 0; i < num_operands; ++i) {
operands[i] = in_vreg_arrays[i](idx);
}
Operation *vreg_op =
builder.create(op.getLoc(), op.getName().getIdentifier(), operands,
out_vreg_ty, attributes.getAttrs());
CHECK(vreg_op);
CHECK_EQ(vreg_op->getNumResults(), 1);
*out_vreg = vreg_op->getResult(0);
*out_vreg = vreg_op_creator(builder, operands);
CHECK(*out_vreg);
});
op.replaceAllUsesWith(assemble(builder, vty, layout,
std::move(out_vreg_array), ctx.target_shape));
op.erase();
return success();
}

LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_OP(OpTrait::hasElementwiseMappableTraits(&op));
if (op.getNumResults() != 1) {
return op.emitError("Not implemented: Only ops with one result supported");
}

NamedAttrList attributes(op.getAttrDictionary());
attributes.erase("in_layout");
attributes.erase("out_layout");

std::optional<VectorType> out_vreg_ty;
return elementwise_op_rule_impl(
ctx, op, layouts_in, layouts_out,
[&, attributes](ImplicitLocOpBuilder &builder,
ArrayRef<Value> operands) -> Value {
if (!out_vreg_ty.has_value()) {
const auto vty = cast<VectorType>(op.getResult(0).getType());
const VectorLayout &layout = *layouts_out.front();
out_vreg_ty = getNativeVregOrVmaskType(
vty.getElementType(), layout.bitwidth(), ctx.target_shape);
}
Operation *vreg_op =
builder.create(op.getLoc(), op.getName().getIdentifier(), operands,
*out_vreg_ty, attributes.getAttrs());
CHECK(vreg_op);
CHECK_EQ(vreg_op->getNumResults(), 1);
return vreg_op->getResult(0);
});
}

FailureOr<std::pair<VectorLayout, xla::Array<Value>>> retileWithCombineHalves(
RewriteContext& ctx, OpBuilder& builder, Location loc,
const ArrayRef<int64_t> shape, const xla::Array<Value>& old_vregs,
Expand Down Expand Up @@ -3514,6 +3539,52 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult tpu_pack_elementwise_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto pack_elementwise_op = cast<tpu::PackElementwiseOp>(op);
const mlir::Type packed_type = pack_elementwise_op.getTargetType();
const VectorType packed_vreg_ty =
getNativeVregType(packed_type, ctx.target_shape);
return elementwise_op_rule_impl(
ctx, op, layouts_in, layouts_out,
[&, packed_vreg_ty](ImplicitLocOpBuilder &builder,
ArrayRef<Value> vreg_operands) -> Value {
Value packed_vreg = tpu::PackSubelementsOp::create(
builder, packed_vreg_ty, vreg_operands, PackFormat::kInterleaved);
return tpu::BitcastVregOp::create(
builder,
getNativeVregType(builder.getI32Type(), ctx.target_shape),
packed_vreg
);
});
}

LogicalResult tpu_unpack_elementwise_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto unpack_elementwise_op = cast<tpu::UnpackElementwiseOp>(op);
const int64_t index = unpack_elementwise_op.getIndex();
const mlir::Type packed_type = unpack_elementwise_op.getSourceType();
const auto result_vty = cast<VectorType>(op.getResult(0).getType());
const mlir::Type unpacked_type = result_vty.getElementType();
const VectorType packed_vreg_ty =
getNativeVregType(packed_type, ctx.target_shape);
const VectorType unpacked_vreg_ty =
getNativeVregType(unpacked_type, ctx.target_shape);
return elementwise_op_rule_impl(
ctx, op, layouts_in, layouts_out,
[&, packed_vreg_ty, unpacked_vreg_ty, index](
ImplicitLocOpBuilder &builder,
ArrayRef<Value> vreg_operands) -> Value {
Value in_vreg = tpu::BitcastVregOp::create(
builder, packed_vreg_ty, vreg_operands[0]);
return tpu::UnpackSubelementsOp::create(
builder, unpacked_vreg_ty, in_vreg, index,
PackFormat::kInterleaved);
});
}

LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -9164,6 +9235,10 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::RotateOp::getOperationName(), tpu_rotate_rule},
{tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule},
{tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule},
{tpu::PackElementwiseOp::getOperationName(),
tpu_pack_elementwise_rule},
{tpu::UnpackElementwiseOp::getOperationName(),
tpu_unpack_elementwise_rule},
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
{tpu::DynamicGatherOp::getOperationName(), tpu_dynamic_gather_rule},
Expand Down
Loading