Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 17 additions & 38 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,21 @@ def elapsed_time(self, other_event):
return abs(other_event.event_time - self.event_time) * 1000


def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
elif "xpu" in device:
return torch.xpu.Event(enable_timing=True)
def device_timer(device: str):
if device in ["cuda", "xpu"]:
return torch.Event(enable_timing=True)
elif ("cpu" in device) or ("mps" in device):
return HostEvent()
else:
print(f"device={device} is not yet suppported")


def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")
def device_sync(device: str):
if torch.accelerator.is_available():
torch.accelerator.synchronize(device)


default_device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "cpu"
)
default_device = acc.type if (acc := torch.accelerator.current_accelerator(check_available=True)) else "cpu"

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -160,10 +146,10 @@ def generate(
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool = False,
prefill_start_event: Optional[torch.cuda.Event] = None,
prefill_end_event: Optional[torch.cuda.Event] = None,
decode_start_event: Optional[torch.cuda.Event] = None,
decode_end_event: Optional[torch.cuda.Event] = None,
prefill_start_event: Optional[torch.Event] = None,
prefill_end_event: Optional[torch.Event] = None,
decode_start_event: Optional[torch.Event] = None,
decode_end_event: Optional[torch.Event] = None,
**sampling_kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -281,8 +267,8 @@ def main(
compile_prefill: bool = False,
profile: Optional[Path] = None,
memory_profile: Optional[Path] = None,
device=default_device,
precision=torch.bfloat16,
device: str = default_device,
precision = torch.bfloat16,
write_result: Optional[Path] = None,
output_json_path: Optional[Path] = None,
output_json_local: bool = False,
Expand Down Expand Up @@ -606,7 +592,7 @@ def ffn_or_attn_only(mod, fqn):
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda",
device=device,
)
.record_inputs(
["wikitext"],
Expand All @@ -616,7 +602,7 @@ def ffn_or_attn_only(mod, fqn):
.values[0]
)
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
with torch.device(device):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)
Expand Down Expand Up @@ -883,10 +869,7 @@ def ffn_or_attn_only(mod, fqn):

for i in range(start, num_samples):
if i == 0:
if device == "cuda":
torch.cuda.reset_peak_memory_stats() # MKG
elif device == "xpu":
torch.xpu.reset_peak_memory_stats() # MKG
torch.accelerator.reset_peak_memory_stats() # MKG
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
Expand Down Expand Up @@ -1016,14 +999,10 @@ def callback(x):
torch.tensor(aggregate_metrics["decode_tokens_per_sec"])
).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() / 1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() / 1e9
elif device == "xpu":
mem = torch.xpu.max_memory_reserved() / 1e9
mem = torch.accelerator.max_memory_reserved() / 1e9
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size * tokpersec:.2f}")
Expand Down