From b1f70935b2a3626a7a862e695eff90ec286d367d Mon Sep 17 00:00:00 2001 From: Matthew Wootten Date: Fri, 14 Jun 2024 18:59:54 -0400 Subject: [PATCH] Ensure regex matches valid JSON for "const" and "enum" with booleans, nulls, and strings Use JSON's serialization, rather than Python's, in this case. Fixes #971. Note that this still does not correctly handle arrays and objects, which are allowed by the JSON schema spec; however, those would be more complex to handle correctly. --- outlines/fsm/json_schema.py | 17 ++++++------- tests/fsm/test_json_schema.py | 47 ++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 810ef5910..e331d2ff7 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -269,19 +269,18 @@ def to_regex( elif "enum" in instance: choices = [] for choice in instance["enum"]: - if type(choice) in [int, float, bool, None]: - choices.append(re.escape(str(choice))) - elif type(choice) == str: - choices.append(f'"{re.escape(choice)}"') - + if type(choice) in [int, float, bool, type(None), str]: + choices.append(re.escape(json.dumps(choice))) + else: + raise TypeError(f"Unsupported data type in enum: {type(choice)}") return f"({'|'.join(choices)})" elif "const" in instance: const = instance["const"] - if type(const) in [int, float, bool, None]: - const = re.escape(str(const)) - elif type(const) == str: - const = f'"{re.escape(const)}"' + if type(const) in [int, float, bool, type(None), str]: + const = re.escape(json.dumps(const)) + else: + raise TypeError(f"Unsupported data type in const: {type(const)}") return const elif "$ref" in instance: diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index f2cc4115b..ef3071f8a 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -177,29 +177,53 @@ def test_match_number(pattern, does_match): '"Marc"', [('"Marc"', True), ('"Jean"', False), ('"John"', False)], ), - # Make sure strings are escaped + # Make sure strings are escaped with regex escaping ( {"title": "Foo", "const": ".*", "type": "string"}, r'"\.\*"', [('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)], ), + # Make sure strings are escaped with JSON escaping + ( + {"title": "Foo", "const": '"', "type": "string"}, + r'"\\""', + [('"\\""', True), ('"""', False)], + ), # Const integer ( {"title": "Foo", "const": 0, "type": "integer"}, "0", [("0", True), ("1", False), ("a", False)], ), + # Const float + ( + {"title": "Foo", "const": 0.2, "type": "float"}, + r"0\.2", + [("0.2", True), ("032", False)], + ), + # Const boolean + ( + {"title": "Foo", "const": True, "type": "boolean"}, + "true", + [("true", True), ("True", False)], + ), + # Const null + ( + {"title": "Foo", "const": None, "type": "null"}, + "null", + [("null", True), ("None", False), ("", False)], + ), # Enum string ( {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"}, '("Marc"|"Jean")', [('"Marc"', True), ('"Jean"', True), ('"John"', False)], ), - # Make sure strings are escaped + # Make sure strings are escaped with regex and JSON escaping ( {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"}, - r'("\.\*"|"\\s\*")', - [('".*"', True), (r'"\s*"', True), (r'"\.\*"', False)], + r'("\.\*"|"\\\\s\*")', + [('".*"', True), (r'"\\s*"', True), (r'"\.\*"', False)], ), # Enum integer ( @@ -207,6 +231,21 @@ def test_match_number(pattern, does_match): "(0|1)", [("0", True), ("1", True), ("a", False)], ), + # Enum mix of types + ( + {"title": "Foo", "enum": [6, 5.3, "potato", True, None]}, + r'(6|5\.3|"potato"|true|null)', + [ + ("6", True), + ("5.3", True), + ('"potato"', True), + ("true", True), + ("null", True), + ("523", False), + ("True", False), + ("None", False), + ], + ), # integer ( {