Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,21 @@
config2 = (30, [(10, 0), (20, 10)], True, expected_hist2)
config3 = (
PiecewiseLinearStateScheduler,
{"param_name": "linear_scheduled_param", "milestones_values": [(3, 12), (5, 10)]},
{"param_name": "linear_scheduled_param", "milestones_values": [(3, 12), (5, 10)], "create_new": True},
)
config4 = (
ExpStateScheduler,
{"param_name": "exp_scheduled_param", "initial_value": 10, "gamma": 0.99, "create_new": True},
)
config4 = (ExpStateScheduler, {"param_name": "exp_scheduled_param", "initial_value": 10, "gamma": 0.99})
config5 = (
MultiStepStateScheduler,
{"param_name": "multistep_scheduled_param", "initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
{
"param_name": "multistep_scheduled_param",
"initial_value": 10,
"gamma": 0.99,
"milestones": [3, 6],
"create_new": True,
},
)


Expand All @@ -39,12 +48,16 @@ def __call__(self, event_index):

config6 = (
LambdaStateScheduler,
{"param_name": "custom_scheduled_param", "lambda_obj": LambdaState(initial_value=10, gamma=0.99)},
{
"param_name": "custom_scheduled_param",
"lambda_obj": LambdaState(initial_value=10, gamma=0.99),
"create_new": True,
},
)

config7 = (
StepStateScheduler,
{"param_name": "step_scheduled_param", "initial_value": 10, "gamma": 0.99, "step_size": 5},
{"param_name": "step_scheduled_param", "initial_value": 10, "gamma": 0.99, "step_size": 5, "create_new": True},
)


Expand All @@ -57,7 +70,10 @@ def test_pwlinear_scheduler_linear_increase_history(
# Testing linear increase
engine = Engine(lambda e, b: None)
pw_linear_step_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="pwlinear_scheduled_param", milestones_values=milestones_values, save_history=save_history,
param_name="pwlinear_scheduled_param",
milestones_values=milestones_values,
save_history=save_history,
create_new=True,
)
pw_linear_step_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -78,7 +94,7 @@ def test_pwlinear_scheduler_step_constant(max_epochs, milestones_values):
# Testing step_constant
engine = Engine(lambda e, b: None)
linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="pwlinear_scheduled_param", milestones_values=milestones_values
param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
)
linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -96,7 +112,7 @@ def test_pwlinear_scheduler_linear_increase(max_epochs, milestones_values, expec
# Testing linear increase
engine = Engine(lambda e, b: None)
linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="pwlinear_scheduled_param", milestones_values=milestones_values
param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
)
linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -115,7 +131,7 @@ def test_pwlinear_scheduler_max_value(
# Testing max_value
engine = Engine(lambda e, b: None)
linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
param_name="linear_scheduled_param", milestones_values=milestones_values,
param_name="linear_scheduled_param", milestones_values=milestones_values, create_new=True
)
linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand Down Expand Up @@ -152,7 +168,7 @@ def test_piecewiselinear_asserts():
def test_exponential_scheduler(max_epochs, initial_value, gamma):
engine = Engine(lambda e, b: None)
exp_state_parameter_scheduler = ExpStateScheduler(
param_name="exp_scheduled_param", initial_value=initial_value, gamma=gamma
param_name="exp_scheduled_param", initial_value=initial_value, gamma=gamma, create_new=True
)
exp_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -170,7 +186,11 @@ def test_step_scheduler(
):
engine = Engine(lambda e, b: None)
step_state_parameter_scheduler = StepStateScheduler(
param_name="step_scheduled_param", initial_value=initial_value, gamma=gamma, step_size=step_size
param_name="step_scheduled_param",
initial_value=initial_value,
gamma=gamma,
step_size=step_size,
create_new=True,
)
step_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -193,7 +213,11 @@ def test_multistep_scheduler(
):
engine = Engine(lambda e, b: None)
multi_step_state_parameter_scheduler = MultiStepStateScheduler(
param_name="multistep_scheduled_param", initial_value=initial_value, gamma=gamma, milestones=milestones,
param_name="multistep_scheduled_param",
initial_value=initial_value,
gamma=gamma,
milestones=milestones,
create_new=True,
)
multi_step_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=max_epochs)
Expand All @@ -219,7 +243,7 @@ def __call__(self, event_index):
return self.initial_value * self.gamma ** (event_index % 9)

lambda_state_parameter_scheduler = LambdaStateScheduler(
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99),
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
)
lambda_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
engine.run([0] * 8, max_epochs=2)
Expand All @@ -243,7 +267,7 @@ def __init__(self, initial_value, gamma):

with pytest.raises(ValueError, match=r"Expected lambda_obj to be callable."):
lambda_state_parameter_scheduler = LambdaStateScheduler(
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99),
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
)


Expand Down Expand Up @@ -286,7 +310,7 @@ def _test(scheduler_cls, scheduler_kwargs):
def test_torch_save_load():

lambda_state_parameter_scheduler = LambdaStateScheduler(
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99),
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
)

torch.save(lambda_state_parameter_scheduler, "dummy_lambda_state_parameter_scheduler.pt")
Expand Down Expand Up @@ -333,6 +357,7 @@ def _test(scheduler_cls, **scheduler_kwargs):
initial_value=10,
gamma=0.99,
milestones=[3, 6],
create_new=True,
)


Expand All @@ -343,7 +368,7 @@ def test_multiple_scheduler_with_save_history():
if "save_history" in config:
del config["save_history"]
_scheduler = scheduler(**config, save_history=True)
_scheduler.attach(engine_multiple_schedulers)
_scheduler.attach(engine_multiple_schedulers,)

engine_multiple_schedulers.run([0] * 8, max_epochs=2)

Expand Down Expand Up @@ -371,7 +396,7 @@ def __init__(self, initial_value, gamma):
def __call__(self, event_index):
return self.initial_value * self.gamma ** (event_index % 9)

param_scheduler = LambdaStateScheduler(param_name="param", lambda_obj=LambdaState(10, 0.99),)
param_scheduler = LambdaStateScheduler(param_name="param", lambda_obj=LambdaState(10, 0.99), create_new=True)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand All @@ -382,7 +407,7 @@ def __call__(self, event_index):
engine = Engine(lambda e, b: None)

param_scheduler = PiecewiseLinearStateScheduler(
param_name="param", milestones_values=[(10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1)]
param_name="param", milestones_values=[(10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1)], create_new=True
)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)
Expand All @@ -393,7 +418,7 @@ def __call__(self, event_index):

engine = Engine(lambda e, b: None)

param_scheduler = ExpStateScheduler(param_name="param", initial_value=10, gamma=0.99)
param_scheduler = ExpStateScheduler(param_name="param", initial_value=10, gamma=0.99, create_new=True)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand All @@ -403,7 +428,7 @@ def __call__(self, event_index):

engine = Engine(lambda e, b: None)

param_scheduler = StepStateScheduler(param_name="param", initial_value=10, gamma=0.99, step_size=5)
param_scheduler = StepStateScheduler(param_name="param", initial_value=10, gamma=0.99, step_size=5, create_new=True)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand All @@ -413,7 +438,9 @@ def __call__(self, event_index):

engine = Engine(lambda e, b: None)

param_scheduler = MultiStepStateScheduler(param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6],)
param_scheduler = MultiStepStateScheduler(
param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6], create_new=True
)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

Expand Down