Skip to content

Commit

Permalink
Add starcoder2 support (AutoGPTQ#578)
Browse files Browse the repository at this point in the history
* add starcoder2 support

* add starcoder2 support
  • Loading branch information
TechxGenus authored Mar 18, 2024
1 parent 80b0571 commit 09289d8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 1 deletion.
1 change: 1 addition & 0 deletions auto_gptq/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .qwen2 import Qwen2GPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .stablelmepoch import StableLMEpochGPTQForCausalLM
from .starcoder2 import Starcoder2GPTQForCausalLM
from .xverse import XverseGPTQForCausalLM
from .yi import YiGPTQForCausalLM
3 changes: 2 additions & 1 deletion auto_gptq/modeling/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
SUPPORTED_MODELS.append("qwen2")
if compare_transformers_version("v4.38.0", op="ge"):
SUPPORTED_MODELS.append("gemma")

if compare_transformers_version("v4.39.0.dev0", op="ge"):
SUPPORTED_MODELS.append("starcoder2")

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

Expand Down
2 changes: 2 additions & 0 deletions auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .qwen2 import Qwen2GPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .stablelmepoch import StableLMEpochGPTQForCausalLM
from .starcoder2 import Starcoder2GPTQForCausalLM
from .xverse import XverseGPTQForCausalLM
from .yi import YiGPTQForCausalLM

Expand All @@ -48,6 +49,7 @@
"xverse": XverseGPTQForCausalLM,
"deci": DeciLMGPTQForCausalLM,
"stablelm_epoch": StableLMEpochGPTQForCausalLM,
"starcoder2": Starcoder2GPTQForCausalLM,
"mixtral": MixtralGPTQForCausalLM,
"qwen2": Qwen2GPTQForCausalLM,
"longllama": LongLlamaGPTQForCausalLM,
Expand Down
21 changes: 21 additions & 0 deletions auto_gptq/modeling/starcoder2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from logging import getLogger

from ._base import BaseGPTQForCausalLM


logger = getLogger(__name__)


class Starcoder2GPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "Starcoder2DecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.c_fc"],
["mlp.c_proj"],
]


__all__ = ["Starcoder2GPTQForCausalLM"]

0 comments on commit 09289d8

Please sign in to comment.