Skip to content

Commit 6ca6a53

Browse files
committed
* More tests for allocators to cover various conflicts.
1 parent 24610c8 commit 6ca6a53

File tree

5 files changed

+429
-6
lines changed

5 files changed

+429
-6
lines changed

src/runtime/memory/memory_manager.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ Allocator* MemoryManager::GetAllocator(Device dev, AllocatorType type) {
169169

170170
NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev,
171171
Optional<String> mem_scope) {
172+
ICHECK(dev.device_type == device_.device_type)
173+
<< "Device mismatch, expected type " << device_.device_type << " got type" << dev.device_type;
174+
ICHECK(dev.device_id == device_.device_id)
175+
<< "Device mismatch, expected id " << device_.device_id << " got id" << dev.device_id;
172176
VerifyDataType(dtype);
173177
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, dev);
174178
container->SetDeleter(BufferDeleter);

src/runtime/memory/naive_allocator.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class NaiveAllocator final : public Allocator {
5252
}
5353

5454
void Free(const Buffer& buffer) override {
55+
ICHECK(buffer.device.device_type == device_.device_type)
56+
<< "Device mismatch, expected type " << device_.device_type << " got type"
57+
<< buffer.device.device_type;
58+
ICHECK(buffer.device.device_id == device_.device_id)
59+
<< "Device mismatch, expected id " << device_.device_id << " got id"
60+
<< buffer.device.device_id;
5561
ICHECK(buffer.alloc_type == type())
5662
<< "Allocator type mismatch, expected " << type() << " got " << buffer.alloc_type;
5763
DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, buffer.data);

src/runtime/memory/pooled_allocator.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ class PooledAllocator final : public Allocator {
7777
}
7878

