Skip to content

feat: realize Tokenizer API, which is a simple wrapper over HuggingFace-style tokenizers. #5813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse

from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser
Expand Down Expand Up @@ -605,6 +605,79 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
)


@app.post("/tokenize", response_class=ORJSONResponse)
async def tokenize(request: Request):
"""Tokenize text using the model's tokenizer."""
try:
data = await request.json()
text = data.get("text", "")
if not isinstance(text, str):
return JSONResponse(
status_code=400,
content={"error": "The 'text' field must be a string"},
)

tokenizer = _global_state.tokenizer_manager.tokenizer
token_ids = tokenizer.encode(text)

return {"tokens": token_ids, "count": len(token_ids)}
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"Failed to tokenize: {str(e)}"},
)


@app.post("/detokenize", response_class=ORJSONResponse)
async def detokenize(request: Request):
"""Detokenize token IDs using the model's tokenizer."""
try:
data = await request.json()
tokens = data.get("tokens", [])
if not isinstance(tokens, list):
return JSONResponse(
status_code=400,
content={"error": "The 'tokens' field must be a list of integers"},
)

for token in tokens:
if not isinstance(token, int):
return JSONResponse(
status_code=400,
content={"error": "All tokens must be integers"},
)

tokenizer = _global_state.tokenizer_manager.tokenizer
text = tokenizer.decode(tokens)

special_tokens = [
"<|begin_of_text|>",
"<|endoftext|>",
"<s>",
"</s>",
"<pad>",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"<bos>",
"<eos>",
]

for token in special_tokens:
text = text.replace(token, "")

if data.get("keep_special_tokens", False):
text = tokenizer.decode(tokens)

return {"text": text}
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"Failed to detokenize: {str(e)}"},
)


@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/delete
Expand Down
51 changes: 51 additions & 0 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import unittest

import openai
import requests

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
Expand Down Expand Up @@ -549,6 +550,56 @@ def test_model_list(self):
assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int)

def test_tokenizer_endpoints(self):
"""Test the tokenizer endpoints."""
# Test /tokenize endpoint
text = "Hello, world! This is a test of the tokenizer API."
response = requests.post(
f"{self.base_url.replace('/v1', '')}/tokenize", json={"text": text}
)
assert response.status_code == 200, f"Failed with: {response.text}"
data = response.json()
assert "tokens" in data
assert "count" in data
assert isinstance(data["tokens"], list)
assert data["count"] == len(data["tokens"])

# Verify tokens are correct by comparing with local tokenizer
expected_tokens = self.tokenizer.encode(text)
assert data["tokens"] == expected_tokens

# Test /detokenize endpoint
tokens = data["tokens"]
response = requests.post(
f"{self.base_url.replace('/v1', '')}/detokenize", json={"tokens": tokens}
)
assert response.status_code == 200, f"Failed with: {response.text}"
data = response.json()
assert "text" in data
assert data["text"].strip() == text.strip()

response = requests.post(
f"{self.base_url.replace('/v1', '')}/detokenize",
json={"tokens": tokens, "keep_special_tokens": True},
)
assert response.status_code == 200, f"Failed with: {response.text}"
data_with_special = response.json()
assert "text" in data_with_special

# Test with empty inputs
response = requests.post(
f"{self.base_url.replace('/v1', '')}/tokenize", json={"text": ""}
)
assert response.status_code == 200, f"Failed with: {response.text}"
empty_tokens = self.tokenizer.encode("")
assert response.json()["count"] == len(empty_tokens)

response = requests.post(
f"{self.base_url.replace('/v1', '')}/detokenize", json={"tokens": []}
)
assert response.status_code == 200, f"Failed with: {response.text}"
assert response.json()["text"] == ""


# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
Expand Down
207 changes: 207 additions & 0 deletions test/srt/test_tokenizer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Test the tokenizer API endpoints.

