Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model parameters don't get updated after upgrading from 1.1.4 to 2.0.7 #18346

Open
yqin-falling-stars opened this issue Aug 20, 2023 · 13 comments
Labels
bug Something isn't working repro needed The issue is missing a reproducible example ver: 2.0.x ver: 2.1.x

Comments

@yqin-falling-stars
Copy link

Bug description

I have a code that can be trained in 1.4, but the model parameters are not updated in 2.0.7 version. What is the possible reason? thanks a lot.

What version are you seeing the problem on?

v2.0, master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

@yqin-falling-stars yqin-falling-stars added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 20, 2023
@awaelchli
Copy link
Contributor

@amifallingstars Thanks for reaching out. Unfortunately it is too hard for us to know what's up given this extremely brief description.

If you have determined that the model weights don't get updated, I assume you mean the optimizer step doesn't get called? This could be the case if you return None from your training_step for some reason.

Please note that for us to be able to work on this, we'd need some evidence that there is a bug in Lightning. The best is if you could provide the code example that reproduces your issue.

@awaelchli awaelchli added waiting on author Waiting on user action, correction, or update and removed needs triage Waiting to be triaged by maintainers labels Aug 24, 2023
@awaelchli awaelchli changed the title problem refer to version transfer Model parameters don't get updated after upgrading from 1.4 to 2.0 Aug 24, 2023
@yqin-falling-stars yqin-falling-stars changed the title Model parameters don't get updated after upgrading from 1.4 to 2.0 Model parameters don't get updated after upgrading from 1.1.4 to 2.0 Aug 25, 2023
@yqin-falling-stars
Copy link
Author

yqin-falling-stars commented Aug 25, 2023

@awaelchli Hi, thanks for answering my question.
My problem is that the same code below can be trained on pytorch_lightning version 1.1.4 (not 1.4, sorry)

