Skip to content

[BUG] Empty torch.Tensor() calls crash when combined with DeepSpeed's partition_parameters.py #4095

Closed
@rosario-purple

Description

Describe the bug

When you try to create a new TransformerEngine (https://github.com/NVIDIA) PyTorch Linear layer with bias = False while using DeepSpeed, this line in TransformerEngine's pytorch/module/linear.py:

self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=torch.cuda.current_device())

winds up calling this function in DeepSpeed:

    def new_tensor(cls, *args) -> Tensor:
        device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
        tensor = _orig_torch_empty(0, device=device).new_empty(*args)
        if tensor.is_floating_point():
            tensor = tensor.to(dtype)

        return tensor

    return new_tensor

Because args is empty and new_empty() requires a size, this crashes and fails.

To Reproduce
Steps to reproduce the behavior:

  1. Create a PyTorch model using TransformerEngine fp8 Linear layers with no bias: layer = te.Linear(foo, bar, bias=False)
  2. Attempt to train the model using DeepSpeed with partitioned parameters (ZeRO Level 3)
  3. Training will crash because args has no parameters.

Expected behavior
Training should not crash.

ds_report output


DeepSpeed C++/CUDA extension op report

NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.

JIT compiled ops requires ninja
ninja .................. [OKAY]

op name ................ installed .. compatible

async_io ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
[WARNING] using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]

DeepSpeed general environment info:
torch install path ............... ['/home/alyssavance/miniforge-pypy3/envs/brr/lib/python3.9/site-packages/torch']
torch version .................... 2.0.1
deepspeed install path ........... ['/home/alyssavance/miniforge-pypy3/envs/brr/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.10.0, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8

Screenshots
N/A

System info (please complete the following information):

  • OS: Ubuntu 22.04
  • GPU count and types: 8x H100
  • Interconnects (if applicable): NVLink
  • Python version: 3.9.16
  • Any other relevant info about your setup

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?

deepspeed launcher

Docker context
Are you using a specific docker image that you can share?

No Docker

Additional context
Add any other context about the problem here.

DeepSpeed should now support fp8 training: https://raw.githubusercontent.com/microsoft/DeepSpeed/master/tests/unit/runtime/half_precision/test_fp8.py

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions