Skip to content

Commit f9afe0a

Browse files
[hotfix] Fix KV Heads Number Assignment in KVCacheManager (#5695)
- Fix key value number assignment in KVCacheManager, as well as method of accessing
1 parent 1ace106 commit f9afe0a

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

colossalai/inference/kv_cache/kvcache_manager.py

+6-17
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@
1515
GIGABYTE = 1024**3
1616

1717

18-
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
19-
if hasattr(config, attr_name):
20-
return getattr(config, attr_name)
21-
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
22-
return getattr(config, config.attribute_map[attr_name])
23-
raise AttributeError(f"{attr_name} is not found in config")
24-
25-
2618
class KVCacheManager:
2719
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
2820
@@ -53,7 +45,7 @@ class KVCacheManager:
5345
And it's possible to have a batch of sequences with different lengths of block tables.
5446
"""
5547

56-
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
48+
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> None:
5749
self.logger = get_dist_logger(__name__)
5850
self.device = get_current_device()
5951

@@ -62,14 +54,11 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
6254
# Model settings
6355
self.dtype = config.dtype
6456
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
65-
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
66-
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
67-
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
68-
69-
if hasattr(config, "num_key_value_heads"):
70-
self.kv_head_num = getattr(config, "num_key_value_heads")
71-
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
72-
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
57+
self.num_layers = model_config.num_hidden_layers
58+
self.head_num = model_config.num_attention_heads
59+
self.head_size = model_config.hidden_size // self.head_num
60+
if hasattr(model_config, "num_key_value_heads"):
61+
self.kv_head_num = model_config.num_key_value_heads
7362
else:
7463
self.kv_head_num = self.head_num
7564

colossalai/shardformer/policies/llama.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
141141
assert (
142142
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
143143
), f"The number of attention heads must be divisible by tensor parallel size."
144-
assert (
145-
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
146-
), f"The number of key_value heads must be divisible by tensor parallel size."
144+
if hasattr(self.model.config, "num_key_value_heads"):
145+
assert (
146+
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
147+
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
148+
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
147149
decoder_attribute_replacement = {
148150
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
149151
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,

0 commit comments

Comments
 (0)