Skip to content

Commit

Permalink
Merge branch 'Project-MONAI:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
K-Rilla authored Jul 24, 2024
2 parents 77207ad + 37917e0 commit 6e1970f
Show file tree
Hide file tree
Showing 95 changed files with 16,406 additions and 218 deletions.
5 changes: 5 additions & 0 deletions docs/source/engines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ Workflows
.. autoclass:: GanTrainer
:members:

`AdversarialTrainer`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: AdversarialTrainer
:members:

`Evaluator`
~~~~~~~~~~~
.. autoclass:: Evaluator
Expand Down
23 changes: 23 additions & 0 deletions docs/source/inferers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ Inferers
:members:
:special-members: __call__

`DiffusionInferer`
~~~~~~~~~~~~~~~~~~
.. autoclass:: DiffusionInferer
:members:
:special-members: __call__

`LatentDiffusionInferer`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LatentDiffusionInferer
:members:
:special-members: __call__

`ControlNetDiffusionInferer`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ControlNetDiffusionInferer
:members:
:special-members: __call__

`ControlNetLatentDiffusionInferer`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ControlNetLatentDiffusionInferer
:members:
:special-members: __call__

Splitters
---------
Expand Down
5 changes: 5 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ Component store
---------------
.. autoclass:: monai.utils.component_store.ComponentStore
:members:

Ordering
--------
.. automodule:: monai.utils.ordering
:members:
2 changes: 1 addition & 1 deletion monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def update_params(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def set_score(self):
def set_score(self, *args, **kwargs):
"""Report the evaluated results to HPO."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate_anchors(
w_ratios = 1 / area_scale
h_ratios = area_scale
# if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]
elif self.spatial_dims == 3:
else:
area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)
w_ratios = 1 / area_scale
h_ratios = aspect_ratios_t[:, 0] / area_scale
Expand All @@ -199,7 +199,7 @@ def generate_anchors(
hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)
if self.spatial_dims == 2:
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0
elif self.spatial_dims == 3:
else: # elif self.spatial_dims == 3:
ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0

Expand Down
11 changes: 7 additions & 4 deletions monai/apps/pathology/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SobelGradients,
)
from monai.transforms.transform import Transform
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique, where
from monai.utils import TransformBackends, convert_to_numpy, optional_import
from monai.utils.misc import ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
Expand Down Expand Up @@ -162,7 +162,8 @@ def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor:
pred = label(pred)[0]
if self.remove_small_objects is not None:
pred = self.remove_small_objects(pred)
pred[pred > 0] = 1
pred_indices = np.where(pred > 0)
pred[pred_indices] = 1

return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]

Expand Down Expand Up @@ -338,7 +339,8 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N
instance_border = instance_border >= self.threshold # uncertain area

marker = mask - convert_to_dst_type(instance_border, mask)[0] # certain foreground
marker[marker < 0] = 0
marker_indices = where(marker < 0)
marker[marker_indices] = 0 # type: ignore[index]
marker = self.postprocess_fn(marker)
marker = convert_to_numpy(marker)

Expand Down Expand Up @@ -379,6 +381,7 @@ def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) ->
"""

p_delta = (current[0] - previous[0], current[1] - previous[1])
row, col = -1, -1

if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)):
row = int(current[0] + 0.5)
Expand Down Expand Up @@ -634,7 +637,7 @@ def __call__( # type: ignore

seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]

inst_type = type_map_crop[seg_map_crop]
inst_type = type_map_crop[seg_map_crop] # type: ignore[index]
type_list, type_pixels = unique(inst_type, return_counts=True)
type_list = list(zip(type_list, type_pixels))
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
Expand Down
1 change: 1 addition & 0 deletions monai/bundle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
raise ValueError(f"Cannot find config file '{full_cname}'")

ardata = archive.read(full_cname)
cdata = {}

if full_cname.lower().endswith("json"):
cdata = json.loads(ardata, **load_kw_args)
Expand Down
1 change: 1 addition & 0 deletions monai/data/dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def collect_meta_data(self):
"""

for data in self.data_loader:
meta_dict = {}
if isinstance(data[self.image_key], MetaTensor):
meta_dict = data[self.image_key].meta
elif self.meta_key in data:
Expand Down
4 changes: 2 additions & 2 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def peek_pending_rank(self):
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
return 1 if a is None else int(max(1, len(a) - 1))

def new_empty(self, size, dtype=None, device=None, requires_grad=False):
def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
"""
must be defined for deepcopy to work
Expand Down Expand Up @@ -580,7 +580,7 @@ def ensure_torch_and_prune_meta(
img.affine = MetaTensor.get_default_affine()
return img

def __repr__(self):
def __repr__(self): # type: ignore[override]
"""
Prints a representation of the tensor.
Prepends "meta" to ``torch.Tensor.__repr__``.
Expand Down
15 changes: 9 additions & 6 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@
pytorch_after,
)

if pytorch_after(1, 13):
# import private code for reuse purposes, comment in case things break in the future
from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map

pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")
Expand Down Expand Up @@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
and so should not be used as a collate function directly in dataloaders.
"""
collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate
collated = collate_fn(batch) # type: ignore
if pytorch_after(1, 13):
from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues

collated = collate_tensor_fn(batch)
else:
collated = default_collate(batch)

meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
if common_:
Expand Down Expand Up @@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence):

if pytorch_after(1, 13):
# needs to go here to avoid circular import
from torch.utils.data._utils.collate import default_collate_fn_map

from monai.data.meta_tensor import MetaTensor

default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
Expand Down
4 changes: 3 additions & 1 deletion monai/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from __future__ import annotations

from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
from .trainer import GanTrainer, SupervisedTrainer, Trainer
from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
from .utils import (
DiffusionPrepareBatch,
IterationEvents,
PrepareBatch,
PrepareBatchDefault,
PrepareBatchExtraInput,
VPredictionPrepareBatch,
default_make_latent,
default_metric_cmp_fn,
default_prepare_batch,
Expand Down
Loading

0 comments on commit 6e1970f

Please sign in to comment.