-
Notifications
You must be signed in to change notification settings - Fork 33
Pull database logic out of core.py and transaction.py #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
aff6bf1
a6f6ef7
22078aa
26c1290
1e260c3
bbbc7f1
514045f
f9dc26b
8234a64
fb643db
8d4a062
7f62f10
ce2d5d2
ac07418
150fbd5
247da4b
e050ac7
2bad487
520d88d
2819621
2225e76
ad6f393
cc96d23
d7b0668
f62de17
83c2843
cf47504
6692323
33b9c49
e2d2f28
20ac8d9
7770152
827c68a
cdd580b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,8 +6,6 @@ | |
| from urllib.parse import urljoin | ||
|
|
||
| import attr | ||
| import elasticsearch | ||
| from elasticsearch_dsl import Q, Search | ||
| from fastapi import HTTPException | ||
| from overrides import overrides | ||
|
|
||
|
|
@@ -18,25 +16,17 @@ | |
|
|
||
| from stac_fastapi.elasticsearch import serializers | ||
| from stac_fastapi.elasticsearch.config import ElasticsearchSettings | ||
| from stac_fastapi.elasticsearch.database_logic import CoreDatabaseLogic | ||
| from stac_fastapi.elasticsearch.session import Session | ||
|
|
||
| # from stac_fastapi.elasticsearch.types.error_checks import ErrorChecks | ||
| from stac_fastapi.types.core import BaseCoreClient | ||
| from stac_fastapi.types.errors import NotFoundError | ||
| from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| NumType = Union[float, int] | ||
|
|
||
| ITEMS_INDEX = "stac_items" | ||
| COLLECTIONS_INDEX = "stac_collections" | ||
|
|
||
|
|
||
| def mk_item_id(item_id: str, collection_id: str): | ||
| """Make the Elasticsearch document _id value from the Item id and collection.""" | ||
| return f"{item_id}|{collection_id}" | ||
|
|
||
|
|
||
| @attr.s | ||
| class CoreCrudClient(BaseCoreClient): | ||
|
|
@@ -51,23 +41,14 @@ class CoreCrudClient(BaseCoreClient): | |
| ) | ||
| settings = ElasticsearchSettings() | ||
| client = settings.create_client | ||
| database = CoreDatabaseLogic() | ||
|
||
|
|
||
| @overrides | ||
| def all_collections(self, **kwargs) -> Collections: | ||
| """Read all collections from the database.""" | ||
| base_url = str(kwargs["request"].base_url) | ||
| try: | ||
| collections = self.client.search( | ||
| index=COLLECTIONS_INDEX, query={"match_all": {}} | ||
| ) | ||
| except elasticsearch.exceptions.NotFoundError: | ||
| raise NotFoundError("No collections exist") | ||
| serialized_collections = [ | ||
| self.collection_serializer.db_to_stac( | ||
| collection["_source"], base_url=base_url | ||
| ) | ||
| for collection in collections["hits"]["hits"] | ||
| ] | ||
| serialized_collections = self.database.get_all_collections(base_url=base_url) | ||
|
|
||
| links = [ | ||
| { | ||
| "rel": Relations.root.value, | ||
|
|
@@ -94,12 +75,8 @@ def all_collections(self, **kwargs) -> Collections: | |
| def get_collection(self, collection_id: str, **kwargs) -> Collection: | ||
| """Get collection by id.""" | ||
| base_url = str(kwargs["request"].base_url) | ||
| try: | ||
| collection = self.client.get(index=COLLECTIONS_INDEX, id=collection_id) | ||
| except elasticsearch.exceptions.NotFoundError: | ||
| raise NotFoundError(f"Collection {collection_id} not found") | ||
|
|
||
| return self.collection_serializer.db_to_stac(collection["_source"], base_url) | ||
| collection = self.database.get_one_collection(collection_id) | ||
| return self.collection_serializer.db_to_stac(collection, base_url) | ||
|
|
||
| @overrides | ||
| def item_collection( | ||
|
|
@@ -108,24 +85,10 @@ def item_collection( | |
| """Read an item collection from the database.""" | ||
| links = [] | ||
| base_url = str(kwargs["request"].base_url) | ||
| search = Search(using=self.client, index="stac_items") | ||
|
|
||
| collection_filter = Q( | ||
| "bool", should=[Q("match_phrase", **{"collection": collection_id})] | ||
| serialized_children, count = self.database.get_item_collection( | ||
| collection_id=collection_id, limit=limit, base_url=base_url | ||
| ) | ||
| search = search.query(collection_filter) | ||
| try: | ||
| count = search.count() | ||
| except elasticsearch.exceptions.NotFoundError: | ||
| raise NotFoundError("No items exist") | ||
| # search = search.sort({"id.keyword" : {"order" : "asc"}}) | ||
| search = search.query()[0:limit] | ||
| collection_children = search.execute().to_dict() | ||
|
|
||
| serialized_children = [ | ||
| self.item_serializer.db_to_stac(item["_source"], base_url=base_url) | ||
| for item in collection_children["hits"]["hits"] | ||
| ] | ||
|
|
||
| context_obj = None | ||
| if self.extension_is_enabled("ContextExtension"): | ||
|
|
@@ -146,15 +109,8 @@ def item_collection( | |
| def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: | ||
| """Get item by item id, collection id.""" | ||
| base_url = str(kwargs["request"].base_url) | ||
| try: | ||
| item = self.client.get( | ||
| index=ITEMS_INDEX, id=mk_item_id(item_id, collection_id) | ||
| ) | ||
| except elasticsearch.exceptions.NotFoundError: | ||
| raise NotFoundError( | ||
| f"Item {item_id} does not exist in Collection {collection_id}" | ||
| ) | ||
| return self.item_serializer.db_to_stac(item["_source"], base_url) | ||
| item = self.database.get_one_item(item_id=item_id, collection_id=collection_id) | ||
| return self.item_serializer.db_to_stac(item, base_url) | ||
|
|
||
| @staticmethod | ||
| def _return_date(interval_str): | ||
|
|
@@ -238,125 +194,63 @@ def get_search( | |
|
|
||
| return resp | ||
|
|
||
| @staticmethod | ||
| def bbox2poly(b0, b1, b2, b3): | ||
| """Transform bbox to polygon.""" | ||
| poly = [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] | ||
| return poly | ||
|
|
||
| def post_search(self, search_request: Search, **kwargs) -> ItemCollection: | ||
| def post_search(self, search_request, **kwargs) -> ItemCollection: | ||
| """POST search catalog.""" | ||
| base_url = str(kwargs["request"].base_url) | ||
| search = ( | ||
| Search() | ||
| .using(self.client) | ||
| .index(ITEMS_INDEX) | ||
| .sort( | ||
| {"properties.datetime": {"order": "desc"}}, | ||
| {"id": {"order": "desc"}}, | ||
| {"collection": {"order": "desc"}}, | ||
| ) | ||
| ) | ||
| search = self.database.create_search_object() | ||
|
|
||
| if search_request.query: | ||
| if type(search_request.query) == str: | ||
| search_request.query = json.loads(search_request.query) | ||
| for (field_name, expr) in search_request.query.items(): | ||
| field = "properties__" + field_name | ||
| for (op, value) in expr.items(): | ||
| if op != "eq": | ||
| key_filter = {field: {f"{op}": value}} | ||
| search = search.query(Q("range", **key_filter)) | ||
| else: | ||
| search = search.query("match_phrase", **{field: value}) | ||
| search = self.database.create_query_filter( | ||
| search=search, op=op, field=field, value=value | ||
| ) | ||
|
|
||
| if search_request.ids: | ||
| id_list = [] | ||
| for item_id in search_request.ids: | ||
| id_list.append(Q("match_phrase", **{"id": item_id})) | ||
| id_filter = Q("bool", should=id_list) | ||
| search = search.query(id_filter) | ||
| search = self.database.search_ids( | ||
| search=search, item_ids=search_request.ids | ||
| ) | ||
|
|
||
| if search_request.collections: | ||
| collection_list = [] | ||
| for collection_id in search_request.collections: | ||
| collection_list.append( | ||
| Q("match_phrase", **{"collection": collection_id}) | ||
| ) | ||
| collection_filter = Q("bool", should=collection_list) | ||
| search = search.query(collection_filter) | ||
| search = self.database.search_collections( | ||
| search=search, collection_ids=search_request.collections | ||
| ) | ||
|
|
||
| if search_request.datetime: | ||
| datetime_search = self._return_date(search_request.datetime) | ||
| if "eq" in datetime_search: | ||
| search = search.query( | ||
| "match_phrase", **{"properties__datetime": datetime_search["eq"]} | ||
| ) | ||
| else: | ||
| search = search.filter( | ||
| "range", properties__datetime={"lte": datetime_search["lte"]} | ||
| ) | ||
| search = search.filter( | ||
| "range", properties__datetime={"gte": datetime_search["gte"]} | ||
| ) | ||
| search = self.database.search_datetime( | ||
| search=search, datetime_search=datetime_search | ||
| ) | ||
|
|
||
| if search_request.bbox: | ||
| bbox = search_request.bbox | ||
| if len(bbox) == 6: | ||
| bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] | ||
| poly = self.bbox2poly(bbox[0], bbox[1], bbox[2], bbox[3]) | ||
|
|
||
| bbox_filter = Q( | ||
| { | ||
| "geo_shape": { | ||
| "geometry": { | ||
| "shape": {"type": "polygon", "coordinates": poly}, | ||
| "relation": "intersects", | ||
| } | ||
| } | ||
| } | ||
| ) | ||
| search = search.query(bbox_filter) | ||
|
|
||
| search = self.database.search_bbox(search=search, bbox=bbox) | ||
|
|
||
| if search_request.intersects: | ||
| intersect_filter = Q( | ||
| { | ||
| "geo_shape": { | ||
| "geometry": { | ||
| "shape": { | ||
| "type": search_request.intersects.type.lower(), | ||
| "coordinates": search_request.intersects.coordinates, | ||
| }, | ||
| "relation": "intersects", | ||
| } | ||
| } | ||
| } | ||
| self.database.search_intersects( | ||
| search=search, intersects=search_request.intersects | ||
| ) | ||
| search = search.query(intersect_filter) | ||
|
|
||
| if search_request.sortby: | ||
| for sort in search_request.sortby: | ||
| if sort.field == "datetime": | ||
| sort.field = "properties__datetime" | ||
| field = sort.field + ".keyword" | ||
| search = search.sort({field: {"order": sort.direction}}) | ||
| search = self.database.sort_field( | ||
| search=search, field=field, direction=sort.direction | ||
| ) | ||
|
|
||
| try: | ||
| count = search.count() | ||
| except elasticsearch.exceptions.NotFoundError: | ||
| raise NotFoundError("No items exist") | ||
|
|
||
| # search = search.sort({"id.keyword" : {"order" : "asc"}}) | ||
| search = search.query()[0 : search_request.limit] | ||
| response = search.execute().to_dict() | ||
|
|
||
| if len(response["hits"]["hits"]) > 0: | ||
| response_features = [ | ||
| self.item_serializer.db_to_stac(item["_source"], base_url=base_url) | ||
| for item in response["hits"]["hits"] | ||
| ] | ||
| else: | ||
| response_features = [] | ||
| count = self.database.search_count(search=search) | ||
|
|
||
| response_features = self.database.execute_search( | ||
| search=search, limit=search_request.limit, base_url=base_url | ||
| ) | ||
|
|
||
| # if self.extension_is_enabled("FieldsExtension"): | ||
| # if search_request.query is not None: | ||
|
|
@@ -384,7 +278,7 @@ def post_search(self, search_request: Search, **kwargs) -> ItemCollection: | |
| else: | ||
| limit = 10 | ||
| response_features = response_features[0:limit] | ||
| limit = 10 | ||
|
|
||
| context_obj = None | ||
| if self.extension_is_enabled("ContextExtension"): | ||
| context_obj = { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My linux laptop doesn't like this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, i'm running on macos so probably why i didn't see it