File tree Expand file tree Collapse file tree 5 files changed +22
-13
lines changed
Expand file tree Collapse file tree 5 files changed +22
-13
lines changed Original file line number Diff line number Diff line change @@ -61,7 +61,7 @@ def generate(
6161 while self .llm_server .has_unfinished_requests ():
6262 step_outputs = self .llm_server .step ()
6363 for output in step_outputs :
64- if output .done :
64+ if output .finished () :
6565 outputs .append (output )
6666 if use_tqdm :
6767 pbar .update (1 )
Original file line number Diff line number Diff line change 44
55from cacheflow import activation_ops
66
7+ _ACTIVATION_REGISTRY = {
8+ "gelu" : nn .GELU (),
9+ "gelu_new" : nn .GELU (approximate = "tanh" ), # NOTE: This may introduce small rounding errors.
10+ "gelu_fast" : nn .GELU (approximate = "tanh" ), # NOTE: This may introduce small rounding errors.
11+ "relu" : nn .ReLU (),
12+ }
13+
14+
15+ def get_act_fn (act_fn : str ) -> nn .Module :
16+ """Get an activation function by name."""
17+ act_fn = act_fn .lower ()
18+ if act_fn in _ACTIVATION_REGISTRY :
19+ return _ACTIVATION_REGISTRY [act_fn ]
20+ raise ValueError (f"Activation function { act_fn !r} is not supported." )
21+
722
823class SiluAndMul (nn .Module ):
924 """An activation function for SwiGLU.
Original file line number Diff line number Diff line change 2727from transformers import GPT2Config
2828
2929from cacheflow .model_executor .input_metadata import InputMetadata
30+ from cacheflow .model_executor .layers .activation import get_act_fn
3031from cacheflow .model_executor .layers .attention import GPTCacheFlowAttention
3132from cacheflow .model_executor .layers .sampler import Sampler
3233from cacheflow .model_executor .weight_utils import (hf_model_weights_iterator ,
@@ -92,12 +93,7 @@ def __init__(
9293 self .c_proj = RowParallelLinear (intermediate_size , hidden_size ,
9394 bias = True , input_is_parallel = True ,
9495 perform_initialization = False )
95-
96- act_fn = config .activation_function
97- if act_fn != "gelu_new" :
98- raise ValueError (f"Unsupported activation: { act_fn } . "
99- "GPT-2 only supports gelu_new for now." )
100- self .act = torch .nn .GELU (approximate = "tanh" )
96+ self .act = get_act_fn (config .activation_function )
10197
10298 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
10399 hidden_states , _ = self .c_fc (hidden_states )
Original file line number Diff line number Diff line change 2626from transformers import GPTNeoXConfig
2727
2828from cacheflow .model_executor .input_metadata import InputMetadata
29+ from cacheflow .model_executor .layers .activation import get_act_fn
2930from cacheflow .model_executor .layers .attention import GPTNeoXCacheFlowAttention
3031from cacheflow .model_executor .layers .sampler import Sampler
3132from cacheflow .model_executor .weight_utils import (hf_model_weights_iterator ,
@@ -94,10 +95,7 @@ def __init__(self, config: GPTNeoXConfig):
9495 self .dense_4h_to_h = RowParallelLinear (config .intermediate_size , config .hidden_size ,
9596 input_is_parallel = True ,
9697 perform_initialization = False )
97- if config .hidden_act != 'gelu' :
98- raise ValueError (f'Unsupported activation: { config .hidden_act } . '
99- 'Only gelu is supported for now.' )
100- self .act = torch .nn .GELU ()
98+ self .act = get_act_fn (config .hidden_act )
10199
102100 def forward (self , hidden_states ):
103101 hidden_states , _ = self .dense_h_to_4h (hidden_states )
Original file line number Diff line number Diff line change 2626from transformers import OPTConfig
2727
2828from cacheflow .model_executor .input_metadata import InputMetadata
29+ from cacheflow .model_executor .layers .activation import get_act_fn
2930from cacheflow .model_executor .layers .attention import GPTCacheFlowAttention
3031from cacheflow .model_executor .layers .sampler import Sampler
3132from cacheflow .model_executor .weight_utils import (hf_model_weights_iterator ,
@@ -105,8 +106,7 @@ def __init__(self, config: OPTConfig):
105106 bias = config .enable_bias ,
106107 )
107108 self .do_layer_norm_before = config .do_layer_norm_before
108- assert config .activation_function == 'relu'
109- self .activation_fn = nn .ReLU ()
109+ self .activation_fn = get_act_fn (config .activation_function )
110110
111111 self .self_attn_layer_norm = nn .LayerNorm (
112112 self .embed_dim , elementwise_affine = config .layer_norm_elementwise_affine )
You can’t perform that action at this time.
0 commit comments