Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support compilation via Torchdynamo, AOT Autograd, NVFuser #17308

Merged
merged 10 commits into from
May 25, 2022

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented May 17, 2022

What does this PR do?

Adding support for TorchDynamo compilation with AOT Autograd and nvfuser backends. Detailed context available at - #17204

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

TODO:

setup pt-nightly CI to run the tests in this PR, instructions:

# install torch-nightly
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly

# install functorch (and reinstall after `git pull` later if need to sync up)
git clone https://github.com/pytorch/functorch
cd functorch
rm -rf build
pip install -e .[aot]

cd ..
git clone https://github.com/pytorch/torchdynamo
cd torchdynamo
pip install -r requirements.txt
python setup.py develop

@ydshieh is adding this in this PR: #17335 in commit: 52e7021

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 17, 2022

The documentation is not available anymore as the PR was closed or merged.

@stas00 stas00 self-assigned this May 17, 2022
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on adding this feature, @anijain2305!

I left a few suggestions below.

For the missing test(s) part:

  1. adding a functional test should be trivial, Just copy one of the tests in tests/trainer/test_trainer.py e.g. test_fp16_full_eval - so here we aren't comparing fp32 to fp16, but w/ and w/o torchdynamo - but really the meat of this test is in the next item as we want to test that it actually do something.

and the key to testing is passing this arg (or whatever the final name will be)

trainer = get_regression_trainer(a=a, b=b, torchdynamo=True, skip_memory_metrics=False)
  1. for the quality part this is trickier, if we want to detect an actual memory and/or speed saving - please suggest what small model we can use to detect a significant change - since CIs run on different machines a small change typically will work on one machine but fail on others.

  2. we will have to sort out how to get the CI to build all these 3rd party packages and the test will need to be conditioned on those available and skip otherwise - let's worry about that last.

