11import asyncio
22import json
3- import os
4- import pprint
5- import signal
6- import sys
7- import threading
8- import time
93from enum import Enum
4+ from uuid import UUID
105
116from aiohttp import web
127from aiohttp_sse import sse_response
1611
1712# from forestadmin.agent_rpc.services.datasource import DatasourceService
1813from forestadmin .agent_toolkit .agent import Agent
19- from forestadmin .agent_toolkit .forest_logger import ForestLogger
20- from forestadmin .datasource_toolkit .interfaces .fields import is_column
14+ from forestadmin .datasource_toolkit .interfaces .fields import PrimitiveType
15+ from forestadmin .datasource_toolkit .utils .schema import SchemaUtils
16+ from forestadmin .rpc_common .hmac import is_valid_hmac
17+ from forestadmin .rpc_common .serializers .actions import ActionFormSerializer , ActionResultSerializer
18+ from forestadmin .rpc_common .serializers .aes import aes_decrypt , aes_encrypt
2119from forestadmin .rpc_common .serializers .collection .aggregation import AggregationSerializer
2220from forestadmin .rpc_common .serializers .collection .filter import (
2321 FilterSerializer ,
@@ -56,27 +54,46 @@ def __init__(self, options: RpcOptions):
5654 # self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
5755 # self.server = grpc.aio.server()
5856 self .listen_addr , self .listen_port = options ["listen_addr" ].rsplit (":" , 1 )
59- self .app = web .Application ()
57+ self .app = web .Application (middlewares = [ self . hmac_middleware ] )
6058 # self.server.add_insecure_port(options["listen_addr"])
6159 options ["skip_schema_update" ] = True
6260 options ["env_secret" ] = "f" * 64
6361 options ["server_url" ] = "http://fake"
64- options ["auth_secret" ] = "fake "
62+ # options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19 "
6563 options ["schema_path" ] = "./.forestadmin-schema.json"
66-
6764 super ().__init__ (options )
65+
66+ self .aes_key = self .options ["auth_secret" ][:16 ].encode ()
67+ self .aes_iv = self .options ["auth_secret" ][- 16 :].encode ()
6868 self ._server_stop = False
6969 self .setup_routes ()
7070 # signal.signal(signal.SIGUSR1, self.stop_handler)
7171
72+ @web .middleware
73+ async def hmac_middleware (self , request : web .Request , handler ):
74+ if request .method == "POST" :
75+ body = await request .read ()
76+ if not is_valid_hmac (
77+ self .options ["auth_secret" ].encode (), body , request .headers .get ("X-FOREST-HMAC" , "" ).encode ("utf-8" )
78+ ):
79+ return web .Response (status = 401 )
80+ return await handler (request )
81+
7282 def setup_routes (self ):
83+ # self.app.middlewares.append(self.hmac_middleware)
7384 self .app .router .add_route ("GET" , "/sse" , self .sse_handler )
7485 self .app .router .add_route ("GET" , "/schema" , self .schema )
7586 self .app .router .add_route ("POST" , "/collection/list" , self .collection_list )
7687 self .app .router .add_route ("POST" , "/collection/create" , self .collection_create )
7788 self .app .router .add_route ("POST" , "/collection/update" , self .collection_update )
7889 self .app .router .add_route ("POST" , "/collection/delete" , self .collection_delete )
7990 self .app .router .add_route ("POST" , "/collection/aggregate" , self .collection_aggregate )
91+ self .app .router .add_route ("POST" , "/collection/get-form" , self .collection_get_form )
92+ self .app .router .add_route ("POST" , "/collection/execute" , self .collection_execute )
93+ self .app .router .add_route ("POST" , "/collection/render-chart" , self .collection_render_chart )
94+
95+ self .app .router .add_route ("POST" , "/execute-native-query" , self .native_query )
96+ self .app .router .add_route ("POST" , "/render-chart" , self .render_chart )
8097 self .app .router .add_route ("GET" , "/" , lambda _ : web .Response (text = "OK" ))
8198
8299 async def sse_handler (self , request : web .Request ) -> web .StreamResponse :
@@ -103,11 +120,14 @@ async def collection_list(self, request: web.Request):
103120
104121 records = await collection .list (caller , filter_ , projection )
105122 records = [RecordSerializer .serialize (record ) for record in records ]
106- return web .json_response ( records )
123+ return web .Response ( text = aes_encrypt ( json . dumps ( records ), self . aes_key , self . aes_iv ) )
107124
108125 async def collection_create (self , request : web .Request ):
109- body_params = await request .json ()
126+ body_params = await request .text ()
127+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
128+ body_params = json .loads (body_params )
110129 ds = await self .customizer .get_datasource ()
130+
111131 collection = ds .get_collection (body_params ["collectionName" ])
112132 caller = CallerSerializer .deserialize (body_params ["caller" ])
113133 data = [RecordSerializer .deserialize (r , collection ) for r in body_params ["data" ]]
@@ -117,7 +137,10 @@ async def collection_create(self, request: web.Request):
117137 return web .json_response (records )
118138
119139 async def collection_update (self , request : web .Request ):
120- body_params = await request .json ()
140+ body_params = await request .text ()
141+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
142+ body_params = json .loads (body_params )
143+
121144 ds = await self .customizer .get_datasource ()
122145 collection = ds .get_collection (body_params ["collectionName" ])
123146 caller = CallerSerializer .deserialize (body_params ["caller" ])
@@ -147,7 +170,103 @@ async def collection_aggregate(self, request: web.Request):
147170
148171 records = await collection .aggregate (caller , filter_ , aggregation )
149172 # records = [RecordSerializer.serialize(record) for record in records]
150- return web .json_response (records )
173+ return web .Response (text = aes_encrypt (json .dumps (records ), self .aes_key , self .aes_iv ))
174+
175+ async def collection_get_form (self , request : web .Request ):
176+ body_params = await request .text ()
177+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
178+ body_params = json .loads (body_params )
179+
180+ ds = await self .customizer .get_datasource ()
181+ collection = ds .get_collection (body_params ["collectionName" ])
182+
183+ caller = CallerSerializer .deserialize (body_params ["caller" ]) if body_params ["caller" ] else None
184+ action_name = body_params ["actionName" ]
185+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) if body_params ["filter" ] else None
186+ data = body_params ["data" ]
187+ meta = body_params ["meta" ]
188+
189+ form = await collection .get_form (caller , action_name , data , filter_ , meta )
190+ return web .Response (
191+ text = aes_encrypt (json .dumps (ActionFormSerializer .serialize (form )), self .aes_key , self .aes_iv )
192+ )
193+
194+ async def collection_execute (self , request : web .Request ):
195+ body_params = await request .text ()
196+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
197+ body_params = json .loads (body_params )
198+
199+ ds = await self .customizer .get_datasource ()
200+ collection = ds .get_collection (body_params ["collectionName" ])
201+
202+ caller = CallerSerializer .deserialize (body_params ["caller" ]) if body_params ["caller" ] else None
203+ action_name = body_params ["actionName" ]
204+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) if body_params ["filter" ] else None
205+ data = body_params ["data" ]
206+
207+ result = await collection .execute (caller , action_name , data , filter_ )
208+ return web .Response (
209+ text = aes_encrypt (json .dumps (ActionResultSerializer .serialize (result )), self .aes_key , self .aes_iv )
210+ )
211+
212+ async def collection_render_chart (self , request : web .Request ):
213+ body_params = await request .text ()
214+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
215+ body_params = json .loads (body_params )
216+
217+ ds = await self .customizer .get_datasource ()
218+ collection = ds .get_collection (body_params ["collectionName" ])
219+
220+ caller = CallerSerializer .deserialize (body_params ["caller" ])
221+ name = body_params ["name" ]
222+ record_id = body_params ["recordId" ]
223+ ret = []
224+ for i , value in enumerate (record_id ):
225+ type_record_id = collection .schema ["fields" ][SchemaUtils .get_primary_keys (collection .schema )[i ]][
226+ "column_type"
227+ ]
228+
229+ if type_record_id == PrimitiveType .DATE :
230+ ret .append (value .fromisoformat ())
231+ elif type_record_id == PrimitiveType .DATE_ONLY :
232+ ret .append (value .fromisoformat ())
233+ elif type_record_id == PrimitiveType .DATE :
234+ ret .append (value .fromisoformat ())
235+ elif type_record_id == PrimitiveType .POINT :
236+ ret .append ((value [0 ], value [1 ]))
237+ elif type_record_id == PrimitiveType .UUID :
238+ ret .append (UUID (value ))
239+ else :
240+ ret .append (value )
241+
242+ result = await collection .render_chart (caller , name , record_id )
243+ return web .Response (text = aes_encrypt (json .dumps (result ), self .aes_key , self .aes_iv ))
244+
245+ async def render_chart (self , request : web .Request ):
246+ body_params = await request .text ()
247+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
248+ body_params = json .loads (body_params )
249+
250+ ds = await self .customizer .get_datasource ()
251+
252+ caller = CallerSerializer .deserialize (body_params ["caller" ])
253+ name = body_params ["name" ]
254+
255+ result = await ds .render_chart (caller , name )
256+ return web .Response (text = aes_encrypt (json .dumps (result ), self .aes_key , self .aes_iv ))
257+
258+ async def native_query (self , request : web .Request ):
259+ body_params = await request .text ()
260+ body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
261+ body_params = json .loads (body_params )
262+
263+ ds = await self .customizer .get_datasource ()
264+ connection_name = body_params ["connectionName" ]
265+ native_query = body_params ["nativeQuery" ]
266+ parameters = body_params ["parameters" ]
267+
268+ result = await ds .execute_native_query (connection_name , native_query , parameters )
269+ return web .Response (text = aes_encrypt (json .dumps (result ), self .aes_key , self .aes_iv ))
151270
152271 def start (self ):
153272 web .run_app (self .app , host = self .listen_addr , port = int (self .listen_port ))
0 commit comments