Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP always sets requires_grad to False? Breaking checkpoint_wrapper. #758

Open
dave-epstein opened this issue Aug 1, 2021 · 44 comments · Fixed by #761
Open

FSDP always sets requires_grad to False? Breaking checkpoint_wrapper. #758

dave-epstein opened this issue Aug 1, 2021 · 44 comments · Fixed by #761
Assignees
Labels
activation checkpoint FSDP FullyShardedDataParallel (zero-3)

Comments

@dave-epstein
Copy link

❓ Questions and Help

Hi, I get the following error when trying to use FSDP with checkpoint_wrapper:
RuntimeError: None of the outputs have requires_grad=True, this checkpoint() is not necessary

I found this weird so I started digging around. I found that in the checkpoint wrapper function, all the input tensors had requires_grad=False stashed in the forward pass which led to this error. I checked that the inputs themselves do have requires_grad=True (at least one of them) and tried to find where this got clobbered. It is happening in the call to cast_floats_to_right_precision(True, True, *args, **kwargs) in the forward method of FullyShardedDataParallel.

I am really confused by this method since it seems that it methodically sets requires_grad to False for all inputs. If relevant, the layer being wrapped with a checkpoint is a Resnet-like block that is an intermediate layer in a network, the layer before it is an un-checkpointed Conv2D. So the input to the block is N x C x H x W and it has requires_grad=True, but when the forward call is dispatched (outputs = self.module(*args, **kwargs)), this tensor now has requires_grad=False. This breaks checkpoint_wrapper :(

Help would be much appreciated!

@anj-s
Copy link
Contributor

anj-s commented Aug 2, 2021

Thanks for opening this issue! We do have tests with checkpoint_wrapper here and here. How does your example compare? Would it possible to maybe modify one of these tests so that we can repro the issue?
I see that the requires_grad field is carried over for leaf nodes in the function you mentioned . Is the requires_grad property being unset for both non-leaf and leaf nodes?

cc @min-xu-ai who might have some ideas about this.

@min-xu-ai
Copy link
Contributor

Thanks for the bug report!

this method since it seems that it methodically sets requires_grad to False for all inputs

The method actually carries the requires_grad from the original. So I don’t see it seeing the flag to false blindly.

do you have a small reproducible case that we can use to debug it?

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

Thanks to both. I am looking into this more deeply still. One key difference may be that I'm using PyTorch Lightning with the FSDP plugin, so I'm not handling the instantiation of the model with full granularity. It may then turn out that this is a Lightning+FSDP+checkpoint_wrapper bug. I will reply with more details.

Another small difference - I am using auto_wrap from fairscale.nn to wrap models in the configure_sharded_model method in PyTorch Lightning. I'm using fairscale 0.4 and lightning 1.4.

I'm having some issues using the PTL "BoringModel" for some reason right now, I wanted to build a simple reproducible case. As I make progress I'll update this thread.

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

Ok, seems like the issue is actually simple - the input feature map requires_grad, is fp32, but is NOT a leaf. So the created fp16 version does NOT get requires_grad set to True, and this causes the issue. But clearly this is causing some undesirable behavior. Is this an issue in my code then, or in the control flow logic of cast_floats_to_right_precision?

@dave-epstein
Copy link
Author

I removed the is_leaf check and now copy over x's requires_grad to y's requires_grad unconditionally. What results now is the following opaque error:

SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f11d22aea00> returned NULL without setting an error

What can I do to help you guys dig into this?

@dave-epstein
Copy link
Author

This may have something to do with the fact that I'm doing GAN training. I'm trying to use regular ddp_sharded and I get a complaint that no parameter requires gradient when I'm in the discriminator optimization step and running noise through a generator. I may need to file a PyTorch Lightning issue.

@min-xu-ai
Copy link
Contributor

Thanks for digging and detailed description on what you tried and got. I am thinking this is a case not yet being tested by FSDP. Perhaps the input to your root FSDP model is generated from another model? Did you try not using mixed precision so that this code is skip to see if this problem is still there or not? Also, did you try remove the nograd context in the conversion function?

@dave-epstein
Copy link
Author

Right now I am trying to fully understand what works and does not work with regular DDP, DDP sharded, and FSDP, while using checkpoint_wrappers for GAN training. I will comment here when I have finished my tests :)

@dave-epstein
Copy link
Author

