Skip to content

Commit 9f6aaaf

Browse files
Enhance PromptCache and SysPrompt to support versioning for cache operations
1 parent 0b179d2 commit 9f6aaaf

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

sysprompt/cache.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,43 @@ def __init__(self, default_ttl: int = 3600):
1717
self._expiration: dict = {}
1818
self._default_ttl: int = default_ttl
1919

20-
def get(self, key: str) -> Optional[Any]:
20+
def get(self, key: str, version: Optional[str] = None) -> Optional[Any]:
2121
"""
2222
Get a prompt from the cache.
2323
"""
24+
cache_key = self._calculate_key(key, version)
2425
if self._default_ttl == 0:
2526
return None # Cache is disabled
26-
if key in self._cache:
27-
if self._is_expired(key):
28-
self.delete(key)
27+
if cache_key in self._cache:
28+
if self._is_expired(cache_key):
29+
self.delete(cache_key)
2930
return None
30-
self._cache.move_to_end(key)
31-
return self._cache[key]
31+
self._cache.move_to_end(cache_key)
32+
return self._cache[cache_key]
3233
return None
3334

34-
def set(self, key: str, value: Any, ttl: Optional[int] = None):
35+
def set(self, key: str, version: Optional[str] = None, value: Any = None, ttl: Optional[int] = None):
3536
"""
3637
Set a prompt in the cache.
3738
"""
3839
ttl = ttl if ttl is not None else self._default_ttl
40+
cache_key = self._calculate_key(key, version)
3941
if ttl == 0:
4042
return # Don't cache if TTL is 0
41-
if key in self._cache:
42-
del self._cache[key]
43-
self._cache[key] = value
44-
self._expiration[key] = time.time() + (ttl or self._default_ttl)
45-
self._cache.move_to_end(key)
43+
if cache_key in self._cache:
44+
del self._cache[cache_key]
45+
self._cache[cache_key] = value
46+
self._expiration[cache_key] = time.time() + (ttl or self._default_ttl)
47+
self._cache.move_to_end(cache_key)
4648

47-
def delete(self, key: str):
49+
def delete(self, key: str, version: Optional[str] = None):
4850
"""
4951
Delete a prompt from the cache.
5052
"""
51-
if key in self._cache:
52-
del self._cache[key]
53-
del self._expiration[key]
53+
cache_key = self._calculate_key(key, version)
54+
if cache_key in self._cache:
55+
del self._cache[cache_key]
56+
del self._expiration[cache_key]
5457

5558
def clear(self):
5659
"""
@@ -71,6 +74,9 @@ def set_default_ttl(self, ttl: int):
7174
"""
7275
self._default_ttl = ttl
7376

77+
def _calculate_key(self, key: str, version: Optional[str] = None) -> str:
78+
return f"{key}:{version}" if version else key
79+
7480
def __len__(self):
7581
"""
7682
Gets the number of prompts in the cache.

sysprompt/main.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class SysPrompt:
2222
"""
2323

2424
BASE_URL: str = "https://api.sysprompt.com"
25-
PROMPT_RETRIEVAL_PATH: str = "sdk/prompts/{code_id}"
26-
PROMPT_LOGGING_PATH: str = "sdk/logs/prompt"
25+
PROMPT_RETRIEVAL_PATH: str = "sdk/prompts/{code_id}/"
26+
PROMPT_LOGGING_PATH: str = "sdk/logs/prompt/"
2727
API_VERSION: str = "v1"
2828

