Skip to content

Download fix #1366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,14 @@ def convert_hf_checkpoint(
config = TransformerArgs.from_params(config_args)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
# Find all candidate weight mapping index files
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
if len(model_map_json_matches):
model_map_json = model_map_json_matches[0]
else:
model_map_json = model_dir / "pytorch_model.bin.index.json"

# If there is no weight mapping, check for a consolidated model and
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
# Llama 3 has a consolidated model and tokenizer.
# Otherwise raise an error.
if not model_map_json.is_file():
if not model_map_json_matches:
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
Expand All @@ -68,11 +63,30 @@ def convert_hf_checkpoint(
return
else:
raise RuntimeError(
f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}"
f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}"
)

with open(model_map_json) as json_map:
bin_index = json.load(json_map)
# Load the json file(s) containing weight mapping
#
# NOTE: If there are multiple index files, there are two possibilities:
# 1. The files could be mapped to different weight format files (e.g. .bin
# vs .safetensors)
# 2. The files could be split subsets of the mappings that need to be
# merged
#
# In either case, we can simply keep the mappings where the target file is
# valid in the model dir.
bin_index = {}
for weight_map_file in model_map_json_matches:
with open(weight_map_file, "r") as handle:
weight_map = json.load(handle)
valid_mappings = {
k: model_dir / v
for (k, v) in weight_map.get("weight_map", {}).items()
if (model_dir / v).is_file()
}
bin_index.update(valid_mappings)
bin_files = set(bin_index.values())

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
Expand All @@ -96,7 +110,6 @@ def convert_hf_checkpoint(
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
return (
Expand Down
9 changes: 5 additions & 4 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ def _download_hf_snapshot(
model_info = model_info(model_config.distribution_path, token=hf_token)
model_fnames = [f.rfilename for f in model_info.siblings]

# Check the model config for preference between safetensors and pth
# Check the model config for preference between safetensors and pth/bin
has_pth = any(f.endswith(".pth") for f in model_fnames)
has_bin = any(f.endswith(".bin") for f in model_fnames)
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)

# If told to prefer safetensors, ignore pth files
# If told to prefer safetensors, ignore pth/bin files
if model_config.prefer_safetensors:
if not has_safetensors:
print(
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
file=sys.stderr,
)
exit(1)
ignore_patterns = "*.pth"
ignore_patterns = ["*.pth", "*.bin"]

# If the model has both, prefer pth files over safetensors
elif has_pth and has_safetensors:
elif (has_pth or has_bin) and has_safetensors:
ignore_patterns = "*safetensors*"

# Otherwise, download everything
Expand Down
Loading