Skip to content

Commit

Permalink
feat: allow generic auth headers with pystac + validators
Browse files Browse the repository at this point in the history
  • Loading branch information
john-dupuy authored and gadomski committed Apr 19, 2024
1 parent bf6d985 commit dda1a2d
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 9 deletions.
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def mypy(session: Session) -> None:
"""Type-check using mypy."""
args = session.posargs or ["src", "tests", "docs/conf.py"]
session.install(".")
session.install("mypy", "pytest", "types-requests", "types-PyYAML")
session.install("mypy", "pytest", "types-requests", "types-PyYAML", "typeguard")
session.run("mypy", *args)
if not session.posargs:
session.run("mypy", f"--python-executable={sys.executable}", "noxfile.py")
Expand All @@ -162,7 +162,7 @@ def mypy(session: Session) -> None:
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("coverage[toml]", "pytest", "pygments")
session.install("coverage[toml]", "pytest", "pygments", "typeguard")
try:
session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs)
finally:
Expand Down
17 changes: 17 additions & 0 deletions src/stac_api_validator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@
"--transaction-collection",
help="The name of the collection to use for Transaction Extension tests.",
)
@click.option(
"-H",
"--headers",
multiple=True,
help="Headers to attach to the main request and dependent pystac requests, curl syntax",
)
def main(
log_level: str,
root_url: str,
Expand All @@ -155,11 +161,21 @@ def main(
query_in_field: Optional[str] = None,
query_in_values: Optional[str] = None,
transaction_collection: Optional[str] = None,
headers: Optional[List[str]] = None,
) -> int:
"""STAC API Validator."""
logging.basicConfig(stream=sys.stdout, level=log_level)

try:
processed_headers = {}
if headers:
processed_headers.update(
{
key.strip(): value.strip()
for key, value in (header.split(":") for header in headers)
}
)

(warnings, errors) = validate_api(
root_url=root_url,
ccs_to_validate=conformance_classes,
Expand All @@ -185,6 +201,7 @@ def main(
query_in_values,
),
transaction_collection=transaction_collection,
headers=processed_headers,
)
except Exception as e:
click.secho(
Expand Down
30 changes: 23 additions & 7 deletions src/stac_api_validator/validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pystac import Collection
from pystac import Item
from pystac import ItemCollection
from pystac import StacIO
from pystac import STACValidationError
from pystac_client import Client
from requests import Request
Expand Down Expand Up @@ -298,6 +299,16 @@ def is_geojson_type(maybe_type: Optional[str]) -> bool:
)


def get_catalog(data_dict: Dict[str, Any], r_session: Session) -> Catalog:
stac_io = StacIO.default()
if r_session.headers and r_session.headers.get("Authorization"):
stac_io.headers = r_session.headers # noqa, type: ignore
stac_io.headers["Accept-Encoding"] = "*"
catalog = Catalog.from_dict(data_dict)
catalog._stac_io = stac_io
return catalog


# def is_json_or_geojson_type(maybe_type: Optional[str]) -> bool:
# return maybe_type and (is_json_type(maybe_type) or is_geojson_type(maybe_type))

Expand Down Expand Up @@ -381,9 +392,8 @@ def retrieve(
additional: Optional[str] = "",
content_type: Optional[str] = None,
) -> Tuple[int, Optional[Dict[str, Any]], Optional[Mapping[str, str]]]:
resp = r_session.send(
Request(method.value, url, headers=headers, params=params, json=body).prepare()
)
request = Request(method.value, url, headers=headers, params=params, json=body)
resp = r_session.send(r_session.prepare_request(request))

# todo: handle connection exception, etc.
# todo: handle timeout
Expand Down Expand Up @@ -537,6 +547,7 @@ def validate_api(
validate_pagination: bool,
query_config: QueryConfig,
transaction_collection: Optional[str],
headers: Optional[Dict[str, str]],
) -> Tuple[Warnings, Errors]:
warnings = Warnings()
errors = Errors()
Expand All @@ -548,6 +559,9 @@ def validate_api(
if auth_query_parameter and (xs := auth_query_parameter.split("=", 1)):
r_session.params = {xs[0]: xs[1]}

if headers:
r_session.headers.update(headers)

_, landing_page_body, landing_page_headers = retrieve(
Method.GET, root_url, errors, Context.CORE, r_session
)
Expand Down Expand Up @@ -704,7 +718,7 @@ def validate_api(

if not errors:
try:
catalog = Client.open(root_url)
catalog = Client.open(root_url, headers=headers)
catalog.validate()
for child in catalog.get_children():
child.validate()
Expand Down Expand Up @@ -811,7 +825,8 @@ def validate_core(
# this validates, among other things, that the child and item link relations reference
# valid STAC Catalogs, Collections, and/or Items
try:
list(take(1000, Catalog.from_dict(root_body).get_all_items()))
catalog = get_catalog(root_body, r_session)
list(take(1000, catalog.get_all_items()))
except pystac.errors.STACTypeError as e:
errors += (
f"[{Context.CORE}] Error while traversing Catalog child/item links to find Items: {e} "
Expand Down Expand Up @@ -839,14 +854,15 @@ def validate_browseable(
# check that at least a few of the items that can be reached from child/item link relations
# can be found through search
try:
for item in take(10, Catalog.from_dict(root_body).get_all_items()):
catalog = get_catalog(root_body, r_session)
for item in take(10, catalog.get_all_items()):
if link := link_by_rel(root_body.get("links"), "search"):
_, body, _ = retrieve(
Method.GET,
link["href"],
errors,
Context.BROWSEABLE,
params={"ids": item.id, "collections": item.collection},
params={"ids": item.id, "collections": item.collection_id},
r_session=r_session,
)
if body and len(body.get("features", [])) != 1:
Expand Down
53 changes: 53 additions & 0 deletions tests/resources/sample-item.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"type": "Feature",
"stac_version": "1.0.0",
"id": "CS3-20160503_132131_05",
"properties": {
"datetime": "2016-05-03T13:22:30.040000Z",
"title": "A CS3 item",
"license": "PDDL-1.0",
"providers": [
{
"name": "CoolSat",
"roles": ["producer", "licensor"],
"url": "https://cool-sat.com/"
}
]
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-122.308150179, 37.488035566],
[-122.597502109, 37.538869539],
[-122.576687533, 37.613537207],
[-122.2880486, 37.562818007],
[-122.308150179, 37.488035566]
]
]
},
"links": [
{
"rel": "collection",
"href": "https://raw.githubusercontent.com/radiantearth/stac-spec/v0.8.1/collection-spec/examples/sentinel2.json"
}
],
"assets": {
"analytic": {
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/analytic.tif",
"title": "4-Band Analytic",
"product": "http://cool-sat.com/catalog/products/analytic.json",
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
"roles": ["data", "analytic"]
},
"thumbnail": {
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/thumbnail.png",
"title": "Thumbnail",
"type": "image/png",
"roles": ["thumbnail"]
}
},
"bbox": [-122.59750209, 37.48803556, -122.2880486, 37.613537207],
"stac_extensions": [],
"collection": "CS3"
}
36 changes: 36 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test cases for the __main__ module."""
import unittest.mock

import pytest
from click.testing import CliRunner
Expand All @@ -15,3 +16,38 @@ def runner() -> CliRunner:
def test_main_fails(runner: CliRunner) -> None:
result = runner.invoke(__main__.main)
assert result.exit_code == 2


def test_retrieve_called_with_auth_headers(
request: pytest.FixtureRequest, runner: CliRunner
) -> None:
if request.config.getoption("typeguard_packages"):
pytest.skip(
"The import hook that typeguard uses seems to break the mock below."
)

expected_headers = {
"User-Agent": "python-requests/2.28.2",
"Accept-Encoding": "gzip, deflate",
"Accept": "*/*",
"Connection": "keep-alive",
"Authorization": "api-key fake-api-key-value",
}

with unittest.mock.patch(
"stac_api_validator.validations.retrieve"
) as retrieve_mock:
runner.invoke(
__main__.main,
args=[
"--root-url",
"https://invalid",
"--conformance",
"core",
"-H",
"Authorization: api-key fake-api-key-value",
],
)
assert retrieve_mock.call_count == 1
r_session = retrieve_mock.call_args.args[-1]
assert r_session.headers == expected_headers
Loading

0 comments on commit dda1a2d

Please sign in to comment.