-
Notifications
You must be signed in to change notification settings - Fork 78
Open
Description
It would be helpful if we could pass the data type instead of defaulting to dtype=torch.bfloat16.
I managed to get this running using the 'cpu' device with OpenMP by hacking the setting to dtype=torch.float32 but with the default dtype=torch.bfloat16 it just sat all night using 1 core and never progressed further (on an older Xeon without native bf16 support, so likely it was trying to upcast into fp32).
I also tried preloading the model to use 4bit (with 'cuda' device), which should work for llama models:
if hf_model is not None:
hf_cfg = hf_model.config.to_dict()
qc = hf_cfg.get("quantization_config", {})
load_in_4bit = qc.get("load_in_4bit", False)
load_in_8bit = qc.get("load_in_8bit", False)
quant_method = qc.get("quant_method", "")
assert not load_in_8bit, "8-bit quantization is not supported"
assert not (
load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
), "Quantization is only supported for torch versions >= 2.1.1"
assert not (
load_in_4bit and ("llama" not in model_name.lower())
), "Quantization is only supported for Llama models"
if load_in_4bit:
assert (
qc.get("quant_method", "") == "bitsandbytes"
), "Only bitsandbytes quantization is supported"
else:
hf_cfg = {}But I got the same shape mismatch exception as mentioned in this thread:
TransformerLensOrg/TransformerLens#569
Might be worth adding the ability to use 4bit if they fix this bug.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels