1
1
import numbers
2
+ import numpy as np
2
3
import mindspore
3
4
from mindspore import ops
4
5
from mindspore ._c_expression import typing
7
8
from mindspore .ops .auto_generate .gen_ops_prim import inplace_normal_op , inplace_scatter_value_op , inplace_scatter_src_reduce_op , \
8
9
inplace_scatter_src_op , inplace_fill_tensor_op , inplace_fill_scalar_op , inplace_zero_op , inplace_uniform_op , \
9
10
inplace_masked_fill_scalar_op , inplace_masked_fill_tensor_op , inplace_random_op , inplace_clamp_scalar_op , \
10
- inplace_clamp_tensor_op , inplace_copy_op , inplace_index_add_op
11
+ inplace_clamp_tensor_op , inplace_copy_op , inplace_index_add_op , inplace_erfinv_op
11
12
12
13
from mindnlp import core
13
14
from ..configs import use_pyboost
@@ -50,7 +51,7 @@ def inplace_normal(input, mean=0, std=1, *, generator=None):
50
51
if input .device .type == 'npu' :
51
52
inplace_normal_op (input , mean , std , seed , offset )
52
53
else :
53
- input .data = ops . normal (input .shape , mean , std )
54
+ input .data = core . tensor ( np . random . normal (mean , std , input .shape ), dtype = input . dtype )
54
55
return input
55
56
56
57
# uniform_
@@ -77,7 +78,8 @@ def inplace_uniform(input, *args, **kwargs):
77
78
if input .device .type == 'npu' :
78
79
inplace_uniform_op (input , from_ , to_ , seed , offset )
79
80
else :
80
- input .data = core .rand (input .shape , generator = generator_ , dtype = input .dtype ) * (to_ - from_ ) + from_
81
+ input .data = core .tensor (np .random .uniform (from_ , to_ , input .shape ), dtype = input .dtype )
82
+ # core.rand(input.shape, generator=generator_, dtype=input.dtype) * (to_ - from_) + from_
81
83
return input
82
84
83
85
def inplace_add (input , other , alpha ):
@@ -227,6 +229,13 @@ def inplace_clamp(self, min=None, max=None):
227
229
self .data = ops .clamp (self , min , max )
228
230
return self
229
231
232
+ def inplace_erfinv (self ):
233
+ if self .device .type == 'npu' :
234
+ inplace_erfinv_op (self )
235
+ else :
236
+ self .data = core .erfinv (self )
237
+ return self
238
+
230
239
__all__ = [
231
240
'inplace_copy' ,
232
241
'inplace_zero' ,
@@ -253,5 +262,6 @@ def inplace_clamp(self, min=None, max=None):
253
262
'inplace_tril' ,
254
263
'inplace_masked_fill' ,
255
264
'inplace_random' ,
256
- 'inplace_clamp'
265
+ 'inplace_clamp' ,
266
+ 'inplace_erfinv'
257
267
]
0 commit comments