@@ -881,6 +884,18 @@ class TrainingArguments:
)
},
)
use_torchdynamo: bool = field(
default=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely there will be various variations of this coming in future, is it not?

So I think it'd be good from the get-going to plan this argument as key:value where the value can be a complex combo, With the first value being "autograd;nvfuser" perhaps?

And then I'd call it just torchdynamo, so:

--torchdynamo "autograd;nvfuser"

Unless this is the only option that will ever be used, but somehow I doubt that.

CCing: @Chillee

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's still allow for one simple option --torchdynamo "autograd;nvfuser" is not user friendly at all.

Comment on lines 2228 to 2229
with self.torchdynamo_smart_context_manager():
with self.autocast_smart_context_manager():
Copy link
Contributor

@stas00 stas00 May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want nested ctx managers as they repeat a lot and don't help to the code readability, so let's create one that combines the existing and the new one.

@sgugger, any suggestions to what we should call the grand combo of all ctx managers in the end code above? Perhaps just smart_context_manager()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just rename autocast_smart_context_manager into compute_loss_context_manager and put everything under it. Also note that we have a ContextManagers utility (in utils) that allow you to group together a list of context managers, if that's useful.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I think we should group all context managers together for code readability.

Comment on lines 2228 to 2229
with self.torchdynamo_smart_context_manager():
with self.autocast_smart_context_manager():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just rename autocast_smart_context_manager into compute_loss_context_manager and put everything under it. Also note that we have a ContextManagers utility (in utils) that allow you to group together a list of context managers, if that's useful.

@@ -881,6 +884,18 @@ class TrainingArguments:
)
},
)
use_torchdynamo: bool = field(
default=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's still allow for one simple option --torchdynamo "autograd;nvfuser" is not user friendly at all.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for iterating! @stas00 are you okay with the changes?

Comment on lines 2201 to 2202
else:
ctx_manager = contextlib.nullcontext()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that else is needed since we set it to this by default.

@@ -450,6 +450,9 @@ class TrainingArguments:
full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training
torchdynamo (`str`, *optional*):
If `True`, TorchDynamo is called with AOT Autograd and nvfuser compiler to compile the appropriate portions
of the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
of the model.
of the model. This is an experimental API and it may change.

elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
elif self.args.torchdynamo is not None:
raise ValueError("torchdynamo training arg can be eager/nvfuser")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not do this check at argparse level via choices which kills 2 birds with one stone as it tells the user which options are legit and tests for wrong choices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I did not know about this.

@@ -450,6 +450,9 @@ class TrainingArguments:
full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training
torchdynamo (`str`, *optional*):
If `True`, TorchDynamo is called with AOT Autograd and nvfuser compiler to compile the appropriate portions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't match the actual usage. Definitely not True/False but the actual choices : eager, nvfuser

default=None,
metadata={
"help": (
"Whether or not to use TorchDynamo. TorchDynamo is a Python level JIT compiler designed to make"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment above - not whether or not, but how - via choices

tests/trainer/test_trainer.py Show resolved Hide resolved
@csarofeen
Copy link

@kevinstephano could you please take a quick look at this PR? Thanks!

@stas00
Copy link
Contributor

stas00 commented May 19, 2022

LGTM, thanks for iterating! @stas00 are you okay with the changes?

I primarily would like to hold off on merging this just yet to hear from @Chillee (PyTorch) and may be @csarofeen (NVIDIA) to think what other options we might want down the road and design the key/values better.

e.g. questions

  • do we have to go through TorchDynamo, or can we go directly through aot Autograd see: [Kernel Fusion] training benchmarks of AOTAutograd (multiple models) #15264
  • should we give users an option to choose other fusers besides nvfuser?
  • is it always "driver -> fuser" combo so perhaps the value should have 2 parts: --compile torchdynamo:nvfuser, --compile aotautograd:fuserxxx (and then the key needs to be renamed and the driver moved into the value - and that way we end up with just one entry point and lots of flexibility on the different combos. not sure on the best key name.

So ideally collect all the possible combos and then we could see how to best organize those.

but I added that the current API proposed in this PR is experimental, so we could go with it and change it at will later.

@kevinstephano
Copy link

@kevinstephano could you please take a quick look at this PR? Thanks!

Looks good to me.

@csarofeen
Copy link

  • do we have to go through TorchDynamo, or can we go directly through aot Autograd see: [Kernel Fusion] training benchmarks of AOTAutograd (multiple models) #15264
  • should we give users an option to choose other fusers besides nvfuser?
  • is it always "driver -> fuser" combo so perhaps the value should have 2 parts: --compile torchdynamo:nvfuser, --compile aotautograd:fuserxxx (and then the key needs to be renamed and the driver moved into the value - and that way we end up with just one entry point and lots of flexibility on the different combos. not sure on the best key name.

For the first question on TorchDynamo I'll leave @Chillee to give an opinion here.

Second point: I personally think nvFuser is going to be your best bet. We're trying to move torch script to be nvFuser by default: pytorch/pytorch#77579 so likely nvFuser is a good bet. Dynamo is looking at supporting multiple backends but I believe that will be more automated of a thing and shouldn't require you worrying about it.

For the last point I think again @Chillee is the one to ask. I think AOTAutograd is moving or has moved to nvFuser by default? Don't know for sure here, don't know what Dynamo is planning/looking for as options.

@Chillee
Copy link

Chillee commented May 19, 2022

do we have to go through TorchDynamo, or can we go directly through aot Autograd

IMO, going through TorchDynamo is the right option here. As mentioned in the previous PR, using AOTAutograd is somewhat risky, since we can't guarantee correctness. So, I don't think it's the right option to provide as a "default" trainer.

If users want to apply AOTAutograd by themselves then I think they should feel free to do so, but I'm not convinced we should provide it as an option integrated into HF.

should we give users an option to choose other fusers besides nvfuser?

Yes, I think it's reasonable. For example, we also have a TensorRT integration with TorchDynamo that has some fairly good numbers. However, as @csarofeen says, NVFuser is definitely what I'd recommend as the "default" for this PR - if we have other backends it'll just be a different flag.

is it always "driver -> fuser" combo so perhaps the value should have 2 parts

I think TorchDynamo + AOTAutograd are essentially "constants" here. It's plausible that in the future there will be other graph-capture paths (such as if we want to export models), but the UX for that will be significantly different (i.e. it won't be a seamless "always work" thing).

So I think it's fine to have fuser be the only thing that changes.

@stas00
Copy link
Contributor

stas00 commented May 20, 2022

Thank you for your commentary, @Chillee and @csarofeen

if we have other backends it'll just be a different flag.

That's the whole point of me starting this discussion - we don't want to have additional flags. We have too many already. That's why I was thinking that perhaps the flag should indicate some sort of non-implementation specific name like --fusion or --compiler or ??? and then the value(s) can define the specific path, so perhaps this PR's original cmd arg can be converted to:

--fusion torchdynamo:nvfuser
--fusion torchdynamo:eager

which makes it easy to add other combos in the future w/o needing to change the cmd arg api.

which fits the current code of this PR:

if self.args.torchdynamo == "eager":
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)

Does any of this resonate at all? And if it does what would be the apt generic naming for the example I used --fusion (currently --torchdynamo key) - perhaps --autofusion, --autooptimize, else?

@Chillee
Copy link

Chillee commented May 20, 2022

@stas00 ah sorry, I mispoke - by "another flag", I meant "another value for the config option". I think something like this would be better.

--fusion nvfuser
--fusion eager

(btw, I think "debug" might be a better name than "eager"? I think it's kinda confusing to have a fusion option called "eager" haha. Or perhaps we should just remove it as an option - it's only useful for debugging bugs).

From our side, I think the main option is just going to be torchdynamo. So I think --fusion nvfuser and --fusion eager is probably sufficient.

@stas00
Copy link
Contributor

stas00 commented May 20, 2022

(btw, I think "debug" might be a better name than "eager"? I think it's kinda confusing to have a fusion option called "eager" haha. Or perhaps we should just remove it as an option - it's only useful for debugging bugs).

I think "eager" is good because that's what you pass to torchdynamo - it'd be easy to document that it doesn't do any fusing and just provides the default behavior.

From our side, I think the main option is just going to be torchdynamo. So I think --fusion nvfuser and --fusion eager is probably sufficient.

so what you're proposing is that torchdynamo is going to be implied as the driver for nvfuser or eager and then in the future there might be other drivers besides torchdynamo?

So currently then we are discussing 2 options:

--fusion nvfuser
--fusion eager

which imply:

--fusion torchdynamo:nvfuser
--fusion torchdynamo:eager

perhaps I should not bother to future proof this flag?

@Chillee
Copy link

Chillee commented May 20, 2022

so what you're proposing is that torchdynamo is going to be implied as the driver for nvfuser or eager and then in the future there might be other drivers besides torchdynamo?

I think it's unlikely that in the (foreseeable) future there will be other drivers besides torchdynamo with a similar UX. So imo, there's not a significant reason to try to future proof this flag now - i don't think it'd be that hard to change while preserving BC in the future either.

@stas00
Copy link
Contributor

stas00 commented May 20, 2022

ok, so then let's keep the original proposal --torchdynamo <nfuser|eager>, right?

@anijain2305
Copy link
Contributor Author

@stas00 This PR is ready for another round of review. Let me know what you think.

@stas00
Copy link
Contributor

stas00 commented May 25, 2022

  1. The memory test consistently hangs for me:
$ pytest tests/trainer/test_trainer.py -k torchdynamo_memory -sv

nothing useful in the output.

Traceback:

$ py-spy dump --pid 530235
Thread 530235 (idle): "MainThread"
    backward (torch/autograd/__init__.py:173)
    backward (torch/_tensor.py:399)
    _backward (functorch/_src/monkey_patching.py:97)
    training_step (transformers/trainer.py:2263)
    test_torchdynamo_memory (tests/trainer/test_trainer.py:1668)
    _callTestMethod (unittest/case.py:633)
    run (unittest/case.py:676)
    __call__ (unittest/case.py:736)
    runtest (_pytest/unittest.py:327)
    pytest_runtest_call (_pytest/runner.py:166)
    _multicall (pluggy/_callers.py:39)
    _hookexec (pluggy/_manager.py:80)
    __call__ (pluggy/_hooks.py:265)
    <lambda> (_pytest/runner.py:259)
    from_call (_pytest/runner.py:338)
    call_runtest_hook (_pytest/runner.py:258)
    call_and_report (_pytest/runner.py:219)
    runtestprotocol (_pytest/runner.py:130)
    pytest_runtest_protocol (_pytest/runner.py:111)
    _multicall (pluggy/_callers.py:39)
    _hookexec (pluggy/_manager.py:80)
    __call__ (pluggy/_hooks.py:265)
    pytest_runtestloop (_pytest/main.py:347)
    _multicall (pluggy/_callers.py:39)
    _hookexec (pluggy/_manager.py:80)
    __call__ (pluggy/_hooks.py:265)
    _main (_pytest/main.py:322)
    wrap_session (_pytest/main.py:268)
    pytest_cmdline_main (_pytest/main.py:315)
    _multicall (pluggy/_callers.py:39)
    _hookexec (pluggy/_manager.py:80)
    __call__ (pluggy/_hooks.py:265)
    main (_pytest/config/__init__.py:164)
    console_main (_pytest/config/__init__.py:187)
    <module> (pytest:8)
Thread 530372 (idle): "Thread-4"
    wait (threading.py:306)
    wait (threading.py:558)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:932)
    _bootstrap (threading.py:890)
Thread 530390 (active)
    _call_impl (torch/nn/modules/module.py:1130)
    _fn (torchdynamo/eval_frame.py:74)
    backward (functorch/_src/aot_autograd.py:185)
    _fn (torchdynamo/eval_frame.py:74)
    apply (torch/autograd/function.py:253)

I tried rebuilding everything and it still hangs. env details below

I can't even Ctrl-C pytest - have to kill it

  1. Once we figure out how to make the test work I need to see how fast it runs to potentially @slow decorate it - which we do for slow tests.

  2. We need to instrument the nightly CI to install all the requirements to run this test. I'm just waiting to confirm how to best approach it.


build env:

PyTorch version: 1.12.0.dev20220518
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 21.10 (x86_64)
GCC version: (Ubuntu 10.3.0-11ubuntu1) 10.3.0
Clang version: 13.0.0-2
CMake version: version 3.21.3
Libc version: glibc-2.34

Python version: 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.32-051532-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA GeForce GTX 1070 Ti

Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] functorch==0.3.0a0+76976db
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.2
[pip3] torch==1.12.0.dev20220518
[pip3] torchaudio==0.12.0.dev20220518
[pip3] torchdynamo==0.2.0
[pip3] torchvision==0.13.0.dev20220518
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] functorch 0.3.0a0+76976db dev_0
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.21.2 pypi_0 pypi
[conda] pytorch 1.12.0.dev20220518 py3.8_cuda11.3_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torch 1.12.0.dev20220404+cu115 pypi_0 pypi
[conda] torchaudio 0.12.0.dev20220404+cu115 pypi_0 pypi
[conda] torchdynamo 0.2.0 dev_0
[conda] torchvision 0.13.0.dev20220404+cu115 pypi_0 pypi

