Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom spilling handler  #12287

Open
wants to merge 13 commits into
base: branch-23.04
Choose a base branch
from
70 changes: 64 additions & 6 deletions python/cudf/cudf/core/buffer/spill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import threading
import traceback
import warnings
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
from weakref import WeakKeyDictionary, WeakValueDictionary

import rmm.mr

Expand Down Expand Up @@ -218,7 +218,10 @@ class SpillManager:
SpillStatistics for the different levels.
"""

_buffers: weakref.WeakValueDictionary[int, SpillableBuffer]
_buffers: WeakValueDictionary[int, SpillableBuffer]
_spill_handlers: WeakKeyDictionary[
SpillableBuffer, Tuple[Callable[..., Optional[int]], Tuple, Dict]
]
statistics: SpillStatistics

def __init__(
Expand All @@ -229,10 +232,11 @@ def __init__(
statistic_level: int = 0,
) -> None:
self._lock = threading.Lock()
self._buffers = weakref.WeakValueDictionary()
self._buffers = WeakValueDictionary()
self._id_counter = 0
self._spill_on_demand = spill_on_demand
self._device_memory_limit = device_memory_limit
self._spill_handlers = WeakKeyDictionary()
self.statistics = SpillStatistics(statistic_level)

if self._spill_on_demand:
Expand Down Expand Up @@ -347,13 +351,22 @@ def spill_device_memory(self, nbytes: int) -> int:
"""
spilled = 0
for buf in self.buffers(order_by_access_time=True):
if spilled >= nbytes:
break
if buf.lock.acquire(blocking=False):
try:
if not buf.is_spilled and buf.spillable:
# Check if `buf` has a registered spill handler
handler = self._spill_handlers.get(buf, None)
if handler is not None:
madsbk marked this conversation as resolved.
Show resolved Hide resolved
func, args, kwargs = handler
s = func(*args, **kwargs)
if s is not None:
madsbk marked this conversation as resolved.
Show resolved Hide resolved
spilled += s
self._spill_handlers.pop(buf, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit odd that the spill handler would be removed when the buffer is spilled. Is that always what we want?

Also, the usage makes it clear that these spill handlers are only appropriate for device to host spilling. Should we key the spill_handlers dict on both the buffer and the source/target of the spill?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see why it's removed, I guess it's being added every time the cached property is accessed.

continue
buf.spill(target="cpu")
spilled += buf.size
if spilled >= nbytes:
break
finally:
buf.lock.release()
return spilled
Expand Down Expand Up @@ -385,6 +398,51 @@ def spill_to_device_limit(self, device_limit: int = None) -> int:
)
return self.spill_device_memory(nbytes=unspilled - limit)

def register_spill_handler(
self,
buffer: SpillableBuffer,
func: Callable[..., Optional[int]],
*args,
**kwargs,
) -> None:
"""Register a spill handler for a buffer

This enables customization of how to handle the spilling of a specific
buffer. When the spill manager chooses to spill the buffer, it calls
the provided callback function instead of spilling the buffer itself.

The callback function is called like `func(*args, **kwargs)` and must
return the number of bytes freed or None. If None, the spill manager
will spill `buffer`.

Warning
-------
The spill manager keeps a reference to `func`, `args`, and `kwargs`
thus everything they reference are also kept alive.

Parameters
----------
buffer : SpillableBuffer
The buffer `func` handle.
func : Callable[*args, **kwargs, Optional[int]]
The spill handler
*args
Positional arguments pass to `func`
**kwargs
Keyword arguments pass to `func`

Return
------
int
The number of bytes spilled or freed.
"""

if buffer in self._spill_handlers:
raise RuntimeError(
f"Spill handler already registered for {buffer}"
)
self._spill_handlers[buffer] = (func, args, kwargs)

