Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/zero/init_ctx/init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _post_context_exec(self):
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()

del self.param_list

Expand Down Expand Up @@ -252,11 +252,11 @@ def half_fn(t: torch.Tensor):
if param.grad is not None:
param.grad = param.grad.to(target_device)

param.colo_attr = ShardedParamV2(param, rm_torch_payload=False)
param.colo_attr = ShardedParamV2(param, set_data_none=False)

if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
param.data = param.colo_attr.data_payload # set param.data to payload

# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated
Expand Down
14 changes: 7 additions & 7 deletions colossalai/zero/sharded_model/sharded_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _post_backward_operations(self) -> None:
if not p.colo_attr.param_is_sharded:
tensor_list.append(p.colo_attr.sharded_data_tensor)
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
self.shard_strategy.shard(tensor_list, self.process_group)

# 4. set all parameters' grad to None
Expand Down Expand Up @@ -357,8 +357,8 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'

param.colo_attr.saved_grad.reset_payload(grad)
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
param.colo_attr.reset_grad_payload(grad)
param.colo_attr.reset_grad_payload(grad) # release the memory of param

if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
Expand All @@ -367,21 +367,21 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor):
fp32_grad = cast_tensor_to_fp32(grad)

if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(fp32_grad)
param.colo_attr.reset_grad_payload(fp32_grad)
else:
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))

# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)

def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.data = p.colo_attr.sharded_data_tensor.payload
p.data = p.colo_attr.data_payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
return gathered_state_dict

def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/sharded_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
15 changes: 7 additions & 8 deletions colossalai/zero/sharded_optim/sharded_optim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,7 @@ def _register_master_weight(self):
if shard_flag:
# we always shard replicated paramters
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
if shard_flag:
# In this branch, there's no need to shard param
# So we gather here
Expand Down Expand Up @@ -296,10 +295,10 @@ def _prepare_grads(self):
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.saved_grad.payload
p.grad = p.colo_attr.saved_grad.payload
p.data = p.colo_attr.grad_payload
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()

def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
Expand All @@ -325,9 +324,9 @@ def _copy_master_param_to_param_fp16(self, p):

# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
p.colo_attr.reset_data_payload(
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.set_data_none()

if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here
Expand Down
51 changes: 42 additions & 9 deletions colossalai/zero/sharded_param/sharded_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')

EMPTY_TENSOR_DICT = {}


def get_empty_tensor(device: torch.device, dtype: torch.dtype):
key = (device, dtype)
if key not in EMPTY_TENSOR_DICT:
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype)

return EMPTY_TENSOR_DICT[key]


class ShardedParamV2(object):

def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel
Expand All @@ -25,24 +35,47 @@ def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self.param = param
if rm_torch_payload:
self.remove_torch_payload()
if set_data_none:
self.set_data_none()

def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class.
"""
return [self._sharded_data_tensor]

def remove_torch_payload(self):
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)
def set_data_none(self):
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)

def set_grad_none(self):
self.saved_grad.set_null()

@property
def sharded_data_tensor(self):
return self._sharded_data_tensor

@property
def data_payload(self):
return self.sharded_data_tensor.payload

@property
def grad_payload(self):
assert not self.saved_grad.is_null()
return self.saved_grad.payload

@property
def param_is_sharded(self):
return self._sharded_data_tensor.is_sharded
return self.sharded_data_tensor.is_sharded

def reset_data_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.set_data_none()

def reset_grad_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.reset_payload(tensor)

def get_memory_usage(self) -> Tuple[int, int]:
"""
Expand All @@ -63,11 +96,11 @@ def _update_mem_use(t: Optional[torch.Tensor]):
cpu_mem_use += t_cpu

address_set = set()
_update_mem_use(self.sharded_data_tensor.payload)
address_set.add(self.sharded_data_tensor.payload.data_ptr())
_update_mem_use(self.data_payload)
address_set.add(self.data_payload.data_ptr())

if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.saved_grad.payload)
_update_mem_use(self.grad_payload)
address_set.add(self.saved_grad.data_ptr())

