Open
Description
When I replace the output layer for llama3.1 70B
nn.Linear(8192, 128_256, bias=False)
with FrozenNF4Linear(8192, 128_256, bias=False)
in torchtune, I surprisingly end up using a lot more memory. Leaving the output layer in bf16 results in the training run using ~43gb of peak memory active, while quantizing the output results in ~52gb active. I wonder if this is due to the large size of the output layer.
Steps to reproduce:
- Replace nn.Linear with FrozenNF4Linear in the model here (FrozenNF4Linear is just a linear_nf4 wrapper)
- tune conifg here
- command:
tune run lora_finetune_single_device --config ./70B_qlora_long_context.yaml
tokenizer.max_seq_len=8192