Skip to content

CPU Memory Usage for Tasks with CPU-GPU Transfer #1351

Open

Description

Issue + Reproducers

So I have an i/o job that reads in data to the CPU and passes to the GPU in a map_blocks call, and then uses CuPy downstream for a non-standard map-blocks call. Here is the reproducer minus the i/o:

from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import dask.array as da
import dask

import cupy as cp
import rmm
import numpy as np
from rmm.allocators.cupy import rmm_cupy_allocator
from cupyx.scipy import sparse as cp_sparse
from scipy import sparse

def set_mem():
    rmm.reinitialize(managed_memory=True)
    cp.cuda.set_allocator(rmm_cupy_allocator)

cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0")
client = Client(cluster)
client.run(set_mem)




M = 100_000
N = 4_000

def make_chunk():
    arr = np.random.random((M,N))
    chunk = cp.array(arr)
    del arr
    return chunk

arr = da.map_blocks(make_chunk, meta=cp.array((1.,), dtype=cp.float64), dtype=cp.float64, chunks=((M,) * 50, (N,) * 1))
blocks = arr.to_delayed().ravel()

def __gram_block(block):
    return block.T @ block

gram_chunk_matrices = da.map_blocks(__gram_block, arr, chunks=((arr.shape[1],) * len(blocks) , (arr.shape[1],)), dtype=arr.dtype, meta=cp.array([]))
gram_chunk_matrices = gram_chunk_matrices.reshape(len(blocks), arr.shape[1], arr.shape[1])
gram_matrix = gram_chunk_matrices.sum(axis=0).compute()

This uses an unaccountable amount of CPU memory, on the order of 4-8 GB. But I have no idea why this is happening. I don't have any CPU memory that should be used here except the initial read. And when the job completes, dask still reports that it is holding on to 4GB (!) of memory. I see at most 2 tasks running with another 2-6 in memory. In total, the CPU memory being so high doesn't make sense since the individual numpy arrays are 320MB, so this should be at most 640MB (and even that seems high given how long they last on the CPU before I call del). I don't think this is a dask-memory-reporting issue because top shows the same amount of memory usage.

I also don't think this has to do with the computation I chose as:

def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices = doubled_matrices.reshape(len(blocks), M, N)

doubled_matrix = doubled_matrices.sum(axis=0).compute()

has the same issue, albeit with a warning about chunk sizes (although I'm not sure why I'm getting the warning since the reshape is right along the blocks). In any case, it's the same memory problem. Minus the reshape, and so minus that warning, same behavior:

def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()

Some env info:

Python 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] on linux
dask==2024.2.1
dask-cuda==24.6.0
dask-cudf-cu12==24.6.0
distributed==2024.2.1
distributed-ucxx-cu12==0.38.0
cuda-python==12.5.0
cudf-cu12==24.6.0
cugraph-cu12==24.6.1
cupy-cuda12x==13.2.0
numpy==1.26.4

Some monitoring screenshots:

Screenshot 2024-06-21 at 16 19 07
Screenshot 2024-06-21 at 16 18 51

Addendum 1: I don't see this as spilling

I also don't think this is spilling because the GPU memory is not that high:

Screenshot 2024-06-24 at 10 48 59

That number is also basically correct: ((4000 * 4000) * 4 + (100_000 * 4000) * 4) < 2GB where the first 4000 * 4000 is from holding the sum(-partial) in memory and then 100_000 * 4000 is the input data.

Addendum 2: This is GPU specific

This behavior does not happen on CPU dask:

import dask.distributed as dd
import numpy as np
cluster = dd.LocalCluster(n_workers=1)
client = dd.Client(cluster)


M = 100_000
N = 4_000
def make_chunk():
    arr = np.random.random((M,N))
    return arr

arr = da.map_blocks(make_chunk, meta=np.array((1.,), dtype=np.float64), dtype=np.float64, chunks=((M,) * 50, (N,) * 1))
def __double_block(block):
    return block * 2

doubled_matrices = da.map_blocks(__double_block, arr)
doubled_matrices.sum(axis=0).compute()

I see the memory use fluctuating, but the baseline of 1.5GB makes sense given the in-memory/processing stats I cited above. It also releases the memory at the end

Addendum 3: Not an rmm issue

I tried commenting out client.run(set_mem) and that also had no effect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions