Skip to content

Commit

Permalink
[Fix] Fix clean_up_tokenization_spaces in tokenizer (#1510)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 25, 2024
1 parent 067d8e1 commit fb2d068
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
1 change: 1 addition & 0 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def get_tokenizer(
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
clean_up_tokenization_spaces=False,
**kwargs,
)
except TypeError as e:
Expand Down
9 changes: 3 additions & 6 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Runtime
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER

Expand Down Expand Up @@ -92,11 +93,7 @@ def __init__(
self.model_proc.start()

def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
torch_dtype=torch_dtype,
)

self.tokenizer = get_tokenizer(model_path)
if self.is_generation:
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
Expand Down
6 changes: 4 additions & 2 deletions scripts/playground/reference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
import argparse

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM

from sglang.srt.hf_transformers_utils import get_tokenizer


@torch.inference_mode()
def normal_text(args):
t = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
Expand Down
5 changes: 4 additions & 1 deletion test/srt/models/test_generation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
from sglang.test.test_utils import calculate_rouge_l, is_in_ci


@dataclasses.dataclass
Expand Down Expand Up @@ -132,6 +132,9 @@ def test_ci_models(self):
)

def test_others(self):
if is_in_ci():
return

for model_case in ALL_OTHER_MODELS:
if (
"ONLY_RUN" in os.environ
Expand Down

0 comments on commit fb2d068

Please sign in to comment.