Skip to content

Commit 8a7495a

Browse files
committed
Refactor cache to consolidate the add behavior
1 parent d46e5b8 commit 8a7495a

File tree

1 file changed

+54
-76
lines changed

1 file changed

+54
-76
lines changed

msal/token_cache.py

Lines changed: 54 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,50 @@ def __init__(self):
3636
self._lock = threading.RLock()
3737
self._cache = {}
3838
self.key_makers = {
39-
self.CredentialType.REFRESH_TOKEN: self._build_rt_key,
40-
self.CredentialType.ACCESS_TOKEN: self._build_at_key,
41-
self.CredentialType.ID_TOKEN: self._build_idt_key,
42-
self.CredentialType.ACCOUNT: self._build_account_key,
39+
self.CredentialType.REFRESH_TOKEN:
40+
lambda home_account_id=None, environment=None, client_id=None,
41+
target=None, **ignored_payload_from_a_real_token:
42+
"-".join([
43+
home_account_id or "",
44+
environment or "",
45+
self.CredentialType.REFRESH_TOKEN,
46+
client_id or "",
47+
"", # RT is cross-tenant in AAD
48+
target or "", # raw value could be None if deserialized from other SDK
49+
]).lower(),
50+
self.CredentialType.ACCESS_TOKEN:
51+
lambda home_account_id=None, environment=None, client_id=None,
52+
realm=None, target=None, **ignored_payload_from_a_real_token:
53+
"-".join([
54+
home_account_id or "",
55+
environment or "",
56+
self.CredentialType.ACCESS_TOKEN,
57+
client_id,
58+
realm or "",
59+
target or "",
60+
]).lower(),
61+
self.CredentialType.ID_TOKEN:
62+
lambda home_account_id=None, environment=None, client_id=None,
63+
realm=None, **ignored_payload_from_a_real_token:
64+
"-".join([
65+
home_account_id or "",
66+
environment or "",
67+
self.CredentialType.ID_TOKEN,
68+
client_id or "",
69+
realm or "",
70+
"" # Albeit irrelevant, schema requires an empty scope here
71+
]).lower(),
72+
self.CredentialType.ACCOUNT:
73+
lambda home_account_id=None, environment=None, realm=None,
74+
**ignored_payload_from_a_real_entry:
75+
"-".join([
76+
home_account_id or "",
77+
environment or "",
78+
realm or "",
79+
]).lower(),
80+
self.CredentialType.APP_METADATA:
81+
lambda environment=None, client_id=None, **kwargs:
82+
"appmetadata-{}-{}".format(environment or "", client_id or ""),
4383
}
4484

4585
def find(self, credential_type, target=None, query=None):
@@ -88,12 +128,9 @@ def add(self, event, now=None):
88128
with self._lock:
89129

90130
if access_token:
91-
key = self._build_at_key(
92-
home_account_id, environment, event.get("client_id", ""),
93-
realm, target)
94131
now = time.time() if now is None else now
95132
expires_in = response.get("expires_in", 3599)
96-
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
133+
at = {
97134
"credential_type": self.CredentialType.ACCESS_TOKEN,
98135
"secret": access_token,
99136
"home_account_id": home_account_id,
@@ -106,12 +143,12 @@ def add(self, event, now=None):
106143
"extended_expires_on": str(int( # Same here
107144
now + response.get("ext_expires_in", expires_in))),
108145
}
146+
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
109147

110148
if client_info:
111149
decoded_id_token = decode_id_token(
112150
id_token, client_id=event["client_id"]) if id_token else {}
113-
key = self._build_account_key(home_account_id, environment, realm)
114-
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
151+
account = {
115152
"home_account_id": home_account_id,
116153
"environment": environment,
117154
"realm": realm,
@@ -123,11 +160,10 @@ def add(self, event, now=None):
123160
else self.AuthorityType.MSSTS,
124161
# "client_info": response.get("client_info"), # Optional
125162
}
163+
self.modify(self.CredentialType.ACCOUNT, account, account)
126164

127165
if id_token:
128-
key = self._build_idt_key(
129-
home_account_id, environment, event.get("client_id", ""), realm)
130-
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
166+
idt = {
131167
"credential_type": self.CredentialType.ID_TOKEN,
132168
"secret": id_token,
133169
"home_account_id": home_account_id,
@@ -136,11 +172,9 @@ def add(self, event, now=None):
136172
"client_id": event.get("client_id"),
137173
# "authority": "it is optional",
138174
}
175+
self.modify(self.CredentialType.ID_TOKEN, idt, idt)
139176

