Skip to content

Commit 78a83a2

Browse files
cuzethanpnavabevanugarte
authored
Add caching of urls (#38)
* added cache.py as part of modules * added argument for cacheSize and made metrics for cache * refactor /find endpoint to use new cache * fixed up a bit of commits * Update modules/cache.py Co-authored-by: pnavab <114110926+pnavab@users.noreply.github.com> * Update modules/cache.py Co-authored-by: pnavab <114110926+pnavab@users.noreply.github.com> * reworded and reformated code for better readability * minor fixes * more small changes * remove MetricsHandler instance method, cache size uses gauge * remove call to instance * linting --------- Co-authored-by: pnavab <114110926+pnavab@users.noreply.github.com> Co-authored-by: evan ugarte <evanuxd@gmail.com>
1 parent 18d77fe commit 78a83a2

File tree

4 files changed

+105
-37
lines changed

4 files changed

+105
-37
lines changed

modules/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,11 @@ def get_args():
3232
default=0,
3333
help="increase logging verbosity; can be used multiple times",
3434
)
35+
parser.add_argument(
36+
"--cache-size",
37+
type=int,
38+
default=100,
39+
help="number of url redirects to store in memory. defaults to 100"
40+
)
3541

3642
return parser.parse_args()

modules/cache.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from collections import OrderedDict
2+
import logging
3+
4+
from modules.metrics import MetricsHandler
5+
6+
class Cache:
7+
def __init__(self, cacheSize):
8+
self.dict = OrderedDict()
9+
self.size = cacheSize
10+
11+
def find(self, alias):
12+
if alias not in self.dict:
13+
MetricsHandler.cache_misses.inc()
14+
return None
15+
16+
self.dict.move_to_end(alias, last=False) #move alias to front of cache
17+
logging.debug(f"alias: '{alias}' is grabbed from mapping")
18+
MetricsHandler.cache_hits.inc()
19+
return self.dict[alias]
20+
21+
def add(self, alias, url_output):
22+
if len(self.dict) == self.size:
23+
data = self.dict.popitem() #remove least used alias if size reaches max
24+
logging.debug(f"alias: {data[0]} has been removed from cache")
25+
self.dict[alias] = url_output
26+
MetricsHandler.cache_size.set(len(self.dict))
27+
logging.debug("set alias: '" + alias + "' to mapping")

modules/metrics.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@ class Metrics(enum.Enum):
2121
prometheus_client.Counter,
2222
['path', 'code'],
2323
)
24+
CACHE_SIZE = (
25+
"cache_size",
26+
"Size of LRU cache for /find",
27+
prometheus_client.Gauge,
28+
)
29+
CACHE_HITS = (
30+
"cache_hits",
31+
"Number of times cache is used",
32+
prometheus_client.Counter,
33+
)
34+
CACHE_MIESSES = (
35+
"cache_misses",
36+
"Number of times cahes is not used",
37+
prometheus_client.Counter,
38+
)
2439

2540
def __init__(self, title, description, prometheus_type, labels=()):
2641
# we use the above default value for labels because it matches what's used
@@ -33,11 +48,7 @@ def __init__(self, title, description, prometheus_type, labels=()):
3348

3449

3550
class MetricsHandler:
36-
_instance = None
37-
38-
def __init__(self):
39-
raise RuntimeError('Call MetricsHandler.instance() instead')
40-
51+
@classmethod
4152
def init(self) -> None:
4253
for metric in Metrics:
4354
setattr(
@@ -47,10 +58,3 @@ def init(self) -> None:
4758
metric.title, metric.description, labelnames=metric.labels
4859
),
4960
)
50-
51-
@classmethod
52-
def instance(cls):
53-
if cls._instance is None:
54-
cls._instance = cls.__new__(cls)
55-
cls.init(cls)
56-
return cls._instance

server.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from modules.constants import HttpResponse, http_code_to_enum
1616
from modules.metrics import MetricsHandler
1717
from modules.sqlite_helpers import increment_used_column
18+
from modules.cache import Cache
1819

1920

2021
app = FastAPI()
@@ -23,17 +24,18 @@
2324

