Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import deepspeed

from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters
from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
Expand Down Expand Up @@ -431,16 +431,18 @@ def __init__(self,
self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather)
self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states)

# We now support PyTorch style backward, but it relies on the counter in ZeRO optimizers.
# However, we need some internal APIs to count the number of only used parameters.
# So we only enable this feature when those internal APIs are available.
# Otherwise, we fallback to DeepSpeed style backward only.
# We now support PyTorch style backward, which relies on the counter in ZeRO optimizers.
# When the required internal PyTorch APIs are available (PyTorch >= 2.3), we use them
# for precise parameter counting. When they are not available (older PyTorch builds),
# count_used_parameters_in_backward() falls back to a conservative count of all
# grad-requiring parameters, which is correct but may slightly delay the epilogue.
# Either way, the feature is safe to enable for all ZeRO optimizers.
# See `count_used_parameters_in_backward` for more details.
self._running_engine_backward = False
self._support_torch_style_backward = False
# Flag to control whether gradients should be scaled by gradient accumulation steps
self._scale_wrt_gas = True
if isinstance(self.optimizer, ZeROOptimizer) and check_internal_apis_for_count_used_parameters():
if isinstance(self.optimizer, ZeROOptimizer):
self._support_torch_style_backward = True
# These hooks are used for non-scalar backward support, such as `out.backward(out_grad)`,
# not for `engine.backward(loss)`. In this case, we need to ensure that the preprocessing
Expand Down
16 changes: 12 additions & 4 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,16 +1435,24 @@ def count_used_parameters_in_backward(parameters: Sequence[torch.nn.Parameter])
its verification on tensor shapes throws an error with ZeRO3 (it expects original tensor shape).
So this function simplifies register_multi_grad_hook just to count used parameters.

When the required PyTorch internal APIs are not available (e.g. PyTorch < 2.3),
this function falls back to counting all parameters that require gradients, which
is conservative but correct — it may delay the epilogue slightly but will never
trigger it prematurely.

Args:
parameters: Iterable of model parameters to inspect.

