Skip to content

Commit 80fdc4d

Browse files
committed
git # This is a combination of 3 commits.
Update low_level_zero_plugin.py Update low_level_zero_plugin.py fix fix fix
1 parent 1b33dcf commit 80fdc4d

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ jobs:
208208

209209
- name: Execute Unit Testing
210210
run: |
211-
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/test_booster/test_plugin/test_3d_plugin.py
211+
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
212212
env:
213213
DATA: /data/scratch/cifar-10
214214
NCCL_SHM_DISABLE: 1

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -351,26 +351,35 @@ def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
351351
return group_id
352352
return -1
353353

354+
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
355+
origin_param_id = id(origin_param)
356+
lora_param_id = id(lora_param)
357+
target_group_id = -1
358+
for group_id, param_group in enumerate(optimizer.param_groups):
359+
for p in param_group['params']:
360+
if id(p) == lora_param_id:
361+
# check if the lora parameter exists.
362+
return -2
363+
if id(p) == origin_param_id:
364+
target_group_id = group_id
365+
return target_group_id
366+
354367
def add_lora_params_to_optimizer(self, model, optimizer):
355368
""" add lora parameters to optimizer """
356369
name2param= {}
357370
for name, param in model.named_parameters():
358371
name2param[name] = param
359372

360-
optimizer_param_nums = 0
361-
for param_group in optimizer.param_groups:
362-
optimizer_param_nums += len(param_group['params'])
363-
364-
# Check if the optimizer is created after the model is transformed into a LoRa model.
365-
if len(name2param) != optimizer_param_nums:
366-
for name, param in name2param.items():
367-
if 'lora_A' in name or 'lora_B' in name:
368-
origin_key = name.replace("lora_A.", "")
369-
origin_key = origin_key.replace("lora_B.", "")
370-
origin_key = origin_key.replace(f"{model.active_adapter}.", "")
371-
origin_param = name2param[origin_key]
372-
group_id = self.get_param_group_id(optimizer, origin_param)
373-
assert group_id != -1, "Parameter error, origin parameter does't exists."
373+
for name, param in name2param.items():
374+
if 'lora_A' in name or 'lora_B' in name:
375+
origin_key = name.replace("lora_A.", "")
376+
origin_key = origin_key.replace("lora_B.", "")
377+
origin_key = origin_key.replace(f"{model.active_adapter}.", "")
378+
origin_param = name2param[origin_key]
379+
group_id = self.get_param_group_id(optimizer, origin_param, param)
380+
if group_id == -1:
381+
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
382+
elif group_id >= 0:
374383
optimizer.param_groups[group_id]['params'].append(param)
375384

376385
def configure(
@@ -384,7 +393,8 @@ def configure(
384393
if self.lora_enabled:
385394
from peft import PeftModel
386395
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
387-
self.add_lora_params_to_optimizer(model, optimizer)
396+
if optimizer is not None:
397+
self.add_lora_params_to_optimizer(model, optimizer)
388398

389399

390400
if not isinstance(model, ModelWrapper):

tests/test_booster/test_plugin/test_dp_plugin_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Iterator, List, Tuple, Union
1+
from typing import Callable, Iterator, List, Tuple, Union, Dict
22

33
import torch
44
import torch.distributed as dist
@@ -51,6 +51,12 @@ def supported_precisions(self) -> List[str]:
5151
def no_sync(self, model: nn.Module) -> Iterator[None]:
5252
pass
5353

54+
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
55+
pass
56+
57+
def support_lora(self) -> bool:
58+
pass
59+
5460

5561
def check_dataloader_sharding():
5662
plugin = DPPluginWrapper()

0 commit comments

Comments
 (0)