2425
app.add_middleware(
2526
CORSMiddleware,
26-
allow_origins=['*'],
27-
allow_methods=['*'],
28-
allow_headers=['*'],
27+
allow_origins=["*"],
28+
allow_methods=["*"],
29+
allow_headers=["*"],
2930
)
3031

31-
metrics_handler = MetricsHandler.instance()
32+
cache = Cache(args.cache_size)
3233

33-
#maybe create the table if it doesnt already exist
34+
# maybe create the table if it doesnt already exist
3435
DATABASE_FILE = args.database_file_path
3536
sqlite_helpers.maybe_create_table(DATABASE_FILE)
3637

38+
3739
# middleware to get metrics on HTTP response codes
3840
@app.middleware("http")
3941
async def track_response_codes(request: Request, call_next):
@@ -42,38 +44,45 @@ async def track_response_codes(request: Request, call_next):
4244
MetricsHandler.http_code.labels(request.url.path, status_code).inc()
4345
return response
4446

47+
4548
@app.post("/create_url")
4649
async def create_url(request: Request):
4750
urljson = await request.json()
4851
logging.debug(f"/create_url called with body: {urljson}")
4952
alias = None
5053

5154
try:
52-
alias = urljson.get('alias')
55+
alias = urljson.get("alias")
5356
if alias is None:
5457
if args.disable_random_alias:
5558
raise KeyError("alias must be specified")
5659
else:
57-
alias = generate_alias(urljson['url'])
60+
alias = generate_alias(urljson["url"])
5861
if not alias.isalnum():
5962
raise ValueError("alias must only contain alphanumeric characters")
6063

6164
with MetricsHandler.query_time.labels("create").time():
62-
response = sqlite_helpers.insert_url(DATABASE_FILE, urljson['url'], alias)
65+
response = sqlite_helpers.insert_url(DATABASE_FILE, urljson["url"], alias)
6366
if response is not None:
6467
MetricsHandler.url_count.inc(1)
65-
return { "url": urljson['url'], "alias": alias, "created_at": response}
68+
return {"url": urljson["url"], "alias": alias, "created_at": response}
6669
else:
67-
raise HTTPException(status_code=HttpResponse.CONFLICT.code )
70+
raise HTTPException(status_code=HttpResponse.CONFLICT.code)
6871
except KeyError:
6972
logging.exception("returning 400 due to missing key")
7073
raise HTTPException(status_code=HttpResponse.BAD_REQUEST.code)
7174
except ValueError:
72-
logging.exception(f"returning 422 due to invalid alias of \"{alias}\"")
75+
logging.exception(f'returning 422 due to invalid alias of "{alias}"')
7376
raise HTTPException(status_code=HttpResponse.INVALID_ARGUMENT_EXCEPTION.code)
7477