super().__init__()

def forward(self, x):
for _ in range(20):
Copy link
Contributor Author

@anijain2305 anijain2305 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 Can you try changing this number to 10 to see if the speed improves? If it doesn't, lets make it 1 to see if this is the culprit. At 1, the test will fail, but we will have little more info to debug,

Copy link
Contributor

@stas00 stas00 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 1 or 2 fails as expected
  • 5 succeeds
  • 10 hangs - same top of the stack

Copy link
Contributor

@stas00 stas00 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured the trigger out I think- Trainer tried to run 2 gpus DP and somehow was leading to hanging.

Forcing running the test on one gpu breaks past this hanging issue

CUDA_VISIBLE_DEVICES=0 pyt tests/trainer/test_trainer.py -k torchdynamo_memory -sv

now fails:

self = <tests.trainer.test_trainer.TrainerIntegrationTest testMethod=test_torchdynamo_memory>

    @require_torch_gpu
    @require_torchdynamo
    def test_torchdynamo_memory(self):
        class MyModule(torch.nn.Module):
            """Simple module that does aggressive fusion"""
    
            def __init__(self):
                super().__init__()
    
            def forward(self, x):
                for _ in range(20):
                    x = torch.nn.functional.relu(x)
                return x
    
        mod = MyModule()
    
        # 1. Default - without TorchDynamo
        a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
        a.grad = None
        trainer = Trainer(model=mod)
        # warmup
        for _ in range(10):
