-
Notifications
You must be signed in to change notification settings - Fork 169
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sensitivity analysis tool for layer-wise FIT and Hessian trace (#592
- Loading branch information
1 parent
4962bdd
commit 3de8748
Showing
3 changed files
with
420 additions
and
0 deletions.
There are no files selected for viewing
105 changes: 105 additions & 0 deletions
105
torchao/quantization/prototype/mixed_precision/scripts/fit.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
import numpy as np | ||
import os | ||
from tqdm import tqdm | ||
import transformers | ||
from datasets import load_dataset | ||
import random | ||
from torch.nn.attention import SDPBackend, sdpa_kernel | ||
|
||
def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") | ||
testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") | ||
|
||
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") | ||
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc | ||
|
||
def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_layers): | ||
|
||
# store the history of trace for each layer | ||
estimated_history=[] | ||
|
||
# store the history of mean trace for each layer | ||
estimated_mean = [[] for _ in range(num_layers)] | ||
trace = [0.] * num_layers | ||
|
||
|
||
for iteration in range(maxIter): | ||
print("iteration: ",iteration) | ||
trace_tmp = [0.] * num_layers | ||
|
||
for i in tqdm(range(nsamples)): | ||
inputs, targets = data[i] | ||
inputs = inputs.to(device) | ||
targets = targets.to(device) | ||
model.zero_grad() | ||
outputs = model(inputs) | ||
logits = outputs.logits | ||
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) | ||
|
||
grads = torch.autograd.grad(loss, model.parameters()) | ||
|
||
# Trace(Fisher Information Matrix) is calculated by the sum of the square of the gradient | ||
for layerid in range(num_layers): | ||
for (name, _), grad in zip(model.named_parameters(), grads): | ||
if "."+str(layerid)+"." in name and ("self_attn" in name or "mlp" in name): | ||
trace_tmp[layerid] += torch.sum(grad * grad).item() | ||
|
||
# clean cache | ||
model.zero_grad() | ||
del grads | ||
torch.cuda.empty_cache() | ||
|
||
# calculate the mean of the trace on the calibration dataset | ||
for t in range(num_layers): | ||
trace[t] = trace_tmp[t] / float(nsamples) | ||
estimated_mean[t].append(trace[t]) | ||
|
||
print("trace:",trace) | ||
estimated_history.append(trace) | ||
|
||
F_average = np.array([np.mean(i) for i in estimated_mean]) | ||
return F_average, estimated_mean, estimated_history | ||
|
||
def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
# have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M | ||
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) | ||
model = model.to(device) | ||
model.eval() | ||
|
||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# load calibration dataset | ||
seed = 0 | ||
trainloader, testloader = get_wikitext2(nsamples, seed, max_seqlen, tokenizer) | ||
|
||
F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers) | ||
print("Iteration Done") | ||
print("avg_trace:", F_average) | ||
print("estimated_mean:", estimated_mean) | ||
print("estimated_history:", estimated_history) | ||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Your CLI description.') | ||
parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') | ||
parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') | ||
parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate FIT') | ||
parser.add_argument('--num_layers', type=int, default=32, help='The number of layers to calculate FIT.') | ||
parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') | ||
args = parser.parse_args() | ||
main(args.max_seqlen, args.checkpoint, args.nsamples, args.maxIter, args.num_layers) |
151 changes: 151 additions & 0 deletions
151
torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
import numpy as np | ||
import os | ||
from tqdm import tqdm | ||
import transformers | ||
from datasets import load_dataset | ||
import random | ||
from torch.nn.attention import SDPBackend, sdpa_kernel | ||
from torch.autograd.functional import hvp | ||
|
||
def group_product(xs, ys): | ||
return [torch.sum(x * y) for (x, y) in zip(xs, ys)] | ||
|
||
def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") | ||
testdata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") | ||
|
||
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") | ||
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc.input_ids | ||
|
||
def dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion): | ||
model.zero_grad() | ||
THv = [torch.zeros(p.size()).to(device) for p in params] # accumulate result | ||
|
||
# Freeze all the parameters in the model | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
# Unfreeze the parameters of attention and MLP layers in layer 0 | ||
layer_ = model.model.layers[layerid] | ||
for param in layer_.self_attn.parameters(): | ||
param.requires_grad = True | ||
for param in layer_.mlp.parameters(): | ||
param.requires_grad = True | ||
|
||
for i in tqdm(range(nsamples)): | ||
torch.cuda.empty_cache() | ||
inputs, labels = data[i] | ||
inputs = inputs.to(device) | ||
labels = labels.to(device) | ||
# if use testloader: | ||
# inputs = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) | ||
# labels = data[:, (i * max_seqlen) : ((i + 1) * max_seqlen)].to(device) | ||
model.zero_grad() | ||
outputs = model(inputs) | ||
logits = outputs.logits | ||
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) | ||
|
||
# get the first order gradients | ||
grads = torch.autograd.grad(loss, params, create_graph=True, only_inputs=True) | ||
|
||
# calculate Hessian vector product via Jac-vector product | ||
Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=False) | ||
|
||
THv = [THv1 + Hv1 + 0.0 for THv1, Hv1 in zip(THv, Hv)] | ||
|
||
# clean cache | ||
model.zero_grad() | ||
del Hv | ||
del grads | ||
torch.cuda.empty_cache() | ||
|
||
THv = [THv1 / float(nsamples) for THv1 in THv] | ||
return THv | ||
|
||
def cal_trace(layerid, params, device, data, nsamples, model, maxIter, max_seqlen, criterion): | ||
vhv_c_history = [] | ||
trace_history = [] | ||
trace = 0. | ||
|
||
for i in range(maxIter): | ||
print("iteration: ",i) | ||
|
||
# generate Rademacher random variables | ||
v = [ | ||
torch.randint_like(p, high=2, device=device) | ||
for p in params | ||
] | ||
|
||
for v_i in v: | ||
v_i[v_i == 0] = -1 | ||
|
||
# calculate Hessian vector product | ||
Hv = dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max_seqlen, criterion) | ||
|
||
vHv = group_product(Hv, v) | ||
|
||
vHv_c = np.array([i.cpu().numpy() for i in vHv]) | ||
|
||
vhv_c_history.append(vHv_c) | ||
|
||
trace = np.sum(vHv_c) | ||
|
||
trace_history.append(trace) | ||
print("trace,", trace) | ||
print("trace_history,", trace_history) | ||
print("vhv_c_history,", vhv_c_history) | ||
|
||
return np.mean(trace_history) | ||
|
||
|
||
def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
# to avoid aten::_scaled_dot_product_flash_attention_backward not implemented error | ||
with sdpa_kernel(SDPBackend.MATH): | ||
|
||
# have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M | ||
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) | ||
model = model.cuda() | ||
model.eval() | ||
|
||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# load calibration dataset | ||
seed = 0 | ||
trainloader, testloader = get_wikitext2(128, seed, 2048, tokenizer) | ||
|
||
# calculate Hessian for only one layer each time | ||
params=[] | ||
layer_ = model.model.layers[layer_id] | ||
for param in layer_.self_attn.parameters(): | ||
params.append(param) | ||
for param in layer_.mlp.parameters(): | ||
params.append(param) | ||
|
||
trace = cal_trace(layerid=layer_id, params=params, device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion) | ||
print("The trace of layer " + str(layer_id) + " is", trace) | ||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Your CLI description.') | ||
parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the trace and hessian') | ||
parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') | ||
parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') | ||
parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace') | ||
parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') | ||
args = parser.parse_args() | ||
main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples) |
Oops, something went wrong.