if self.param.data is not None and self.param.data.data_ptr() not in address_set:
Expand Down
6 changes: 6 additions & 0 deletions colossalai/zero/sharded_param/sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD)
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
assert tensor.requires_grad is False
super().__init__(tensor, state)

# kept the shape, numel and dtype of the init tensor.
Expand All @@ -17,6 +18,11 @@ def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD)
self._origin_dtype = tensor.dtype
self._is_sharded = False

@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype

@property
def origin_numel(self) -> int:
return self._origin_numel
Expand Down
21 changes: 10 additions & 11 deletions colossalai/zero/sharded_param/tensorful_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class StatefulTensor(object):
https://arxiv.org/abs/2108.05818
"""

def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if {self._state}"
assert self._payload is None, f"payload has to None if state is {self._state}"

def data_ptr(self):
if self._payload is None:
Expand All @@ -50,13 +50,13 @@ def trans_state(self, state: TensorState) -> None:
self._payload = None

@property
def payload(self) -> int:
def payload(self) -> Optional[torch.Tensor]:
return self._payload

def copy_payload(self, tensor) -> int:
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))

def reset_payload(self, tensor) -> int:
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
Expand All @@ -67,15 +67,14 @@ def device(self) -> torch.device:

@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
return self._payload.dtype

@property
def shape(self):
return self._payload.shape

def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")

def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")

@property
def shape(self):
return self._payload.shape
8 changes: 4 additions & 4 deletions colossalai/zero/utils/zero_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def pre_fwd_exec(self, module: torch.nn.Module, *args):
self._memstarts_collector.sample_memstats()

for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"

def post_fwd_exec(self, module: torch.nn.Module, *args):
Expand All @@ -79,7 +79,7 @@ def post_fwd_exec(self, module: torch.nn.Module, *args):

# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()

def pre_bwd_exec(self, module: torch.nn.Module, input, output):

Expand All @@ -105,7 +105,7 @@ def pre_bwd_exec(self, module: torch.nn.Module, input, output):
self._memstarts_collector.sample_memstats()

for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"

def post_bwd_exec(self, module: torch.nn.Module, input):
Expand All @@ -124,7 +124,7 @@ def post_bwd_exec(self, module: torch.nn.Module, input):

# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()

def pre_iter(self):
pass
Expand Down
6 changes: 3 additions & 3 deletions tests/test_moe/test_moe_zero_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
assert param.colo_attr.is_replicated

if param.colo_attr.param_is_sharded:
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
assert param.colo_attr.data_payload.device.type == init_device.type, \
f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
else:
assert param.colo_attr.sharded_data_tensor.payload.device.type == 'cuda'
assert param.colo_attr.data_payload.device.type == 'cuda'


def _run_dist(rank, world_size, port):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe/test_moe_zero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
assert_equal_in_group(p.colo_attr.data_payload)

model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_moe/test_moe_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))

model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
Expand All @@ -100,7 +100,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.colo_attr.sharded_data_tensor.payload)
p.data.copy_(zp.colo_attr.data_payload)

for i, (data, label) in enumerate(train_dataloader):
if i > 5:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_zero/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def check_grads_padding(model, zero_model, loose=False):
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
if zero_p.colo_attr.is_replicated:
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank].float()
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
else:
zero_grad = zero_p.colo_attr.saved_grad.payload
zero_grad = zero_p.colo_attr.grad_payload
grad = p.grad.to(zero_grad.dtype)

assert grad.dtype == zero_grad.dtype
Expand All @@ -127,15 +127,15 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
zero_p = zero_p.colo_attr.data_payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
zero_p = zero_p.colo_attr.data_payload.to(p.device)

assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
2 changes: 1 addition & 1 deletion tests/test_zero/test_found_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
data, label = data.cuda(), label.cuda()
_run_step(zero_model, sharded_optim, data, label, criterion, False)
for param in zero_model.parameters():
assert not has_inf_or_nan(param.colo_attr.sharded_data_tensor.payload)
assert not has_inf_or_nan(param.colo_attr.data_payload)


def _run_dist(rank, world_size, port):
Expand Down
Loading