Description
🚀 The feature, motivation and pitch
There are many models on Huggingface that are published as safetensors
rather than model.pth
checkpoints. The request here is to support converting and loading those checkpoints into a format that is usable with torchchat
.
There are several places where this limitation is currently enforced:
- _download_hf_snapshot method explicitly ignores
safetensors
files. - convert_hf_checkpoint explicitly looks for
pytorch_model.bin.index.json
which would be named differently for models that usesafetensors
(e.g.model.safetensors.index.json
) - convert_hf_checkpoint only supports
torch.load
to load thestate_dict
rather thansafetensors.torch.load
Alternatives
Currently, this safetensors
-> model.pth
can be accomplished manually after downloading a model locally, so this could be solved with documentation instead of code.
Additional context
This issue is a piece of the puzzle for adding support for Granite Code 3b/8b which use the llama
architecture in transormers
, but take advantage several pieces of the architecture that are not currently supported by torchchat
. The work-in-progress for Granite Code can be found on my fork: https://github.com/gabe-l-hart/torchchat/tree/GraniteCodeSupport
RFC (Optional)
I have a working implementation to support safetensors
during download and conversion that I plan to submit as a PR. The changes address the three points in code referenced above:
- Allow the download of
safetensors
files in_download_hf_snapshot
- I'm not yet sure how to avoid double-downloading weights for models that have both
safetensors
andmodel.pth
, so will look to solve this before concluding the work
- I'm not yet sure how to avoid double-downloading weights for models that have both
- When looking for the tensor index file, search for all files ending in
.index.json
, and if a single file is found, use that one - When loading the
state_dict
, use the correct method based on the type of file (torch.load
orsafetensors.torch.load
)