1515import logging
1616import threading
1717import weakref
18+ from enum import Enum
1819from functools import wraps
1920from typing import (
2021 TYPE_CHECKING ,
2122 Any ,
2223 Callable ,
2324 Collection ,
25+ Dict ,
2426 Generic ,
25- Iterable ,
2627 List ,
2728 Optional ,
2829 Type ,
@@ -190,7 +191,7 @@ def __init__(
190191 root : "ListNode[_Node]" ,
191192 key : KT ,
192193 value : VT ,
193- cache : "weakref.ReferenceType[LruCache]" ,
194+ cache : "weakref.ReferenceType[LruCache[KT, VT] ]" ,
194195 clock : Clock ,
195196 callbacks : Collection [Callable [[], None ]] = (),
196197 prune_unread_entries : bool = True ,
@@ -290,6 +291,12 @@ def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
290291 self ._global_list_node .update_last_access (clock )
291292
292293
294+ class _Sentinel (Enum ):
295+ # defining a sentinel in this way allows mypy to correctly handle the
296+ # type of a dictionary lookup.
297+ sentinel = object ()
298+
299+
293300class LruCache (Generic [KT , VT ]):
294301 """
295302 Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -302,7 +309,7 @@ def __init__(
302309 max_size : int ,
303310 cache_name : Optional [str ] = None ,
304311 cache_type : Type [Union [dict , TreeCache ]] = dict ,
305- size_callback : Optional [Callable ] = None ,
312+ size_callback : Optional [Callable [[ VT ], int ] ] = None ,
306313 metrics_collection_callback : Optional [Callable [[], None ]] = None ,
307314 apply_cache_factor_from_config : bool = True ,
308315 clock : Optional [Clock ] = None ,
@@ -339,7 +346,7 @@ def __init__(
339346 else :
340347 real_clock = clock
341348
342- cache = cache_type ()
349+ cache : Union [ Dict [ KT , _Node [ KT , VT ]], TreeCache ] = cache_type ()
343350 self .cache = cache # Used for introspection.
344351 self .apply_cache_factor_from_config = apply_cache_factor_from_config
345352
@@ -374,7 +381,7 @@ def __init__(
374381 # creating more each time we create a `_Node`.
375382 weak_ref_to_self = weakref .ref (self )
376383
377- list_root = ListNode [_Node ].create_root_node ()
384+ list_root = ListNode [_Node [ KT , VT ] ].create_root_node ()
378385
379386 lock = threading .Lock ()
380387
@@ -422,7 +429,7 @@ def cache_len() -> int:
422429 def add_node (
423430 key : KT , value : VT , callbacks : Collection [Callable [[], None ]] = ()
424431 ) -> None :
425- node = _Node (
432+ node : _Node [ KT , VT ] = _Node (
426433 list_root ,
427434 key ,
428435 value ,
@@ -439,10 +446,10 @@ def add_node(
439446 if caches .TRACK_MEMORY_USAGE and metrics :
440447 metrics .inc_memory_usage (node .memory )
441448
442- def move_node_to_front (node : _Node ) -> None :
449+ def move_node_to_front (node : _Node [ KT , VT ] ) -> None :
443450 node .move_to_front (real_clock , list_root )
444451
445- def delete_node (node : _Node ) -> int :
452+ def delete_node (node : _Node [ KT , VT ] ) -> int :
446453 node .drop_from_lists ()
447454
448455 deleted_len = 1
@@ -496,7 +503,7 @@ def cache_get(
496503
497504 @synchronized
498505 def cache_set (
499- key : KT , value : VT , callbacks : Iterable [Callable [[], None ]] = ()
506+ key : KT , value : VT , callbacks : Collection [Callable [[], None ]] = ()
500507 ) -> None :
501508 node = cache .get (key , None )
502509 if node is not None :
@@ -590,8 +597,6 @@ def cache_clear() -> None:
590597 def cache_contains (key : KT ) -> bool :
591598 return key in cache
592599
593- self .sentinel = object ()
594-
595600 # make sure that we clear out any excess entries after we get resized.
596601 self ._on_resize = evict
597602
@@ -608,18 +613,18 @@ def cache_contains(key: KT) -> bool:
608613 self .clear = cache_clear
609614
610615 def __getitem__ (self , key : KT ) -> VT :
611- result = self .get (key , self .sentinel )
612- if result is self .sentinel :
616+ result = self .get (key , _Sentinel .sentinel )
617+ if result is _Sentinel .sentinel :
613618 raise KeyError ()
614619 else :
615- return cast ( VT , result )
620+ return result
616621
617622 def __setitem__ (self , key : KT , value : VT ) -> None :
618623 self .set (key , value )
619624
620625 def __delitem__ (self , key : KT , value : VT ) -> None :
621- result = self .pop (key , self .sentinel )
622- if result is self .sentinel :
626+ result = self .pop (key , _Sentinel .sentinel )
627+ if result is _Sentinel .sentinel :
623628 raise KeyError ()
624629
625630 def __len__ (self ) -> int :
0 commit comments