>           orig_loss = trainer.training_step(mod, {"x": a})

tests/trainer/test_trainer.py:1649: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/trainer.py:2263: in training_step
    loss.backward()
/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/_tensor.py:399: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/autograd/__init__.py:166: in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

outputs = (tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0',
       grad_fn=<SelectBackward0>),), grads = (None,)
is_grads_batched = False

    def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],
                    is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:
        new_grads: List[_OptionalTensor] = []
        for out, grad in zip(outputs, grads):
            if isinstance(grad, torch.Tensor):
                grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
                if not out.shape == grad_shape:
                    if is_grads_batched:
                        raise RuntimeError("If `is_grads_batched=True`, we interpret the first "
                                           "dimension of each grad_output as the batch dimension. "
                                           "The sizes of the remaining dimensions are expected to match "
                                           "the shape of corresponding output, but a mismatch "
                                           "was detected: grad_output["
                                           + str(grads.index(grad)) + "] has a shape of "
                                           + str(grad.shape) + " and output["
                                           + str(outputs.index(out)) + "] has a shape of "
                                           + str(out.shape) + ". "
                                           "If you only want some tensors in `grad_output` to be considered "
                                           "batched, consider using vmap.")
                    else:
                        raise RuntimeError("Mismatch in shape: grad_output["
                                           + str(grads.index(grad)) + "] has a shape of "
                                           + str(grad.shape) + " and output["
                                           + str(outputs.index(out)) + "] has a shape of "
                                           + str(out.shape) + ".")
                if out.dtype.is_complex != grad.dtype.is_complex:
                    raise RuntimeError("For complex Tensors, both grad_output and output"
                                       " are required to have the same dtype."
                                       " Mismatch in dtype: grad_output["
                                       + str(grads.index(grad)) + "] has a dtype of "
                                       + str(grad.dtype) + " and output["
                                       + str(outputs.index(out)) + "] has a dtype of "
                                       + str(out.dtype) + ".")
                new_grads.append(grad)
            elif grad is None:
                if out.requires_grad:
                    if out.numel() != 1:
