Skip to content

Commit 15a1df4

Browse files
klucketensorflower-gardener
authored andcommitted
Remove test use of StreamExecutor::GetAllocator.
PiperOrigin-RevId: 627053448
1 parent b66b503 commit 15a1df4

File tree

8 files changed

+104
-63
lines changed

8 files changed

+104
-63
lines changed

third_party/xla/xla/service/generic_transfer_manager_test.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "xla/shape.h"
3131
#include "xla/shape_tree.h"
3232
#include "xla/shape_util.h"
33+
#include "xla/stream_executor/device_memory_allocator.h"
3334
#include "xla/stream_executor/host/host_platform_id.h"
3435
#include "xla/stream_executor/platform_manager.h"
3536
#include "xla/stream_executor/stream_executor.h"
@@ -62,18 +63,21 @@ class GenericTransferManagerTest : public ::testing::Test {
6263
se::PlatformManager::PlatformWithId(se::host::kHostPlatformId));
6364
TF_ASSERT_OK_AND_ASSIGN(stream_executor_, platform->ExecutorForDevice(0));
6465
TF_ASSERT_OK_AND_ASSIGN(stream_, stream_executor_->CreateStream());
66+
allocator_ =
67+
std::make_unique<se::StreamExecutorMemoryAllocator>(stream_executor_);
6568
}
6669

6770
ScopedShapedBuffer AllocateBuffer(const Shape& shape) {
68-
auto buffer = transfer_manager_.AllocateScopedShapedBuffer(
69-
shape, stream_executor_->GetAllocator(),
70-
/*device_ordinal=*/0);
71+
auto buffer =
72+
transfer_manager_.AllocateScopedShapedBuffer(shape, allocator_.get(),
73+
/*device_ordinal=*/0);
7174
return std::move(buffer.value());
7275
}
7376

7477
PackingTransferManager transfer_manager_;
7578
se::StreamExecutor* stream_executor_;
7679
std::unique_ptr<se::Stream> stream_;
80+
std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
7781
};
7882

7983
TEST_F(GenericTransferManagerTest, TransferLiteralToDevice) {

third_party/xla/xla/service/gpu/fusions/cudnn_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,13 @@ ENTRY e {
335335
backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}}
336336
})";
337337

338+
se::StreamExecutorMemoryAllocator allocator(
339+
backend().default_stream_executor());
338340
// Verify that a command buffer is applied.
339-
TF_ASSERT_OK_AND_ASSIGN(
340-
std::unique_ptr<Executable> executable,
341-
backend().compiler()->RunBackend(
342-
GetOptimizedModule(kHloText).value(),
343-
backend().default_stream_executor(),
344-
backend().default_stream_executor()->GetAllocator()));
341+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> executable,
342+
backend().compiler()->RunBackend(
343+
GetOptimizedModule(kHloText).value(),
344+
backend().default_stream_executor(), &allocator));
345345
absl::StatusOr<bool> filecheck_result =
346346
RunFileCheck(executable->module().ToString(), R"(
347347
; CHECK: ENTRY

third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ TEST(AddressComputationThunkTest, SlicedGemm) {
186186

187187
// Preparing parameters for thunk execution.
188188
ServiceExecutableRunOptions run_options;
189+
se::StreamExecutorMemoryAllocator allocator(executor);
189190
BufferAllocations allocations(
190-
{lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0,
191-
executor->GetAllocator());
191+
{lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, &allocator);
192192

193193
Thunk::ExecuteParams params =
194194
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -358,9 +358,10 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) {
358358

359359
// Preparing parameters for thunk execution.
360360
ServiceExecutableRunOptions run_options;
361+
se::StreamExecutorMemoryAllocator allocator(executor);
361362
BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0,
362363
lhs_offset_1, rhs_offset_0, rhs_offset_1},
363-
0, executor->GetAllocator());
364+
0, &allocator);
364365

365366
Thunk::ExecuteParams params =
366367
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -527,9 +528,10 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) {
527528

528529
// Preparing parameters for thunk execution.
529530
ServiceExecutableRunOptions run_options;
531+
se::StreamExecutorMemoryAllocator allocator(executor);
530532
BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0,
531533
lhs_offset_1, rhs_offset_0, rhs_offset_1},
532-
0, executor->GetAllocator());
534+
0, &allocator);
533535

