Skip to content

Commit 56ee32e

Browse files
dcfidalgoKumoLiu
andauthored
Fix: Small logic mistake in the AsDiscrete.__call__ method (#7984)
Hi MONAI Team! Thank you very much for this super nice framework, really appreciate it! Just found a small logic mistake in one of the transform classes. To reproduce: ```python import torch from monai.transforms.post.array import AsDiscrete transform = AsDiscrete(argmax=True) prediction = torch.rand(2, 3, 3) transform(prediction, argmax=False) # will still apply argmax ``` ### Description Proposed fix: `argmax` is explicitly checked for `None` in the `__cal__` method. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: David Carreto Fidalgo <davidc.fidalgo@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent ae5a04d commit 56ee32e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

monai/transforms/post/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ def __call__(
211211
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
212212
img = convert_to_tensor(img, track_meta=get_track_meta())
213213
img_t, *_ = convert_data_type(img, torch.Tensor)
214-
if argmax or self.argmax:
214+
argmax = self.argmax if argmax is None else argmax
215+
if argmax:
215216
img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))
216217

217218
to_onehot = self.to_onehot if to_onehot is None else to_onehot

0 commit comments

Comments
 (0)