Skip to content

Commit fc1394e

Browse files
committed
chore: all methods now exists in rpc datasource/agent
1 parent da7c642 commit fc1394e

File tree

9 files changed

+676
-144
lines changed

9 files changed

+676
-144
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ src/datasource_toolkit/poetry.lock
177177
src/flask_agent/poetry.lock
178178
src/django_agent/poetry.lock
179179
src/datasource_django/poetry.lock
180+
src/datasource_rpc/poetry.lock
181+
src/agent_rpc/poetry.lock
182+
src/rpc_common/poetry.lock
180183

181184
# generate file during tests
182185
src/django_agent/tests/test_project_agent/.forestadmin-schema.json

src/agent_rpc/forestadmin/agent_rpc/agent.py

Lines changed: 134 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
import asyncio
22
import json
3-
import os
4-
import pprint
5-
import signal
6-
import sys
7-
import threading
8-
import time
93
from enum import Enum
4+
from uuid import UUID
105

116
from aiohttp import web
127
from aiohttp_sse import sse_response
@@ -16,8 +11,11 @@
1611

1712
# from forestadmin.agent_rpc.services.datasource import DatasourceService
1813
from forestadmin.agent_toolkit.agent import Agent
19-
from forestadmin.agent_toolkit.forest_logger import ForestLogger
20-
from forestadmin.datasource_toolkit.interfaces.fields import is_column
14+
from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType
15+
from forestadmin.datasource_toolkit.utils.schema import SchemaUtils
16+
from forestadmin.rpc_common.hmac import is_valid_hmac
17+
from forestadmin.rpc_common.serializers.actions import ActionFormSerializer, ActionResultSerializer
18+
from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt
2119
from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer
2220
from forestadmin.rpc_common.serializers.collection.filter import (
2321
FilterSerializer,
@@ -56,27 +54,46 @@ def __init__(self, options: RpcOptions):
5654
# self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
5755
# self.server = grpc.aio.server()
5856
self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1)
59-
self.app = web.Application()
57+
self.app = web.Application(middlewares=[self.hmac_middleware])
6058
# self.server.add_insecure_port(options["listen_addr"])
6159
options["skip_schema_update"] = True
6260
options["env_secret"] = "f" * 64
6361
options["server_url"] = "http://fake"
64-
options["auth_secret"] = "fake"
62+
# options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19"
6563
options["schema_path"] = "./.forestadmin-schema.json"
66-
6764
super().__init__(options)
65+
66+
self.aes_key = self.options["auth_secret"][:16].encode()
67+
self.aes_iv = self.options["auth_secret"][-16:].encode()
6868
self._server_stop = False
6969
self.setup_routes()
7070
# signal.signal(signal.SIGUSR1, self.stop_handler)
7171

72+
@web.middleware
73+
async def hmac_middleware(self, request: web.Request, handler):
74+
if request.method == "POST":
75+
body = await request.read()
76+
if not is_valid_hmac(
77+
self.options["auth_secret"].encode(), body, request.headers.get("X-FOREST-HMAC", "").encode("utf-8")
78+
):
79+
return web.Response(status=401)
80+
return await handler(request)
81+
7282
def setup_routes(self):
83+
# self.app.middlewares.append(self.hmac_middleware)
7384
self.app.router.add_route("GET", "/sse", self.sse_handler)
7485
self.app.router.add_route("GET", "/schema", self.schema)
7586
self.app.router.add_route("POST", "/collection/list", self.collection_list)
7687
self.app.router.add_route("POST", "/collection/create", self.collection_create)
7788
self.app.router.add_route("POST", "/collection/update", self.collection_update)
7889
self.app.router.add_route("POST", "/collection/delete", self.collection_delete)
7990
self.app.router.add_route("POST", "/collection/aggregate", self.collection_aggregate)
91+
self.app.router.add_route("POST", "/collection/get-form", self.collection_get_form)
92+
self.app.router.add_route("POST", "/collection/execute", self.collection_execute)
93+
self.app.router.add_route("POST", "/collection/render-chart", self.collection_render_chart)
94+
95+
self.app.router.add_route("POST", "/execute-native-query", self.native_query)
96+
self.app.router.add_route("POST", "/render-chart", self.render_chart)
8097
self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK"))
8198

8299
async def sse_handler(self, request: web.Request) -> web.StreamResponse:
@@ -103,11 +120,14 @@ async def collection_list(self, request: web.Request):
103120

104121
records = await collection.list(caller, filter_, projection)
105122
records = [RecordSerializer.serialize(record) for record in records]
106-
return web.json_response(records)
123+
return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv))
107124

