Skip to content

Commit aaa924f

Browse files
authored
Fixes in JsonSchema oneOf parsing, better multilingual testing (#144)
* Switching tokenizer to larger, more modern one * Fixing edge case in oneOf parsing * Added missing import
1 parent f649926 commit aaa924f

File tree

5 files changed

+62
-26
lines changed

5 files changed

+62
-26
lines changed

lmformatenforcer/characterlevelparser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,11 @@ def get_allowed_characters(self) -> str:
127127
def can_end(self) -> bool:
128128
return any([parser.can_end() for parser in self.parsers])
129129

130-
def shortcut_key(self) -> Optional[str]:
131-
return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None
130+
def shortcut_key(self) -> Optional[Hashable]:
131+
unique_shortcut_keys = set(parser.shortcut_key() for parser in self.parsers)
132+
if len(unique_shortcut_keys) == 1:
133+
return next(iter(unique_shortcut_keys))
134+
return None
132135

133136
def cache_key(self) -> Optional[Hashable]:
134137
all_cache_keys = tuple(parser.cache_key() for parser in self.parsers)

lmformatenforcer/jsonschemaparser.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def __init__(self, root: JsonSchemaParser):
160160

161161

162162
def _merge_object_schemas(base_schema: JsonSchemaObject, option_schema: JsonSchemaObject) -> JsonSchemaObject:
163-
for property_name, property_value in base_schema.properties.items():
163+
base_schema_properties = base_schema.properties or {}
164+
for property_name, property_value in base_schema_properties.items():
164165
# We assume that if a property exists in both base and option, the option version will be
165166
# more specific, therefore we only take missing entries
166167
if property_name not in option_schema.properties:
@@ -201,13 +202,13 @@ def get_parser(
201202
max_length=value_schema.maxLength,
202203
pattern=value_schema.pattern,
203204
)
205+
if value_schema.oneOf:
206+
# We create a combined object schema for each option that includes the information from the parent
207+
# And then create a UnionParser based on the combined options
208+
merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf]
209+
object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas]
210+
return UnionParser(object_parsing_options)
204211
elif value_schema.type == "object":
205-
if value_schema.oneOf:
206-
# We create a combined object schema for each option that includes the information from the parent
207-
# And then create a UnionParser based on the combined options
208-
merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf]
209-
object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas]
210-
return UnionParser(object_parsing_options)
211212
return ObjectParsingState(value_schema, parsing_state)
212213
elif value_schema.type == None and value_schema.ref:
213214
value_class_name = value_schema.ref.split('/')[-1]

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ types-setuptools = "68.1.0.1"
5656
[tool.poetry.group.tests.dependencies]
5757
pytest = {version = "6.2.5", python = ">=3.8"}
5858
coverage = {version = "^7.3.1", python = ">=3.8", extras = ["toml"]}
59-
transformers = ">=4.28.1"
59+
transformers = ">=4.37.0"
6060
torch = {version = "^2.1.0+cpu", source = "pytorch"}
6161
numpy = "^1.21.0"
6262

6363
[tool.poetry.group.samples.dependencies]
6464
Flask = {version = "2.3.2", python = ">=3.8"}
65-
transformers = ">=4.28.1"
65+
transformers = ">=4.37.0"
6666
tokenizers = ">=0.13.3"
6767

6868

tests/common.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from pstats import Stats
33
from typing import Optional
44
from transformers import AutoTokenizer, PreTrainedTokenizerBase
5-
65
from lmformatenforcer import CharacterLevelParser
76
from lmformatenforcer.exceptions import LMFormatEnforcerException
87
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
98
from lmformatenforcer.integrations.transformers import build_token_enforcer_tokenizer_data
10-
9+
import logging
10+
1111

1212
_tokenizer: Optional[PreTrainedTokenizerBase] = None
1313
_tokenizer_data: Optional[TokenEnforcerTokenizerData] = None
@@ -40,10 +40,19 @@ def assert_parser_with_string_direct(string: str, parser: CharacterLevelParser,
4040
def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevelParser, expect_success: bool, profile_file_path: Optional[str]):
4141
global _tokenizer
4242
if _tokenizer is None:
43-
model_id = 'TheBloke/Llama-2-7b-Chat-GPTQ'
43+
model_id = 'Qwen/Qwen2.5-72B-Instruct'
4444
_tokenizer = AutoTokenizer.from_pretrained(model_id)
45-
45+
4646
global _tokenizer_data
47+
48+
# For testing, we make sure that all letters exist individually in the tokenizer
49+
encoded_0 = _tokenizer.encode("0")
50+
for word in set(string):
51+
encoded_word = _tokenizer.encode(word)
52+
if len(encoded_word) > len(encoded_0):
53+
logging.basicConfig(level=logging.INFO)
54+
logging.warning("Encountered out-of-tokenizer character, LMFE does not deal with this well")
55+
4756
if _tokenizer_data is None:
4857
_tokenizer_data = build_token_enforcer_tokenizer_data(_tokenizer)
4958