140177
if refresh_token:
141-
key = self._build_rt_key(
142-
home_account_id, environment,
143-
event.get("client_id", ""), target)
144178
rt = {
145179
"credential_type": self.CredentialType.REFRESH_TOKEN,
146180
"secret": refresh_token,
@@ -151,53 +185,33 @@ def add(self, event, now=None):
151185
}
152186
if "foci" in response:
153187
rt["family_id"] = response["foci"]
154-
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] = rt
188+
self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt)
155189

156-
key = self._build_appmetadata_key(environment, event.get("client_id"))
157190
app_metadata = {
158191
"client_id": event.get("client_id"),
159192
"environment": environment,
160193
}
161194
if "foci" in response:
162195
app_metadata["family_id"] = response.get("foci")
163-
self._cache.setdefault(self.CredentialType.APP_METADATA, {})[key] = app_metadata
196+
self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata)
164197

165198
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
166199
# Modify the specified old_entry with new_key_value_pairs,
167200
# or remove the old_entry if the new_key_value_pairs is None.
168201

169-
# This helper exists to consolidate all token modify/remove behaviors,
202+
# This helper exists to consolidate all token add/modify/remove behaviors,
170203
# so that the sub-classes will have only one method to work on,
171204
# instead of patching a pair of update_xx() and remove_xx() per type.
172205
# You can monkeypatch self.key_makers to support more types on-the-fly.
173206
key = self.key_makers[credential_type](**old_entry)
174207
with self._lock:
175208
if new_key_value_pairs: # Update with them
176209
entries = self._cache.setdefault(credential_type, {})
177-
entry = entries.get(key, {}) # key usually exists, but we'll survive its absence
210+
entry = entries.setdefault(key, {}) # Create it if not yet exist
178211
entry.update(new_key_value_pairs)
179212
else: # Remove old_entry
180213
self._cache.setdefault(credential_type, {}).pop(key, None)
181214

182-
183-
@staticmethod
184-
def _build_appmetadata_key(environment, client_id):
185-
return "appmetadata-{}-{}".format(environment or "", client_id or "")
186-
187-
@classmethod
188-
def _build_rt_key(
189-
cls,
190-
home_account_id=None, environment=None, client_id=None, target=None,
191-
**ignored_payload_from_a_real_token):
192-
return "-".join([
193-
home_account_id or "",
194-
environment or "",
195-
cls.CredentialType.REFRESH_TOKEN,
196-
client_id or "",
197-
"", # RT is cross-tenant in AAD
198-
target or "", # raw value could be None if deserialized from other SDK
199-
]).lower()
200-
201215
def remove_rt(self, rt_item):
202216
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
203217
return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
@@ -207,50 +221,14 @@ def update_rt(self, rt_item, new_rt):
207221
return self.modify(
208222
self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt})
209223

210-
@classmethod
211-
def _build_at_key(cls,
212-
home_account_id=None, environment=None, client_id=None,
213-
realm=None, target=None, **ignored_payload_from_a_real_token):
214-
return "-".join([
215-
home_account_id or "",
216-
environment or "",
217-
cls.CredentialType.ACCESS_TOKEN,
218-
client_id,
219-
realm or "",
220-
target or "",
221-
]).lower()
222-
223224
def remove_at(self, at_item):
224225
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
225226
return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
226227

227-
@classmethod
228-
def _build_idt_key(cls,
229-
home_account_id=None, environment=None, client_id=None, realm=None,
230-
**ignored_payload_from_a_real_token):
231-
return "-".join([
232-
home_account_id or "",
233-
environment or "",
234-
cls.CredentialType.ID_TOKEN,
235-
client_id or "",
236-
realm or "",
237-
"" # Albeit irrelevant, schema requires an empty scope here
238-
]).lower()
239-
240228
def remove_idt(self, idt_item):
241229
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
242230
return self.modify(self.CredentialType.ID_TOKEN, idt_item)
243231

244-
@classmethod
245-
def _build_account_key(cls,
246-
home_account_id=None, environment=None, realm=None,
247-
**ignored_payload_from_a_real_entry):
248-
return "-".join([
249-
home_account_id or "",
250-
environment or "",
251-
realm or "",
252-
]).lower()
253-
254232
def remove_account(self, account_item):
255233
assert "authority_type" in account_item
256234
return self.modify(self.CredentialType.ACCOUNT, account_item)

0 commit comments

Comments
 (0)