Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoRA Implementation #9562

Merged
merged 11 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MoRA
  • Loading branch information
lcykww committed Dec 14, 2024
commit 9100581e9b8c04aa96e86dcaf661fab86a6f2598
174 changes: 87 additions & 87 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@
class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
use_quick_lora: bool = False,
rslora: bool = False,
lora_plus_scale: float = 1.0,
pissa: bool = False,
lora_use_mixer: bool = False,
use_mora: bool = False,
**kwargs
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
use_quick_lora: bool = False,
rslora: bool = False,
lora_plus_scale: float = 1.0,
pissa: bool = False,
lora_use_mixer: bool = False,
use_mora: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
Expand Down Expand Up @@ -129,7 +129,7 @@

if not rslora and not pissa:
self.scaling = self.lora_alpha / self.r
elif pissa or use_mora:

Check warning on line 132 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L132

Added line #L132 was not covered by tests
self.scaling = 1.0
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)
Expand Down Expand Up @@ -161,11 +161,11 @@
def RoPE_init(self, r, rb1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好都是小写比如def rope_init

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if self.cos is None or self.sin is None:
inv_freq = 1.0 / (10000 ** (paddle.arange(0, r, 2, dtype=paddle.float32) / r))
t = paddle.arange(rb1, dtype=paddle.float32)
t = paddle.arange(rb1, dtype=self._dtype)
freqs = t.unsqueeze(1) @ inv_freq.unsqueeze(0)
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos = paddle.unsqueeze(paddle.cos(emb), axis=0).astype(paddle.float32)
self.sin = paddle.unsqueeze(paddle.sin(emb), axis=0).astype(paddle.float32)
self.cos = paddle.unsqueeze(paddle.cos(emb), axis=0).astype(self._dtype)
self.sin = paddle.unsqueeze(paddle.sin(emb), axis=0).astype(self._dtype)

@property
def use_quick_lora(self):
Expand All @@ -189,7 +189,7 @@
# create RoPE
self.RoPE_init(r, rb1)
# apply RoPE rotation
rh_in_x = paddle.concat([-in_x[..., r // 2:], in_x[..., : r // 2]], axis=-1)
rh_in_x = paddle.concat([-in_x[..., r // 2 :], in_x[..., : r // 2]], axis=-1)
in_x = in_x * self.cos + rh_in_x * self.sin

# matmul with high rank matrix
Expand All @@ -198,10 +198,10 @@
# reshape the output
out_x = out_x.reshape([*x.shape[:-1], -1])[..., : self.out_features]
if out_x.shape[-1] < self.out_features:
repeat_time = self.out_features // out_x.shape[-1]
if self.out_features % out_x.shape[-1] != 0:
repeat_time += 1
out_x = paddle.concat([out_x] * repeat_time, axis=-1)[..., : self.out_features]

Check warning on line 204 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L201-L204

Added lines #L201 - L204 were not covered by tests

return out_x

Expand All @@ -221,19 +221,19 @@
# create RoPE
self.RoPE_init(r, rb1)
# create the weights after rotation
aw2 = paddle.concat([self.lora_A[:, r // 2:], -self.lora_A[:, : r // 2]], axis=1)
aw2 = paddle.concat([self.lora_A[:, r // 2 :], -self.lora_A[:, : r // 2]], axis=1)
# apply RoPE
for i in range(rb1 - 1):
w[i * r: (i + 1) * r, i * r: (i + 1) * r] = aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i]
w[i * r : (i + 1) * r, i * r : (i + 1) * r] = aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i]
# Process the last chunk that may be incomplete
i = rb1 - 1
w[i * r:, i * r:] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, : r - pad_size]
w[i * r :, i * r :] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, : r - pad_size]
# padding
if pad_size > 0:
w[i * r:, :pad_size] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, r - pad_size:]
w[i * r :, :pad_size] = (aw2 * self.sin[:, i] + self.lora_A * self.cos[:, i])[:, r - pad_size :]
# reshape the weights
if self.in_features < self.out_features:
w = paddle.concat([w] * rb2, axis=0)[: self.out_features]

Check warning on line 236 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L236

Added line #L236 was not covered by tests
else:
w = w[: self.out_features]
final_weight = w
Expand Down Expand Up @@ -286,25 +286,25 @@

class RowParallelLoRALinear(RowParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
use_quick_lora: bool = False,
pissa: bool = False,
use_mora: bool = False,
**kwargs
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
use_quick_lora: bool = False,
pissa: bool = False,
use_mora: bool = False,
**kwargs
):
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")

if pissa or use_mora:
raise ValueError("Pissa or Mora is not supported in model parallel by now")

Check warning on line 307 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L306-L307

Added lines #L306 - L307 were not covered by tests

self.r = r
self.lora_alpha = lora_alpha
Expand Down Expand Up @@ -440,16 +440,16 @@

class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
use_quick_lora: bool = False,
**kwargs
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
use_quick_lora: bool = False,
**kwargs
):
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
Expand Down Expand Up @@ -551,26 +551,26 @@

class ColumnParallelLoRALinear(ColumnParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
pissa: bool = False,
use_mora: bool = False,
**kwargs
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
pissa: bool = False,
use_mora: bool = False,
**kwargs
):
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")

if pissa or use_mora:
raise ValueError("Pissa or Mora is not supported in model parallel by now")

Check warning on line 573 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L572-L573

Added lines #L572 - L573 were not covered by tests

self.r = r
self.lora_alpha = lora_alpha
Expand Down Expand Up @@ -687,17 +687,17 @@

class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
):
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
Expand Down Expand Up @@ -802,14 +802,14 @@
class LoRAConv2D(nn.Conv2D):
# LoRA implemented in a dense layer
def __init__(
self,
in_channels,
out_channels,
kernel_size,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs
self,
in_channels,
out_channels,
kernel_size,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs
):
nn.Conv2D.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
if not isinstance(r, int) or r <= 0:
Expand Down Expand Up @@ -866,11 +866,11 @@
else:
# conv2d 3x3
delta_weight = (
F.conv2d(
weight_A.transpose([1, 0, 2, 3]),
weight_B,
).transpose([1, 0, 2, 3])
* self.scaling
F.conv2d(
weight_A.transpose([1, 0, 2, 3]),
weight_B,
).transpose([1, 0, 2, 3])
* self.scaling
)
# Make sure that the weights are not merged
new_weight = self.weight - delta_weight
Expand All @@ -889,11 +889,11 @@
else:
# conv2d 3x3
delta_weight = (
F.conv2d(
weight_A.transpose([1, 0, 2, 3]),
weight_B,
).transpose([1, 0, 2, 3])
* self.scaling
F.conv2d(
weight_A.transpose([1, 0, 2, 3]),
weight_B,
).transpose([1, 0, 2, 3])
* self.scaling
)
# Merge the weights and mark it
new_weight = self.weight + delta_weight
Expand All @@ -905,8 +905,8 @@
result = super().forward(input)
if not self.merged and not self.disable_lora:
result += (
self.lora_B_forward(self.lora_A_forward(self.lora_dropout(input.cast(dtype=self.lora_A.dtype))))
* self.scaling
self.lora_B_forward(self.lora_A_forward(self.lora_dropout(input.cast(dtype=self.lora_A.dtype))))
* self.scaling
)
result = result.cast(dtype=previous_dtype)
return result
Expand Down
Loading