78+
7579
@app.get("/list")
76-
async def get_urls(search: Optional[str] = None, page: int = 0, sort_by: str = "created_at", order: str = "DESC"):
80+
async def get_urls(
81+
search: Optional[str] = None,
82+
page: int = 0,
83+
sort_by: str = "created_at",
84+
order: str = "DESC",
85+
):
7786
valid_sort_attributes = {"id", "url", "alias", "created_at", "used"}
7887
if order not in {"DESC", "ASC"}:
7988
raise HTTPException(status_code=400, detail="Invalid order")
@@ -82,20 +91,36 @@ async def get_urls(search: Optional[str] = None, page: int = 0, sort_by: str = "
8291
if page < 0:
8392
raise HTTPException(status_code=400, detail="Invalid page number")
8493
if search and not search.isalnum():
85-
raise HTTPException(status_code=400, detail=f'search term "{search}" is invalid. only alphanumeric chars are allowed')
94+
raise HTTPException(
95+
status_code=400,
96+
detail=f'search term "{search}" is invalid. only alphanumeric chars are allowed',
97+
)
8698
with MetricsHandler.query_time.labels("list").time():
87-
urls = sqlite_helpers.get_urls(DATABASE_FILE, page, search=search, sort_by=sort_by, order=order)
99+
urls = sqlite_helpers.get_urls(
100+
DATABASE_FILE, page, search=search, sort_by=sort_by, order=order
101+
)
88102
total_urls = sqlite_helpers.get_number_of_entries(DATABASE_FILE, search=search)
89-
return {"data": urls, "total": total_urls, "rows_per_page": sqlite_helpers.ROWS_PER_PAGE}
103+
return {
104+
"data": urls,
105+
"total": total_urls,
106+
"rows_per_page": sqlite_helpers.ROWS_PER_PAGE,
107+
}
108+
90109

91110
@app.get("/find/{alias}")
92111
async def get_url(alias: str):
93112
logging.debug(f"/find called with alias: {alias}")
113+
url_output = cache.find(alias) # try to find url in cache
114+
if url_output is not None:
115+
alias_queue.put(alias)
116+
return RedirectResponse(url_output)
117+
94118
with MetricsHandler.query_time.labels("find").time():
95119
url_output = sqlite_helpers.get_url(DATABASE_FILE, alias)
96-
97120
if url_output is None:
98121
raise HTTPException(status_code=HttpResponse.NOT_FOUND.code)
122+
cache.add(alias, url_output) # else, adds url and alias to cache
123+
99124
alias_queue.put(alias)
100125
return RedirectResponse(url_output)
101126

@@ -104,15 +129,19 @@ async def get_url(alias: str):
104129
async def delete_url(alias: str):
105130
logging.debug(f"/delete called with alias: {alias}")
106131
with MetricsHandler.query_time.labels("delete").time():
107-
if(sqlite_helpers.delete_url(DATABASE_FILE, alias)):
108-
return {"message": "URL deleted successfully"}
109-
else:
110-
raise HTTPException(status_code=HttpResponse.NOT_FOUND.code)
132+
if sqlite_helpers.delete_url(DATABASE_FILE, alias):
133+
return {"message": "URL deleted successfully"}
134+
else:
135+
raise HTTPException(status_code=HttpResponse.NOT_FOUND.code)
136+
111137

112138
@app.exception_handler(HTTPException)
113139
async def http_exception_handler(request, exc):
114140
status_code_enum = http_code_to_enum[exc.status_code]
115-
return HTMLResponse(content=status_code_enum.content, status_code=status_code_enum.code)
141+
return HTMLResponse(
142+
content=status_code_enum.content, status_code=status_code_enum.code
143+
)
144+
116145

117146
@app.get("/metrics")
118147
def get_metrics():
@@ -121,20 +150,22 @@ def get_metrics():
121150
content=prometheus_client.generate_latest(),
122151
)
123152

153+
124154
logging.Formatter.converter = time.gmtime
125155

126156
logging.basicConfig(
127157
# in mondo we trust
128158
format="%(asctime)s.%(msecs)03dZ %(levelname)s:%(name)s:%(message)s",
129159
datefmt="%Y-%m-%dT%H:%M:%S",
130-
level= logging.ERROR - (args.verbose*10),
160+
level=logging.ERROR - (args.verbose * 10),
131161
)
132162

163+
133164
def consumer():
134165
while True:
135166
alias = alias_queue.get()
136-
if alias is None:
137-
break
167+
if alias is None:
168+
break
138169
try:
139170
with MetricsHandler.query_time.labels("increment_used").time():
140171
increment_used_column(DATABASE_FILE, alias)
@@ -144,7 +175,6 @@ def consumer():
144175
alias_queue.task_done()
145176

146177

147-
148178
# we have a separate __name__ check here due to how FastAPI starts
149179
# a server. the file is first ran (where __name__ == "__main__")
150180
# and then calls `uvicorn.run`. the call to run() reruns the file,
@@ -155,6 +185,7 @@ def consumer():
155185
# server uses
156186
if __name__ == "server":
157187
initial_url_count = sqlite_helpers.get_number_of_entries(DATABASE_FILE)
188+
MetricsHandler.init()
158189
MetricsHandler.url_count.inc(initial_url_count)
159190
consumer_thread = Thread(target=consumer, daemon=True)
160191
consumer_thread.start()

0 commit comments

Comments
 (0)