Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 01ff8dc

Browse files
committed
fix docstring and add more test cases for multiplier
1 parent ce72651 commit 01ff8dc

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

generative/networks/blocks/encoder_modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class SpatialRescaler(nn.Module):
3030
n_stages: number of interpolation stages.
3131
size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]).
3232
method: algorithm used for sampling.
33-
multiplier: multiplier for spatial size. If scale_factor is a tuple,
34-
its length has to match the number of spatial dimensions.
33+
multiplier: multiplier for spatial size. If `multiplier` is a sequence,
34+
its length has to match the number of spatial dimensions; `input.dim() - 2`.
3535
in_channels: number of input channels.
3636
out_channels: number of output channels.
3737
bias: whether to have a bias term.
@@ -43,7 +43,7 @@ def __init__(
4343
n_stages: int = 1,
4444
size: Sequence[int] | int | None = None,
4545
method: str = "bilinear",
46-
multiplier: float | None = None,
46+
multiplier: Sequence[float] | float | None = None,
4747
in_channels: int = 3,
4848
out_channels: int = None,
4949
bias: bool = False,

tests/test_encoder_modules.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@
6969
(1, 3, 16, 16, 16),
7070
(1, 2, 8, 8, 8),
7171
],
72+
[
73+
{
74+
"spatial_dims": 3,
75+
"n_stages": 1,
76+
"method": "trilinear",
77+
"multiplier": (0.25, 0.5, 0.75),
78+
"in_channels": 3,
79+
"out_channels": 2,
80+
},
81+
(1, 3, 20, 20, 20),
82+
(1, 2, 5, 10, 15),
83+
],
7284
[
7385
{"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2},
7486
(1, 3, 16, 16),

0 commit comments

Comments
 (0)