Skip to content

Commit d5fe93b

Browse files
authored
Merge pull request #250 from h-mayorquin/add_save_to_zarr
Add save to zarr
2 parents 4d15eda + fc09d62 commit d5fe93b

File tree

3 files changed

+200
-11
lines changed

3 files changed

+200
-11
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ test = [
4545
"scipy",
4646
"pandas",
4747
"h5py",
48-
]
48+
"zarr>=2.16.0"
49+
]
4950

5051
docs = [
5152
"pillow",

src/probeinterface/probe.py

Lines changed: 176 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22
import numpy as np
33
from typing import Optional
4-
4+
from pathlib import Path
5+
import json
56

67
from .shank import Shank
78

@@ -197,16 +198,24 @@ def __repr__(self):
197198
return self.get_title()
198199

199200
def annotate(self, **kwargs):
200-
"""Annotates the probe object.
201+
"""
202+
Annotates the probe object.
201203
202-
Parameter
203-
---------
204-
**kwargs : list of keyword arguments to add to the annotations
204+
Parameters
205+
----------
206+
**kwargs : list of keyword arguments to add to the annotations (e.g., brain_area="CA1")
205207
"""
206208
self.annotations.update(kwargs)
207209
self.check_annotations()
208210

209211
def annotate_contacts(self, **kwargs):
212+
"""
213+
Annotates the contacts of the probe.
214+
215+
Parameters
216+
----------
217+
**kwargs : list of keyword arguments to add to the annotations (e.g., quality=["good", "bad", ...])
218+
"""
210219
n = self.get_contact_count()
211220
for k, values in kwargs.items():
212221
assert len(values) == n, (
@@ -506,6 +515,13 @@ def __eq__(self, other):
506515
if not np.array_equal(self.contact_annotations[key], other.contact_annotations[key]):
507516
return False
508517

518+
# planar contour
519+
if self.probe_planar_contour is not None:
520+
if other.probe_planar_contour is None:
521+
return False
522+
if not np.array_equal(self.probe_planar_contour, other.probe_planar_contour):
523+
return False
524+
509525
return True
510526

511527
def copy(self):
@@ -862,7 +878,7 @@ def to_numpy(self, complete: bool = False) -> np.array:
862878
dtype += [(f"plane_axis_{dim}_0", "float64")]
863879
dtype += [(f"plane_axis_{dim}_1", "float64")]
864880
for k, v in self.contact_annotations.items():
865-
dtype += [(f"{k}", np.dtype(v[0]))]
881+
dtype += [(f"{k}", np.array(v, copy=False).dtype)]
866882

867883
arr = np.zeros(self.get_contact_count(), dtype=dtype)
868884
arr["x"] = self.contact_positions[:, 0]
@@ -916,8 +932,28 @@ def from_numpy(arr: np.ndarray) -> "Probe":
916932
probe : Probe
917933
The instantiated Probe object
918934
"""
919-
920935
fields = list(arr.dtype.fields)
936+
main_fields = [
937+
"x",
938+
"y",
939+
"z",
940+
"contact_shapes",
941+
"shank_ids",
942+
"contact_ids",
943+
"device_channel_indices",
944+
"radius",
945+
"width",
946+
"height",
947+
"plane_axis_x_0",
948+
"plane_axis_x_1",
949+
"plane_axis_y_0",
950+
"plane_axis_y_1",
951+
"plane_axis_z_0",
952+
"plane_axis_z_1",
953+
"probe_index",
954+
"si_units",
955+
]
956+
contact_annotation_fields = [f for f in fields if f not in main_fields]
921957

922958
if "z" in fields:
923959
ndim = 3
@@ -964,14 +1000,146 @@ def from_numpy(arr: np.ndarray) -> "Probe":
9641000

9651001
if "device_channel_indices" in fields:
9661002
dev_channel_indices = arr["device_channel_indices"]
967-
probe.set_device_channel_indices(dev_channel_indices)
1003+
if not np.all(dev_channel_indices == -1):
1004+
probe.set_device_channel_indices(dev_channel_indices)
9681005
if "shank_ids" in fields:
9691006
probe.set_shank_ids(arr["shank_ids"])
9701007
if "contact_ids" in fields:
9711008
probe.set_contact_ids(arr["contact_ids"])
9721009

1010+
# contact annotations
1011+
for k in contact_annotation_fields:
1012+
probe.annotate_contacts(**{k: arr[k]})
9731013
return probe
9741014

1015+
def add_probe_to_zarr_group(self, group: "zarr.Group") -> None:
1016+
"""
1017+
Serialize the probe's data and structure to a specified Zarr group.
1018+
1019+
This method is used to save the probe's attributes, annotations, and other
1020+
related data into a Zarr group, facilitating integration into larger Zarr
1021+
structures.
1022+
1023+
Parameters
1024+
----------
1025+
group : zarr.Group
1026+
The target Zarr group where the probe's data will be stored.
1027+
"""
1028+
probe_arr = self.to_numpy(complete=True)
1029+
1030+
# add fields and contact annotations
1031+
for field_name, (dtype, offset) in probe_arr.dtype.fields.items():
1032+
data = probe_arr[field_name]
1033+
group.create_dataset(name=field_name, data=data, dtype=dtype, chunks=False)
1034+
1035+
# Annotations as a group (special attibutes are stored as annotations)
1036+
annotations_group = group.create_group("annotations")
1037+
for key, value in self.annotations.items():
1038+
annotations_group.attrs[key] = value
1039+
1040+
# Add planar contour
1041+
if self.probe_planar_contour is not None:
1042+
group.create_dataset(
1043+
name="probe_planar_contour", data=self.probe_planar_contour, dtype="float64", chunks=False
1044+
)
1045+
1046+
def to_zarr(self, folder_path: str | Path) -> None:
1047+
"""
1048+
Serialize the Probe object to a Zarr file located at the specified folder path.
1049+
1050+
This method initializes a new Zarr group at the given folder path and calls
1051+
`add_probe_to_zarr_group` to serialize the Probe's data into this group, effectively
1052+
storing the entire Probe's state in a Zarr archive.
1053+
1054+
Parameters
1055+
----------
1056+
folder_path : str | Path
1057+
The path to the folder where the Zarr data structure will be created and
1058+
where the serialized data will be stored. If the folder does not exist,
1059+
it will be created.
1060+
"""
1061+
import zarr
1062+
1063+
# Create or open a Zarr group for writing
1064+
zarr_group = zarr.open_group(folder_path, mode="w")
1065+
1066+
# Serialize this Probe object into the Zarr group
1067+
self.add_probe_to_zarr_group(zarr_group)
1068+
1069+
@staticmethod
1070+
def from_zarr_group(group: zarr.Group) -> "Probe":
1071+
"""
1072+
Load a probe instance from a given Zarr group.
1073+
1074+
Parameters
1075+
----------
1076+
group : zarr.Group
1077+
The Zarr group from which to load the probe.
1078+
1079+
Returns
1080+
-------
1081+
Probe
1082+
An instance of the Probe class initialized with data from the Zarr group.
1083+
"""
1084+
import zarr
1085+
1086+
dtype = []
1087+
# load all datasets
1088+
num_contacts = None
1089+
probe_arr_keys = []
1090+
for key in group.keys():
1091+
if key == "probe_planar_contour":
1092+
continue
1093+
if key == "annotations":
1094+
continue
1095+
dset = group[key]
1096+
if isinstance(dset, zarr.Array):
1097+
probe_arr_keys.append(key)
1098+
dtype.append((key, dset.dtype))
1099+
if num_contacts is None:
1100+
num_contacts = len(dset)
1101+
1102+
# Create a structured array from the datasets
1103+
probe_arr = np.zeros(num_contacts, dtype=dtype)
1104+
1105+
for probe_key in probe_arr_keys:
1106+
probe_arr[probe_key] = group[probe_key][:]
1107+
1108+
# Create a Probe instance from the structured array
1109+
probe = Probe.from_numpy(probe_arr)
1110+
1111+
# Load annotations
1112+
annotations_group = group.get("annotations", None)
1113+
for key in annotations_group.attrs.keys():
1114+
# Use the annotate method for each key-value pair
1115+
probe.annotate(**{key: annotations_group.attrs[key]})
1116+
1117+
if "probe_planar_contour" in group:
1118+
# Directly assign since there's no specific setter for probe_planar_contour
1119+
probe.probe_planar_contour = group["probe_planar_contour"][:]
1120+
1121+
return probe
1122+
1123+
@staticmethod
1124+
def from_zarr(folder_path: str | Path) -> "Probe":
1125+
"""
1126+
Deserialize the Probe object from a Zarr file located at the given folder path.
1127+
1128+
Parameters
1129+
----------
1130+
folder_path : str | Path
1131+
The path to the folder where the Zarr file is located.
1132+
1133+
Returns
1134+
-------
1135+
Probe
1136+
An instance of the Probe class initialized with data from the Zarr file.
1137+
"""
1138+
import zarr
1139+
1140+
zarr_group = zarr.open(folder_path, mode="r")
1141+
return Probe.from_zarr_group(zarr_group)
1142+
9751143
def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
9761144
"""
9771145
Export the probe to a pandas dataframe

tests/test_probe.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from probeinterface import Probe
22
from probeinterface.generator import generate_dummy_probe
3+
from pathlib import Path
34

45
import numpy as np
56

@@ -148,7 +149,7 @@ def test_probe_equality_dunder():
148149

149150
# Modify probe2
150151
probe2.move([1, 1])
151-
assert probe2 != probe1
152+
assert probe1 != probe2
152153

153154

154155
def test_set_shanks():
@@ -162,7 +163,26 @@ def test_set_shanks():
162163
assert all(probe.shank_ids == shank_ids.astype(str))
163164

164165

166+
def test_save_to_zarr(tmp_path):
167+
# Generate a dummy probe instance
168+
probe = generate_dummy_probe()
169+
170+
# Define file path in the temporary directory
171+
folder_path = Path(tmp_path) / "probe.zarr"
172+
173+
# Save the probe object to Zarr format
174+
probe.to_zarr(folder_path=folder_path)
175+
176+
# Reload the probe object from the saved Zarr file
177+
reloaded_probe = Probe.from_zarr(folder_path=folder_path)
178+
179+
# Assert that the reloaded probe is equal to the original
180+
assert probe == reloaded_probe, "Reloaded Probe object does not match the original"
181+
182+
165183
if __name__ == "__main__":
166184
test_probe()
167185

168-
test_set_shanks()
186+
tmp_path = Path("tmp")
187+
tmp_path.mkdir(exist_ok=True)
188+
test_save_to_zarr(tmp_path)

0 commit comments

Comments
 (0)