|
39 | 39 | " \n", |
40 | 40 | " (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)\n", |
41 | 41 | "\n", |
42 | | - "- `GDSDataset` inherited from `PersistentDataset`.\n", |
43 | | - "\n", |
44 | | - " In this tutorial, we have implemented a `GDSDataset` that inherits from `PersistentDataset`. We have re-implemented the `_cachecheck` method to create and save cache using GDS.\n", |
45 | | - "\n", |
46 | 42 | "- A simple demo comparing the time taken with and without GDS.\n", |
47 | 43 | "\n", |
48 | 44 | " In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.\n", |
|
79 | 75 | }, |
80 | 76 | { |
81 | 77 | "cell_type": "code", |
82 | | - "execution_count": 1, |
| 78 | + "execution_count": null, |
83 | 79 | "metadata": {}, |
84 | 80 | "outputs": [], |
85 | 81 | "source": [ |
86 | 82 | "import os\n", |
87 | 83 | "import time\n", |
88 | | - "import cupy\n", |
89 | 84 | "import torch\n", |
90 | 85 | "import shutil\n", |
91 | 86 | "import tempfile\n", |
92 | | - "import numpy as np\n", |
93 | | - "from typing import Any\n", |
94 | | - "from pathlib import Path\n", |
95 | | - "from copy import deepcopy\n", |
96 | | - "from collections.abc import Callable, Sequence\n", |
97 | | - "from kvikio.numpy import fromfile, tofile\n", |
98 | 87 | "\n", |
99 | 88 | "import monai\n", |
100 | 89 | "import monai.transforms as mt\n", |
101 | 90 | "from monai.config import print_config\n", |
102 | | - "from monai.data.utils import pickle_hashing, SUPPORTED_PICKLE_MOD\n", |
103 | | - "from monai.utils import convert_to_tensor, set_determinism, look_up_option\n", |
| 91 | + "from monai.data.dataset import GDSDataset\n", |
| 92 | + "from monai.utils import set_determinism\n", |
104 | 93 | "\n", |
105 | 94 | "print_config()" |
106 | 95 | ] |
|
135 | 124 | "print(root_dir)" |
136 | 125 | ] |
137 | 126 | }, |
138 | | - { |
139 | | - "cell_type": "markdown", |
140 | | - "metadata": {}, |
141 | | - "source": [ |
142 | | - "## GDSDataset" |
143 | | - ] |
144 | | - }, |
145 | | - { |
146 | | - "cell_type": "code", |
147 | | - "execution_count": 3, |
148 | | - "metadata": {}, |
149 | | - "outputs": [], |
150 | | - "source": [ |
151 | | - "class GDSDataset(monai.data.PersistentDataset):\n", |
152 | | - " def __init__(\n", |
153 | | - " self,\n", |
154 | | - " data: Sequence,\n", |
155 | | - " transform: Sequence[Callable] | Callable,\n", |
156 | | - " cache_dir: Path | str | None,\n", |
157 | | - " hash_func: Callable[..., bytes] = pickle_hashing,\n", |
158 | | - " hash_transform: Callable[..., bytes] | None = None,\n", |
159 | | - " reset_ops_id: bool = True,\n", |
160 | | - " device: int = None,\n", |
161 | | - " **kwargs: Any,\n", |
162 | | - " ) -> None:\n", |
163 | | - " super().__init__(\n", |
164 | | - " data=data,\n", |
165 | | - " transform=transform,\n", |
166 | | - " cache_dir=cache_dir,\n", |
167 | | - " hash_func=hash_func,\n", |
168 | | - " hash_transform=hash_transform,\n", |
169 | | - " reset_ops_id=reset_ops_id,\n", |
170 | | - " **kwargs,\n", |
171 | | - " )\n", |
172 | | - " self.device = device\n", |
173 | | - "\n", |
174 | | - " def _cachecheck(self, item_transformed):\n", |
175 | | - " \"\"\"given the input dictionary ``item_transformed``, return a transformed version of it\"\"\"\n", |
176 | | - " hashfile = None\n", |
177 | | - " # compute a cache id\n", |
178 | | - " if self.cache_dir is not None:\n", |
179 | | - " data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n", |
180 | | - " data_item_md5 += self.transform_hash\n", |
181 | | - " hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n", |
182 | | - "\n", |
183 | | - " if hashfile is not None and hashfile.is_file(): # cache hit\n", |
184 | | - " with cupy.cuda.Device(self.device):\n", |
185 | | - " item = {}\n", |
186 | | - " for k in item_transformed:\n", |
187 | | - " meta_k = torch.load(self.cache_dir / f\"{hashfile.name}-{k}-meta\")\n", |
188 | | - " item[k] = fromfile(f\"{hashfile}-{k}\", dtype=np.float32, like=cupy.empty(()))\n", |
189 | | - " item[k] = convert_to_tensor(item[k].reshape(meta_k[\"shape\"]), device=f\"cuda:{self.device}\")\n", |
190 | | - " item[f\"{k}_meta_dict\"] = meta_k\n", |
191 | | - " return item\n", |
192 | | - "\n", |
193 | | - " # create new cache\n", |
194 | | - " _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed\n", |
195 | | - " if hashfile is None:\n", |
196 | | - " return _item_transformed\n", |
197 | | - "\n", |
198 | | - " for k in _item_transformed: # {'image': ..., 'label': ...}\n", |
199 | | - " _item_transformed_meta = _item_transformed[k].meta\n", |
200 | | - " _item_transformed_data = _item_transformed[k].array\n", |
201 | | - " _item_transformed_meta[\"shape\"] = _item_transformed_data.shape\n", |
202 | | - " tofile(_item_transformed_data, f\"{hashfile}-{k}\")\n", |
203 | | - " try:\n", |
204 | | - " # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n", |
205 | | - " # to make the cache more robust to manual killing of parent process\n", |
206 | | - " # which may leave partially written cache files in an incomplete state\n", |
207 | | - " with tempfile.TemporaryDirectory() as tmpdirname:\n", |
208 | | - " meta_hash_file_name = f\"{hashfile.name}-{k}-meta\"\n", |
209 | | - " meta_hash_file = self.cache_dir / meta_hash_file_name\n", |
210 | | - " temp_hash_file = Path(tmpdirname) / meta_hash_file_name\n", |
211 | | - " torch.save(\n", |
212 | | - " obj=_item_transformed_meta,\n", |
213 | | - " f=temp_hash_file,\n", |
214 | | - " pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n", |
215 | | - " pickle_protocol=self.pickle_protocol,\n", |
216 | | - " )\n", |
217 | | - " if temp_hash_file.is_file() and not meta_hash_file.is_file():\n", |
218 | | - " # On Unix, if target exists and is a file, it will be replaced silently if the\n", |
219 | | - " # user has permission.\n", |
220 | | - " # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n", |
221 | | - " try:\n", |
222 | | - " shutil.move(str(temp_hash_file), meta_hash_file)\n", |
223 | | - " except FileExistsError:\n", |
224 | | - " pass\n", |
225 | | - " except PermissionError: # project-monai/monai issue #3613\n", |
226 | | - " pass\n", |
227 | | - " open(hashfile, \"a\").close() # store cacheid\n", |
228 | | - "\n", |
229 | | - " return _item_transformed" |
230 | | - ] |
231 | | - }, |
232 | 127 | { |
233 | 128 | "cell_type": "markdown", |
234 | 129 | "metadata": {}, |
|
245 | 140 | }, |
246 | 141 | { |
247 | 142 | "cell_type": "code", |
248 | | - "execution_count": 4, |
| 143 | + "execution_count": 3, |
249 | 144 | "metadata": {}, |
250 | 145 | "outputs": [ |
251 | 146 | { |
252 | 147 | "name": "stdout", |
253 | 148 | "output_type": "stream", |
254 | 149 | "text": [ |
255 | | - "2023-07-12 09:26:17,878 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n", |
256 | | - "2023-07-12 09:26:17,878 - INFO - File exists: samples.zip, skipped downloading.\n", |
257 | | - "2023-07-12 09:26:17,879 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n" |
| 150 | + "2023-07-27 07:59:12,054 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n", |
| 151 | + "2023-07-27 07:59:12,055 - INFO - File exists: samples.zip, skipped downloading.\n", |
| 152 | + "2023-07-27 07:59:12,056 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n" |
258 | 153 | ] |
259 | 154 | } |
260 | 155 | ], |
|
283 | 178 | }, |
284 | 179 | { |
285 | 180 | "cell_type": "code", |
286 | | - "execution_count": 5, |
| 181 | + "execution_count": 4, |
287 | 182 | "metadata": {}, |
288 | 183 | "outputs": [], |
289 | 184 | "source": [ |
|
299 | 194 | }, |
300 | 195 | { |
301 | 196 | "cell_type": "code", |
302 | | - "execution_count": 6, |
| 197 | + "execution_count": 5, |
303 | 198 | "metadata": {}, |
304 | 199 | "outputs": [], |
305 | 200 | "source": [ |
|
332 | 227 | }, |
333 | 228 | { |
334 | 229 | "cell_type": "code", |
335 | | - "execution_count": 7, |
| 230 | + "execution_count": 6, |
336 | 231 | "metadata": {}, |
337 | 232 | "outputs": [ |
338 | 233 | { |
339 | 234 | "name": "stdout", |
340 | 235 | "output_type": "stream", |
341 | 236 | "text": [ |
342 | | - "epoch0 time 19.746733903884888\n", |
343 | | - "epoch1 time 0.9976603984832764\n", |
344 | | - "epoch2 time 0.982248067855835\n", |
345 | | - "epoch3 time 0.9838874340057373\n", |
346 | | - "epoch4 time 0.9793403148651123\n", |
347 | | - "total time 23.69102692604065\n" |
| 237 | + "epoch0 time 20.148560762405396\n", |
| 238 | + "epoch1 time 0.9835140705108643\n", |
| 239 | + "epoch2 time 0.9708101749420166\n", |
| 240 | + "epoch3 time 0.9711742401123047\n", |
| 241 | + "epoch4 time 0.9711296558380127\n", |
| 242 | + "total time 24.04619812965393\n" |
348 | 243 | ] |
349 | 244 | } |
350 | 245 | ], |
|
372 | 267 | }, |
373 | 268 | { |
374 | 269 | "cell_type": "code", |
375 | | - "execution_count": 8, |
| 270 | + "execution_count": 7, |
376 | 271 | "metadata": {}, |
377 | 272 | "outputs": [ |
378 | 273 | { |
379 | 274 | "name": "stdout", |
380 | 275 | "output_type": "stream", |
381 | 276 | "text": [ |
382 | | - "epoch0 time 21.206729650497437\n", |
383 | | - "epoch1 time 1.510526180267334\n", |
384 | | - "epoch2 time 1.588256597518921\n", |
385 | | - "epoch3 time 1.4431262016296387\n", |
386 | | - "epoch4 time 1.4594802856445312\n", |
387 | | - "total time 27.20927882194519\n" |
| 277 | + "epoch0 time 21.170511722564697\n", |
| 278 | + "epoch1 time 1.482978105545044\n", |
| 279 | + "epoch2 time 1.5378782749176025\n", |
| 280 | + "epoch3 time 1.4499244689941406\n", |
| 281 | + "epoch4 time 1.4379286766052246\n", |
| 282 | + "total time 27.08065962791443\n" |
388 | 283 | ] |
389 | 284 | } |
390 | 285 | ], |
|
0 commit comments