Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTQ for generate_v2 #1866

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions recipes/configs/llama2/generation_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this commented out until the user wants to do something with it.

# quantization_method:
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb q: so the torchtune.training.quantization API is just for QAT.. or we're not using it anymore?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you mentioned this in the PR description - if we're going to be using the torchao APIs instead it'd be good to follow up with an issue

# use_hqq: False # Turn on for more accurate results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HDCharles is this true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry this was anecdotal.


# Transform arguments
tokenizer:
Expand All @@ -27,16 +31,16 @@ checkpointer:
output_dir: ./
model_type: LLAMA2

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
max_new_tokens: 200
max_new_tokens: 500
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow longer generation to really see the benefit of quant + compile.

temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO
53 changes: 31 additions & 22 deletions recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtune.generation import sample

from torchtune.modules.transforms import Transform
from torchtune.training import compile_model


class SingleTurnYAMLToMessages(Transform):
Expand Down Expand Up @@ -65,29 +66,37 @@ class InferenceRecipe:

This *does not* currently support the following features:
- torch.compile
- quantization through torchao
- multi-GPU generation
- batch generation
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
self.device = utils.get_device(device=cfg.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a public recipe, no need to be a "private" variable.

cc @pbontrager

self.dtype = training.get_dtype(dtype=cfg.dtype, device=self.device)
self.logger = utils.get_logger(cfg.log_level)
training.set_seed(seed=cfg.seed)

def setup(self, cfg: DictConfig) -> None:
"""Setup the model and transforms."""
# Load checkpointer and state_dict
# Load checkpointer
_checkpointer = config.instantiate(cfg.checkpointer)
_ckpt_dict = _checkpointer.load_checkpoint()

# Instantiate model
with training.set_default_dtype(self._dtype), self._device:
with training.set_default_dtype(self.dtype), self.device:
model = config.instantiate(cfg.model)
model.load_state_dict(_ckpt_dict[training.MODEL_KEY])
self.logger.info(f"Model was initialized with precision {self.dtype}.")

# Quantize the model if specified
if cfg.get("quantization_method") is not None:
from torchao.quantization.quant_api import quantize_
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazily import torchao API


quantization_method = config.instantiate(cfg.quantization_method)
compile_model(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiling the model is necessary for quantization to be really worth it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious whether compiling the model results in greater speedups than compiling the next-token-prediction fn like gptfast do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should compile after quantize_ for speedup actually

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I was following the pattern from AO's README where the model is compiled first:

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

Why should the model be compiled after quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Anecdotally, I don't see much difference in tok/sec (after first token) between putting compile first or second. Can you share some more details about which one is correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right now quantize_ needs to compile after: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#full-affine-quantization-flow-example

but autoquant will do compile first before calling autoquant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mode='max-autotune' will take a long time to compile. Is it worth it? We dont do it for training.

Its interesting that in AO's read it says to put compile first. Do we also do it for QLoRA?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see much difference in tok/sec (after first token) between putting compile first or second.

I haven't tried calling quantize_ after compile actually, maybe it would have the same effect as well, need to confirm

quantize_(model, quantization_method, device=self.device)

self.model = model
self._logger.info(f"Model was initialized with precision {self._dtype}.")

# Instantiate transforms
self.model_transform = config.instantiate(cfg.tokenizer)
Expand All @@ -105,13 +114,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
for p in itertools.chain(self.model.parameters(), self.model.buffers())
]
)
self._logger.info(
self.logger.info(
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
self.logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
self.logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)

Expand All @@ -128,10 +137,10 @@ def generate(self, cfg: DictConfig):
total_response_length = seq_len + cfg.max_new_tokens

# 3. Setup KV cache
with self._device:
with self.device:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
dtype=self.dtype,
encoder_max_seq_len=(
self.model_transform.image_seq_len if is_multimodal_input else None
),
Expand All @@ -143,7 +152,7 @@ def generate(self, cfg: DictConfig):
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
device=self._device,
device=self.device,
)
)
input_pos = torch.arange(total_response_length)
Expand All @@ -155,20 +164,20 @@ def generate(self, cfg: DictConfig):
[model_inputs], pad_direction="left", pad_max_images=1
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
prompt = batch.pop("tokens").to(self.device)
else:
prompt = torch.tensor(
model_inputs["tokens"], device=self._device
).unsqueeze(0)
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :]
Copy link
Contributor Author

@joecummings joecummings Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted this to fit on one line lol

batch["mask"] = causal_mask[None, :seq_len]
batch["input_pos"] = input_pos[None, :seq_len]
utils.batch_to_device(batch, self._device)
utils.batch_to_device(batch, self.device)

# 6. Prefill step
generated_tokens = []
t0 = time.perf_counter()
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
t1 = time.perf_counter()
Copy link
Contributor Author

@joecummings joecummings Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we might have a warmup run, we log this differently so the user can see how good quantization / compilation is.

self.logger.info(f"Time to generate first token: {t1 - t0:.02f} sec")
generated_tokens.append(token.item())

if is_multimodal_input:
Expand All @@ -192,15 +201,15 @@ def generate(self, cfg: DictConfig):
generated_tokens.append(token.item())
seq_len += 1

t = time.perf_counter() - t0
t2 = time.perf_counter() - t1

# 8. Translate tokens back to text
decoded = self.model_transform.decode(generated_tokens)
self._logger.info(f"\n\n{decoded}\n")
self.logger.info(f"\n{decoded}\n")

# 9. Log metrics
tokens_per_second = len(generated_tokens) / t
self.log_metrics(total_time=t, tokens_per_second=tokens_per_second)
tokens_per_second = len(generated_tokens) / t2
self.log_metrics(total_time=t2, tokens_per_second=tokens_per_second)


@config.parse
Expand Down
52 changes: 51 additions & 1 deletion tests/recipes/dev/test_generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from tests.common import TUNE_PATH
from tests.recipes.utils import MODEL_TEST_CONFIGS, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS, mps_ignored_test, TOKENIZER_PATHS
from tests.test_utils import (
CKPT_MODEL_PATHS,
gpu_test,
mps_ignored_test,
TOKENIZER_PATHS,
)


class TestGenerateV2:
Expand Down Expand Up @@ -62,6 +67,51 @@ def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir):
logs = caplog.text
assert expected_output in logs

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_llama2_generate_with_quantization(self, caplog, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS["llama2"])
ckpt_dir = ckpt_path.parent

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)

cmd = f"""
tune run dev/generate_v2 \
--config llama2/generation_v2 \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
device=cuda \
dtype=bf16 \
max_new_tokens=10 \
seed=123 \
quantization_method._component_=torchao.quantization.quant_api.int4_weight_only \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2"]
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# this is gibberish b/c the model is random weights, but it's
# the expected value for what we currently have in V2
# this test should catch any changes to the generate recipe that affect output
expected_output = (
"Halfotherтература retir pushingroad Chem CURLorientationocation Stadium"
)

logs = caplog.text
assert expected_output in logs

@pytest.mark.integration_test
def test_llama2_fail_on_bad_input(self, capsys, monkeypatch, tmpdir):
"""Should fail when user passes in a bad input:
Expand Down
Loading