Skip to content

Commit 136426d

Browse files
committed
testing torchao config migration
Summary: Testing for pytorch/ao#1690 Convenient to have this here to test on torchao main vs torchao experiment Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f5f39b3 commit 136426d

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Test that https://github.com/pytorch/ao/issues/1690 does not break HF
3+
"""
4+
5+
import fire
6+
7+
import torch
8+
import torchao
9+
import transformers
10+
11+
def run():
12+
print(f"torch version: {torch.__version__}")
13+
print(f"torchao version: {torchao.__version__}")
14+
print(f"transformers version: {transformers.__version__}")
15+
16+
# test code copy-pasted from
17+
# https://huggingface.co/docs/transformers/main/en/quantization/torchao
18+
19+
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
20+
21+
model_name = "meta-llama/Meta-Llama-3-8B"
22+
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
23+
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
24+
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
25+
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
26+
# quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
27+
28+
tokenizer = AutoTokenizer.from_pretrained(model_name)
29+
input_text = "What are we having for dinner?"
30+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
31+
32+
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
33+
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
34+
print(tokenizer.decode(output[0], skip_special_tokens=True))
35+
36+
# benchmark the performance
37+
import torch.utils.benchmark as benchmark
38+
39+
def benchmark_fn(f, *args, **kwargs):
40+
# Manual warmup
41+
for _ in range(5):
42+
f(*args, **kwargs)
43+
44+
t0 = benchmark.Timer(
45+
stmt="f(*args, **kwargs)",
46+
globals={"args": args, "kwargs": kwargs, "f": f},
47+
num_threads=torch.get_num_threads(),
48+
)
49+
return f"{(t0.blocked_autorange().mean):.3f}"
50+
51+
MAX_NEW_TOKENS = 1000
52+
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
53+
54+
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
55+
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
56+
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
57+
58+
pass
59+
60+
if __name__ == '__main__':
61+
fire.Fire(run)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Test that https://huggingface.co/docs/diffusers/en/quantization/torchao is not
3+
broken by https://github.com/pytorch/ao/issues/1690
4+
"""
5+
6+
import fire
7+
8+
def run():
9+
# copy-pasted from https://huggingface.co/docs/diffusers/en/quantization/torchao
10+
11+
import torch
12+
import diffusers
13+
import torchao
14+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
15+
16+
print(f"torch version: {torch.__version__}")
17+
print(f"torchao version: {torchao.__version__}")
18+
print(f"diffusers version: {diffusers.__version__}")
19+
20+
model_id = "black-forest-labs/FLUX.1-dev"
21+
dtype = torch.bfloat16
22+
23+
quantization_config = TorchAoConfig("int8wo")
24+
print(quantization_config)
25+
transformer = FluxTransformer2DModel.from_pretrained(
26+
model_id,
27+
subfolder="transformer",
28+
quantization_config=quantization_config,
29+
torch_dtype=dtype,
30+
)
31+
print(transformer)
32+
pipe = FluxPipeline.from_pretrained(
33+
model_id,
34+
transformer=transformer,
35+
torch_dtype=dtype,
36+
)
37+
pipe.to("cuda")
38+
39+
# Without quantization: ~31.447 GB
40+
# With quantization: ~20.40 GB
41+
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
42+
43+
prompt = "A cat holding a sign that says hello world"
44+
image = pipe(
45+
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
46+
).images[0]
47+
image.save("output.png")
48+
49+
if __name__ == '__main__':
50+
fire.Fire(run)

0 commit comments

Comments
 (0)