class Fegnet(L.LightningModule):
    def __init__(self, _config, word_vectors):
        super().__init__()
        self.save_hyperparameters("_config")
        self._config = _config
        self.model = MyNet(_config, word_vectors)
        self.table = []
        utils.set_metrics(self)

    def forward(self, inputs):
        return self.model(inputs)

    def training_step(self, batch, batch_idx):
        inputs, records = self.process_batch(batch)
        output = self(inputs)
        self.logger.experiment.add_scalar("train/loss", output["loss"], self.global_step) if self.logger else None
        self.compute_iou_and_update(output, records)
        return output["loss"]

    def on_training_epoch_end(self):
        self.log_metrics()

    def validation_step(self, batch, batch_idx):
        inputs, records = self.process_batch(batch)
        output = self(inputs)
        self.logger.experiment.add_scalar("val/loss", output["loss"], self.global_step) if self.logger else None
        self.compute_iou_and_update(output, records)
        return output
    
    def on_validation_epoch_end(self):
        self.log_metrics()
    
    def on_fit_end(self):
        self.log_table()
        return super().on_fit_end()
    
    def process_batch(self, batch):
        records, vfeats, vfeat_lens, word_ids, char_ids, s_labels, e_labels, h_labels = batch
        
        inputs = edict()
        inputs.vfeats = vfeats
        inputs.h_labels = h_labels
        inputs.s_labels, inputs.e_labels = s_labels, e_labels
        inputs.word_ids, inputs.char_ids = word_ids, char_ids
        
        # generate mask inputs
        inputs.query_mask = (torch.zeros_like(inputs.word_ids) != inputs.word_ids).float()
        inputs.video_mask = convert_length_to_mask(vfeat_lens)
        
        return inputs, records
    
    def compute_iou_and_update(self, output, records, track=None):
        start_logits_list = output["start_logits_list"]
        end_logits_list = output["end_logits_list"]

        start_indices, end_indices = extract_index(start_logits_list[0], end_logits_list[0])
        for record, start_index, end_index in zip(records, start_indices, end_indices):
            start_time, end_time = index_to_time(start_index, end_index, record["v_len"], record["duration"])
            iou = calculate_iou(i0=[start_time, end_time], i1=[record["s_time"], record["e_time"]])
            utils.update_module_ious(self, iou, operation="update")

    def log_metrics(self):
        phase = "train" if self.training else "test"
        r1i3, r1i5, r1i7, mIoU = utils.get_module_ious(self)
        self.table.append([self.global_step, r1i7*100, r1i5*100, r1i3*100, mIoU*100])
        self.log(f"{phase}/r1i7", r1i7)
        self.logger.experiment.add_scalars(f"{phase}", {
            "r1i7": r1i7*100, 
            "r1i5": r1i5*100, 
            "r1i3": r1i3*100, 
            "mIoU": mIoU*100
        }, self.global_step) if self.logger else None
        if not self.training:
            print(f"\t\t\t\tgstep: {self.global_step}, r1i7: {r1i7*100:02.02f}, r1i5: {r1i5*100:02.02f}, r1i3: {r1i3*100:02.02f}, mIoU: {mIoU*100:02.02f}")
        utils.update_module_ious(self, operation="reset")

    def configure_optimizers(self):
        return utils.set_schedule(self)

    def log_table(self):
        columns = ["gstep", "r1i7", "r1i5", "r1i3", "mIoU"]
        data = sorted(self.table, key=lambda x:x[1], reverse=True)
        # data = list(map(lambda x:[x[0], x[1]*100, x[2]*100, x[3]*100, x[4]*100], data))
        self.logger.log_table(key="performance", columns=columns, data=data)
    
    def log_text(self):
        header = ["gstep", "r1i7", "r1i5", "r1i3", "mIoU"]
        data = sorted(self.table, key=lambda x:x[1], reverse=True)
        text = list_to_markdown_table(data, header)
        self.logger.add_text('Final Performance', text, global_step=0)`

and my Trainer is

@ex.automain
def main(_config):
    _config = copy.deepcopy(_config)
    L.seed_everything(12345)
    
    dm = eval(_config["data_module"])(_config)
    _config["model"].update({"word_size": dm.data.n_words, "char_size": dm.data.n_chars})
    model_params = edict()
    model_params.update(_config)
    model = Fegnet(model_params, dm.data.word_vector)
    
    log_dir = os.path.join(_config["paths"]["log_dir"], _config["dataset"], _config["model"]["name"])
    save_dir = os.path.join(log_dir, _config["exp_name"])
    
    if os.path.exists(save_dir): rmtree(save_dir)
    
    logger = TensorBoardLogger(save_dir=save_dir, version=_config["exp_name"]) if _config["use_logger"] else False

    # Define the checkpoints for id and ood training
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(save_dir, "cpt_id"),
        # filename='checkpoint-ood-{global_step}-{test_ood/r1i7*100:.2f}',
        save_top_k=3,
        verbose=True,
        monitor="test/r1i7",
        mode="max",
    )
    

    print('='*71+'Config: '+'='*71)
    pprint(_config)
    print('='*150)
    sys.stdout.flush()
    

    trainer = L.Trainer(
        devices=1,
        max_epochs=_config["train"]["epochs"],
        logger=logger,
        gradient_clip_val=1.0,
        val_check_interval=0.5,
        precision=32,
        callbacks=[checkpoint_callback]
    )

    
    trainer.fit(model, datamodule=dm)

all my out put is the same with pytorch_lightning 2.0.7
image
and all my parameters stay the same after evey training step.
However, I didn't encounter this problem with pytorch_lightning 1.1.4

@yqin-falling-stars yqin-falling-stars changed the title Model parameters don't get updated after upgrading from 1.1.4 to 2.0 Model parameters don't get updated after upgrading from 1.1.4 to 2.0.7 Aug 25, 2023
@awaelchli awaelchli removed the waiting on author Waiting on user action, correction, or update label Aug 25, 2023
@awaelchli
Copy link
Contributor

What is this part?

    def configure_optimizers(self):
        return utils.set_schedule(self)

Is it returning a valid torch Optimizer that handles the closure correctly?

@awaelchli awaelchli added the waiting on author Waiting on user action, correction, or update label Aug 26, 2023
@yqin-falling-stars
Copy link
Author

def set_schedule(pl_module):
    lr = pl_module.hparams._config.train.lr
    wd = pl_module.hparams._config.train.decay
    
    no_decay = [
        "bias",
        "layer_norm",
        "LayerNorm",
    ]

    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in pl_module.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": wd,
        },
        {
            "params": [
                p
                for n, p in pl_module.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if pl_module.trainer.max_steps is None:
        max_steps = (
            len(pl_module.trainer.datamodule.train_dataloader())
            * pl_module.trainer.max_epochs
        )
    else:
        max_steps = pl_module.trainer.max_steps

    # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98))
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max_steps * pl_module.hparams._config.train.warmup_proportion,
        num_training_steps=max_steps,
    )


    return ({
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
        },
    })

@yqin-falling-stars
Copy link
Author

@awaelchli

@awaelchli
Copy link
Contributor

@amifallingstars I can't spot anything suspicious in your code.
Can you test that the loss returned from training_step is changing from one iteration to the next?
And how did you verify that the parameters don't change? Did you store and compare them?

@yqin-falling-stars
Copy link
Author

@awaelchli Sure, I have printed the parameter from epoch to epoch, they didn't change at all. The loss is also the same. I really want to use the new features in version 2.x, but I can't beacuse of this problem. I run the code on linux with Nividia RTX A5000, Titan X, and V100, neither of them work. I also checked my cuda version and python verison which are both in the range of the offical documents. I have tried the demo in the offical site which trains an encoder and a decoder to fit the MINIST dataset, this loss goes down in a right manner with version 2.x. I don't know what can i do. T_T.

@awaelchli
Copy link
Contributor

I see, and does the learning rate scheduler assign a positive learning rate? Note that if the learning rate is set to 0, the parameters wouldn't change. As a sanity check, please remove the scheduler and just return

def configure_optimizers(self):
    return torch.optim.AdamW(self.parameters())

@yqin-falling-stars
Copy link
Author

Yes, I just return torch.optim.AdamW(self.parameters()) and the parameters changed but failed training. The problem must be there, so why my optimizer_helper failed? I printed the learning rate and it has the correct value.

@yqin-falling-stars
Copy link
Author

image I also printed out the named_parameters it has value 5e-4. my optimizer_helper is
def set_schedule(pl_module):
    lr = pl_module.hparams._config.train.lr
    wd = pl_module.hparams._config.train.decay
    
    no_decay = [
        "bias",
        "layer_norm",
        "LayerNorm",
    ]

    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in pl_module.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": wd,
        },
        {
            "params": [
                p
                for n, p in pl_module.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    if pl_module.trainer.max_steps is None:
        max_steps = (
            len(pl_module.trainer.datamodule.train_dataloader())
            * pl_module.trainer.max_epochs
        )
    else:
        max_steps = pl_module.trainer.max_steps

    # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98))
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max_steps * pl_module.hparams._config.train.warmup_proportion,
        num_training_steps=max_steps,
    )


    return ({
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
        },
    })

@awaelchli
Copy link
Contributor

awaelchli commented Aug 28, 2023

Can you check how get_linear_scheduler_with_warmup modifies the optimizer?

From what you have posted earlier, it is clear that the issue comes from the optimizer and how it's parameters get managed.

  1. Your screenshot shows that you have some parameters in your model that are 0. Check that this is intentional. Note that depending on the layer, you may get a 0 gradient!
  2. Clearly from you last message, the get_linear_scheduler_with_warmup has a side effect on the optimizer, it must be updating it internally somehow.

Please understand that if you don't share the code, I won't be able to help much. The forum is better suited for implementation help questions. Here we are discussing bugs in the Lightning framework, and to pursue them we need some evidence that there is something wrong in Lightning. A screenshot is not sufficient.

@yqin-falling-stars
Copy link
Author

import torch
import random

from transformers.optimization import AdamW
from transformers import (
    get_polynomial_decay_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)

from .metrics import IoUThresholdPercentage, Scalar, MeanIoU
from project.models import objectives


def set_metrics(pl_module):
    for split in ["train", "test"]:
        setattr(pl_module, f"{split}_r1i3", IoUThresholdPercentage(threshold=0.3))
        setattr(pl_module, f"{split}_r1i5", IoUThresholdPercentage(threshold=0.5))
        setattr(pl_module, f"{split}_r1i7", IoUThresholdPercentage(threshold=0.7))
        setattr(pl_module, f"{split}_mIoU", MeanIoU())
            

def update_module_ious(pl_module, iou=None, operation=None):
    phase = "train" if pl_module.training else "test"
    names = [f"{phase}_r1i3", f"{phase}_r1i5", f"{phase}_r1i7", f"{phase}_mIoU"]
    for name in names:
        if operation == "update":
            assert iou is not None, "update an invalid iou value"
            getattr(pl_module, name)(iou)
        elif operation == "reset":
            getattr(pl_module, name).reset()
    return None


def get_module_ious(pl_module):
    phase = "train" if pl_module.training else "test"
    names = [f"{phase}_r1i3", f"{phase}_r1i5", f"{phase}_r1i7", f"{phase}_mIoU"]
    return [getattr(pl_module, name).compute() for name in names]
    

def set_schedule(pl_module):
    lr = pl_module.hparams._config.train.lr
    # print(lr)
    # raise
    wd = pl_module.hparams._config.train.decay
    
    no_decay = [
        "bias",
        "layer_norm",
        "LayerNorm",
    ]

    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in pl_module.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": wd,
        },
        {
            "params": [
                p
                for n, p in pl_module.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    # print(optimizer_grouped_parameters)
    # raise
    if pl_module.trainer.max_steps is None:
        max_steps = (
            len(pl_module.trainer.datamodule.train_dataloader())
            * pl_module.trainer.max_epochs
        )
    else:
        max_steps = pl_module.trainer.max_steps

    optimizer = torch.optim.AdamW(pl_module.parameters(), lr=lr, eps=1e-8, betas=(0.9, 0.98))
    # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max_steps * pl_module.hparams._config.train.warmup_proportion,
        num_training_steps=max_steps,
    )

    # return optimizer
    return ({
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
        },
    })

get_linear_scheduler_with_warmup is function from transformers which is the pytorch library.

@yqin-falling-stars
Copy link
Author

zeros are the initialization of the linear bias

def init_parameters(self):
        def init_weights(m):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LSTM):
                m.reset_parameters()
        self.apply(init_weights)

@awaelchli awaelchli added repro needed The issue is missing a reproducible example and removed waiting on author Waiting on user action, correction, or update labels Sep 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working repro needed The issue is missing a reproducible example ver: 2.0.x ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants