Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FutureWarning: torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead. #20370

Open
loretoparisi opened this issue Oct 28, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@loretoparisi
Copy link

loretoparisi commented Oct 28, 2024

Bug description

I'm getting the error

FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.

specifically the issue seems to be caused by fairscale:

/home/coder/.local/lib/python3.10/site-packages/fairscale/experimental/nn/offload.py:19: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

I'm running


torch==2.5.0
lightning==2.4.0
lightning-utilities==0.11.8
fairscale==0.4.6

python code is almost standard distributed training code with FSDPStrategy train strategy and it was working before:

def custom_auto_wrap_policy(module, recurse, nonwrapped_numel, **kwargs):
            # Wrap only Embedding layers
            if isinstance(module, nn.Embedding):
                return True
            return False
        
        sharding_strategy=modelArgs.sharding_strategy
        state_dict_type=modelArgs.state_dict_type

strategy = FSDPStrategy(
                timeout=CUSTOM_TIMEOUT,
                cpu_offload=cpu_offload,
                activation_checkpointing_policy=custom_auto_wrap_policy,
                auto_wrap_policy=custom_auto_wrap_policy,
                mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
                process_group_backend="nccl",
                sharding_strategy=sharding_strategy,
                state_dict_type=state_dict_type
            )

    pairs = preprocess_pairs_tensor
    train_dataset = TensorDataset(pairs)
    trainloader = DataLoader(train_dataset,
                             batch_size = modelArgs.batch_size,
                             collate_fn = my_collate,
                             drop_last = False,
                             shuffle=True,
                             num_workers=psutil.cpu_count(),
                             persistent_workers=True,
                             pin_memory=True)
    
    # Initialize a trainer
    trainer = L.Trainer(
        logger=logger,
        log_every_n_steps=1,
        precision="bf16-true",
        callbacks=[checkpoint_callback],
        accelerator=accelerator,
        devices=devices,
        num_nodes=num_nodes,
        strategy=strategy,
        #limit_train_batches=1.0,
        max_epochs=modelArgs.epochs,
        deterministic=True
    )

Error messages and logs

Undefined number of following logging

/home/coder/.local/lib/python3.10/site-packages/fairscale/experimental/nn/offload.py:19: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.

Environment

<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
        - available:         True
        - version:           12.4
* Lightning:
        - lightning:         2.4.0
        - lightning-utilities: 0.11.8
        - pytorch-lightning: 1.9.5
        - torch:             2.5.0
        - torchmetrics:      1.5.1
        - torchvision:       0.20.0
