File tree Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Original file line number Diff line number Diff line change 11import time
22
33import torch
4- from transformers import AutoModelForCausalLM , LlamaTokenizerFast
4+ from transformers import AutoModelForCausalLM , AutoTokenizer
55from utils import get_defualt_parser , inference , print_output
66
77if __name__ == "__main__" :
88 parser = get_defualt_parser ()
99 args = parser .parse_args ()
1010 start = time .time ()
1111 torch .set_default_dtype (torch .bfloat16 )
12+
13+ tokenizer = AutoTokenizer .from_pretrained (args .pretrained , trust_remote_code = True )
14+
1215 model = AutoModelForCausalLM .from_pretrained (
1316 args .pretrained ,
1417 trust_remote_code = True ,
1821 model .eval ()
1922 init_time = time .time () - start
2023
21- # A transformers-compatible version of the grok-1 tokenizer by Xenova
22- # https://huggingface.co/Xenova/grok-1-tokenizer
23- tokenizer = LlamaTokenizerFast .from_pretrained ("Xenova/grok-1-tokenizer" )
24-
2524 for text in args .text :
2625 output = inference (
2726 model ,
Original file line number Diff line number Diff line change 22
33import torch
44from grok1_policy import Grok1ForCausalLMPolicy
5- from transformers import AutoModelForCausalLM , LlamaTokenizerFast
5+ from transformers import AutoModelForCausalLM , AutoTokenizer
66from utils import get_defualt_parser , inference , print_output
77
88import colossalai
2727 )
2828 booster = Booster (plugin = plugin )
2929 torch .set_default_dtype (torch .bfloat16 )
30+
31+ tokenizer = AutoTokenizer .from_pretrained (args .pretrained , trust_remote_code = True )
32+
3033 with LazyInitContext (default_device = get_current_device ()):
3134 model = AutoModelForCausalLM .from_pretrained (
3235 args .pretrained , trust_remote_code = True , torch_dtype = torch .bfloat16
3538 model .eval ()
3639 init_time = time .time () - start
3740
38- # A transformers-compatible version of the grok-1 tokenizer by Xenova
39- # https://huggingface.co/Xenova/grok-1-tokenizer
40- tokenizer = LlamaTokenizerFast .from_pretrained ("Xenova/grok-1-tokenizer" )
41-
4241 for text in args .text :
4342 output = inference (
4443 model .unwrap (),
You can’t perform that action at this time.
0 commit comments