forked from theroyallab/tabbyAPI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauth.py
128 lines (99 loc) · 4.03 KB
/
auth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import secrets
from typing import Optional
from fastapi import Header, HTTPException
from pydantic import BaseModel
import yaml
from logger import init_logger
logger = init_logger(__name__)
class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
"""
api_key: str
admin_key: str
def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return False
AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
global DISABLE_AUTH
DISABLE_AUTH = disable_from_config
if disable_from_config:
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)
return
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except OSError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
)
AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
)
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
"""Check if the API key is valid."""
# Allow request if auth is disabled
if DISABLE_AUTH:
return
if x_api_key:
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
raise HTTPException(401, "Invalid API key")
return x_api_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "api_key"
):
raise HTTPException(401, "Invalid API key")
return authorization
raise HTTPException(401, "Please provide an API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
"""Check if the admin key is valid."""
# Allow request if auth is disabled
if DISABLE_AUTH:
return
if x_admin_key:
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
raise HTTPException(401, "Invalid admin key")
return x_admin_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid admin key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "admin_key"
):
raise HTTPException(401, "Invalid admin key")
return authorization
raise HTTPException(401, "Please provide an admin key")