Skip to content

Allow relative_to with recursive=False (in-place on objects) #1824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

from .globals import get_global_tmp_folder, is_set_global_tmp_folder
from .core_tools import check_json, is_dict_extractor, recursive_path_modifier, SIJsonEncoder
from .core_tools import check_json, is_dict_extractor, recursive_path_modifier, dict_contains_extractors, SIJsonEncoder
from .job_tools import _shared_job_kwargs_doc


Expand Down Expand Up @@ -310,6 +310,7 @@ def to_dict(
relative_to: Union[str, Path, None] = None,
folder_metadata=None,
recursive: bool = False,
skip_recursive_path_modifier_warning: bool = False,
) -> dict:
"""
Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize
Expand All @@ -329,6 +330,8 @@ def to_dict(
Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties.
recursive: bool
If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False.
skip_recursive_path_modifier_warning: bool
If True, skip the warning that is raised when `recursive=True` and `relative_to` is not None.
Comment on lines +333 to +334
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which situation we would have recursive=True and relative_to not None ?
I think this option is too much detail and to_dict() is more or less internal.
So I think that either we are strict and we raise an error either we do not care but adding this flag is quite a bit heavy for the API no ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you said, to_dict is internal so we can keep it in my opinion


Returns
-------
Expand Down Expand Up @@ -359,6 +362,7 @@ def to_dict(
new_kwargs[name] = transform_extractors_to_dict(value)

kwargs = new_kwargs

class_name = str(type(self)).replace("<class '", "").replace("'>", "")
module = class_name.split(".")[0]
imported_module = importlib.import_module(module)
Expand All @@ -376,11 +380,6 @@ def to_dict(
"relative_paths": (relative_to is not None),
}

try:
dump_dict["version"] = imported_module.__version__
except AttributeError:
dump_dict["version"] = "unknown"

Comment on lines -379 to -383
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it moved elsewhere ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if include_annotations:
dump_dict["annotations"] = self._annotations
else:
Expand All @@ -394,9 +393,12 @@ def to_dict(
dump_dict["properties"] = {k: self._properties.get(k, None) for k in self._main_properties}

if relative_to is not None:
relative_to = Path(relative_to).absolute()
relative_to = Path(relative_to).resolve().absolute()
assert relative_to.is_dir(), "'relative_to' must be an existing directory"
dump_dict = _make_paths_relative(dump_dict, relative_to)
copy = False if dict_contains_extractors(dump_dict) else True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure to follow the logic here. Are we making an inplace modification of extarctors._kwargs ?
If this is the case I think this is not a good idea.
rec.to_dict() should not modify rec itself. Maybe I am missing something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are. that is the only way relative_to can work without recursive=True

dump_dict = _make_paths_relative(
dump_dict, relative_to, copy=copy, skip_warning=skip_recursive_path_modifier_warning
)

if folder_metadata is not None:
if relative_to is not None:
Expand Down Expand Up @@ -424,7 +426,8 @@ def from_dict(dictionary: dict, base_folder: Optional[Union[Path, str]] = None)
"""
if dictionary["relative_paths"]:
assert base_folder is not None, "When relative_paths=True, need to provide base_folder"
dictionary = _make_paths_absolute(dictionary, base_folder)
copy = False if dict_contains_extractors(dictionary) else True
dictionary = _make_paths_absolute(dictionary, base_folder, copy=copy)
extractor = _load_extractor_from_dict(dictionary)
folder_metadata = dictionary.get("folder_metadata", None)
if folder_metadata is not None:
Expand Down Expand Up @@ -463,9 +466,9 @@ def clone(self) -> "BaseExtractor":
"""
Clones an existing extractor into a new instance.
"""
d = self.to_dict(include_annotations=True, include_properties=True)
d = deepcopy(d)
clone = BaseExtractor.from_dict(d)
dictionary = self.to_dict(include_annotations=True, include_properties=True)
dictionary = deepcopy(dictionary)
clone = BaseExtractor.from_dict(dictionary)
Comment on lines +469 to +471
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer "d"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d doesn't work with the debugger.

