Skip to content

Commit

Permalink
[Llama] Make torchao's Llama trainable (#728)
Browse files Browse the repository at this point in the history
* initial change

* skip safetensors weights

* update quantized training script

* add activation checkpointing
  • Loading branch information
gau-nernst authored Aug 22, 2024
1 parent 99644e9 commit 8002099
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 21 deletions.
38 changes: 26 additions & 12 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
# pip install transformers sentencepiece wandb
# pip install huggingface_hub sentencepiece wandb
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
Expand All @@ -9,21 +9,33 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
from functools import partial
from pathlib import Path

import numpy as np
import torch
import wandb
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM

from torchao._models.llama.model import ModelArgs, Transformer
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.quantization.quant_api import quantize_


def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
return model(batch, labels=batch).loss
# hack from fairseq
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
def enable_activation_checkpointing(m: torch.nn.Module):
assert not hasattr(m, "_forward")
m._forward = m.forward
m.forward = partial(checkpoint, m.forward)


def get_loss(model: Transformer, batch: torch.Tensor):
logits = model(batch)[:, :-1].flatten(0, 1)
labels = batch[:, 1:].flatten()
return torch.nn.functional.cross_entropy(logits, labels)


def get_tinystories():
Expand Down Expand Up @@ -91,17 +103,19 @@ def get_tinystories():
if args.seed is not None:
torch.manual_seed(args.seed)

config = LlamaConfig(
hidden_size=args.d_model,
config = ModelArgs(
block_size=args.seq_len,
n_layer=args.depth,
n_head=args.d_model // args.head_dim,
dim=args.d_model,
intermediate_size=args.ffn_size,
num_hidden_layers=args.depth,
num_attention_heads=args.d_model // args.head_dim,
max_position_embeddings=args.seq_len,
use_cache=False,
)
model = LlamaForCausalLM(config).bfloat16().cuda()
model = Transformer(config).bfloat16().cuda()
with torch.device("cuda"):
model.setup_caches(args.batch_size, args.seq_len, training=True)
if args.activation_checkpointing:
model.gradient_checkpointing_enable()
for layer in model.layers:
enable_activation_checkpointing(layer)
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize is not None:
Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
from huggingface_hub import snapshot_download
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
Expand Down
24 changes: 16 additions & 8 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, config: ModelArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
def setup_caches(self, max_batch_size, max_seq_length, training: bool = False):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
Expand All @@ -163,16 +163,21 @@ def setup_caches(self, max_batch_size, max_seq_length):
dtype = self.output.scales.dtype
elif hasattr(self.output, "scales_and_zeros"):
dtype = self.output.scales_and_zeros.dtype
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
if not training:
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
if input_pos is not None:
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
else:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
x = self.tok_embeddings(idx)

for i, layer in enumerate(self.layers):
Expand All @@ -194,7 +199,7 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
def forward(self, x: Tensor, input_pos: Optional[Tensor], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -224,7 +229,7 @@ def load_hook(self, state_dict, prefix, *args):
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
Expand All @@ -244,7 +249,10 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
if mask is not None:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
else:
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down

0 comments on commit 8002099

Please sign in to comment.