Run with:
python3 -m unittest sglang/test/srt/test_tokenizer_api.py
or directly:
python3 sglang/test/srt/test_tokenizer_api.py
"""

import json
import os
import sys
import unittest

import requests

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestTokenizerAPI(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = None
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.tokenizer = get_tokenizer(cls.model)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_tokenize_simple(self):
"""Test tokenizing a simple string."""
text = "Hello, world!"
response = requests.post(f"{self.base_url}/tokenize", json={"text": text})

self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()

# Compare with local tokenizer
expected_tokens = self.tokenizer.encode(text)

self.assertIn("tokens", data)
self.assertIsInstance(data["tokens"], list)
self.assertEqual(data["tokens"], expected_tokens)

self.assertIn("count", data)
self.assertEqual(data["count"], len(expected_tokens))

def test_tokenize_empty(self):
"""Test tokenizing an empty string."""
text = ""
response = requests.post(f"{self.base_url}/tokenize", json={"text": text})
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()

expected_tokens = self.tokenizer.encode(text)

self.assertIn("tokens", data)
self.assertEqual(data["tokens"], expected_tokens)
self.assertEqual(data["count"], len(expected_tokens))

def test_tokenize_long(self):
"""Test tokenizing a longer text."""
text = "This is a longer text that should be tokenized properly. " * 10
response = requests.post(f"{self.base_url}/tokenize", json={"text": text})
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()

expected_tokens = self.tokenizer.encode(text)

self.assertIn("tokens", data)
self.assertEqual(data["tokens"], expected_tokens)
self.assertEqual(data["count"], len(expected_tokens))

def test_tokenize_invalid(self):
"""Test tokenizing with invalid input."""
response = requests.post(
f"{self.base_url}/tokenize", json={"text": 123} # Not a string
)
self.assertEqual(response.status_code, 400)
self.assertIn("error", response.json())

def test_detokenize_simple(self):
"""Test detokenizing a simple token list."""
text = "Hello, world!"
tokens = self.tokenizer.encode(text)

response = requests.post(f"{self.base_url}/detokenize", json={"tokens": tokens})
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()

self.assertIn("text", data)
self.assertEqual(data["text"].strip(), text.strip())

def test_detokenize_empty(self):
"""Test detokenizing an empty token list."""
tokens = []

response = requests.post(f"{self.base_url}/detokenize", json={"tokens": tokens})
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()

self.assertIn("text", data)
self.assertEqual(data["text"], "")

def test_detokenize_invalid_format(self):
"""Test detokenizing with invalid input format."""
response = requests.post(
f"{self.base_url}/detokenize", json={"tokens": "not_a_list"}
)
self.assertEqual(response.status_code, 400)
self.assertIn("error", response.json())

def test_detokenize_invalid_token(self):
"""Test detokenizing with invalid token type."""
response = requests.post(
f"{self.base_url}/detokenize", json={"tokens": [1, 2, "not_an_int", 4]}
)
self.assertEqual(response.status_code, 400)
self.assertIn("error", response.json())

def test_detokenize_keep_special_tokens(self):
"""Test detokenizing with the option to keep special tokens."""
text = "Hello, world!"
tokens = self.tokenizer.encode(text)

response = requests.post(f"{self.base_url}/detokenize", json={"tokens": tokens})
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data_without_special = response.json()

response = requests.post(
f"{self.base_url}/detokenize",
json={"tokens": tokens, "keep_special_tokens": True},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data_with_special = response.json()

self.assertIn("text", data_with_special)
self.assertIsInstance(data_with_special["text"], str)

if data_with_special["text"] != data_without_special["text"]:
special_tokens = [
"<|begin_of_text|>",
"<|endoftext|>",
"<s>",
"</s>",
"<pad>",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"<bos>",
"<eos>",
]
has_special_token = any(
token in data_with_special["text"] for token in special_tokens
)
self.assertTrue(
has_special_token,
f"Expected special tokens in: {data_with_special['text']}",
)

def test_roundtrip(self):
"""Test tokenize followed by detokenize roundtrip."""
original_text = "This is a test of the tokenizer API roundtrip functionality."

# First tokenize
tokenize_response = requests.post(
f"{self.base_url}/tokenize", json={"text": original_text}
)
self.assertEqual(
tokenize_response.status_code, 200, f"Failed with: {tokenize_response.text}"
)
tokens = tokenize_response.json()["tokens"]

# Then detokenize
detokenize_response = requests.post(
f"{self.base_url}/detokenize", json={"tokens": tokens}
)
self.assertEqual(
detokenize_response.status_code,
200,
f"Failed with: {detokenize_response.text}",
)
reconstructed_text = detokenize_response.json()["text"]

# Compare original and reconstructed text (ignore any special tokens)
self.assertEqual(reconstructed_text.strip(), original_text.strip())


if __name__ == "__main__":
unittest.main()
Loading