Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query parsing unit tests #3672

Merged
merged 2 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langchain/chains/query_constructor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse(self, text: str) -> StructuredQuery:
parsed = parse_json_markdown(text, expected_keys)
if len(parsed["query"]) == 0:
parsed["query"] = " "
if parsed["filter"] == "NO_FILTER":
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
parsed["filter"] = None
else:
parsed["filter"] = self.ast_parse(parsed["filter"])
Expand Down
21 changes: 14 additions & 7 deletions langchain/chains/query_constructor/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@

func_call: CNAME "(" [args] ")"

?value: SIGNED_NUMBER -> number
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| list
| string
| "false" -> false
| "true" -> true
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true

args: expr ("," expr)*
string: ESCAPED_STRING
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"

%import common.CNAME
%import common.SIGNED_NUMBER
%import common.ESCAPED_STRING
%import common.SIGNED_FLOAT
%import common.SIGNED_INT
%import common.WS
%ignore WS
"""
Expand All @@ -44,7 +46,7 @@ def __init__(
self,
*args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]],
allowed_operators: Optional[Sequence[Operator]] = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -93,9 +95,14 @@ def true(self) -> bool:
return True

def list(self, item: Any) -> list:
if item is None:
return []
return list(item)

def number(self, item: Any) -> float:
def int(self, item: Any) -> int:
return int(item)

def float(self, item: Any) -> float:
return float(item)

def string(self, item: Any) -> str:
Expand Down
2 changes: 1 addition & 1 deletion langchain/chains/query_constructor/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
{{
"query": "teenager love",
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))"
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
}}"""

NO_FILTER_ANSWER = """\
Expand Down
Empty file.
116 changes: 116 additions & 0 deletions tests/unit_tests/chains/query_constructor/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Test LLM-generated structured query parsing."""
from typing import Any, Tuple, cast

import pytest

from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.chains.query_constructor.parser import get_parser

DEFAULT_PARSER = get_parser()


@pytest.mark.parametrize("x", ("", "foo", 'foo("bar", "baz")'))
def test_parse_invalid_grammar(x: str) -> None:
with pytest.raises(Exception):
DEFAULT_PARSER.parse(x)


def test_parse_comparison() -> None:
comp = 'gte("foo", 2)'
expected = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
for input in (
comp,
comp.replace('"', "'"),
comp.replace(" ", ""),
comp.replace(" ", " "),
comp.replace("(", " ("),
comp.replace(",", ", "),
comp.replace("2", "2.0"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual


def test_parse_operation() -> None:
op = 'and(eq("foo", "bar"), lt("baz", 1995.25))'
eq = Comparison(comparator=Comparator.EQ, attribute="foo", value="bar")
lt = Comparison(comparator=Comparator.LT, attribute="baz", value=1995.25)
expected = Operation(operator=Operator.AND, arguments=[eq, lt])
for input in (
op,
op.replace('"', "'"),
op.replace(" ", ""),
op.replace(" ", " "),
op.replace("(", " ("),
op.replace(",", ", "),
op.replace("25", "250"),
):
actual = DEFAULT_PARSER.parse(input)
assert expected == actual


def test_parse_nested_operation() -> None:
op = 'and(or(eq("a", "b"), eq("a", "c"), eq("a", "d")), not(eq("z", "foo")))'
eq1 = Comparison(comparator=Comparator.EQ, attribute="a", value="b")
eq2 = Comparison(comparator=Comparator.EQ, attribute="a", value="c")
eq3 = Comparison(comparator=Comparator.EQ, attribute="a", value="d")
eq4 = Comparison(comparator=Comparator.EQ, attribute="z", value="foo")
_not = Operation(operator=Operator.NOT, arguments=[eq4])
_or = Operation(operator=Operator.OR, arguments=[eq1, eq2, eq3])
expected = Operation(operator=Operator.AND, arguments=[_or, _not])
actual = DEFAULT_PARSER.parse(op)
assert expected == actual


def test_parse_disallowed_comparator() -> None:
parser = get_parser(allowed_comparators=[Comparator.EQ])
with pytest.raises(ValueError):
parser.parse('gt("a", 2)')


def test_parse_disallowed_operator() -> None:
parser = get_parser(allowed_operators=[Operator.AND])
with pytest.raises(ValueError):
parser.parse('not(gt("a", 2))')


def _test_parse_value(x: Any) -> None:
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
actual = parsed.value
assert actual == x


@pytest.mark.parametrize("x", (-1, 0, 1_000_000))
def test_parse_int_value(x: int) -> None:
_test_parse_value(x)


@pytest.mark.parametrize("x", (-1.001, 0.00000002, 1_234_567.6543210))
def test_parse_float_value(x: float) -> None:
_test_parse_value(x)


@pytest.mark.parametrize("x", ([], [1, "b", "true"]))
def test_parse_list_value(x: list) -> None:
_test_parse_value(x)


def test_parse_string_value() -> None:
for x in ('""', '" "', '"foo"', "'foo'"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's possible to add extra information after the `assert actual == x[1:-1], details message' , so if one of the tests fails it's easy to determine which test- case it is? (or use pytest.mark.parameterize which does that)

parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
assert actual == x[1:-1]


def test_parse_bool_value() -> None:
for y in ("true", "false"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) this is definitely a matter of style, i tend to hard-code the constants for all the test-cases instead of relying on code, to keep the specification of the inputs into the test case as simple as possible

for x in (y, y.upper(), y.title()):
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
actual = parsed.value
expected = y == "true"
assert actual == expected