Returns:
The number of parameters whose gradient nodes will be executed by the autograd engine
for the active backward call.
for the active backward call. When internal APIs are unavailable, returns the total
count of parameters that require gradients (conservative fallback).
"""
assert check_internal_apis_for_count_used_parameters(), (
"count_used_parameters_in_backward requires internal PyTorch APIs that are not available "
"in this PyTorch build.")
if not check_internal_apis_for_count_used_parameters():
# Fallback for older PyTorch versions (< 2.3) that lack the internal APIs.
# Return the total number of grad-requiring parameters as a conservative
# upper bound. This ensures the epilogue never fires prematurely.
return sum(1 for p in parameters if isinstance(p, torch.Tensor) and p.requires_grad)

from torch.autograd.graph import _get_grad_fn_or_grad_acc
if torch._C._current_graph_task_id() == -1:
Expand Down
205 changes: 205 additions & 0 deletions tests/unit/runtime/test_count_used_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

"""
Tests for count_used_parameters_in_backward and its fallback behaviour
when PyTorch internal APIs are unavailable (issue #7756).
"""

import pytest
import torch
from unittest.mock import patch

from deepspeed.runtime.utils import (
check_internal_apis_for_count_used_parameters,
count_used_parameters_in_backward,
)


# ---------------------------------------------------------------------------
# Helper: build a small parameter list
# ---------------------------------------------------------------------------

def _make_params(n=4, requires_grad=True):
"""Return a list of n Parameters on CPU."""
return [torch.nn.Parameter(torch.randn(3, 3), requires_grad=requires_grad) for _ in range(n)]


# ---------------------------------------------------------------------------
# Tests for check_internal_apis_for_count_used_parameters
# ---------------------------------------------------------------------------

class TestCheckInternalApis:
"""Verify the availability probe returns correct results."""

def test_returns_bool(self):
result = check_internal_apis_for_count_used_parameters()
assert isinstance(result, bool)

def test_false_when_get_grad_fn_missing(self):
with patch.object(torch.autograd.graph, '_get_grad_fn_or_grad_acc', create=False, new=None):
# Remove the attribute entirely
saved = getattr(torch.autograd.graph, '_get_grad_fn_or_grad_acc', None)
try:
if hasattr(torch.autograd.graph, '_get_grad_fn_or_grad_acc'):
delattr(torch.autograd.graph, '_get_grad_fn_or_grad_acc')
assert check_internal_apis_for_count_used_parameters() is False
finally:
if saved is not None:
torch.autograd.graph._get_grad_fn_or_grad_acc = saved

def test_false_when_current_graph_task_id_missing(self):
saved = getattr(torch._C, '_current_graph_task_id', None)
try:
if hasattr(torch._C, '_current_graph_task_id'):
delattr(torch._C, '_current_graph_task_id')
assert check_internal_apis_for_count_used_parameters() is False
finally:
if saved is not None:
torch._C._current_graph_task_id = saved

def test_false_when_will_engine_execute_node_missing(self):
saved = getattr(torch._C, '_will_engine_execute_node', None)
try:
if hasattr(torch._C, '_will_engine_execute_node'):
delattr(torch._C, '_will_engine_execute_node')
assert check_internal_apis_for_count_used_parameters() is False
finally:
if saved is not None:
torch._C._will_engine_execute_node = saved


# ---------------------------------------------------------------------------
# Tests for the fallback path of count_used_parameters_in_backward
# ---------------------------------------------------------------------------

class TestCountUsedParametersFallback:
"""When internal APIs are missing, the function must fall back to counting
all parameters that require gradients instead of crashing."""

def test_fallback_returns_total_grad_count(self):
"""With APIs unavailable, should return count of grad-requiring params."""
params = _make_params(5, requires_grad=True)
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward(params)
assert result == 5

def test_fallback_excludes_no_grad_params(self):
"""Params with requires_grad=False should not be counted in fallback."""
params = _make_params(3, requires_grad=True) + _make_params(2, requires_grad=False)
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward(params)
assert result == 3

def test_fallback_empty_list(self):
"""Empty parameter list should return 0."""
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward([])
assert result == 0

def test_fallback_all_no_grad(self):
"""All params with requires_grad=False should return 0."""
params = _make_params(4, requires_grad=False)
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward(params)
assert result == 0

def test_fallback_mixed_tensors_and_non_tensors(self):
"""Non-tensor items in the parameter list should be skipped."""
params = _make_params(2, requires_grad=True)
mixed = params + [None, "not_a_tensor", 42]
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward(mixed)
assert result == 2

def test_fallback_does_not_raise(self):
"""The old code raised AssertionError; the fix must NOT raise."""
params = _make_params(3, requires_grad=True)
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
# This should complete without any assertion or exception
result = count_used_parameters_in_backward(params)
assert isinstance(result, int)
assert result >= 0

def test_fallback_single_param(self):
"""Single parameter with grad should return 1."""
params = _make_params(1, requires_grad=True)
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward(params)
assert result == 1

def test_fallback_large_param_list(self):
"""Ensure fallback scales correctly with many parameters."""
params = _make_params(100, requires_grad=True)
with patch('deepspeed.runtime.utils.check_internal_apis_for_count_used_parameters', return_value=False):
result = count_used_parameters_in_backward(params)
assert result == 100


# ---------------------------------------------------------------------------
# Tests for the native path (when APIs are available)
# ---------------------------------------------------------------------------

class TestCountUsedParametersNative:
"""When internal APIs are available, verify the native path works correctly."""

@pytest.mark.skipif(
not check_internal_apis_for_count_used_parameters(),
reason="PyTorch internal APIs not available (likely PyTorch < 2.3)"
)
def test_native_path_during_backward(self):
"""Native path should work correctly when called inside a backward hook."""
model = torch.nn.Linear(4, 2)
params = list(model.parameters())
results = []

def hook_fn(grad):
count = count_used_parameters_in_backward(params)
results.append(count)
return grad

x = torch.randn(1, 4)
out = model(x)
loss = out.sum()

# Register hook on one of the parameter's grad_fn
params[0].register_hook(hook_fn)
loss.backward()

assert len(results) == 1
assert results[0] > 0 # At least some params should participate

@pytest.mark.skipif(
not check_internal_apis_for_count_used_parameters(),
reason="PyTorch internal APIs not available (likely PyTorch < 2.3)"
)
def test_native_path_raises_outside_backward(self):
"""Native path should raise RuntimeError when not inside backward."""
params = _make_params(3, requires_grad=True)
with pytest.raises(RuntimeError, match="must be called during backward execution"):
count_used_parameters_in_backward(params)

@pytest.mark.skipif(
not check_internal_apis_for_count_used_parameters(),
reason="PyTorch internal APIs not available (likely PyTorch < 2.3)"
)
def test_native_path_empty_list(self):
"""Native path should return 0 for empty list during backward."""
model = torch.nn.Linear(4, 2)
results = []

def hook_fn(grad):
count = count_used_parameters_in_backward([])
results.append(count)
return grad

x = torch.randn(1, 4)
out = model(x)
loss = out.sum()
list(model.parameters())[0].register_hook(hook_fn)
loss.backward()

assert len(results) == 1
assert results[0] == 0