Skip to content

Commit 8a8c0ba

Browse files
committed
fix
fix fix fix fix
1 parent 4997393 commit 8a8c0ba

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import warnings
23
import os
34
from functools import partial
45
from pathlib import Path
@@ -9,6 +10,7 @@
910

1011
import torch
1112
import torch.nn as nn
13+
from torch.nn import Parameter
1214
from torch.optim import Optimizer
1315
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
1416
from torch.utils._pytree import tree_map
@@ -335,13 +337,27 @@ def enable_lora(
335337
from peft import PeftModel, get_peft_model
336338
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
337339
self.lora_enabled = True
340+
warnings.warn("You have enabled LoRa training. Please check the hyperparameter such as lr")
338341

339342
if pretrained_dir is None:
340343
peft_model = get_peft_model(model, lora_config)
341344
else:
342345
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
343346
return peft_model
344347

348+
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, add_param: Parameter):
349+
origin_param_id = id(origin_param)
350+
add_param_id = id(add_param)
351+
group_id = -1
352+
for pg_id, param_group in enumerate(optimizer.param_groups):
353+
for p in param_group['params']:
354+
if id(p) == add_param_id:
355+
return -2
356+
if id(p) == origin_param_id:
357+
group_id = pg_id
358+
return group_id
359+
360+
345361
def configure(
346362
self,
347363
model: nn.Module,
@@ -353,12 +369,21 @@ def configure(
353369
if self.lora_enabled:
354370
from peft import PeftModel
355371
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
356-
357-
optim_params_nums = 0
358-
for param_group in optimizer.param_groups:
359-
optim_params_nums += len(param_group['params'])
360-
model_params_nums = len(list(model.named_parameters()))
361-
assert optim_params_nums == model_params_nums, "Optimizer should be initialized after enabling lora."
372+
373+
# add lora parameters to optimizer
374+
name2param= {}
375+
for name, param in model.named_parameters():
376+
name2param[name] = param
377+
for name, param in name2param.items():
378+
if 'lora_A' in name or 'lora_B' in name:
379+
origin_key = name.replace("lora_A.", "")
380+
origin_key = origin_key.replace("lora_B.", "")
381+
origin_key = origin_key.replace(f"{model.active_adapter}.", "")
382+
origin_param = name2param[origin_key]
383+
group_id = self.get_param_group_id(optimizer, origin_param, param)
384+
assert group_id != -1, "Parameter error, origin parameter does't exists."
385+
if group_id >= 0:
386+
optimizer.param_groups[group_id]['params'].append(param)
362387

363388
if not isinstance(model, ModelWrapper):
364389
model = LowLevelZeroModel(model, self.precision)

tests/test_booster/test_plugin/test_low_level_zero_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
4848

4949
except Exception as e:
5050
return repr(e)
51+
# raise e
5152

5253

5354

tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
168168

169169
def run_dist(rank, world_size, port):
170170
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
171-
# check_low_level_zero_checkpointIO()
171+
check_low_level_zero_checkpointIO()
172172
check_low_level_zero_lora_checkpointIO()
173173
torch.cuda.empty_cache()
174174

0 commit comments

Comments
 (0)