Skip to content

Commit 4391752

Browse files
authored
fix normal memory issue on GPU (#2107)
1 parent 6d94df0 commit 4391752

File tree

4 files changed

+25
-6
lines changed

4 files changed

+25
-6
lines changed

mindnlp/core/_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,9 @@ def tobytes(self):
797797
Tensor.index_add_ = ops.inplace_index_add
798798
StubTensor.index_add_ = ops.inplace_index_add
799799

800+
Tensor.erfinv_ = ops.inplace_erfinv
801+
StubTensor.erfinv_ = ops.inplace_erfinv
802+
800803
def _rebuild_from_type_v2(func, new_type, args, state):
801804
ret = func(*args)
802805
return ret

mindnlp/core/ops/creation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mindspore.ops._primitive_cache import _get_cache_prim
1313
from ..configs import use_pyboost, ON_ORANGE_PI
1414
from .._bind import get_default_dtype, get_default_device
15+
from .._dtype import dtype2np
1516
from .utils import py2dtype
1617
from .other import finfo
1718

@@ -195,11 +196,15 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False,
195196
device = 'meta'
196197

197198
# To avoid the problem in irecv and recv of using empty.
198-
if device != 'meta':
199+
if device not in ['meta', 'GPU']:
199200
out = mindspore.mint.empty(size, dtype=dtype, device=device)
200201
else:
201202
out = CTensor(dtype=dtype, shape=size)
202203
out = mindspore.Tensor(out)
204+
# else:
205+
# out = np.empty(size, dtype=dtype2np[dtype])
206+
# out = mindspore.Tensor(out)
207+
203208
if requires_grad:
204209
out.requires_grad = True
205210
return out

mindnlp/core/ops/inplace.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numbers
2+
import numpy as np
23
import mindspore
34
from mindspore import ops
45
from mindspore._c_expression import typing
@@ -7,7 +8,7 @@
78
from mindspore.ops.auto_generate.gen_ops_prim import inplace_normal_op, inplace_scatter_value_op, inplace_scatter_src_reduce_op, \
89
inplace_scatter_src_op, inplace_fill_tensor_op, inplace_fill_scalar_op, inplace_zero_op, inplace_uniform_op, \
910
inplace_masked_fill_scalar_op, inplace_masked_fill_tensor_op, inplace_random_op, inplace_clamp_scalar_op, \
10-
inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op
11+
inplace_clamp_tensor_op, inplace_copy_op, inplace_index_add_op, inplace_erfinv_op
1112

1213
from mindnlp import core
1314
from ..configs import use_pyboost
@@ -50,7 +51,7 @@ def inplace_normal(input, mean=0, std=1, *, generator=None):
5051
if input.device.type == 'npu':
5152
inplace_normal_op(input, mean, std, seed, offset)
5253
else:
53-
input.data = ops.normal(input.shape, mean, std)
54+
input.data = core.tensor(np.random.normal(mean, std, input.shape), dtype=input.dtype)
5455
return input
5556

5657
# uniform_
@@ -77,7 +78,8 @@ def inplace_uniform(input, *args, **kwargs):
7778
if input.device.type == 'npu':
7879
inplace_uniform_op(input, from_, to_, seed, offset)
7980
else:
80-
input.data = core.rand(input.shape, generator=generator_, dtype=input.dtype) * (to_ - from_) + from_
81+
input.data = core.tensor(np.random.uniform(from_, to_, input.shape), dtype=input.dtype)
82+
# core.rand(input.shape, generator=generator_, dtype=input.dtype) * (to_ - from_) + from_
8183
return input
8284

8385
def inplace_add(input, other, alpha):
@@ -227,6 +229,13 @@ def inplace_clamp(self, min=None, max=None):
227229
self.data = ops.clamp(self, min, max)
228230
return self
229231

232+
def inplace_erfinv(self):
233+
if self.device.type == 'npu':
234+
inplace_erfinv_op(self)
235+
else:
236+
self.data = core.erfinv(self)
237+
return self
238+
230239
__all__ = [
231240
'inplace_copy',
232241
'inplace_zero',
@@ -253,5 +262,6 @@ def inplace_clamp(self, min=None, max=None):
253262
'inplace_tril',
254263
'inplace_masked_fill',
255264
'inplace_random',
256-
'inplace_clamp'
265+
'inplace_clamp',
266+
'inplace_erfinv'
257267
]

mindnlp/core/ops/random.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def randn(*size, generator=None, dtype=None, **kwargs):
127127
dtype = get_default_dtype()
128128
if use_pyboost() and has_randn:
129129
return mindspore.mint.randn(*new_size, generator=generator, dtype=dtype)
130-
return ops.randn(*new_size, dtype=dtype)
130+
# return ops.randn(*new_size, dtype=dtype)
131+
return mindspore.Tensor(np.random.randn(*new_size), dtype=dtype)
131132

132133
# randn_like
133134
has_randn_like = hasattr(mindspore.mint, 'randn_like')

0 commit comments

Comments
 (0)