Skip to content

Commit 28b8b65

Browse files
authored
✨ Feature: add cache strategy option (#161)
1 parent 08d5034 commit 28b8b65

File tree

8 files changed

+114
-35
lines changed

8 files changed

+114
-35
lines changed

docs/usage/configuration.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ github = GitHub(
1212
user_agent="GitHubKit/Python",
1313
follow_redirects=True,
1414
timeout=None,
15+
cache_strategy=None,
1516
http_cache=True,
1617
auto_retry=True,
1718
rest_api_validate_body=True,
@@ -24,13 +25,15 @@ Or, you can pass the config object directly (not recommended):
2425
import httpx
2526
from githubkit import GitHub, Config
2627
from githubkit.retry import RETRY_DEFAULT
28+
from githubkit.cache import DEFAULT_CACHE_STRATEGY
2729

2830
config = Config(
2931
base_url="https://api.github.com/",
3032
accept="application/vnd.github+json",
3133
user_agent="GitHubKit/Python",
3234
follow_redirects=True,
3335
timeout=httpx.Timeout(None),
36+
cache_strategy=DEFAULT_CACHE_STRATEGY,
3437
http_cache=True,
3538
auto_retry=RETRY_DEFAULT,
3639
rest_api_validate_body=True,
@@ -65,6 +68,10 @@ The `follow_redirects` option is used to enable or disable the HTTP redirect fol
6568

6669
The `timeout` option is used to set the request timeout. You can pass a float, `None` or `httpx.Timeout` to this field. By default, the requests will never timeout. See [Timeout](https://www.python-httpx.org/advanced/timeouts/) for more information.
6770

71+
### `cache_strategy`
72+
73+
The `cache_strategy` option defines how to cache the tokens or http responses. You can provide a githubkit built-in cache strategy or a custom one that implements the `BaseCacheStrategy` interface. By default, githubkit uses the `MemCacheStrategy` to cache the data in memory.
74+
6875
### `http_cache`
6976

7077
The `http_cache` option enables the http caching feature powered by [Hishel](https://hishel.com/) for HTTPX. GitHub API limits the number of requests that you can make within a specific amount of time. This feature is useful to reduce the number of requests to GitHub API and avoid hitting the rate limit.

githubkit/auth/app.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, Union, Optional
2+
from typing_extensions import LiteralString
33
from datetime import datetime, timezone, timedelta
44
from collections.abc import Generator, AsyncGenerator
5+
from typing import TYPE_CHECKING, Union, ClassVar, Optional
56

67
import httpx
78

89
from githubkit.exception import AuthCredentialError
9-
from githubkit.cache import DEFAULT_CACHE, BaseCache
1010
from githubkit.utils import UNSET, Unset, exclude_unset
1111
from githubkit.compat import model_dump, type_validate_python
1212

@@ -38,10 +38,9 @@ class AppAuth(httpx.Auth):
3838
repositories: Union[Unset, list[str]] = UNSET
3939
repository_ids: Union[Unset, list[int]] = UNSET
4040
permissions: Union[Unset, "AppPermissionsType"] = UNSET
41-
cache: "BaseCache" = DEFAULT_CACHE
4241

43-
JWT_CACHE_KEY = "githubkit:auth:app:{issuer}:jwt"
44-
INSTALLATION_CACHE_KEY = (
42+
JWT_CACHE_KEY: ClassVar[LiteralString] = "githubkit:auth:app:{issuer}:jwt"
43+
INSTALLATION_CACHE_KEY: ClassVar[LiteralString] = (
4544
"githubkit:auth:app:{issuer}:installation:"
4645
"{installation_id}:{permissions}:{repositories}:{repository_ids}"
4746
)
@@ -89,17 +88,19 @@ def _get_jwt_cache_key(self) -> str:
8988
return self.JWT_CACHE_KEY.format(issuer=self.issuer)
9089

9190
def get_jwt(self) -> str:
91+
cache = self.github.config.cache_strategy.get_cache_storage()
9292
cache_key = self._get_jwt_cache_key()
93-
if not (token := self.cache.get(cache_key)):
93+
if not (token := cache.get(cache_key)):
9494
token = self._create_jwt()
95-
self.cache.set(cache_key, token, timedelta(minutes=8))
95+
cache.set(cache_key, token, timedelta(minutes=8))
9696
return token
9797

9898
async def aget_jwt(self) -> str:
99+
cache = self.github.config.cache_strategy.get_async_cache_storage()
99100
cache_key = self._get_jwt_cache_key()
100-
if not (token := await self.cache.aget(cache_key)):
101+
if not (token := await cache.aget(cache_key)):
101102
token = self._create_jwt()
102-
await self.cache.aset(cache_key, token, timedelta(minutes=8))
103+
await cache.aset(cache_key, token, timedelta(minutes=8))
103104
return token
104105

105106
def _build_installation_auth_request(self) -> httpx.Request:
@@ -202,8 +203,9 @@ def sync_auth_flow(
202203
).sync_auth_flow(request)
203204
return
204205

206+
cache = self.github.config.cache_strategy.get_cache_storage()
205207
key = self._get_installation_cache_key()
206-
if not (token := self.cache.get(key)):
208+
if not (token := cache.get(key)):
207209
token_request = self._build_installation_auth_request()
208210
token_request.headers["Authorization"] = f"Bearer {self.get_jwt()}"
209211
response = yield token_request
@@ -213,7 +215,7 @@ def sync_auth_flow(
213215
expire = datetime.strptime(
214216
response.parsed_data.expires_at, "%Y-%m-%dT%H:%M:%SZ"
215217
).replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)
216-
self.cache.set(key, token, expire)
218+
cache.set(key, token, expire)
217219
request.headers["Authorization"] = f"token {token}"
218220
yield request
219221

@@ -239,8 +241,9 @@ async def async_auth_flow(
239241
yield request
240242
return
241243

244+
cache = self.github.config.cache_strategy.get_async_cache_storage()
242245
key = self._get_installation_cache_key()
243-
if not (token := await self.cache.aget(key)):
246+
if not (token := await cache.aget(key)):
244247
token_request = self._build_installation_auth_request()
245248
token_request.headers["Authorization"] = f"Bearer {await self.aget_jwt()}"
246249
response = yield token_request
@@ -250,7 +253,7 @@ async def async_auth_flow(
250253
expire = datetime.strptime(
251254
response.parsed_data.expires_at, "%Y-%m-%dT%H:%M:%SZ"
252255
).replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)
253-
await self.cache.aset(key, token, expire)
256+
await cache.aset(key, token, expire)
254257
request.headers["Authorization"] = f"token {token}"
255258
yield request
256259

@@ -263,7 +266,6 @@ class AppAuthStrategy(BaseAuthStrategy):
263266
private_key: str
264267
client_id: Optional[str] = None
265268
client_secret: Optional[str] = None
266-
cache: "BaseCache" = DEFAULT_CACHE
267269

268270
def __post_init__(self):
269271
# either app_id or client_id must be provided
@@ -288,7 +290,6 @@ def as_installation(
288290
repositories,
289291
repository_ids,
290292
permissions,
291-
self.cache,
292293
)
293294

294295
def as_oauth_app(self) -> OAuthAppAuthStrategy:
@@ -305,7 +306,6 @@ def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
305306
self.private_key,
306307
self.client_id,
307308
self.client_secret,
308-
cache=self.cache,
309309
)
310310

311311

@@ -321,7 +321,6 @@ class AppInstallationAuthStrategy(BaseAuthStrategy):
321321
repositories: Union[Unset, list[str]] = UNSET
322322
repository_ids: Union[Unset, list[int]] = UNSET
323323
permissions: Union[Unset, "AppPermissionsType"] = UNSET
324-
cache: "BaseCache" = DEFAULT_CACHE
325324

326325
def __post_init__(self):
327326
# either app_id or client_id must be provided
@@ -336,7 +335,6 @@ def as_app(self) -> AppAuthStrategy:
336335
self.private_key,
337336
self.client_id,
338337
self.client_secret,
339-
self.cache,
340338
)
341339

342340
def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
@@ -350,5 +348,4 @@ def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
350348
self.repositories,
351349
self.repository_ids,
352350
self.permissions,
353-
cache=self.cache,
354351
)

githubkit/cache/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from .base import BaseCache as BaseCache
22
from .mem_cache import MemCache as MemCache
3+
from .base import AsyncBaseCache as AsyncBaseCache
4+
from .base import BaseCacheStrategy as BaseCacheStrategy
5+
from .mem_cache import MemCacheStrategy as MemCacheStrategy
36

4-
DEFAULT_CACHE = MemCache()
7+
DEFAULT_CACHE_STRATEGY = MemCacheStrategy()

githubkit/cache/base.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,42 @@
22
from typing import Optional
33
from datetime import timedelta
44

5+
from hishel import BaseStorage, AsyncBaseStorage
6+
57

68
class BaseCache(abc.ABC):
79
@abc.abstractmethod
810
def get(self, key: str) -> Optional[str]:
911
raise NotImplementedError
1012

1113
@abc.abstractmethod
12-
async def aget(self, key: str) -> Optional[str]:
14+
def set(self, key: str, value: str, ex: timedelta) -> None:
1315
raise NotImplementedError
1416

17+
18+
class AsyncBaseCache(abc.ABC):
1519
@abc.abstractmethod
16-
def set(self, key: str, value: str, ex: timedelta) -> None:
20+
async def aget(self, key: str) -> Optional[str]:
1721
raise NotImplementedError
1822

1923
@abc.abstractmethod
2024
async def aset(self, key: str, value: str, ex: timedelta) -> None:
2125
raise NotImplementedError
26+
27+
28+
class BaseCacheStrategy(abc.ABC):
29+
@abc.abstractmethod
30+
def get_cache_storage(self) -> BaseCache:
31+
raise NotImplementedError
32+
33+
@abc.abstractmethod
34+
def get_async_cache_storage(self) -> AsyncBaseCache:
35+
raise NotImplementedError
36+
37+
@abc.abstractmethod
38+
def get_hishel_storage(self) -> BaseStorage:
39+
raise NotImplementedError
40+
41+
@abc.abstractmethod
42+
def get_async_hishel_storage(self) -> AsyncBaseStorage:
43+
raise NotImplementedError

githubkit/cache/mem_cache.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from dataclasses import dataclass
33
from datetime import datetime, timezone, timedelta
44

5-
from .base import BaseCache
5+
from hishel import InMemoryStorage, AsyncInMemoryStorage
6+
7+
from .base import BaseCache, AsyncBaseCache, BaseCacheStrategy
68

79

810
@dataclass(frozen=True)
@@ -11,7 +13,7 @@ class _Item:
1113
expire_at: Optional[datetime] = None
1214

1315

14-
class MemCache(BaseCache):
16+
class MemCache(AsyncBaseCache, BaseCache):
1517
"""Simple Memory Cache with Expiration Support"""
1618

1719
def __init__(self):
@@ -36,3 +38,28 @@ def set(self, key: str, value: str, ex: timedelta) -> None:
3638

3739
async def aset(self, key: str, value: str, ex: timedelta) -> None:
3840
return self.set(key, value, ex)
41+
42+
43+
class MemCacheStrategy(BaseCacheStrategy):
44+
def __init__(self) -> None:
45+
self._cache: Optional[MemCache] = None
46+
self._hishel_storage: Optional[InMemoryStorage] = None
47+
self._hishel_async_storage: Optional[AsyncInMemoryStorage] = None
48+
49+
def get_cache_storage(self) -> MemCache:
50+
if self._cache is None:
51+
self._cache = MemCache()
52+
return self._cache
53+
54+
def get_async_cache_storage(self) -> MemCache:
55+
return self.get_cache_storage()
56+
57+
def get_hishel_storage(self) -> InMemoryStorage:
58+
if self._hishel_storage is None:
59+
self._hishel_storage = InMemoryStorage()
60+
return self._hishel_storage
61+
62+
def get_async_hishel_storage(self) -> AsyncInMemoryStorage:
63+
if self._hishel_async_storage is None:
64+
self._hishel_async_storage = AsyncInMemoryStorage()
65+
return self._hishel_async_storage

githubkit/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .retry import RETRY_DEFAULT
88
from .typing import RetryDecisionFunc
9+
from .cache import DEFAULT_CACHE_STRATEGY, BaseCacheStrategy
910

1011

1112
@dataclass(frozen=True)
@@ -15,6 +16,7 @@ class Config:
1516
user_agent: str
1617
follow_redirects: bool
1718
timeout: httpx.Timeout
19+
cache_strategy: BaseCacheStrategy
1820
http_cache: bool
1921
auto_retry: Optional[RetryDecisionFunc]
2022
rest_api_validate_body: bool
@@ -64,6 +66,12 @@ def build_timeout(
6466
return timeout if isinstance(timeout, httpx.Timeout) else httpx.Timeout(timeout)
6567

6668

69+
def build_cache_strategy(
70+
cache_strategy: Optional[BaseCacheStrategy],
71+
) -> BaseCacheStrategy:
72+
return cache_strategy or DEFAULT_CACHE_STRATEGY
73+
74+
6775
def build_auto_retry(
6876
auto_retry: Union[bool, RetryDecisionFunc] = True,
6977
) -> Optional[RetryDecisionFunc]:
@@ -76,12 +84,14 @@ def build_auto_retry(
7684

7785

7886
def get_config(
87+
*,
7988
base_url: Optional[Union[str, httpx.URL]] = None,
8089
accept_format: Optional[str] = None,
8190
previews: Optional[list[str]] = None,
8291
user_agent: Optional[str] = None,
8392
follow_redirects: bool = True,
8493
timeout: Optional[Union[float, httpx.Timeout]] = None,
94+
cache_strategy: Optional[BaseCacheStrategy] = None,
8595
http_cache: bool = True,
8696
auto_retry: Union[bool, RetryDecisionFunc] = True,
8797
rest_api_validate_body: bool = True,
@@ -92,6 +102,7 @@ def get_config(
92102
build_user_agent(user_agent),
93103
follow_redirects,
94104
build_timeout(timeout),
105+
build_cache_strategy(cache_strategy),
95106
http_cache,
96107
build_auto_retry(auto_retry),
97108
rest_api_validate_body,

0 commit comments

Comments
 (0)