Skip to content

Commit d8c0a58

Browse files
committed
fix
fix fix
1 parent 70e8113 commit d8c0a58

File tree

11 files changed

+117
-47
lines changed

11 files changed

+117
-47
lines changed

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from colossalai.cluster import DistCoordinator
1616
from colossalai.interface import ModelWrapper, OptimizerWrapper
17-
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
17+
from colossalai.tensor.padded_tensor import (
18+
init_as_padded_tensor,
19+
is_padded_tensor,
20+
to_padded_tensor,
21+
to_unpadded_tensor,
22+
)
1823
from colossalai.utils import get_current_device
1924

2025
from .general_checkpoint_io import GeneralCheckpointIO
@@ -873,7 +878,7 @@ def gather_from_sharded_optimizer_state(
873878

874879
padding_dim = search_padding_dim(v.shape, original_shape)
875880
if padding_dim is not None:
876-
v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
881+
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
877882
v = to_unpadded_tensor(v)
878883

879884
state_[k] = v.detach().clone().to(device)

colossalai/checkpoint_io/utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
to_global,
2020
to_global_for_customized_distributed_tensor,
2121
)
22-
from colossalai.tensor.p_tensor.api import init_as_ptensor, is_padded_tensor
2322

2423
SAFE_WEIGHTS_NAME = "model.safetensors"
2524
WEIGHTS_NAME = "pytorch_model.bin"
@@ -208,13 +207,11 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
208207
"""
209208
param_ = param if keep_vars else param.detach()
210209
if is_distributed_tensor(param_):
211-
param_ = to_global(param_)
210+
return to_global(param_)
212211
elif is_customized_distributed_tensor(param_):
213-
param_ = to_global_for_customized_distributed_tensor(param_)
214-
215-
if is_padded_tensor(param):
216-
param_ = init_as_ptensor(param_, param.current_length, param.origin_length, param.padding_dim)
217-
return param_
212+
return to_global_for_customized_distributed_tensor(param_)
213+
else:
214+
return param_
218215

219216

220217
def save_state_dict_shards(

colossalai/shardformer/layer/parallel_module.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
is_distributed_tensor,
2121
sharded_tensor_to_param,
2222
)
23-
from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
23+
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
2424

2525
__all__ = ["ParallelModule"]
2626

@@ -297,8 +297,7 @@ def _load_from_state_dict(
297297
continue
298298

299299
if is_padded_tensor(param):
300-
print("is_padded_tensor(param)", is_padded_tensor(param))
301-
input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim)
300+
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
302301

303302
if is_distributed_tensor(param):
304303
# shard the input param

colossalai/tensor/d_tensor/layout_converter.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from colossalai.tensor.d_tensor.comm_spec import *
1111
from colossalai.tensor.d_tensor.layout import Layout
1212
from colossalai.tensor.d_tensor.misc import LayoutException
13+
from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
1314
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
1415

1516
from .sharding_spec import ShardingSpec
@@ -607,8 +608,18 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo
607608
[3.],
608609
[3.]])
609610
"""
611+
610612
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
613+
614+
target_tensor = tensor
611615
for comm_spec in comm_action_sequence:
612-
tensor = comm_spec.covert_spec_to_action(tensor)
613-
tensor.dist_layout = target_layout
614-
return tensor
616+
target_tensor = comm_spec.covert_spec_to_action(target_tensor)
617+
target_tensor.dist_layout = target_layout
618+
619+
# restore the padding information
620+
if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
621+
target_tensor = init_as_padded_tensor(
622+
target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
623+
)
624+
625+
return target_tensor

colossalai/tensor/p_tensor/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
2+
3+
__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"]

colossalai/tensor/p_tensor/api.py renamed to colossalai/tensor/padded_tensor/api.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
1616

1717
def new_detach(self):
1818
t_ = self._unpad_detach()
19-
t_.padding_dim = self.padding_dim
20-
t_.origin_length = self.origin_length
21-
t_.current_length = self.current_length
19+
t_._padding_dim = self._padding_dim
20+
t_._origin_length = self._origin_length
21+
t_._current_length = self._current_length
2222
return t_
2323

2424
def new_clone(self, *args, **kwargs):
2525
t_ = self._unpad_clone(*args, **kwargs)
26-
t_.padding_dim = self.padding_dim
27-
t_.origin_length = self.origin_length
28-
t_.current_length = self.current_length
26+
t_._padding_dim = self._padding_dim
27+
t_._origin_length = self._origin_length
28+
t_._current_length = self._current_length
2929
return t_
3030

