Skip to content

Commit db94ff9

Browse files
committed
fix naming
fix naming fix naming fix
1 parent 80fdc4d commit db94ff9

File tree

4 files changed

+33
-9
lines changed

4 files changed

+33
-9
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import warnings
3+
import enum
34
import os
45
from functools import partial
56
from pathlib import Path
@@ -43,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
4344

4445
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
4546

47+
class OptimizerParamCheckState(enum.Enum):
48+
ORIGIN_PARAM_FINDED = 0
49+
ORIGIN_PARAM_NOT_FIND = -1
50+
LORA_PARM_EXISTED = -2
51+
4652

4753
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
4854
def __init__(self, module: nn.Module, precision: str) -> None:
@@ -354,15 +360,18 @@ def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
354360
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
355361
origin_param_id = id(origin_param)
356362
lora_param_id = id(lora_param)
357-
target_group_id = -1
363+
target_group_id = None
358364
for group_id, param_group in enumerate(optimizer.param_groups):
359365
for p in param_group['params']:
360366
if id(p) == lora_param_id:
361367
# check if the lora parameter exists.
362-
return -2
368+
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
363369
if id(p) == origin_param_id:
364370
target_group_id = group_id
365-
return target_group_id
371+
if target_group_id is not None:
372+
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
373+
else:
374+
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
366375

367376
def add_lora_params_to_optimizer(self, model, optimizer):
368377
""" add lora parameters to optimizer """
@@ -374,12 +383,12 @@ def add_lora_params_to_optimizer(self, model, optimizer):
374383
if 'lora_A' in name or 'lora_B' in name:
375384
origin_key = name.replace("lora_A.", "")
376385
origin_key = origin_key.replace("lora_B.", "")
377-
origin_key = origin_key.replace(f"{model.active_adapter}.", "")
386+
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
378387
origin_param = name2param[origin_key]
379-
group_id = self.get_param_group_id(optimizer, origin_param, param)
380-
if group_id == -1:
388+
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
389+
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
381390
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
382-
elif group_id >= 0:
391+
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
383392
optimizer.param_groups[group_id]['params'].append(param)
384393

385394
def configure(

colossalai/pipeline/p2p.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
4444

4545
return unpickle
4646

47+
def check_for_nccl_backend(group):
48+
49+
pg = group or c10d._get_default_group()
50+
# Gate PG wrapper check on Gloo availability.
51+
if c10d._GLOO_AVAILABLE:
52+
# It is not expected for PG to be wrapped many times, but support it just
53+
# in case
54+
while isinstance(pg, c10d._ProcessGroupWrapper):
55+
pg = pg.wrapped_pg
56+
57+
return (
58+
c10d.is_nccl_available() and
59+
pg.name() == c10d.Backend.NCCL
60+
)
4761

4862
def _broadcast_object_list(
4963
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
@@ -65,7 +79,7 @@ def _broadcast_object_list(
6579
c10d._warn_not_in_group("broadcast_object_list")
6680
return
6781

68-
is_nccl_backend = c10d._check_for_nccl_backend(group)
82+
is_nccl_backend = check_for_nccl_backend(group)
6983
current_device = None
7084

7185
if device is not None:

requirements/requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ SentencePiece
1818
ninja
1919
flash_attn==2.0.5
2020
datasets
21-
peft
21+
peft>=0.7.1
2222
#auto-gptq now not support torch1.12

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ einops
1414
sentencepiece
1515
google
1616
protobuf
17+
peft>=0.7.1

0 commit comments

Comments
 (0)