tests/test_jsonschemaparser.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,37 @@
66
from lmformatenforcer import JsonSchemaParser
77
from enum import Enum
88
import pytest
9-
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH
9+
from lmformatenforcer.characterlevelparser import CharacterLevelParserConfig
10+
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, COMPLETE_ALPHABET, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH
1011

1112
from .common import assert_parser_with_string, CharacterNotAllowedException
1213

1314

14-
def _test_json_schema_parsing_with_string(string: str, schema_dict: Optional[dict], expect_success: bool, profile_file_path: Optional[str] = None):
15-
parser = JsonSchemaParser(schema_dict)
15+
def _test_json_schema_parsing_with_string(string: str,
16+
schema_dict: Optional[dict],
17+
expect_success: bool,
18+
profile_file_path: Optional[str] = None,
19+
ensure_ascii_in_json_dumps: bool = False):
20+
alphabet = COMPLETE_ALPHABET
21+
for letter in set(string):
22+
if letter not in alphabet and letter != '\n':
23+
alphabet += letter
24+
if expect_success:
25+
try:
26+
minified = json.dumps(json.loads(string), separators=(',', ':'), ensure_ascii=False)
27+
for letter in set(minified):
28+
if letter not in alphabet and letter != '\n':
29+
alphabet += letter
30+
except:
31+
pass
32+
config = CharacterLevelParserConfig(alphabet=alphabet)
33+
parser = JsonSchemaParser(schema_dict, config=config)
1634
assert_parser_with_string(string, parser, expect_success, profile_file_path)
1735
if expect_success:
1836
# If expecting success, also check minified and pretty-printed
19-
minified = json.dumps(json.loads(string), separators=(',', ':'))
37+
minified = json.dumps(json.loads(string), separators=(',', ':'), ensure_ascii=ensure_ascii_in_json_dumps)
2038
assert_parser_with_string(minified, parser, expect_success)
21-
pretty_printed = json.dumps(json.loads(string), indent=2)
39+
pretty_printed = json.dumps(json.loads(string), indent=2, ensure_ascii=ensure_ascii_in_json_dumps)
2240
assert_parser_with_string(pretty_printed, parser, expect_success)
2341

2442

@@ -190,22 +208,22 @@ class ListOfNoMinLengthModel(BaseModel):
190208
def test_string_escaping():
191209
for escaping_character in BACKSLASH_ESCAPING_CHARACTERS:
192210
test_string = f'{{"num":1,"message":"hello {BACKSLASH}{escaping_character} world"}}'
193-
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True)
211+
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True, ensure_ascii_in_json_dumps=True)
194212
for non_escaping_character in 'a1?':
195213
test_string = f'{{"num":1,"message":"hello {BACKSLASH}{non_escaping_character} world"}}'
196-
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False)
214+
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False, ensure_ascii_in_json_dumps=True)
197215

198216
# Unicode
199217
test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9f0 world"}}'
200-
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True)
218+
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True, ensure_ascii_in_json_dumps=True)
201219

202220
# Not enough unicode digits
203221
test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9f world"}}'
204-
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False)
222+
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False, ensure_ascii_in_json_dumps=True)
205223

206224
# Unicode digit outside of hex range
207225
test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9fP world"}}'
208-
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False)
226+
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False, ensure_ascii_in_json_dumps=True)
209227

210228

211229
def test_comma_after_all_object_keys_fails():
@@ -774,4 +792,9 @@ def test_invalid_number_formats_with_leading_zeros(test_input):
774792
('{"value": -9007199254740992}', True),
775793
])
776794
def test_number_edge_cases(test_input, expected_success):
777-
_test_json_schema_parsing_with_string(test_input, schema, expected_success)
795+
_test_json_schema_parsing_with_string(test_input, schema, expected_success)
796+
797+
def test_chinese_oneof_schema():
798+
test_schema = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "array", "items": { "oneOf": [ { "type": "object", "properties": { "trigger": { "type": "string" }, "event_type": { "enum": [ "公司上市" ] }, "arguments": { "type": "array", "items": { "type": "object", "properties": { "role": { "enum": [ "上市公司", "证券代码", "环节", "披露时间", "发行价格", "事件时间", "市值", "募资金额" ] }, "argument": { "type": "string" } }, "required": [ "role", "argument" ] } } }, "required": [ "trigger", "event_type", "arguments" ] }, { "type": "object", "properties": { "trigger": { "type": "string" }, "event_type": { "enum": [ "被约谈" ] }, "arguments": { "type": "array", "items": { "type": "object", "properties": { "role": { "enum": [ "公司名称", "披露时间", "被约谈时间", "约谈机构" ] }, "argument": { "type": "string" } }, "required": [ "role", "argument" ] } } }, "required": [ "trigger", "event_type", "arguments" ] } ] } }
799+
correct_output = """[{"trigger": "IPO", "event_type": "公司上市", "arguments": [{"role": "上市公司", "argument": "理想汽车"}, {"role": "披露时间", "argument": "30日"}, {"role": "发行价格", "argument": "8-10美元"}, {"role": "环节", "argument": "筹备上市"}]}]"""
800+
_test_json_schema_parsing_with_string(correct_output, test_schema, True)

0 commit comments

Comments
 (0)