534536
Thunk::ExecuteParams params =
535537
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -673,9 +675,9 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) {
673675

674676
// Preparing parameters for thunk execution.
675677
ServiceExecutableRunOptions run_options;
678+
se::StreamExecutorMemoryAllocator allocator(executor);
676679
BufferAllocations allocations(
677-
{src, dst, offset_0, offset_1, offset_2, offset_3}, 0,
678-
executor->GetAllocator());
680+
{src, dst, offset_0, offset_1, offset_2, offset_3}, 0, &allocator);
679681

680682
Thunk::ExecuteParams params =
681683
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -861,10 +863,11 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) {
861863

862864
// Preparing parameters for thunk execution.
863865
ServiceExecutableRunOptions run_options;
866+
se::StreamExecutorMemoryAllocator allocator(executor);
864867
BufferAllocations allocations(
865868
{src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3,
866869
dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3},
867-
0, executor->GetAllocator());
870+
0, &allocator);
868871

869872
Thunk::ExecuteParams params =
870873
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -1024,9 +1027,9 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) {
10241027

10251028
// Preparing parameters for thunk execution.
10261029
ServiceExecutableRunOptions run_options;
1030+
se::StreamExecutorMemoryAllocator allocator(executor);
10271031
BufferAllocations allocations(
1028-
{workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0,
1029-
executor->GetAllocator());
1032+
{workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0, &allocator);
10301033

10311034
Thunk::ExecuteParams params =
10321035
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -1174,10 +1177,11 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) {
11741177

11751178
// Preparing parameters for thunk execution.
11761179
ServiceExecutableRunOptions run_options;
1180+
se::StreamExecutorMemoryAllocator allocator(executor);
11771181
BufferAllocations allocations(
11781182
{workspace, /*garbage, to be ignored*/ se::DeviceMemoryBase(), out, rhs,
11791183
lhs_offset_0, lhs_offset_1, /*garbage, to be ignored*/ rhs, lhs},
1180-
0, executor->GetAllocator());
1184+
0, &allocator);
11811185

11821186
Thunk::ExecuteParams params =
11831187
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -1323,9 +1327,10 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) {
13231327

13241328
// Preparing parameters for thunk execution.
13251329
ServiceExecutableRunOptions run_options;
1330+
se::StreamExecutorMemoryAllocator allocator(executor);
13261331
BufferAllocations allocations(
13271332
{lhs_whole_buffer, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0,
1328-
executor->GetAllocator());
1333+
&allocator);
13291334

13301335
Thunk::ExecuteParams params =
13311336
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -1506,10 +1511,11 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) {
15061511

15071512
// Preparing parameters for thunk execution.
15081513
ServiceExecutableRunOptions run_options;
1514+
se::StreamExecutorMemoryAllocator allocator(executor);
15091515
BufferAllocations allocations(
15101516
{src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3,
15111517
dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3},
1512-
0, executor->GetAllocator());
1518+
0, &allocator);
15131519

15141520
Thunk::ExecuteParams params =
15151521
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -1675,8 +1681,9 @@ TEST(AddressComputationThunkTest, SlicedOperandsSameBufferGemm) {
16751681

16761682
// Preparing parameters for thunk execution.
16771683
ServiceExecutableRunOptions run_options;
1684+
se::StreamExecutorMemoryAllocator allocator(executor);
16781685
BufferAllocations allocations({buffer, workspace, lhs_offset_0, lhs_offset_1},
1679-
0, executor->GetAllocator());
1686+
0, &allocator);
16801687

16811688
Thunk::ExecuteParams params =
16821689
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),

