4848from vllm .sequence import SamplerOutput
4949
5050
51+ @torch .compile
52+ def layer_norm_func (hidden_states , weight , variance_epsilon ):
53+ input_dtype = hidden_states .dtype
54+ hidden_states = hidden_states .to (torch .float32 )
55+ mean = hidden_states .mean (- 1 , keepdim = True )
56+ variance = (hidden_states - mean ).pow (2 ).mean (- 1 , keepdim = True )
57+ hidden_states = (hidden_states - mean ) * torch .rsqrt (variance +
58+ variance_epsilon )
59+ hidden_states = weight .to (torch .float32 ) * hidden_states
60+ return hidden_states .to (input_dtype )
61+
62+
5163class LayerNorm (nn .Module ):
5264
5365 def __init__ (self , param_shape = None , eps = 1e-5 ):
@@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5):
5769 set_weight_attrs (self .weight , {"weight_loader" : self .weight_loader })
5870
5971 def forward (self , hidden_states , residuals = None ):
60- input_dtype = hidden_states .dtype
61- hidden_states = hidden_states .to (torch .float32 )
62- mean = hidden_states .mean (- 1 , keepdim = True )
63- variance = (hidden_states - mean ).pow (2 ).mean (- 1 , keepdim = True )
64- hidden_states = (hidden_states -
65- mean ) * torch .rsqrt (variance + self .variance_epsilon )
66- hidden_states = self .weight .to (torch .float32 ) * hidden_states
67- return hidden_states .to (input_dtype ), residuals
72+ hidden_states = layer_norm_func (hidden_states , self .weight ,
73+ self .variance_epsilon )
74+ return hidden_states , residuals
6875
6976 def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
7077 tp_rank = get_tensor_model_parallel_rank ()
0 commit comments