-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MoRA Implementation #9562
Merged
Merged
MoRA Implementation #9562
Changes from 1 commit
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e8d72b6
MoRA Implementation
lcykww a46fce2
MoRA算法
lcykww 9100581
MoRA
lcykww 77945bc
MoRA
lcykww 0e420f3
MoRA
lcykww a91ae43
MoRA
lcykww c0e4416
MoRA
lcykww 142cd3d
MoRA
lcykww 46e3da3
MoRA
lcykww 5017004
MoRA
lcykww 6181b2b
changed lora_config.py
lcykww File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
MoRA
- Loading branch information
commit 9100581e9b8c04aa96e86dcaf661fab86a6f2598
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,19 +53,19 @@ | |
class LoRALinear(nn.Linear): | ||
# LoRA implemented in a dense layer | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
use_quick_lora: bool = False, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
pissa: bool = False, | ||
lora_use_mixer: bool = False, | ||
use_mora: bool = False, | ||
**kwargs | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
use_quick_lora: bool = False, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
pissa: bool = False, | ||
lora_use_mixer: bool = False, | ||
use_mora: bool = False, | ||
**kwargs | ||
): | ||
nn.Linear.__init__(self, in_features, out_features, **kwargs) | ||
if not isinstance(r, int) or r <= 0: | ||
|
@@ -129,7 +129,7 @@ | |
|
||
if not rslora and not pissa: | ||
self.scaling = self.lora_alpha / self.r | ||
elif pissa or use_mora: | ||
self.scaling = 1.0 | ||
else: | ||
self.scaling = self.lora_alpha / math.sqrt(self.r) | ||
|
@@ -161,11 +161,11 @@ | |
def RoPE_init(self, r, rb1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最好都是小写比如def rope_init There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
if self.cos is None or self.sin is None: | ||
inv_freq = 1.0 / (10000 ** (paddle.arange(0, r, 2, dtype=paddle.float32) / r)) | ||
t = paddle.arange(rb1, dtype=paddle.float32) | ||
t = paddle.arange(rb1, dtype=self._dtype) | ||
freqs = t.unsqueeze(1) @ inv_freq.unsqueeze(0) | ||
emb = paddle.concat([freqs, freqs], axis=-1) | ||
self.cos = paddle.unsqueeze(paddle.cos(emb), axis=0).astype(paddle.float32) | ||
self.sin = paddle.unsqueeze(paddle.sin(emb), axis=0).astype(paddle.float32) | ||
self.cos = paddle.unsqueeze(paddle.cos(emb), axis=0).astype(self._dtype) | ||
self.sin = paddle.unsqueeze(paddle.sin(emb), axis=0).astype(self._dtype) | ||
|
||
@property | ||
def use_quick_lora(self): | ||
|
@@ -189,7 +189,7 @@ | |
# create RoPE | ||
self.RoPE_init(r, rb1) | ||
# apply RoPE rotation | ||
rh_in_x = paddle.concat([-in_x[..., r // 2:], in_x[..., : r // 2]], axis=-1) | ||
rh_in_x = paddle.concat([-in_x[..., r // 2 :], in_x[..., : r // 2]], axis=-1) | ||
in_x = in_x * self.cos + rh_in_x * self.sin | ||
|
||
# matmul with high rank matrix | ||
|
@@ -198,10 +198,10 @@ | |
# reshape the output | ||
out_x = out_x.reshape([*x.shape[:-1], -1])[..., : self.out_features] | ||
if out_x.shape[-1] < self.out_features: | ||
repeat_time = self.out_features // out_x.shape[-1] | ||
if self.out_features % out_x.shape[-1] != 0: | ||
repeat_time += 1 | ||
out_x = paddle.concat([out_x] * repeat_time, axis=-1)[..., : self.out_features] | ||
|
||
return out_x | ||
|
||
|
@@ -221,19 +221,19 @@ | |
# create RoPE | ||
self.RoPE_init(r, rb1) | ||
# create the weights after rotation | ||
aw2 = paddle.concat([self.lora_A[:, r // 2:], -self.lora_A[:, : r // 2]], axis=1) | ||
aw2 = paddle.concat([self.lora_A[:, r // 2 :], -self.lora_A[:, : r // 2]], axis=1) | ||
# apply RoPE | ||
for i in range(rb1 - 1): | ||
w[i * r: (i + 1) * r, i * r: (i + 1) * r] = aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i] | ||
w[i * r : (i + 1) * r, i * r : (i + 1) * r] = aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i] | ||
# Process the last chunk that may be incomplete | ||
i = rb1 - 1 | ||
w[i * r:, i * r:] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, : r - pad_size] | ||
w[i * r :, i * r :] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, : r - pad_size] | ||
# padding | ||
if pad_size > 0: | ||
w[i * r:, :pad_size] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, r - pad_size:] | ||
w[i * r :, :pad_size] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, r - pad_size :] | ||
# reshape the weights | ||
if self.in_features < self.out_features: | ||
w = paddle.concat([w] * rb2, axis=0)[: self.out_features] | ||
else: | ||
w = w[: self.out_features] | ||
final_weight = w | ||
|
@@ -286,25 +286,25 @@ | |
|
||
class RowParallelLoRALinear(RowParallelLinear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
use_quick_lora: bool = False, | ||
pissa: bool = False, | ||
use_mora: bool = False, | ||
**kwargs | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
use_quick_lora: bool = False, | ||
pissa: bool = False, | ||
use_mora: bool = False, | ||
**kwargs | ||
): | ||
RowParallelLinear.__init__(self, in_features, out_features, **kwargs) | ||
if not isinstance(r, int) or r <= 0: | ||
raise ValueError("Lora rank r should be a positive integer") | ||
|
||
if pissa or use_mora: | ||
raise ValueError("Pissa or Mora is not supported in model parallel by now") | ||
|
||
self.r = r | ||
self.lora_alpha = lora_alpha | ||
|
@@ -440,16 +440,16 @@ | |
|
||
class RowSequenceParallelLoRALinear(RowSequenceParallelLinear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
use_quick_lora: bool = False, | ||
**kwargs | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
use_quick_lora: bool = False, | ||
**kwargs | ||
): | ||
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) | ||
if not isinstance(r, int) or r <= 0: | ||
|
@@ -551,26 +551,26 @@ | |
|
||
class ColumnParallelLoRALinear(ColumnParallelLinear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
lora_A_weight_attr: Optional[paddle.ParamAttr] = None, | ||
use_quick_lora: bool = False, | ||
pissa: bool = False, | ||
use_mora: bool = False, | ||
**kwargs | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
lora_A_weight_attr: Optional[paddle.ParamAttr] = None, | ||
use_quick_lora: bool = False, | ||
pissa: bool = False, | ||
use_mora: bool = False, | ||
**kwargs | ||
): | ||
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs) | ||
if not isinstance(r, int) or r <= 0: | ||
raise ValueError("Lora rank r should be a positive integer") | ||
|
||
if pissa or use_mora: | ||
raise ValueError("Pissa or Mora is not supported in model parallel by now") | ||
|
||
self.r = r | ||
self.lora_alpha = lora_alpha | ||
|
@@ -687,17 +687,17 @@ | |
|
||
class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
lora_A_weight_attr: Optional[paddle.ParamAttr] = None, | ||
use_quick_lora: bool = False, | ||
**kwargs | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
rslora: bool = False, | ||
lora_plus_scale: float = 1.0, | ||
lora_A_weight_attr: Optional[paddle.ParamAttr] = None, | ||
use_quick_lora: bool = False, | ||
**kwargs | ||
): | ||
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) | ||
if not isinstance(r, int) or r <= 0: | ||
|
@@ -802,14 +802,14 @@ | |
class LoRAConv2D(nn.Conv2D): | ||
# LoRA implemented in a dense layer | ||
def __init__( | ||
self, | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
**kwargs | ||
self, | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
**kwargs | ||
): | ||
nn.Conv2D.__init__(self, in_channels, out_channels, kernel_size, **kwargs) | ||
if not isinstance(r, int) or r <= 0: | ||
|
@@ -866,11 +866,11 @@ | |
else: | ||
# conv2d 3x3 | ||
delta_weight = ( | ||
F.conv2d( | ||
weight_A.transpose([1, 0, 2, 3]), | ||
weight_B, | ||
).transpose([1, 0, 2, 3]) | ||
* self.scaling | ||
F.conv2d( | ||
weight_A.transpose([1, 0, 2, 3]), | ||
weight_B, | ||
).transpose([1, 0, 2, 3]) | ||
* self.scaling | ||
) | ||
# Make sure that the weights are not merged | ||
new_weight = self.weight - delta_weight | ||
|
@@ -889,11 +889,11 @@ | |
else: | ||
# conv2d 3x3 | ||
delta_weight = ( | ||
F.conv2d( | ||
weight_A.transpose([1, 0, 2, 3]), | ||
weight_B, | ||
).transpose([1, 0, 2, 3]) | ||
* self.scaling | ||
F.conv2d( | ||
weight_A.transpose([1, 0, 2, 3]), | ||
weight_B, | ||
).transpose([1, 0, 2, 3]) | ||
* self.scaling | ||
) | ||
# Merge the weights and mark it | ||
new_weight = self.weight + delta_weight | ||
|
@@ -905,8 +905,8 @@ | |
result = super().forward(input) | ||
if not self.merged and not self.disable_lora: | ||
result += ( | ||
self.lora_B_forward(self.lora_A_forward(self.lora_dropout(input.cast(dtype=self.lora_A.dtype)))) | ||
* self.scaling | ||
self.lora_B_forward(self.lora_A_forward(self.lora_dropout(input.cast(dtype=self.lora_A.dtype)))) | ||
* self.scaling | ||
) | ||
result = result.cast(dtype=previous_dtype) | ||
return result | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/kongds/MoRA/blob/main/peft-mora/src/peft/tuners/lora/layer.py#L122 这里scaling设为1