-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fairseq support #1915
Fairseq support #1915
Conversation
Hi @jeffra, The pass-through feature is very neat! I had no idea it was added. We have been unwrapping the model for things like Though this magic now makes things very difficult to debug since it doesn't pass the engine object to the method it passed it to, so it could be a puzzle to the user which thinks a normal engine's
so we can clearly see the the confusing part is: yet nobody monkey patched deepspeed - I had to scour all places to validate nobody did monkeypatching, until I thought of searching for an overridden |
So @VictorSanh found that a model's DeepSpeed/deepspeed/runtime/engine.py Lines 1821 to 1826 in 349f845
if he manually adds it before calling In other words any utility relying on auto-discovery of external parameters and which gets invoked on an unwrapped model will break. The same model works just fine under normal training mode or when not using HF |
This, of course, also means that any HF Ideally, the solution would be to simply pass the wrapped model to The other approach would be to pass 2 objects - one to call on for the methods to get resolved and the other is the wrapped model to continue passing it on, until |
Thanks Stas! Some more context about this: The initial error arose when trying to perform generation (inference) on a HF transformers like model (saying "like" because it has some custom components, that i will describe later). The model is wrapped into a DSEngine, and calling at that point, The error arises in the vocab projection module (i.e. at the very end of the network), which instead of a regular
The error would typically be:
in the call to |
ok, so perhaps we can rescue this situation in a few ways (this and the next comment): deepspeed solution a. DeepSpeed/deepspeed/runtime/engine.py Lines 1821 to 1826 in 349f845
Let's call it: Of course the user can call https://deepspeed.readthedocs.io/en/latest/zero3.html#registering-external-parameters to overcome this, but the problem here is that we have an inconsistency between train/eval and |
I think this pass-through method magic is actually problematic as it swipes the problem under the carpet as can be seen from the above comments. model provider solution a. This doesn't have to be deepspeed specific and could be used for any framework that does the wrapping and needs to ensure the wrapping stays during |
deepspeed solution b: |
This is the first step towards deepspeed supporting fairseq. In order to use deepspeed w. fairseq we still require some changes on the fairseq trainer side. Currently pairs w. a fork of fairseq that adds deepspeed trainer support: https://github.com/jeffra/fairseq/tree/wip-deepspeed. We'll update documentation on deepspeed side once this fork has a new home and is fully featured and tested.
This PR adds the following features:
nn.Module
or DeepSpeedEngine attributes without code change. For example,BaseFairseqModel
(and it's children) expose many attributes that are used throughout the fairseq codebase. Without this feature we would need to modify many places in the fairseq code with deepspeed enabled checks and invoked the model with eithermodel.<fairseq-method>()
ormodel.module.<fairseq-method>()
.override_loss_scale(loss_scale)
. This can be called before each training step to manually set a specific loss scale value.FairseqOptimizer
w/o needing to specify that it is untestedmodel_f
for loading model state_dictzero.Init
in zero-3 to allow kwarg apply, this allowsF.linear()
usage which was previously not allowedDeepSpeedOptimizer
,ZeROOptimizer
) for all of the major deepspeed optimizers, this makes it easier to detect we're using a deepspeed optimizer on the client side. In the future it would be great to add common functionality to these classes that all of our optimizers can inherit from (e.g., loss scale logic, grad norm)