Skip to content

Add GDSdataset #6778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CSVDataset,
Dataset,
DatasetFunc,
GDSDataset,
LMDBDataset,
NPZDictItemDataset,
PersistentDataset,
Expand Down
178 changes: 176 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import (
Compose,
Expand All @@ -44,7 +45,7 @@
convert_to_contiguous,
reset_ops_id,
)
from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import
from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

if TYPE_CHECKING:
Expand All @@ -54,8 +55,10 @@
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

cp, _ = optional_import("cupy")
lmdb, _ = optional_import("lmdb")
pd, _ = optional_import("pandas")
kvikio_numpy, _ = optional_import("kvikio.numpy")


class Dataset(_TorchDataset):
Expand Down Expand Up @@ -326,7 +329,6 @@ def _pre_transform(self, item_transformed):
first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)

item_transformed = self.transform(item_transformed, end=first_random, threading=True)

if self.reset_ops_id:
Expand Down Expand Up @@ -1510,3 +1512,175 @@ def __init__(
dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs
)
super().__init__(data=data, transform=transform)


class GDSDataset(PersistentDataset):
"""
An extension of the PersistentDataset using direct memory access(DMA) data path between
GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system
bandwidth while decreasing latency and utilization load on the CPU and GPU.

A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.
"""

def __init__(
self,
data: Sequence,
transform: Sequence[Callable] | Callable,
cache_dir: Path | str | None,
device: int,
hash_func: Callable[..., bytes] = pickle_hashing,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
**kwargs: Any,
) -> None:
"""
Args:
data: input data file paths to load and transform to generate dataset for model.
`GDSDataset` expects input data to be a list of serializable
and hashes them as cache keys using `hash_func`.
transform: transforms to execute operations on input data.
cache_dir: If specified, this is the location for gpu direct storage
of pre-computed transformed data tensors. The cache_dir is computed once, and
persists on disk until explicitly removed. Different runs, programs, experiments
may share a common cache dir provided that the transforms pre-processing is consistent.
If `cache_dir` doesn't exist, will automatically create it.
If `cache_dir` is `None`, there is effectively no caching.
device: target device to put the output Tensor data. Note that only int can be used to
specify the gpu to be used.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.

"""
super().__init__(
data=data,
transform=transform,
cache_dir=cache_dir,
hash_func=hash_func,
hash_transform=hash_transform,
reset_ops_id=reset_ops_id,
**kwargs,
)
self.device = device
self._meta_cache: dict[Any, dict[Any, Any]] = {}

def _cachecheck(self, item_transformed):
"""
In order to enable direct storage to the GPU when loading the hashfile, rewritten this function.
Note that in this function, it will always return `torch.Tensor` when load data from cache.

Args:
item_transformed: The current data element to be mutated into transformed representation

Returns:
The transformed data_element, either from cache, or explicitly computing it.

Warning:
The current implementation does not encode transform information as part of the
hashing mechanism used for generating cache names when `hash_transform` is None.
If the transforms applied are changed in any way, the objects in the cache dir will be invalid.

"""
hashfile = None
# compute a cache id
if self.cache_dir is not None:
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
hashfile = self.cache_dir / f"{data_item_md5}.pt"

if hashfile is not None and hashfile.is_file(): # cache hit
with cp.cuda.Device(self.device):
if isinstance(item_transformed, dict):
item: dict[Any, Any] = {} # type:ignore
for k in item_transformed:
meta_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta")
item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(()))
item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}")
item[f"{k}_meta_dict"] = meta_k
return item
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}")
if bool(_meta):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(()))
item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
return item

# create new cache
_item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed
if hashfile is None:
return _item_transformed
if isinstance(_item_transformed, dict):
for k in _item_transformed:
data_hashfile = f"{hashfile}-{k}"
meta_hash_file_name = f"{hashfile.name}-{k}-meta"
if isinstance(_item_transformed[k], (np.ndarray, torch.Tensor)):
self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name)
else:
return _item_transformed
elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)):
data_hashfile = f"{hashfile}"
meta_hash_file_name = f"{hashfile.name}-meta"
self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name)
else:
for i, _item in enumerate(_item_transformed):
for k in _item:
data_hashfile = f"{hashfile}-{k}-{i}"
meta_hash_file_name = f"{hashfile.name}-{k}-meta-{i}"
self._create_new_cache(_item, data_hashfile, meta_hash_file_name)
open(hashfile, "a").close() # store cacheid
return _item_transformed

def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {}
_item_transformed_data = data.array if isinstance(data, MetaTensor) else data
if isinstance(_item_transformed_data, torch.Tensor):
_item_transformed_data = _item_transformed_data.numpy()
self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape
self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype
kvikio_numpy.tofile(_item_transformed_data, data_hashfile)
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
# to make the cache more robust to manual killing of parent process
# which may leave partially written cache files in an incomplete state
with tempfile.TemporaryDirectory() as tmpdirname:
meta_hash_file = self.cache_dir / meta_hash_file_name
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
torch.save(
obj=self._meta_cache[meta_hash_file_name],
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
)
if temp_hash_file.is_file() and not meta_hash_file.is_file():
# On Unix, if target exists and is a file, it will be replaced silently if the
# user has permission.
# for more details: https://docs.python.org/3/library/shutil.html#shutil.move.
try:
shutil.move(str(temp_hash_file), meta_hash_file)
except FileExistsError:
pass
except PermissionError: # project-monai/monai issue #3613
pass

def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
return torch.load(self.cache_dir / meta_hash_file_name) # type:ignore
Loading