Skip to content

Commit

Permalink
Allow the use of percent string in Bool.__and__ method
Browse files Browse the repository at this point in the history
  • Loading branch information
Godefroy-Amaury committed Jul 24, 2024
1 parent ea0a718 commit 7d2d368
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Removed
### Fixed
- Fixed Search helper to ensure proper retention of the _collapse attribute in chained operations. ([#771](https://github.com/opensearch-project/opensearch-py/pull/771))
- Fixed the use of `minimum_should_match` with `Bool` to allow the use of string-based value (percent string, combination)
### Updated APIs
- Updated opensearch-py APIs to reflect [opensearch-api-specification@0b033a9](https://github.com/opensearch-project/opensearch-api-specification/commit/0b033a92cac4cb20ec3fb51350c139afc753b089)
- Updated opensearch-py APIs to reflect [opensearch-api-specification@d5ca873](https://github.com/opensearch-project/opensearch-api-specification/commit/d5ca873d20ff54be16ec48e7bd629cda7c4a6332)
Expand Down
12 changes: 10 additions & 2 deletions opensearchpy/helpers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# under the License.

import collections.abc as collections_abc
import contextlib
from itertools import chain
from typing import Any, Optional

Expand Down Expand Up @@ -219,10 +220,17 @@ def __and__(self, other: "Bool") -> Any:
del q._params["minimum_should_match"]

for qx in (self, other):
# TODO: percentages will fail here
min_should_match = qx._min_should_match

# attempt to convert a string integer representation to int
with contextlib.suppress(ValueError):
min_should_match = int(min_should_match)

# all subqueries are required
if len(qx.should) <= min_should_match:
if (
isinstance(min_should_match, int)
and len(qx.should) <= min_should_match
):
q.must.extend(qx.should)
# not all of them are required, use it and remember min_should_match
elif not q.should:
Expand Down
83 changes: 81 additions & 2 deletions test_opensearchpy/test_helpers/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Type, Union

from typing import Any

import pytest
from pytest import raises

from opensearchpy.helpers import function, query
Expand Down Expand Up @@ -564,3 +564,82 @@ def test_script_score() -> None:
assert isinstance(q.query, query.MatchAll)
assert q.script == {"source": "...", "params": {}}
assert q.to_dict() == d


@pytest.mark.parametrize( # type: ignore[misc]
"minimum_should_match, expected_type",
[
("1", int),
("-1", int),
("50%", str),
("-50%", str),
("1<50%", str),
],
)
def test_bool_with_minimum_should_match_as_string(
minimum_should_match: str,
expected_type: Type[Union[str, int]],
) -> None:

q1 = query.Bool(
minimum_should_match=minimum_should_match,
should=[
query.Q("term", field="aa1"),
query.Q("term", field="aa2"),
],
)

q2 = query.Bool(
minimum_should_match=minimum_should_match,
should=[
query.Q("term", field="bb1"),
query.Q("term", field="bb2"),
],
)

q3 = q1 & q2

d1 = {
"bool": {
"minimum_should_match": minimum_should_match,
"should": [
{"term": {"field": "aa1"}},
{"term": {"field": "aa2"}},
],
}
}

d2 = {
"bool": {
"minimum_should_match": minimum_should_match,
"should": [
{"term": {"field": "bb1"}},
{"term": {"field": "bb2"}},
],
}
}

d3 = {
"bool": {
"should": [
{"term": {"field": "aa1"}},
{"term": {"field": "aa2"}},
],
"must": [
{
"bool": {
"should": [
{"term": {"field": "bb1"}},
{"term": {"field": "bb2"}},
],
"minimum_should_match": expected_type(minimum_should_match),
}
}
],
"minimum_should_match": expected_type(minimum_should_match),
}
}

assert q1.to_dict() == d1
assert q2.to_dict() == d2
assert q3.to_dict() == d3

0 comments on commit 7d2d368

Please sign in to comment.