Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for generic auth headers with pystac + validators #392

Merged
merged 8 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def mypy(session: Session) -> None:
"""Type-check using mypy."""
args = session.posargs or ["src", "tests", "docs/conf.py"]
session.install(".")
session.install("mypy", "pytest", "types-requests", "types-PyYAML")
session.install("mypy", "pytest", "types-requests", "types-PyYAML", "typeguard")
session.run("mypy", *args)
if not session.posargs:
session.run("mypy", f"--python-executable={sys.executable}", "noxfile.py")
Expand All @@ -162,7 +162,7 @@ def mypy(session: Session) -> None:
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("coverage[toml]", "pytest", "pygments")
session.install("coverage[toml]", "pytest", "pygments", "typeguard")
try:
session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs)
finally:
Expand Down
17 changes: 17 additions & 0 deletions src/stac_api_validator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@
"--transaction-collection",
help="The name of the collection to use for Transaction Extension tests.",
)
@click.option(
"-H",
"--headers",
multiple=True,
help="Headers to attach to the main request and dependent pystac requests, curl syntax",
)
def main(
log_level: str,
root_url: str,
Expand All @@ -155,11 +161,21 @@ def main(
query_in_field: Optional[str] = None,
query_in_values: Optional[str] = None,
transaction_collection: Optional[str] = None,
headers: Optional[List[str]] = None,
) -> int:
"""STAC API Validator."""
logging.basicConfig(stream=sys.stdout, level=log_level)

try:
processed_headers = {}
if headers:
processed_headers.update(
{
key.strip(): value.strip()
for key, value in (header.split(":") for header in headers)
}
)

(warnings, errors) = validate_api(
root_url=root_url,
ccs_to_validate=conformance_classes,
Expand All @@ -185,6 +201,7 @@ def main(
query_in_values,
),
transaction_collection=transaction_collection,
headers=processed_headers,
)
except Exception as e:
click.secho(
Expand Down
30 changes: 23 additions & 7 deletions src/stac_api_validator/validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pystac import Collection
from pystac import Item
from pystac import ItemCollection
from pystac import StacIO
from pystac import STACValidationError
from pystac_client import Client
from requests import Request
Expand Down Expand Up @@ -298,6 +299,16 @@ def is_geojson_type(maybe_type: Optional[str]) -> bool:
)


def get_catalog(data_dict: Dict[str, Any], r_session: Session) -> Catalog:
stac_io = StacIO.default()
if r_session.headers and r_session.headers.get("Authorization"):
stac_io.headers = r_session.headers # noqa, type: ignore
stac_io.headers["Accept-Encoding"] = "*"
catalog = Catalog.from_dict(data_dict)
catalog._stac_io = stac_io
return catalog


# def is_json_or_geojson_type(maybe_type: Optional[str]) -> bool:
# return maybe_type and (is_json_type(maybe_type) or is_geojson_type(maybe_type))

Expand Down Expand Up @@ -381,9 +392,8 @@ def retrieve(
additional: Optional[str] = "",
content_type: Optional[str] = None,
) -> Tuple[int, Optional[Dict[str, Any]], Optional[Mapping[str, str]]]:
resp = r_session.send(
Request(method.value, url, headers=headers, params=params, json=body).prepare()
)
request = Request(method.value, url, headers=headers, params=params, json=body)
resp = r_session.send(r_session.prepare_request(request))

# todo: handle connection exception, etc.
# todo: handle timeout
Expand Down Expand Up @@ -537,6 +547,7 @@ def validate_api(
validate_pagination: bool,
query_config: QueryConfig,
transaction_collection: Optional[str],
headers: Optional[Dict[str, str]],
) -> Tuple[Warnings, Errors]:
warnings = Warnings()
errors = Errors()
Expand All @@ -548,6 +559,9 @@ def validate_api(
if auth_query_parameter and (xs := auth_query_parameter.split("=", 1)):
r_session.params = {xs[0]: xs[1]}

if headers:
r_session.headers.update(headers)

_, landing_page_body, landing_page_headers = retrieve(
Method.GET, root_url, errors, Context.CORE, r_session
)
Expand Down Expand Up @@ -704,7 +718,7 @@ def validate_api(

if not errors:
try:
catalog = Client.open(root_url)
catalog = Client.open(root_url, headers=headers)
catalog.validate()
for child in catalog.get_children():
child.validate()
Expand Down Expand Up @@ -811,7 +825,8 @@ def validate_core(
# this validates, among other things, that the child and item link relations reference
# valid STAC Catalogs, Collections, and/or Items
try:
list(take(1000, Catalog.from_dict(root_body).get_all_items()))
catalog = get_catalog(root_body, r_session)
list(take(1000, catalog.get_all_items()))
except pystac.errors.STACTypeError as e:
errors += (
f"[{Context.CORE}] Error while traversing Catalog child/item links to find Items: {e} "
Expand Down Expand Up @@ -839,14 +854,15 @@ def validate_browseable(
# check that at least a few of the items that can be reached from child/item link relations
# can be found through search
try:
for item in take(10, Catalog.from_dict(root_body).get_all_items()):
catalog = get_catalog(root_body, r_session)
for item in take(10, catalog.get_all_items()):
if link := link_by_rel(root_body.get("links"), "search"):
_, body, _ = retrieve(
Method.GET,
link["href"],
errors,
Context.BROWSEABLE,
params={"ids": item.id, "collections": item.collection},
params={"ids": item.id, "collections": item.collection_id},
r_session=r_session,
)
if body and len(body.get("features", [])) != 1:
Expand Down
53 changes: 53 additions & 0 deletions tests/resources/sample-item.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"type": "Feature",
"stac_version": "1.0.0",
"id": "CS3-20160503_132131_05",
"properties": {
"datetime": "2016-05-03T13:22:30.040000Z",
"title": "A CS3 item",
"license": "PDDL-1.0",
"providers": [
{
"name": "CoolSat",
"roles": ["producer", "licensor"],
"url": "https://cool-sat.com/"
}
]
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-122.308150179, 37.488035566],
[-122.597502109, 37.538869539],
[-122.576687533, 37.613537207],
[-122.2880486, 37.562818007],
[-122.308150179, 37.488035566]
]
]
},
"links": [
{
"rel": "collection",
"href": "https://raw.githubusercontent.com/radiantearth/stac-spec/v0.8.1/collection-spec/examples/sentinel2.json"
}
],
"assets": {
"analytic": {
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/analytic.tif",
"title": "4-Band Analytic",
"product": "http://cool-sat.com/catalog/products/analytic.json",
"type": "image/tiff; application=geotiff; profile=cloud-optimized",
"roles": ["data", "analytic"]
},
"thumbnail": {
"href": "http://cool-sat.com/catalog/CS3-20160503_132130_04/thumbnail.png",
"title": "Thumbnail",
"type": "image/png",
"roles": ["thumbnail"]
}
},
"bbox": [-122.59750209, 37.48803556, -122.2880486, 37.613537207],
"stac_extensions": [],
"collection": "CS3"
}
37 changes: 37 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test cases for the __main__ module."""

import unittest.mock

import pytest
from click.testing import CliRunner

Expand All @@ -15,3 +17,38 @@ def runner() -> CliRunner:
def test_main_fails(runner: CliRunner) -> None:
result = runner.invoke(__main__.main)
assert result.exit_code == 2


def test_retrieve_called_with_auth_headers(
request: pytest.FixtureRequest, runner: CliRunner
) -> None:
if request.config.getoption("typeguard_packages"):
pytest.skip(
"The import hook that typeguard uses seems to break the mock below."
)

expected_headers = {
"User-Agent": "python-requests/2.32.3",
gadomski marked this conversation as resolved.
Show resolved Hide resolved
"Accept-Encoding": "gzip, deflate",
"Accept": "*/*",
"Connection": "keep-alive",
"Authorization": "api-key fake-api-key-value",
}

with unittest.mock.patch(
"stac_api_validator.validations.retrieve"
) as retrieve_mock:
runner.invoke(
__main__.main,
args=[
"--root-url",
"https://invalid",
"--conformance",
"core",
"-H",
"Authorization: api-key fake-api-key-value",
],
)
assert retrieve_mock.call_count == 1
r_session = retrieve_mock.call_args.args[-1]
assert r_session.headers == expected_headers
Loading
Loading