Skip to content

Commit

Permalink
Add Pluggable CUDA allocator backend (pytorch#86786)
Browse files Browse the repository at this point in the history
Fixes pytorch#43144

This uses the Backend system added by [82682](pytorch#82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.

For example, we could have the following allocator in c++

```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>

extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
   void *ptr;
   std::cout<<"alloc "<< size<<std::endl;
   cudaMalloc(&ptr, size);
   return ptr;
}

void my_free(void* ptr) {
   std::cout<<"free "<<std::endl;
   cudaFree(ptr);
}
}
```

Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```

And use it from PyTorch as follows

```python
import torch

# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```

Things to discuss
- How to test this, needs compiling external code ...

Pull Request resolved: pytorch#86786
Approved by: https://github.com/albanD
  • Loading branch information
Emilio Castillo authored and pytorchmergebot committed Nov 23, 2022
1 parent 1333fdc commit c9d4390
Show file tree
Hide file tree
Showing 13 changed files with 751 additions and 11 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",
"torch/csrc/cuda/CUDAPluggableAllocator.cpp",
]

libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
Expand Down
4 changes: 4 additions & 0 deletions c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,10 @@ class NativeCachingAllocator : public CUDAAllocator {
}
}

bool initialized() override {
return device_allocator.size() > 0;
}

/** allocates a block which is safe to use from the provided stream */
void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {
TORCH_INTERNAL_ASSERT(
Expand Down
1 change: 1 addition & 0 deletions c10/cuda/CUDACachingAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class CUDAAllocator : public Allocator {
virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0;
virtual void raw_delete(void* ptr) = 0;
virtual void init(int device_count) = 0;
virtual bool initialized() = 0;
virtual void setMemoryFraction(double fraction, int device) = 0;
virtual void emptyCache() = 0;
virtual void cacheInfo(int dev_id, size_t* largestBlock) = 0;
Expand Down
4 changes: 4 additions & 0 deletions c10/cuda/CUDAMallocAsyncAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
(void)called;
}

bool initialized() {
return devs_initialized_flags.size() > 0;
}

static inline void assertValidDevice(int device) {
TORCH_CHECK(
0 <= device && device < device_count, "Invalid device argument.");
Expand Down
2 changes: 2 additions & 0 deletions docs/source/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ Memory management
caching_allocator_alloc
caching_allocator_delete
get_allocator_backend
CUDAPluggableAllocator
change_current_allocator
.. FIXME The following doesn't seem to exist. Is it supposed to?
https://github.com/pytorch/pytorch/issues/27785
.. autofunction:: reset_max_memory_reserved
Expand Down
60 changes: 60 additions & 0 deletions docs/source/notes/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,66 @@ Available options:
.. _CUDA's built-in asynchronous allocator:
https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/

.. _cuda-memory-custom-allocator:

Using custom memory allocators for CUDA
---------------------------------------

It is possible to define allocators as simple functions in C/C++ and compile
them as a shared library, the code below shows a basic allocator that just
traces all the memory operations.

.. code:: C++

#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
// Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
cudaMalloc(&ptr, size);
std::cout<<"alloc "<<ptr<<size<<std::endl;
return ptr;
}
void my_free(void* ptr, ssize_t size, cudaStream_t stream) {
std::cout<<"free "<<ptr<< " "<<stream<<std::endl;
cudaFree(ptr);
}
}


This can be used in python through the :class:`torch.cuda.memory.CUDAPluggableAllocator`.
The user is responsible for supplying the path to the `.so` file and the name
of the alloc/free functions that match the signatures specified above.

.. code:: python
import torch
# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
'alloc.so', 'my_malloc', 'my_free')
# Swap the current allocator
torch.cuda.memory.change_current_allocator(new_alloc)
# This will allocate memory in the device using the new allocator
b = torch.zeros(10, device='cuda')
.. code:: python
import torch
# Do an initial memory allocator
b = torch.zeros(10, device='cuda')
# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
'alloc.so', 'my_malloc', 'my_free')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(new_alloc)
.. _cufft-plan-cache:

cuFFT plan cache
Expand Down
7 changes: 7 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,13 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
def _cuda_getAllocatorBackend() -> str: ...

class _cuda_CUDAAllocator:
...

def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ...
def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ...
def _cuda_getAllocator() -> _cuda_CUDAAllocator: ...
def _cuda_lock_mutex() -> None: ...
def _cuda_unlock_mutex() -> None: ...
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
Expand Down
Loading

0 comments on commit c9d4390

Please sign in to comment.