1515from modules .constants import HttpResponse , http_code_to_enum
1616from modules .metrics import MetricsHandler
1717from modules .sqlite_helpers import increment_used_column
18+ from modules .cache import Cache
1819
1920
2021app = FastAPI ()
2324
2425app .add_middleware (
2526 CORSMiddleware ,
26- allow_origins = ['*' ],
27- allow_methods = ['*' ],
28- allow_headers = ['*' ],
27+ allow_origins = ["*" ],
28+ allow_methods = ["*" ],
29+ allow_headers = ["*" ],
2930)
3031
31- metrics_handler = MetricsHandler . instance ( )
32+ cache = Cache ( args . cache_size )
3233
33- #maybe create the table if it doesnt already exist
34+ # maybe create the table if it doesnt already exist
3435DATABASE_FILE = args .database_file_path
3536sqlite_helpers .maybe_create_table (DATABASE_FILE )
3637
38+
3739# middleware to get metrics on HTTP response codes
3840@app .middleware ("http" )
3941async def track_response_codes (request : Request , call_next ):
@@ -42,38 +44,45 @@ async def track_response_codes(request: Request, call_next):
4244 MetricsHandler .http_code .labels (request .url .path , status_code ).inc ()
4345 return response
4446
47+
4548@app .post ("/create_url" )
4649async def create_url (request : Request ):
4750 urljson = await request .json ()
4851 logging .debug (f"/create_url called with body: { urljson } " )
4952 alias = None
5053
5154 try :
52- alias = urljson .get (' alias' )
55+ alias = urljson .get (" alias" )
5356 if alias is None :
5457 if args .disable_random_alias :
5558 raise KeyError ("alias must be specified" )
5659 else :
57- alias = generate_alias (urljson [' url' ])
60+ alias = generate_alias (urljson [" url" ])
5861 if not alias .isalnum ():
5962 raise ValueError ("alias must only contain alphanumeric characters" )
6063
6164 with MetricsHandler .query_time .labels ("create" ).time ():
62- response = sqlite_helpers .insert_url (DATABASE_FILE , urljson [' url' ], alias )
65+ response = sqlite_helpers .insert_url (DATABASE_FILE , urljson [" url" ], alias )
6366 if response is not None :
6467 MetricsHandler .url_count .inc (1 )
65- return { "url" : urljson [' url' ], "alias" : alias , "created_at" : response }
68+ return {"url" : urljson [" url" ], "alias" : alias , "created_at" : response }
6669 else :
67- raise HTTPException (status_code = HttpResponse .CONFLICT .code )
70+ raise HTTPException (status_code = HttpResponse .CONFLICT .code )
6871 except KeyError :
6972 logging .exception ("returning 400 due to missing key" )
7073 raise HTTPException (status_code = HttpResponse .BAD_REQUEST .code )
7174 except ValueError :
72- logging .exception (f" returning 422 due to invalid alias of \ "{ alias } \" " )
75+ logging .exception (f' returning 422 due to invalid alias of "{ alias } "' )
7376 raise HTTPException (status_code = HttpResponse .INVALID_ARGUMENT_EXCEPTION .code )
7477
78+
7579@app .get ("/list" )
76- async def get_urls (search : Optional [str ] = None , page : int = 0 , sort_by : str = "created_at" , order : str = "DESC" ):
80+ async def get_urls (
81+ search : Optional [str ] = None ,
82+ page : int = 0 ,
83+ sort_by : str = "created_at" ,
84+ order : str = "DESC" ,
85+ ):
7786 valid_sort_attributes = {"id" , "url" , "alias" , "created_at" , "used" }
7887 if order not in {"DESC" , "ASC" }:
7988 raise HTTPException (status_code = 400 , detail = "Invalid order" )
@@ -82,20 +91,36 @@ async def get_urls(search: Optional[str] = None, page: int = 0, sort_by: str = "
8291 if page < 0 :
8392 raise HTTPException (status_code = 400 , detail = "Invalid page number" )
8493 if search and not search .isalnum ():
85- raise HTTPException (status_code = 400 , detail = f'search term "{ search } " is invalid. only alphanumeric chars are allowed' )
94+ raise HTTPException (
95+ status_code = 400 ,
96+ detail = f'search term "{ search } " is invalid. only alphanumeric chars are allowed' ,
97+ )
8698 with MetricsHandler .query_time .labels ("list" ).time ():
87- urls = sqlite_helpers .get_urls (DATABASE_FILE , page , search = search , sort_by = sort_by , order = order )
99+ urls = sqlite_helpers .get_urls (
100+ DATABASE_FILE , page , search = search , sort_by = sort_by , order = order
101+ )
88102 total_urls = sqlite_helpers .get_number_of_entries (DATABASE_FILE , search = search )
89- return {"data" : urls , "total" : total_urls , "rows_per_page" : sqlite_helpers .ROWS_PER_PAGE }
103+ return {
104+ "data" : urls ,
105+ "total" : total_urls ,
106+ "rows_per_page" : sqlite_helpers .ROWS_PER_PAGE ,
107+ }
108+
90109
91110@app .get ("/find/{alias}" )
92111async def get_url (alias : str ):
93112 logging .debug (f"/find called with alias: { alias } " )
113+ url_output = cache .find (alias ) # try to find url in cache
114+ if url_output is not None :
115+ alias_queue .put (alias )
116+ return RedirectResponse (url_output )
117+
94118 with MetricsHandler .query_time .labels ("find" ).time ():
95119 url_output = sqlite_helpers .get_url (DATABASE_FILE , alias )
96-
97120 if url_output is None :
98121 raise HTTPException (status_code = HttpResponse .NOT_FOUND .code )
122+ cache .add (alias , url_output ) # else, adds url and alias to cache
123+
99124 alias_queue .put (alias )
100125 return RedirectResponse (url_output )
101126
@@ -104,15 +129,19 @@ async def get_url(alias: str):
104129async def delete_url (alias : str ):
105130 logging .debug (f"/delete called with alias: { alias } " )
106131 with MetricsHandler .query_time .labels ("delete" ).time ():
107- if (sqlite_helpers .delete_url (DATABASE_FILE , alias )):
108- return {"message" : "URL deleted successfully" }
109- else :
110- raise HTTPException (status_code = HttpResponse .NOT_FOUND .code )
132+ if sqlite_helpers .delete_url (DATABASE_FILE , alias ):
133+ return {"message" : "URL deleted successfully" }
134+ else :
135+ raise HTTPException (status_code = HttpResponse .NOT_FOUND .code )
136+
111137
112138@app .exception_handler (HTTPException )
113139async def http_exception_handler (request , exc ):
114140 status_code_enum = http_code_to_enum [exc .status_code ]
115- return HTMLResponse (content = status_code_enum .content , status_code = status_code_enum .code )
141+ return HTMLResponse (
142+ content = status_code_enum .content , status_code = status_code_enum .code
143+ )
144+
116145
117146@app .get ("/metrics" )
118147def get_metrics ():
@@ -121,20 +150,22 @@ def get_metrics():
121150 content = prometheus_client .generate_latest (),
122151 )
123152
153+
124154logging .Formatter .converter = time .gmtime
125155
126156logging .basicConfig (
127157 # in mondo we trust
128158 format = "%(asctime)s.%(msecs)03dZ %(levelname)s:%(name)s:%(message)s" ,
129159 datefmt = "%Y-%m-%dT%H:%M:%S" ,
130- level = logging .ERROR - (args .verbose * 10 ),
160+ level = logging .ERROR - (args .verbose * 10 ),
131161)
132162
163+
133164def consumer ():
134165 while True :
135166 alias = alias_queue .get ()
136- if alias is None :
137- break
167+ if alias is None :
168+ break
138169 try :
139170 with MetricsHandler .query_time .labels ("increment_used" ).time ():
140171 increment_used_column (DATABASE_FILE , alias )
@@ -144,7 +175,6 @@ def consumer():
144175 alias_queue .task_done ()
145176
146177
147-
148178# we have a separate __name__ check here due to how FastAPI starts
149179# a server. the file is first ran (where __name__ == "__main__")
150180# and then calls `uvicorn.run`. the call to run() reruns the file,
@@ -155,6 +185,7 @@ def consumer():
155185# server uses
156186if __name__ == "server" :
157187 initial_url_count = sqlite_helpers .get_number_of_entries (DATABASE_FILE )
188+ MetricsHandler .init ()
158189 MetricsHandler .url_count .inc (initial_url_count )
159190 consumer_thread = Thread (target = consumer , daemon = True )
160191 consumer_thread .start ()
0 commit comments