@@ -679,20 +679,23 @@ class MSRAInitializer(Initializer):
679679
680680 .. math::
681681
682- x = \ sqrt{\\ frac{6.0 }{fan\_in}}
682+ x = gain \times \ sqrt{\frac{3 }{fan\_in}}
683683
684684 In case of Normal distribution, the mean is 0 and the standard deviation
685685 is
686686
687687 .. math::
688688
689- \sqrt{\\ frac{2.0}{ fan\_in}}
689+ \frac{gain}{\sqrt{{ fan\_in} }}
690690
691691 Args:
692692 uniform (bool): whether to use uniform or normal distribution
693- fan_in (float32|None): fan_in for MSRAInitializer. If None, it is\
694- inferred from the variable. default is None.
693+ fan_in (float32|None): fan_in (in_features) of trainable Tensor,\
694+ If None, it will be infered automaticly. If you don't want to use in_features of the Tensor,\
695+ you can set the value of 'fan_in' smartly by yourself. default is None.
695696 seed (int32): random seed
697+ negative_slope (float, optional): negative_slope (only used with leaky_relu). default is 0.0.
698+ nonlinearity(str, optional): the non-linear function. default is relu.
696699
697700 Note:
698701 It is recommended to set fan_in to None for most cases.
@@ -709,7 +712,12 @@ class MSRAInitializer(Initializer):
709712
710713 """
711714
712- def __init__ (self , uniform = True , fan_in = None , seed = 0 ):
715+ def __init__ (self ,
716+ uniform = True ,
717+ fan_in = None ,
718+ seed = 0 ,
719+ negative_slope = 0 ,
720+ nonlinearity = 'relu' ):
713721 """Constructor for MSRAInitializer
714722 """
715723 assert uniform is not None
@@ -718,6 +726,8 @@ def __init__(self, uniform=True, fan_in=None, seed=0):
718726 self ._uniform = uniform
719727 self ._fan_in = fan_in
720728 self ._seed = seed
729+ self ._negative_slope = negative_slope
730+ self ._nonlinearity = nonlinearity
721731
722732 def __call__ (self , var , block = None ):
723733 """Initialize the input tensor with MSRA initialization.
@@ -759,13 +769,16 @@ def __call__(self, var, block=None):
759769
760770 if framework ._non_static_mode ():
761771 if self ._uniform :
762- limit = np .sqrt (6.0 / float (fan_in ))
772+ gain = calculate_gain (self ._nonlinearity , self ._negative_slope )
773+ limit = gain * math .sqrt (3.0 / float (fan_in ))
774+
763775 out_var = _C_ops .uniform_random ('shape' , out_var .shape , 'min' ,
764776 - limit , 'max' , limit , 'seed' ,
765777 self ._seed , 'dtype' ,
766778 int (out_dtype ))
767779 else :
768- std = math .sqrt (2.0 / float (fan_in ))
780+ gain = calculate_gain (self ._nonlinearity , self ._negative_slope )
781+ std = gain / math .sqrt (float (fan_in ))
769782 if in_dygraph_mode ():
770783 place = _current_expected_place ()
771784 out_var = _C_ops .final_state_gaussian_random (
@@ -786,33 +799,33 @@ def __call__(self, var, block=None):
786799 return None
787800 else :
788801 if self ._uniform :
789- limit = np . sqrt ( 6.0 / float ( fan_in ) )
790- op = block . append_op (
791- type = "uniform_random" ,
792- inputs = {},
793- outputs = {"Out" : out_var },
794- attrs = {
795- "shape" : out_var .shape ,
796- "dtype" : int (out_dtype ),
797- "min" : - limit ,
798- "max" : limit ,
799- "seed" : self ._seed
800- },
801- stop_gradient = True )
802+ gain = calculate_gain ( self . _nonlinearity , self . _negative_slope )
803+ limit = gain * math . sqrt ( 3.0 / float ( fan_in ))
804+ op = block . append_op ( type = "uniform_random" ,
805+ inputs = {},
806+ outputs = {"Out" : out_var },
807+ attrs = {
808+ "shape" : out_var .shape ,
809+ "dtype" : int (out_dtype ),
810+ "min" : - limit ,
811+ "max" : limit ,
812+ "seed" : self ._seed
813+ },
814+ stop_gradient = True )
802815
803816 else :
804- std = np . sqrt ( 2.0 / float ( fan_in ) )
805- op = block . append_op (
806- type = "gaussian_random" ,
807- outputs = {"Out" : out_var },
808- attrs = {
809- "shape" : out_var .shape ,
810- "dtype" : int (out_dtype ),
811- "mean" : 0.0 ,
812- "std" : std ,
813- "seed" : self ._seed
814- },
815- stop_gradient = True )
817+ gain = calculate_gain ( self . _nonlinearity , self . _negative_slope )
818+ std = gain / math . sqrt ( float ( fan_in ))
819+ op = block . append_op ( type = "gaussian_random" ,
820+ outputs = {"Out" : out_var },
821+ attrs = {
822+ "shape" : out_var .shape ,
823+ "dtype" : int (out_dtype ),
824+ "mean" : 0.0 ,
825+ "std" : std ,
826+ "seed" : self ._seed
827+ },
828+ stop_gradient = True )
816829
817830 if var .dtype == VarDesc .VarType .FP16 or (
818831 var .dtype == VarDesc .VarType .BF16 and not self ._uniform ):
0 commit comments