* Packages:
        - aiofiles:          22.1.0
        - aiohappyeyeballs:  2.4.3
        - aiohttp:           3.10.10
        - aiosignal:         1.3.1
        - aiosqlite:         0.20.0
        - anyio:             4.6.0
        - appdirs:           1.4.4
        - argon2-cffi:       23.1.0
        - argon2-cffi-bindings: 21.2.0
        - argparse:          1.4.0
        - arrow:             1.3.0
        - asttokens:         2.4.1
        - async-timeout:     4.0.3
        - attrs:             24.2.0
        - autocommand:       2.2.2
        - autofaiss:         2.15.8
        - babel:             2.16.0
        - backports.tarfile: 1.2.0
        - beautifulsoup4:    4.12.3
        - bleach:            6.1.0
        - blinker:           1.4
        - boto3:             1.26.145
        - botocore:          1.29.165
        - certifi:           2024.8.30
        - cffi:              1.17.1
        - charset-normalizer: 3.3.2
        - click:             8.1.7
        - coloredlogs:       15.0.1
        - comm:              0.2.2
        - cryptography:      3.4.8
        - datasets:          2.14.4
        - dbus-python:       1.2.18
        - debugpy:           1.8.6
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.7
        - distro:            1.7.0
        - docker-pycreds:    0.4.0
        - duckdb:            1.1.1
        - embedding-reader:  1.7.0
        - entrypoints:       0.4
        - exceptiongroup:    1.2.2
        - executing:         2.1.0
        - fairscale:         0.4.6
        - faiss-cpu:         1.9.0
        - fastjsonschema:    2.20.0
        - filelock:          3.16.1
        - fire:              0.4.0
        - flatbuffers:       24.3.25
        - fqdn:              1.5.1
        - frozenlist:        1.5.0
        - fsspec:            2024.10.0
        - fuzzywuzzy:        0.18.0
        - gitdb:             4.0.11
        - gitpython:         3.1.43
        - hnswlib:           0.7.0
        - httplib2:          0.20.2
        - huggingface-hub:   0.25.2
        - humanfriendly:     10.0
        - idna:              3.10
        - importlib-metadata: 4.6.4
        - importlib-resources: 6.4.0
        - inflect:           7.3.1
        - ipykernel:         6.29.5
        - ipython:           8.27.0
        - ipython-genutils:  0.2.0
        - isoduration:       20.11.0
        - jaraco.collections: 5.1.0
        - jaraco.context:    5.3.0
        - jaraco.functools:  4.0.1
        - jaraco.text:       3.12.1
        - jedi:              0.19.1
        - jeepney:           0.7.1
        - jinja2:            3.1.4
        - jmespath:          1.0.1
        - joblib:            1.3.2
        - json5:             0.9.25
        - jsonpointer:       3.0.0
        - jsonschema:        4.23.0
        - jsonschema-specifications: 2023.12.1
        - jupyter-client:    7.4.9
        - jupyter-core:      5.7.2
        - jupyter-events:    0.10.0
        - jupyter-server:    2.14.2
        - jupyter-server-fileid: 0.9.3
        - jupyter-server-terminals: 0.5.3
        - jupyter-server-ydoc: 0.8.0
        - jupyter-ydoc:      0.2.5
        - jupyterlab:        3.6.2
        - jupyterlab-pygments: 0.3.0
        - jupyterlab-server: 2.27.3
        - keyring:           23.5.0
        - launchpadlib:      1.10.16
        - lazr.restfulclient: 0.14.4
        - lazr.uri:          1.0.6
        - levenshtein:       0.23.0
        - lightning:         2.4.0
        - lightning-utilities: 0.11.8
        - markupsafe:        2.1.5
        - matplotlib-inline: 0.1.7
        - mistune:           3.0.2
        - more-itertools:    8.10.0
        - mpmath:            1.3.0
        - multidict:         6.1.0
        - multiprocess:      0.70.15
        - nbclassic:         1.1.0
        - nbclient:          0.10.0
        - nbconvert:         7.16.4
        - nbformat:          5.10.4
        - nest-asyncio:      1.6.0
        - networkx:          3.4.2
        - nltk:              3.9.1
        - notebook:          6.5.7
        - notebook-shim:     0.2.4
        - numpy:             1.26.4
        - nvidia-cublas-cu12: 12.4.5.8
        - nvidia-cuda-cupti-cu12: 12.4.127
        - nvidia-cuda-nvrtc-cu12: 12.4.127
        - nvidia-cuda-runtime-cu12: 12.4.127
        - nvidia-cudnn-cu12: 9.1.0.70
        - nvidia-cufft-cu12: 11.2.1.3
        - nvidia-curand-cu12: 10.3.5.147
        - nvidia-cusolver-cu12: 11.6.1.9
        - nvidia-cusparse-cu12: 12.3.1.170
        - nvidia-nccl-cu12:  2.21.5
        - nvidia-nvjitlink-cu12: 12.4.127
        - nvidia-nvtx-cu12:  12.4.127
        - oauthlib:          3.2.0
        - onnx:              1.17.0
        - onnxruntime-gpu:   1.19.2
        - optimum:           1.23.2
        - overrides:         7.7.0
        - packaging:         24.1
        - pandas:            1.3.5
        - pandocfilters:     1.5.1
        - parso:             0.8.4
        - pexpect:           4.9.0
        - pillow:            11.0.0
        - pip:               24.2
        - platformdirs:      4.3.6
        - prometheus-client: 0.21.0
        - prompt-toolkit:    3.0.48
        - propcache:         0.2.0
        - protobuf:          4.25.3
        - psutil:            5.9.5
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.3
        - pyarrow:           12.0.1
        - pycparser:         2.22
        - pygments:          2.18.0
        - pygobject:         3.42.1
        - pyjwt:             2.3.0
        - pyparsing:         2.4.7
        - python-apt:        2.4.0+ubuntu4
        - python-dateutil:   2.9.0.post0
        - python-json-logger: 2.0.7
        - python-levenshtein: 0.23.0
        - pytorch-lightning: 1.9.5
        - pytz:              2024.2
        - pyyaml:            6.0.2
        - pyzmq:             26.0.3
        - rapidfuzz:         3.4.0
        - referencing:       0.35.1
        - regex:             2024.9.11
        - requests:          2.32.3
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rpds-py:           0.20.0
        - s3transfer:        0.6.2
        - safetensors:       0.4.5
        - scikit-learn:      1.5.2
        - scipy:             1.14.1
        - secretstorage:     3.3.1
        - send2trash:        1.8.3
        - sentence-transformers: 2.2.2
        - sentencepiece:     0.2.0
        - sentry-sdk:        2.17.0
        - setproctitle:      1.3.3
        - setuptools:        75.1.0
        - six:               1.16.0
        - smmap:             5.0.1
        - sniffio:           1.3.1
        - soupsieve:         2.6
        - stack-data:        0.6.3
        - sympy:             1.13.1
        - termcolor:         2.5.0
        - terminado:         0.18.1
        - threadpoolctl:     3.5.0
        - tinycss2:          1.3.0
        - tokenizers:        0.20.1
        - tomli:             2.0.1
        - torch:             2.5.0
        - torchmetrics:      1.5.1
        - torchvision:       0.20.0
        - tornado:           6.4.1
        - tqdm:              4.66.1
        - traitlets:         5.14.3
        - transformers:      4.46.0
        - triton:            3.1.0
        - typeguard:         4.3.0
        - types-python-dateutil: 2.9.0.20240906
        - typing-extensions: 4.12.2
        - uri-template:      1.3.0
        - urllib3:           1.26.20
        - wadllib:           1.3.6
        - wandb:             0.16.6
        - wcwidth:           0.2.13
        - webcolors:         24.8.0
        - webencodings:      0.5.1
        - websocket-client:  1.8.0
        - wheel:             0.44.0
        - xxhash:            3.5.0
        - y-py:              0.6.2
        - yarl:              1.16.0
        - ypy-websocket:     0.8.4
        - zipp:              1.0.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - 
        - processor:         x86_64
        - python:            3.10.12
        - release:           5.10.219-208.866.amzn2.x86_64
        - version:           #1 SMP Tue Jun 18 14:00:06 UTC 2024

</details>

More info

No response

@loretoparisi loretoparisi added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

1 participant