Skip to content

remove the draft version of the GDSDataset #1473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 23 additions & 128 deletions modules/GDS_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
" \n",
" (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)\n",
"\n",
"- `GDSDataset` inherited from `PersistentDataset`.\n",
"\n",
" In this tutorial, we have implemented a `GDSDataset` that inherits from `PersistentDataset`. We have re-implemented the `_cachecheck` method to create and save cache using GDS.\n",
"\n",
"- A simple demo comparing the time taken with and without GDS.\n",
"\n",
" In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.\n",
Expand Down Expand Up @@ -79,28 +75,21 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"import cupy\n",
"import torch\n",
"import shutil\n",
"import tempfile\n",
"import numpy as np\n",
"from typing import Any\n",
"from pathlib import Path\n",
"from copy import deepcopy\n",
"from collections.abc import Callable, Sequence\n",
"from kvikio.numpy import fromfile, tofile\n",
"\n",
"import monai\n",
"import monai.transforms as mt\n",
"from monai.config import print_config\n",
"from monai.data.utils import pickle_hashing, SUPPORTED_PICKLE_MOD\n",
"from monai.utils import convert_to_tensor, set_determinism, look_up_option\n",
"from monai.data.dataset import GDSDataset\n",
"from monai.utils import set_determinism\n",
"\n",
"print_config()"
]
Expand Down Expand Up @@ -135,100 +124,6 @@
"print(root_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GDSDataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class GDSDataset(monai.data.PersistentDataset):\n",
" def __init__(\n",
" self,\n",
" data: Sequence,\n",
" transform: Sequence[Callable] | Callable,\n",
" cache_dir: Path | str | None,\n",
" hash_func: Callable[..., bytes] = pickle_hashing,\n",
" hash_transform: Callable[..., bytes] | None = None,\n",
" reset_ops_id: bool = True,\n",
" device: int = None,\n",
" **kwargs: Any,\n",
" ) -> None:\n",
" super().__init__(\n",
" data=data,\n",
" transform=transform,\n",
" cache_dir=cache_dir,\n",
" hash_func=hash_func,\n",
" hash_transform=hash_transform,\n",
" reset_ops_id=reset_ops_id,\n",
" **kwargs,\n",
" )\n",
" self.device = device\n",
"\n",
" def _cachecheck(self, item_transformed):\n",
" \"\"\"given the input dictionary ``item_transformed``, return a transformed version of it\"\"\"\n",
" hashfile = None\n",
" # compute a cache id\n",
" if self.cache_dir is not None:\n",
" data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n",
" data_item_md5 += self.transform_hash\n",
" hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n",
"\n",
" if hashfile is not None and hashfile.is_file(): # cache hit\n",
" with cupy.cuda.Device(self.device):\n",
" item = {}\n",
" for k in item_transformed:\n",
" meta_k = torch.load(self.cache_dir / f\"{hashfile.name}-{k}-meta\")\n",
" item[k] = fromfile(f\"{hashfile}-{k}\", dtype=np.float32, like=cupy.empty(()))\n",
" item[k] = convert_to_tensor(item[k].reshape(meta_k[\"shape\"]), device=f\"cuda:{self.device}\")\n",
" item[f\"{k}_meta_dict\"] = meta_k\n",
" return item\n",
"\n",
" # create new cache\n",
" _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed\n",
" if hashfile is None:\n",
" return _item_transformed\n",
"\n",
" for k in _item_transformed: # {'image': ..., 'label': ...}\n",
" _item_transformed_meta = _item_transformed[k].meta\n",
" _item_transformed_data = _item_transformed[k].array\n",
" _item_transformed_meta[\"shape\"] = _item_transformed_data.shape\n",
" tofile(_item_transformed_data, f\"{hashfile}-{k}\")\n",
" try:\n",
" # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n",
" # to make the cache more robust to manual killing of parent process\n",
" # which may leave partially written cache files in an incomplete state\n",
" with tempfile.TemporaryDirectory() as tmpdirname:\n",
" meta_hash_file_name = f\"{hashfile.name}-{k}-meta\"\n",
" meta_hash_file = self.cache_dir / meta_hash_file_name\n",
" temp_hash_file = Path(tmpdirname) / meta_hash_file_name\n",
" torch.save(\n",
" obj=_item_transformed_meta,\n",
" f=temp_hash_file,\n",
" pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n",
" pickle_protocol=self.pickle_protocol,\n",
" )\n",
" if temp_hash_file.is_file() and not meta_hash_file.is_file():\n",
" # On Unix, if target exists and is a file, it will be replaced silently if the\n",
" # user has permission.\n",
" # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n",
" try:\n",
" shutil.move(str(temp_hash_file), meta_hash_file)\n",
" except FileExistsError:\n",
" pass\n",
" except PermissionError: # project-monai/monai issue #3613\n",
" pass\n",
" open(hashfile, \"a\").close() # store cacheid\n",
"\n",
" return _item_transformed"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -245,16 +140,16 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-07-12 09:26:17,878 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n",
"2023-07-12 09:26:17,878 - INFO - File exists: samples.zip, skipped downloading.\n",
"2023-07-12 09:26:17,879 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n"
"2023-07-27 07:59:12,054 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n",
"2023-07-27 07:59:12,055 - INFO - File exists: samples.zip, skipped downloading.\n",
"2023-07-27 07:59:12,056 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n"
]
}
],
Expand Down Expand Up @@ -283,7 +178,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -299,7 +194,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -332,19 +227,19 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch0 time 19.746733903884888\n",
"epoch1 time 0.9976603984832764\n",
"epoch2 time 0.982248067855835\n",
"epoch3 time 0.9838874340057373\n",
"epoch4 time 0.9793403148651123\n",
"total time 23.69102692604065\n"
"epoch0 time 20.148560762405396\n",
"epoch1 time 0.9835140705108643\n",
"epoch2 time 0.9708101749420166\n",
"epoch3 time 0.9711742401123047\n",
"epoch4 time 0.9711296558380127\n",
"total time 24.04619812965393\n"
]
}
],
Expand Down Expand Up @@ -372,19 +267,19 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch0 time 21.206729650497437\n",
"epoch1 time 1.510526180267334\n",
"epoch2 time 1.588256597518921\n",
"epoch3 time 1.4431262016296387\n",
"epoch4 time 1.4594802856445312\n",
"total time 27.20927882194519\n"
"epoch0 time 21.170511722564697\n",
"epoch1 time 1.482978105545044\n",
"epoch2 time 1.5378782749176025\n",
"epoch3 time 1.4499244689941406\n",
"epoch4 time 1.4379286766052246\n",
"total time 27.08065962791443\n"
]
}
],
Expand Down