Skip to content

BaseFinetuning callback can add the same parameter to the optimizer multiple times #16465

Open
@samgelman

Description

@samgelman

Bug description

The BaseFinetuning callback adds the same parameter to the optimizer multiple times when a module reuses a parameter from a different module. This causes UserWarning: optimizer contains a parameter group with duplicate parameters; in future, this will cause an error; see github.com/pytorch/pytorch/issues/40967 for more information.

I think the problem is with how BaseFinetuning.filter_params() gathers parameters to add to the optimizer. I created a reproducible example below.

How to reproduce the bug

here's a link to a reproducible example: https://colab.research.google.com/drive/1b8-CNJzyDB9bhryF_vOoZkOymZ9spKyY?usp=sharing

Error messages and logs

 UserWarning: optimizer contains a parameter group with duplicate parameters; in future, this will cause an error; see github.com/pytorch/pytorch/issues/40967 for more information

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

For additional context, this issue came up when I was finetuning ESM (https://github.com/facebookresearch/esm). Specifically, ESM2 uses a module called RobertaLMHead (https://github.com/facebookresearch/esm/blob/7c2beef1eb74d8b5744f28ffc215a244d874a74f/esm/model/esm2.py#L71), which reuses the weight from embed_tokens.weight

My current workaround is to check for duplicate parameters unfreeze_and_add_param_group and delete them:

unique_params = set()
unique_params_list = []
for param in params:
    if param not in unique_params:
        unique_params.add(param)
        unique_params_list.append(param)
params = unique_params_list

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions