Skip to content

Commit 41dbf29

Browse files
committed
Promote TokenCache._find() to TokenCache.search()
Change all find() in application.py to search() Update msal/token_cache.py Co-authored-by: Jiashuo Li <4003950+jiasli@users.noreply.github.com> Refine inline comments
1 parent 3a5990a commit 41dbf29

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

msal/application.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,7 @@ def _find_msal_accounts(self, environment):
11421142
"local_account_id": a.get("local_account_id"), # Tenant-specific
11431143
"realm": a.get("realm"), # Tenant-specific
11441144
}
1145-
for a in self.token_cache.find(
1145+
for a in self.token_cache.search(
11461146
TokenCache.CredentialType.ACCOUNT,
11471147
query={"environment": environment})
11481148
if a["authority_type"] in interested_authority_types
@@ -1188,18 +1188,22 @@ def _sign_out(self, home_account):
11881188
"home_account_id": home_account["home_account_id"],} # realm-independent
11891189
app_metadata = self._get_app_metadata(home_account["environment"])
11901190
# Remove RTs/FRTs, and they are realm-independent
1191-
for rt in [rt for rt in self.token_cache.find(
1191+
for rt in [ # Remove RTs from a static list (rather than from a dynamic generator),
1192+
# to avoid changing self.token_cache while it is being iterated
1193+
rt for rt in self.token_cache.search(
11921194
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account)
11931195
# Do RT's app ownership check as a precaution, in case family apps
11941196
# and 3rd-party apps share same token cache, although they should not.
11951197
if rt["client_id"] == self.client_id or (
11961198
app_metadata.get("family_id") # Now let's settle family business
11971199
and rt.get("family_id") == app_metadata["family_id"])
1198-
]:
1200+
]:
11991201
self.token_cache.remove_rt(rt)
1200-
for at in self.token_cache.find( # Remove ATs
1201-
# Regardless of realm, b/c we've removed realm-independent RTs anyway
1202-
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account):
1202+
for at in list(self.token_cache.search( # Remove ATs from a static list,
1203+
# to avoid changing self.token_cache while it is being iterated
1204+
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account,
1205+
# Regardless of realm, b/c we've removed realm-independent RTs anyway
1206+
)):
12031207
# To avoid the complexity of locating sibling family app's AT,
12041208
# we skip AT's app ownership check.
12051209
# It means ATs for other apps will also be removed, it is OK because:
@@ -1213,11 +1217,15 @@ def _forget_me(self, home_account):
12131217
owned_by_home_account = {
12141218
"environment": home_account["environment"],
12151219
"home_account_id": home_account["home_account_id"],} # realm-independent
1216-
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
1217-
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account):
1220+
for idt in list(self.token_cache.search( # Remove IDTs from a static list,
1221+
# to avoid changing self.token_cache while it is being iterated
1222+
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account, # regardless of realm
1223+
)):
12181224
self.token_cache.remove_idt(idt)
1219-
for a in self.token_cache.find( # Remove Accounts, regardless of realm
1220-
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
1225+
for a in list(self.token_cache.search( # Remove Accounts from a static list,
1226+
# to avoid changing self.token_cache while it is being iterated
1227+
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account, # regardless of realm
1228+
)):
12211229
self.token_cache.remove_account(a)
12221230

12231231
def _acquire_token_by_cloud_shell(self, scopes, data=None):
@@ -1350,12 +1358,12 @@ def _acquire_token_silent_with_error(
13501358
return result
13511359
final_result = result
13521360
for alias in self._get_authority_aliases(self.authority.instance):
1353-
if not self.token_cache.find(
1361+
if not list(self.token_cache.search( # Need a list to test emptiness
13541362
self.token_cache.CredentialType.REFRESH_TOKEN,
13551363
# target=scopes, # MUST NOT filter by scopes, because:
13561364
# 1. AAD RTs are scope-independent;
13571365
# 2. therefore target is optional per schema;
1358-
query={"environment": alias}):
1366+
query={"environment": alias})):
13591367
# Skip heavy weight logic when RT for this alias doesn't exist
13601368
continue
13611369
the_authority = Authority(
@@ -1410,11 +1418,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
14101418
query["key_id"] = key_id
14111419
now = time.time()
14121420
refresh_reason = msal.telemetry.AT_ABSENT
1413-
for entry in self.token_cache._find( # It returns a generator
1421+
for entry in self.token_cache.search( # A generator allows us to
1422+
# break early in cache-hit without finding a full list
14141423
self.token_cache.CredentialType.ACCESS_TOKEN,
14151424
target=scopes,
14161425
query=query,
1417-
): # Note that _find() holds a lock during this for loop;
1426+
): # This loop is about token search, not about token deletion.
1427+
# Note that search() holds a lock during this loop;
14181428
# that is fine because this loop is fast
14191429
expires_in = int(entry["expires_on"]) - now
14201430
if expires_in < 5*60: # Then consider it expired
@@ -1552,10 +1562,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
15521562
rt_remover=None, break_condition=lambda response: False,
15531563
refresh_reason=None, correlation_id=None, claims_challenge=None,
15541564
**kwargs):
1555-
matches = self.token_cache.find(
1565+
matches = list(self.token_cache.search( # We want a list to test emptiness
15561566
self.token_cache.CredentialType.REFRESH_TOKEN,
15571567
# target=scopes, # AAD RTs are scope-independent
1558-
query=query)
1568+
query=query))
15591569
logger.debug("Found %d RTs matching %s", len(matches), {
15601570
k: _pii_less_home_account_id(v) if k == "home_account_id" and v else v
15611571
for k, v in query.items()
@@ -2252,11 +2262,12 @@ def remove_tokens_for_client(self):
22522262
:func:`~acquire_token_for_client()` for the current client."""
22532263
for env in [self.authority.instance] + self._get_authority_aliases(
22542264
self.authority.instance):
2255-
for at in self.token_cache.find(TokenCache.CredentialType.ACCESS_TOKEN, query={
2265+
for at in list(self.token_cache.search( # Remove ATs from a snapshot
2266+
TokenCache.CredentialType.ACCESS_TOKEN, query={
22562267
"client_id": self.client_id,
22572268
"environment": env,
22582269
"home_account_id": None, # These are mostly app-only tokens
2259-
}):
2270+
})):
22602271
self.token_cache.remove_at(at)
22612272
# acquire_token_for_client() obtains no RTs, so we have no RT to remove
22622273

msal/token_cache.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
import time
44
import logging
5+
import warnings
56

67
from .authority import canonicalize
78
from .oauth2cli.oidc import decode_part, decode_id_token
@@ -117,7 +118,7 @@ def _get(self, credential_type, key, default=None): # O(1)
117118
with self._lock:
118119
return self._cache.get(credential_type, {}).get(key, default)
119120

120-
def _find(self, credential_type, target=None, query=None): # O(n) generator
121+
def search(self, credential_type, target=None, query=None): # O(n) generator
121122
"""Returns a generator of matching entries.
122123
123124
It is O(1) for AT hits, and O(n) for other types.
@@ -150,8 +151,12 @@ def _find(self, credential_type, target=None, query=None): # O(n) generator
150151
if entry != preferred_result: # Avoid yielding the same entry twice
151152
yield entry
152153

153-
def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead.
154-
return list(self._find(credential_type, target=target, query=query))
154+
def find(self, credential_type, target=None, query=None):
155+
"""Equivalent to list(search(...))."""
156+
warnings.warn(
157+
"Use list(search(...)) instead to explicitly get a list.",
158+
DeprecationWarning)
159+
return list(self.search(credential_type, target=target, query=query))
155160

156161
def add(self, event, now=None):
157162
"""Handle a token obtaining event, and add tokens into cache."""

0 commit comments

Comments
 (0)