Skip to content

Commit

Permalink
Fixing mistral nemo. (#2276)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored and ErikKaum committed Jul 26, 2024
1 parent 4bf3e59 commit d3ebcdc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 0 additions & 2 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,6 @@ def get_model(
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it`
head_size=config_dict["head_dim"],
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def __init__(self, prefix: str, config, weights, layer_id):
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = self.hidden_size // self.num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,12 @@ def __init__(
assert self.num_kv_heads > 0

if head_size is None:
self.head_size = config.hidden_size // config.num_attention_heads
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(config, "head_dim"):
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size

Expand Down

0 comments on commit d3ebcdc

Please sign in to comment.