Skip to content

Commit 06f2799

Browse files
tyb0807tensorflower-gardener
authored andcommitted
[xla:gpu][NFC] Use meaningful constexpr
PiperOrigin-RevId: 620194665
1 parent 3016662 commit 06f2799

File tree

1 file changed

+36
-25
lines changed
  • third_party/xla/xla/service/gpu/fusions

1 file changed

+36
-25
lines changed

third_party/xla/xla/service/gpu/fusions/custom.cc

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ namespace xla {
7070
namespace gpu {
7171
namespace {
7272

73+
constexpr unsigned kLHSOperandIndex = 0;
74+
constexpr unsigned kRHSOperandIndex = 1;
75+
76+
constexpr unsigned kGEMMOutputBufferIndex = 0;
77+
constexpr unsigned kGEMMWorkspaceBufferIndex = 1;
78+
7379
absl::StatusOr<std::unique_ptr<Thunk>> BuildCustomKernelThunkForFusion(
7480
IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
7581
CustomKernel custom_kernel) {
@@ -144,12 +150,14 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
144150
TF_ASSIGN_OR_RETURN(
145151
BufferAllocation::Slice lhs_slice,
146152
GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion,
147-
*custom_call.operand(0), /*index=*/{}));
153+
*custom_call.operand(kLHSOperandIndex),
154+
/*index=*/{}));
148155

149156
TF_ASSIGN_OR_RETURN(
150157
BufferAllocation::Slice rhs_slice,
151158
GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion,
152-
*custom_call.operand(1), /*index=*/{}));
159+
*custom_call.operand(kRHSOperandIndex),
160+
/*index=*/{}));
153161

154162
BufferAllocation::Slice output;
155163
std::optional<BufferAllocation::Slice> workspace;
@@ -161,10 +169,11 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
161169
TF_ASSIGN_OR_RETURN(output,
162170
GetAllocationSlice(buffer_assignment, &fusion, {}));
163171
} else {
164-
TF_ASSIGN_OR_RETURN(output,
165-
GetAllocationSlice(buffer_assignment, &fusion, {0}));
172+
TF_ASSIGN_OR_RETURN(output, GetAllocationSlice(buffer_assignment, &fusion,
173+
{kGEMMOutputBufferIndex}));
166174
TF_ASSIGN_OR_RETURN(workspace,
167-
GetAllocationSlice(buffer_assignment, &fusion, {1}));
175+
GetAllocationSlice(buffer_assignment, &fusion,
176+
{kGEMMWorkspaceBufferIndex}));
168177
}
169178

170179
bool deterministic_ops =
@@ -249,15 +258,15 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
249258
slice_instr->index_operands().front()->shape().element_type()));
250259
};
251260

252-
TF_ASSIGN_OR_RETURN(
253-
BufferAllocation::Slice lhs_slice,
254-
get_original_operand_slice(custom_call.operand(0), /*index=*/{}));
261+
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice,
262+
get_original_operand_slice(
263+
custom_call.operand(kLHSOperandIndex), /*index=*/{}));
255264
collect_slice_info();
256265

257266
slice_instr = nullptr;
258-
TF_ASSIGN_OR_RETURN(
259-
BufferAllocation::Slice rhs_slice,
260-
get_original_operand_slice(custom_call.operand(1), /*index=*/{}));
267+
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice,
268+
get_original_operand_slice(
269+
custom_call.operand(kRHSOperandIndex), /*index=*/{}));
261270
collect_slice_info();
262271

263272
slice_instr = nullptr;
@@ -309,13 +318,15 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
309318
slice_instr = nullptr;
310319
collect_slice_info();
311320
} else {
312-
TF_ASSIGN_OR_RETURN(output,
313-
get_original_result_slice(&custom_call, /*index=*/{0}));
321+
TF_ASSIGN_OR_RETURN(
322+
output, get_original_result_slice(&custom_call,
323+
/*index=*/{kGEMMOutputBufferIndex}));
314324
collect_slice_info();
315325
// TODO(vuson): If we want to support slices of workspace, we'd need to
316326
// start `HloFindIf` with `get-tuple-element` with the right index.
317-
TF_ASSIGN_OR_RETURN(workspace, GetAllocationSlice(buffer_assignment,
318-
&fusion, /*index=*/{1}));
327+
TF_ASSIGN_OR_RETURN(
328+
workspace, GetAllocationSlice(buffer_assignment, &fusion,
329+
/*index=*/{kGEMMWorkspaceBufferIndex}));
319330
slice_instr = nullptr;
320331
collect_slice_info();
321332
fake_allocations[3] = std::make_unique<BufferAllocation>(
@@ -340,18 +351,18 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
340351
GemmConfig::For(static_cast<const HloInstruction*>(&custom_call)));
341352

342353
int64_t lhs_byte_size =
343-
ShapeUtil::ByteSizeOf(custom_call.operand(0)->shape());
344-
fake_allocations[0] = std::make_unique<BufferAllocation>(
345-
/*index=*/0, lhs_byte_size, /*color=*/0);
346-
BufferAllocation::Slice slice_lhs_fake(fake_allocations[0].get(), 0,
347-
lhs_byte_size);
354+
ShapeUtil::ByteSizeOf(custom_call.operand(kLHSOperandIndex)->shape());
355+
fake_allocations[kLHSOperandIndex] = std::make_unique<BufferAllocation>(
356+
/*index=*/kLHSOperandIndex, lhs_byte_size, /*color=*/0);
357+
BufferAllocation::Slice slice_lhs_fake(
358+
fake_allocations[kLHSOperandIndex].get(), 0, lhs_byte_size);
348359

349360
int64_t rhs_byte_size =
350-
ShapeUtil::ByteSizeOf(custom_call.operand(1)->shape());
351-
fake_allocations[1] = std::make_unique<BufferAllocation>(
352-
/*index=*/1, rhs_byte_size, /*color=*/0);
353-
BufferAllocation::Slice slice_rhs_fake(fake_allocations[1].get(), 0,
354-
rhs_byte_size);
361+
ShapeUtil::ByteSizeOf(custom_call.operand(kRHSOperandIndex)->shape());
362+
fake_allocations[kRHSOperandIndex] = std::make_unique<BufferAllocation>(
363+
/*index=*/kRHSOperandIndex, rhs_byte_size, /*color=*/0);
364+
BufferAllocation::Slice slice_rhs_fake(
365+
fake_allocations[kRHSOperandIndex].get(), 0, rhs_byte_size);
355366

356367
fake_allocations[2] = std::make_unique<BufferAllocation>(
357368
/*index=*/2, out_fake_byte_size, /*color=*/0);

0 commit comments

Comments
 (0)