Skip to content

Commit 4a151dd

Browse files
authored
Add activation registry (#126)
1 parent 057daef commit 4a151dd

File tree

5 files changed

+22
-13
lines changed

5 files changed

+22
-13
lines changed

cacheflow/entrypoints/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def generate(
6161
while self.llm_server.has_unfinished_requests():
6262
step_outputs = self.llm_server.step()
6363
for output in step_outputs:
64-
if output.done:
64+
if output.finished():
6565
outputs.append(output)
6666
if use_tqdm:
6767
pbar.update(1)

cacheflow/model_executor/layers/activation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44

55
from cacheflow import activation_ops
66

7+
_ACTIVATION_REGISTRY = {
8+
"gelu": nn.GELU(),
9+
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
10+
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
11+
"relu": nn.ReLU(),
12+
}
13+
14+
15+
def get_act_fn(act_fn: str) -> nn.Module:
16+
"""Get an activation function by name."""
17+
act_fn = act_fn.lower()
18+
if act_fn in _ACTIVATION_REGISTRY:
19+
return _ACTIVATION_REGISTRY[act_fn]
20+
raise ValueError(f"Activation function {act_fn!r} is not supported.")
21+
722

823
class SiluAndMul(nn.Module):
924
"""An activation function for SwiGLU.

cacheflow/model_executor/models/gpt2.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers import GPT2Config
2828

2929
from cacheflow.model_executor.input_metadata import InputMetadata
30+
from cacheflow.model_executor.layers.activation import get_act_fn
3031
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
3132
from cacheflow.model_executor.layers.sampler import Sampler
3233
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@@ -92,12 +93,7 @@ def __init__(
9293
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
9394
bias=True, input_is_parallel=True,
9495
perform_initialization=False)
95-
96-
act_fn = config.activation_function
97-
if act_fn != "gelu_new":
98-
raise ValueError(f"Unsupported activation: {act_fn}. "
99-
"GPT-2 only supports gelu_new for now.")
100-
self.act = torch.nn.GELU(approximate="tanh")
96+
self.act = get_act_fn(config.activation_function)
10197

10298
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
10399
hidden_states, _ = self.c_fc(hidden_states)

cacheflow/model_executor/models/gpt_neox.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from transformers import GPTNeoXConfig
2727

2828
from cacheflow.model_executor.input_metadata import InputMetadata
29+
from cacheflow.model_executor.layers.activation import get_act_fn
2930
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
3031
from cacheflow.model_executor.layers.sampler import Sampler
3132
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@@ -94,10 +95,7 @@ def __init__(self, config: GPTNeoXConfig):
9495
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
9596
input_is_parallel=True,
9697
perform_initialization=False)
97-
if config.hidden_act != 'gelu':
98-
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
99-
'Only gelu is supported for now.')
100-
self.act = torch.nn.GELU()
98+
self.act = get_act_fn(config.hidden_act)
10199

102100
def forward(self, hidden_states):
103101
hidden_states, _ = self.dense_h_to_4h(hidden_states)

cacheflow/model_executor/models/opt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from transformers import OPTConfig
2727

2828
from cacheflow.model_executor.input_metadata import InputMetadata
29+
from cacheflow.model_executor.layers.activation import get_act_fn
2930
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
3031
from cacheflow.model_executor.layers.sampler import Sampler
3132
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@@ -105,8 +106,7 @@ def __init__(self, config: OPTConfig):
105106
bias=config.enable_bias,
106107
)
107108
self.do_layer_norm_before = config.do_layer_norm_before
108-
assert config.activation_function == 'relu'
109-
self.activation_fn = nn.ReLU()
109+
self.activation_fn = get_act_fn(config.activation_function)
110110

111111
self.self_attn_layer_norm = nn.LayerNorm(
112112
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)

0 commit comments

Comments
 (0)