This will be a bit long, but I hope it is comprehensive and helpful, and that it provides somewhere for us to start from to continue debugging. In all cases, I use the same snippet for checkpoint wrapping, which I am putting at the end of this comment. To address your comment above, I will try to re-add the is-leaf check and see what happens with fp32 training. I think the nograd context is not the problem, since the requires_grad was being propagated fine when I removed the is-leaf check.

I have been focusing on getting things working in the regular DDP and DDP Sharded setting first.

With regular DDP, I get an error - RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: .... So, that's not very good, and Googling brought no relevant results. I spent some time on this to confirm that I was not double-wrapping any module in checkpoint or doing anything weird with accessing parameters. This is regardless of whether I use just the G optimizer or both G and D. So, I'm stuck there.

With DDP Sharded, the situation is as follows:

When I try to train with both optimizers, there was the error I described above, where it complains that no parameter in the input requires gradient. This is because the generator model has all params with requires_grad=False in the discriminator step. So my hack, which I think fixes this, is that instead of inputting torch.randn(...) to the generator, I input torch.randn(...).requires_grad_().

When I do that, I keep getting messages that trainable parameters change, at every iteration (I am 99% sure it's a non-issue and expected behavior due to toggling between states - in _detect_train_change where this message is printed, trainable_mask was equal to the exact negative of self._reference_trainable_mask). In addition to this, I get a warning that some gradients are not being reduced (the check in refresh_trainable). I spent some time looking into the latter and it seems concerning - the same 4 parameters are always not reduced, as reported by: [i for i,v in enumerate(self._grad_to_be_reduced) if v]. Not sure what more I can find of substance here, but it indicates a potential issue.

Then, I tried to only train the optimizer for the generator, so no 2 optimizers. The first forward pass works without complaining, and then the next time the cleanup() function crashes where there is an assert param.grad is not None. If I remove the assertion it appears that the training loop then works without complaining, it seems. But I'm not sure if I can safely remove that assertion or not.

With FSDP (by setting plugins='fsdp' in PTL), in the one-optimizer setting, I get a (maybe irrelevant) warning ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked and then the error assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer.".

With FSDP, in the two optimizer setting, I get SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f9b573dea00> returned NULL without setting an error.

As you can see - nothing really works, but maybe I made some progress with the DDP sharded setting. I know this is pretty complicated, but I think I'm not capable of solving it myself at this point :)

The snippet I use for creating the models used, with model initialization simplified/abstracted:

def wrap_checkpoints(self, modules: list[type], blacklist: Optional[list[str]] = None,
                 root: Optional[nn.Module] = None, **checkpoint_kwargs):
  root = root or self
  if blacklist is None:
      blacklist = []
  for n, m in root.named_modules():
      if any([n.startswith(s) for s in blacklist]) or getattr(m, '_checkpointed', False):
          continue
      if any([isinstance(m, m_) for m_ in modules]):
          root.get_submodule(n).apply(lambda mm: setattr(mm, '_checkpointed', True))
          parent_n, n = n.rsplit('.', 1)
          setattr(root.get_submodule(parent_n), n, checkpoint_wrapper(m, **checkpoint_kwargs))
  return root

def configure_sharded_model(self):
  self.discriminator = auto_wrap(
      self.wrap_checkpoints(root=DiscriminatorNetworks(**D_kwargs), modules=self.checkpoint_modules,
                            offload_to_cpu=True))
  self.generator = auto_wrap(
      self.wrap_checkpoints(root=GeneratorNetwork(**G_kwargs), modules=self.checkpoint_modules,
                            offload_to_cpu=True))

@min-xu-ai
Copy link
Contributor

Let me add @zhaojuanmao and @blefaudeux

do you have instructions for running the code snippet?

@anj-s
Copy link
Contributor

anj-s commented Aug 2, 2021

@dave-epstein Want to confirm that for ShardedDP we use the same code snippet? In that case are you replacing auto_wrap with ShardedDataParallel? For DDP, I am guessing we just call self.wrap_checkpoints?

Is the goal to use FSDP or SDP?

@anj-s
Copy link
Contributor

anj-s commented Aug 2, 2021

Thats what I am trying to confirm. I am unclear about what is being used with checkpoint_wrapper. I know DDP and FSDP but what is DDP sharded?

@blefaudeux
Copy link
Contributor

blefaudeux commented Aug 2, 2021

This will be a bit long, but I hope it is comprehensive and helpful, and that it provides somewhere for us to start from to continue debugging. In all cases, I use the same snippet for checkpoint wrapping, which I am putting at the end of this comment. To address your comment above, I will try to re-add the is-leaf check and see what happens with fp32 training. I think the nograd context is not the problem, since the requires_grad was being propagated fine when I removed the is-leaf check.

I have been focusing on getting things working in the regular DDP and DDP Sharded setting first.

With regular DDP, I get an error - RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: .... So, that's not very good, and Googling brought no relevant results. I spent some time on this to confirm that I was not double-wrapping any module in checkpoint or doing anything weird with accessing parameters. This is regardless of whether I use just the G optimizer or both G and D. So, I'm stuck there.

This happens when the same parameter produces grad twice, AFAIK. DDP's task is to all reduce the gradients when produced, and in that case it's lost in that this parameter's gradient has already been reduced but it just produced new gradients, so something is probably amiss. It could happen if the same param is used twice in different places in your FW, each use will produce a gradient. If they can/should be different, you could have two instances instead and it should fix this issue

With DDP Sharded, the situation is as follows:

When I try to train with both optimizers, there was the error I described above, where it complains that no parameter in the input requires gradient. This is because the generator model has all params with requires_grad=False in the discriminator step. So my hack, which I think fixes this, is that instead of inputting torch.randn(...) to the generator, I input torch.randn(...).requires_grad_().

When I do that, I keep getting messages that trainable parameters change, at every iteration (I am 99% sure it's a non-issue and expected behavior due to toggling between states - in _detect_train_change where this message is printed, trainable_mask was equal to the exact negative of self._reference_trainable_mask).

This means that the trainability of the parameters have changed, at least one param, and it does not look like what you want ?

In addition to this, I get a warning that some gradients are not being reduced (the check in refresh_trainable).

This means that at least one param "requires grad", but it was not actually used in the FW probably (or used with with an input without grads, one of these options), so it does not produce grads. Same as above, not deal breaking per say, but it's meant to warn you that something is probably wrong in the compute graph

I spent some time looking into the latter and it seems concerning - the same 4 parameters are always not reduced, as reported by: [i for i,v in enumerate(self._grad_to_be_reduced) if v]. Not sure what more I can find of substance here, but it indicates a potential issue.

I would bet that these parameters are not being touched really ? (Or used with parameters which have lost their gradient). Their backward hook is not called

Then, I tried to only train the optimizer for the generator, so no 2 optimizers. The first forward pass works without complaining, and then the next time the cleanup() function crashes where there is an assert param.grad is not None. If I remove the assertion it appears that the training loop then works without complaining, it seems. But I'm not sure if I can safely remove that assertion or not.

To me there's probably something wrong in the compute (in the formulation of the model), the two warnings above (grads not produced and train ability changed) should not fire for no reason. It's pretty easy to have a typo, say you produce integers somewhere and the backward graph will cut there for instance, easy to overlook (just an example)

With FSDP (by setting plugins='fsdp' in PTL), in the one-optimizer setting, I get a (maybe irrelevant) warning ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked and then the error assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer.".

This tells me that the AMP context is probably wrong (the second warning, not the first one), it's a meaningful warning which should not happen I believe

With FSDP, in the two optimizer setting, I get SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f9b573dea00> returned NULL without setting an error.

As you can see - nothing really works, but maybe I made some progress with the DDP sharded setting. I know this is pretty complicated, but I think I'm not capable of solving it myself at this point :)

The snippet I use for creating the models used, with model initialization simplified/abstracted:

def wrap_checkpoints(self, modules: list[type], blacklist: Optional[list[str]] = None,

                 root: Optional[nn.Module] = None, **checkpoint_kwargs):

  root = root or self

  if blacklist is None:

      blacklist = []

  for n, m in root.named_modules():

      if any([n.startswith(s) for s in blacklist]) or getattr(m, '_checkpointed', False):

          continue

      if any([isinstance(m, m_) for m_ in modules]):

          root.get_submodule(n).apply(lambda mm: setattr(mm, '_checkpointed', True))

          parent_n, n = n.rsplit('.', 1)

          setattr(root.get_submodule(parent_n), n, checkpoint_wrapper(m, **checkpoint_kwargs))

  return root



def configure_sharded_model(self):

  self.discriminator = auto_wrap(

      self.wrap_checkpoints(root=DiscriminatorNetworks(**D_kwargs), modules=self.checkpoint_modules,

                            offload_to_cpu=True))

  self.generator = auto_wrap(

      self.wrap_checkpoints(root=GeneratorNetwork(**G_kwargs), modules=self.checkpoint_modules,

                            offload_to_cpu=True))

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

@anj-s My understanding was that outside of FSDP, auto_wrap is a no-op. I never call DDP/Sharded DDP ( https://fairscale.readthedocs.io/en/stable/api/nn/sharded_ddp.html ) /FSDP manually. Pytorch Lightning handles that. The instructions in PTL for using FSDP are to do what I did basically (this page gives examples for both Sharded DDP and FSDP - https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#fully-sharded-training - the page says that auto_wrap is a no-op when not using FSDP).

The goal is to use FSDP but I thought that maybe Sharded DDP is more mature so I should try to get it working first / use it to debug my model.

@blefaudeux I am looking into what you wrote and will reply with a separate comment.

@anj-s
Copy link
Contributor

anj-s commented Aug 2, 2021

@anj-s My understanding was that outside of FSDP, auto_wrap is a no-op. I never call DDP/Sharded DDP ( https://fairscale.readthedocs.io/en/stable/api/nn/sharded_ddp.html ) /FSDP manually. Pytorch Lightning handles that. The instructions in PTL for using FSDP are to do what I did basically (this page gives examples for both Sharded DDP and FSDP - https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#fully-sharded-training - the page says that auto_wrap is a no-op when not using FSDP).

The goal is to use FSDP but I thought that maybe Sharded DDP is more mature so I should try to get it working first / use it to debug my model.

@blefaudeux I am looking into what you wrote and will reply with a separate comment.

Thanks! That makes sense. I would like to repro with just DDP + checkpoint_wrapper to see if that works. looks like there were some errors with that as well. I think we can then try FSDP + checkpoint_wrapper. If you have a smaller repro example to start with that would be good.

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

@blefaudeux

Re DDP (same param marked twice): I do call my discriminator multiple times with different batches, as well as call some underlying module in the generator repeatedly (an iterative generation process, so cannot be unfolded into different params). Is this use case fundamentally unsupported by gradient checkpointing? I didn't see any notes like that anywhere. (@anj-s since you mentioned wanting to repro with DDP + checkpoint_wrapper, I'm not sure if this is relevant) The regular DDP error message is actually stumping me more than any of the other ones, so I was not planning on trying to get it to work. That being said, I should be able to hopefully come up with a small snippet that reproduces both this "marked twice" error in DDP as well as the opaque NULL error in FSDP.

Re shared DDP / trainability: I think I explained this in my first comment. The trainability warning printing every iteration is annoying but I was saying it's to be expected since I am switching the generator from being trainable to the discriminator being trainable, so of course the trainability changes. I'm not bothered by that error and I don't think you guys should be either (except that maybe you could do something to suppress this from being printed out at every fwd pass).

Re some gradients not being reduced: I am pretty sure that every parameter is being touched. This is what I spent some time looking into yesterday. I agree it is a concerning error, and I will keep investigating today. Other than this warning message, sharded DDP appears to work.

Re the cleanup() bug: yes, something weird is probably occurring here since I kept all my generator/discriminator code but removed the discriminator optimizer. I don't want to focus on this since it doesn't seem like the most relevant.

It seems that the focus now should be getting a compact example that reproduces these bugs (specifically the FSDP ones) sent over to you guys. Let me work on that now. Is there a good way for me to share it with you privately?

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

BTW just as a sanity check I ran the same scaffolding code (that is, the way I'm calling checkpoint_wrapper and wrap/auto_wrap) with Sharded DDP and FSDP on another model (one optimizer, regular feed-forward CNN more or less). It worked fine in both cases.

EDIT - the model actually did NOT work fine with regular DDP and checkpointing. The "mark a variable ready only once" error was thrown. Looking into it.

@blefaudeux
Copy link
Contributor

@dave-epstein re trainability, my bad, I missed that part of your explanation, makes sense !

It makes me think of a possible explanation with respect to DDP though, AFAIK it doesn't expect the graph to change in terms of trainability from step to step, so depending on the structure of your code (if generator and discriminator share something for instance) it could dislike something there. Cc @mrshenli just in case.

Sorry if this is obvious, but even for FSDP I would focus on getting things to work without activation checkpointing first, and then turn that on, in terms of backwards it does complexity the flow a fair bit. Sidenote is that if you're using Transformer you can also use reversible layers for a similar effect but a more straightforward graph

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

For DDP, the model started working perfectly fine once I short-circuited the forward pass so that it did not call the same model multiple times. Okay, so we know why that's happening in DDP (multiple calls to same model). Shouldn't this be allowable? It seems like a pretty common use case. In my case it was an off-the-shelf ResNet50 that I was calling twice with two different batches of images:

self.encoder(batchone)
self.encoder(batchtwo)

@dave-epstein
Copy link
Author

@blefaudeux Thanks. Nothing is too obvious, especially when I've been staring at things for so long :) See my comment posted right after yours for DDP (G and D don't share any weights). Not using a Transformer in this case, but good tip.

Focusing on DDP / Sharded DDP / FSDP without checkpointing:

  • DDP: no checkpointing - it works! Checkpointing - throws the "Expected to mark a variable ready only once." error. Seems like this is caused by multiple fwd passes of the same model. I'd say this should be supported, since it seems like such a common use case.
  • Sharded DDP: Same issue with and without checkpointing - Grads waiting to be reduced.
  • FSDP: Same issue with and without checkpointing - the NULL error thing.

@anj-s
Copy link
Contributor

anj-s commented Aug 2, 2021

@dave-epstein I think we can only debug with a small repro example accompanied by stack traces(if possible). It is best if the smaller repro example is something you can share with us publicly on GitHub.

@dave-epstein
Copy link
Author

Ok, let me work on putting together a smallest possible example to repro, and share it here when done. Hopefully I can get it to reproduce all of these errors :)

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

Here is a very very simple example that causes the "Expected to mark a variable ready only once." error to fire on regular DDP when using checkpoint_wrapper. As I suspected it is caused (among other things, I guess), by calling the same model more than once. Are you guys able to easily connect this to Lightning scaffolding, or should I share the entry point file too?

from typing import Optional

import torch
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
from pytorch_lightning import LightningModule
from torch import nn
from torchvision.models import resnet50
from torchvision.models.resnet import Bottleneck


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()

    def configure_sharded_model(self):
        self.layer = self._checkpoint_and_shard(resnet152(), modules=[Bottleneck], wrap_mode='wrap')

    def _checkpoint_and_shard(self, m, modules, wrap_mode='auto'):
        wrap_fn = auto_wrap if wrap_mode == 'auto' else wrap
        return wrap_fn(self.wrap_checkpoints(root=m, modules=modules, offload_to_cpu=True))

    def wrap_checkpoints(self, modules: list[type], blacklist: Optional[list[str]] = None,
                         root: Optional[nn.Module] = None, **checkpoint_kwargs):
        root = root or self
        count = 0
        if blacklist is None:
            blacklist = []
        for n, m in root.named_modules():
            if any([n.startswith(s) for s in blacklist]) or getattr(m, '_checkpointed', False):
                continue
            if any([isinstance(m, m_) for m_ in modules]):
                root.get_submodule(n).apply(lambda mm: setattr(mm, '_checkpointed', True))
                parent_n, n = n.rsplit('.', 1)
                setattr(root.get_submodule(parent_n), n, checkpoint_wrapper(m, **checkpoint_kwargs))
                count += 1
        print(f'Checkpointed {count} modules of type {modules}')
        return root

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        batch = batch[0]
        n = len(batch)
        firsthalf = self.layer(batch[:n//2])
        secondhalf = self.layer(batch[n//2:])
        firstloss = torch.nn.functional.cross_entropy(firsthalf, torch.ones(len(firsthalf)).to(firsthalf.device).to(torch.long))
        secondloss = torch.nn.functional.cross_entropy(secondhalf, torch.ones(len(secondhalf)).to(firsthalf.device).to(torch.long))
        loss = firstloss + secondloss
        return {'loss': loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def on_train_start(self) -> None:
        # if int(os.environ.get("LOCAL_RANK", 0)) == 0: import ipdb; ipdb.set_trace()
        pass

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('fake_test_acc', loss)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

BTW, on sharded DDP and FSDP, this example works fine. But I figure it's probably still a bug.

Working on examples to repro the other errors too.

@dave-epstein
Copy link
Author

Any GAN setup with PT Lightning should give you the repeat "trainable parameters changed" error with Sharded DDP, even this one - https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/basic-gan.html - which I tried. However, the grads waiting to be reduced error is more mysterious, and seems to be somehow tied to my system which is a bit more complex. I'll post here (as usual) when I figure more out.

@min-xu-ai
Copy link
Contributor

For DDP, the model started working perfectly fine once I short-circuited the forward pass so that it did not call the same model multiple times. Okay, so we know why that's happening in DDP (multiple calls to same model). Shouldn't this be allowable? It seems like a pretty common use case. In my case it was an off-the-shelf ResNet50 that I was calling twice with two different batches of images:

self.encoder(batchone)
self.encoder(batchtwo)

This is a familiar case where we have debugging a lot for VISSL framework as well. You can check for those "multiple_fwd" test cases in the tests/nn/data_parallel directory to see what is being tested.

With multiple forward for the same module, when it is also checkpointed, we have to put in special logic to handle it with FSDP and checkpoint_wrapper. But since this is a complicated cases, I won't be surprised that there are still corner cases not handled.

DDP did not handle multiple fwd + checkpoint when I tried it about 6 months ago. The best workaround I had at that time is to use model.no_sync() context and stop the DDP code from doing gradient reduction and then doing the gradient reduction by custom (non DDP) code after the backward is done. This is certainly slower since the reduction is not overlapped with the backward compute.

However, I think @zhaojuanmao has put in a new DDP flag called "static_graph" or something similar that may have fixed this issue since then.

@dave-epstein
Copy link
Author

dave-epstein commented Aug 2, 2021

The _set_static_graph does not work in my case due to the separate G and D forward pass logic. Regardless, I'm (more than) okay with using FSDP instead of DDP. I'm not spending more time on that case. Instead I'm trying to build a portable example to reproduce the bugs I'm experiencing with sharded DDP and FSDP, which may be exposing a corner case in the logic you guys have written.

@min-xu-ai
Copy link
Contributor

min-xu-ai commented Aug 3, 2021

build a portable example to reproduce the bugs I'm experiencing with sharded DDP and FSDP, which may be exposing a corner case in the logic you guys have written.

Thank you so much! It will be tremendously helpful!

@dave-epstein
Copy link
Author

I just prepared a really simple example that reproduces the behavior I described above identically. I am packaging it now and will post another comment here once it's done.

@dave-epstein
Copy link
Author

dave-epstein commented Aug 3, 2021

So I think the FSDP bug is very related to Issue #617. The error hiding in the FSDP output was: ERROR: expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.BACKWARD_PRE.

Here is the code: https://github.com/dave-epstein/example-code/blob/main/example.tar.gz
Sorry if there was a nicer way to share it, this was the most convenient thing I could come up with.

FWIW I'm on V100s with 32GB memory, CUDA 11.1, torch 1.9.0. But you don't need more than a few GB of VRAM to run the experiments.

You should get the following situation - run the code with PTL 1.4, fairscale 0.4, you need Hydra installed too.

  • python run_debug.py +experiment=debug trainer.plugins=fsdp model.checkpoint=true and python run_debug.py +experiment=debug trainer.plugins=fsdp model.checkpoint=false
    FSDP - You get the state error I pasted above. Note that this toy example seems to differ from my case in that I didn't have to modify the is_leaf condition in the casting that was discussed above. But we can deal with that later... :) It also differs in that I don't think it elicits the ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked message which my full code does. But again that can be looked at later.

  • python run_debug.py +experiment=debug trainer.plugins=ddp_sharded model.checkpoint=true and python run_debug.py +experiment=debug trainer.plugins=ddp_sharded model.checkpoint=false
    Sharded DDP - You get the "trainable params changed" due to switching between G/D (benign), as well as the more concerning Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context in refresh_trainable in sharded_ddp.py. The layers that have self._grad_to_be_reduced[index] set to True when it shouldn't be are consistent between forward passes, and I think this should not happen. However, since this is just a warning, the training does proceed.

The below experiments are just regular DDP:

  • python run_debug.py +experiment=debug "~trainer.plugins" model.checkpoint=true
    You get RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons..., as we discussed above. Not you guys' problem, I guess. And I'm fine with using either sharded or fully sharded DDP (I prefer it).
  • python run_debug.py +experiment=debug "~trainer.plugins" model.checkpoint=false
    This should just work.

I hope this helps. Please let me know if you have any questions.

@anj-s anj-s self-assigned this Aug 3, 2021
@anj-s
Copy link
Contributor

anj-s commented Aug 3, 2021

@dave-epstein I can reproduce the issue. Thank you for sharing the repro example, this really helps a great deal. i'll start investigating and update once I have something.

@blefaudeux
Copy link
Contributor

blefaudeux commented Aug 4, 2021

Any GAN setup with PT Lightning should give you the repeat "trainable parameters changed" error with Sharded DDP, even this one - https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/basic-gan.html - which I tried. However, the grads waiting to be reduced error is more mysterious, and seems to be somehow tied to my system which is a bit more complex. I'll post here (as usual) when I figure more out.

Could you try with "auto refresh trainable" set to off ? This stops ShardedDDP from monitoring the trainability status (it will stick to the one initially defined when wrapping the model), both warnings are related to that (ShardedDDP detects that the graph changed, but some hooks defined previously to track trainable parameters have not fired -> warning that something could be wrong in the setup)

Edit: we can try that out ourselves, only access to my phone at the moment so throwing out ideas, sorry

@dave-epstein
Copy link
Author

dave-epstein commented Aug 4, 2021

Any GAN setup with PT Lightning should give you the repeat "trainable parameters changed" error with Sharded DDP, even this one - https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/basic-gan.html - which I tried. However, the grads waiting to be reduced error is more mysterious, and seems to be somehow tied to my system which is a bit more complex. I'll post here (as usual) when I figure more out.

Could you try with "auto refresh trainable" set to off ? This stops ShardedDDP from monitoring the trainability status (it will stick to the one initially defined when wrapping the model), both warnings are related to that (ShardedDDP detects that the graph changed, but some hooks defined previously to track trainable parameters have not fired -> warning that something could be wrong in the setup)

Edit: we can try that out ourselves, only access to my phone at the moment so throwing out ideas, sorry

Can try this, but won't it break because the parameters which require grad are changing every other forward pass? From my recollection of the code the list of parameters is pulled from the optimizers and checked for requires_grad. I'll take a look tomorrow.

@dave-epstein
Copy link
Author

dave-epstein commented Aug 5, 2021

Yes, I think I'm correct about that. If you don't end up calling refresh_trainable then it will just stick to the list of params that had requires_grad at first. Depending on when this initial setup happens, this will either be either only G or only D params (bad for obvious reasons), or all params, both G and D (if the setup happens before any logic in PTL turns off params belonging to the non-active optimizer). In the latter case, this has the downside of doing unnecessary work in each forward pass waiting for grads that will never come, but I don't know if that's significant computationally or not.

@anj-s
Copy link
Contributor

anj-s commented Aug 10, 2021

@dave-epstein Spent some debugging and identified that the issue was that the generator model (ResNet) had requires_grad=False on its param. This meant that the backward hook was not attached and triggered. We did not check for the condition that the root FSDP instance had params but those params did not require grads. That is what led to the error being raised. I have a potential fix that i tagged in this issue. Can you patch it and see if it fixes your original code?

One of the things I did not understand is the following: I see configure_optimizers in dummycomposer.py being called twice. Any idea why this is the case?

@dave-epstein
Copy link
Author

dave-epstein commented Aug 10, 2021

@anj-s It does seem that that fixed the FSDP problem! Were you able to look at the regular DDP Sharded situation too? It's not clear to me when I might prefer using just that over FSDP, just curious.

Re: configure_optimizers, I poked around a bit. The crux of it (not sure if it's an issue or not?) is in pre_dispatch in accelerator.py, in the PyTorch Lightning code. L104 calls self.training_type_plugin.pre_dispatch which is the fully_sharded.py method (the file is under plugins/training_type in PTL), which calls configure_ddp which then calls setup_optimizers (L330, accelerator.py). Then, L106 in pre_dispatch in accelerator.py also calls the same self.setup_optimizers. Hope that makes sense. Seems like none of this is code in the fairscale repo, all in PTL.

Another thing: Unfortunately, it seems like I can't use model checkpointing to save memory because I'm computing R1 loss which regularizes model gradients, and checkpoint seems to require .backward() and not .grad(). If this is something that can be worked around or merits further discussion, I can open a new issue.

@ananthsub
Copy link

@daveepstein I can follow up with you for fixes to PTL for FSDP

@anj-s
Copy link
Contributor

anj-s commented Aug 11, 2021

@anj-s It does seem that that fixed the FSDP problem! Were you able to look at the regular DDP Sharded situation too? It's not clear to me when I might prefer using just that over FSDP, just curious.
--> Great! Glad to know the issue was fixed. Sorry, did you mean SharadedDataParallel or DDP? I did not get a chance to look at either of these errors. For DDP it might be best to post in the PyTorch forums or open an issue. I can take a look at the SDP warnings that you mentioned.

Re: configure_optimizers, I poked around a bit. The crux of it (not sure if it's an issue or not?) is in pre_dispatch in accelerator.py, in the PyTorch Lightning code. L104 calls self.training_type_plugin.pre_dispatch which is the fully_sharded.py method (the file is under plugins/training_type in PTL), which calls configure_ddp which then calls setup_optimizers (L330, accelerator.py). Then, L106 in pre_dispatch in accelerator.py also calls the same self.setup_optimizers. Hope that makes sense. Seems like none of this is code in the fairscale repo, all in PTL.
--> Looks like ananthsub@ will take a look at these.

Another thing: Unfortunately, it seems like I can't use model checkpointing to save memory because I'm computing R1 loss which regularizes model gradients, and checkpoint seems to require .backward() and not .grad(). If this is something that can be worked around or merits further discussion, I can open a new issue.
--> Let us open a new issue for this. First I would check if the vanilla checkpoint_wrapper from PyTorch supports this. If I recall correctly when using custom autograd.Function you need to use backward().

@dave-epstein
Copy link
Author

Yes, I meant ShardedDataParallel. DDP is not under your purview :)

Right, vanilla PT checkpointing doesn't support this either. My instinct is that it's a pretty deep limitation and not something that can be patched, but maybe I'll make a forum post or issue.

Thanks! I will leave the issue open for you to take a look at Sharded.

@anj-s anj-s reopened this Aug 12, 2021
@dave-epstein
Copy link
Author

dave-epstein commented Sep 2, 2021

Hi @anj-s , I'm still getting this issue now when I use wrap instead of auto_wrap, in a standard GAN training setup (ie when the FSDP wrapper is actually used, since the model is otherwise under the parameter count threshold). This is on a new machine where I made the two changes to fully_sharded_data_parallel.py that are in the above commit. It takes a bit of time to put together a reproducible example, so I was wondering if there are any checks I could conduct on my end to see what's going on.

any([p.requires_grad for p in m.params]) is True so it's expecting state BACKWARD_POST but it gets BACKWARD_PRE.

@dave-epstein
Copy link
Author

dave-epstein commented Sep 2, 2021

I think (99% sure) it's related to autograd.grad being called to generate a loss. The forward pass works fine until one of these grad-of-grad losses kicks in. This is a common use case with GANs, e.g. for path length loss on a generator or gradient regularization on a discriminator.

@min-xu-ai
Copy link
Contributor

Hi @anj-s @dave-epstein, did you guys see the issue #771? It also mentioned the autograd.grad, which is not supported by the current checkpoint function. Maybe you can ask pytorch team on the root cause of that and the plan to fix that in the future?

What's TODO left here in this issue?

@anj-s
Copy link
Contributor

anj-s commented Sep 7, 2021

auto_wrap

@dave-epstein Can you confirm that all the params of the model and submodules have the requires_grad set on them. The error previously was that we were checking the condition when some of the params in the submodules had requires_grad=False. auto_wrap and wrap may differ which modules are wrapped and hence which params we check for the state. I'll take another look to see what could be an issue wrt to params. Is this happening with checkpoint + autograd.grad?

@min-xu-ai see comment for what is remaining. Essentially this issue was open to follow up on the ShardedDataParallel warnings. However it seems now that the original error has resurfaced when using only wrap. let me know if you have any ideas about what could fundamentally be different between wrap and auto_wrap.

@min-xu-ai
Copy link
Contributor

@dave-epstein, there have been several recent changes that fixed issues around FSDP multiple forward passes in the same iteration as well as requires_grad handling in checkpoint_wrapper. It might be worth it to check if the master branch of fairscale works better for your use case. Let us know what you had found if you try!

@min-xu-ai min-xu-ai added FSDP FullyShardedDataParallel (zero-3) activation checkpoint labels Sep 15, 2021
@Alex-Songs
Copy link

@ananthsub @blefaudeux @dave-epstein @min-xu-ai @anj-s
Hello, when I use fairscale 0.4.6, the same error is reported: None of the inputs have requires_grad=True. Gradients will be None warnings.warn("None of the inputs have requires_grad=True. Gradients will be None. Please solve it ? I just use the activation checkpoint for the middle layer of the Transformer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
activation checkpoint FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
6 participants