Skip to content

Allow us to set the dtype (for 'cpu' device) #23

@jukofyork

Description

@jukofyork

It would be helpful if we could pass the data type instead of defaulting to dtype=torch.bfloat16.

I managed to get this running using the 'cpu' device with OpenMP by hacking the setting to dtype=torch.float32 but with the default dtype=torch.bfloat16 it just sat all night using 1 core and never progressed further (on an older Xeon without native bf16 support, so likely it was trying to upcast into fp32).


I also tried preloading the model to use 4bit (with 'cuda' device), which should work for llama models:

        if hf_model is not None:
            hf_cfg = hf_model.config.to_dict()
            qc = hf_cfg.get("quantization_config", {})
            load_in_4bit = qc.get("load_in_4bit", False)
            load_in_8bit = qc.get("load_in_8bit", False)
            quant_method = qc.get("quant_method", "")
            assert not load_in_8bit, "8-bit quantization is not supported"
            assert not (
                load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
            ), "Quantization is only supported for torch versions >= 2.1.1"
            assert not (
                load_in_4bit and ("llama" not in model_name.lower())
            ), "Quantization is only supported for Llama models"
            if load_in_4bit:
                assert (
                    qc.get("quant_method", "") == "bitsandbytes"
                ), "Only bitsandbytes quantization is supported"
        else:
            hf_cfg = {}

from: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/HookedTransformer.py

But I got the same shape mismatch exception as mentioned in this thread:

TransformerLensOrg/TransformerLens#569

Might be worth adding the ability to use 4bit if they fix this bug.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions