Skip to content

[Distributed] Add support for torchchat checkpoint format #1268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from torchchat.distributed.logging_utils import SingletonLogger

# TODO - these are not distributed specific, consider moving to new package
from torchchat.distributed.safetensor_utils import (
from torchchat.distributed.checkpoint_utils import (
get_hf_config_file,
get_hf_weight_map_and_path,
load_safetensor_weights,
load_weights_from_hf_format,
load_weights_from_torchchat_format,
)
from torchchat.distributed.utils import (
bytes_to_readable,
Expand Down Expand Up @@ -129,26 +129,33 @@ def _build_chat_tokenizer(
return tokenizer


def _load_model_weights(stage_module, distribution, device, model_config):
def _load_model_weights(
stage_module: torch.nn.Module,
distribution: str,
device: torch.device,
model_config: ModelArgs,
chpt_from: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for better clarity, 'chkpt' is a much better abbreviation for checkpoint. chpt is short for 'chapter' which is confusing.

):
"""Load the weights from the safetensor file(s) into the model stage.
Model config is needed b/c we permute wq and wk weights based on attn heads.
"""

weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)

num_loaded_weights, num_missing_weights = load_safetensor_weights(
stage_module,
weight_map,
weight_path,
key_map,
device,
model_config=model_config,
)
logger.info(
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
)
if num_missing_weights > 0:
raise ValueError(f"Missing {num_missing_weights} weights")
Args:
stage_module (torch.nn.Module): The model stage to load the weights into.
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
device (torch.device): The device to load the weights onto.
model_config (ModelArgs): The model config.
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
"""
if chpt_from == "hf":
# This format stands for: index file + multiple binary files
load_weights_from_hf_format(stage_module, distribution, device, model_config)
elif chpt_from == "torchchat":
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")


def _encode_strings(
Expand Down Expand Up @@ -306,7 +313,7 @@ def main(args):
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")

# Model-level config
model_config = ModelArgs.from_name(distribution)
Expand Down Expand Up @@ -368,7 +375,7 @@ def main(args):
# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
_load_model_weights(model, distribution, device=device, model_config=config)
_load_model_weights(model, distribution, device, config, args.chpt_from)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
Expand Down Expand Up @@ -602,6 +609,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
default=False,
help="Whether to decode token into string in flight",
)
parser.add_argument(
"--chpt-from",
type=str,
default="hf", # TODO: change to torchchat once we support it well
help="Checkpoint format to load from",
choices=["hf", "torchchat"],
)
args = parser.parse_args()

main(args)
16 changes: 11 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
return model


def _load_model_default(builder_args: BuilderArgs) -> Model:
assert not builder_args.gguf_path

model: Model = _init_model_on_meta_device(builder_args)

def _load_checkpoint(builder_args: BuilderArgs):
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
print("Loading Tune checkpoint")
meta_checkpoint = torch.load(
Expand Down Expand Up @@ -377,6 +373,16 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
mmap=True,
weights_only=True,
)
return checkpoint


def _load_model_default(builder_args: BuilderArgs) -> Model:
assert not builder_args.gguf_path

model: Model = _init_model_on_meta_device(builder_args)

# Load checkpoint from filesystem
checkpoint = _load_checkpoint(builder_args)

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]
Expand Down
Loading
Loading