Description
openedon Jun 21, 2024
Issue + Reproducers
So I have an i/o job that reads in data to the CPU and passes to the GPU in a map_blocks
call, and then uses CuPy downstream for a non-standard map-blocks call. Here is the reproducer minus the i/o:
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import dask.array as da
import dask
import cupy as cp
import rmm
import numpy as np
from rmm.allocators.cupy import rmm_cupy_allocator
from cupyx.scipy import sparse as cp_sparse
from scipy import sparse
def set_mem():
rmm.reinitialize(managed_memory=True)
cp.cuda.set_allocator(rmm_cupy_allocator)
cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0")
client = Client(cluster)
client.run(set_mem)
M = 100_000
N = 4_000
def make_chunk():
arr = np.random.random((M,N))
chunk = cp.array(arr)
del arr
return chunk
arr = da.map_blocks(make_chunk, meta=cp.array((1.,), dtype=cp.float64), dtype=cp.float64, chunks=((M,) * 50, (N,) * 1))
blocks = arr.to_delayed().ravel()
def __gram_block(block):
return block.T @ block
gram_chunk_matrices = da.map_blocks(__gram_block, arr, chunks=((arr.shape[1],) * len(blocks) , (arr.shape[1],)), dtype=arr.dtype, meta=cp.array([]))
gram_chunk_matrices = gram_chunk_matrices.reshape(len(blocks), arr.shape[1], arr.shape[1])
gram_matrix = gram_chunk_matrices.sum(axis=0).compute()
This uses an unaccountable amount of CPU memory, on the order of 4-8 GB. But I have no idea why this is happening. I don't have any CPU memory that should be used here except the initial read. And when the job completes, dask still reports that it is holding on to 4GB (!) of memory. I see at most 2 tasks running with another 2-6 in memory. In total, the CPU memory being so high doesn't make sense since the individual numpy arrays are 320MB, so this should be at most 640MB (and even that seems high given how long they last on the CPU before I call del
). I don't think this is a dask-memory-reporting issue because top
shows the same amount of memory usage.
I also don't think this has to do with the computation I chose as:
def __double_block(block):
return block * 2
doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices = doubled_matrices.reshape(len(blocks), M, N)
doubled_matrix = doubled_matrices.sum(axis=0).compute()
has the same issue, albeit with a warning about chunk sizes (although I'm not sure why I'm getting the warning since the reshape is right along the blocks). In any case, it's the same memory problem. Minus the reshape, and so minus that warning, same behavior:
def __double_block(block):
return block * 2
doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()
Some env info:
Python 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] on linux
dask==2024.2.1
dask-cuda==24.6.0
dask-cudf-cu12==24.6.0
distributed==2024.2.1
distributed-ucxx-cu12==0.38.0
cuda-python==12.5.0
cudf-cu12==24.6.0
cugraph-cu12==24.6.1
cupy-cuda12x==13.2.0
numpy==1.26.4
Some monitoring screenshots:
Addendum 1: I don't see this as spilling
I also don't think this is spilling because the GPU memory is not that high:
That number is also basically correct: ((4000 * 4000) * 4 + (100_000 * 4000) * 4) < 2GB
where the first 4000 * 4000
is from holding the sum(-partial)
in memory and then 100_000 * 4000
is the input data.
Addendum 2: This is GPU specific
This behavior does not happen on CPU dask:
import dask.distributed as dd
import numpy as np
cluster = dd.LocalCluster(n_workers=1)
client = dd.Client(cluster)
M = 100_000
N = 4_000
def make_chunk():
arr = np.random.random((M,N))
return arr
arr = da.map_blocks(make_chunk, meta=np.array((1.,), dtype=np.float64), dtype=np.float64, chunks=((M,) * 50, (N,) * 1))
def __double_block(block):
return block * 2
doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()
I see the memory use fluctuating, but the baseline of 1.5GB
makes sense given the in-memory/processing stats I cited above. It also releases the memory at the end
Addendum 3: Not an rmm
issue
I tried commenting out client.run(set_mem)
and that also had no effect.