Skip to content

Commit d968058

Browse files
authored
Limit aggregators (#36)
* feat: add config to limit aggregators * fix: end of file line * feat: better logs if bad query * feat: reject aggregation requests for unspecific resources * docs: add warning for aggregation requests * feat: adapt tests * feat: update test * fix: fix metrics * fix: lint * docs: update changelog * docs: indicate what's expected in new config * fix: better swagger test, required pytest-mock
1 parent 1360f83 commit d968058

File tree

10 files changed

+123
-35
lines changed

10 files changed

+123
-35
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Current (in progress)
44

55
- Handle queries with aggregators [#35](https://github.com/datagouv/api-tabular/pull/35)
6+
- Restrain aggregators to list of specific resources [#36](https://github.com/datagouv/api-tabular/pull/36)
67

78
## 0.2.1 (2024-11-21)
89

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ column_name__max
161161
# sum
162162
column_name__sum
163163
```
164+
165+
> /!\ WARNING: aggregation requests are only available for resources that are listed in the `ALLOW_AGGREGATION` list of the config file.
166+
164167
> NB : passing an aggregation operator (`count`, `avg`, `min`, `max`, `sum`) returns a column that is named `<column_name>__<operator>` (for instance: `?birth__groupby&score__sum` will return a list of dicts with the keys `birth` and `score__sum`).
165168
166169
For instance:

api_tabular/app.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ async def resource_data(request):
9595
offset = 0
9696

9797
try:
98-
sql_query = build_sql_query_string(query_string, page_size, offset)
98+
sql_query = build_sql_query_string(query_string, resource_id, page_size, offset)
9999
except ValueError as e:
100100
raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}")
101-
101+
except PermissionError as e:
102+
raise QueryException(403, None, "Unauthorized parameters", str(e))
102103
resource = await get_resource(request.app["csession"], resource_id, ["parsing_table"])
103104
response, total = await get_resource_data(request.app["csession"], resource, sql_query)
104105

@@ -123,9 +124,11 @@ async def resource_data_csv(request):
123124
query_string = request.query_string.split("&") if request.query_string else []
124125

125126
try:
126-
sql_query = build_sql_query_string(query_string)
127+
sql_query = build_sql_query_string(query_string, resource_id)
127128
except ValueError:
128129
raise QueryException(400, None, "Invalid query string", "Malformed query")
130+
except PermissionError as e:
131+
raise QueryException(403, None, "Unauthorized parameters", str(e))
129132

130133
resource = await get_resource(request.app["csession"], resource_id, ["parsing_table"])
131134

api_tabular/config_default.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ PAGE_SIZE_DEFAULT = 20
66
PAGE_SIZE_MAX = 50
77
BATCH_SIZE = 50000
88
DOC_PATH = "/api/doc"
9+
ALLOW_AGGREGATION = [] # list of resource_ids

api_tabular/metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ async def metrics_data(request):
7878
else:
7979
offset = 0
8080
try:
81-
sql_query = build_sql_query_string(query_string, page_size, offset)
82-
except ValueError:
83-
raise QueryException(400, None, "Invalid query string", "Malformed query")
81+
sql_query = build_sql_query_string(query_string, page_size=page_size, offset=offset)
82+
except ValueError as e:
83+
raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}")
8484

8585
response, total = await get_object_data(request.app["csession"], model, sql_query)
8686

@@ -104,8 +104,8 @@ async def metrics_data_csv(request):
104104

105105
try:
106106
sql_query = build_sql_query_string(query_string)
107-
except ValueError:
108-
raise QueryException(400, None, "Invalid query string", "Malformed query")
107+
except ValueError as e:
108+
raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}")
109109

110110
response_headers = {
111111
"Content-Disposition": f'attachment; filename="{model}.csv"',

api_tabular/utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,40 @@
6969
"groupby": {
7070
"name": "{}__groupby",
7171
"description": "Performs `group by values` operation in column: {}",
72+
"is_aggregator": True,
7273
},
7374
"count": {
7475
"name": "{}__count",
7576
"description": "Performs `count values` operation in column: {}",
77+
"is_aggregator": True,
7678
},
7779
"avg": {
7880
"name": "{}__avg",
7981
"description": "Performs `mean` operation in column: {}",
82+
"is_aggregator": True,
8083
},
8184
"min": {
8285
"name": "{}__min",
8386
"description": "Performs `minimum` operation in column: {}",
87+
"is_aggregator": True,
8488
},
8589
"max": {
8690
"name": "{}__max",
8791
"description": "Performs `maximum` operation in column: {}",
92+
"is_aggregator": True,
8893
},
8994
"sum": {
9095
"name": "{}__sum",
9196
"description": "Performs `sum` operation in column: {}",
97+
"is_aggregator": True,
9298
},
9399
}
94100

95101

102+
def is_aggregation_allowed(resource_id: str):
103+
return resource_id in config.ALLOW_AGGREGATION
104+
105+
96106
async def get_app_version() -> str:
97107
"""Parse pyproject.toml and return the version or an error."""
98108
try:
@@ -105,7 +115,12 @@ async def get_app_version() -> str:
105115
return f"unknown ({str(e)})"
106116

107117

108-
def build_sql_query_string(request_arg: list, page_size: int = None, offset: int = 0) -> str:
118+
def build_sql_query_string(
119+
request_arg: list,
120+
resource_id: Optional[str] = None,
121+
page_size: int = None,
122+
offset: int = 0,
123+
) -> str:
109124
sql_query = []
110125
aggregators = defaultdict(list)
111126
sorted = False
@@ -125,6 +140,11 @@ def build_sql_query_string(request_arg: list, page_size: int = None, offset: int
125140
else:
126141
raise ValueError(f"argument '{arg}' could not be parsed")
127142
if aggregators:
143+
if resource_id and not is_aggregation_allowed(resource_id):
144+
raise PermissionError(
145+
f"Aggregation parameters `{'`, `'.join(aggregators.keys())}` "
146+
f"are not allowed for resource '{resource_id}'"
147+
)
128148
agg_query = "select="
129149
for operator in aggregators:
130150
if operator == "groupby":
@@ -217,7 +237,7 @@ def url_for(request: Request, route: str, *args, **kwargs) -> str:
217237
return router[route].url_for(**kwargs)
218238

219239

220-
def swagger_parameters(resource_columns: dict) -> list:
240+
def swagger_parameters(resource_columns: dict, resource_id: str) -> list:
221241
parameters_list = [
222242
{
223243
"name": "page",
@@ -239,6 +259,10 @@ def swagger_parameters(resource_columns: dict) -> list:
239259
# see cast for db here: https://github.com/datagouv/hydra/blob/main/udata_hydra/analysis/csv.py
240260
for key, value in resource_columns.items():
241261
for op in OPERATORS_DESCRIPTIONS:
262+
if not is_aggregation_allowed(resource_id) and OPERATORS_DESCRIPTIONS[op].get(
263+
"is_aggregator"
264+
):
265+
continue
242266
if op in TYPE_POSSIBILITIES[value["python_type"]]:
243267
parameters_list.extend(
244268
[
@@ -354,7 +378,7 @@ def swagger_component(resource_columns: dict) -> dict:
354378

355379

356380
def build_swagger_file(resource_columns: dict, rid: str) -> str:
357-
parameters_list = swagger_parameters(resource_columns)
381+
parameters_list = swagger_parameters(resource_columns, rid)
358382
component_dict = swagger_component(resource_columns)
359383
swagger_dict = {
360384
"openapi": "3.0.3",

poetry.lock

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ aioresponses = "^0.7.4"
2020
bumpx = "^0.3.10"
2121
pytest = "^7.2.1"
2222
pytest-asyncio = "^0.20.3"
23+
pytest-mock = "^3.14.0"
2324
ruff = "^0.6.5"
2425

2526
[tool.ruff]

tests/test_query.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
import pytest
22

3+
from api_tabular import config
34
from api_tabular.utils import build_sql_query_string
45

6+
from .conftest import RESOURCE_ID
7+
58

69
def test_query_build_limit():
710
query_str = []
8-
result = build_sql_query_string(query_str, 12)
11+
result = build_sql_query_string(query_str, page_size=12)
912
assert result == "limit=12&order=__id.asc"
1013

1114

1215
def test_query_build_offset():
1316
query_str = []
14-
result = build_sql_query_string(query_str, 12, 12)
17+
result = build_sql_query_string(query_str, page_size=12, offset=12)
1518
assert result == "limit=12&offset=12&order=__id.asc"
1619

1720

1821
def test_query_build_sort_asc():
1922
query_str = ["column_name__sort=asc"]
20-
result = build_sql_query_string(query_str, 50)
23+
result = build_sql_query_string(query_str, page_size=50)
2124
assert result == 'order="column_name".asc&limit=50'
2225

2326

@@ -39,43 +42,43 @@ def test_query_build_sort_asc_with_page_in_query():
3942

4043
def test_query_build_sort_desc():
4144
query_str = ["column_name__sort=desc"]
42-
result = build_sql_query_string(query_str, 50)
45+
result = build_sql_query_string(query_str, page_size=50)
4346
assert result == 'order="column_name".desc&limit=50'
4447

4548

4649
def test_query_build_exact():
4750
query_str = ["column_name__exact=BIDULE"]
48-
result = build_sql_query_string(query_str, 50)
51+
result = build_sql_query_string(query_str, page_size=50)
4952
assert result == '"column_name"=eq.BIDULE&limit=50&order=__id.asc'
5053

5154

5255
def test_query_build_differs():
5356
query_str = ["column_name__differs=BIDULE"]
54-
result = build_sql_query_string(query_str, 50)
57+
result = build_sql_query_string(query_str, page_size=50)
5558
assert result == '"column_name"=neq.BIDULE&limit=50&order=__id.asc'
5659

5760

5861
def test_query_build_contains():
5962
query_str = ["column_name__contains=BIDULE"]
60-
result = build_sql_query_string(query_str, 50)
63+
result = build_sql_query_string(query_str, page_size=50)
6164
assert result == '"column_name"=ilike.*BIDULE*&limit=50&order=__id.asc'
6265

6366

6467
def test_query_build_in():
6568
query_str = ["column_name__in=value1,value2,value3"]
66-
result = build_sql_query_string(query_str, 50)
69+
result = build_sql_query_string(query_str, page_size=50)
6770
assert result == '"column_name"=in.(value1,value2,value3)&limit=50&order=__id.asc'
6871

6972

7073
def test_query_build_less():
7174
query_str = ["column_name__less=12"]
72-
result = build_sql_query_string(query_str, 50, 12)
75+
result = build_sql_query_string(query_str, page_size=50, offset=12)
7376
assert result == '"column_name"=lte.12&limit=50&offset=12&order=__id.asc'
7477

7578

7679
def test_query_build_greater():
7780
query_str = ["column_name__greater=12"]
78-
result = build_sql_query_string(query_str, 50)
81+
result = build_sql_query_string(query_str, page_size=50)
7982
assert result == '"column_name"=gte.12&limit=50&order=__id.asc'
8083

8184

@@ -85,7 +88,7 @@ def test_query_build_multiple():
8588
"column_name__greater=12",
8689
"column_name__exact=BIDULE",
8790
]
88-
result = build_sql_query_string(query_str, 50)
91+
result = build_sql_query_string(query_str, page_size=50)
8992
assert (
9093
result
9194
== '"column_name"=eq.BIDULE&"column_name"=gte.12&"column_name"=eq.BIDULE&limit=50&order=__id.asc'
@@ -95,16 +98,29 @@ def test_query_build_multiple():
9598
def test_query_build_multiple_with_unknown():
9699
query_str = ["select=numnum"]
97100
with pytest.raises(ValueError):
98-
build_sql_query_string(query_str, 50)
99-
100-
101-
def test_query_aggregators():
101+
build_sql_query_string(query_str, page_size=50)
102+
103+
104+
@pytest.mark.parametrize(
105+
"allow_aggregation",
106+
[
107+
False,
108+
True,
109+
],
110+
)
111+
def test_query_aggregators(allow_aggregation, mocker):
112+
if allow_aggregation:
113+
mocker.patch("api_tabular.config.ALLOW_AGGREGATION", [RESOURCE_ID])
102114
query_str = [
103115
"column_name__groupby",
104116
"column_name__min",
105117
"column_name__avg",
106118
]
107-
results = build_sql_query_string(query_str, 50).split("&")
119+
if not allow_aggregation:
120+
with pytest.raises(PermissionError):
121+
build_sql_query_string(query_str, resource_id=RESOURCE_ID, page_size=50)
122+
return
123+
results = build_sql_query_string(query_str, resource_id=RESOURCE_ID, page_size=50).split("&")
108124
assert "limit=50" in results
109125
assert "order=__id.asc" not in results # no sort if aggregators
110126
select = [_ for _ in results if "select" in _]

tests/test_swagger.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pytest
55
import yaml
66

7-
from api_tabular.utils import TYPE_POSSIBILITIES
7+
from api_tabular import config
8+
from api_tabular.utils import OPERATORS_DESCRIPTIONS, TYPE_POSSIBILITIES
89

910
from .conftest import RESOURCE_ID, TABLES_INDEX_PATTERN
1011

@@ -17,7 +18,16 @@ async def test_swagger_endpoint(client, rmock):
1718
assert res.status == 200
1819

1920

20-
async def test_swagger_content(client, rmock):
21+
@pytest.mark.parametrize(
22+
"allow_aggregation",
23+
[
24+
False,
25+
True,
26+
],
27+
)
28+
async def test_swagger_content(client, rmock, allow_aggregation, mocker):
29+
if allow_aggregation:
30+
mocker.patch("api_tabular.config.ALLOW_AGGREGATION", [RESOURCE_ID])
2131
with open("db/sample.csv", newline="") as csvfile:
2232
spamreader = csv.reader(csvfile, delimiter=",", quotechar='"')
2333
# getting the csv-detective output in the test file
@@ -49,8 +59,20 @@ async def test_swagger_content(client, rmock):
4959
elif p == "in":
5060
value = "value1,value2,..."
5161
for _p in _params:
52-
if (
53-
f"{c}__{_p}={value}" not in params # filters
54-
and f"{c}__{_p}" not in params # aggregators
55-
):
56-
raise ValueError(f"{c}__{_p} is missing in {output} output")
62+
if allow_aggregation:
63+
if (
64+
f"{c}__{_p}={value}" not in params # filters
65+
and f"{c}__{_p}" not in params # aggregators
66+
):
67+
raise ValueError(f"{c}__{_p} is missing in {output} output")
68+
else:
69+
if (
70+
not OPERATORS_DESCRIPTIONS.get(_p, {}).get("is_aggregator")
71+
and f"{c}__{_p}={value}" not in params # filters are in
72+
):
73+
raise ValueError(f"{c}__{_p} is missing in {output} output")
74+
if (
75+
OPERATORS_DESCRIPTIONS.get(_p, {}).get("is_aggregator")
76+
and f"{c}__{_p}" in params # aggregators are out
77+
):
78+
raise ValueError(f"{c}__{_p} is in {output} output but should not")

0 commit comments

Comments
 (0)