3131
# bind the new methods to the tensor
@@ -63,7 +63,7 @@ def is_padded_tensor(tensor: torch.Tensor) -> bool:
6363
Returns:
6464
bool: Whether the given tensor is a padding tensor.
6565
"""
66-
return hasattr(tensor, "padding_dim")
66+
return hasattr(tensor, "_padding_dim")
6767

6868

6969
def to_padded_tensor(
@@ -89,9 +89,9 @@ def to_padded_tensor(
8989
)
9090
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
9191

92-
setattr(tensor, "padding_dim", padding_dim)
93-
setattr(tensor, "origin_length", origin_length)
94-
setattr(tensor, "current_length", current_length)
92+
tensor._padding_dim = padding_dim
93+
tensor._origin_length = origin_length
94+
tensor._current_length = current_length
9595

9696
_hijack_detach_and_clone(tensor)
9797

@@ -103,25 +103,25 @@ def to_unpadded_tensor(ptensor: torch.Tensor):
103103
return ptensor
104104

105105
unpad_slices = [slice(None)] * ptensor.dim()
106-
unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length)
106+
unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
107107
ptensor.data = ptensor.data[tuple(unpad_slices)]
108108

109-
delattr(ptensor, "padding_dim")
110-
delattr(ptensor, "origin_length")
111-
delattr(ptensor, "current_length")
109+
delattr(ptensor, "_padding_dim")
110+
delattr(ptensor, "_origin_length")
111+
delattr(ptensor, "_current_length")
112112

113113
_hijack_back_detach_and_clone(ptensor)
114114

115115
return ptensor
116116

117117

118-
def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
118+
def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
119119
if is_padded_tensor(tensor):
120120
return tensor
121121

122-
setattr(tensor, "padding_dim", padding_dim)
123-
setattr(tensor, "origin_length", origin_length)
124-
setattr(tensor, "current_length", current_length)
122+
tensor._padding_dim = padding_dim
123+
tensor._origin_length = origin_length
124+
tensor._current_length = current_length
125125

126126
_hijack_detach_and_clone(tensor)
127127

colossalai/zero/gemini/gemini_ddp.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
is_customized_distributed_tensor,
2828
is_distributed_tensor,
2929
)
30-
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
30+
from colossalai.tensor.padded_tensor import (
31+
init_as_padded_tensor,
32+
is_padded_tensor,
33+
to_padded_tensor,
34+
to_unpadded_tensor,
35+
)
3136
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
3237
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
3338

@@ -462,8 +467,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
462467
)
463468
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
464469
if is_padded_tensor(tensor):
465-
record_tensor = init_as_ptensor(
466-
record_tensor, tensor.current_length, tensor.origin_length, tensor.padding_dim
470+
record_tensor = init_as_padded_tensor(
471+
record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
467472
)
468473
record_tensor = to_unpadded_tensor(record_tensor)
469474

@@ -661,7 +666,7 @@ def load(
661666
global_shape = get_global_shape(dest_tensor)
662667

663668
if is_padded_tensor(dest_tensor):
664-
padding_dim = dest_tensor.padding_dim
669+
padding_dim = dest_tensor._padding_dim
665670
input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)
666671

667672
if source_device_mesh is not None and source_sharding_spec is not None:

colossalai/zero/gemini/gemini_optimizer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
is_customized_distributed_tensor,
2929
is_distributed_tensor,
3030
)
31-
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
31+
from colossalai.tensor.padded_tensor import (
32+
init_as_padded_tensor,
33+
is_padded_tensor,
34+
to_padded_tensor,
35+
to_unpadded_tensor,
36+
)
3237
from colossalai.utils import disposable, is_ddp_ignored
3338

3439
from .chunk import Chunk, ChunkManager
@@ -495,8 +500,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
495500
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
496501
state_tensor = state_tensor.reshape(global_shape)
497502
if is_padded_tensor(param):
498-
state_tensor = init_as_ptensor(
499-
state_tensor, param.current_length, param.origin_length, param.padding_dim
503+
state_tensor = init_as_padded_tensor(
504+
state_tensor, param._current_length, param._origin_length, param._padding_dim
500505
)
501506
state_tensor = to_unpadded_tensor(state_tensor)
502507
collected_states[state_name] = state_tensor
@@ -555,8 +560,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
555560
)
556561
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
557562
if is_padded_tensor(param):
558-
state_tensor = init_as_ptensor(
559-
state_tensor, param.current_length, param.origin_length, param.padding_dim
563+
state_tensor = init_as_padded_tensor(
564+
state_tensor, param._current_length, param._origin_length, param._padding_dim
560565
)
561566
state_tensor = to_unpadded_tensor(state_tensor)
562567

@@ -732,7 +737,7 @@ def cast(param, state_range, value, global_shape, origin_shape, key=None):
732737

733738
if is_padded_tensor(real_param):
734739
value = torch.reshape(value, origin_shape)
735-
padding_dim = real_param.padding_dim
740+
padding_dim = real_param._padding_dim
736741
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
737742

738743
if is_dtensor:

tests/test_shardformer/test_model/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from colossalai.shardformer._utils import getattr_
2222
from colossalai.shardformer.policies.auto_policy import Policy
2323
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
24-
from colossalai.tensor.p_tensor.api import is_padded_tensor, to_unpadded_tensor
24+
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
2525

2626

2727
def build_model(

0 commit comments

Comments
 (0)