Skip to content

Commit

Permalink
Fixed: get_document_annotations no longer implicitly limits results t…
Browse files Browse the repository at this point in the history
…o 100 (#541)

* Added: consume_paginated_api automatically follows URLs and collects results

* Use consume_paginated_api for all list endpoints.

* Fix tests for usage of consume_paginated_api

* Deactivated magicmock test.

* Fix export project test.

---------

Co-authored-by: iftwigs <42752431+iftwigs@users.noreply.github.com>
Co-authored-by: fz <fz@konfuzio.com>
Co-authored-by: zypriafl <zypriafl@web.de>
  • Loading branch information
4 people authored Sep 16, 2024
1 parent 8e30b65 commit d8ea117
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 93 deletions.
86 changes: 40 additions & 46 deletions konfuzio_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from json import JSONDecodeError
from operator import itemgetter
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import requests
from dotenv import set_key
Expand Down Expand Up @@ -183,6 +183,29 @@ def konfuzio_session(
return session


def consume_paginated_api(initial_url: str, session) -> List[Dict[str, Any]]:
"""
Consumes a paginated API, following the `next` links until all items have been retrieved.
:param initial_url: The initial URL to start fetching data from
:param session: A Konfuzio Session object
:return: A list of all items from all pages
"""
all_results = []
next_url = initial_url

while next_url:
logger.info(f'Fetching paginated {next_url}')
response = session.get(next_url)
response.raise_for_status()

data = response.json()
all_results.extend(data['results'])
next_url = data.get('next')

return all_results


def get_project_list(session=None):
"""
Get the list of all Projects for the user.
Expand All @@ -193,8 +216,8 @@ def get_project_list(session=None):
if session is None:
session = konfuzio_session()
url = get_projects_list_url()
r = session.get(url=url)
return r.json()
result = consume_paginated_api(url, session)
return result


def get_project_details(project_id: int, session=None) -> dict:
Expand Down Expand Up @@ -233,9 +256,9 @@ def get_project_labels(project_id: int, session=None) -> dict:
host = None

url = get_project_labels_url(project_id=project_id, host=host)
r = session.get(url=url)
result = consume_paginated_api(url, session)

return r.json()
return result


def get_project_label_sets(project_id: int, session=None) -> dict:
Expand All @@ -253,9 +276,9 @@ def get_project_label_sets(project_id: int, session=None) -> dict:
host = None

url = get_project_label_sets_url(project_id=project_id, host=host)
r = session.get(url=url)
result = consume_paginated_api(url, session)

return r.json()
return result


def create_new_project(project_name, session=None):
Expand Down Expand Up @@ -317,8 +340,8 @@ def get_document_annotations(document_id: int, session=None):
else:
host = None
url = get_document_annotations_url(document_id=document_id, host=host)
r = session.get(url)
return r.json()
result = consume_paginated_api(url, session)
return result


def get_document_bbox(document_id: int, session=None):
Expand Down Expand Up @@ -371,22 +394,6 @@ def get_page_image(document_id: int, page_number: int, session=None, thumbnail:
return r.content


# def post_document_bulk_annotation(document_id: int, project_id: int, annotation_list, session=konfuzio_session()):
# """
# Add a list of Annotations to an existing document.
#
# :param document_id: ID of the file
# :param project_id: ID of the project
# :param annotation_list: List of Annotations
# :param session: Konfuzio session with Retry and Timeout policy
# :return: Response status.
# """
# url = get_document_annotations_url(document_id, project_id=project_id)
# r = session.post(url, json=annotation_list)
# r.raise_for_status()
# return r


def post_document_annotation(
document_id: int,
spans: List,
Expand Down Expand Up @@ -559,20 +566,11 @@ def get_meta_of_files(

if limit:
url = get_documents_meta_url(project_id=project_id, offset=0, limit=limit, *args, **kwargs)
r = session.get(url)
result = r.json()['results']
else:
url = get_documents_meta_url(project_id=project_id, limit=pagination_limit, host=host, *args, **kwargs)
result = []
r = session.get(url)
data = r.json()
result += data['results']

if not limit:
while 'next' in data.keys() and data['next']:
logger.info(f'Iterate on paginated {url}.')
url = data['next']
r = session.get(url)
data = r.json()
result += data['results']
result = consume_paginated_api(url, session)

sorted_documents = sorted(result, key=itemgetter('id'))
return sorted_documents
Expand Down Expand Up @@ -784,8 +782,8 @@ def get_project_categories(project_id: int = None, session=None) -> List[Dict]:
else:
host = None
url = get_project_categories_url(project_id=project_id, host=host)
r = session.get(url=url)
return r.json()['results']
result = consume_paginated_api(url, session)
return result


def upload_ai_model(ai_model_path: str, project_id: int = None, category_id: int = None, session=None):
Expand Down Expand Up @@ -921,11 +919,7 @@ def get_all_project_ais(project_id: int, session=None) -> dict:

for ai_type, url in urls.items():
try:
response = session.get(url=url)
response.raise_for_status()

if response.status_code == 200:
all_ais[ai_type] = json.loads(response.text)
all_ais[ai_type] = consume_paginated_api(url, session)
except HTTPError as e:
all_ais[ai_type] = {'error': e}
print(f'[ERROR] while fetching {ai_type} AIs: {e}')
Expand All @@ -946,7 +940,7 @@ def export_ai_models(project, session=None, category_id=None) -> int:

project_ai_models = project.ai_models
for model_type, details in project_ai_models.items():
count = details.get('count')
count = len(details)
if count and count > 0:
# Only AI types with at least one model will be exported
ai_types.add(model_type)
Expand All @@ -962,7 +956,7 @@ def export_ai_models(project, session=None, category_id=None) -> int:
variant = ai_type
folder = os.path.join(project.project_folder, 'models', variant + '_ais')

ai_models = project_ai_models.get(variant, {}).get('results', [])
ai_models = project_ai_models.get(variant, [])

for index, ai_model in enumerate(ai_models):
# Filter Extraction AI by selected category.
Expand Down
2 changes: 1 addition & 1 deletion konfuzio_sdk/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3534,7 +3534,7 @@ def download_document_details(self):
if data['category']:
self._category = self.project.get_category_by_id(data['category'])
# write a file, even there are no annotations to support offline work
annotations = get_document_annotations(document_id=self.id_, session=self.project.session)['results']
annotations = get_document_annotations(document_id=self.id_, session=self.project.session)
with open(self.annotation_file_path, 'w') as f:
json.dump(annotations, f, indent=2, sort_keys=True)

Expand Down
52 changes: 6 additions & 46 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_projects_details(self):
"""Test to get Document details."""
data = get_project_list()
new_var = self.RESTORED_PROJECT_ID
assert new_var in [prj['id'] for prj in data['results']]
assert set(data['results'][0]) == {
assert new_var in [prj['id'] for prj in data]
assert set(data[0]) == {
'id',
'name',
'storage_name',
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_download_file_not_available(self):

def test_get_annotations(self):
"""Download Annotations and the Text from API for a Document and check their offset alignment."""
annotations = get_document_annotations(self.test_document.id_)['results']
annotations = get_document_annotations(self.test_document.id_)
self.assertEqual(len(annotations), 21)

def test_post_document_annotation_multiline_as_bboxes(self):
Expand Down Expand Up @@ -384,7 +384,7 @@ def test_post_document_annotation(self):
)
annotation = json.loads(response.text)
# check if the update has been received by the server
annotations = get_document_annotations(self.test_document.id_)['results']
annotations = get_document_annotations(self.test_document.id_)
assert annotation['id'] in [annotation['id'] for annotation in annotations]
# delete the annotation, i.e. change its status from feedback required to negative
negative_id = delete_document_annotation(annotation['id'])
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_change_annotation(self):

def test_get_project_labels(self):
"""Download Labels from API for a Project."""
label_ids = [label['id'] for label in get_project_labels(project_id=TEST_PROJECT_ID)['results']]
label_ids = [label['id'] for label in get_project_labels(project_id=TEST_PROJECT_ID)]
assert set(label_ids) == {
858,
859,
Expand All @@ -454,7 +454,7 @@ def test_get_project_labels(self):

def test_get_project_label_sets(self):
"""Test getting all Label Sets of a Project."""
label_set_ids = [label_set['id'] for label_set in get_project_label_sets(project_id=TEST_PROJECT_ID)['results']]
label_set_ids = [label_set['id'] for label_set in get_project_label_sets(project_id=TEST_PROJECT_ID)]
assert label_set_ids == [64, 3706, 3686, 3707]

def test_download_office_file(self):
Expand Down Expand Up @@ -602,46 +602,6 @@ def text(self):
_get_auth_token('test', 'test')
assert 'HTTP Status 500' in context.exception

@patch('konfuzio_sdk.api.konfuzio_session')
@patch('konfuzio_sdk.api.get_extraction_ais_list_url')
@patch('konfuzio_sdk.api.get_splitting_ais_list_url')
@patch('konfuzio_sdk.api.get_categorization_ais_list_url')
@patch('konfuzio_sdk.api.json.loads')
def test_get_all_project_ais(
self,
mock_json_loads,
mock_get_categorization_url,
mock_get_splitting_url,
mock_get_extraction_url,
mock_session,
):
"""Retrieve all AIs from a Project."""
# Setup
sample_data = {'AI_DATA': 'AI_SAMPLE_DATA'}

mock_session.return_value.get.return_value.status_code = 200
mock_json_loads.return_value = sample_data

# Action
result = get_all_project_ais(project_id=1)

# Assertions
self.assertEqual(
result,
{
'extraction': sample_data,
'filesplitting': sample_data,
'categorization': sample_data,
},
)

from konfuzio_sdk.api import konfuzio_session

# Ensure the mock methods were called with the correct arguments
mock_get_extraction_url.assert_called_once_with(1, konfuzio_session().host)
mock_get_splitting_url.assert_called_once_with(1, konfuzio_session().host)
mock_get_categorization_url.assert_called_once_with(1, konfuzio_session().host)

@patch('konfuzio_sdk.api.konfuzio_session')
@patch('konfuzio_sdk.api.get_extraction_ais_list_url')
@patch('konfuzio_sdk.api.get_splitting_ais_list_url')
Expand Down

0 comments on commit d8ea117

Please sign in to comment.