1111import time
1212from dataclasses import dataclass
1313from pathlib import Path
14- from typing import Optional , Tuple , List
14+ from typing import List , Optional , Tuple
1515
1616import torch
1717import torch ._dynamo .config
3232B_INST , E_INST = "[INST]" , "[/INST]"
3333B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
3434
35+
3536class ChatFormat :
3637 def __init__ (self , tokenizer ):
3738 self .tokenizer = tokenizer
@@ -62,7 +63,6 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
6263 return tokens
6364
6465
65-
6666@dataclass
6767class GeneratorArgs :
6868 prompt : str = "torchchat is pronounced torch-chat and is so cool because"
@@ -210,11 +210,17 @@ def decode_n_tokens(
210210):
211211 new_tokens , new_probs = [], []
212212 encountered_eos = False
213- for i in range (num_new_tokens - 1 ): # -1 to save space to run an EoS if dont generate it naturally
213+ for i in range (
214+ num_new_tokens - 1
215+ ): # -1 to save space to run an EoS if dont generate it naturally
214216 # Actually better for Inductor to codegen attention here
215217 with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
216218 next_token , next_prob = decode_one_token (
217- model , cur_token .clone (), input_pos , need_probs = need_probs , ** sampling_kwargs
219+ model ,
220+ cur_token .clone (),
221+ input_pos ,
222+ need_probs = need_probs ,
223+ ** sampling_kwargs ,
218224 )
219225 input_pos += 1
220226 new_tokens .append (next_token .clone ())
@@ -223,15 +229,25 @@ def decode_n_tokens(
223229 new_probs .append (next_prob .clone ())
224230 cur_token = next_token .view (1 , - 1 )
225231 # encountered eos
226- if (next_token .item () == eos_token_id or (eot_id is not None and next_token .item () == eot_id )):
232+ if next_token .item () == eos_token_id or (
233+ eot_id is not None and next_token .item () == eot_id
234+ ):
227235 encountered_eos = True
228- _ , _ = decode_one_token (model , cur_token , input_pos , need_probs , ** sampling_kwargs )
236+ _ , _ = decode_one_token (
237+ model , cur_token , input_pos , need_probs , ** sampling_kwargs
238+ )
229239 input_pos += 1
230240 break
231241 if not encountered_eos :
232- eos_token = torch .tensor ([eos_token_id if eot_id is None else eot_id ], dtype = cur_token .dtype , device = cur_token .device )
242+ eos_token = torch .tensor (
243+ [eos_token_id if eot_id is None else eot_id ],
244+ dtype = cur_token .dtype ,
245+ device = cur_token .device ,
246+ )
233247 new_tokens .append (eos_token .clone ())
234- _ , _ = decode_one_token (model , eos_token .view (1 , - 1 ), input_pos , need_probs , ** sampling_kwargs )
248+ _ , _ = decode_one_token (
249+ model , eos_token .view (1 , - 1 ), input_pos , need_probs , ** sampling_kwargs
250+ )
235251 input_pos += 1
236252
237253 return new_tokens , new_probs
@@ -337,7 +353,9 @@ def generate(
337353 with torch .device (device ):
338354 model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
339355 if is_speculative and draft_model is not model :
340- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
356+ draft_model .setup_caches (
357+ max_batch_size = 1 , max_seq_length = max_seq_length
358+ )
341359
342360 # create an empty tensor of the expected final shape and
343361 # fill in the current tokens
@@ -366,7 +384,9 @@ def generate(
366384
367385 num_tokens_generated = 0
368386 input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
369- accept_counts = [0 ] * (speculate_k + 1 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
387+ accept_counts = [0 ] * (
388+ speculate_k + 1
389+ ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
370390
371391 if is_speculative :
372392 input_pos = input_pos .item () # for speculative decoding easier to keep on host
@@ -392,12 +412,14 @@ def generate(
392412 max_new_tokens - 1 ,
393413 callback = callback ,
394414 need_probs = False ,
395- eos_token_id = tokenizer .eos_id () if tokenizer else 2 ,
396- eot_id = tokenizer .special_tokens ["<|eot_id|>" ] if is_llama3_model else None ,
415+ eos_token_id = tokenizer .eos_id () if tokenizer else 2 ,
416+ eot_id = tokenizer .special_tokens ["<|eot_id|>" ] if is_llama3_model else None ,
397417 ** sampling_kwargs ,
398418 )
399419 seq [T + 1 : T + 1 + len (generated_tokens )] = torch .cat (generated_tokens )
400- seq = seq [:T + 1 + len (generated_tokens )] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
420+ seq = seq [
421+ : T + 1 + len (generated_tokens )
422+ ] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
401423
402424 generate_stats = {"accept_counts" : accept_counts }
403425 return seq , generate_stats
@@ -410,7 +432,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
410432 return torch .tensor (tokens , dtype = torch .int , device = device )
411433
412434
413-
414435def get_device_info (name : str ) -> str :
415436 import platform
416437 from subprocess import check_output
@@ -481,7 +502,9 @@ def _main(
481502 # Piggy backing off of this flag then for now to identify llama3 without prompting user.
482503 is_llama3_model = tokenizer_args .is_tiktoken
483504 if generator_args .chat_mode and is_llama3_model :
484- logging .debug ("Llama3 model detected in chat mode. Using updated sentence schemas" )
505+ logging .debug (
506+ "Llama3 model detected in chat mode. Using updated sentence schemas"
507+ )
485508
486509 builder_args .setup_caches = False
487510 model = _initialize_model (builder_args , quantize , tokenizer )
@@ -534,20 +557,29 @@ def _main(
534557 if generator_args .compile_prefill :
535558 prefill = torch .compile (prefill , fullgraph = True , dynamic = True )
536559
537- system_prompt = None
560+ system_prompt = None
538561 # Set up our max_seq_length
539562 if generator_args .chat_mode :
540- max_seq_length = 2048
541- print (f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of { max_seq_length } tokens is hit or until the user says /bye" )
542- system_prompt = input ("System Prompt [Optional]: " )
563+ max_seq_length = model .config .max_seq_length
564+ print (
565+ f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of { max_seq_length } tokens is hit or until the user says /bye"
566+ )
567+ get_system_prompt = input (
568+ "Do you want to enter a system prompt? Enter y for yes and anything else for no. \n "
569+ )
570+ if get_system_prompt == "y" or get_system_prompt == "Y" :
571+ system_prompt = input ("What is your system prompt? \n " )
543572 if is_llama3_model :
544573 chat_formatter = ChatFormat (tokenizer )
545574 else :
546- max_seq_length = min (encoded .size (0 ) + generator_args .max_new_tokens , model .config .block_size )
547-
575+ max_seq_length = min (
576+ encoded .size (0 ) + generator_args .max_new_tokens , model .config .block_size
577+ )
548578
549579 max_seq_length = (
550- max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
580+ max_seq_length + speculative_builder_args .speculate_k + 1
581+ if draft_model is not None
582+ else max_seq_length
551583 )
552584
553585 aggregate_metrics = {
@@ -557,39 +589,59 @@ def _main(
557589 start = - 1 if generator_args .compile else 0
558590 start_pos = 0
559591
560-
561592 # arbitrarily large number as chat mode goes until max_seq length or user exits
562593 num_samples = generator_args .num_samples if not generator_args .chat_mode else 100000
563- i = - 1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
564- while (i < num_samples ):
594+ i = (
595+ - 1
596+ ) # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
597+ while i < num_samples :
565598 i += 1
566599 device_sync (device = builder_args .device )
567600 if i >= 0 and generator_args .chat_mode :
568601 prompt = input ("User: " )
569- if ( prompt == "/bye" ) :
602+ if prompt == "/bye" :
570603 print ("Exiting Chat.\n " )
571604 break
572605 if not is_llama3_model :
573606 if system_prompt :
574607 prompt = f"{ B_INST } { B_SYS } \n { system_prompt .strip ()} \n { E_SYS } \n \n { prompt .strip } { E_INST } "
575- system_prompt = None # can only provide system prompt on first interaction
608+ system_prompt = (
609+ None # can only provide system prompt on first interaction
610+ )
576611 else :
577612 prompt = f"{ B_INST } { prompt .strip ()} { E_INST } "
578613 encoded = encode_tokens (
579614 tokenizer , prompt , bos = True , device = builder_args .device
580615 )
581616 else :
582- if system_prompt :
583- encoded = chat_formatter .encode_dialog_prompt ([{"role" : "system" , "content" : system_prompt }, {"role" : "user" , "content" : prompt }])
617+ if system_prompt is not None :
618+ encoded = chat_formatter .encode_dialog_prompt (
619+ [
620+ {"role" : "system" , "content" : system_prompt },
621+ {"role" : "user" , "content" : prompt },
622+ ]
623+ )
584624 system_prompt = None
585- elif (i == 0 ):
586- encoded = chat_formatter .encode_dialog_prompt ([{"role" : "user" , "content" : prompt }])
625+ elif i == 0 :
626+ encoded = chat_formatter .encode_dialog_prompt (
627+ [{"role" : "user" , "content" : prompt }]
628+ )
587629 else :
588- encoded = chat_formatter .encode_message ({"role" : "user" , "content" : prompt })
589- encoded .extend (chat_formatter .encode_header ({"role" : "assistant" , "content" : "" }))
590- encoded = torch .tensor (encoded , dtype = torch .int , device = builder_args .device )
591- if (encoded .size (0 ) + start_pos > max_seq_length ):
592- print ("This prompt would take us past the max_seq_length. Ending Conversation." )
630+ encoded = chat_formatter .encode_message (
631+ {"role" : "user" , "content" : prompt }
632+ )
633+ encoded .extend (
634+ chat_formatter .encode_header (
635+ {"role" : "assistant" , "content" : "" }
636+ )
637+ )
638+ encoded = torch .tensor (
639+ encoded , dtype = torch .int , device = builder_args .device
640+ )
641+ if encoded .size (0 ) + start_pos > max_seq_length :
642+ print (
643+ "This prompt would take us past the max_seq_length. Ending Conversation."
644+ )
593645 break
594646
595647 if generator_args .chat_mode and i >= 0 :
@@ -604,12 +656,17 @@ def callback(
604656 ):
605657 if done_generating :
606658 return
607- buffer .append (tokenizer .decode ([period_id ] + x .tolist ())[1 :]) # I think this results in the first output token being dropped from the display which is wrong.
659+ buffer .append (
660+ tokenizer .decode ([period_id ] + x .tolist ())[1 :]
661+ ) # I think this results in the first output token being dropped from the display which is wrong.
608662 if x .item () == tokenizer .eos_id ():
609663 done_generating = True
610- if (is_llama3_model and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]):
664+ if (
665+ is_llama3_model
666+ and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
667+ ):
611668 done_generating = True
612- buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
669+ buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
613670 if len (buffer ) == 4 or done_generating :
614671 print ("" .join (buffer ), end = "" , flush = True )
615672 buffer .clear ()
@@ -672,7 +729,7 @@ def callback(x):
672729 )
673730 logging .debug (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
674731
675- if ( start_pos >= max_seq_length ) :
732+ if start_pos >= max_seq_length :
676733 print ("Max Sequence Length Reached. Ending Conversation." )
677734 break
678735
0 commit comments