From 13f56c117925392adba4cc66e87f4a6965e2bcfa Mon Sep 17 00:00:00 2001 From: Toshiki Kataoka Date: Tue, 24 Dec 2024 10:14:23 +0900 Subject: [PATCH] test: BOS occurrence in formatted conversation Signed-off-by: Toshiki Kataoka --- tests/examples_tests/__init__.py | 0 tests/examples_tests/test_jinja.py | 47 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/examples_tests/__init__.py create mode 100644 tests/examples_tests/test_jinja.py diff --git a/tests/examples_tests/__init__.py b/tests/examples_tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/examples_tests/test_jinja.py b/tests/examples_tests/test_jinja.py new file mode 100644 index 0000000000000..78081a6e895f7 --- /dev/null +++ b/tests/examples_tests/test_jinja.py @@ -0,0 +1,47 @@ +from pathlib import Path + +import pytest +import transformers + +jinja_paths = [ + pytest.param(path, id=path.stem) + for path in sorted((Path(__name__).parent.parent / + "examples").glob("*.jinja")) +] + + +@pytest.mark.parametrize("path", jinja_paths) +@pytest.mark.parametrize("num_messages", [1, 3]) +def test_bos(path: Path, num_messages: int) -> None: + with path.open("r", encoding="utf-8") as f: + chat_template = f.read() + # We might guess an appropriate tokenizer model from the file name but we + # don't maintain such list. + # Use arbitrary BOS for testing. It doesn't have to match the str in the + # correct tokenizer. + bos_token = "=BOS=" + tokenizer = transformers.PreTrainedTokenizerBase( + chat_template=chat_template, bos_token=bos_token, eos_token="=EOS=") + conversation = [ + { + "role": "user", + "content": "1" + }, + { + "role": "assistant", + "content": "2" + }, + { + "role": "user", + "content": "3" + }, + ][:num_messages] + try: + prompt: str = tokenizer.apply_chat_template(conversation=conversation, + tokenize=False) + except Exception as e: + if str(e) == "Embedding models should only embed one message at a time": + pytest.skip(reason=str(e)) + raise + assert prompt.startswith(bos_token) + assert prompt.count(bos_token) == 1