Closed
Description
The doc provided in https://docs.pytorch.org/torchtune/stable/generated/torchtune.generation.generate.html#torchtune.generation.generate for running generation does not work. I tried the following:
>>> import torch
>>> from torchtune.models.llama3 import llama3_tokenizer
>>> from torchtune.models.llama3 import llama3_8b
>>> from torchtune.generation import generate
>>> model = llama3_8b().cuda()
>>> tokenizer = llama3_tokenizer("checkpoints/Meta-Llama-3-8B-Instruct/original/tokenizer.model")
>>> prompt = tokenizer.encode("Hi my name is")
>>> output, logits = generate(model, torch.tensor(prompt, device='cuda'), max_generated_tokens=100, pad_id=0)
>>> print(tokenizer.decode(output[0].tolist()))
Hi my name is
>>> output
tensor([[128000, 13347, 856, 836, 374, 128001, 49263, 45892, 117650,
80252, 13825, 93166, 86232, 1309, 119791, 57968, 109216, 119099,
127824, 50532, 114899, 101806, 63967, 82748, 81405, 119646, 1323,
88452, 81382, 43309, 46070, 60111, 98318, 89937, 82561, 86967,
15046, 46705, 92231, 49405, 105751, 12936, 63385, 78030, 65426,
115513, 63001, 47715, 26661, 115855, 74187, 20661, 46922, 52735,
71358, 45263, 9412, 70215, 46441, 17561, 34201, 16042, 105009,
111681, 57920, 97103, 7404, 96699, 85056, 65707, 8174, 12481,
121220, 2882, 41843, 39199, 99413, 44659, 110860, 58017, 41245,
74254, 91415, 4041, 120132, 21972, 46548, 68651, 11568, 88572,
106217, 34486, 95538, 126271, 128138, 34382, 115571, 85049, 126324,
107640, 124685, 17832, 118559, 54026, 124872, 102345]],
device='cuda:0')
>>>
Metadata
Metadata
Assignees
Labels
No labels