Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/stac_fastapi/indexed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from stac_pydantic.shared import BBox

from stac_fastapi.indexed.constants import rel_parent, rel_root, rel_self
from stac_fastapi.indexed.db import fetchall, fetchone
from stac_fastapi.indexed.db import fetchall, fetchone, format_query_object_name
from stac_fastapi.indexed.links.catalog import get_catalog_link
from stac_fastapi.indexed.links.collection import (
fix_collection_links,
Expand Down Expand Up @@ -53,7 +53,7 @@ async def get_collection(
self, collection_id: str, request: Request, **kwargs
) -> Collection:
row = fetchone(
"SELECT stac_location FROM collections WHERE id = ?",
f"SELECT stac_location FROM {format_query_object_name('collections')} WHERE id = ?",
[collection_id],
)
if row is not None:
Expand Down Expand Up @@ -102,7 +102,7 @@ async def get_item(
collection_id, request=request
) # will error if collection does not exist
row = fetchone(
"SELECT stac_location, applied_fixes FROM items WHERE collection_id = ? and id = ?",
f"SELECT stac_location, applied_fixes FROM {format_query_object_name('items')} WHERE collection_id = ? and id = ?",
[collection_id, item_id],
)
if row is not None:
Expand Down Expand Up @@ -200,7 +200,10 @@ def _get_minimal_collections_response(self) -> Collections:
collections=[
Collection(**{"id": id})
for id in [
row[0] for row in fetchall("SELECT id FROM collections ORDER BY id")
row[0]
for row in fetchall(
f"SELECT id FROM {format_query_object_name('collections')} ORDER BY id"
)
]
],
links=[],
Expand All @@ -211,7 +214,9 @@ async def _get_full_collections_response(self, request: Request) -> Collections:
fetch_dict(url)
for url in [
row[0]
for row in fetchall("SELECT stac_location FROM collections ORDER BY id")
for row in fetchall(
f"SELECT stac_location FROM {format_query_object_name('collections')} ORDER BY id"
)
]
]
collections = [
Expand Down
61 changes: 48 additions & 13 deletions src/stac_fastapi/indexed/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from logging import Logger, getLogger
from os import environ
from time import time
from typing import Any, Final, List, Optional
from typing import Any, Dict, Final, List, Optional

from duckdb import DuckDBPyConnection
from duckdb import connect as duckdb_connect
Expand All @@ -11,7 +12,14 @@

_logger: Final[Logger] = getLogger(__name__)
_query_timing_precision: Final[int] = 3
_query_object_identifier_prefix: Final[str] = "src:"
_query_object_identifier_suffix: Final[str] = ":src"
_query_object_identifier_template: Final[str] = (
f"{_query_object_identifier_prefix}{{}}{_query_object_identifier_suffix}"
)

_root_db_connection: DuckDBPyConnection = None
_parquet_uris: Dict[str, str] = {}


async def connect_to_db() -> None:
Expand Down Expand Up @@ -40,22 +48,21 @@ async def connect_to_db() -> None:
# Dockerfiles pre-install extensions, so don't need installing here.
# Local debug (e.g. running in vscode) still requires this install.
execute("INSTALL spatial")
times["install spatial extension"]
times["install spatial extension"] = time()
execute("INSTALL httpfs")
times["install httpfs extension"]
times["install httpfs extension"] = time()
execute("LOAD spatial")
times["load spatial extension"] = time()
execute("LOAD httpfs")
times["load httpfs extension"] = time()
duckdb_thread_count = settings.duckdb_threads
if duckdb_thread_count:
_set_duckdb_threads(duckdb_thread_count)
parquet_uris = await index_source.get_parquet_uris()
if len(parquet_uris.keys()) == 0:
global _parquet_uris
_parquet_uris = await index_source.get_parquet_uris()
times["get parquet URIs"] = time()
if len(_parquet_uris.keys()) == 0:
raise Exception(f"no URIs found from '{index_manifest_uri}'")
for view_name, source_uri in parquet_uris.items():
execute(f"CREATE VIEW {view_name} AS SELECT * FROM '{source_uri}'")
times["create views from parquet"] = time()
for operation, completed_at in times.items():
_logger.info(
"'{}' completed in {}s".format(
Expand All @@ -73,30 +80,58 @@ async def disconnect_from_db() -> None:
_logger.error(e)


def get_db_connection():
return _root_db_connection.cursor()
# SQL queries include placeholder strings that are replaced with Parquet URIs prior to query execution.
# This improves query performance relative to creating views in DuckDB from Parquet files and querying those.
# Placeholders are used until the point of query execution so that API search pagination tokens,
# which are JWT-encoded SQL queries and visible to the client, do not leak implementation detail around
# parquet URI locations.
def format_query_object_name(object_name: str) -> str:
return _query_object_identifier_template.format(object_name)


def execute(statement: str, params: Optional[List[Any]] = None) -> None:
start = time()
get_db_connection().execute(statement, params)
statement = _prepare_statement(statement)
_get_db_connection().execute(statement, params)
_sql_log_message(statement, time() - start, None, params)


def fetchone(statement: str, params: Optional[List[Any]] = None) -> Any:
start = time()
result = get_db_connection().execute(statement, params).fetchone()
statement = _prepare_statement(statement)
result = _get_db_connection().execute(statement, params).fetchone()
_sql_log_message(statement, time() - start, 1 if result is not None else 0, params)
return result


def fetchall(statement: str, params: Optional[List[Any]] = None) -> List[Any]:
start = time()
result = get_db_connection().execute(statement, params).fetchall()
statement = _prepare_statement(statement)
result = _get_db_connection().execute(statement, params).fetchall()
_sql_log_message(statement, time() - start, len(result), params)
return result


def _get_db_connection():
return _root_db_connection.cursor()


def _prepare_statement(statement: str) -> str:
query_object_identifier_regex = rf"\b{re.escape(_query_object_identifier_prefix)}([^:]+){re.escape(_query_object_identifier_suffix)}\b"
for query_object_name in re.findall(query_object_identifier_regex, statement):
if query_object_name not in _parquet_uris:
_logger.warning(
f"{query_object_name} not in parquet URI map, query will likely fail"
)
continue
statement = re.sub(
rf"\b{re.escape(_query_object_identifier_prefix)}{re.escape(query_object_name)}{re.escape(_query_object_identifier_suffix)}\b",
f"'{_parquet_uris[query_object_name]}'",
statement,
)
return statement


def _sql_log_message(
statement: str,
duration: float,
Expand Down
6 changes: 3 additions & 3 deletions src/stac_fastapi/indexed/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from stac_index.common.indexing_error import IndexingError

from stac_fastapi.indexed.db import fetchall
from stac_fastapi.indexed.db import fetchall, format_query_object_name


def get_all_errors() -> List[IndexingError]:
Expand All @@ -18,7 +18,7 @@ def get_all_errors() -> List[IndexingError]:
collection=row[6],
item=row[7],
)
for row in fetchall("""
for row in fetchall(f"""
SELECT time
, error_type
, subtype
Expand All @@ -27,7 +27,7 @@ def get_all_errors() -> List[IndexingError]:
, possible_fixes
, collection
, item
FROM errors
FROM {format_query_object_name('errors')}
ORDER BY id
""")
]
8 changes: 4 additions & 4 deletions src/stac_fastapi/indexed/queryables/queryable_field_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from logging import Logger, getLogger
from typing import Dict, Final

from stac_fastapi.indexed.db import fetchall
from stac_fastapi.indexed.db import fetchall, format_query_object_name

_logger: Final[Logger] = getLogger(__name__)

Expand All @@ -25,7 +25,7 @@ def get_queryable_config_by_name() -> Dict[str, QueryableConfig]:
_logger.debug("fetching queryable field config")
field_config = {}
for row in fetchall(
"""
f"""
SELECT name
, qbc.collection_id
, qbc.description
Expand All @@ -34,8 +34,8 @@ def get_queryable_config_by_name() -> Dict[str, QueryableConfig]:
, icols.column_type as items_column_type
, icols.column_type = 'GEOMETRY' as is_geometry
, icols.column_type IN ('TIMESTAMP WITH TIME ZONE') as is_temporal
FROM queryables_by_collection qbc
INNER JOIN (DESCRIBE items) icols ON qbc.items_column = icols.column_name
FROM {format_query_object_name('queryables_by_collection')} qbc
INNER JOIN (DESCRIBE {format_query_object_name('items')}) icols ON qbc.items_column = icols.column_name
""",
):
field_config[row[0]] = QueryableConfig(
Expand Down
5 changes: 3 additions & 2 deletions src/stac_fastapi/indexed/search/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from stac_pydantic.api.extensions.sort import SortDirections, SortExtension

from stac_fastapi.indexed.constants import rel_root, rel_self
from stac_fastapi.indexed.db import fetchall
from stac_fastapi.indexed.db import fetchall, format_query_object_name
from stac_fastapi.indexed.links.catalog import get_catalog_link
from stac_fastapi.indexed.links.item import fix_item_links
from stac_fastapi.indexed.links.search import get_search_link, get_token_link
Expand Down Expand Up @@ -165,12 +165,13 @@ def _new_query(
params.extend(addition.params)
query = """
SELECT stac_location, applied_fixes
FROM items
FROM {table_name}
{where}
{order}
{{limit}}
{{offset}}
""".format(
table_name=format_query_object_name("items"),
where="WHERE {}".format(" AND ".join(clauses)) if len(clauses) > 0 else "",
order="ORDER BY {}".format(", ".join(sorts)),
)
Expand Down
1 change: 0 additions & 1 deletion src/stac_fastapi/indexed/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class _Settings(ApiSettings):
log_level: str = "info"
index_manifest_uri: str
token_jwt_secret: str
s3_endpoint: Optional[str] = None
duckdb_threads: Optional[int] = None
deployment_root_path: Optional[str] = None
install_duckdb_extensions: bool = (
Expand Down
6 changes: 3 additions & 3 deletions src/stac_fastapi/indexed/sortables/sortable_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from logging import Logger, getLogger
from typing import Dict, Final, List

from stac_fastapi.indexed.db import fetchall
from stac_fastapi.indexed.db import fetchall, format_query_object_name

_logger: Final[Logger] = getLogger(__name__)

Expand All @@ -27,12 +27,12 @@ def get_sortable_configs() -> List[SortableConfig]:
items_column=row[3],
)
for row in fetchall(
"""
f"""
SELECT name
, collection_id
, description
, items_column
FROM sortables_by_collection
FROM {format_query_object_name('sortables_by_collection')}
"""
)
]
Expand Down
8 changes: 4 additions & 4 deletions tests/with_environment/integration_tests/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from base64 import b64decode, b64encode
from base64 import urlsafe_b64decode, urlsafe_b64encode
from datetime import datetime
from glob import glob
from json import dumps, load, loads
Expand Down Expand Up @@ -116,8 +116,8 @@ def get_claims_from_token(token: str) -> Dict[str, Any]:
claims_part = token_parts[1]
missing_padding = len(claims_part) % 4
if missing_padding:
claims_part += "=" * (4 - missing_padding)
decoded_bytes = b64decode(claims_part)
claims_part += "=" * (-len(claims_part) % 4)
decoded_bytes = urlsafe_b64decode(claims_part)
return loads(decoded_bytes.decode("UTF-8"))


Expand All @@ -128,7 +128,7 @@ def rebuild_token_with_altered_claims(
assert len(token_parts) == 3
return "{}.{}.{}".format(
token_parts[0],
b64encode(dumps(altered_claims).encode("UTF-8")).decode("UTF-8"),
urlsafe_b64encode(dumps(altered_claims).encode("UTF-8")).decode("UTF-8"),
token_parts[2],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def test_collection_items_token_immutable() -> None:
token_match = match(r".+\?token=(.+)$", next_link["href"])
assert token_match
token = token_match.group(1)
token_claims = get_claims_from_token(token)
try:
token_claims = get_claims_from_token(token)
except Exception as e:
raise Exception(f"token decode failed on '{token}', link '{next_link}'", e)
assert "limit" in token_claims
assert token_claims["limit"] == limit
altered_claims = {
Expand Down