Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit aef9ba2

Browse files
authored
Merge pull request #240 from Project-MONAI/217_set_default_inference_timesteps
Initialise inference_timesteps to train_timesteps
2 parents 00e7bb7 + 5cad89f commit aef9ba2

File tree

3 files changed

+7
-14
lines changed

3 files changed

+7
-14
lines changed

generative/networks/schedulers/ddim.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,15 @@ def __init__(
103103
# standard deviation of the initial noise distribution
104104
self.init_noise_sigma = 1.0
105105

106-
# setable values
107-
self.num_inference_steps = None
106+
108107
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
109108

110109
self.clip_sample = clip_sample
111110
self.steps_offset = steps_offset
112111

112+
# default the number of inference timesteps to the number of train steps
113+
self.set_timesteps(num_train_timesteps)
114+
113115
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
114116
"""
115117
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

generative/networks/schedulers/pndm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,11 @@ def __init__(
117117
self.cur_sample = None
118118
self.ets = []
119119

120-
# settable values
121-
self.num_inference_steps = None
120+
122121
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
123-
self.prk_timesteps = torch.Tensor([])
124-
self.plms_timesteps = torch.Tensor([])
125-
self.timesteps = torch.Tensor([])
126122

123+
# default the number of inference timesteps to the number of train steps
124+
self.set_timesteps(num_train_timesteps)
127125
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
128126
"""
129127
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

tests/test_scheduler_pndm.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,6 @@ def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape):
3939
noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
4040
self.assertEqual(noisy.shape, expected_shape)
4141

42-
@parameterized.expand(TEST_CASES)
43-
def test_error_if_timesteps_not_set(self, input_param, input_shape, expected_shape):
44-
scheduler = PNDMScheduler(**input_param)
45-
with self.assertRaises(ValueError):
46-
model_output = torch.randn(input_shape)
47-
sample = torch.randn(input_shape)
48-
scheduler.step(model_output=model_output, timestep=500, sample=sample)
4942

5043
@parameterized.expand(TEST_CASES)
5144
def test_step_shape(self, input_param, input_shape, expected_shape):

0 commit comments

Comments
 (0)