-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PTQ for generate_v2
#1866
base: main
Are you sure you want to change the base?
PTQ for generate_v2
#1866
Changes from 3 commits
f89fdd4
86b7784
e006f78
eafd3b2
e09f1a1
0575b67
f318412
2faf50c
80bb4e3
ff2ffba
322c802
24508fb
67718b9
a864ef9
31f64b5
fc501ee
3594697
d03db9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,10 @@ | |
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM | ||
# quantization_method: | ||
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dumb q: so the torchtune.training.quantization API is just for QAT.. or we're not using it anymore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you mentioned this in the PR description - if we're going to be using the torchao APIs instead it'd be good to follow up with an issue |
||
# use_hqq: False # Turn on for more accurate results | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HDCharles is this true? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah sorry this was anecdotal. |
||
|
||
# Transform arguments | ||
tokenizer: | ||
|
@@ -27,16 +31,16 @@ checkpointer: | |
output_dir: ./ | ||
model_type: LLAMA2 | ||
|
||
# Device | ||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
log_level: INFO | ||
|
||
# Generation arguments | ||
prompt: | ||
system: You are a helpful and creative AI assistant. | ||
user: What is the capital of France? | ||
max_new_tokens: 200 | ||
max_new_tokens: 500 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Allow longer generation to really see the benefit of quant + compile. |
||
temperature: 0.6 # 0.8 and 0.6 are popular values to try | ||
top_k: 300 | ||
|
||
# Device | ||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
log_level: INFO |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from torchtune.generation import sample | ||
|
||
from torchtune.modules.transforms import Transform | ||
from torchtune.training import compile_model | ||
|
||
|
||
class SingleTurnYAMLToMessages(Transform): | ||
|
@@ -65,29 +66,37 @@ class InferenceRecipe: | |
|
||
This *does not* currently support the following features: | ||
- torch.compile | ||
- quantization through torchao | ||
- multi-GPU generation | ||
- batch generation | ||
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
self.device = utils.get_device(device=cfg.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a public recipe, no need to be a "private" variable. cc @pbontrager |
||
self.dtype = training.get_dtype(dtype=cfg.dtype, device=self.device) | ||
self.logger = utils.get_logger(cfg.log_level) | ||
training.set_seed(seed=cfg.seed) | ||
|
||
def setup(self, cfg: DictConfig) -> None: | ||
"""Setup the model and transforms.""" | ||
# Load checkpointer and state_dict | ||
# Load checkpointer | ||
_checkpointer = config.instantiate(cfg.checkpointer) | ||
_ckpt_dict = _checkpointer.load_checkpoint() | ||
|
||
# Instantiate model | ||
with training.set_default_dtype(self._dtype), self._device: | ||
with training.set_default_dtype(self.dtype), self.device: | ||
model = config.instantiate(cfg.model) | ||
model.load_state_dict(_ckpt_dict[training.MODEL_KEY]) | ||
self.logger.info(f"Model was initialized with precision {self.dtype}.") | ||
|
||
# Quantize the model if specified | ||
if cfg.get("quantization_method") is not None: | ||
from torchao.quantization.quant_api import quantize_ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lazily import torchao API |
||
|
||
quantization_method = config.instantiate(cfg.quantization_method) | ||
compile_model(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compiling the model is necessary for quantization to be really worth it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious whether compiling the model results in greater speedups than compiling the next-token-prediction fn like gptfast do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should compile after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting! I was following the pattern from AO's README where the model is compiled first: model = torchao.autoquant(torch.compile(model, mode='max-autotune')) Why should the model be compiled after quantization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jerryzh168 Anecdotally, I don't see much difference in tok/sec (after first token) between putting compile first or second. Can you share some more details about which one is correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh right now but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mode='max-autotune' will take a long time to compile. Is it worth it? We dont do it for training. Its interesting that in AO's read it says to put compile first. Do we also do it for QLoRA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I haven't tried calling |
||
quantize_(model, quantization_method, device=self.device) | ||
|
||
self.model = model | ||
self._logger.info(f"Model was initialized with precision {self._dtype}.") | ||
|
||
# Instantiate transforms | ||
self.model_transform = config.instantiate(cfg.tokenizer) | ||
|
@@ -105,13 +114,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: | |
for p in itertools.chain(self.model.parameters(), self.model.buffers()) | ||
] | ||
) | ||
self._logger.info( | ||
self.logger.info( | ||
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" | ||
) | ||
self._logger.info( | ||
self.logger.info( | ||
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | ||
) | ||
self._logger.info( | ||
self.logger.info( | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" | ||
) | ||
|
||
|
@@ -128,10 +137,10 @@ def generate(self, cfg: DictConfig): | |
total_response_length = seq_len + cfg.max_new_tokens | ||
|
||
# 3. Setup KV cache | ||
with self._device: | ||
with self.device: | ||
self.model.setup_caches( | ||
batch_size=1, | ||
dtype=self._dtype, | ||
dtype=self.dtype, | ||
encoder_max_seq_len=( | ||
self.model_transform.image_seq_len if is_multimodal_input else None | ||
), | ||
|
@@ -143,7 +152,7 @@ def generate(self, cfg: DictConfig): | |
torch.ones( | ||
size=(total_response_length, total_response_length), | ||
dtype=torch.bool, | ||
device=self._device, | ||
device=self.device, | ||
) | ||
) | ||
input_pos = torch.arange(total_response_length) | ||
|
@@ -155,20 +164,20 @@ def generate(self, cfg: DictConfig): | |
[model_inputs], pad_direction="left", pad_max_images=1 | ||
) | ||
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] | ||
prompt = batch.pop("tokens").to(self._device) | ||
prompt = batch.pop("tokens").to(self.device) | ||
else: | ||
prompt = torch.tensor( | ||
model_inputs["tokens"], device=self._device | ||
).unsqueeze(0) | ||
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted this to fit on one line lol |
||
batch["mask"] = causal_mask[None, :seq_len] | ||
batch["input_pos"] = input_pos[None, :seq_len] | ||
utils.batch_to_device(batch, self._device) | ||
utils.batch_to_device(batch, self.device) | ||
|
||
# 6. Prefill step | ||
generated_tokens = [] | ||
t0 = time.perf_counter() | ||
logits = self.model(prompt, **batch)[:, -1] | ||
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) | ||
t1 = time.perf_counter() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we might have a warmup run, we log this differently so the user can see how good quantization / compilation is. |
||
self.logger.info(f"Time to generate first token: {t1 - t0:.02f} sec") | ||
generated_tokens.append(token.item()) | ||
|
||
if is_multimodal_input: | ||
|
@@ -192,15 +201,15 @@ def generate(self, cfg: DictConfig): | |
generated_tokens.append(token.item()) | ||
seq_len += 1 | ||
|
||
t = time.perf_counter() - t0 | ||
t2 = time.perf_counter() - t1 | ||
|
||
# 8. Translate tokens back to text | ||
decoded = self.model_transform.decode(generated_tokens) | ||
self._logger.info(f"\n\n{decoded}\n") | ||
self.logger.info(f"\n{decoded}\n") | ||
|
||
# 9. Log metrics | ||
tokens_per_second = len(generated_tokens) / t | ||
self.log_metrics(total_time=t, tokens_per_second=tokens_per_second) | ||
tokens_per_second = len(generated_tokens) / t2 | ||
self.log_metrics(total_time=t2, tokens_per_second=tokens_per_second) | ||
|
||
|
||
@config.parse | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave this commented out until the user wants to do something with it.