Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _initialize_model(
if builder_args.setup_caches:
max_seq_length = 350
with torch.device(builder_args.device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
model.setup_caches(max_batch_size=1, mamodel.config.max_seq_lenax_seq_length)

model.to(dtype=builder_args.precision)

Expand Down
1 change: 1 addition & 0 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ModelArgs:
multiple_of: int = 256
ffn_dim_multiplier: Optional[int] = None
use_tiktoken: bool = False
max_seq_len: int = 8192

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down
129 changes: 91 additions & 38 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple

import torch
import torch._dynamo.config
Expand All @@ -32,6 +32,7 @@
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class ChatFormat:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
Expand Down Expand Up @@ -62,7 +63,6 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
return tokens



@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
Expand Down Expand Up @@ -210,11 +210,17 @@ def decode_n_tokens(
):
new_tokens, new_probs = [], []
encountered_eos = False
for i in range(num_new_tokens - 1): # -1 to save space to run an EoS if dont generate it naturally
for i in range(
num_new_tokens - 1
): # -1 to save space to run an EoS if dont generate it naturally
# Actually better for Inductor to codegen attention here
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
next_token, next_prob = decode_one_token(
model, cur_token.clone(), input_pos, need_probs=need_probs, **sampling_kwargs
model,
cur_token.clone(),
input_pos,
need_probs=need_probs,
**sampling_kwargs,
)
input_pos += 1
new_tokens.append(next_token.clone())
Expand All @@ -223,15 +229,25 @@ def decode_n_tokens(
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
# encountered eos
if (next_token.item() == eos_token_id or (eot_id is not None and next_token.item() == eot_id)):
if next_token.item() == eos_token_id or (
eot_id is not None and next_token.item() == eot_id
):
encountered_eos = True
_, _ = decode_one_token(model, cur_token, input_pos, need_probs, **sampling_kwargs)
_, _ = decode_one_token(
model, cur_token, input_pos, need_probs, **sampling_kwargs
)
input_pos += 1
break
if not encountered_eos:
eos_token = torch.tensor([eos_token_id if eot_id is None else eot_id], dtype=cur_token.dtype, device=cur_token.device)
eos_token = torch.tensor(
[eos_token_id if eot_id is None else eot_id],
dtype=cur_token.dtype,
device=cur_token.device,
)
new_tokens.append(eos_token.clone())
_, _ = decode_one_token(model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs)
_, _ = decode_one_token(
model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs
)
input_pos += 1

return new_tokens, new_probs
Expand Down Expand Up @@ -337,7 +353,9 @@ def generate(
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
draft_model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length
)

# create an empty tensor of the expected final shape and
# fill in the current tokens
Expand Down Expand Up @@ -366,7 +384,9 @@ def generate(

num_tokens_generated = 0
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
accept_counts = [0] * (speculate_k + 1) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long

if is_speculative:
input_pos = input_pos.item() # for speculative decoding easier to keep on host
Expand All @@ -392,12 +412,14 @@ def generate(
max_new_tokens - 1,
callback=callback,
need_probs=False,
eos_token_id = tokenizer.eos_id() if tokenizer else 2,
eot_id = tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
eos_token_id=tokenizer.eos_id() if tokenizer else 2,
eot_id=tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
**sampling_kwargs,
)
seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens)
seq = seq[:T + 1 + len(generated_tokens)] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
seq = seq[
: T + 1 + len(generated_tokens)
] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.

generate_stats = {"accept_counts": accept_counts}
return seq, generate_stats
Expand All @@ -410,7 +432,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
return torch.tensor(tokens, dtype=torch.int, device=device)



def get_device_info(name: str) -> str:
import platform
from subprocess import check_output
Expand Down Expand Up @@ -481,7 +502,9 @@ def _main(
# Piggy backing off of this flag then for now to identify llama3 without prompting user.
is_llama3_model = tokenizer_args.is_tiktoken
if generator_args.chat_mode and is_llama3_model:
logging.debug("Llama3 model detected in chat mode. Using updated sentence schemas")
logging.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)

builder_args.setup_caches = False
model = _initialize_model(builder_args, quantize, tokenizer)
Expand Down Expand Up @@ -534,19 +557,24 @@ def _main(
if generator_args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

system_prompt=None
system_prompt = None
# Set up our max_seq_length
if generator_args.chat_mode:
max_seq_length = 2048
print(f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye")
get_system_prompt = input("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n")
if (get_system_prompt == "y" or get_system_prompt == "Y"):
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
)
get_system_prompt = input(
"Do you want to enter a system prompt? Enter y for yes and anything else for no. \n"
)
if get_system_prompt == "y" or get_system_prompt == "Y":
system_prompt = input("What is your system prompt? \n")
if is_llama3_model:
chat_formatter = ChatFormat(tokenizer)
else:
max_seq_length = min(encoded.size(0) + generator_args.max_new_tokens, model.config.block_size)

max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens, model.config.block_size
)

max_seq_length = (
max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
Expand All @@ -559,39 +587,59 @@ def _main(
start = -1 if generator_args.compile else 0
start_pos = 0


# arbitrarily large number as chat mode goes until max_seq length or user exits
num_samples = generator_args.num_samples if not generator_args.chat_mode else 100000
i = -1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
while (i < num_samples):
i = (
-1
) # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
while i < num_samples:
i += 1
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? \n")
if (prompt == "/bye"):
if prompt == "/bye":
print("Exiting Chat.\n")
break
if not is_llama3_model:
if system_prompt is not None:
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
system_prompt = None # can only provide system prompt on first interaction
system_prompt = (
None # can only provide system prompt on first interaction
)
else:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
tokenizer, prompt, bos=True, device=builder_args.device
)
else:
if system_prompt is not None:
encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}])
encoded = chat_formatter.encode_dialog_prompt(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
)
system_prompt = None
elif(i == 0):
encoded = chat_formatter.encode_dialog_prompt([{"role" : "user", "content" : prompt}])
elif i == 0:
encoded = chat_formatter.encode_dialog_prompt(
[{"role": "user", "content": prompt}]
)
else:
encoded = chat_formatter.encode_message({"role" : "user", "content" : prompt})
encoded.extend(chat_formatter.encode_header({"role": "assistant", "content": ""}))
encoded = torch.tensor(encoded, dtype=torch.int, device=builder_args.device)
if (encoded.size(0) + start_pos > max_seq_length):
print("This prompt would take us past the max_seq_length. Ending Conversation.")
encoded = chat_formatter.encode_message(
{"role": "user", "content": prompt}
)
encoded.extend(
chat_formatter.encode_header(
{"role": "assistant", "content": ""}
)
)
encoded = torch.tensor(
encoded, dtype=torch.int, device=builder_args.device
)
if encoded.size(0) + start_pos > max_seq_length:
print(
"This prompt would take us past the max_seq_length. Ending Conversation."
)
break

if generator_args.chat_mode and i >= 0:
Expand All @@ -604,12 +652,17 @@ def callback(
):
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) # I think this results in the first output token being dropped from the display which is wrong.
buffer.append(
tokenizer.decode([period_id] + x.tolist())[1:]
) # I think this results in the first output token being dropped from the display which is wrong.
if x.item() == tokenizer.eos_id():
done_generating = True
if (is_llama3_model and x.item() == tokenizer.special_tokens["<|eot_id|>"]):
if (
is_llama3_model
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
):
done_generating = True
buffer = buffer[:-1] # drop the eot_id from the output buffer
buffer = buffer[:-1] # drop the eot_id from the output buffer
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()
Expand Down Expand Up @@ -672,7 +725,7 @@ def callback(x):
)
logging.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")

if (start_pos >= max_seq_length):
if start_pos >= max_seq_length:
print("Max Sequence Length Reached. Ending Conversation.")
break

Expand Down
11 changes: 7 additions & 4 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import find_multiple, get_precision, use_et_backend
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend


#########################################################################
Expand Down Expand Up @@ -97,11 +97,14 @@ def quantized_model(self) -> nn.Module:


class PrecisionHandler(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
self.kwargs = kwargs

if isinstance(dtype, str):
dtype = name_to_dtype(dtype)
self.dtype = dtype

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass
Expand All @@ -110,7 +113,7 @@ def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
return self.model_.to(device=self.device, **self.kwargs)
return self.model_.to(device=self.device, dtype=self.dtype)


#########################################################################
Expand Down