third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ TEST(CommandBufferCmdTest, MemcpyCmd) {
204204
commands.Emplace<MemcpyDeviceToDeviceCmd>(s0, slice_b, slice_a, byte_length);
205205

206206
ServiceExecutableRunOptions run_options;
207-
BufferAllocations allocations({a, b}, 0, executor->GetAllocator());
207+
se::StreamExecutorMemoryAllocator allocator(executor);
208+
BufferAllocations allocations({a, b}, 0, &allocator);
208209

209210
CommandBufferCmd::StateManager state;
210211

@@ -272,7 +273,8 @@ TEST(CommandBufferCmdTest, BarrierCmd) {
272273
commands.Emplace<MemcpyDeviceToDeviceCmd>(s1, slice_e, slice_d, byte_length);
273274

274275
ServiceExecutableRunOptions run_options;
275-
BufferAllocations allocations({a, b, c, d, e}, 0, executor->GetAllocator());
276+
se::StreamExecutorMemoryAllocator allocator(executor);
277+
BufferAllocations allocations({a, b, c, d, e}, 0, &allocator);
276278

277279
CommandBufferCmd::StateManager state;
278280

@@ -341,7 +343,8 @@ TEST(CommandBufferCmdTest, LaunchCmd) {
341343
TF_ASSERT_OK(commands.Initialize({executor, source}, state));
342344

343345
ServiceExecutableRunOptions run_options;
344-
BufferAllocations allocations({a, b}, 0, executor->GetAllocator());
346+
se::StreamExecutorMemoryAllocator allocator(executor);
347+
BufferAllocations allocations({a, b}, 0, &allocator);
345348

346349
Thunk::ExecuteParams params =
347350
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
@@ -400,7 +403,8 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
400403
se::DeviceMemoryBase mem0(reinterpret_cast<void*>(0x01234567));
401404
se::DeviceMemoryBase mem1(reinterpret_cast<void*>(0x12345670));
402405

403-
BufferAllocations allocations({mem0, mem1}, 0, executor->GetAllocator());
406+
se::StreamExecutorMemoryAllocator allocator(executor);
407+
BufferAllocations allocations({mem0, mem1}, 0, &allocator);
404408

405409
// No-op trace callback to count how many times it was called.
406410
int64_t num_calls = 0;
@@ -423,7 +427,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
423427

424428
// Check that when memory address changes we re-trace the command buffer.
425429
se::DeviceMemoryBase mem2(reinterpret_cast<void*>(0x23456701));
426-
allocations = BufferAllocations({mem0, mem2}, 0, executor->GetAllocator());
430+
allocations = BufferAllocations({mem0, mem2}, 0, &allocator);
427431

428432
TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2,
429433
traced_cmd_buffer.GetOrTraceCommandBuffer(
@@ -433,7 +437,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
433437
EXPECT_EQ(num_calls, 2);
434438

435439
// Check that we keep first command buffer in cache.
436-
allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator());
440+
allocations = BufferAllocations({mem0, mem1}, 0, &allocator);
437441

438442
TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3,
439443
traced_cmd_buffer.GetOrTraceCommandBuffer(
@@ -442,7 +446,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
442446
EXPECT_EQ(num_calls, 2);
443447

444448
// Check that we trace a new graph when buffer allocation pattern is new.
445-
allocations = BufferAllocations({mem0, mem0}, 0, executor->GetAllocator());
449+
allocations = BufferAllocations({mem0, mem0}, 0, &allocator);
446450

447451
TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4,
448452
traced_cmd_buffer.GetOrTraceCommandBuffer(
@@ -452,7 +456,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
452456
EXPECT_EQ(num_calls, 3);
453457

454458
// Check that we still keep the previous graph in cache.
455-
allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator());
459+
allocations = BufferAllocations({mem0, mem1}, 0, &allocator);
456460

457461
TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5,
458462
traced_cmd_buffer.GetOrTraceCommandBuffer(
@@ -479,12 +483,13 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) {
479483

480484
se::DeviceMemoryBase mem0(reinterpret_cast<void*>(0x01234567));
481485
se::DeviceMemoryBase mem1(reinterpret_cast<void*>(0x12345670));
486+
se::StreamExecutorMemoryAllocator allocator(executor);
482487

483488
std::array<BufferAllocations, 4> allocations = {
484-
BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()),
485-
BufferAllocations({mem1, mem0}, 0, executor->GetAllocator()),
486-
BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()),
487-
BufferAllocations({mem1, mem1}, 0, executor->GetAllocator()),
489+
BufferAllocations({mem0, mem1}, 0, &allocator),
490+
BufferAllocations({mem1, mem0}, 0, &allocator),
491+
BufferAllocations({mem0, mem0}, 0, &allocator),
492+
BufferAllocations({mem1, mem1}, 0, &allocator),
488493
};
489494

490495
int32_t index = 0;

0 commit comments

Comments
 (0)