Skip to content

Commit 63cda4b

Browse files
authored
Merge branch 'main' into main
2 parents b8c8fa8 + 47dd371 commit 63cda4b

File tree

4 files changed

+252
-107
lines changed

4 files changed

+252
-107
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ examples/tutorials/*.svg
180180
doc/_build/*
181181
doc/tutorials/*
182182
doc/sources/*
183+
*sg_execution_times.rst
183184

184185
examples/getting_started/tmp_*
185186
examples/getting_started/phy

src/spikeinterface/core/core_tools.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22
from pathlib import Path, WindowsPath
3-
from typing import Union
3+
from typing import Union, Generator
44
import os
55
import sys
66
import datetime
77
import json
88
from copy import deepcopy
99
import importlib
1010
from math import prod
11+
from collections import namedtuple
1112

1213
import numpy as np
1314

@@ -183,6 +184,75 @@ def is_dict_extractor(d: dict) -> bool:
183184
return is_extractor
184185

185186

187+
extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"])
188+
189+
190+
def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]:
191+
"""
192+
Iterator for recursive traversal of a dictionary.
193+
This function explores the dictionary recursively and yields the path to each value along with the value itself.
194+
195+
By path here we mean the keys that lead to the value in the dictionary:
196+
e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b').
197+
198+
See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure.
199+
200+
Parameters
201+
----------
202+
extractor_dict : dict
203+
Input dictionary
204+
205+
Yields
206+
------
207+
extractor_dict_element
208+
Named tuple containing the value, the name, and the access_path to the value in the dictionary.
209+
210+
"""
211+
212+
def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""):
213+
if isinstance(dict_list_or_value, dict):
214+
for k, v in dict_list_or_value.items():
215+
yield from _extractor_dict_iterator(v, access_path + (k,), name=k)
216+
elif isinstance(dict_list_or_value, list):
217+
for i, v in enumerate(dict_list_or_value):
218+
yield from _extractor_dict_iterator(
219+
v, access_path + (i,), name=name
220+
) # Propagate name of list to children
221+
else:
222+
yield extractor_dict_element(
223+
value=dict_list_or_value,
224+
name=name,
225+
access_path=access_path,
226+
)
227+
228+
yield from _extractor_dict_iterator(extractor_dict)
229+
230+
231+
def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value):
232+
"""
233+
In place modification of a value in a nested dictionary given its access path.
234+
235+
Parameters
236+
----------
237+
extractor_dict : dict
238+
The dictionary to modify
239+
access_path : tuple
240+
The path to the value in the dictionary
241+
new_value : object
242+
The new value to set
243+
244+
Returns
245+
-------
246+
dict
247+
The modified dictionary
248+
"""
249+
250+
current = extractor_dict
251+
for key in access_path[:-1]:
252+
current = current[key]
253+
current[access_path[-1]] = new_value
254+
255+
186256
def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
187257
"""
188258
Generic function for recursive modification of paths in an extractor dict.
@@ -250,15 +320,17 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
250320
raise ValueError(f"{k} key for path must be str or list[str]")
251321

252322

253-
def _get_paths_list(d):
254-
# this explore a dict and get all paths flatten in a list
255-
# the trick is to use a closure func called by recursive_path_modifier()
256-
path_list = []
323+
# This is the current definition that an element in a extractor_dict is a path
324+
# This is shared across a couple of definition so it is here for DNRY
325+
element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path))
326+
257327

258-
def append_to_path(p):
259-
path_list.append(p)
328+
def _get_paths_list(d: dict) -> list[str | Path]:
329+
path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)]
330+
331+
# if check_if_exists: TODO: Enable this once container_tools test uses proper mocks
332+
# path_list = [p for p in path_list if Path(p).exists()]
260333

261-
recursive_path_modifier(d, append_to_path, target="path", copy=True)
262334
return path_list
263335

264336

@@ -318,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool:
318390
return len(not_possible) == 0
319391

320392

321-
def make_paths_relative(input_dict, relative_folder) -> dict:
393+
def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict:
322394
"""
323395
Recursively transform a dict describing an BaseExtractor to make every path relative to a folder.
324396
@@ -334,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict:
334406
output_dict: dict
335407
A copy of the input dict with modified paths.
336408
"""
409+
337410
relative_folder = Path(relative_folder).resolve().absolute()
338-
func = lambda p: _relative_to(p, relative_folder)
339-
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)
411+
412+
path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
413+
# Only paths that exist are made relative
414+
path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()]
415+
416+
output_dict = deepcopy(input_dict)
417+
for element in path_elements_in_dict:
418+
new_value = _relative_to(element.value, relative_folder)
419+
set_value_in_extractor_dict(
420+
extractor_dict=output_dict,
421+
access_path=element.access_path,
422+
new_value=new_value,
423+
)
424+
340425
return output_dict
341426

342427

@@ -359,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder):
359444
base_folder = Path(base_folder)
360445
# use as_posix instead of str to make the path unix like even on window
361446
func = lambda p: (base_folder / p).resolve().absolute().as_posix()
362-
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)
447+
448+
path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
449+
output_dict = deepcopy(input_dict)
450+
451+
output_dict = deepcopy(input_dict)
452+
for element in path_elements_in_dict:
453+
absolute_path = (base_folder / element.value).resolve()
454+
if Path(absolute_path).exists():
455+
new_value = absolute_path.as_posix() # Not so sure about this, Sam
456+
set_value_in_extractor_dict(
457+
extractor_dict=output_dict,
458+
access_path=element.access_path,
459+
new_value=new_value,
460+
)
461+
363462
return output_dict
364463

365464

366465
def recursive_key_finder(d, key):
367466
# Find all values for a key on a dictionary, even if nested
467+
# TODO refactor to use extractor_dict_iterator
468+
368469
for k, v in d.items():
369470
if isinstance(v, dict):
370471
yield from recursive_key_finder(v, key)

0 commit comments

Comments
 (0)