Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from colossalai.tensor.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict
from typing import Dict, Iterable
from colossalai.logging import get_dist_logger


Expand Down Expand Up @@ -38,6 +38,8 @@ def __init__(self, module: torch.nn.Module) -> None:
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))

Expand All @@ -55,6 +57,8 @@ def backward(self, loss: torch.Tensor):
loss.backward()
torch.cuda.current_stream().wait_stream(self.comm_stream)
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad

Expand Down Expand Up @@ -99,6 +103,25 @@ def zero_grad(self, set_to_none: bool = False) -> None:
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()

@staticmethod
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing ColoDDP.

Example::
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
>>> module = ColoDDP(module)

Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True


class ColoDDPV2(ColoDDP):

Expand All @@ -114,6 +137,8 @@ def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> No
self.chunk_manager.create_group('fp32_param')
# TODO: get param order and filter unused params
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
assert p.dtype == torch.half
fp32_p = p.float().detach()
self.chunk_manager.append_tensor(p, 'fp16_param')
Expand All @@ -133,6 +158,8 @@ def forward(self, *args, **kwargs):

def _setup_grads_ptr(self):
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
p.grad = None
else:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/zero/utils/zero_hook_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, gemini_manager: GeminiManager) -> None:
self._training_phase = TrainingPhase.FORWARD

def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
Expand All @@ -33,6 +34,7 @@ def pre_op(self, params):
self._gemini_manager.sample_model_data()

def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)
Expand Down
87 changes: 87 additions & 0 deletions tests/test_utils/test_ddp_ignore_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from functools import partial
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
import torch.distributed as dist
import os
import random
import numpy as np


def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True


def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module)


def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2:
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ColoDDPV2(module, gemini_manager)


class Net(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(3, 3, bias=False)
self.fc2 = torch.nn.Linear(3, 1, bias=False)

def forward(self, x):
return self.fc2(self.fc1(x))


def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
with ColoInitContext(device=get_current_device()):
model = Net().cuda()
w1 = model.fc1.weight
w2 = model.fc2.weight
ddp_cls.set_params_to_ignore([w2])
model = init_ddp_func(model)
x = torch.rand(2, 3, device=get_current_device())
logits = model(x)
loss = torch.sum(logits)
model.backward(loss)
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
dist.all_gather(w1_grads, w1.grad)
assert torch.equal(w1_grads[0], w1_grads[1])
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
dist.all_gather(w2_grads, w2.grad)
assert not torch.equal(w2_grads[0], w2_grads[1])


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
set_seed(dist.get_rank())
run_fwd_bwd(ColoDDP, init_ddp)
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False))
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True))


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_ddp_ignore_params(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_ddp_ignore_params(2)