Skip to content

Generation does not work #2769

Closed
Closed
@Shariar076

Description

@Shariar076

The doc provided in https://docs.pytorch.org/torchtune/stable/generated/torchtune.generation.generate.html#torchtune.generation.generate for running generation does not work. I tried the following:

>>> import torch                                                                                                                                                                                                                
>>> from torchtune.models.llama3 import llama3_tokenizer                                                          
>>> from torchtune.models.llama3 import llama3_8b                                                                 
>>> from torchtune.generation import generate                                                                     
>>> model = llama3_8b().cuda()                                                                                    
>>> tokenizer = llama3_tokenizer("checkpoints/Meta-Llama-3-8B-Instruct/original/tokenizer.model")                 
>>> prompt = tokenizer.encode("Hi my name is")                                                                    
>>> output, logits = generate(model, torch.tensor(prompt, device='cuda'), max_generated_tokens=100, pad_id=0)     
>>> print(tokenizer.decode(output[0].tolist()))                                                                   
Hi my name is                                                                                                     
>>> output                                                                                                        
tensor([[128000,  13347,    856,    836,    374, 128001,  49263,  45892, 117650,                                  
          80252,  13825,  93166,  86232,   1309, 119791,  57968, 109216, 119099,                                  
         127824,  50532, 114899, 101806,  63967,  82748,  81405, 119646,   1323,                                  
          88452,  81382,  43309,  46070,  60111,  98318,  89937,  82561,  86967,                                  
          15046,  46705,  92231,  49405, 105751,  12936,  63385,  78030,  65426,                                  
         115513,  63001,  47715,  26661, 115855,  74187,  20661,  46922,  52735,                                  
          71358,  45263,   9412,  70215,  46441,  17561,  34201,  16042, 105009,                                  
         111681,  57920,  97103,   7404,  96699,  85056,  65707,   8174,  12481,                                  
         121220,   2882,  41843,  39199,  99413,  44659, 110860,  58017,  41245,                                  
          74254,  91415,   4041, 120132,  21972,  46548,  68651,  11568,  88572,                                  
         106217,  34486,  95538, 126271, 128138,  34382, 115571,  85049, 126324,                                  
         107640, 124685,  17832, 118559,  54026, 124872, 102345]],                                                
       device='cuda:0')                                                                                           
>>>                                                                                                               

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions