Skip to content

Commit af5a4bc

Browse files
miguelgrinberggithub-actions[bot]
authored andcommitted
Allow tuples and other iterables in source() method (#1895)
* Allow tuples and other iterables in source() method Fixes #1893 * review feedback (cherry picked from commit 6853db7)
1 parent 7dbf074 commit af5a4bc

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

elasticsearch_dsl/search_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@
3737
from typing_extensions import Self, TypeVar
3838

3939
from .aggs import A, Agg, AggBase
40+
from .document_base import InstrumentedField
4041
from .exceptions import IllegalOperation
4142
from .query import Bool, Q, Query
4243
from .response import Hit, Response
4344
from .utils import _R, AnyUsingType, AttrDict, DslBase, recursive_to_dict
4445

4546
if TYPE_CHECKING:
46-
from .document_base import InstrumentedField
4747
from .field import Field, Object
4848

4949

@@ -714,10 +714,14 @@ def ensure_strings(
714714
Dict[str, List[Union[str, "InstrumentedField"]]],
715715
]
716716
) -> Union[str, List[str], Dict[str, List[str]]]:
717-
if isinstance(fields, list):
718-
return [str(f) for f in fields]
719-
elif isinstance(fields, dict):
717+
if isinstance(fields, dict):
720718
return {k: ensure_strings(v) for k, v in fields.items()}
719+
elif not isinstance(fields, (str, InstrumentedField)):
720+
# we assume that if `fields` is not a any of [dict, str,
721+
# InstrumentedField] then it is an iterable of strings or
722+
# InstrumentedFields, so we convert them to a plain list of
723+
# strings
724+
return [str(f) for f in fields]
721725
else:
722726
return str(fields)
723727

tests/_async/test_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def test_source() -> None:
546546

547547
assert {
548548
"_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}
549-
} == AsyncSearch().source(includes=["foo.bar.*"], excludes=["foo.one"]).to_dict()
549+
} == AsyncSearch().source(includes=["foo.bar.*"], excludes=("foo.one",)).to_dict()
550550

551551
assert {"_source": False} == AsyncSearch().source(False).to_dict()
552552

tests/_sync/test_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def test_source() -> None:
546546

547547
assert {
548548
"_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}
549-
} == Search().source(includes=["foo.bar.*"], excludes=["foo.one"]).to_dict()
549+
} == Search().source(includes=["foo.bar.*"], excludes=("foo.one",)).to_dict()
550550

551551
assert {"_source": False} == Search().source(False).to_dict()
552552

0 commit comments

Comments
 (0)