Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions findmy/accessory.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ def from_json(
class _AccessoryKeyGenerator(KeyGenerator[KeyPair]):
"""KeyPair generator. Uses the same algorithm internally as FindMy accessories do."""

# cache enough keys for an entire week.
# every interval'th key is cached.
_CACHE_SIZE = 4 * 24 * 7 # 4 keys / hour
_CACHE_INTERVAL = 10

def __init__(
self,
master_key: bytes,
Expand All @@ -401,8 +406,7 @@ def __init__(
self._initial_sk = initial_sk
self._key_type = key_type

self._cur_sk = initial_sk
self._cur_sk_ind = 0
self._sk_cache: dict[int, bytes] = {}

self._iter_ind = 0

Expand All @@ -426,14 +430,33 @@ def _get_sk(self, ind: int) -> bytes:
msg = "The key index must be non-negative"
raise ValueError(msg)

if ind < self._cur_sk_ind: # behind us; need to reset :(
self._cur_sk = self._initial_sk
self._cur_sk_ind = 0
# retrieve from cache
cached_sk = self._sk_cache.get(ind)
if cached_sk is not None:
return cached_sk

# not in cache: find largest cached index smaller than ind (if exists)
start_ind: int = 0
cur_sk: bytes = self._initial_sk
for cached_ind in self._sk_cache:
if cached_ind < ind and cached_ind > start_ind:
start_ind = cached_ind
cur_sk = self._sk_cache[cached_ind]

# compute and update cache
for cur_ind in range(start_ind, ind):
cur_sk = crypto.x963_kdf(cur_sk, b"update", 32)

# insert intermediate result into cache and evict oldest entry if necessary
if cur_ind % self._CACHE_INTERVAL == 0:
self._sk_cache[cur_ind] = cur_sk

if len(self._sk_cache) > self._CACHE_SIZE:
# evict oldest entry
oldest_ind = min(self._sk_cache.keys())
del self._sk_cache[oldest_ind]

for _ in range(self._cur_sk_ind, ind):
self._cur_sk = crypto.x963_kdf(self._cur_sk, b"update", 32)
self._cur_sk_ind += 1
return self._cur_sk
return cur_sk

def _get_keypair(self, ind: int) -> KeyPair:
sk = self._get_sk(ind)
Expand All @@ -449,14 +472,14 @@ def _generate_keys(self, start: int, stop: int | None) -> Generator[KeyPair, Non

@override
def __iter__(self) -> KeyGenerator:
self._iter_ind = -1
return self

@override
def __next__(self) -> KeyPair:
key = self._get_keypair(self._iter_ind)
self._iter_ind += 1

return self._get_keypair(self._iter_ind)
return key

@overload
def __getitem__(self, val: int) -> KeyPair: ...
Expand Down