From 5236160d82aabbcfc8fc441b89ee65c39cb0502b Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Tue, 16 Apr 2024 15:11:56 +0200 Subject: [PATCH 1/5] Add ScicatClient.query_datasets --- docs/release-notes.rst | 2 + src/scitacean/client.py | 118 ++++++++++++++++- tests/client/query_client_test.py | 210 ++++++++++++++++++++++++++++++ 3 files changed, 327 insertions(+), 3 deletions(-) create mode 100644 tests/client/query_client_test.py diff --git a/docs/release-notes.rst b/docs/release-notes.rst index dbc84076..0b7a73e7 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -41,6 +41,8 @@ Security Features ~~~~~~~~ +* Added experimental :meth:`client.ScicatClient.query_datasets` for querying datasets by field. + Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/src/scitacean/client.py b/src/scitacean/client.py index 013df78d..ee18be14 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -6,6 +6,7 @@ import dataclasses import datetime +import json import re import warnings from collections.abc import Callable, Iterable, Iterator @@ -15,6 +16,7 @@ from urllib.parse import quote_plus import httpx +import pydantic from . import model from ._base_model import convert_download_to_user_model @@ -708,6 +710,109 @@ def get_dataset_model( **dset_json, ) + def query_datasets( + self, + fields: dict[str, Any], + *, + limit: int | None = None, + order: str | None = None, + strict_validation: bool = False, + ) -> list[model.DownloadDataset]: + """Query for datasets in SciCat. + + Attention + --------- + This function is experimental and may change or be removed in the future. + It is currently unclear how best to implement querying because SciCat + provides multiple, very different APIs and there are plans for supporting + queries via Mongo query language directly. + + See `issue #177 `_ + for a discussion. + + Parameters + ---------- + fields: + Fields to query for. + Returned datasets must match all fields exactly. + See examples below. + limit: + Maximum number of results to return. + Requires ``order`` to be specified. + If not given, all matching datasets are returned. + order: + Specify order of results. + For example, ``"creationTime:asc"`` and ``"creationTime:desc"`` return + results in ascending or descending order in creation time, respectively. + strict_validation: + If ``True``, the datasets must pass validation. + If ``False``, datasets are still returned if validation fails. + Note that some dataset fields may have a bad value or type. + A warning will be logged if validation fails. + + Returns + ------- + : + A list of dataset models that match the query. + + Examples + -------- + Get all datasets belonging to proposal ``abc.123``: + + .. code-block:: python + + scicat_client.query_datasets({'proposalId': 'abc.123'}) + + Get all datasets that belong to proposal ``abc.123`` + **and** have name ``"ds name"``: (The name and proposal must match exactly.) + + .. code-block:: python + + scicat_client.query_datasets({'proposalId': 'abc.123', 'name': 'ds name'}) + + Return only the newest 5 datasets for proposal ``bc.123``: + + .. code-block:: python + + scicat_client.query_datasets( + {'proposalId': 'bc.123'}, + limit=5, + order="creationTime:desc", + ) + """ + # Use a pydantic model to support serializing custom types to JSON. + params_model = pydantic.create_model( + "QueryParams", **{key: (type(field), ...) for key, field in fields.items()} + ) + params = {"fields": params_model(**fields).model_dump_json()} + + limits = {} + if order is not None: + limits["order"] = order + if limit is not None: + if order is None: + raise ValueError("`order` is required when `limit` is specified.") + limits["limit"] = limit + if limits: + params["limits"] = json.dumps(limits) + + dsets_json = self._call_endpoint( + cmd="get", + url="datasets/fullquery", + params=params, + operation="query_datasets", + ) + if not dsets_json: + return [] + return [ + model.construct( + model.DownloadDataset, + _strict_validation=strict_validation, + **dset_json, + ) + for dset_json in dsets_json + ] + def get_orig_datablocks( self, pid: PID, strict_validation: bool = False ) -> list[model.DownloadOrigDatablock]: @@ -1010,7 +1115,12 @@ def validate_dataset_model( raise ValueError(f"Dataset {dset} did not pass validation in SciCat.") def _send_to_scicat( - self, *, cmd: str, url: str, data: model.BaseModel | None = None + self, + *, + cmd: str, + url: str, + data: model.BaseModel | None = None, + params: dict[str, str] | None = None, ) -> httpx.Response: if self._token is not None: token = self._token.get_str() @@ -1029,6 +1139,7 @@ def _send_to_scicat( content=data.model_dump_json(exclude_none=True) if data is not None else None, + params=params, headers=headers, timeout=self._timeout.seconds, ) @@ -1047,14 +1158,15 @@ def _call_endpoint( *, cmd: str, url: str, - data: model.BaseModel | None = None, operation: str, + data: model.BaseModel | None = None, + params: dict[str, str] | None = None, ) -> Any: full_url = _url_concat(self._base_url, url) logger = get_logger() logger.info("Calling SciCat API at %s for operation '%s'", full_url, operation) - response = self._send_to_scicat(cmd=cmd, url=full_url, data=data) + response = self._send_to_scicat(cmd=cmd, url=full_url, data=data, params=params) if not response.is_success: logger.error( "SciCat API call to %s failed: %s %s: %s", diff --git a/tests/client/query_client_test.py b/tests/client/query_client_test.py new file mode 100644 index 00000000..e9730064 --- /dev/null +++ b/tests/client/query_client_test.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) + +import pytest +from dateutil.parser import parse as parse_datetime + +from scitacean import Client, DatasetType, RemotePath, model +from scitacean.testing.backend import skip_if_not_backend +from scitacean.testing.backend.config import SciCatAccess + +UPLOAD_DATASETS = { + "raw1": model.UploadRawDataset( + ownerGroup="PLACEHOLDER", + accessGroups=["uu", "faculty"], + contactEmail="ponder.stibbons@uu.am", + creationTime=parse_datetime("2004-06-13T01:45:28.100Z"), + datasetName="dataset 1", + numberOfFiles=0, + numberOfFilesArchived=0, + owner="PLACEHOLDER", + sourceFolder=RemotePath("/hex/raw1"), + type=DatasetType.RAW, + principalInvestigator="investigator 1", + creationLocation="UU", + proposalId="p0124", + ), + "raw2": model.UploadRawDataset( + ownerGroup="PLACEHOLDER", + accessGroups=["uu", "faculty"], + contactEmail="ponder.stibbons@uu.am", + creationTime=parse_datetime("2004-06-14T14:00:30Z"), + datasetName="dataset 2", + numberOfFiles=0, + numberOfFilesArchived=0, + owner="PLACEHOLDER", + sourceFolder=RemotePath("/hex/raw2"), + type=DatasetType.RAW, + principalInvestigator="investigator 2", + creationLocation="UU", + proposalId="p0124", + ), + "raw3": model.UploadRawDataset( + ownerGroup="PLACEHOLDER", + accessGroups=["uu", "faculty"], + contactEmail="ponder.stibbons@uu.am", + creationTime=parse_datetime("2004-06-10T00:13:13Z"), + datasetName="dataset 3", + numberOfFiles=0, + numberOfFilesArchived=0, + owner="PLACEHOLDER", + sourceFolder=RemotePath("/hex/raw3"), + type=DatasetType.RAW, + principalInvestigator="investigator 1", + creationLocation="UU", + proposalId="p0124", + ), + "raw4": model.UploadRawDataset( + ownerGroup="PLACEHOLDER", + accessGroups=["uu", "faculty"], + contactEmail="ponder.stibbons@uu.am", + creationTime=parse_datetime("2005-11-03T21:56:02Z"), + datasetName="dataset 1", + numberOfFiles=0, + numberOfFilesArchived=0, + owner="PLACEHOLDER", + sourceFolder=RemotePath("/hex/raw4"), + type=DatasetType.RAW, + principalInvestigator="investigator X", + creationLocation="UU", + ), + "derived1": model.UploadDerivedDataset( + ownerGroup="PLACEHOLDER", + accessGroups=["uu", "faculty"], + contactEmail="ponder.stibbons@uu.am", + creationTime=parse_datetime("2004-10-02T08:47:33Z"), + datasetName="dataset 1", + numberOfFiles=0, + numberOfFilesArchived=0, + owner="PLACEHOLDER", + sourceFolder=RemotePath("/hex/derived1"), + type=DatasetType.DERIVED, + investigator="investigator 1", + inputDatasets=[], + usedSoftware=["scitacean"], + ), + "derived2": model.UploadDerivedDataset( + ownerGroup="PLACEHOLDER", + accessGroups=["uu", "faculty"], + contactEmail="ponder.stibbons@uu.am", + creationTime=parse_datetime("2004-10-14T09:18:58Z"), + datasetName="derived dataset 2", + numberOfFiles=0, + numberOfFilesArchived=0, + owner="PLACEHOLDER", + sourceFolder=RemotePath("/hex/derived2"), + type=DatasetType.DERIVED, + investigator="investigator 1", + inputDatasets=[], + usedSoftware=["scitacean"], + ), +} +SEED = {} + + +@pytest.fixture(scope="module", autouse=True) +def seed_database(request: pytest.FixtureRequest, scicat_access: SciCatAccess) -> None: + skip_if_not_backend(request) + + client = Client.from_credentials( + url=scicat_access.url, + **scicat_access.user.credentials, # type: ignore[arg-type] + ) + for key, dset in UPLOAD_DATASETS.items(): + dset.ownerGroup = scicat_access.user.group + dset.owner = scicat_access.user.username + SEED[key] = client.scicat.create_dataset_model(dset) + + +def test_query_dataset_multiple_by_single_field(real_client, seed_database): + datasets = real_client.scicat.query_datasets({"proposalId": "p0124"}) + actual = {ds.pid: ds for ds in datasets} + expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw2", "raw3")} + assert actual == expected + + +def test_query_dataset_no_match(real_client, seed_database): + datasets = real_client.scicat.query_datasets({"owner": "librarian"}) + assert not datasets + + +def test_query_dataset_multiple_by_multiple_fields(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"proposalId": "p0124", "principalInvestigator": "investigator 1"}, + ) + actual = {ds.pid: ds for ds in datasets} + expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw3")} + assert actual == expected + + +def test_query_dataset_multiple_by_derived_field(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"investigator": "investigator 1"}, + ) + actual = {ds.pid: ds for ds in datasets} + expected = {SEED[key].pid: SEED[key] for key in ("derived1", "derived2")} + assert actual == expected + + +def test_query_dataset_uses_conjunction_of_fields(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"proposalId": "p0124", "investigator": "investigator X"}, + ) + assert not datasets + + +def test_query_dataset_can_use_custom_type(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"sourceFolder": RemotePath("/hex/raw4")}, + ) + expected = [SEED["raw4"]] + assert datasets == expected + + +def test_query_dataset_set_order(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"proposalId": "p0124"}, + order="creationTime:desc", + ) + # This test uses a list to check the order + expected = [SEED[key] for key in ("raw2", "raw1", "raw3")] + assert datasets == expected + + +def test_query_dataset_limit_ascending_creation_time(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"proposalId": "p0124"}, + limit=2, + order="creationTime:asc", + ) + actual = {ds.pid: ds for ds in datasets} + expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw3")} + assert actual == expected + + +def test_query_dataset_limit_descending_creation_time(real_client, seed_database): + datasets = real_client.scicat.query_datasets( + {"proposalId": "p0124"}, + limit=2, + order="creationTime:desc", + ) + actual = {ds.pid: ds for ds in datasets} + expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw2")} + assert actual == expected + + +def test_query_dataset_limit_needs_order(real_client, seed_database): + with pytest.raises(ValueError, match="limit"): + real_client.scicat.query_datasets( + {"proposalId": "p0124"}, + limit=2, + ) + + +def test_query_dataset_all(real_client, seed_database): + datasets = real_client.scicat.query_datasets({}) + actual = {ds.pid: ds for ds in datasets} + # We cannot test `datasets` directly because there are other datasets + # in the database from other tests. + for ds in SEED.values(): + assert actual[ds.pid] == ds From ce6cff8e9a9bf98f318061cdb2eb199ece28733e Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Tue, 16 Apr 2024 15:25:23 +0200 Subject: [PATCH 2/5] Fix type issues --- src/scitacean/client.py | 4 ++-- tests/client/query_client_test.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/scitacean/client.py b/src/scitacean/client.py index ee18be14..ab456299 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -781,12 +781,12 @@ def query_datasets( ) """ # Use a pydantic model to support serializing custom types to JSON. - params_model = pydantic.create_model( + params_model = pydantic.create_model( # type: ignore[call-overload] "QueryParams", **{key: (type(field), ...) for key, field in fields.items()} ) params = {"fields": params_model(**fields).model_dump_json()} - limits = {} + limits: dict[str, Union[str, int]] = {} if order is not None: limits["order"] = order if limit is not None: diff --git a/tests/client/query_client_test.py b/tests/client/query_client_test.py index e9730064..49a97950 100644 --- a/tests/client/query_client_test.py +++ b/tests/client/query_client_test.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) +from typing import Union + import pytest from dateutil.parser import parse as parse_datetime @@ -8,7 +10,9 @@ from scitacean.testing.backend import skip_if_not_backend from scitacean.testing.backend.config import SciCatAccess -UPLOAD_DATASETS = { +UPLOAD_DATASETS: dict[ + str, Union[model.UploadDerivedDataset, model.UploadRawDataset] +] = { "raw1": model.UploadRawDataset( ownerGroup="PLACEHOLDER", accessGroups=["uu", "faculty"], From 56b30ebe27e6b20955a98eb5fb6e78b5994d7be2 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Mon, 27 May 2024 12:19:55 +0200 Subject: [PATCH 3/5] Enable ruff UP rules --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5d9da4af..829be07e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,11 +98,11 @@ extend-include = ["*.ipynb"] extend-exclude = [".*", "__pycache__", "build", "dist", "venv"] [tool.ruff.lint] -select = ["B", "C4", "D", "DTZ", "E", "F", "G", "I", "FBT003", "PERF", "PGH", "PT", "PYI", "RUF", "S", "T20", "W"] +select = ["B", "C4", "D", "DTZ", "E", "F", "G", "I", "FBT003", "PERF", "PGH", "PT", "PYI", "RUF", "S", "T20", "UP", "W"] ignore = [ "D105", # most magic methods don't need docstrings as their purpose is always the same "E741", "E742", "E743", # do not use names ‘l’, ‘O’, or ‘I’; they are not a problem with a proper font - "UP038", # does not seem to work and leads to slower code + "UP038", # leads to slower code # Conflict with ruff format, see # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules "COM812", "COM819", "D206", "D300", "E111", "E114", "E117", "ISC001", "ISC002", "Q000", "Q001", "Q002", "Q003", "W191", From f26be94856d36433aaa749e6d5b14640506d5185 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Mon, 27 May 2024 12:20:24 +0200 Subject: [PATCH 4/5] Update type syntax --- src/scitacean/client.py | 2 +- tests/client/query_client_test.py | 41 ++++++++++++++++++------------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/scitacean/client.py b/src/scitacean/client.py index ab456299..f6cbadad 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -786,7 +786,7 @@ def query_datasets( ) params = {"fields": params_model(**fields).model_dump_json()} - limits: dict[str, Union[str, int]] = {} + limits: dict[str, str | int] = {} if order is not None: limits["order"] = order if limit is not None: diff --git a/tests/client/query_client_test.py b/tests/client/query_client_test.py index 49a97950..243b25e0 100644 --- a/tests/client/query_client_test.py +++ b/tests/client/query_client_test.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) -from typing import Union - import pytest from dateutil.parser import parse as parse_datetime @@ -10,9 +8,7 @@ from scitacean.testing.backend import skip_if_not_backend from scitacean.testing.backend.config import SciCatAccess -UPLOAD_DATASETS: dict[ - str, Union[model.UploadDerivedDataset, model.UploadRawDataset] -] = { +UPLOAD_DATASETS: dict[str, model.UploadDerivedDataset | model.UploadRawDataset] = { "raw1": model.UploadRawDataset( ownerGroup="PLACEHOLDER", accessGroups=["uu", "faculty"], @@ -107,7 +103,7 @@ @pytest.fixture(scope="module", autouse=True) -def seed_database(request: pytest.FixtureRequest, scicat_access: SciCatAccess) -> None: +def _seed_database(request: pytest.FixtureRequest, scicat_access: SciCatAccess) -> None: skip_if_not_backend(request) client = Client.from_credentials( @@ -120,19 +116,22 @@ def seed_database(request: pytest.FixtureRequest, scicat_access: SciCatAccess) - SEED[key] = client.scicat.create_dataset_model(dset) -def test_query_dataset_multiple_by_single_field(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_multiple_by_single_field(real_client): datasets = real_client.scicat.query_datasets({"proposalId": "p0124"}) actual = {ds.pid: ds for ds in datasets} expected = {SEED[key].pid: SEED[key] for key in ("raw1", "raw2", "raw3")} assert actual == expected -def test_query_dataset_no_match(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_no_match(real_client): datasets = real_client.scicat.query_datasets({"owner": "librarian"}) assert not datasets -def test_query_dataset_multiple_by_multiple_fields(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_multiple_by_multiple_fields(real_client): datasets = real_client.scicat.query_datasets( {"proposalId": "p0124", "principalInvestigator": "investigator 1"}, ) @@ -141,7 +140,8 @@ def test_query_dataset_multiple_by_multiple_fields(real_client, seed_database): assert actual == expected -def test_query_dataset_multiple_by_derived_field(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_multiple_by_derived_field(real_client): datasets = real_client.scicat.query_datasets( {"investigator": "investigator 1"}, ) @@ -150,14 +150,16 @@ def test_query_dataset_multiple_by_derived_field(real_client, seed_database): assert actual == expected -def test_query_dataset_uses_conjunction_of_fields(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_uses_conjunction_of_fields(real_client): datasets = real_client.scicat.query_datasets( {"proposalId": "p0124", "investigator": "investigator X"}, ) assert not datasets -def test_query_dataset_can_use_custom_type(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_can_use_custom_type(real_client): datasets = real_client.scicat.query_datasets( {"sourceFolder": RemotePath("/hex/raw4")}, ) @@ -165,7 +167,8 @@ def test_query_dataset_can_use_custom_type(real_client, seed_database): assert datasets == expected -def test_query_dataset_set_order(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_set_order(real_client): datasets = real_client.scicat.query_datasets( {"proposalId": "p0124"}, order="creationTime:desc", @@ -175,7 +178,8 @@ def test_query_dataset_set_order(real_client, seed_database): assert datasets == expected -def test_query_dataset_limit_ascending_creation_time(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_limit_ascending_creation_time(real_client): datasets = real_client.scicat.query_datasets( {"proposalId": "p0124"}, limit=2, @@ -186,7 +190,8 @@ def test_query_dataset_limit_ascending_creation_time(real_client, seed_database) assert actual == expected -def test_query_dataset_limit_descending_creation_time(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_limit_descending_creation_time(real_client): datasets = real_client.scicat.query_datasets( {"proposalId": "p0124"}, limit=2, @@ -197,7 +202,8 @@ def test_query_dataset_limit_descending_creation_time(real_client, seed_database assert actual == expected -def test_query_dataset_limit_needs_order(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_limit_needs_order(real_client): with pytest.raises(ValueError, match="limit"): real_client.scicat.query_datasets( {"proposalId": "p0124"}, @@ -205,7 +211,8 @@ def test_query_dataset_limit_needs_order(real_client, seed_database): ) -def test_query_dataset_all(real_client, seed_database): +@pytest.mark.usefixtures("_seed_database") +def test_query_dataset_all(real_client): datasets = real_client.scicat.query_datasets({}) actual = {ds.pid: ds for ds in datasets} # We cannot test `datasets` directly because there are other datasets From af912bb6d04fe1d1acdcb6d2b21906001834a3f6 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Mon, 27 May 2024 12:29:01 +0200 Subject: [PATCH 5/5] Fix docstring --- src/scitacean/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/scitacean/client.py b/src/scitacean/client.py index f6cbadad..c0bdcf26 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -768,7 +768,10 @@ def query_datasets( .. code-block:: python - scicat_client.query_datasets({'proposalId': 'abc.123', 'name': 'ds name'}) + scicat_client.query_datasets({ + 'proposalId': 'abc.123', + 'datasetName': 'ds name' + }) Return only the newest 5 datasets for proposal ``bc.123``: