Skip to content

Commit b1ee28f

Browse files
committed
* allow multiple allocators to coexist for the same device.
Using available allocator instead of requested is leading to an unpexpected crash
1 parent 82e298c commit b1ee28f

File tree

4 files changed

+29
-15
lines changed

4 files changed

+29
-15
lines changed

include/tvm/runtime/memory/memory_manager.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,20 @@ namespace tvm {
3939
namespace runtime {
4040
namespace memory {
4141

42+
enum AllocatorType {
43+
kNaive = 1,
44+
kPooled,
45+
};
46+
4247
struct Buffer {
4348
/*! \brief The pointer to the allocated block of memory. */
4449
void* data{nullptr};
4550
/*! \brief The size of the block. */
4651
size_t size{0};
4752
/*! \brief The context of the allocated buffers. */
4853
Device device;
49-
};
50-
51-
enum AllocatorType {
52-
kNaive = 1,
53-
kPooled,
54+
/*! \brief The allocator that created this buffer. */
55+
AllocatorType alloc_type;
5456
};
5557

5658
class Allocator {
@@ -113,16 +115,18 @@ class MemoryManager {
113115
/*!
114116
* \brief Get an allocator given the context.
115117
* \param dev The TVM device
118+
* \param type The allocator type
116119
* \return The memory allocator.
117120
*/
118-
static Allocator* GetAllocator(Device dev);
121+
static Allocator* GetAllocator(Device dev, AllocatorType type);
119122

120123
private:
121124
MemoryManager() {}
122125

123126
protected:
124127
std::mutex mu_;
125-
std::unordered_map<Device, std::unique_ptr<Allocator>> allocators_;
128+
std::unordered_map<Device, std::unordered_map<AllocatorType, std::unique_ptr<Allocator>>>
129+
allocators_;
126130
};
127131

128132
/*! \brief An object representing a storage allocation. */
@@ -138,7 +142,7 @@ class StorageObj : public Object {
138142
static void Deleter(Object* ptr);
139143

140144
~StorageObj() {
141-
auto alloc = MemoryManager::Global()->GetAllocator(buffer.device);
145+
auto alloc = MemoryManager::Global()->GetAllocator(buffer.device, buffer.alloc_type);
142146
alloc->Free(buffer);
143147
}
144148

src/runtime/memory/memory_manager.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void BufferDeleter(Object* obj) {
3737
auto* ptr = static_cast<NDArray::Container*>(obj);
3838
ICHECK(ptr->manager_ctx != nullptr);
3939
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
40-
MemoryManager::GetAllocator(buffer->device)->Free(*(buffer));
40+
MemoryManager::GetAllocator(buffer->device, buffer->alloc_type)->Free(*(buffer));
4141
delete buffer;
4242
delete ptr;
4343
}
@@ -122,6 +122,9 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
122122
MemoryManager* m = MemoryManager::Global();
123123
std::lock_guard<std::mutex> lock(m->mu_);
124124
if (m->allocators_.find(dev) == m->allocators_.end()) {
125+
m->allocators_.emplace(dev, std::unordered_map<AllocatorType, std::unique_ptr<Allocator>>());
126+
}
127+
if (m->allocators_.at(dev).find(type) == m->allocators_.at(dev).end()) {
125128
std::unique_ptr<Allocator> alloc;
126129
switch (type) {
127130
case kNaive: {
@@ -138,26 +141,29 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
138141
LOG(FATAL) << "Unknown allocator type: " << type;
139142
}
140143
auto ret = alloc.get();
141-
m->allocators_.emplace(dev, std::move(alloc));
144+
m->allocators_.at(dev).emplace(type, std::move(alloc));
142145
return ret;
143146
}
144-
auto alloc = m->allocators_.at(dev).get();
145-
if (alloc->type() != type) {
147+
auto alloc = m->allocators_.at(dev).at(type).get();
148+
/*if (alloc->type() != type) {
146149
LOG(WARNING) << "The type of existing allocator for " << dev
147150
<< " is different from the request type (" << alloc->type() << " vs " << type
148151
<< ")";
149-
}
152+
}*/
150153
return alloc;
151154
}
152155

153-
Allocator* MemoryManager::GetAllocator(Device dev) {
156+
Allocator* MemoryManager::GetAllocator(Device dev, AllocatorType type) {
154157
MemoryManager* m = MemoryManager::Global();
155158
std::lock_guard<std::mutex> lock(m->mu_);
156159
auto it = m->allocators_.find(dev);
157160
if (it == m->allocators_.end()) {
158161
LOG(FATAL) << "Allocator for " << dev << " has not been created yet.";
159162
}
160-
return it->second.get();
163+
if (it->second.find(type) == it->second.end()) {
164+
LOG(FATAL) << "Allocator for " << dev << " of type " << type << " has not been created yet.";
165+
}
166+
return it->second.at(type).get();
161167
}
162168

163169
NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev,

src/runtime/memory/naive_allocator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class NaiveAllocator final : public Allocator {
4141
Buffer buf;
4242
buf.device = device_;
4343
buf.size = nbytes;
44+
buf.alloc_type = kNaive;
4445
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, nbytes, alignment, type_hint);
4546
used_memory_.fetch_add(nbytes, std::memory_order_relaxed);
4647
DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B";
@@ -59,6 +60,7 @@ class NaiveAllocator final : public Allocator {
5960
auto tmp_buf = Allocator::Alloc(device_, shape, type_hint, mem_scope);
6061
buf.size = tmp_buf.size;
6162
buf.data = tmp_buf.data;
63+
buf.alloc_type = kNaive;
6264
return buf;
6365
}
6466

@@ -67,6 +69,7 @@ class NaiveAllocator final : public Allocator {
6769
type_hint, String(mem_scope));
6870
used_memory_.fetch_add(nbytes, std::memory_order_relaxed);
6971
DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B";
72+
buf.alloc_type = kNaive;
7073
return buf;
7174
}
7275

src/runtime/memory/pooled_allocator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class PooledAllocator final : public Allocator {
5858
Buffer buf;
5959
buf.device = device_;
6060
buf.size = size;
61+
buf.alloc_type = kPooled;
6162
try {
6263
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint);
6364
} catch (InternalError& err) {

0 commit comments

Comments
 (0)