108125
async def collection_create(self, request: web.Request):
109-
body_params = await request.json()
126+
body_params = await request.text()
127+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
128+
body_params = json.loads(body_params)
110129
ds = await self.customizer.get_datasource()
130+
111131
collection = ds.get_collection(body_params["collectionName"])
112132
caller = CallerSerializer.deserialize(body_params["caller"])
113133
data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]]
@@ -117,7 +137,10 @@ async def collection_create(self, request: web.Request):
117137
return web.json_response(records)
118138

119139
async def collection_update(self, request: web.Request):
120-
body_params = await request.json()
140+
body_params = await request.text()
141+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
142+
body_params = json.loads(body_params)
143+
121144
ds = await self.customizer.get_datasource()
122145
collection = ds.get_collection(body_params["collectionName"])
123146
caller = CallerSerializer.deserialize(body_params["caller"])
@@ -147,7 +170,103 @@ async def collection_aggregate(self, request: web.Request):
147170

148171
records = await collection.aggregate(caller, filter_, aggregation)
149172
# records = [RecordSerializer.serialize(record) for record in records]
150-
return web.json_response(records)
173+
return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv))
174+
175+
async def collection_get_form(self, request: web.Request):
176+
body_params = await request.text()
177+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
178+
body_params = json.loads(body_params)
179+
180+
ds = await self.customizer.get_datasource()
181+
collection = ds.get_collection(body_params["collectionName"])
182+
183+
caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None
184+
action_name = body_params["actionName"]
185+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None
186+
data = body_params["data"]
187+
meta = body_params["meta"]
188+
189+
form = await collection.get_form(caller, action_name, data, filter_, meta)
190+
return web.Response(
191+
text=aes_encrypt(json.dumps(ActionFormSerializer.serialize(form)), self.aes_key, self.aes_iv)
192+
)
193+
194+
async def collection_execute(self, request: web.Request):
195+
body_params = await request.text()
196+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
197+
body_params = json.loads(body_params)
198+
199+
ds = await self.customizer.get_datasource()
200+
collection = ds.get_collection(body_params["collectionName"])
201+
202+
caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None
203+
action_name = body_params["actionName"]
204+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None
205+
data = body_params["data"]
206+
207+
result = await collection.execute(caller, action_name, data, filter_)
208+
return web.Response(
209+
text=aes_encrypt(json.dumps(ActionResultSerializer.serialize(result)), self.aes_key, self.aes_iv)
210+
)
211+
212+
async def collection_render_chart(self, request: web.Request):
213+
body_params = await request.text()
214+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
215+
body_params = json.loads(body_params)
216+
217+
ds = await self.customizer.get_datasource()
218+
collection = ds.get_collection(body_params["collectionName"])
219+
220+
caller = CallerSerializer.deserialize(body_params["caller"])
221+
name = body_params["name"]
222+
record_id = body_params["recordId"]
223+
ret = []
224+
for i, value in enumerate(record_id):
225+
type_record_id = collection.schema["fields"][SchemaUtils.get_primary_keys(collection.schema)[i]][
226+
"column_type"
227+
]
228+
229+
if type_record_id == PrimitiveType.DATE:
230+
ret.append(value.fromisoformat())
231+
elif type_record_id == PrimitiveType.DATE_ONLY:
232+
ret.append(value.fromisoformat())
233+
elif type_record_id == PrimitiveType.DATE:
234+
ret.append(value.fromisoformat())
235+
elif type_record_id == PrimitiveType.POINT:
236+
ret.append((value[0], value[1]))
237+
elif type_record_id == PrimitiveType.UUID:
238+
ret.append(UUID(value))
239+
else:
240+
ret.append(value)
241+
242+
result = await collection.render_chart(caller, name, record_id)
243+
return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv))
244+
245+
async def render_chart(self, request: web.Request):
246+
body_params = await request.text()
247+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
248+
body_params = json.loads(body_params)
249+
250+
ds = await self.customizer.get_datasource()
251+
252+
caller = CallerSerializer.deserialize(body_params["caller"])
253+
name = body_params["name"]
254+
255+
result = await ds.render_chart(caller, name)
256+
return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv))
257+
258+
async def native_query(self, request: web.Request):
259+
body_params = await request.text()
260+
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
261+
body_params = json.loads(body_params)
262+
263+
ds = await self.customizer.get_datasource()
264+
connection_name = body_params["connectionName"]
265+
native_query = body_params["nativeQuery"]
266+
parameters = body_params["parameters"]
267+
268+
result = await ds.execute_native_query(connection_name, native_query, parameters)
269+
return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv))
151270

152271
def start(self):
153272
web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port))

src/agent_rpc/forestadmin/agent_rpc/options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33

44
class RpcOptions(TypedDict):
55
listen_addr: str
6+
aes_key: bytes
7+
aes_iv: bytes

0 commit comments

Comments
 (0)