>                       raise RuntimeError("grad can be implicitly created only for scalar outputs")
E                       RuntimeError: grad can be implicitly created only for scalar outputs

/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/autograd/__init__.py:67: RuntimeError

Copy link
Contributor Author

@anijain2305 anijain2305 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. For single GPU, the training_step was not calling mean() on the output. PyTorch expects a scalar loss for .backward() call and therefore you saw the error message you pasted. I just added a CustomTrainer to reduce to scalar and the test passes for single GPU as well.

I am not sure why it hangs for DP though (maybe its compiling for each GPU node, and thus compilation time is shooting up?).

Is it possible to limit the TorchDynamo usage for single GPUs only? We have not really tested TorchDynamo for distributed training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's because I have a lopsided setup?

GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA GeForce GTX 1070 Ti

Is the compiled version hardware agnostic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works on 2 gpus with DP as well - should we remove the restrictions then and put it back how you coded it originally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, awesome. Thanks for putting time on this.

There is no need to change the test back to original one. My commit earlier extended the test to single-GPU as well. So, the test works for both single and multi-GPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except due to the @require_torch_non_multi_gpu it will now only ever will be run on a single gpu.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hang seems like concerning behavior. Glad it seems to have sorted itself out, but please let us know if it returns in any form. I do have a single mixed GPU system I could try reproducing on if it comes back. Though I would always recommend DDP running across matching GPUs.

Copy link
Contributor

@stas00 stas00 May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hanging was on a single GPU as well with pytorch nightly from 05-18.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are working now.

@sgugger sgugger merged commit 897a8dd into huggingface:main May 25, 2022
@stas00
Copy link
Contributor

stas00 commented May 25, 2022

@anijain2305, are you up to doing one more PR with docs? https://huggingface.co/docs/transformers/main/en/performance

  1. add HF Trainer usage example
  2. add examples of how a user can do it directly

I guess with the new layout the docs would go here:
https://github.com/huggingface/transformers/blob/main/docs/source/en/perf_train_gpu_one.mdx

@anijain2305
Copy link
Contributor Author

@stas00 Yes, I can do docs as well. Let me take a look and I will come back where to put the section.

@stas00
Copy link
Contributor

stas00 commented May 25, 2022

Also as I updated in the OP, @ydshieh is instrumenting the nightly CI to install the prerequisites for this test in this PR: #17335 in commit: 52e7021

elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…ce#17308)

* Support compilation via Torchdynamo, AOT Autograd, NVFuser

* Address comments

* Lint

* Stas comments - missing quality test

* Lintere

* Quality test

* Doc lint

* Reset CUDA peak mem

* Add CustomTrainer

* require a single gpu

Co-authored-by: Stas Bekman <stas@stason.org>
@stas00
Copy link
Contributor

stas00 commented Jun 16, 2022

pinging about the docs, @anijain2305 - thank you!

almost nobody will use your work unless you document it in user-facing docs. so you're the ones who really want to add these docs, I'd think...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants