15
15
GIGABYTE = 1024 ** 3
16
16
17
17
18
- def get_model_config_attr (config : PretrainedConfig , attr_name : str ):
19
- if hasattr (config , attr_name ):
20
- return getattr (config , attr_name )
21
- elif hasattr (config , "attribute_map" ) and hasattr (config , config .attribute_map [attr_name ]):
22
- return getattr (config , config .attribute_map [attr_name ])
23
- raise AttributeError (f"{ attr_name } is not found in config" )
24
-
25
-
26
18
class KVCacheManager :
27
19
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
28
20
@@ -53,7 +45,7 @@ class KVCacheManager:
53
45
And it's possible to have a batch of sequences with different lengths of block tables.
54
46
"""
55
47
56
- def __init__ (self , config : InferenceConfig , model_config : PretrainedConfig , verbose : bool = False ) -> None :
48
+ def __init__ (self , config : InferenceConfig , model_config : PretrainedConfig ) -> None :
57
49
self .logger = get_dist_logger (__name__ )
58
50
self .device = get_current_device ()
59
51
@@ -62,14 +54,11 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
62
54
# Model settings
63
55
self .dtype = config .dtype
64
56
self .elem_size_in_bytes = torch .tensor ([], dtype = self .dtype ).element_size ()
65
- self .num_layers = get_model_config_attr (model_config , "num_hidden_layers" )
66
- self .head_num = get_model_config_attr (model_config , "num_attention_heads" )
67
- self .head_size = get_model_config_attr (model_config , "hidden_size" ) // self .head_num
68
-
69
- if hasattr (config , "num_key_value_heads" ):
70
- self .kv_head_num = getattr (config , "num_key_value_heads" )
71
- elif hasattr (config , "attribute_map" ) and hasattr (config , config .attribute_map ["num_key_value_heads" ]):
72
- self .kv_head_num = getattr (config , config .attribute_map ["num_key_value_heads" ])
57
+ self .num_layers = model_config .num_hidden_layers
58
+ self .head_num = model_config .num_attention_heads
59
+ self .head_size = model_config .hidden_size // self .head_num
60
+ if hasattr (model_config , "num_key_value_heads" ):
61
+ self .kv_head_num = model_config .num_key_value_heads
73
62
else :
74
63
self .kv_head_num = self .head_num
75
64
0 commit comments