Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary assert on sub_module.training #5215

Conversation

ringohoffman
Copy link
Contributor

@ringohoffman ringohoffman commented Mar 1, 2024

Related: Lightning-AI/pytorch-lightning#19467

Why were these asserts added? nn.Module.training is for controlling forward() behavior. I have never seen it used to control backward() behavior, let alone raise an error because of it.

To me and the users in the ticket I linked, these checks are very unexpected. It seems like training is being co-opted for something beyond what it was originally intended for. If I just put my whole model into train mode before calling backward, I stop seeing these errors. How is training expected to be set on partially frozen models and why?

@JakobLS
Copy link

JakobLS commented Mar 2, 2024

When removing these asserts it allows me to launch my training script. Yet unaware of whether it has any consequences further down though.

@ringohoffman ringohoffman marked this pull request as draft March 8, 2024 17:49
@championsnet
Copy link

I tried removing these asserts as well but now I get an error later on parameter_offload.py:
File "/.local/lib64/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 316, in fetch_sub_module assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()

@Boltzmachine
Copy link

Does it get merged? It is really ridiculous to have such an assertion

@ringohoffman
Copy link
Contributor Author

Does it get merged? It is really ridiculous to have such an assertion

Last I checked, it doesn't work even if you remove the assertion. That is why I gave up on this.

@tjruwase
Copy link
Contributor

Does it get merged? It is really ridiculous to have such an assertion

Last I checked, it doesn't work even if you remove the assertion. That is why I gave up on this.

@ringohoffman, @Boltzmachine, apologies for missing this.

First, I want to affirm your observations:

  1. The assertion is incorrect and a problem for backward pass of frozen weights.
  2. We hijacked .training for module prefetching optimization by have separate prefetcher for (1) forward+backward trace (i.e., training) and (2) forward trace (i.e., eval/inference).
  3. It makes sense that removing these assertions uncovers problems previously unknown to us, since we have not tested those execution paths. We have previously assumed .train() workaround to make things work. However, creating gradients on frozen weights is not a long-term solution, so requires fixing.

Second, here are my thoughts for next steps:

  1. Remove these assertions and the misuse of .training for prefetching optimization. This will enable correct handling of backward on frozen weights and avoid user confusions (such as those discussed in the Lightning forum).
  2. Fix issues arising from the above change. @ringohoffman, are you able to revive this PR or share those failures? I am curious if disabling prefetching would address the failures you observed.
  3. Improve prefetching robustness:
    1. A different way to distinguish between forward+backward trace and forward trace.
    2. Eliminate prefetching across forward/backward boundary. Instead have separate prefetchers for forward and backward traces. We need to understand the performance implications of this.

We would really appreciate your help on the above plan. We also understand this might no longer be priority for you.

@tohtana, FYI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants