@@ -374,68 +374,36 @@ def forward(self, var, block=None):
374
374
["uint16" , "float16" , "float32" , "float64" ],
375
375
"guassian_random" )
376
376
377
- # to be compatible of fp16 initalizers
378
- if var .dtype in [VarDesc .VarType .FP16 , VarDesc .VarType .BF16 ]:
379
- out_dtype = VarDesc .VarType .FP32
380
- out_var = block .create_var (name = unique_name .generate ("." .join (
381
- ['normal_init' , var .name , 'tmp' ])),
382
- shape = var .shape ,
383
- dtype = out_dtype ,
384
- type = VarDesc .VarType .LOD_TENSOR ,
385
- persistable = False )
386
- else :
387
- out_dtype = var .dtype
388
- out_var = var
389
-
390
377
if self ._seed == 0 :
391
378
self ._seed = block .program .random_seed
392
379
393
380
if in_dygraph_mode ():
394
381
place = _current_expected_place ()
395
382
out_var = _C_ops .gaussian_random (var .shape , self ._mean ,
396
383
self ._std_dev , self ._seed ,
397
- out_dtype , place )
398
-
399
- if var .dtype in [VarDesc .VarType .FP16 , VarDesc .VarType .BF16 ]:
400
- var_tmp = _C_ops .cast (out_var , var .dtype )
401
- var_tmp ._share_underline_tensor_to (var )
402
- else :
403
- out_var ._share_underline_tensor_to (var )
384
+ var .dtype , place )
385
+ out_var ._share_underline_tensor_to (var )
404
386
return None
405
387
406
388
if _in_legacy_dygraph ():
407
389
out_var = _legacy_C_ops .gaussian_random (
408
- 'shape' , var .shape , 'dtype' , out_dtype , 'mean' , self ._mean ,
390
+ 'shape' , var .shape , 'dtype' , var . dtype , 'mean' , self ._mean ,
409
391
'std' , self ._std_dev , 'seed' , self ._seed , 'use_mkldnn' , False )
410
392
411
- if var .dtype in [VarDesc .VarType .FP16 , VarDesc .VarType .BF16 ]:
412
- var_tmp = _legacy_C_ops .cast (out_var , 'in_dtype' , out_var .dtype ,
413
- 'out_dtype' , var .dtype )
414
- var_tmp ._share_underline_tensor_to (var )
415
- else :
416
- out_var ._share_underline_tensor_to (var )
393
+ out_var ._share_underline_tensor_to (var )
417
394
return None
418
395
else :
419
396
op = block .append_op (type = "gaussian_random" ,
420
- outputs = {"Out" : out_var },
397
+ outputs = {"Out" : var },
421
398
attrs = {
422
399
"shape" : var .shape ,
423
- "dtype" : out_dtype ,
400
+ "dtype" : var . dtype ,
424
401
"mean" : self ._mean ,
425
402
"std" : self ._std_dev ,
426
403
"seed" : self ._seed ,
427
404
"use_mkldnn" : False
428
405
},
429
406
stop_gradient = True )
430
-
431
- if var .dtype in [VarDesc .VarType .FP16 , VarDesc .VarType .BF16 ]:
432
- block .append_op (type = "cast" ,
433
- inputs = {"X" : out_var },
434
- outputs = {"Out" : var },
435
- attrs = {
436
- "in_dtype" : out_var .dtype ,
437
- "out_dtype" : var .dtype
438
- })
439
407
var .op = op
440
408
return op
441
409
@@ -695,7 +663,7 @@ def forward(self, var, block=None):
695
663
outputs = {"Out" : out_var },
696
664
attrs = {
697
665
"shape" : out_var .shape ,
698
- "dtype" : out_dtype ,
666
+ "dtype" : out_var . dtype ,
699
667
"mean" : 0.0 ,
700
668
"std" : std ,
701
669
"seed" : self ._seed
0 commit comments