Skip to content

Commit 0a18317

Browse files
authored
fix peft memory cost issue (#2133)
1 parent c292d3e commit 0a18317

40 files changed

+3061
-7173
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ tests/diffusers/
171171
tests/transformers/
172172
tests/huggingface_transformers/
173173
.gradio/
174+
175+
huanhuan.json

examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb

Lines changed: 1086 additions & 0 deletions
Large diffs are not rendered by default.

mindnlp/core/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@
5555
from .func import vmap
5656
from .configs import set_pyboost
5757

58-
from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
59-
return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils
60-
58+
from . import profiler, cuda, amp, compiler, jit, version, __future__, overrides, \
59+
return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils, optim
6160
from ._lowrank import svd_lowrank
6261
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
6362

mindnlp/core/_prims/ascend.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from mindspore.ops.auto_generate import gen_ops_prim
44
from mindspore.ops.auto_generate import pyboost_inner_prim
55
from mindspore._c_expression import _empty_instance
6+
from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
67

78
from mindnlp import core
89
from mindnlp.core._C import default_generator
@@ -162,4 +163,33 @@ def reverse_v2(input, dims):
162163
dims = (dims,)
163164
return pyboost_inner_prim.reverse_v2_impl(input, dims)
164165

165-
__all__.append('reverse_v2')
166+
adam_op = ops.Adam().set_device('Ascend')
167+
def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
168+
# var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad
169+
return adam_op(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
170+
171+
__all__.append('raw_adam')
172+
173+
depend_op = ops.Depend().set_device('Ascend')
174+
def depend(*args):
175+
return depend_op(*args)
176+
177+
__all__.append('depend')
178+
179+
npu_get_float_status_op = NPUGetFloatStatusV2().set_device('Ascend')
180+
def npu_get_float_status_v2(status):
181+
return npu_get_float_status_op(status)
182+
183+
__all__.append('npu_get_float_status_v2')
184+
185+
npu_clear_float_status_op = NPUClearFloatStatusV2().set_device('Ascend')
186+
def npu_clear_float_status_v2(status):
187+
return npu_clear_float_status_op(status)
188+
189+
__all__.append('npu_clear_float_status_v2')
190+
191+
stop_gradient_op = ops.StopGradient().set_device('Ascend')
192+
def stop_gradient(*args):
193+
return stop_gradient_op(*args)
194+
195+
__all__.append('stop_gradient')

mindnlp/core/_prims/numpy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,10 @@ def tril_ext(input, diagonal):
511511
return core.Tensor.from_numpy(out)
512512

513513
__all__.append('tril_ext')
514+
515+
def randperm_ext(n, seed, offset, dtype):
516+
out = np.random.permutation(n)
517+
return core.Tensor.from_numpy(out)
518+
519+
__all__.append('randperm_ext')
520+

mindnlp/core/_tensor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ def tensor(data, *, dtype=None, device=None, requires_grad=False):
9797
if isinstance(data, Tensor):
9898
UserWarning("To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than core.tensor(sourceTensor).")
9999
out = Tensor(data)
100-
out._device = data.device
100+
if device is not None:
101+
out._device = device
102+
else:
103+
out._device = data.device
104+
101105
return out
102106

103107
# if isinstance(data, list):
@@ -1306,7 +1310,7 @@ def is_floating_point(self):
13061310
def is_leaf(self):
13071311
if not self.requires_grad:
13081312
return True
1309-
if self.requires_grad and self._user_created:
1313+
if self.requires_grad and hasattr(self, 'param_info'):
13101314
return True
13111315
return False
13121316

@@ -2397,9 +2401,7 @@ def zero_(self):
23972401

23982402
# Tensor.detach
23992403
def detach(self):
2400-
out = self.data
2401-
out._requires_grad = False
2402-
return out
2404+
return ops.stop_gradient(self)
24032405

24042406
# Tensor.detach_
24052407
def detach_(self):
@@ -2501,7 +2503,7 @@ def __mod__(self, other):
25012503
return ops.fmod(self, other)
25022504

25032505
def backward(self):
2504-
pass
2506+
return self
25052507

25062508
def log_softmax(self, dim):
25072509
return ops.log_softmax(self, dim)

mindnlp/core/amp/grad_scaler.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,32 @@
77
from enum import Enum
88
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
99

10+
import numpy as np
11+
import mindspore
1012
from mindnlp import core
11-
13+
from mindnlp.core.configs import DEVICE_TARGET, SOC
1214

1315
__all__ = ["OptState", "GradScaler"]
1416

1517

18+
def non_finite_check(inputs):
19+
"""all finite check"""
20+
if DEVICE_TARGET == 'Ascend':
21+
status = core.tensor(np.array([0] * 8), dtype=core.int32, device='npu')
22+
status = core.depend(status, inputs)
23+
found_inf = core.npu_get_float_status_v2(status)
24+
status = core.depend(status, found_inf)
25+
clear_status = core.npu_clear_float_status_v2(status)
26+
found_inf = core.depend(found_inf, clear_status)
27+
return found_inf.sum()
28+
29+
found_inf = core.all_finite(inputs) # pylint: disable=invalid-unary-operand-type
30+
# else:
31+
# outputs = _hypermap(_partial(_overflow), inputs)
32+
# flag_sum = ops.addn(outputs).reshape(())
33+
# status_finite = ops.less(flag_sum, 1)
34+
return found_inf
35+
1636
class _MultiDeviceReplicator:
1737
"""Lazily serves copies of a tensor to requested devices.
1838
@@ -276,11 +296,15 @@ def _unscale_grads_(
276296

277297
for device, per_dtype_grads in per_device_and_dtype_grads.items():
278298
for grads in per_dtype_grads.values():
279-
core._amp_foreach_non_finite_check_and_unscale_(
280-
grads,
281-
per_device_found_inf.get(device),
282-
per_device_inv_scale.get(device),
283-
)
299+
# core._amp_foreach_non_finite_check_and_unscale_(
300+
# grads,
301+
# per_device_found_inf.get(device),
302+
# per_device_inv_scale.get(device),
303+
# )
304+
found_inf = per_device_found_inf.get(device)
305+
found_inf.copy_(non_finite_check(grads).to(found_inf.dtype))
306+
for grad in grads:
307+
grad *= per_device_inv_scale.get(device)
284308

285309
return per_device_found_inf._per_device_tensors
286310

@@ -518,14 +542,32 @@ def update(self, new_scale: Optional[Union[float, core.Tensor]] = None) -> None:
518542
for i in range(1, len(found_infs)):
519543
found_inf_combined += found_infs[i]
520544

521-
core._amp_update_scale_(
522-
_scale,
523-
_growth_tracker,
524-
found_inf_combined,
525-
self._growth_factor,
526-
self._backoff_factor,
527-
self._growth_interval,
528-
)
545+
# core._amp_update_scale_(
546+
# _scale,
547+
# _growth_tracker,
548+
# found_inf_combined,
549+
# self._growth_factor,
550+
# self._backoff_factor,
551+
# self._growth_interval,
552+
# )
553+
if found_inf_combined > 0:
554+
_scale.copy_(_scale * self._backoff_factor)
555+
_growth_tracker.copy_(_growth_tracker * 0)
556+
else:
557+
successful = self._growth_interval + 1
558+
if successful == self._growth_interval:
559+
new_scale = _scale * self._growth_factor
560+
if core.isfinite(new_scale):
561+
_scale.copy_(new_scale)
562+
_growth_tracker.copy_(_growth_tracker * 0)
563+
else:
564+
_growth_tracker.copy_(
565+
core.tensor(successful,
566+
dtype=_growth_tracker.dtype,
567+
device=_growth_tracker.device
568+
)
569+
)
570+
529571

530572
# To prepare for next iteration, clear the data collected from optimizers this iteration.
531573
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)

mindnlp/core/autograd/function.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
except:
1414
Function = None
1515

16+
from mindnlp import core
17+
1618
grad_ = GradOperation(False, True, False)
1719
grad_sens_ = GradOperation(False, True, True)
1820
grad_input_sens_ = GradOperation(True, True, True)
@@ -29,7 +31,7 @@ def fn_aux(*args):
2931
outputs = fn(*args)
3032
no_grad_outputs = ()
3133
for out in outputs[1:]:
32-
no_grad_outputs += (stop_gradient(out),)
34+
no_grad_outputs += (out.detach(),)
3335
return outputs[0], no_grad_outputs
3436

3537
if has_aux:
@@ -39,19 +41,19 @@ def fn_aux(*args):
3941

4042
def value_and_grad_f(*args, **kwargs):
4143
_pynative_executor.set_grad_flag(True)
42-
_pynative_executor.new_graph(fn, *args, **kwargs)
44+
_pynative_executor.new_graph(fn_, *args, **kwargs)
4345
values = fn_(*args, **kwargs)
44-
_pynative_executor.end_graph(fn, values, *args, **kwargs)
46+
_pynative_executor.end_graph(fn_, values, *args, **kwargs)
4547

4648
run_args = args
4749
if kwargs:
4850
run_args = args + tuple(kwargs.values())
4951

5052
grads = _pynative_executor.check_run(grad_, fn_, params_or_argnums, None, *run_args)
5153
grads = _pynative_executor.grad(fn_, grad_, params_or_argnums, None, *run_args)
52-
grads = tuple(mindspore.Tensor(grad) for grad in grads)
5354
if attach_grads:
5455
for param, grad in zip(params_or_argnums, grads):
56+
grad = core.tensor(grad, device=param.device)
5557
if param.grad is None:
5658
param.grad = grad
5759
else:

mindnlp/core/npu/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
HalfTensor = core.FloatTensor
2121
BFloat16Tensor = core.BFloat16Tensor
2222

23+
24+
class DefaultGenerators:
25+
def __getitem__(self, idx):
26+
return core.default_generator
27+
28+
default_generators = DefaultGenerators()
29+
2330
def set_compile_mode(*args, **kwargs):
2431
pass
2532

mindnlp/core/npu/random.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Union
2+
3+
from mindnlp import core
4+
5+
6+
def get_rng_state(device: Union[int, str, core.device] = "npu") -> core.Tensor:
7+
r"""Return the random number generator state of the specified GPU as a ByteTensor.
8+
9+
Args:
10+
device (msadapter.device or int, optional): The device to return the RNG state of.
11+
Default: ``'cuda'`` (i.e., ``msadapter.device('cuda')``, the current CUDA device).
12+
13+
.. warning::
14+
This function eagerly initializes CUDA.
15+
"""
16+
17+
if isinstance(device, str):
18+
device = core.device(device)
19+
elif isinstance(device, int):
20+
device = core.device("npu", device)
21+
idx = device.index
22+
default_generator = core.npu.default_generators[idx]
23+
return default_generator.get_state()

0 commit comments

Comments
 (0)