Skip to content

Commit

Permalink
[REFRACTOR] gptqmodel_post_init (ModelCloud#103)
Browse files Browse the repository at this point in the history
* refractor gptqmodel_post_init and mod marlin support load shard

* mod clean up

* revert code

* mod quanttype to qlinear cls

* mod clean up

* rename qlinear cls
  • Loading branch information
PZS-ModelCloud authored Jun 29, 2024
1 parent ede2154 commit 824f5d9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 37 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ class DeepSeekV2GPTQ(BaseGPTQModel):
# included in layer 1-59
["mlp.shared_experts.gate_proj", "mlp.shared_experts.up_proj"],
["mlp.shared_experts.down_proj"],
]
]
58 changes: 22 additions & 36 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from ..models._const import CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear as ExllamaV2QuantLinear
from ..quantization import FORMAT, QuantizeConfig
from .backend import Backend
from .importer import select_quant_linear
Expand Down Expand Up @@ -393,37 +395,28 @@ def simple_dispatch_model(model, device_map):

return model


# TODO: refractor. very strange post_init has to re-determine qlinear type again
# when qliear type is selected, it should auto-override the model post_init method and
# not have to go about looping over modules to match qlinear type a second time as it is
# very prone to bugs
def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None):
"""
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
"""

# post init for bitblas backend.
device_to_buffers_size = {}
for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "bitblas":
submodule.post_init()

# exllama
model_uses_exllama = False
# exllamav2
fixed_bytes = {}
model_uses_exllamav2 = False

for name, submodule in model.named_modules():
if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllama":
if isinstance(submodule, ExllamaQuantLinear):
model_uses_exllama = True
device = submodule.qweight.device
if device not in device_to_buffers_size:
device_to_buffers_size[device] = {
"max_dq_buffer_size": 1,
"max_inner_outer_dim": 1,
}

if not use_act_order:
submodule._use_act_order = False
else:
submodule._use_act_order = True
submodule._use_act_order = True if use_act_order else False

# Disable this heuristic for detecting act_order, but it could be used instead of the config.
"""
Expand All @@ -447,6 +440,11 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i
submodule.infeatures,
submodule.outfeatures,
)
elif isinstance(submodule, ExllamaV2QuantLinear):
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0))

if model_uses_exllama:
# To be honest this is quite ugly, not proud of this.
Expand Down Expand Up @@ -496,22 +494,6 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

# The buffers need to have been initialized first before calling make_q4.
for name, submodule in model.named_modules():
if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllama":
submodule.post_init()

# exllamav2
fixed_bytes = {}
model_uses_exllamav2 = False

for _, submodule in model.named_modules():
if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllamav2":
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0))

if model_uses_exllamav2:
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors

Expand All @@ -522,10 +504,14 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i
# have persistent buffers, otherwise we will get OOM
model.device_tensors = device_tensors

for _, submodule in model.named_modules():
if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllamav2":
device = submodule.qweight.device
submodule.post_init(temp_dq=model.device_tensors[device])
# The buffers need to have been initialized first before calling make_q4.
for _, submodule in model.named_modules():
if isinstance(submodule, ExllamaV2QuantLinear):
device = submodule.qweight.device
submodule.post_init(temp_dq=model.device_tensors[device])
elif isinstance(submodule, BaseQuantLinear):
submodule.post_init()

torch.cuda.empty_cache()

return model
Expand Down

0 comments on commit 824f5d9

Please sign in to comment.