Skip to content

Commit

Permalink
Merge pull request #201 from SciCatProject/dataset-query
Browse files Browse the repository at this point in the history
Add experimental ScicatClient.query_datasets
  • Loading branch information
jl-wynen committed May 29, 2024
2 parents 1b554e0 + af912bb commit 1b46daa
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Security
Features
~~~~~~~~

* Added experimental :meth:`client.ScicatClient.query_datasets` for querying datasets by field.

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ extend-include = ["*.ipynb"]
extend-exclude = [".*", "__pycache__", "build", "dist", "venv"]

[tool.ruff.lint]
select = ["B", "C4", "D", "DTZ", "E", "F", "G", "I", "FBT003", "PERF", "PGH", "PT", "PYI", "RUF", "S", "T20", "W"]
select = ["B", "C4", "D", "DTZ", "E", "F", "G", "I", "FBT003", "PERF", "PGH", "PT", "PYI", "RUF", "S", "T20", "UP", "W"]
ignore = [
"D105", # most magic methods don't need docstrings as their purpose is always the same
"E741", "E742", "E743", # do not use names ‘l’, ‘O’, or ‘I’; they are not a problem with a proper font
"UP038", # does not seem to work and leads to slower code
"UP038", # leads to slower code
# Conflict with ruff format, see
# https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules
"COM812", "COM819", "D206", "D300", "E111", "E114", "E117", "ISC001", "ISC002", "Q000", "Q001", "Q002", "Q003", "W191",
Expand Down
121 changes: 118 additions & 3 deletions src/scitacean/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dataclasses
import datetime
import json
import re
import warnings
from collections.abc import Callable, Iterable, Iterator
Expand All @@ -15,6 +16,7 @@
from urllib.parse import quote_plus

import httpx
import pydantic

from . import model
from ._base_model import convert_download_to_user_model
Expand Down Expand Up @@ -708,6 +710,112 @@ def get_dataset_model(
**dset_json,
)

def query_datasets(
self,
fields: dict[str, Any],
*,
limit: int | None = None,
order: str | None = None,
strict_validation: bool = False,
) -> list[model.DownloadDataset]:
"""Query for datasets in SciCat.
Attention
---------
This function is experimental and may change or be removed in the future.
It is currently unclear how best to implement querying because SciCat
provides multiple, very different APIs and there are plans for supporting
queries via Mongo query language directly.
See `issue #177 <https://github.com/SciCatProject/scitacean/issues/177>`_
for a discussion.
Parameters
----------
fields:
Fields to query for.
Returned datasets must match all fields exactly.
See examples below.
limit:
Maximum number of results to return.
Requires ``order`` to be specified.
If not given, all matching datasets are returned.
order:
Specify order of results.
For example, ``"creationTime:asc"`` and ``"creationTime:desc"`` return
results in ascending or descending order in creation time, respectively.
strict_validation:
If ``True``, the datasets must pass validation.
If ``False``, datasets are still returned if validation fails.
Note that some dataset fields may have a bad value or type.
A warning will be logged if validation fails.
Returns
-------
:
A list of dataset models that match the query.
Examples
--------
Get all datasets belonging to proposal ``abc.123``:
.. code-block:: python
scicat_client.query_datasets({'proposalId': 'abc.123'})
Get all datasets that belong to proposal ``abc.123``
**and** have name ``"ds name"``: (The name and proposal must match exactly.)
.. code-block:: python
scicat_client.query_datasets({
'proposalId': 'abc.123',
'datasetName': 'ds name'
})
Return only the newest 5 datasets for proposal ``bc.123``:
.. code-block:: python
scicat_client.query_datasets(
{'proposalId': 'bc.123'},
limit=5,
order="creationTime:desc",
)
"""
# Use a pydantic model to support serializing custom types to JSON.
params_model = pydantic.create_model( # type: ignore[call-overload]
"QueryParams", **{key: (type(field), ...) for key, field in fields.items()}
)
params = {"fields": params_model(**fields).model_dump_json()}

limits: dict[str, str | int] = {}
if order is not None:
limits["order"] = order
if limit is not None:
if order is None:
raise ValueError("`order` is required when `limit` is specified.")
limits["limit"] = limit
if limits:
params["limits"] = json.dumps(limits)

dsets_json = self._call_endpoint(
cmd="get",
url="datasets/fullquery",
params=params,
operation="query_datasets",
)
if not dsets_json:
return []
return [
model.construct(
model.DownloadDataset,
_strict_validation=strict_validation,
**dset_json,
)
for dset_json in dsets_json
]

def get_orig_datablocks(
self, pid: PID, strict_validation: bool = False
) -> list[model.DownloadOrigDatablock]:
Expand Down Expand Up @@ -1010,7 +1118,12 @@ def validate_dataset_model(
raise ValueError(f"Dataset {dset} did not pass validation in SciCat.")

def _send_to_scicat(
self, *, cmd: str, url: str, data: model.BaseModel | None = None
self,
*,
cmd: str,
url: str,
data: model.BaseModel | None = None,
params: dict[str, str] | None = None,
) -> httpx.Response:
if self._token is not None:
token = self._token.get_str()
Expand All @@ -1029,6 +1142,7 @@ def _send_to_scicat(
content=data.model_dump_json(exclude_none=True)
if data is not None
else None,
params=params,
headers=headers,
timeout=self._timeout.seconds,
)
Expand All @@ -1047,14 +1161,15 @@ def _call_endpoint(
*,
cmd: str,
url: str,
data: model.BaseModel | None = None,
operation: str,
data: model.BaseModel | None = None,
params: dict[str, str] | None = None,
) -> Any:
full_url = _url_concat(self._base_url, url)
logger = get_logger()
logger.info("Calling SciCat API at %s for operation '%s'", full_url, operation)

response = self._send_to_scicat(cmd=cmd, url=full_url, data=data)
response = self._send_to_scicat(cmd=cmd, url=full_url, data=data, params=params)
if not response.is_success:
logger.error(
"SciCat API call to %s failed: %s %s: %s",
Expand Down
Loading

0 comments on commit 1b46daa

Please sign in to comment.