return clone

def check_if_dumpable(self):
Expand Down Expand Up @@ -557,7 +560,9 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No
else:
raise ValueError("Dump: file must .json or .pkl")

def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None:
def dump_to_json(
self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None, recursive=False
) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand. I think that json should be always reccursive no ?
how could it be not reccurssive ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed it is, but it is handled by the SIJsonEncoder directly

"""
Dump recording extractor to json file.
The extractor can be re-loaded with load_extractor_from_json(json_file)
Expand All @@ -568,17 +573,29 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
Path of the json file
relative_to: str, Path, or None
If not None, file_paths are serialized relative to this path
folder_metadata: str, Path, or None
Folder with numpy files containing additional information (e.g. probe in BaseRecording) and properties.
recursive: bool
If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False.
"""
assert self.check_if_dumpable()
assert self.check_if_json_serializable(), "The extractor is not json serializable"
dump_dict = self.to_dict(
include_annotations=True, include_properties=False, relative_to=relative_to, folder_metadata=folder_metadata
include_annotations=True,
include_properties=False,
relative_to=relative_to,
folder_metadata=folder_metadata,
recursive=recursive,
skip_recursive_path_modifier_warning=True, # we skip warning because we will make paths absolute again
)
file_path = self._get_file_path(file_path, [".json"])

file_path.write_text(
json.dumps(dump_dict, indent=4, cls=SIJsonEncoder),
encoding="utf8",
)
if relative_to:
# Make paths absolute again
dump_dict = _make_paths_absolute(dump_dict, relative_to, copy=False, skip_warning=True)

