3
3
import itertools
4
4
import json
5
5
import warnings
6
- from collections .abc import Hashable , Iterator , Sequence
6
+ from collections .abc import Callable , Hashable , Iterator , Sequence
7
7
from operator import itemgetter
8
- from typing import Any , Callable , Optional , Union
8
+ from typing import Any
9
9
10
10
import numpy as np
11
11
import xarray as xr
@@ -55,10 +55,10 @@ class BatchSchema:
55
55
56
56
def __init__ (
57
57
self ,
58
- ds : Union [ xr .Dataset , xr .DataArray ] ,
58
+ ds : xr .Dataset | xr .DataArray ,
59
59
input_dims : dict [Hashable , int ],
60
- input_overlap : Optional [ dict [Hashable , int ]] = None ,
61
- batch_dims : Optional [ dict [Hashable , int ]] = None ,
60
+ input_overlap : dict [Hashable , int ] | None = None ,
61
+ batch_dims : dict [Hashable , int ] | None = None ,
62
62
concat_input_bins : bool = True ,
63
63
preload_batch : bool = True ,
64
64
):
@@ -91,9 +91,7 @@ def __init__(
91
91
)
92
92
self .selectors : BatchSelectorSet = self ._gen_batch_selectors (ds )
93
93
94
- def _gen_batch_selectors (
95
- self , ds : Union [xr .DataArray , xr .Dataset ]
96
- ) -> BatchSelectorSet :
94
+ def _gen_batch_selectors (self , ds : xr .DataArray | xr .Dataset ) -> BatchSelectorSet :
97
95
"""
98
96
Create batch selectors dict, which can be used to create a batch
99
97
from an Xarray data object.
@@ -106,9 +104,7 @@ def _gen_batch_selectors(
106
104
else : # Each patch gets its own batch
107
105
return {ind : [value ] for ind , value in enumerate (patch_selectors )}
108
106
109
- def _gen_patch_selectors (
110
- self , ds : Union [xr .DataArray , xr .Dataset ]
111
- ) -> PatchGenerator :
107
+ def _gen_patch_selectors (self , ds : xr .DataArray | xr .Dataset ) -> PatchGenerator :
112
108
"""
113
109
Create an iterator that can be used to index an Xarray Dataset/DataArray.
114
110
"""
@@ -127,7 +123,7 @@ def _gen_patch_selectors(
127
123
return all_slices
128
124
129
125
def _combine_patches_into_batch (
130
- self , ds : Union [ xr .DataArray , xr .Dataset ] , patch_selectors : PatchGenerator
126
+ self , ds : xr .DataArray | xr .Dataset , patch_selectors : PatchGenerator
131
127
) -> BatchSelectorSet :
132
128
"""
133
129
Combine the patch selectors to form a batch
@@ -169,7 +165,7 @@ def _combine_patches_grouped_by_batch_dims(
169
165
return dict (enumerate (batch_selectors ))
170
166
171
167
def _combine_patches_grouped_by_input_and_batch_dims (
172
- self , ds : Union [ xr .DataArray , xr .Dataset ] , patch_selectors : PatchGenerator
168
+ self , ds : xr .DataArray | xr .Dataset , patch_selectors : PatchGenerator
173
169
) -> BatchSelectorSet :
174
170
"""
175
171
Combine patches with multiple slices along ``batch_dims`` grouped into
@@ -197,7 +193,7 @@ def _gen_empty_batch_selectors(self) -> BatchSelectorSet:
197
193
n_batches = np .prod (list (self ._n_batches_per_dim .values ()))
198
194
return {k : [] for k in range (n_batches )}
199
195
200
- def _gen_patch_numbers (self , ds : Union [ xr .DataArray , xr .Dataset ] ):
196
+ def _gen_patch_numbers (self , ds : xr .DataArray | xr .Dataset ):
201
197
"""
202
198
Calculate the number of patches per dimension and the number of patches
203
199
in each batch per dimension.
@@ -214,7 +210,7 @@ def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
214
210
for dim , length in self ._all_sliced_dims .items ()
215
211
}
216
212
217
- def _gen_batch_numbers (self , ds : Union [ xr .DataArray , xr .Dataset ] ):
213
+ def _gen_batch_numbers (self , ds : xr .DataArray | xr .Dataset ):
218
214
"""
219
215
Calculate the number of batches per dimension
220
216
"""
@@ -324,7 +320,7 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[sli
324
320
325
321
326
322
def _iterate_through_dimensions (
327
- ds : Union [ xr .Dataset , xr .DataArray ] ,
323
+ ds : xr .Dataset | xr .DataArray ,
328
324
* ,
329
325
dims : dict [Hashable , int ],
330
326
overlap : dict [Hashable , int ] = {},
@@ -350,10 +346,10 @@ def _iterate_through_dimensions(
350
346
351
347
352
348
def _drop_input_dims (
353
- ds : Union [ xr .Dataset , xr .DataArray ] ,
349
+ ds : xr .Dataset | xr .DataArray ,
354
350
input_dims : dict [Hashable , int ],
355
351
suffix : str = '_input' ,
356
- ) -> Union [ xr .Dataset , xr .DataArray ] :
352
+ ) -> xr .Dataset | xr .DataArray :
357
353
# remove input_dims coordinates from datasets, rename the dimensions
358
354
# then put intput_dims back in as coordinates
359
355
out = ds .copy ()
@@ -368,9 +364,9 @@ def _drop_input_dims(
368
364
369
365
370
366
def _maybe_stack_batch_dims (
371
- ds : Union [ xr .Dataset , xr .DataArray ] ,
367
+ ds : xr .Dataset | xr .DataArray ,
372
368
input_dims : Sequence [Hashable ],
373
- ) -> Union [ xr .Dataset , xr .DataArray ] :
369
+ ) -> xr .Dataset | xr .DataArray :
374
370
batch_dims = [d for d in ds .sizes if d not in input_dims ]
375
371
if len (batch_dims ) < 2 :
376
372
return ds
@@ -424,14 +420,14 @@ class BatchGenerator:
424
420
425
421
def __init__ (
426
422
self ,
427
- ds : Union [ xr .Dataset , xr .DataArray ] ,
423
+ ds : xr .Dataset | xr .DataArray ,
428
424
input_dims : dict [Hashable , int ],
429
425
input_overlap : dict [Hashable , int ] = {},
430
426
batch_dims : dict [Hashable , int ] = {},
431
427
concat_input_dims : bool = False ,
432
428
preload_batch : bool = True ,
433
- cache : Optional [ dict [str , Any ]] = None ,
434
- cache_preprocess : Optional [ Callable ] = None ,
429
+ cache : dict [str , Any ] | None = None ,
430
+ cache_preprocess : Callable | None = None ,
435
431
):
436
432
self .ds = ds
437
433
self .cache = cache
@@ -466,14 +462,14 @@ def concat_input_dims(self):
466
462
def preload_batch (self ):
467
463
return self ._batch_selectors .preload_batch
468
464
469
- def __iter__ (self ) -> Iterator [Union [ xr .DataArray , xr .Dataset ] ]:
465
+ def __iter__ (self ) -> Iterator [xr .DataArray | xr .Dataset ]:
470
466
for idx in self ._batch_selectors .selectors :
471
467
yield self [idx ]
472
468
473
469
def __len__ (self ) -> int :
474
470
return len (self ._batch_selectors .selectors )
475
471
476
- def __getitem__ (self , idx : int ) -> Union [ xr .Dataset , xr .DataArray ] :
472
+ def __getitem__ (self , idx : int ) -> xr .Dataset | xr .DataArray :
477
473
if not isinstance (idx , int ):
478
474
raise NotImplementedError (
479
475
f'{ type (self ).__name__ } .__getitem__ currently requires a single integer key'
@@ -532,7 +528,7 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
532
528
def _batch_in_cache (self , idx : int ) -> bool :
533
529
return self .cache is not None and f'{ idx } /.zgroup' in self .cache
534
530
535
- def _cache_batch (self , idx : int , batch : Union [ xr .Dataset , xr .DataArray ] ) -> None :
531
+ def _cache_batch (self , idx : int , batch : xr .Dataset | xr .DataArray ) -> None :
536
532
batch .to_zarr (self .cache , group = str (idx ), mode = 'a' )
537
533
538
534
def _get_cached_batch (self , idx : int ) -> xr .Dataset :
0 commit comments