diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index 414b8ebcc4..5da859f102 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import logging as log from math import ceil from typing import Literal @@ -7,6 +8,7 @@ from django.conf import settings from django.core.cache import cache +from decouple import config from elasticsearch.exceptions import NotFoundError from elasticsearch_dsl import Q, Search from elasticsearch_dsl.query import EMPTY_QUERY @@ -28,6 +30,10 @@ from api.utils.search_context import SearchContext +module_logger = logging.getLogger(__name__) + + +NESTING_THRESHOLD = config("POST_PROCESS_NESTING_THRESHOLD", cast=int, default=5) SOURCE_CACHE_TIMEOUT = 60 * 60 * 4 # 4 hours FILTER_CACHE_TIMEOUT = 30 THUMBNAIL = "thumbnail" @@ -49,7 +55,7 @@ def _quote_escape(query_string): def _post_process_results( - s, start, end, page_size, search_results, filter_dead + s, start, end, page_size, search_results, filter_dead, nesting=0 ) -> list[Hit] | None: """ Perform some steps on results fetched from the backend. @@ -66,17 +72,26 @@ def _post_process_results( :param search_results: The Elasticsearch response object containing search results. :param filter_dead: Whether images should be validated. + :param nesting: the level of nesting at which this function is being called :return: List of results. """ - results = [] - to_validate = [] - for res in search_results: - if hasattr(res.meta, "highlight"): - res.fields_matched = dir(res.meta.highlight) - to_validate.append(res.url) - results.append(res) + + logger = module_logger.getChild("_post_process_results") + if nesting > NESTING_THRESHOLD: + logger.info( + { + "message": "Nesting threshold breached", + "nesting": nesting, + "start": start, + "end": end, + "page_size": page_size, + } + ) + + results = list(search_results) if filter_dead: + to_validate = [res.url for res in search_results] query_hash = get_query_hash(s) check_dead_links(query_hash, start, results, to_validate) @@ -130,7 +145,7 @@ def _post_process_results( search_response = get_es_response(s, es_query="postprocess_search") return _post_process_results( - s, start, end, page_size, search_response, filter_dead + s, start, end, page_size, search_response, filter_dead, nesting + 1 ) return results[:page_size] diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index 20aebf8ef4..1e5eae4327 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -83,7 +83,7 @@ def get_db_results(self, results): results = list(self.get_queryset().filter(identifier__in=identifiers)) results.sort(key=lambda x: identifiers.index(str(x.identifier))) for result, hit in zip(results, hits): - result.fields_matched = getattr(hit, "fields_matched", None) + result.fields_matched = getattr(hit.meta, "highlight", None) return results diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index 43318fa481..b344a1074d 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -1,3 +1,4 @@ +import logging import random import re from collections.abc import Callable @@ -760,3 +761,37 @@ def test_post_process_results_recurses_as_needed( } assert wrapped_post_process_results.call_count == 2 + + +@mock.patch( + "api.controllers.search_controller.check_dead_links", +) +def test_excessive_recursion_in_post_process( + mock_check_dead_links, + image_media_type_config, + redis, + caplog, +): + def _delete_all_results_but_first(_, __, results, ___): + results[1:] = [] + + mock_check_dead_links.side_effect = _delete_all_results_but_first + + serializer = image_media_type_config.search_request_serializer( + # This query string does not matter, ultimately, as pook is mocking + # the ES response regardless of the input + data={"q": "bird perched"} + ) + serializer.is_valid() + + with caplog.at_level(logging.INFO): + results, _, _, _ = search_controller.search( + search_params=serializer, + ip=0, + origin_index=image_media_type_config.origin_index, + exact_index=True, + page=1, + page_size=2, + filter_dead=True, + ) + assert "Nesting threshold breached" in caplog.text diff --git a/api/test/unit/views/test_media_views.py b/api/test/unit/views/test_media_views.py index c34ea0b91e..b94e309d16 100644 --- a/api/test/unit/views/test_media_views.py +++ b/api/test/unit/views/test_media_views.py @@ -11,8 +11,15 @@ @pytest.mark.django_db def test_list_query_count(api_client, media_type_config): num_results = 20 + + # Since controller returns a list of ``Hit``s, not model instances, we must + # set the ``meta`` param on each of them to match the shape of ``Hit``. + results = media_type_config.model_factory.create_batch(size=num_results) + for result in results: + result.meta = None + controller_ret = ( - media_type_config.model_factory.create_batch(size=num_results), # results + results, 1, # num_pages num_results, {}, # search_context