Skip to content

Commit

Permalink
Added fix for key error because of missing 'hits' key. (#616)
Browse files Browse the repository at this point in the history
Updated CHANGELOG.md.



nox formatting applied.



Added new unit test for actions scan function.



Added type hints & nox formatting.



Added fix to async scan function & added matching unit tests for async.

Signed-off-by: Djcarrillo6 <djcarrillo6@yahoo.com>
  • Loading branch information
Djcarrillo6 authored Dec 4, 2023
1 parent 44f916c commit db61b59
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Removed
- Removed unnecessary `# -*- coding: utf-8 -*-` headers from .py files ([#615](https://github.com/opensearch-project/opensearch-py/pull/615), [#617](https://github.com/opensearch-project/opensearch-py/pull/617))
### Fixed
- Fix KeyError when scroll return no hits ([#616](https://github.com/opensearch-project/opensearch-py/pull/616))
### Security

## [2.4.2]
Expand Down
15 changes: 9 additions & 6 deletions opensearchpy/_async/helpers/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,17 @@ async def async_scan(
scroll_id = resp.get("_scroll_id")

try:
while scroll_id and resp["hits"]["hits"]:
for hit in resp["hits"]["hits"]:
while scroll_id and resp.get("hits", {}).get("hits"):
for hit in resp.get("hits", {}).get("hits", []):
yield hit

# Default to 0 if the value isn't included in the response
shards_successful = resp["_shards"].get("successful", 0)
shards_skipped = resp["_shards"].get("skipped", 0)
shards_total = resp["_shards"].get("total", 0)
_shards = resp.get("_shards")

if _shards:
# Default to 0 if the value isn't included in the response
shards_successful = _shards.get("successful", 0)
shards_skipped = _shards.get("skipped", 0)
shards_total = _shards.get("total", 0)

# check if we have any errors
if (shards_successful + shards_skipped) < shards_total:
Expand Down
16 changes: 10 additions & 6 deletions opensearchpy/helpers/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,14 +586,17 @@ def scan(
scroll_id = resp.get("_scroll_id")

try:
while scroll_id and resp["hits"]["hits"]:
for hit in resp["hits"]["hits"]:
while scroll_id and resp.get("hits", {}).get("hits"):
for hit in resp.get("hits", {}).get("hits", []):
yield hit

# Default to 0 if the value isn't included in the response
shards_successful = resp["_shards"].get("successful", 0)
shards_skipped = resp["_shards"].get("skipped", 0)
shards_total = resp["_shards"].get("total", 0)
_shards = resp.get("_shards")

if _shards:
# Default to 0 if the value isn't included in the response
shards_successful = _shards.get("successful", 0)
shards_skipped = _shards.get("skipped", 0)
shards_total = _shards.get("total", 0)

# check if we have any errors
if (shards_successful + shards_skipped) < shards_total:
Expand All @@ -614,6 +617,7 @@ def scan(
shards_total,
),
)

resp = client.scroll(
body={"scroll_id": scroll_id, "scroll": scroll}, **scroll_kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,34 @@ async def test_scan_auth_kwargs_favor_scroll_kwargs_option(
}
assert async_client.scroll.call_args[1]["sort"] == "asc"

async def test_async_scan_with_missing_hits_key(
self, async_client: Any, scan_teardown: Any
) -> None:
with patch.object(
async_client,
"search",
return_value=MockResponse({"_scroll_id": "dummy_scroll_id", "_shards": {}}),
):
with patch.object(
async_client,
"scroll",
return_value=MockResponse(
{"_scroll_id": "dummy_scroll_id", "_shards": {}}
),
):
with patch.object(
async_client, "clear_scroll", return_value=MockResponse({})
):
async_scan_result = [
hit
async for hit in actions.async_scan(
async_client, query={"query": {"match_all": {}}}
)
]
assert (
async_scan_result == []
), "Expected empty results when 'hits' key is missing"


@pytest.fixture(scope="function") # type: ignore
async def reindex_setup(async_client: Any) -> Any:
Expand Down
22 changes: 22 additions & 0 deletions test_opensearchpy/test_helpers/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import threading
import time
from typing import Any
from unittest.mock import Mock

import mock
import pytest
Expand Down Expand Up @@ -270,3 +271,24 @@ def test_string_actions_are_marked_as_simple_inserts(self) -> None:
self.assertEqual(
('{"index":{}}', "whatever"), helpers.expand_action("whatever")
)


class TestScanFunction(TestCase):
@mock.patch("opensearchpy.OpenSearch.clear_scroll")
@mock.patch("opensearchpy.OpenSearch.scroll")
@mock.patch("opensearchpy.OpenSearch.search")
def test_scan_with_missing_hits_key(
self, mock_search: Mock, mock_scroll: Mock, mock_clear_scroll: Mock
) -> None:
# Simulate a response where the 'hits' key is missing
mock_search.return_value = {"_scroll_id": "dummy_scroll_id", "_shards": {}}

mock_scroll.side_effect = [{"_scroll_id": "dummy_scroll_id", "_shards": {}}]

mock_clear_scroll.return_value = None

client = OpenSearch()

# The test should pass without raising a KeyError
scan_result = list(helpers.scan(client, query={"query": {"match_all": {}}}))
assert scan_result == [], "Expected empty results when 'hits' key is missing"

0 comments on commit db61b59

Please sign in to comment.