Skip to content

[Bug] Missing input validation for fp16 loss_scale accepts float('inf') #7852

@amadhan882

Description

@amadhan882

Describe the bug

The fp16 configuration in DeepSpeed accepts float('inf') for the loss_scale parameter without any validation. While other parameters like gradient_accumulation_steps and stage3_max_live_parameters correctly trigger AssertionError or ValidationError for invalid/negative inputs, loss_scale set to infinity allows the engine to initialize successfully. This leads to a logic leak where gradients will silently become NaN during training.

To Reproduce

The following script was used on Google Colab to test logic edge cases. The "Infinite Scale Test" successfully bypasses validation while others are caught.

Reproduction Code:

import deepspeed
import torch
import torch.nn as nn

class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    def forward(self, x): return self.linear(x)

def run_test(name, config):
    print(f"\n[#] Testing: {name}")
    print(f"[-] Config: {config}")
    try:
        model = TinyModel()
        engine, _, _, _ = deepspeed.initialize(
            model=model,
            config=config,
            model_parameters=model.parameters()
        )
        print("[+] STATUS: EXECUTED (Possible Logic Leak)")
    except Exception as e:
        print(f"[X] CRASH: {type(e).__name__} -> {e}")

test_cases = {
    "Negative Accumulation": {
        "train_batch_size": 16,
        "gradient_accumulation_steps": -1
    },
    "Zero Discovery Stage": {
        "train_batch_size": 16,
        "zero_optimization": {
            "stage": 3,
            "stage3_max_live_parameters": -100
        }
    },
    "Infinite Scale Test": {
        "train_batch_size": 16,
        "fp16": {
            "enabled": True,
            "loss_scale": float('inf')
        }
    }
}

if __name__ == "__main__":
    for name, cfg in test_cases.items():
        run_test(name, cfg)

Observed Output:

[#] Testing: Negative Accumulation
[-] Config: {'train_batch_size': 16, 'gradient_accumulation_steps': -1}
[X] CRASH: AssertionError -> Micro batch size per gpu: -16 has to be greater than 0

[#] Testing: Zero Discovery Stage
[-] Config: {'train_batch_size': 16, 'zero_optimization': {'stage': 3, 'stage3_max_live_parameters': -100}}
[X] CRASH: ValidationError -> 1 validation error for DeepSpeedZeroConfig
stage3_max_live_parameters
  Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-100, input_type=int]

[#] Testing: Infinite Scale Test
[-] Config: {'train_batch_size': 16, 'fp16': {'enabled': True, 'loss_scale': inf}}
[+] STATUS: EXECUTED (Possible Logic Leak)

Expected behavior

DeepSpeed should validate that loss_scale is a finite positive number. Providing inf should trigger a validation error during the configuration parsing phase, consistent with how other numerical constraints are enforced.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.12/dist-packages/torch']
torch version .................... 2.9.0+cu128
deepspeed install path ........... ['/usr/local/lib/python3.12/dist-packages/deepspeed']
deepspeed info ................... 0.18.6, unknown, unknown
torch cuda version ............... 12.8
nvcc version ..................... 12.8
deepspeed wheel compiled w. ...... torch 2.9, cuda 12.8
shared memory (/dev/shm) size .... 5.68 GB

System info

  • OS: Google Colab (Ubuntu 22.04 LTS)
  • GPU count and types: 1 x Tesla T4 (NVIDIA-SMI 580.82.07)
  • CUDA Version: 13.0
  • Python version: 3.12.12
  • DeepSpeed version: 0.18.6

Launcher context

Local Python Script (Google Colab)

Additional context

The bypass allows mathematically invalid configurations to pass initialization, causing training to fail with NaNs later. Enforcing finiteness in the config parser would improve debugging for users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions