Skip to content

Commit

Permalink
[Mosaic:TPU][NFC] In ext and trunc rules, avoid vreg array reshape by…
Browse files Browse the repository at this point in the history
… always using implicit shapes

PiperOrigin-RevId: 704098211
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 9, 2024
1 parent 79318a0 commit bd1efcf
Showing 1 changed file with 19 additions and 32 deletions.
51 changes: 19 additions & 32 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -850,20 +850,13 @@ FailureOr<xla::Array<Value>> ext_op_rule_impl(RewriteContext &ctx,
const VectorLayout &layout_out) {
const auto result_ty = cast<VectorType>(op.getResult().getType());
auto source = cast<TypedValue<VectorType>>(op.getIn());
const auto source_ty = source.getType();
auto output_vregs_shape =
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape);
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> input_vregs,
disassemble(builder, layout_in, source, ctx.target_shape));
disassemble(builder, layout_in, source, ctx.target_shape,
/*use_implicit_shape=*/true));
xla::Array<Value> output_vregs(output_vregs_shape);
// TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble?
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
ctx.target_shape));
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
ctx.target_shape));
}
const VectorType res_vreg_ty =
getNativeVregType(result_ty.getElementType(), ctx.target_shape);
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
Expand Down Expand Up @@ -900,9 +893,6 @@ FailureOr<xla::Array<Value>> ext_op_rule_impl(RewriteContext &ctx,
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
});
}
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
output_vregs.Reshape(output_vregs_shape);
}
return output_vregs;
}

Expand All @@ -925,8 +915,9 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
*layouts_out.front()));
const auto result_ty = cast<VectorType>(extf_op.getResult().getType());
extf_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(),
std::move(output_vregs), ctx.target_shape)
.getResult());
std::move(output_vregs), ctx.target_shape,
/*use_implicit_shape=*/true)
.getResult());
extf_op.erase();
return success();
}
Expand All @@ -946,8 +937,10 @@ LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op,
*layouts_out.front()));
const auto result_ty = cast<VectorType>(extsi_op.getResult().getType());
extsi_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(),
std::move(output_vregs), ctx.target_shape)
.getResult());
std::move(output_vregs),
ctx.target_shape,
/*use_implicit_shape=*/true)
.getResult());
extsi_op.erase();
return success();
}
Expand Down Expand Up @@ -998,8 +991,10 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
*v = builder.create<BitcastVregOp>(op.getLoc(), res_vreg_ty, unpacked);
});
extui_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(),
std::move(output_vregs), ctx.target_shape)
.getResult());
std::move(output_vregs),
ctx.target_shape,
/*use_implicit_shape=*/true)
.getResult());
extui_op.erase();
return success();
}
Expand All @@ -1010,13 +1005,13 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto source = cast<TypedValue<VectorType>>(op.getIn());
const auto source_ty = source.getType();
auto result_ty = cast<VectorType>(op.getResult().getType());
auto output_vregs_shape =
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape);
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> input_vregs,
disassemble(builder, layout_in, source, ctx.target_shape));
disassemble(builder, layout_in, source, ctx.target_shape,
/*use_implicit_shape=*/true));
xla::Array<Value> output_vregs(output_vregs_shape);
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
Expand All @@ -1031,12 +1026,6 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
}
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
ctx.target_shape));
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
ctx.target_shape));
}
VectorType res_vreg_ty =
getNativeVregType(result_ty.getElementType(), ctx.target_shape);
if (layout_out.tiling() == ctx.target_shape) {
Expand Down Expand Up @@ -1081,11 +1070,9 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
} else {
return op.emitOpError("Not implemented: unsupported output tiling");
}
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
output_vregs.Reshape(output_vregs_shape);
}
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape)
std::move(output_vregs), ctx.target_shape,
/*use_implicit_shape=*/true)
.getResult());
op.erase();
return success();
Expand Down

0 comments on commit bd1efcf

Please sign in to comment.