2929
def __init__(self, api_key: Optional[str], raise_on_error: bool = False, default_cache_ttl: int = 900):
@@ -71,36 +71,37 @@ def preload_prompts(self, code_ids: Optional[List[str]] = None):
7171
"""
7272
prompts: List[dict] = []
7373
if code_ids is None:
74-
prompts = list(self._api_request('GET', 'sdk/prompts'))
74+
prompts = list(self._api_request('GET', 'sdk/prompts/'))
7575
else:
76-
prompts = list(self._api_request('POST', 'sdk/prompts', {'code_ids': code_ids}))
76+
prompts = list(self._api_request('POST', 'sdk/prompts/', {'code_ids': code_ids}))
7777

7878
for prompt in prompts:
79-
self._prompts_cache.set(prompt['code_id'], prompt)
79+
self._prompts_cache.set(prompt['code_id'], None, prompt)
8080

81-
def get_prompt(self, code_id: str) -> str | dict | list:
81+
def get_prompt(self, code_id: str, version: Optional[str] = None) -> str | dict | list:
8282
"""
8383
Retrieve a prompt from the SysPrompt API.
8484
8585
Args:
8686
code_id (str): The unique identifier for the prompt.
87-
87+
version (Optional[str]): The version of the prompt to retrieve, empty for the default version.
8888
Returns:
8989
str | dict | list: The prompt data.
9090
9191
Raises:
9292
SysPromptException: If the API request fails.
9393
"""
94-
cached_prompt = self._prompts_cache.get(code_id)
94+
cached_prompt = self._prompts_cache.get(code_id, version)
9595
if cached_prompt is not None:
9696
return cached_prompt.get('content', [])
9797

9898
try:
9999
if not code_id:
100100
raise SysPromptException("Prompt code ID is required. Got empty string.")
101-
prompt = self._api_request('GET', self.PROMPT_RETRIEVAL_PATH.format(code_id=code_id))
101+
version_string = f"?version={version}" if version else ""
102+
prompt = self._api_request('GET', f"{self.PROMPT_RETRIEVAL_PATH.format(code_id=code_id)}{version_string}")
102103
if prompt:
103-
self._prompts_cache.set(code_id, prompt)
104+
self._prompts_cache.set(code_id, version, prompt)
104105
return prompt.get('content', [])
105106
except Exception as e:
106107
if self._raise_on_error:
@@ -113,21 +114,23 @@ def compile(
113114
self,
114115
prompt_id: str | None = None,
115116
params: dict = {},
116-
prompt_object: None | str | dict | list = None
117+
version: Optional[str] = None,
118+
prompt_object: None | str | dict | list = None,
117119
) -> str | dict | list:
118120
"""
119121
Compile a prompt with the given parameters.
120122
You can either provide a prompt code ID or a prompt object.
121123
122124
Args:
125+
prompt_id (str): The unique identifier for the prompt.
123126
params (dict): The parameters to substitute into the prompt.
124-
prompt_code_id (str): The unique identifier for the prompt.
127+
version (str): The version of the prompt to retrieve, empty for the default version.
125128
prompt_object (str | dict | list): The prompt object to compile.
126129
Returns:
127130
str | dict | list: The compiled prompt.
128131
"""
129132

130-
prompt: str | dict | list | None = self.get_prompt(prompt_id) if prompt_id else prompt_object
133+
prompt: str | dict | list | None = self.get_prompt(prompt_id, version) if prompt_id else prompt_object
131134
if not prompt:
132135
if self._raise_on_error:
133136
raise SysPromptException("Prompt not found")
@@ -316,6 +319,7 @@ def _api_request(self, method: str, path: str, data: Optional[dict] = None) -> d
316319
for attempt in range(max_retries):
317320
try:
318321
url: str = self._get_url(path)
322+
print(url)
319323
response: requests.Response = requests.request(method, url, json=data, headers=self._get_headers())
320324
response.raise_for_status()
321325
return response.json()
@@ -375,7 +379,7 @@ def _get_url(self, path: str) -> str:
375379
Returns:
376380
str: The full URL.
377381
"""
378-
return f"{self.BASE_URL}/{self.API_VERSION}/{path}/"
382+
return f"{self.BASE_URL}/{self.API_VERSION}/{path}"
379383

380384
def _get_headers(self) -> dict:
381385
"""

0 commit comments

Comments
 (0)