-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP always sets requires_grad to False? Breaking checkpoint_wrapper. #758
Comments
Thanks for opening this issue! We do have tests with checkpoint_wrapper here and here. How does your example compare? Would it possible to maybe modify one of these tests so that we can repro the issue? cc @min-xu-ai who might have some ideas about this. |
Thanks for the bug report!
The method actually carries the requires_grad from the original. So I don’t see it seeing the flag to false blindly. do you have a small reproducible case that we can use to debug it? |
Thanks to both. I am looking into this more deeply still. One key difference may be that I'm using PyTorch Lightning with the FSDP plugin, so I'm not handling the instantiation of the model with full granularity. It may then turn out that this is a Lightning+FSDP+checkpoint_wrapper bug. I will reply with more details. Another small difference - I am using auto_wrap from fairscale.nn to wrap models in the configure_sharded_model method in PyTorch Lightning. I'm using fairscale 0.4 and lightning 1.4. I'm having some issues using the PTL "BoringModel" for some reason right now, I wanted to build a simple reproducible case. As I make progress I'll update this thread. |
Ok, seems like the issue is actually simple - the input feature map requires_grad, is fp32, but is NOT a leaf. So the created fp16 version does NOT get requires_grad set to True, and this causes the issue. But clearly this is causing some undesirable behavior. Is this an issue in my code then, or in the control flow logic of cast_floats_to_right_precision? |
I removed the is_leaf check and now copy over x's requires_grad to y's requires_grad unconditionally. What results now is the following opaque error: SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f11d22aea00> returned NULL without setting an error What can I do to help you guys dig into this? |
This may have something to do with the fact that I'm doing GAN training. I'm trying to use regular ddp_sharded and I get a complaint that no parameter requires gradient when I'm in the discriminator optimization step and running noise through a generator. I may need to file a PyTorch Lightning issue. |
Thanks for digging and detailed description on what you tried and got. I am thinking this is a case not yet being tested by FSDP. Perhaps the input to your root FSDP model is generated from another model? Did you try not using mixed precision so that this code is skip to see if this problem is still there or not? Also, did you try remove the nograd context in the conversion function? |
Right now I am trying to fully understand what works and does not work with regular DDP, DDP sharded, and FSDP, while using checkpoint_wrappers for GAN training. I will comment here when I have finished my tests :) |
This will be a bit long, but I hope it is comprehensive and helpful, and that it provides somewhere for us to start from to continue debugging. In all cases, I use the same snippet for checkpoint wrapping, which I am putting at the end of this comment. To address your comment above, I will try to re-add the is-leaf check and see what happens with fp32 training. I think the nograd context is not the problem, since the requires_grad was being propagated fine when I removed the is-leaf check. I have been focusing on getting things working in the regular DDP and DDP Sharded setting first. With regular DDP, I get an error - With DDP Sharded, the situation is as follows: When I try to train with both optimizers, there was the error I described above, where it complains that no parameter in the input requires gradient. This is because the generator model has all params with When I do that, I keep getting messages that trainable parameters change, at every iteration (I am 99% sure it's a non-issue and expected behavior due to toggling between states - in Then, I tried to only train the optimizer for the generator, so no 2 optimizers. The first forward pass works without complaining, and then the next time the With FSDP (by setting With FSDP, in the two optimizer setting, I get As you can see - nothing really works, but maybe I made some progress with the DDP sharded setting. I know this is pretty complicated, but I think I'm not capable of solving it myself at this point :) The snippet I use for creating the models used, with model initialization simplified/abstracted: def wrap_checkpoints(self, modules: list[type], blacklist: Optional[list[str]] = None,
root: Optional[nn.Module] = None, **checkpoint_kwargs):
root = root or self
if blacklist is None:
blacklist = []
for n, m in root.named_modules():
if any([n.startswith(s) for s in blacklist]) or getattr(m, '_checkpointed', False):
continue
if any([isinstance(m, m_) for m_ in modules]):
root.get_submodule(n).apply(lambda mm: setattr(mm, '_checkpointed', True))
parent_n, n = n.rsplit('.', 1)
setattr(root.get_submodule(parent_n), n, checkpoint_wrapper(m, **checkpoint_kwargs))
return root
def configure_sharded_model(self):
self.discriminator = auto_wrap(
self.wrap_checkpoints(root=DiscriminatorNetworks(**D_kwargs), modules=self.checkpoint_modules,
offload_to_cpu=True))
self.generator = auto_wrap(
self.wrap_checkpoints(root=GeneratorNetwork(**G_kwargs), modules=self.checkpoint_modules,
offload_to_cpu=True)) |
Let me add @zhaojuanmao and @blefaudeux do you have instructions for running the code snippet? |
@dave-epstein Want to confirm that for ShardedDP we use the same code snippet? In that case are you replacing Is the goal to use FSDP or SDP? |
Thats what I am trying to confirm. I am unclear about what is being used with checkpoint_wrapper. I know DDP and FSDP but what is DDP sharded? |
This happens when the same parameter produces grad twice, AFAIK. DDP's task is to all reduce the gradients when produced, and in that case it's lost in that this parameter's gradient has already been reduced but it just produced new gradients, so something is probably amiss. It could happen if the same param is used twice in different places in your FW, each use will produce a gradient. If they can/should be different, you could have two instances instead and it should fix this issue
This means that the trainability of the parameters have changed, at least one param, and it does not look like what you want ?
This means that at least one param "requires grad", but it was not actually used in the FW probably (or used with with an input without grads, one of these options), so it does not produce grads. Same as above, not deal breaking per say, but it's meant to warn you that something is probably wrong in the compute graph
I would bet that these parameters are not being touched really ? (Or used with parameters which have lost their gradient). Their backward hook is not called
To me there's probably something wrong in the compute (in the formulation of the model), the two warnings above (grads not produced and train ability changed) should not fire for no reason. It's pretty easy to have a typo, say you produce integers somewhere and the backward graph will cut there for instance, easy to overlook (just an example)
This tells me that the AMP context is probably wrong (the second warning, not the first one), it's a meaningful warning which should not happen I believe
|
@anj-s My understanding was that outside of FSDP, auto_wrap is a no-op. I never call DDP/Sharded DDP ( https://fairscale.readthedocs.io/en/stable/api/nn/sharded_ddp.html ) /FSDP manually. Pytorch Lightning handles that. The instructions in PTL for using FSDP are to do what I did basically (this page gives examples for both Sharded DDP and FSDP - https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#fully-sharded-training - the page says that auto_wrap is a no-op when not using FSDP). The goal is to use FSDP but I thought that maybe Sharded DDP is more mature so I should try to get it working first / use it to debug my model. @blefaudeux I am looking into what you wrote and will reply with a separate comment. |
Thanks! That makes sense. I would like to repro with just DDP + checkpoint_wrapper to see if that works. looks like there were some errors with that as well. I think we can then try FSDP + checkpoint_wrapper. If you have a smaller repro example to start with that would be good. |
Re DDP (same param marked twice): I do call my discriminator multiple times with different batches, as well as call some underlying module in the generator repeatedly (an iterative generation process, so cannot be unfolded into different params). Is this use case fundamentally unsupported by gradient checkpointing? I didn't see any notes like that anywhere. (@anj-s since you mentioned wanting to repro with DDP + checkpoint_wrapper, I'm not sure if this is relevant) The regular DDP error message is actually stumping me more than any of the other ones, so I was not planning on trying to get it to work. That being said, I should be able to hopefully come up with a small snippet that reproduces both this "marked twice" error in DDP as well as the opaque NULL error in FSDP. Re shared DDP / trainability: I think I explained this in my first comment. The trainability warning printing every iteration is annoying but I was saying it's to be expected since I am switching the generator from being trainable to the discriminator being trainable, so of course the trainability changes. I'm not bothered by that error and I don't think you guys should be either (except that maybe you could do something to suppress this from being printed out at every fwd pass). Re some gradients not being reduced: I am pretty sure that every parameter is being touched. This is what I spent some time looking into yesterday. I agree it is a concerning error, and I will keep investigating today. Other than this warning message, sharded DDP appears to work. Re the cleanup() bug: yes, something weird is probably occurring here since I kept all my generator/discriminator code but removed the discriminator optimizer. I don't want to focus on this since it doesn't seem like the most relevant. It seems that the focus now should be getting a compact example that reproduces these bugs (specifically the FSDP ones) sent over to you guys. Let me work on that now. Is there a good way for me to share it with you privately? |
BTW just as a sanity check I ran the same scaffolding code (that is, the way I'm calling checkpoint_wrapper and wrap/auto_wrap) with Sharded DDP and FSDP on another model (one optimizer, regular feed-forward CNN more or less). It worked fine in both cases. EDIT - the model actually did NOT work fine with regular DDP and checkpointing. The "mark a variable ready only once" error was thrown. Looking into it. |
@dave-epstein re trainability, my bad, I missed that part of your explanation, makes sense ! It makes me think of a possible explanation with respect to DDP though, AFAIK it doesn't expect the graph to change in terms of trainability from step to step, so depending on the structure of your code (if generator and discriminator share something for instance) it could dislike something there. Cc @mrshenli just in case. Sorry if this is obvious, but even for FSDP I would focus on getting things to work without activation checkpointing first, and then turn that on, in terms of backwards it does complexity the flow a fair bit. Sidenote is that if you're using Transformer you can also use reversible layers for a similar effect but a more straightforward graph |
For DDP, the model started working perfectly fine once I short-circuited the forward pass so that it did not call the same model multiple times. Okay, so we know why that's happening in DDP (multiple calls to same model). Shouldn't this be allowable? It seems like a pretty common use case. In my case it was an off-the-shelf ResNet50 that I was calling twice with two different batches of images: self.encoder(batchone)
self.encoder(batchtwo) |
@blefaudeux Thanks. Nothing is too obvious, especially when I've been staring at things for so long :) See my comment posted right after yours for DDP (G and D don't share any weights). Not using a Transformer in this case, but good tip. Focusing on DDP / Sharded DDP / FSDP without checkpointing:
|
@dave-epstein I think we can only debug with a small repro example accompanied by stack traces(if possible). It is best if the smaller repro example is something you can share with us publicly on GitHub. |
Ok, let me work on putting together a smallest possible example to repro, and share it here when done. Hopefully I can get it to reproduce all of these errors :) |
Here is a very very simple example that causes the "Expected to mark a variable ready only once." error to fire on regular DDP when using checkpoint_wrapper. As I suspected it is caused (among other things, I guess), by calling the same model more than once. Are you guys able to easily connect this to Lightning scaffolding, or should I share the entry point file too? from typing import Optional
import torch
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
from pytorch_lightning import LightningModule
from torch import nn
from torchvision.models import resnet50
from torchvision.models.resnet import Bottleneck
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
def configure_sharded_model(self):
self.layer = self._checkpoint_and_shard(resnet152(), modules=[Bottleneck], wrap_mode='wrap')
def _checkpoint_and_shard(self, m, modules, wrap_mode='auto'):
wrap_fn = auto_wrap if wrap_mode == 'auto' else wrap
return wrap_fn(self.wrap_checkpoints(root=m, modules=modules, offload_to_cpu=True))
def wrap_checkpoints(self, modules: list[type], blacklist: Optional[list[str]] = None,
root: Optional[nn.Module] = None, **checkpoint_kwargs):
root = root or self
count = 0
if blacklist is None:
blacklist = []
for n, m in root.named_modules():
if any([n.startswith(s) for s in blacklist]) or getattr(m, '_checkpointed', False):
continue
if any([isinstance(m, m_) for m_ in modules]):
root.get_submodule(n).apply(lambda mm: setattr(mm, '_checkpointed', True))
parent_n, n = n.rsplit('.', 1)
setattr(root.get_submodule(parent_n), n, checkpoint_wrapper(m, **checkpoint_kwargs))
count += 1
print(f'Checkpointed {count} modules of type {modules}')
return root
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
batch = batch[0]
n = len(batch)
firsthalf = self.layer(batch[:n//2])
secondhalf = self.layer(batch[n//2:])
firstloss = torch.nn.functional.cross_entropy(firsthalf, torch.ones(len(firsthalf)).to(firsthalf.device).to(torch.long))
secondloss = torch.nn.functional.cross_entropy(secondhalf, torch.ones(len(secondhalf)).to(firsthalf.device).to(torch.long))
loss = firstloss + secondloss
return {'loss': loss}
def training_step_end(self, training_step_outputs):
return training_step_outputs
def on_train_start(self) -> None:
# if int(os.environ.get("LOCAL_RANK", 0)) == 0: import ipdb; ipdb.set_trace()
pass
def training_epoch_end(self, outputs) -> None:
torch.stack([x["loss"] for x in outputs]).mean()
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"x": loss}
def validation_epoch_end(self, outputs) -> None:
torch.stack([x['x'] for x in outputs]).mean()
def test_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('fake_test_acc', loss)
return {"y": loss}
def test_epoch_end(self, outputs) -> None:
torch.stack([x["y"] for x in outputs]).mean()
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler] BTW, on sharded DDP and FSDP, this example works fine. But I figure it's probably still a bug. Working on examples to repro the other errors too. |
Any GAN setup with PT Lightning should give you the repeat "trainable parameters changed" error with Sharded DDP, even this one - https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/basic-gan.html - which I tried. However, the grads waiting to be reduced error is more mysterious, and seems to be somehow tied to my system which is a bit more complex. I'll post here (as usual) when I figure more out. |
This is a familiar case where we have debugging a lot for VISSL framework as well. You can check for those "multiple_fwd" test cases in the tests/nn/data_parallel directory to see what is being tested. With multiple forward for the same module, when it is also checkpointed, we have to put in special logic to handle it with FSDP and checkpoint_wrapper. But since this is a complicated cases, I won't be surprised that there are still corner cases not handled. DDP did not handle multiple fwd + checkpoint when I tried it about 6 months ago. The best workaround I had at that time is to use model.no_sync() context and stop the DDP code from doing gradient reduction and then doing the gradient reduction by custom (non DDP) code after the backward is done. This is certainly slower since the reduction is not overlapped with the backward compute. However, I think @zhaojuanmao has put in a new DDP flag called "static_graph" or something similar that may have fixed this issue since then. |
The _set_static_graph does not work in my case due to the separate G and D forward pass logic. Regardless, I'm (more than) okay with using FSDP instead of DDP. I'm not spending more time on that case. Instead I'm trying to build a portable example to reproduce the bugs I'm experiencing with sharded DDP and FSDP, which may be exposing a corner case in the logic you guys have written. |
Thank you so much! It will be tremendously helpful! |
I just prepared a really simple example that reproduces the behavior I described above identically. I am packaging it now and will post another comment here once it's done. |
So I think the FSDP bug is very related to Issue #617. The error hiding in the FSDP output was: Here is the code: https://github.com/dave-epstein/example-code/blob/main/example.tar.gz FWIW I'm on V100s with 32GB memory, CUDA 11.1, torch 1.9.0. But you don't need more than a few GB of VRAM to run the experiments. You should get the following situation - run the code with PTL 1.4, fairscale 0.4, you need Hydra installed too.
The below experiments are just regular DDP:
I hope this helps. Please let me know if you have any questions. |
@dave-epstein I can reproduce the issue. Thank you for sharing the repro example, this really helps a great deal. i'll start investigating and update once I have something. |
Could you try with "auto refresh trainable" set to off ? This stops ShardedDDP from monitoring the trainability status (it will stick to the one initially defined when wrapping the model), both warnings are related to that (ShardedDDP detects that the graph changed, but some hooks defined previously to track trainable parameters have not fired -> warning that something could be wrong in the setup) Edit: we can try that out ourselves, only access to my phone at the moment so throwing out ideas, sorry |
Can try this, but won't it break because the parameters which require grad are changing every other forward pass? From my recollection of the code the list of parameters is pulled from the optimizers and checked for requires_grad. I'll take a look tomorrow. |
Yes, I think I'm correct about that. If you don't end up calling |
@dave-epstein Spent some debugging and identified that the issue was that the generator model (ResNet) had requires_grad=False on its param. This meant that the backward hook was not attached and triggered. We did not check for the condition that the root FSDP instance had params but those params did not require grads. That is what led to the error being raised. I have a potential fix that i tagged in this issue. Can you patch it and see if it fixes your original code? One of the things I did not understand is the following: I see |
@anj-s It does seem that that fixed the FSDP problem! Were you able to look at the regular DDP Sharded situation too? It's not clear to me when I might prefer using just that over FSDP, just curious. Re: configure_optimizers, I poked around a bit. The crux of it (not sure if it's an issue or not?) is in Another thing: Unfortunately, it seems like I can't use model checkpointing to save memory because I'm computing R1 loss which regularizes model gradients, and checkpoint seems to require .backward() and not .grad(). If this is something that can be worked around or merits further discussion, I can open a new issue. |
@daveepstein I can follow up with you for fixes to PTL for FSDP |
|
Yes, I meant ShardedDataParallel. DDP is not under your purview :) Right, vanilla PT checkpointing doesn't support this either. My instinct is that it's a pretty deep limitation and not something that can be patched, but maybe I'll make a forum post or issue. Thanks! I will leave the issue open for you to take a look at Sharded. |
Hi @anj-s , I'm still getting this issue now when I use
|
I think (99% sure) it's related to autograd.grad being called to generate a loss. The forward pass works fine until one of these grad-of-grad losses kicks in. This is a common use case with GANs, e.g. for path length loss on a generator or gradient regularization on a discriminator. |
Hi @anj-s @dave-epstein, did you guys see the issue #771? It also mentioned the autograd.grad, which is not supported by the current checkpoint function. Maybe you can ask pytorch team on the root cause of that and the plan to fix that in the future? What's TODO left here in this issue? |
@dave-epstein Can you confirm that all the params of the model and submodules have the requires_grad set on them. The error previously was that we were checking the condition when some of the params in the submodules had requires_grad=False. auto_wrap and wrap may differ which modules are wrapped and hence which params we check for the state. I'll take another look to see what could be an issue wrt to params. Is this happening with checkpoint + autograd.grad? @min-xu-ai see comment for what is remaining. Essentially this issue was open to follow up on the ShardedDataParallel warnings. However it seems now that the original error has resurfaced when using only wrap. let me know if you have any ideas about what could fundamentally be different between wrap and auto_wrap. |
@dave-epstein, there have been several recent changes that fixed issues around FSDP multiple forward passes in the same iteration as well as requires_grad handling in checkpoint_wrapper. It might be worth it to check if the master branch of fairscale works better for your use case. Let us know what you had found if you try! |
@ananthsub @blefaudeux @dave-epstein @min-xu-ai @anj-s |
❓ Questions and Help
Hi, I get the following error when trying to use FSDP with checkpoint_wrapper:
RuntimeError: None of the outputs have requires_grad=True, this checkpoint() is not necessary
I found this weird so I started digging around. I found that in the checkpoint wrapper function, all the input tensors had requires_grad=False stashed in the forward pass which led to this error. I checked that the inputs themselves do have requires_grad=True (at least one of them) and tried to find where this got clobbered. It is happening in the call to cast_floats_to_right_precision(True, True, *args, **kwargs) in the forward method of FullyShardedDataParallel.
I am really confused by this method since it seems that it methodically sets requires_grad to False for all inputs. If relevant, the layer being wrapped with a checkpoint is a Resnet-like block that is an intermediate layer in a network, the layer before it is an un-checkpointed Conv2D. So the input to the block is N x C x H x W and it has requires_grad=True, but when the forward call is dispatched (outputs = self.module(*args, **kwargs)), this tensor now has requires_grad=False. This breaks checkpoint_wrapper :(
Help would be much appreciated!
The text was updated successfully, but these errors were encountered: