|
| 1 | +import re |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer |
| 7 | +from transformers.utils.quantization_config import Mxfp4Config |
| 8 | + |
| 9 | +from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert |
| 10 | +from QEfficient.utils.logging_utils import logger |
| 11 | + |
| 12 | + |
| 13 | +class QEffMxfp4GptOssExperts(nn.Module): |
| 14 | + def __init__(self, config): |
| 15 | + super().__init__() |
| 16 | + self.config = config |
| 17 | + |
| 18 | + self.num_experts = config.num_local_experts |
| 19 | + self.intermediate_size = config.intermediate_size |
| 20 | + self.hidden_size = config.hidden_size |
| 21 | + |
| 22 | + self.gate_up_proj_blocks = nn.Parameter( |
| 23 | + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), |
| 24 | + requires_grad=False, |
| 25 | + ) |
| 26 | + self.gate_up_proj_scales = nn.Parameter( |
| 27 | + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), |
| 28 | + requires_grad=False, |
| 29 | + ) |
| 30 | + self.gate_up_proj_bias = nn.Parameter( |
| 31 | + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False |
| 32 | + ) |
| 33 | + |
| 34 | + self.down_proj_blocks = nn.Parameter( |
| 35 | + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), |
| 36 | + requires_grad=False, |
| 37 | + ) |
| 38 | + self.down_proj_scales = nn.Parameter( |
| 39 | + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), |
| 40 | + requires_grad=False, |
| 41 | + ) |
| 42 | + self.down_proj_bias = nn.Parameter( |
| 43 | + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False |
| 44 | + ) |
| 45 | + self.alpha = 1.702 |
| 46 | + self.limit = 7.0 |
| 47 | + |
| 48 | + self.gate_up_proj_precision_config = None |
| 49 | + self.down_proj_precision_config = None |
| 50 | + |
| 51 | + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: |
| 52 | + gate_up_proj = convert_moe_packed_tensors( |
| 53 | + self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32 |
| 54 | + ) |
| 55 | + down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32) |
| 56 | + batch_size = hidden_states.shape[0] |
| 57 | + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) |
| 58 | + num_experts = routing_weights.shape[1] |
| 59 | + hidden_states = hidden_states.repeat(num_experts, 1) |
| 60 | + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) |
| 61 | + gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :] |
| 62 | + gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
| 63 | + gate = gate.clamp(min=None, max=self.limit) |
| 64 | + up = up.clamp(min=-self.limit, max=self.limit) |
| 65 | + glu = gate * torch.sigmoid(gate * self.alpha) |
| 66 | + next_states = torch.bmm(((up + 1) * glu), down_proj) |
| 67 | + next_states = next_states + self.down_proj_bias[..., None, :] |
| 68 | + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) |
| 69 | + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] |
| 70 | + next_states = next_states.sum(dim=0) |
| 71 | + return next_states |
| 72 | + |
| 73 | + |
| 74 | +def should_convert_module(current_key_name, patterns): |
| 75 | + current_key_name_str = ".".join(current_key_name) |
| 76 | + if not any( |
| 77 | + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns |
| 78 | + ): |
| 79 | + return True |
| 80 | + return False |
| 81 | + |
| 82 | + |
| 83 | +class QEffMxfp4Config(Mxfp4Config): |
| 84 | + """ |
| 85 | + Currently there is not need to change the implementation of Mxfp4Config |
| 86 | + This is placeholder for future when we would want to change this |
| 87 | + """ |
| 88 | + |
| 89 | + pass |
| 90 | + |
| 91 | + |
| 92 | +class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer): |
| 93 | + def validate_environment(self, *args, **kwargs): |
| 94 | + return True |
| 95 | + |
| 96 | + def update_torch_dtype(self, torch_dtype): |
| 97 | + if torch_dtype not in [None, torch.float32]: |
| 98 | + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") |
| 99 | + return None |
| 100 | + |
| 101 | + def _process_model_before_weight_loading( |
| 102 | + self, |
| 103 | + model: torch.nn.Module, |
| 104 | + keep_in_fp32_modules: Optional[list[str]] = None, |
| 105 | + **kwargs, |
| 106 | + ): |
| 107 | + self.modules_to_not_convert = get_keys_to_not_convert(model) |
| 108 | + self.modules_to_not_convert = ( |
| 109 | + ["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert |
| 110 | + ) |
| 111 | + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) |
| 112 | + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) |
| 113 | + config = model.config |
| 114 | + |
| 115 | + # -- Defining local method as it uses lot of local variables -- |
| 116 | + def _replace_with_mxfp4_linear( |
| 117 | + model, |
| 118 | + modules_to_not_convert=None, |
| 119 | + current_key_name=None, |
| 120 | + quantization_config=None, |
| 121 | + has_been_replaced=False, |
| 122 | + ): |
| 123 | + if current_key_name is None: |
| 124 | + current_key_name = [] |
| 125 | + |
| 126 | + for name, module in model.named_children(): |
| 127 | + current_key_name.append(name) |
| 128 | + if not should_convert_module(current_key_name, modules_to_not_convert): |
| 129 | + current_key_name.pop(-1) |
| 130 | + continue |
| 131 | + if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: |
| 132 | + model._modules[name] = QEffMxfp4GptOssExperts(config) |
| 133 | + has_been_replaced = True |
| 134 | + if len(list(module.children())) > 0: |
| 135 | + _, has_been_replaced = _replace_with_mxfp4_linear( |
| 136 | + module, |
| 137 | + modules_to_not_convert, |
| 138 | + current_key_name, |
| 139 | + quantization_config, |
| 140 | + has_been_replaced=has_been_replaced, |
| 141 | + ) |
| 142 | + current_key_name.pop(-1) |
| 143 | + return model, has_been_replaced |
| 144 | + |
| 145 | + _replace_with_mxfp4_linear( |
| 146 | + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config |
| 147 | + ) |
| 148 | + model.config.quantization_config = self.quantization_config |
0 commit comments