Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script to convert Grok-1 weights from raw JAX pickle files. #7058

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Address review comments by foldl.
  • Loading branch information
heiner committed May 25, 2024
commit ef671c693d99dc7e460f54aa69503f01cd45f84d
32 changes: 9 additions & 23 deletions convert_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
# First operate on meta tensors to find shapes and dtypes for GGUF header.
for idx, name in enumerate(weight_names):
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info(
Expand All @@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):

for name in weight_names:
weight, scales = weights.pop(name)
tensor = convert_weight(name, weight, scales, config.experts)
tensor = maybe_permute_tensor(name, tensor, config)
tensor = convert_weight(name, weight, scales, config)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()

Expand Down Expand Up @@ -258,7 +257,7 @@ def from_numpy(array):
return torch.from_numpy(array)


def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None):
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
weight = from_numpy(weight).to(device=device, dtype=dtype)
if scales is not None:
Expand All @@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No
else:
weight = weight * scale

# Transpose linear matrix
if len(weight.shape) >= 2 and "token_embd" not in name:
if name == "token_embd":
weight *= config.embedding_multiplier_scale
elif len(weight.shape) >= 2:
# Transpose linear matrix
weight = weight.transpose(-1, -2)


if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
weight = weight[experts] # gather.
weight = weight[config.experts] # gather.

return weight


def maybe_permute_tensor(name, tensor, config):
def permute(weights, n_head):
return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)

if name.endswith("attn_k"):
return permute(tensor, config.num_key_value_heads)
elif name.endswith("attn_q"):
return permute(tensor, config.num_attention_heads)

return tensor


def extract_vocabulary_from_model(vocab):
tokens = []
scores = []
Expand Down