Skip to content

Commit 36c4bb2

Browse files
[Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path * trust remote code
1 parent 00525f7 commit 36c4bb2

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

examples/language/grok-1/inference.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import time
22

33
import torch
4-
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
55
from utils import get_defualt_parser, inference, print_output
66

77
if __name__ == "__main__":
88
parser = get_defualt_parser()
99
args = parser.parse_args()
1010
start = time.time()
1111
torch.set_default_dtype(torch.bfloat16)
12+
13+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
14+
1215
model = AutoModelForCausalLM.from_pretrained(
1316
args.pretrained,
1417
trust_remote_code=True,
@@ -18,10 +21,6 @@
1821
model.eval()
1922
init_time = time.time() - start
2023

21-
# A transformers-compatible version of the grok-1 tokenizer by Xenova
22-
# https://huggingface.co/Xenova/grok-1-tokenizer
23-
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
24-
2524
for text in args.text:
2625
output = inference(
2726
model,

examples/language/grok-1/inference_tp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from grok1_policy import Grok1ForCausalLMPolicy
5-
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
66
from utils import get_defualt_parser, inference, print_output
77

88
import colossalai
@@ -27,6 +27,9 @@
2727
)
2828
booster = Booster(plugin=plugin)
2929
torch.set_default_dtype(torch.bfloat16)
30+
31+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
32+
3033
with LazyInitContext(default_device=get_current_device()):
3134
model = AutoModelForCausalLM.from_pretrained(
3235
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
@@ -35,10 +38,6 @@
3538
model.eval()
3639
init_time = time.time() - start
3740

38-
# A transformers-compatible version of the grok-1 tokenizer by Xenova
39-
# https://huggingface.co/Xenova/grok-1-tokenizer
40-
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
41-
4241
for text in args.text:
4342
output = inference(
4443
model.unwrap(),

0 commit comments

Comments
 (0)