diff --git a/noxfile.py b/noxfile.py index e45e00f..5d3e9c6 100644 --- a/noxfile.py +++ b/noxfile.py @@ -152,7 +152,7 @@ def mypy(session: Session) -> None: """Type-check using mypy.""" args = session.posargs or ["src", "tests", "docs/conf.py"] session.install(".") - session.install("mypy", "pytest", "types-requests", "types-PyYAML") + session.install("mypy", "pytest", "types-requests", "types-PyYAML", "typeguard") session.run("mypy", *args) if not session.posargs: session.run("mypy", f"--python-executable={sys.executable}", "noxfile.py") @@ -162,7 +162,7 @@ def mypy(session: Session) -> None: def tests(session: Session) -> None: """Run the test suite.""" session.install(".") - session.install("coverage[toml]", "pytest", "pygments") + session.install("coverage[toml]", "pytest", "pygments", "typeguard") try: session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs) finally: diff --git a/src/stac_api_validator/__main__.py b/src/stac_api_validator/__main__.py index be0d74b..db983c7 100644 --- a/src/stac_api_validator/__main__.py +++ b/src/stac_api_validator/__main__.py @@ -131,6 +131,12 @@ "--transaction-collection", help="The name of the collection to use for Transaction Extension tests.", ) +@click.option( + "-H", + "--headers", + multiple=True, + help="Headers to attach to the main request and dependent pystac requests, curl syntax", +) def main( log_level: str, root_url: str, @@ -155,11 +161,21 @@ def main( query_in_field: Optional[str] = None, query_in_values: Optional[str] = None, transaction_collection: Optional[str] = None, + headers: Optional[List[str]] = None, ) -> int: """STAC API Validator.""" logging.basicConfig(stream=sys.stdout, level=log_level) try: + processed_headers = {} + if headers: + processed_headers.update( + { + key.strip(): value.strip() + for key, value in (header.split(":") for header in headers) + } + ) + (warnings, errors) = validate_api( root_url=root_url, ccs_to_validate=conformance_classes, @@ -185,6 +201,7 @@ def main( query_in_values, ), transaction_collection=transaction_collection, + headers=processed_headers, ) except Exception as e: click.secho( diff --git a/src/stac_api_validator/validations.py b/src/stac_api_validator/validations.py index 4c2f429..2ddb15e 100644 --- a/src/stac_api_validator/validations.py +++ b/src/stac_api_validator/validations.py @@ -27,6 +27,7 @@ from pystac import Collection from pystac import Item from pystac import ItemCollection +from pystac import StacIO from pystac import STACValidationError from pystac_client import Client from requests import Request @@ -298,6 +299,16 @@ def is_geojson_type(maybe_type: Optional[str]) -> bool: ) +def get_catalog(data_dict: Dict[str, Any], r_session: Session) -> Catalog: + stac_io = StacIO.default() + if r_session.headers and r_session.headers.get("Authorization"): + stac_io.headers = r_session.headers # noqa, type: ignore + stac_io.headers["Accept-Encoding"] = "*" + catalog = Catalog.from_dict(data_dict) + catalog._stac_io = stac_io + return catalog + + # def is_json_or_geojson_type(maybe_type: Optional[str]) -> bool: # return maybe_type and (is_json_type(maybe_type) or is_geojson_type(maybe_type)) @@ -381,9 +392,8 @@ def retrieve( additional: Optional[str] = "", content_type: Optional[str] = None, ) -> Tuple[int, Optional[Dict[str, Any]], Optional[Mapping[str, str]]]: - resp = r_session.send( - Request(method.value, url, headers=headers, params=params, json=body).prepare() - ) + request = Request(method.value, url, headers=headers, params=params, json=body) + resp = r_session.send(r_session.prepare_request(request)) # todo: handle connection exception, etc. # todo: handle timeout @@ -537,6 +547,7 @@ def validate_api( validate_pagination: bool, query_config: QueryConfig, transaction_collection: Optional[str], + headers: Optional[Dict[str, str]], ) -> Tuple[Warnings, Errors]: warnings = Warnings() errors = Errors() @@ -548,6 +559,9 @@ def validate_api( if auth_query_parameter and (xs := auth_query_parameter.split("=", 1)): r_session.params = {xs[0]: xs[1]} + if headers: + r_session.headers.update(headers) + _, landing_page_body, landing_page_headers = retrieve( Method.GET, root_url, errors, Context.CORE, r_session ) @@ -704,7 +718,7 @@ def validate_api( if not errors: try: - catalog = Client.open(root_url) + catalog = Client.open(root_url, headers=headers) catalog.validate() for child in catalog.get_children(): child.validate() @@ -811,7 +825,8 @@ def validate_core( # this validates, among other things, that the child and item link relations reference # valid STAC Catalogs, Collections, and/or Items try: - list(take(1000, Catalog.from_dict(root_body).get_all_items())) + catalog = get_catalog(root_body, r_session) + list(take(1000, catalog.get_all_items())) except pystac.errors.STACTypeError as e: errors += ( f"[{Context.CORE}] Error while traversing Catalog child/item links to find Items: {e} " @@ -839,14 +854,15 @@ def validate_browseable( # check that at least a few of the items that can be reached from child/item link relations # can be found through search try: - for item in take(10, Catalog.from_dict(root_body).get_all_items()): + catalog = get_catalog(root_body, r_session) + for item in take(10, catalog.get_all_items()): if link := link_by_rel(root_body.get("links"), "search"): _, body, _ = retrieve( Method.GET, link["href"], errors, Context.BROWSEABLE, - params={"ids": item.id, "collections": item.collection}, + params={"ids": item.id, "collections": item.collection_id}, r_session=r_session, ) if body and len(body.get("features", [])) != 1: diff --git a/tests/resources/sample-item.json b/tests/resources/sample-item.json new file mode 100644 index 0000000..3e4bafe --- /dev/null +++ b/tests/resources/sample-item.json @@ -0,0 +1,53 @@ +{ + "type": "Feature", + "stac_version": "1.0.0", + "id": "CS3-20160503_132131_05", + "properties": { + "datetime": "2016-05-03T13:22:30.040000Z", + "title": "A CS3 item", + "license": "PDDL-1.0", + "providers": [ + { + "name": "CoolSat", + "roles": ["producer", "licensor"], + "url": "https://cool-sat.com/" + } + ] + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-122.308150179, 37.488035566], + [-122.597502109, 37.538869539], + [-122.576687533, 37.613537207], + [-122.2880486, 37.562818007], + [-122.308150179, 37.488035566] + ] + ] + }, + "links": [ + { + "rel": "collection", + "href": "https://raw.githubusercontent.com/radiantearth/stac-spec/v0.8.1/collection-spec/examples/sentinel2.json" + } + ], + "assets": { + "analytic": { + "href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/analytic.tif", + "title": "4-Band Analytic", + "product": "http://cool-sat.com/catalog/products/analytic.json", + "type": "image/tiff; application=geotiff; profile=cloud-optimized", + "roles": ["data", "analytic"] + }, + "thumbnail": { + "href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/thumbnail.png", + "title": "Thumbnail", + "type": "image/png", + "roles": ["thumbnail"] + } + }, + "bbox": [-122.59750209, 37.48803556, -122.2880486, 37.613537207], + "stac_extensions": [], + "collection": "CS3" +} diff --git a/tests/test_main.py b/tests/test_main.py index f21a680..bed1eb0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ """Test cases for the __main__ module.""" +import unittest.mock import pytest from click.testing import CliRunner @@ -15,3 +16,38 @@ def runner() -> CliRunner: def test_main_fails(runner: CliRunner) -> None: result = runner.invoke(__main__.main) assert result.exit_code == 2 + + +def test_retrieve_called_with_auth_headers( + request: pytest.FixtureRequest, runner: CliRunner +) -> None: + if request.config.getoption("typeguard_packages"): + pytest.skip( + "The import hook that typeguard uses seems to break the mock below." + ) + + expected_headers = { + "User-Agent": "python-requests/2.28.2", + "Accept-Encoding": "gzip, deflate", + "Accept": "*/*", + "Connection": "keep-alive", + "Authorization": "api-key fake-api-key-value", + } + + with unittest.mock.patch( + "stac_api_validator.validations.retrieve" + ) as retrieve_mock: + runner.invoke( + __main__.main, + args=[ + "--root-url", + "https://invalid", + "--conformance", + "core", + "-H", + "Authorization: api-key fake-api-key-value", + ], + ) + assert retrieve_mock.call_count == 1 + r_session = retrieve_mock.call_args.args[-1] + assert r_session.headers == expected_headers diff --git a/tests/test_validations.py b/tests/test_validations.py new file mode 100644 index 0000000..37ddfb7 --- /dev/null +++ b/tests/test_validations.py @@ -0,0 +1,149 @@ +""" +Test cases for the 'validations' module +""" +import json +import os +import pathlib +import unittest.mock +from copy import copy +from typing import Dict +from typing import Generator + +import pystac +import pytest +import requests + +from stac_api_validator import validations + + +@pytest.fixture +def r_session() -> Generator[requests.Session, None, None]: + yield requests.Session() + + +@pytest.fixture +def catalog_dict() -> Generator[Dict[str, str], None, None]: + current_path = pathlib.Path(os.path.dirname(os.path.abspath(__file__))) + + with open(current_path / "resources" / "landing_page.json") as f: + # Load the contents of the file into a Python dictionary + data = json.load(f) + + yield data + + +@pytest.fixture +def sample_item() -> Generator[pystac.Item, None, None]: + current_path = pathlib.Path(os.path.dirname(os.path.abspath(__file__))) + + with open(current_path / "resources" / "sample-item.json") as f: + # Load the contents of the file into a Python dictionary + data = json.load(f) + + yield pystac.Item.from_dict(data) + + +@pytest.fixture +def expected_headers() -> Generator[Dict[str, str], None, None]: + yield { + "User-Agent": "python-requests/2.28.2", + "Accept-Encoding": "gzip, deflate", + "Accept": "*/*", + "Connection": "keep-alive", + "Authorization": "api-key fake-api-key-value", + } + + +def test_get_catalog( + r_session: requests.Session, + catalog_dict: Dict[str, str], + expected_headers: Dict[str, str], +) -> None: + r_session.headers = copy(expected_headers) # type: ignore + expected_headers.update({"Accept-Encoding": "*"}) + + catalog = validations.get_catalog(catalog_dict, r_session) + assert catalog._stac_io.headers == expected_headers # type: ignore + + +def test_retrieve( + r_session: requests.Session, expected_headers: Dict[str, str] +) -> None: + headers = {"Authorization": "api-key fake-api-key-value"} + r_session.send = unittest.mock.MagicMock() # type: ignore + r_session.send.status_code = 500 + + validations.retrieve( + validations.Method.GET, + "https://invalid", + validations.Errors(), + validations.Context.CORE, + r_session=r_session, + headers=headers, + ) + assert r_session.send.call_count == 1 + prepared_request_headers = r_session.send.call_args_list[0].args[0].headers + assert prepared_request_headers == expected_headers + + +def test_validate_api( + request: pytest.FixtureRequest, + r_session: requests.Session, + expected_headers: Dict[str, str], +) -> None: + if request.config.getoption("typeguard_packages"): + pytest.skip( + "The import hook that typeguard uses seems to break the mock below." + ) + headers = {"Authorization": "api-key fake-api-key-value"} + + with unittest.mock.patch( + "stac_api_validator.validations.retrieve" + ) as retrieve_mock: + retrieve_mock.return_value = None, None, None + validations.validate_api( + "https://invalid", + ccs_to_validate=["core"], + collection=None, + geometry=None, + auth_bearer_token=None, + auth_query_parameter=None, + fields_nested_property=None, + validate_pagination=None, + query_config=None, + transaction_collection=None, + headers=headers, + ) + assert retrieve_mock.call_count == 1 + r_session = retrieve_mock.call_args.args[-1] + assert r_session.headers == expected_headers + + +def test_validate_browseable( + request: pytest.FixtureRequest, + r_session: requests.Session, + catalog_dict: Dict[str, str], + sample_item: pystac.Item, + expected_headers: Dict[str, str], +) -> None: + if request.config.getoption("typeguard_packages"): + pytest.skip( + "The import hook that typeguard uses seems to break the mock below." + ) + + r_session.headers = copy(expected_headers) # type: ignore + + with unittest.mock.patch( + "stac_api_validator.validations.get_catalog" + ) as get_catalog_mock: + get_catalog_mock.get_all_items.return_value = [sample_item] + + validations.validate_browseable( + catalog_dict, + errors=validations.Errors(), + warnings=validations.Warnings(), + r_session=r_session, + ) + assert get_catalog_mock.call_count == 1 + session_from_mock = get_catalog_mock.call_args.args[-1] + assert session_from_mock.headers == expected_headers