@@ -70,6 +70,12 @@ namespace xla {
70
70
namespace gpu {
71
71
namespace {
72
72
73
+ constexpr unsigned kLHSOperandIndex = 0 ;
74
+ constexpr unsigned kRHSOperandIndex = 1 ;
75
+
76
+ constexpr unsigned kGEMMOutputBufferIndex = 0 ;
77
+ constexpr unsigned kGEMMWorkspaceBufferIndex = 1 ;
78
+
73
79
absl::StatusOr<std::unique_ptr<Thunk>> BuildCustomKernelThunkForFusion (
74
80
IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion,
75
81
CustomKernel custom_kernel) {
@@ -144,12 +150,14 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
144
150
TF_ASSIGN_OR_RETURN (
145
151
BufferAllocation::Slice lhs_slice,
146
152
GetSliceWithUpdatedOffsetAndSize (buffer_assignment, adaptor, fusion,
147
- *custom_call.operand (0 ), /* index=*/ {}));
153
+ *custom_call.operand (kLHSOperandIndex ),
154
+ /* index=*/ {}));
148
155
149
156
TF_ASSIGN_OR_RETURN (
150
157
BufferAllocation::Slice rhs_slice,
151
158
GetSliceWithUpdatedOffsetAndSize (buffer_assignment, adaptor, fusion,
152
- *custom_call.operand (1 ), /* index=*/ {}));
159
+ *custom_call.operand (kRHSOperandIndex ),
160
+ /* index=*/ {}));
153
161
154
162
BufferAllocation::Slice output;
155
163
std::optional<BufferAllocation::Slice> workspace;
@@ -161,10 +169,11 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
161
169
TF_ASSIGN_OR_RETURN (output,
162
170
GetAllocationSlice (buffer_assignment, &fusion, {}));
163
171
} else {
164
- TF_ASSIGN_OR_RETURN (output,
165
- GetAllocationSlice (buffer_assignment, &fusion, { 0 }));
172
+ TF_ASSIGN_OR_RETURN (output, GetAllocationSlice (buffer_assignment, &fusion,
173
+ { kGEMMOutputBufferIndex }));
166
174
TF_ASSIGN_OR_RETURN (workspace,
167
- GetAllocationSlice (buffer_assignment, &fusion, {1 }));
175
+ GetAllocationSlice (buffer_assignment, &fusion,
176
+ {kGEMMWorkspaceBufferIndex }));
168
177
}
169
178
170
179
bool deterministic_ops =
@@ -249,15 +258,15 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
249
258
slice_instr->index_operands ().front ()->shape ().element_type ()));
250
259
};
251
260
252
- TF_ASSIGN_OR_RETURN (
253
- BufferAllocation::Slice lhs_slice,
254
- get_original_operand_slice ( custom_call.operand (0 ), /* index=*/ {}));
261
+ TF_ASSIGN_OR_RETURN (BufferAllocation::Slice lhs_slice,
262
+ get_original_operand_slice (
263
+ custom_call.operand (kLHSOperandIndex ), /* index=*/ {}));
255
264
collect_slice_info ();
256
265
257
266
slice_instr = nullptr ;
258
- TF_ASSIGN_OR_RETURN (
259
- BufferAllocation::Slice rhs_slice,
260
- get_original_operand_slice ( custom_call.operand (1 ), /* index=*/ {}));
267
+ TF_ASSIGN_OR_RETURN (BufferAllocation::Slice rhs_slice,
268
+ get_original_operand_slice (
269
+ custom_call.operand (kRHSOperandIndex ), /* index=*/ {}));
261
270
collect_slice_info ();
262
271
263
272
slice_instr = nullptr ;
@@ -309,13 +318,15 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
309
318
slice_instr = nullptr ;
310
319
collect_slice_info ();
311
320
} else {
312
- TF_ASSIGN_OR_RETURN (output,
313
- get_original_result_slice (&custom_call, /* index=*/ {0 }));
321
+ TF_ASSIGN_OR_RETURN (
322
+ output, get_original_result_slice (&custom_call,
323
+ /* index=*/ {kGEMMOutputBufferIndex }));
314
324
collect_slice_info ();
315
325
// TODO(vuson): If we want to support slices of workspace, we'd need to
316
326
// start `HloFindIf` with `get-tuple-element` with the right index.
317
- TF_ASSIGN_OR_RETURN (workspace, GetAllocationSlice (buffer_assignment,
318
- &fusion, /* index=*/ {1 }));
327
+ TF_ASSIGN_OR_RETURN (
328
+ workspace, GetAllocationSlice (buffer_assignment, &fusion,
329
+ /* index=*/ {kGEMMWorkspaceBufferIndex }));
319
330
slice_instr = nullptr ;
320
331
collect_slice_info ();
321
332
fake_allocations[3 ] = std::make_unique<BufferAllocation>(
@@ -340,18 +351,18 @@ absl::StatusOr<FusionEmissionResult> EmitDynamicSlicedGemm(
340
351
GemmConfig::For (static_cast <const HloInstruction*>(&custom_call)));
341
352
342
353
int64_t lhs_byte_size =
343
- ShapeUtil::ByteSizeOf (custom_call.operand (0 )->shape ());
344
- fake_allocations[0 ] = std::make_unique<BufferAllocation>(
345
- /* index=*/ 0 , lhs_byte_size, /* color=*/ 0 );
346
- BufferAllocation::Slice slice_lhs_fake (fake_allocations[ 0 ]. get (), 0 ,
347
- lhs_byte_size);
354
+ ShapeUtil::ByteSizeOf (custom_call.operand (kLHSOperandIndex )->shape ());
355
+ fake_allocations[kLHSOperandIndex ] = std::make_unique<BufferAllocation>(
356
+ /* index=*/ kLHSOperandIndex , lhs_byte_size, /* color=*/ 0 );
357
+ BufferAllocation::Slice slice_lhs_fake (
358
+ fake_allocations[ kLHSOperandIndex ]. get (), 0 , lhs_byte_size);
348
359
349
360
int64_t rhs_byte_size =
350
- ShapeUtil::ByteSizeOf (custom_call.operand (1 )->shape ());
351
- fake_allocations[1 ] = std::make_unique<BufferAllocation>(
352
- /* index=*/ 1 , rhs_byte_size, /* color=*/ 0 );
353
- BufferAllocation::Slice slice_rhs_fake (fake_allocations[ 1 ]. get (), 0 ,
354
- rhs_byte_size);
361
+ ShapeUtil::ByteSizeOf (custom_call.operand (kRHSOperandIndex )->shape ());
362
+ fake_allocations[kRHSOperandIndex ] = std::make_unique<BufferAllocation>(
363
+ /* index=*/ kRHSOperandIndex , rhs_byte_size, /* color=*/ 0 );
364
+ BufferAllocation::Slice slice_rhs_fake (
365
+ fake_allocations[ kRHSOperandIndex ]. get (), 0 , rhs_byte_size);
355
366
356
367
fake_allocations[2 ] = std::make_unique<BufferAllocation>(
357
368
/* index=*/ 2 , out_fake_byte_size, /* color=*/ 0 );
0 commit comments