Skip to content

Commit

Permalink
add sensitivity analysis tool for layer-wise FIT and Hessian trace (#592
Browse files Browse the repository at this point in the history
)

* add FIT and Hessian

* renamed files
  • Loading branch information
Hanxian97 authored and jainapurva committed Aug 7, 2024
1 parent 4962bdd commit 3de8748
Show file tree
Hide file tree
Showing 3 changed files with 420 additions and 0 deletions.
105 changes: 105 additions & 0 deletions torchao/quantization/prototype/mixed_precision/scripts/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import numpy as np
import os
from tqdm import tqdm
import transformers
from datasets import load_dataset
import random
from torch.nn.attention import SDPBackend, sdpa_kernel

def get_wikitext2(nsamples, seed, seqlen, tokenizer):
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")

trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc

def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_layers):

# store the history of trace for each layer
estimated_history=[]

# store the history of mean trace for each layer
estimated_mean = [[] for _ in range(num_layers)]
trace = [0.] * num_layers


for iteration in range(maxIter):
print("iteration: ",iteration)
trace_tmp = [0.] * num_layers

for i in tqdm(range(nsamples)):
inputs, targets = data[i]
inputs = inputs.to(device)
targets = targets.to(device)
model.zero_grad()
outputs = model(inputs)
logits = outputs.logits
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

grads = torch.autograd.grad(loss, model.parameters())

# Trace(Fisher Information Matrix) is calculated by the sum of the square of the gradient
for layerid in range(num_layers):
for (name, _), grad in zip(model.named_parameters(), grads):
if "."+str(layerid)+"." in name and ("self_attn" in name or "mlp" in name):
trace_tmp[layerid] += torch.sum(grad * grad).item()

# clean cache
model.zero_grad()
del grads
torch.cuda.empty_cache()

# calculate the mean of the trace on the calibration dataset
for t in range(num_layers):
trace[t] = trace_tmp[t] / float(nsamples)
estimated_mean[t].append(trace[t])

print("trace:",trace)
estimated_history.append(trace)

F_average = np.array([np.mean(i) for i in estimated_mean])
return F_average, estimated_mean, estimated_history

def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers):
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
model = model.to(device)
model.eval()

criterion = torch.nn.CrossEntropyLoss()

# load calibration dataset
seed = 0
trainloader, testloader = get_wikitext2(nsamples, seed, max_seqlen, tokenizer)

F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers)
print("Iteration Done")
print("avg_trace:", F_average)
print("estimated_mean:", estimated_mean)
print("estimated_history:", estimated_history)

if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')
parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model')
parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length')
parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate FIT')
parser.add_argument('--num_layers', type=int, default=32, help='The number of layers to calculate FIT.')
parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset')
args = parser.parse_args()
main(args.max_seqlen, args.checkpoint, args.nsamples, args.maxIter, args.num_layers)
151 changes: 151 additions & 0 deletions torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
import numpy as np
import os
from tqdm import tqdm
import transformers
from datasets import load_dataset
import random
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.autograd.functional import hvp

def group_product(xs, ys):
return [torch.sum(x * y) for (x, y) in zip(xs, ys)]

def get_wikitext2(nsamples, seed, seqlen, tokenizer):
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")

trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc.input_ids

def dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion):
model.zero_grad()
THv = [torch.zeros(p.size()).to(device) for p in params] # accumulate result

# Freeze all the parameters in the model
for param in model.parameters():
param.requires_grad = False

# Unfreeze the parameters of attention and MLP layers in layer 0
layer_ = model.model.layers[layerid]
for param in layer_.self_attn.parameters():
param.requires_grad = True
for param in layer_.mlp.parameters():
param.requires_grad = True

for i in tqdm(range(nsamples)):
torch.cuda.empty_cache()
inputs, labels = data[i]
inputs = inputs.to(device)
labels = labels.to(device)
# if use testloader:
# inputs = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device)
# labels = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device)
model.zero_grad()
outputs = model(inputs)
logits = outputs.logits
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

# get the first order gradients
grads = torch.autograd.grad(loss, params, create_graph=True, only_inputs=True)

# calculate Hessian vector product via Jac-vector product
Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=False)

THv = [THv1 + Hv1 + 0.0 for THv1, Hv1 in zip(THv, Hv)]

# clean cache
model.zero_grad()
del Hv
del grads
torch.cuda.empty_cache()

THv = [THv1 / float(nsamples) for THv1 in THv]
return THv

def cal_trace(layerid, params, device, data, nsamples, model, maxIter, max_seqlen, criterion):
vhv_c_history = []
trace_history = []
trace = 0.

for i in range(maxIter):
print("iteration: ",i)

# generate Rademacher random variables
v = [
torch.randint_like(p, high=2, device=device)
for p in params
]

for v_i in v:
v_i[v_i == 0] = -1

# calculate Hessian vector product
Hv = dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion)

vHv = group_product(Hv, v)

vHv_c = np.array([i.cpu().numpy() for i in vHv])

vhv_c_history.append(vHv_c)

trace = np.sum(vHv_c)

trace_history.append(trace)
print("trace,", trace)
print("trace_history,", trace_history)
print("vhv_c_history,", vhv_c_history)

return np.mean(trace_history)


def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples):
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# to avoid aten::_scaled_dot_product_flash_attention_backward not implemented error
with sdpa_kernel(SDPBackend.MATH):

# have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
model = model.cuda()
model.eval()

criterion = torch.nn.CrossEntropyLoss()

# load calibration dataset
seed = 0
trainloader, testloader = get_wikitext2(128, seed, 2048, tokenizer)

# calculate Hessian for only one layer each time
params=[]
layer_ = model.model.layers[layer_id]
for param in layer_.self_attn.parameters():
params.append(param)
for param in layer_.mlp.parameters():
params.append(param)

trace = cal_trace(layerid=layer_id, params=params, device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion)
print("The trace of layer " + str(layer_id) + " is", trace)

if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')
parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the trace and hessian')
parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model')
parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length')
parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace')
parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset')
args = parser.parse_args()
main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples)
Loading

0 comments on commit 3de8748

Please sign in to comment.