Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Stochastic Weight Averaging (SWA) and LearningRateFinder simultaneously can cause issues: #20070

Open
liuzeyu6 opened this issue Jul 10, 2024 · 0 comments
Labels
bug Something isn't working callback: swa help wanted Open to be worked on ver: 2.2.x

Comments

@liuzeyu6
Copy link

liuzeyu6 commented Jul 10, 2024

Bug description

trainer.fit(model, data_module)

trainer.py 543 fit
call._call_and_handle_interrupt(

call.py 44 _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)

trainer.py 579 _fit_impl
self._run(model, ckpt_path=ckpt_path)

trainer.py 948 _run
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment

call.py 95 _call_setup_hook
_call_callback_hooks(trainer, "setup", stage=fn)

call.py 210 _call_callback_hooks
fn(trainer, trainer.lightning_module, *args, **kwargs)

stochastic_weight_avg.py 154 setup
self._average_model = deepcopy(pl_module)

copy.py 172 deepcopy
y = _reconstruct(x, memo, *rv)

copy.py 271 _reconstruct
state = deepcopy(state, memo)

copy.py 146 deepcopy
y = copier(x, memo)

copy.py 231 _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)

copy.py 146 deepcopy
y = copier(x, memo)

copy.py 206 _deepcopy_list
append(deepcopy(a, memo))

copy.py 153 deepcopy
y = copier(memo)

_tensor.py 86 deepcopy
raise RuntimeError(

RuntimeError:
Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

What version are you seeing the problem on?

master

How to reproduce the bug

swa_callback = StochasticWeightAveraging(swa_lrs=1e-2) #

 
    trainer = L.Trainer(default_root_dir=config.root_dir,
                         devices=1,
                         accelerator="auto",
                         log_every_n_steps = 1,
                         max_epochs=config.hyperparameters.max_epoch_nums,
                         callbacks=[swa_callback, timer, metric_callback, print_callback,
                                    checkpoint_callback, early_stop_callback, lr_monitor,rich_bar])

    tuner = Tuner(trainer)
    lr_finder  = tuner.lr_find(model, datamodule=data_module)

    # fig = lr_finder.plot(suggest=True)
    # fig.show()
    suggested_lr = lr_finder.suggestion()
    logging.info(f"Suggested Learning Rate: {suggested_lr}")
    model.hparams.lr = suggested_lr


    trainer.fit(model, data_module)
    trainer.test(model, data_module)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 1.5.0):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @carmocca

@liuzeyu6 liuzeyu6 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jul 10, 2024
@awaelchli awaelchli added help wanted Open to be worked on callback: swa and removed needs triage Waiting to be triaged by maintainers labels Jul 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: swa help wanted Open to be worked on ver: 2.2.x
Projects
None yet
Development

No branches or pull requests

2 participants