Skip to content
95 changes: 89 additions & 6 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import warnings
import enum
import os
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Dict, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -41,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):

SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]

class OptimizerParamCheckState(enum.Enum):
ORIGIN_PARAM_FINDED = 0
ORIGIN_PARAM_NOT_FIND = -1
LORA_PARM_EXISTED = -2


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
Expand Down Expand Up @@ -208,6 +216,18 @@ def load_sharded_model(
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)


class LowLevelZeroPlugin(DPPluginBase):
"""
Expand Down Expand Up @@ -287,6 +307,7 @@ def __init__(
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.lora_enabled = False
self.verbose = verbose

# set class name with stage, for better error message
Expand All @@ -310,6 +331,66 @@ def control_device(self) -> bool:
def supported_devices(self) -> List[str]:
return ["cuda"]


def support_lora(self) -> bool:
return True

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")

if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model

def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']:
if id(p) == origin_param_id:
return group_id
return -1

def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
origin_param_id = id(origin_param)
lora_param_id = id(lora_param)
target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']:
if id(p) == lora_param_id:
# check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
if id(p) == origin_param_id:
target_group_id = group_id
if target_group_id is not None:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
else:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND

def add_lora_params_to_optimizer(self, model, optimizer):
""" add lora parameters to optimizer """
name2param= {}
for name, param in model.named_parameters():
name2param[name] = param

for name, param in name2param.items():
if 'lora_A' in name or 'lora_B' in name:
origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
optimizer.param_groups[group_id]['params'].append(param)

def configure(
self,
model: nn.Module,
Expand All @@ -318,6 +399,13 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if self.lora_enabled:
from peft import PeftModel
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
if optimizer is not None:
self.add_lora_params_to_optimizer(model, optimizer)


if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)

Expand All @@ -339,8 +427,3 @@ def get_checkpoint_io(self) -> CheckpointIO:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.no_sync()

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError
16 changes: 15 additions & 1 deletion colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -

return unpickle

def check_for_nccl_backend(group):

pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg

return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)

def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
Expand All @@ -65,7 +79,7 @@ def _broadcast_object_list(
c10d._warn_not_in_group("broadcast_object_list")
return

is_nccl_backend = c10d._check_for_nccl_backend(group)
is_nccl_backend = check_for_nccl_backend(group)
current_device = None

if device is not None:
Expand Down
3 changes: 3 additions & 0 deletions colossalai/zero/low_level/bookkeeping/gradient_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def get_working_grads_by_group_id(self, group_id: int) -> List:
"""

grad_list = []
# When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients.
if group_id not in self._grads_of_params.keys():
return grad_list
for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ SentencePiece
ninja
flash_attn==2.0.5
datasets
peft
peft>=0.7.1
#auto-gptq now not support torch1.12
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ einops
sentencepiece
google
protobuf
peft>=0.7.1
8 changes: 7 additions & 1 deletion tests/test_booster/test_plugin/test_dp_plugin_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union, Dict

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -51,6 +51,12 @@ def supported_precisions(self) -> List[str]:
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass

def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
pass

def support_lora(self) -> bool:
pass


def check_dataloader_sharding():
plugin = DPPluginWrapper()
Expand Down
40 changes: 39 additions & 1 deletion tests/test_booster/test_plugin/test_low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed as dist
from peft import LoraConfig

import colossalai
from colossalai.booster import Booster
Expand All @@ -18,12 +19,16 @@
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]


def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)

if lora_config is not None:
model = booster.enable_lora(model, lora_config=lora_config)

criterion = lambda x: x.mean()
data = data_gen_fn()

Expand All @@ -43,6 +48,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:

except Exception as e:
return repr(e)
# raise e



@parameterize("stage", [2])
Expand Down Expand Up @@ -81,10 +88,41 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])


@parameterize("stage", [2])
@parameterize("model_name", ["transformers_llama"])
def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair

sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config)

torch.cuda.empty_cache()

if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break

if dist.get_rank() == 0:
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])

def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_plugin(early_stop=early_stop)
check_low_level_zero_lora(early_stop=early_stop)


@rerun_if_address_is_in_use()
Expand Down
Loading