Skip to content

Commit af32022

Browse files
authored
[Gemini] fix the convert_to_torch_module bug (#2269)
1 parent 879df8b commit af32022

File tree

4 files changed

+48
-21
lines changed

4 files changed

+48
-21
lines changed

colossalai/gemini/gemini_mgr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class GeminiManager:
3030

3131
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
3232

33-
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
33+
assert placement_policy in PlacementPolicyFactory.get_policy_names()
3434
self.policy_name = placement_policy
3535
policy_cls = PlacementPolicyFactory.create(placement_policy)
3636
self._chunk_manager = chunk_manager

colossalai/gemini/placement_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def create(policy_name: str) -> Type[PlacementPolicy]:
236236
return PlacementPolicyFactory.policies[policy_name]
237237

238238
@staticmethod
239-
def get_polocy_names():
239+
def get_policy_names():
240240
return tuple(PlacementPolicyFactory.policies.keys())
241241

242242
@staticmethod

colossalai/nn/parallel/data_parallel.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -360,24 +360,20 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0:
360360
destination = hook_result
361361
return destination
362362

363-
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
364-
r"""Saves module state to `destination` dictionary, containing a state
365-
of the module, but not its descendants. This is called on every
366-
submodule in :meth:`~torch.nn.Module.state_dict`.
367-
368-
In rare cases, subclasses can achieve class-specific behavior by
369-
overriding this method with custom logic.
363+
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
364+
"""
365+
get param content from chunks.
370366
371367
Args:
372-
destination (dict): a dict where state will be stored
373-
prefix (str): the prefix for parameters and buffers used in this
374-
module
375-
"""
376-
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
368+
param_list (_type_): a list of torch.nn.Parameters
369+
only_rank_0 (_type_): _description_
377370
371+
Returns:
372+
Dict: a dict whose key is param name and value is param with correct payload
373+
"""
378374
# save parameters
379375
param_to_save_data = dict()
380-
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
376+
chunk_list = self.chunk_manager.get_chunks(param_list)
381377
for chunk in chunk_list:
382378
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
383379

@@ -391,7 +387,37 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
391387
param_to_save_data[tensor] = record_tensor
392388

393389
del temp_chunk
390+
return param_to_save_data
391+
392+
def torch_named_parameters(self):
393+
"""
394+
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
395+
It works the same as torch.Module named_parameters
396+
"""
397+
params_list = [p for p in self.parameters(recurse=True)]
398+
param_to_save_data = self._get_param_to_save_data(params_list, False)
399+
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
400+
if p is not None:
401+
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
402+
record_parameter = param_to_save_data[p]
403+
yield name, record_parameter
404+
405+
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
406+
r"""Saves module state to `destination` dictionary, containing a state
407+
of the module, but not its descendants. This is called on every
408+
submodule in :meth:`~torch.nn.Module.state_dict`.
409+
410+
In rare cases, subclasses can achieve class-specific behavior by
411+
overriding this method with custom logic.
412+
413+
Args:
414+
destination (dict): a dict where state will be stored
415+
prefix (str): the prefix for parameters and buffers used in this
416+
module
417+
"""
418+
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
394419

420+
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
395421
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
396422
if p is not None:
397423
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)

colossalai/nn/parallel/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch.distributed as dist
33

44
from colossalai.gemini.chunk import Chunk
5-
from colossalai.tensor import ColoTensor
65
from colossalai.utils import get_current_device
76

87

@@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
2221
return total_temp
2322

2423

24+
# TODO() not work for module where two params share the same tensor.
2525
def _add_param(model, name, param):
2626
name_list = name.split('.')
2727
module = model._modules[name_list[0]]
@@ -30,7 +30,7 @@ def _add_param(model, name, param):
3030
module._parameters[name_list[-1]] = param
3131

3232

33-
def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
33+
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
3434
"""convert_to_torch_module
3535
3636
Args:
@@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
3939
Returns:
4040
torch.nn.Module: a torch model contains the params of gemini_ddp_model
4141
"""
42+
from colossalai.nn.parallel import GeminiDDP
43+
assert isinstance(gemini_ddp_model, GeminiDDP)
4244
module = gemini_ddp_model.module
4345

44-
for n, p in module.named_parameters():
45-
if isinstance(p, ColoTensor):
46-
p.to_replicate_()
47-
_add_param(module, n, p.data)
46+
# replace ColoTensor to torch.nn.Tensor in module
47+
for n, p in gemini_ddp_model.torch_named_parameters():
48+
_add_param(module, n, p)
4849

4950
return module

0 commit comments

Comments
 (0)