def __repr__(self) -> str:
spilled = sum(buf.size for buf in self.buffers() if buf.is_spilled)
unspilled = sum(
Expand Down
88 changes: 87 additions & 1 deletion python/cudf/cudf/core/buffer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

from __future__ import annotations

import functools
import sys
import threading
import weakref
from contextlib import ContextDecorator
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, TypeVar, Union

import cudf
from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper
from cudf.core.buffer.spill_manager import get_global_manager
from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock

T = TypeVar("T")


def as_buffer(
data: Union[int, Any],
Expand Down Expand Up @@ -134,3 +140,83 @@ def get_spill_lock() -> Union[SpillLock, None]:
_id = threading.get_ident()
spill_lock, _ = _thread_spill_locks.get(_id, (None, 0))
return spill_lock


def _clear_property_cache(
instance_ref: weakref.ReferenceType[T], nbytes: int, attrname: str
) -> Optional[int]:
"""Spill handler that clears the `cached_property` of an instance

The signature of this function is compatible with SpillManager's
register_spill_handler.

To avoid keeping instance alive, we take a weak reference of the instance.
madsbk marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
instance_ref
Weakref of the instance
nbytes : int
Size of the cached data
attrname : str
Name of the cached attribute

Return
------
int
Number of bytes cleared
"""

instance = instance_ref()
if instance is None:
return 0

cached = instance.__dict__.get(attrname, None)
if cached is None:
return None # The cached has been cleared
madsbk marked this conversation as resolved.
Show resolved Hide resolved

# If `cached` is known outside of the cache, we cannot free any
# memory by clearing the cache. We have three inside references:
# `instance.__dict__`, `cached`, and `sys.getrefcount`.
if sys.getrefcount(cached) > 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this line disconcerting, but I'm having trouble articulating exactly why so I'm just going to try to explain my thought process and see if we can resolve my confusion.

At the point where we call buf.spill (and now the spill handler) we already know that the buffer is spillable. A buffer is spillable according to our existing logic if it hasn't been handed out to an external consumer and if there are no spill locks around it. If we get to to this point and the number of references is greater than 3, doesn't that imply that we should just spill? It seems like at that point you're in the exact same scenario that you would be for any other SpillableBuffer where a future use could try to unspill it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we get to to this point and the number of references is greater than 3, doesn't that imply that we should just spill? It seems like at that point you're in the exact same scenario that you would be for any other SpillableBuffer where a future use could try to unspill it.

Correct, this is also what is done. By returning None, we are telling the spill manager to spill the buffer as usual.
We could move this responsibility to the handler always. It would then be up to the handler to spill the buffer?

return None

instance.__dict__.pop(attrname, None) # Clear cache atomically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relying on the atomic behavior of Python built-ins has bit me in the past where the line of Python code turns out to not actually be a single atomic operation. Would probably be best to wrap this in a lock to be absolutely sure it's thread-safe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A local lock here would not help since changes to the class instance could come from anywhere. The fact that dict access is thread safe is relied on throughout Python including in functools.cached_property

return nbytes


class cached_property(functools.cached_property):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to use a different name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe cached_property_no_spill or something.

"""A version of `cached_property` that delete instead of spill the cache
madsbk marked this conversation as resolved.
Show resolved Hide resolved

When spilling is disabled (the default case), this decorator is identical
to `functools.cached_property`.

When spilling is enabled, this property register a spill handler for
the cached data that deletes the data rather than spilling it. For now,
only cached Columns are handled this way.
See `SpillManager.register_spill_handler`.
"""

def __get__(self, instance: T, owner=None):
cache_hit = self.attrname in instance.__dict__
ret = super().__get__(instance, owner)
if cache_hit or not isinstance(ret, cudf.core.column.ColumnBase):
return ret

manager = get_global_manager()
if manager is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what circumstances could this be None? Wouldn't the call to ret = super().__get__(instance, owner) trigger the creation of a SpillableBuffer (assuming this decorator is only applied to functions that return something that wraps a SpillableBuffer) which in turn would always lead to the creation of a global SpillManager if one isn't already initialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to handle the case where spilling is disabled. Notice, this decorator is also used when spilling is disabled.  

return ret

buf = ret.base_data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the mask buffer of the Column, or - in the case of nested columns - its children?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.  

I have simplified the code a bit. Instead of supporting any type, we now check that the instance is a RangeIndex. I haven't found any other use of cached_property where the cache can be the sole owner of a buffer, so I think we should limit the scope to RangeIndex for now. 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems related to the question I asked above. I'm OK with not overengineering to support anything else, although it would be good to ensure that if we do need to support other types later it only requires adding new code and not restructuring the existing code (hence my comment above).

if buf is None or buf.nbytes == 0:
return ret
assert isinstance(buf, SpillableBuffer)

manager.register_spill_handler(
buf,
_clear_property_cache,
weakref.ref(instance),
nbytes=buf.nbytes,
attrname=self.attrname,
)
return ret
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import math
import pickle
import warnings
from functools import cached_property
from numbers import Number
from typing import (
Any,
Expand Down Expand Up @@ -38,6 +37,7 @@
is_string_dtype,
)
from cudf.core._base_index import BaseIndex, _index_astype_docstring
from cudf.core.buffer.utils import cached_property
from cudf.core.column import (
CategoricalColumn,
ColumnBase,
Expand Down Expand Up @@ -245,7 +245,7 @@ def step(self):
def _num_rows(self):
return len(self)

@cached_property # type: ignore
@cached_property
@_cudf_nvtx_annotate
def _values(self):
if len(self) > 0:
Expand Down
56 changes: 56 additions & 0 deletions python/cudf/cudf/tests/test_spilling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import cudf
import cudf.core.buffer.spill_manager
import cudf.options
from cudf._lib.column import Column
from cudf.core.abc import Serializable
from cudf.core.buffer import (
Buffer,
Expand All @@ -36,6 +37,8 @@
SpillableBufferSlice,
SpillLock,
)
from cudf.core.buffer.utils import cached_property
from cudf.core.column import column
from cudf.testing._utils import assert_eq

if get_global_manager() is not None:
Expand Down Expand Up @@ -609,3 +612,56 @@ def test_statistics_expose(manager: SpillManager):
assert stat.count == 10
assert stat.total_nbytes == buffers[0].nbytes * 10
assert stat.spilled_nbytes == buffers[0].nbytes * 10


def test_cached_property(manager: SpillManager):
class ClassWithCachedColumn:
@cached_property
def cached_column(self) -> Column:
return column.arange(3)

# Check that a spill handler is created
c = ClassWithCachedColumn()
col = c.cached_column
assert len(manager.buffers()) == 1
assert manager.buffers()[0] is col.base_data
assert len(manager._spill_handlers) == 1

# Since we have a ref to `col`, the cache is spilled
assert manager.spill_device_memory(nbytes=1) == gen_df_data_nbytes
assert len(manager.buffers()) == 1
assert len(manager._spill_handlers) == 1

# Let's unspill and delete our ref to `col`. We still have the
# cached buffer and its spill handler
col.base_data.spill(target="gpu")
del col
assert len(manager.buffers()) == 1
assert len(manager._spill_handlers) == 1

# However, now that we have removed the ref to `col`, spilling the
# cached buffer, will clear the cache
assert manager.spill_device_memory(nbytes=1) == gen_df_data_nbytes
assert len(manager.buffers()) == 0
assert len(manager._spill_handlers) == 0


def test_spilling_of_range_index(manager: SpillManager):
df = single_column_df(target="gpu")
assert isinstance(df.index, cudf.RangeIndex)
assert spilled_and_unspilled(manager) == (0, gen_df_data_nbytes)

# materialize the index
df.index._values
assert spilled_and_unspilled(manager) == (0, gen_df_data_nbytes * 2)

# spill the column, which has the oldest access time
manager.spill_device_memory(nbytes=1)
assert spilled_and_unspilled(manager) == (
gen_df_data_nbytes,
gen_df_data_nbytes,
)

# spill the index, which is deleted instead of spilled.
manager.spill_device_memory(nbytes=1)
assert spilled_and_unspilled(manager) == (gen_df_data_nbytes, 0)