Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script to convert Grok-1 weights from raw JAX pickle files. #7058

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Write tensors in layer order.
  • Loading branch information
heiner committed May 25, 2024
commit 0a1ef1127fc45649694e8d837759e00ae21c237e
31 changes: 21 additions & 10 deletions convert_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,20 @@ def get_dtype_and_ggml_type(name, tensor, ggml_type):


def dump_state_dict(f, ggml_type, input_dir, config):
weight_names = get_weight_names(config.num_hidden_layers)
weights = {}

# First operate on meta tensors to find shapes and dtypes for GGUF header.
for idx, name in enumerate(weight_names):
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
# Load weights in file order (mmap'ed).
for idx, name in enumerate(get_weight_names(config.num_hidden_layers)):
weights[name] = get_weights(f"{input_dir}/tensor{idx:05}_000")

logging.debug("Loaded %i files", len(weights))

# But write in layer order.
weight_names = get_weight_names(config.num_hidden_layers, lexicographic=False)

# Operate on meta tensors to find shapes and dtypes for GGUF header.
for name in weight_names:
weight, scales = weights[name]
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
Expand All @@ -213,8 +221,6 @@ def dump_state_dict(f, ggml_type, input_dir, config):
quantized_meta_tensor.nbytes,
tensor_ggml_type,
)
weights[name] = weight, scales
logging.debug("Loaded %i files", len(weight_names))

f.write_header_to_file()
f.write_kv_data_to_file()
Expand Down Expand Up @@ -244,7 +250,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
except NameError:
pass

if len(tensor_info) != len(weight_names):
if weights:
logging.warning("Not all tensors are converted")


Expand Down Expand Up @@ -293,8 +299,10 @@ def extract_vocabulary_from_model(vocab):
return tokens, scores, toktypes


def get_weight_names(num_hidden_layers=64):
"""Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
def get_weight_names(num_hidden_layers=64, lexicographic=True):
"""Return Grok-1 weight names.

If `lexicographic` is set, the order is as in the tensor#####_000 files."""

weight_names = [
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
Expand All @@ -317,7 +325,10 @@ def get_weight_names(num_hidden_layers=64):
)

layers = [str(bid) for bid in range(64)]
layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...

if lexicographic:
# Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
layers.sort()

for bid in layers[:num_hidden_layers]:
for key in layer:
Expand Down