Skip to content

Commit

Permalink
add scaling (#8256)
Browse files Browse the repository at this point in the history
* add scaling

* add scaling

* add scaling

* format
  • Loading branch information
lugimzzz authored Apr 12, 2024
1 parent 662feb1 commit 0d05544
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import math
import os
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union
Expand Down Expand Up @@ -94,6 +95,15 @@ def __post_init__(self):
)
self.use_quick_lora = False

@property
def scaling(self):
if not self.rslora and not self.pissa:
return self.lora_alpha / self.r
elif self.pissa:
return 1.0
else:
return self.lora_alpha / math.sqrt(self.r)

@property
def __dict__(self):
return asdict(self)
Expand All @@ -114,6 +124,7 @@ def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)

output_dict = self.__dict__
output_dict["scaling"] = self.scaling
output_path = os.path.join(save_directory, LORA_CONFIG_NAME)

# save it
Expand All @@ -136,6 +147,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
raise ValueError(f"Can't find lora_config.json at '{pretrained_model_name_or_path}'")

loaded_attributes = cls.from_json_file(config_file)
loaded_attributes.pop("scaling", None)

config = cls(**kwargs)

Expand Down

0 comments on commit 0d05544

Please sign in to comment.