7979
void Free(const Buffer& buffer) override {
80+
ICHECK(buffer.device.device_type == device_.device_type)
81+
<< "Device mismatch, expected type " << device_.device_type << " got type"
82+
<< buffer.device.device_type;
83+
ICHECK(buffer.device.device_id == device_.device_id)
84+
<< "Device mismatch, expected id " << device_.device_id << " got id"
85+
<< buffer.device.device_id;
8086
ICHECK(buffer.alloc_type == type())
8187
<< "Allocator type mismatch, expected " << type() << " got " << buffer.alloc_type;
8288
std::lock_guard<std::recursive_mutex> lock(mu_);

tests/cpp/relay/backend/graph_plan_token_alloc.cc

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class TokenAllocatorMixedWrapper : public TokenAllocatorMixed {
3333
inline size_t AllocListSize() const { return data_.size(); }
3434
};
3535

36-
TEST(TokenMixedAlloc, OneToken) {
36+
TEST(TokenMixedAlloc, TextureOneToken) {
3737
TokenAllocatorMixedWrapper alloc;
3838
int storage_ids = 0;
3939
EXPECT_EQ(alloc.AllocListSize(), 0);
@@ -63,7 +63,7 @@ TEST(TokenMixedAlloc, OneToken) {
6363
EXPECT_EQ(alloc.FreeListSize(), 1);
6464
}
6565

66-
TEST(TokenMixedAlloc, EqualSizeTokenReuse) {
66+
TEST(TokenMixedAlloc, TextureEqualSizeTokenReuse) {
6767
TokenAllocatorMixedWrapper alloc;
6868
int storage_ids = 0;
6969
EXPECT_EQ(alloc.AllocListSize(), 0);
@@ -137,7 +137,7 @@ TEST(TokenMixedAlloc, EqualSizeTokenReuse) {
137137
EXPECT_EQ(alloc.FreeListSize(), 1);
138138
}
139139

140-
TEST(TokenMixedAlloc, EqualSizeDiffTypes) {
140+
TEST(TokenMixedAlloc, TextureEqualSizeDiffTypes) {
141141
TokenAllocatorMixedWrapper alloc;
142142
int storage_ids = 0;
143143
EXPECT_EQ(alloc.AllocListSize(), 0);
@@ -186,7 +186,7 @@ TEST(TokenMixedAlloc, EqualSizeDiffTypes) {
186186
EXPECT_EQ(alloc.FreeListSize(), 1);
187187
}
188188

189-
TEST(TokenMixedAlloc, DifferentSizesTokenReuse) {
189+
TEST(TokenMixedAlloc, TextureDifferentSizesTokenReuse) {
190190
TokenAllocatorMixedWrapper alloc;
191191
int storage_ids = 0;
192192
EXPECT_EQ(alloc.AllocListSize(), 0);
@@ -255,7 +255,7 @@ TEST(TokenMixedAlloc, DifferentSizesTokenReuse) {
255255
EXPECT_EQ(sizeReq, 576000);
256256
}
257257

258-
TEST(TokenMixedAlloc, DifferentSizesTokenReuse2) {
258+
TEST(TokenMixedAlloc, TextureDifferentSizesTokenReuse2) {
259259
TokenAllocatorMixedWrapper alloc;
260260
int storage_ids = 0;
261261
EXPECT_EQ(alloc.AllocListSize(), 0);
@@ -302,7 +302,7 @@ TEST(TokenMixedAlloc, DifferentSizesTokenReuse2) {
302302
EXPECT_EQ(sizeReq, 140800);
303303
}
304304

305-
TEST(TokenMixedAlloc, SameSizesButDiffMemoryScopes) {
305+
TEST(TokenMixedAlloc, TextureSameSizesButDiffMemoryScopes) {
306306
TokenAllocatorMixedWrapper alloc;
307307
int storage_ids = 0;
308308
EXPECT_EQ(alloc.AllocListSize(), 0);
@@ -354,5 +354,187 @@ TEST(TokenMixedAlloc, SameSizesButDiffMemoryScopes) {
354354
EXPECT_EQ(alloc.AllocListSize(), 1);
355355
EXPECT_EQ(alloc.FreeListSize(), 1);
356356
}
357+
358+
TEST(TokenMixedAlloc, OneToken) {
359+
TokenAllocatorMixedWrapper alloc;
360+
int storage_ids = 0;
361+
EXPECT_EQ(alloc.AllocListSize(), 0);
362+
EXPECT_EQ(alloc.FreeListSize(), 0);
363+
364+
TensorType tt1({1, 22, 20, 20, 4}, DataType(kDLFloat, 32, 1));
365+
VirtualDevice vd1(kDLOpenCL, 0, Target("opencl"));
366+
StorageToken tok1 = {
367+
1, // ref_counter
368+
0, // max bytes
369+
tt1, // tensor type
370+
vd1, // virtual device
371+
-1 // storage_id
372+
};
373+
EXPECT_EQ(alloc.Request(&tok1), nullptr);
374+
375+
alloc.Alloc(&tok1, storage_ids++);
376+
EXPECT_EQ(alloc.AllocListSize(), 1);
377+
EXPECT_EQ(alloc.FreeListSize(), 0);
378+
379+
tok1.ref_counter -= 1;
380+
alloc.CheckForRelease(&tok1);
381+
EXPECT_EQ(alloc.AllocListSize(), 1);
382+
EXPECT_EQ(alloc.FreeListSize(), 1);
383+
}
384+
385+
TEST(TokenMixedAlloc, EqualSizeTokenReuse) {
386+
TokenAllocatorMixedWrapper alloc;
387+
int storage_ids = 0;
388+
EXPECT_EQ(alloc.AllocListSize(), 0);
389+
EXPECT_EQ(alloc.FreeListSize(), 0);
390+
391+
TensorType tt1({1, 22, 20, 20, 4}, DataType(kDLFloat, 32, 1));
392+
VirtualDevice vd1(kDLOpenCL, 0, Target("opencl"));
393+
StorageToken tok1 = {
394+
1, // ref_counter
395+
0, // max bytes
396+
tt1, // tensor type
397+
vd1, // virtual device
398+
-1 // storage_id
399+
};
400+
EXPECT_EQ(alloc.Request(&tok1), nullptr);
401+
402+
alloc.Alloc(&tok1, storage_ids++);
403+
EXPECT_EQ(alloc.AllocListSize(), 1);
404+
EXPECT_EQ(alloc.FreeListSize(), 0);
405+
406+
tok1.ref_counter -= 1;
407+
alloc.CheckForRelease(&tok1);
408+
EXPECT_EQ(alloc.AllocListSize(), 1);
409+
EXPECT_EQ(alloc.FreeListSize(), 1);
410+
411+
StorageToken tok2 = {
412+
1, // ref_counter
413+
0, // max bytes
414+
tt1, // tensor type
415+
vd1, // virtual device
416+
-1 // storage_id
417+
};
418+
auto req = alloc.Request(&tok2);
419+
EXPECT_NE(req, nullptr);
420+
EXPECT_EQ(alloc.AllocListSize(), 1);
421+
EXPECT_EQ(alloc.FreeListSize(), 0);
422+
EXPECT_EQ(req->storage_id, storage_ids - 1);
423+
EXPECT_EQ(req->ref_counter, 1);
424+
425+
req->ref_counter -= 1;
426+
alloc.CheckForRelease(req);
427+
EXPECT_EQ(alloc.AllocListSize(), 1);
428+
EXPECT_EQ(alloc.FreeListSize(), 1);
429+
}
430+
431+
TEST(TokenMixedAlloc, EqualSizeDiffTypes) {
432+
TokenAllocatorMixedWrapper alloc;
433+
int storage_ids = 0;
434+
EXPECT_EQ(alloc.AllocListSize(), 0);
435+
EXPECT_EQ(alloc.FreeListSize(), 0);
436+
437+
TensorType tt1({1, 22, 20, 20, 4}, DataType(kDLFloat, 32, 1));
438+
VirtualDevice vd1(kDLOpenCL, 0, Target("opencl"));
439+
StorageToken tok1 = {
440+
1, // ref_counter
441+
0, // max bytes
442+
tt1, // tensor type
443+
vd1, // virtual device
444+
-1 // storage_id
445+
};
446+
EXPECT_EQ(alloc.Request(&tok1), nullptr);
447+
448+
alloc.Alloc(&tok1, storage_ids++);
449+
EXPECT_EQ(alloc.AllocListSize(), 1);
450+
EXPECT_EQ(alloc.FreeListSize(), 0);
451+
452+
tok1.ref_counter -= 1;
453+
alloc.CheckForRelease(&tok1);
454+
EXPECT_EQ(alloc.AllocListSize(), 1);
455+
EXPECT_EQ(alloc.FreeListSize(), 1);
456+
457+
TensorType tt2({1, 22, 20, 20, 4}, DataType(kDLFloat, 16, 1));
458+
StorageToken tok2 = {
459+
1, // ref_counter
460+
0, // max bytes
461+
tt2, // tensor type
462+
vd1, // virtual device
463+
-1 // storage_id
464+
};
465+
466+
auto req1 = alloc.Request(&tok2);
467+
EXPECT_NE(req1, nullptr);
468+
EXPECT_EQ(alloc.AllocListSize(), 1);
469+
EXPECT_EQ(alloc.FreeListSize(), 0);
470+
471+
req1->ref_counter -= 1;
472+
alloc.CheckForRelease(req1);
473+
EXPECT_EQ(alloc.AllocListSize(), 1);
474+
EXPECT_EQ(alloc.FreeListSize(), 1);
475+
}
476+
477+
TEST(TokenMixedAlloc, DifferentSizesTokenReuse) {
478+
TokenAllocatorMixedWrapper alloc;
479+
int storage_ids = 0;
480+
EXPECT_EQ(alloc.AllocListSize(), 0);
481+
EXPECT_EQ(alloc.FreeListSize(), 0);
482+
483+
TensorType tt1({1, 22, 20, 20, 4}, DataType(kDLFloat, 32, 1));
484+
VirtualDevice vd1(kDLOpenCL, 0, Target("opencl"));
485+
StorageToken tok1 = {
486+
1, // ref_counter
487+
0, // max bytes
488+
tt1, // tensor type
489+
vd1, // virtual device
490+
-1 // storage_id
491+
};
492+
EXPECT_EQ(alloc.Request(&tok1), nullptr);
493+
494+
alloc.Alloc(&tok1, storage_ids++);
495+
EXPECT_EQ(alloc.AllocListSize(), 1);
496+
EXPECT_EQ(alloc.FreeListSize(), 0);
497+
498+
tok1.ref_counter -= 1;
499+
alloc.CheckForRelease(&tok1);
500+
EXPECT_EQ(alloc.AllocListSize(), 1);
501+
EXPECT_EQ(alloc.FreeListSize(), 1);
502+
503+
TensorType tt2({1, 40, 30, 30, 4}, DataType(kDLFloat, 32, 1));
504+
StorageToken tok2 = {
505+
1, // ref_counter
506+
0, // max bytes
507+
tt2, // tensor type
508+
vd1, // virtual device
509+
-1 // storage_id
510+
};
511+
auto req = alloc.Request(&tok2);
512+
EXPECT_NE(req, nullptr);
513+
EXPECT_EQ(alloc.AllocListSize(), 1);
514+
EXPECT_EQ(alloc.FreeListSize(), 0);
515+
EXPECT_EQ(req->storage_id, storage_ids - 1);
516+
EXPECT_EQ(req->ref_counter, 1);
517+
518+
req->ref_counter -= 1;
519+
alloc.CheckForRelease(req);
520+
EXPECT_EQ(alloc.AllocListSize(), 1);
521+
EXPECT_EQ(alloc.FreeListSize(), 1);
522+
523+
TensorType tt3({1, 25, 30, 30, 4}, DataType(kDLFloat, 32, 1));
524+
StorageToken tok3 = {
525+
1, // ref_counter
526+
0, // max bytes
527+
tt3, // tensor type
528+
vd1, // virtual device
529+
-1 // storage_id
530+
};
531+
auto req2 = alloc.Request(&tok3);
532+
EXPECT_NE(req2, nullptr);
533+
EXPECT_EQ(alloc.AllocListSize(), 1);
534+
EXPECT_EQ(alloc.FreeListSize(), 0);
535+
EXPECT_EQ(req2->storage_id, storage_ids - 1);
536+
EXPECT_EQ(req2->ref_counter, 1);
537+
}
538+
357539
} // namespace relay
358540
} // namespace tvm

0 commit comments

Comments
 (0)