def dump_to_pickle(
self,
Expand All @@ -603,18 +620,23 @@ def dump_to_pickle(
recursive: bool
If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False.
"""
assert self.check_if_dumpable()
assert self.check_if_dumpable(), "The extractor is not dumpable to pickle"
if relative_to:
assert recursive, "When relative_to is given, recursive must be True"
dump_dict = self.to_dict(
include_annotations=True,
include_properties=include_properties,
relative_to=relative_to,
folder_metadata=folder_metadata,
recursive=recursive,
skip_recursive_path_modifier_warning=True, # we skip warning because we will make paths absolute again
)
file_path = self._get_file_path(file_path, [".pkl", ".pickle"])

file_path.write_bytes(pickle.dumps(dump_dict))

# we don't need to make paths absolute, because for pickle this is only available for recursive=True

@staticmethod
def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor":
"""
Expand All @@ -630,16 +652,16 @@ def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor":
# standard case based on a file (json or pickle)
if str(file_path).endswith(".json"):
with open(str(file_path), "r") as f:
d = json.load(f)
dictionary = json.load(f)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
with open(str(file_path), "rb") as f:
d = pickle.load(f)
dictionary = pickle.load(f)
else:
raise ValueError(f"Impossible to load {file_path}")
if "warning" in d and "not dumpable" in d["warning"]:
if "warning" in dictionary and "not dumpable" in dictionary["warning"]:
print("The extractor was not dumpable")
return None
extractor = BaseExtractor.from_dict(d, base_folder=base_folder)
extractor = BaseExtractor.from_dict(dictionary, base_folder=base_folder)
return extractor

elif file_path.is_dir():
Expand Down Expand Up @@ -920,16 +942,20 @@ def save_to_zarr(
return cached


def _make_paths_relative(d, relative) -> dict:
relative = str(Path(relative).absolute())
func = lambda p: os.path.relpath(str(p), start=relative)
return recursive_path_modifier(d, func, target="path", copy=True)
def _make_paths_relative(d, relative, copy=True, skip_warning=False) -> dict:
relative = Path(relative).absolute()
func = lambda p: os.path.relpath(Path(p).resolve().absolute(), start=relative)
return recursive_path_modifier(
d, func, target="path", copy=copy, skip_targets=["relative_paths"], skip_warning=skip_warning
)


def _make_paths_absolute(d, base):
def _make_paths_absolute(d, base, copy=True, skip_warning=False) -> dict:
base = Path(base)
func = lambda p: str((base / p).resolve().absolute())
return recursive_path_modifier(d, func, target="path", copy=True)
return recursive_path_modifier(
d, func, target="path", copy=copy, skip_targets=["relative_paths"], skip_warning=skip_warning
)


def _load_extractor_from_dict(dic) -> BaseExtractor:
Expand Down
18 changes: 9 additions & 9 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def has_time_vector(self, segment_index=None):
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
d = rs.get_times_kwargs()
return d["time_vector"] is not None
time_kwargs = rs.get_times_kwargs()
return time_kwargs["time_vector"] is not None

def set_times(self, times, segment_index=None, with_warning=True):
"""Set times for a recording segment.
Expand Down Expand Up @@ -501,8 +501,8 @@ def _save(self, format="binary", **save_kwargs):
# save time vector if any
t_starts = np.zeros(self.get_num_segments(), dtype="float64") * np.nan
for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
time_kwargs = rs.get_times_kwargs()
time_vector = time_kwargs["time_vector"]
if time_vector is not None:
_ = zarr_root.create_dataset(
name=f"times_seg{segment_index}",
Expand All @@ -511,7 +511,7 @@ def _save(self, format="binary", **save_kwargs):
compressor=zarr_kwargs["compressor"],
)
elif d["t_start"] is not None:
t_starts[segment_index] = d["t_start"]
t_starts[segment_index] = time_kwargs["t_start"]

if np.any(~np.isnan(t_starts)):
zarr_root.create_dataset(name="t_starts", data=t_starts, compressor=None)
Expand All @@ -530,8 +530,8 @@ def _save(self, format="binary", **save_kwargs):
cached.set_probegroup(probegroup)

for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
time_kwargs = rs.get_times_kwargs()
time_vector = time_kwargs["time_vector"]
if time_vector is not None:
cached._recording_segments[segment_index].time_vector = time_vector

Expand Down Expand Up @@ -559,8 +559,8 @@ def _extra_metadata_to_folder(self, folder):

# save time vector if any
for segment_index, rs in enumerate(self._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
time_kwargs = rs.get_times_kwargs()
time_vector = time_kwargs["time_vector"]
if time_vector is not None:
np.save(folder / f"times_cached_seg{segment_index}.npy", time_vector)

Expand Down
13 changes: 7 additions & 6 deletions src/spikeinterface/core/binaryfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,23 @@ def __init__(self, folder_path):
folder_path = Path(folder_path)

with open(folder_path / "binary.json", "r") as f:
d = json.load(f)
dictionary = json.load(f)

if not d["class"].endswith(".BinaryRecordingExtractor"):
if not dictionary["class"].endswith(".BinaryRecordingExtractor"):
raise ValueError("This folder is not a binary spikeinterface folder")

assert d["relative_paths"]
assert dictionary["relative_paths"]

d = _make_paths_absolute(d, folder_path)
kwargs = dictionary["kwargs"]
kwargs = _make_paths_absolute(kwargs, folder_path)

BinaryRecordingExtractor.__init__(self, **d["kwargs"])
BinaryRecordingExtractor.__init__(self, **kwargs)

folder_metadata = folder_path
self.load_metadata_from_folder(folder_metadata)

self._kwargs = dict(folder_path=str(folder_path.absolute()))
self._bin_kwargs = d["kwargs"]
self._bin_kwargs = kwargs
if "num_channels" not in self._bin_kwargs:
assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json"
self._bin_kwargs["num_channels"] = self._bin_kwargs["num_chan"]
Expand Down
Loading