Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SqueezeLLM Support #1326

Merged
merged 19 commits into from
Oct 22, 2023
Prev Previous commit
Next Next commit
Fix Mistral
  • Loading branch information
WoosukKwon committed Oct 22, 2023
commit 8f0e1f7a81a802c49597d364dbf47f64af6d59e8
12 changes: 7 additions & 5 deletions vllm/model_executor/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,10 @@ def load_weights(self,
if "rotary_emb.inv_freq" in name:
continue

is_packed = False
packed_dim = None
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
Expand All @@ -348,9 +348,11 @@ def load_weights(self,
if is_transposed:
param = param.T

if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor

loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
Expand Down
Loading