Skip to content

Commit cac8923

Browse files
committed
fix ci
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 43fc5e4 commit cac8923

File tree

2 files changed

+64
-38
lines changed

2 files changed

+64
-38
lines changed

monai/inferers/merger.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,56 @@ def __init__(
310310
self.chunks = chunks
311311

312312
# Handle compressor/codecs based on zarr version
313-
self.codecs = codecs
314-
self.value_codecs = value_codecs
315-
self.count_codecs = count_codecs
316-
317-
# For backward compatibility
318-
if compressor is not None and codecs is None:
319-
self.codecs = compressor if isinstance(compressor, (list, tuple)) else [compressor]
320-
if value_compressor is not None and value_codecs is None:
321-
self.value_codecs = value_compressor if isinstance(value_compressor, (list, tuple)) else [value_compressor]
322-
if count_compressor is not None and count_codecs is None:
323-
self.count_codecs = count_compressor if isinstance(count_compressor, (list, tuple)) else [count_compressor]
313+
is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0")
314+
315+
if is_zarr_v3:
316+
# For zarr v3, use codecs or convert compressor to codecs
317+
if codecs is not None:
318+
self.codecs = codecs
319+
elif compressor is not None:
320+
# Convert compressor to codec format
321+
if isinstance(compressor, (list, tuple)):
322+
self.codecs = compressor
323+
else:
324+
self.codecs = None
325+
326+
if value_codecs is not None:
327+
self.value_codecs = value_codecs
328+
elif value_compressor is not None:
329+
if isinstance(value_compressor, (list, tuple)):
330+
self.value_codecs = value_compressor
331+
else:
332+
self.value_codecs = None
333+
334+
if count_codecs is not None:
335+
self.count_codecs = count_codecs
336+
elif count_compressor is not None:
337+
if isinstance(count_compressor, (list, tuple)):
338+
self.count_codecs = count_compressor
339+
else:
340+
self.count_codecs = [
341+
{"name": "bytes", "configuration": {}},
342+
{"name": count_compressor.lower(), "configuration": {}},
343+
]
344+
else:
345+
self.count_codecs = None
346+
else:
347+
# For zarr v2, use compressors
348+
if codecs is not None:
349+
# If codecs are specified in v2, use the first codec as compressor
350+
self.codecs = codecs[0] if isinstance(codecs, (list, tuple)) else codecs
351+
else:
352+
self.codecs = compressor
353+
354+
if value_codecs is not None:
355+
self.value_codecs = value_codecs[0] if isinstance(value_codecs, (list, tuple)) else value_codecs
356+
else:
357+
self.value_codecs = value_compressor
358+
359+
if count_codecs is not None:
360+
self.count_codecs = count_codecs[0] if isinstance(count_codecs, (list, tuple)) else count_codecs
361+
else:
362+
self.count_codecs = count_compressor
324363

325364
# Create zarr arrays with appropriate parameters based on version
326365
if is_zarr_v3:

tests/inferers/test_zarr_avg_merger.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import unittest
1515
import warnings
16+
from typing import Any
1617

1718
import numpy as np
1819
import torch
@@ -428,38 +429,24 @@ def test_zarr_avg_merge_none_merged_shape_error(self):
428429
with self.assertRaises(ValueError):
429430
ZarrAvgMerger(merged_shape=None)
430431

431-
def test_deprecated_compressor_warning(self):
432+
def _check_deprecation_warning(self, param_name: str, value: Any):
433+
"""Helper function to check deprecation warnings for compressor parameters."""
432434
with warnings.catch_warnings(record=True) as w:
433435
warnings.simplefilter("always")
434-
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, compressor="LZ4")
436+
kwargs = {"merged_shape": TENSOR_4x4.shape, param_name: value}
437+
ZarrAvgMerger(**kwargs)
435438
self.assertTrue(len(w) == 1)
436-
self.assertTrue(issubclass(w[-1].category, DeprecationWarning))
437-
expected_message_part = (
438-
"The `compressor` argument is deprecated since 1.5.0 and will be removed in 1.7.0. "
439-
"Please use 'codecs' instead."
439+
self.assertTrue(issubclass(w[-1].category, FutureWarning))
440+
expected_message = (
441+
f"Argument `{param_name}` has been deprecated since version 1.5.0. It will be removed in version 1.7.0."
440442
)
441-
self.assertTrue(expected_message_part in str(w[-1].message))
443+
self.assertIn(expected_message, str(w[-1].message))
444+
445+
def test_deprecated_compressor_warning(self):
446+
self._check_deprecation_warning("compressor", numcodecs.VLenBytes())
442447

443448
def test_deprecated_value_compressor_warning(self):
444-
with warnings.catch_warnings(record=True) as w:
445-
warnings.simplefilter("always")
446-
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, value_compressor="LZ4")
447-
self.assertTrue(len(w) == 1)
448-
self.assertTrue(issubclass(w[-1].category, DeprecationWarning))
449-
expected_message_part = (
450-
"The `value_compressor` argument is deprecated since 1.5.0 and will be removed in 1.7.0. "
451-
"Please use 'value_codecs' instead."
452-
)
453-
self.assertTrue(expected_message_part in str(w[-1].message))
449+
self._check_deprecation_warning("value_compressor", numcodecs.VLenBytes())
454450

455451
def test_deprecated_count_compressor_warning(self):
456-
with warnings.catch_warnings(record=True) as w:
457-
warnings.simplefilter("always")
458-
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, count_compressor="LZ4")
459-
self.assertTrue(len(w) == 1)
460-
self.assertTrue(issubclass(w[-1].category, DeprecationWarning))
461-
expected_message_part = (
462-
"The `count_compressor` argument is deprecated since 1.5.0 and will be removed in 1.7.0. "
463-
"Please use 'count_codecs' instead."
464-
)
465-
self.assertTrue(expected_message_part in str(w[-1].message))
452+
self._check_deprecation_warning("count_compressor", numcodecs.VLenBytes())

0 commit comments

Comments
 (0)