-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vera-pissa method added #8722
vera-pissa method added #8722
Changes from 1 commit
1d51f1f
aa95de7
d0f9689
92f8773
277e388
1baf39a
0aef75a
6c6c708
dd86a6c
ddc939a
d4810c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ | |
from paddle.distributed.fleet.meta_parallel import PipelineLayer | ||
|
||
from ...transformers.model_utils import PretrainedModel, _add_variant, dtype_guard | ||
from ...transformers.utils import weight_name_suffix | ||
from ...utils.env import VERA_WEIGHTS_NAME | ||
from ...utils.log import logger | ||
from .vera_config import VeRAConfig | ||
|
@@ -46,9 +45,9 @@ | |
self.model = self.get_vera_model(model, vera_config) | ||
self.is_pipelinemodel = False | ||
if issubclass(type(self.model), PipelineLayer): | ||
raise NotImplementedError("vera don't support pipeline parallel now") | ||
if vera_config.tensor_parallel_degree > 1: | ||
raise NotImplementedError("vera don't support tensor parallel now") | ||
self.forward = self.model.forward | ||
|
||
@classmethod | ||
|
@@ -77,14 +76,14 @@ | |
vera_config_tensor_parallel_degree > 1 | ||
and vera_config_tensor_parallel_degree != model.config.tensor_parallel_degree | ||
): | ||
raise NotImplementedError( | ||
f"{vera_config_tensor_parallel_degree} is not equal to {model.config.tensor_parallel_degree}. Please merge VeRA weights first." | ||
) | ||
|
||
# set vera state dict | ||
vera_model.set_state_dict(vera_state_dict) | ||
else: | ||
logger.error(f"VeRA weights not found under {vera_path}, creating VeRA weights from scratch") | ||
|
||
return vera_model | ||
|
||
|
@@ -103,15 +102,15 @@ | |
save_model_config = kwargs.get("save_model_config", True) | ||
|
||
if self.is_pipelinemodel: | ||
self.model._single_to_pp_mapping = None | ||
if self.quantized and merge_tensor_parallel and self.vera_config.tensor_parallel_degree > 1: | ||
merge_tensor_parallel = False | ||
logger.warning( | ||
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False." | ||
) | ||
if self.is_pipelinemodel and merge_tensor_parallel and self.vera_config.tensor_parallel_degree > 1: | ||
merge_tensor_parallel = False | ||
logger.warning( | ||
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False." | ||
) | ||
|
||
|
@@ -128,9 +127,6 @@ | |
logger.info(f"vera config to save is {vera_config_to_save}") | ||
|
||
trainable_state_dict = self.get_trainable_state_dict() | ||
if vera_config_to_save.tensor_parallel_degree > 1: | ||
if variant is None: | ||
variant = weight_name_suffix() | ||
|
||
# save vera weight | ||
vera_weight_name = _add_variant(VERA_WEIGHTS_NAME, variant) | ||
|
@@ -143,7 +139,7 @@ | |
if save_model_config: | ||
model_config_to_save = copy.deepcopy(self.model.config) | ||
if merge_tensor_parallel: | ||
model_config_to_save.tensor_parallel_degree = -1 | ||
model_config_to_save.save_pretrained(save_directory) | ||
|
||
def _find_and_replace_module(self, model, module_name, vera_config, enable_vera): | ||
|
@@ -178,17 +174,17 @@ | |
setattr(parent_module, attribute_chain[-1], vera_module) | ||
|
||
def _find_and_restore_module(self, module_name): | ||
parent_module = self.model | ||
attribute_chain = module_name.split(".") | ||
for name in attribute_chain[:-1]: | ||
parent_module = getattr(parent_module, name) | ||
module = getattr(parent_module, attribute_chain[-1]) | ||
original_model_class = self.restore_layer_map[module.__class__] | ||
original_module = original_model_class(in_features=module.weight.shape[0], out_features=module.weight.shape[1]) | ||
original_module.weight = module.weight | ||
if module.bias is not None: | ||
original_module.bias = module.bias | ||
setattr(parent_module, attribute_chain[-1], original_module) | ||
|
||
def get_trainable_state_dict(self): | ||
trainable_state_dict = OrderedDict() | ||
|
@@ -199,14 +195,14 @@ | |
return trainable_state_dict | ||
|
||
def print_trainable_parameters(self) -> None: | ||
freeze_numel = 0 | ||
trainable_numel = 0 | ||
for _, weight in self.model.state_dict().items(): | ||
if weight.stop_gradient: | ||
freeze_numel += np.prod(weight.shape) | ||
else: | ||
trainable_numel += np.prod(weight.shape) | ||
logger.debug( | ||
f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel+trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel+trainable_numel):.2%}" | ||
) | ||
|
||
|
@@ -215,14 +211,14 @@ | |
if isinstance(layer, VeRALinear): | ||
for name, weight in layer.state_dict().items(): | ||
if self.vera_config.trainable_bias in ["vera", "all"] and "bias" in name: | ||
weight.stop_gradient = False | ||
elif "vera" in name: | ||
# notfreezeB=True, vera_b, vera_d, vera_B is trainable | ||
# notfreezeB=False, vera_b, vera_d is trainable | ||
if "vera_b" in name or "vera_d" in name: | ||
weight.stop_gradient = False | ||
elif "vera_B" in name and notfreezeB: | ||
weight.stop_gradient = False | ||
else: | ||
weight.stop_gradient = True | ||
else: | ||
|
@@ -230,26 +226,26 @@ | |
else: | ||
for name, weight in layer.state_dict().items(): | ||
if self.vera_config.trainable_bias == "all" and "bias" in name: | ||
weight.stop_gradient = False | ||
else: | ||
weight.stop_gradient = True | ||
if self.vera_config.trainable_modules is not None: | ||
for name, weight in self.model.state_dict().items(): | ||
if any( | ||
re.fullmatch(trainable_module, name) for trainable_module in self.vera_config.trainable_modules | ||
): | ||
weight.stop_gradient = False | ||
|
||
def get_vera_model(self, model: Union[PretrainedModel, nn.Layer], vera_config: VeRAConfig): | ||
|
||
if vera_config.target_modules is None: | ||
return model | ||
elif isinstance(vera_config.target_modules, str): | ||
target_modules = [vera_config.target_modules] | ||
if vera_config.enable_vera_list is None: | ||
enable_vera_list = [vera_config.enable_vera_list] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. enable_vera_list 这个应该是直接复用lora的,vera并没有对应的功能,建议把enable_vera_list相关全部删除,走代码里为None的分支就好 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该是在vera_config层面就把enable_vera_list全部删除,因为我们不需要这个参数,我看现在代码还保留着? |
||
else: | ||
raise TypeError( | ||
f"Invalid `enable_vera_list` value: {vera_config.enable_vera_list}. Since `target_modules` is `str`, `enable_vera_list` must be `None` or `List[bool]`" | ||
) | ||
else: | ||
|
@@ -257,7 +253,7 @@ | |
if vera_config.enable_vera_list is None: | ||
enable_vera_list = [None for _ in range(len(target_modules))] | ||
else: | ||
raise TypeError( | ||
f"Invalid `enable_vera_list` value: {vera_config.enable_vera_list}. Since `target_modules` is `List[str]`, `enable_vera_list` must be `None` or `List[Optional[List[bool]]]`" | ||
) | ||
|
||
|
@@ -269,23 +265,19 @@ | |
return model | ||
|
||
def restore_original_model(self): | ||
# make sure W and vera weights are not merged before we restore the original model | ||
if self.vera_config.merge_weights: | ||
self.train() | ||
|
||
for layer_name, layer in self.model.named_sublayers(): | ||
if isinstance(layer, VeRALinear): | ||
self._find_and_restore_module(layer_name) | ||
else: | ||
raise NotImplementedError(f"{layer} restoration is not supported yet.") | ||
return self.model | ||
|
||
def __getattr__(self, name: str): | ||
"""Forward missing attributes to the wrapped module.""" | ||
try: | ||
return super().__getattr__(name) # defer to nn.Layer's logic | ||
except AttributeError: | ||
return getattr(self.model, name) | ||
|
||
def train(self): | ||
self.training = True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"base_model_name_or_path": null, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个文件的作用是什么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 测试用的,已删除,done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已把vera_config层就把enable_vera_list全部删除 |
||
"do_qat": false, | ||
"dtype": null, | ||
"enable_vera_list": null, | ||
"head_dim": null, | ||
"pissa_init": false, | ||
"r": 8, | ||
"target_modules": null, | ||
"tensor_parallel_degree": -1, | ||
"trainable_bias": null, | ||
"trainable_modules": null, | ||
"vera_alpha": 8, | ||
"vera_dropout": 0.0 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么把这个删除了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了增加代码的覆盖率,重新加回去了并添加相应的异常测试