Skip to content

QLoRA Worse Memory When linear_nf4 Used on Output #1433

Open
@pbontrager

Description

@pbontrager

When I replace the output layer for llama3.1 70B

nn.Linear(8192, 128_256, bias=False) with FrozenNF4Linear(8192, 128_256, bias=False)

in torchtune, I surprisingly end up using a lot more memory. Leaving the output layer in bf16 results in the training run using ~43gb of peak memory active, while quantizing the output results in ~52gb active. I wonder if this is due to the large size of the output layer.

Steps to reproduce:

  • Replace nn.Linear with FrozenNF4Linear in the model here (FrozenNF4Linear is just a linear_nf4 wrapper)
  • tune conifg here
  • command: tune run lora_finetune_single_device --config ./70B_qlora_long_context.yaml tokenizer.max_seq_len=8192

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions