|
1 | 1 | from __future__ import annotations |
2 | 2 | import numpy as np |
3 | 3 | from typing import Optional |
4 | | - |
| 4 | +from pathlib import Path |
| 5 | +import json |
5 | 6 |
|
6 | 7 | from .shank import Shank |
7 | 8 |
|
@@ -197,16 +198,24 @@ def __repr__(self): |
197 | 198 | return self.get_title() |
198 | 199 |
|
199 | 200 | def annotate(self, **kwargs): |
200 | | - """Annotates the probe object. |
| 201 | + """ |
| 202 | + Annotates the probe object. |
201 | 203 |
|
202 | | - Parameter |
203 | | - --------- |
204 | | - **kwargs : list of keyword arguments to add to the annotations |
| 204 | + Parameters |
| 205 | + ---------- |
| 206 | + **kwargs : list of keyword arguments to add to the annotations (e.g., brain_area="CA1") |
205 | 207 | """ |
206 | 208 | self.annotations.update(kwargs) |
207 | 209 | self.check_annotations() |
208 | 210 |
|
209 | 211 | def annotate_contacts(self, **kwargs): |
| 212 | + """ |
| 213 | + Annotates the contacts of the probe. |
| 214 | +
|
| 215 | + Parameters |
| 216 | + ---------- |
| 217 | + **kwargs : list of keyword arguments to add to the annotations (e.g., quality=["good", "bad", ...]) |
| 218 | + """ |
210 | 219 | n = self.get_contact_count() |
211 | 220 | for k, values in kwargs.items(): |
212 | 221 | assert len(values) == n, ( |
@@ -506,6 +515,13 @@ def __eq__(self, other): |
506 | 515 | if not np.array_equal(self.contact_annotations[key], other.contact_annotations[key]): |
507 | 516 | return False |
508 | 517 |
|
| 518 | + # planar contour |
| 519 | + if self.probe_planar_contour is not None: |
| 520 | + if other.probe_planar_contour is None: |
| 521 | + return False |
| 522 | + if not np.array_equal(self.probe_planar_contour, other.probe_planar_contour): |
| 523 | + return False |
| 524 | + |
509 | 525 | return True |
510 | 526 |
|
511 | 527 | def copy(self): |
@@ -862,7 +878,7 @@ def to_numpy(self, complete: bool = False) -> np.array: |
862 | 878 | dtype += [(f"plane_axis_{dim}_0", "float64")] |
863 | 879 | dtype += [(f"plane_axis_{dim}_1", "float64")] |
864 | 880 | for k, v in self.contact_annotations.items(): |
865 | | - dtype += [(f"{k}", np.dtype(v[0]))] |
| 881 | + dtype += [(f"{k}", np.array(v, copy=False).dtype)] |
866 | 882 |
|
867 | 883 | arr = np.zeros(self.get_contact_count(), dtype=dtype) |
868 | 884 | arr["x"] = self.contact_positions[:, 0] |
@@ -916,8 +932,28 @@ def from_numpy(arr: np.ndarray) -> "Probe": |
916 | 932 | probe : Probe |
917 | 933 | The instantiated Probe object |
918 | 934 | """ |
919 | | - |
920 | 935 | fields = list(arr.dtype.fields) |
| 936 | + main_fields = [ |
| 937 | + "x", |
| 938 | + "y", |
| 939 | + "z", |
| 940 | + "contact_shapes", |
| 941 | + "shank_ids", |
| 942 | + "contact_ids", |
| 943 | + "device_channel_indices", |
| 944 | + "radius", |
| 945 | + "width", |
| 946 | + "height", |
| 947 | + "plane_axis_x_0", |
| 948 | + "plane_axis_x_1", |
| 949 | + "plane_axis_y_0", |
| 950 | + "plane_axis_y_1", |
| 951 | + "plane_axis_z_0", |
| 952 | + "plane_axis_z_1", |
| 953 | + "probe_index", |
| 954 | + "si_units", |
| 955 | + ] |
| 956 | + contact_annotation_fields = [f for f in fields if f not in main_fields] |
921 | 957 |
|
922 | 958 | if "z" in fields: |
923 | 959 | ndim = 3 |
@@ -964,14 +1000,146 @@ def from_numpy(arr: np.ndarray) -> "Probe": |
964 | 1000 |
|
965 | 1001 | if "device_channel_indices" in fields: |
966 | 1002 | dev_channel_indices = arr["device_channel_indices"] |
967 | | - probe.set_device_channel_indices(dev_channel_indices) |
| 1003 | + if not np.all(dev_channel_indices == -1): |
| 1004 | + probe.set_device_channel_indices(dev_channel_indices) |
968 | 1005 | if "shank_ids" in fields: |
969 | 1006 | probe.set_shank_ids(arr["shank_ids"]) |
970 | 1007 | if "contact_ids" in fields: |
971 | 1008 | probe.set_contact_ids(arr["contact_ids"]) |
972 | 1009 |
|
| 1010 | + # contact annotations |
| 1011 | + for k in contact_annotation_fields: |
| 1012 | + probe.annotate_contacts(**{k: arr[k]}) |
973 | 1013 | return probe |
974 | 1014 |
|
| 1015 | + def add_probe_to_zarr_group(self, group: "zarr.Group") -> None: |
| 1016 | + """ |
| 1017 | + Serialize the probe's data and structure to a specified Zarr group. |
| 1018 | +
|
| 1019 | + This method is used to save the probe's attributes, annotations, and other |
| 1020 | + related data into a Zarr group, facilitating integration into larger Zarr |
| 1021 | + structures. |
| 1022 | +
|
| 1023 | + Parameters |
| 1024 | + ---------- |
| 1025 | + group : zarr.Group |
| 1026 | + The target Zarr group where the probe's data will be stored. |
| 1027 | + """ |
| 1028 | + probe_arr = self.to_numpy(complete=True) |
| 1029 | + |
| 1030 | + # add fields and contact annotations |
| 1031 | + for field_name, (dtype, offset) in probe_arr.dtype.fields.items(): |
| 1032 | + data = probe_arr[field_name] |
| 1033 | + group.create_dataset(name=field_name, data=data, dtype=dtype, chunks=False) |
| 1034 | + |
| 1035 | + # Annotations as a group (special attibutes are stored as annotations) |
| 1036 | + annotations_group = group.create_group("annotations") |
| 1037 | + for key, value in self.annotations.items(): |
| 1038 | + annotations_group.attrs[key] = value |
| 1039 | + |
| 1040 | + # Add planar contour |
| 1041 | + if self.probe_planar_contour is not None: |
| 1042 | + group.create_dataset( |
| 1043 | + name="probe_planar_contour", data=self.probe_planar_contour, dtype="float64", chunks=False |
| 1044 | + ) |
| 1045 | + |
| 1046 | + def to_zarr(self, folder_path: str | Path) -> None: |
| 1047 | + """ |
| 1048 | + Serialize the Probe object to a Zarr file located at the specified folder path. |
| 1049 | +
|
| 1050 | + This method initializes a new Zarr group at the given folder path and calls |
| 1051 | + `add_probe_to_zarr_group` to serialize the Probe's data into this group, effectively |
| 1052 | + storing the entire Probe's state in a Zarr archive. |
| 1053 | +
|
| 1054 | + Parameters |
| 1055 | + ---------- |
| 1056 | + folder_path : str | Path |
| 1057 | + The path to the folder where the Zarr data structure will be created and |
| 1058 | + where the serialized data will be stored. If the folder does not exist, |
| 1059 | + it will be created. |
| 1060 | + """ |
| 1061 | + import zarr |
| 1062 | + |
| 1063 | + # Create or open a Zarr group for writing |
| 1064 | + zarr_group = zarr.open_group(folder_path, mode="w") |
| 1065 | + |
| 1066 | + # Serialize this Probe object into the Zarr group |
| 1067 | + self.add_probe_to_zarr_group(zarr_group) |
| 1068 | + |
| 1069 | + @staticmethod |
| 1070 | + def from_zarr_group(group: zarr.Group) -> "Probe": |
| 1071 | + """ |
| 1072 | + Load a probe instance from a given Zarr group. |
| 1073 | +
|
| 1074 | + Parameters |
| 1075 | + ---------- |
| 1076 | + group : zarr.Group |
| 1077 | + The Zarr group from which to load the probe. |
| 1078 | +
|
| 1079 | + Returns |
| 1080 | + ------- |
| 1081 | + Probe |
| 1082 | + An instance of the Probe class initialized with data from the Zarr group. |
| 1083 | + """ |
| 1084 | + import zarr |
| 1085 | + |
| 1086 | + dtype = [] |
| 1087 | + # load all datasets |
| 1088 | + num_contacts = None |
| 1089 | + probe_arr_keys = [] |
| 1090 | + for key in group.keys(): |
| 1091 | + if key == "probe_planar_contour": |
| 1092 | + continue |
| 1093 | + if key == "annotations": |
| 1094 | + continue |
| 1095 | + dset = group[key] |
| 1096 | + if isinstance(dset, zarr.Array): |
| 1097 | + probe_arr_keys.append(key) |
| 1098 | + dtype.append((key, dset.dtype)) |
| 1099 | + if num_contacts is None: |
| 1100 | + num_contacts = len(dset) |
| 1101 | + |
| 1102 | + # Create a structured array from the datasets |
| 1103 | + probe_arr = np.zeros(num_contacts, dtype=dtype) |
| 1104 | + |
| 1105 | + for probe_key in probe_arr_keys: |
| 1106 | + probe_arr[probe_key] = group[probe_key][:] |
| 1107 | + |
| 1108 | + # Create a Probe instance from the structured array |
| 1109 | + probe = Probe.from_numpy(probe_arr) |
| 1110 | + |
| 1111 | + # Load annotations |
| 1112 | + annotations_group = group.get("annotations", None) |
| 1113 | + for key in annotations_group.attrs.keys(): |
| 1114 | + # Use the annotate method for each key-value pair |
| 1115 | + probe.annotate(**{key: annotations_group.attrs[key]}) |
| 1116 | + |
| 1117 | + if "probe_planar_contour" in group: |
| 1118 | + # Directly assign since there's no specific setter for probe_planar_contour |
| 1119 | + probe.probe_planar_contour = group["probe_planar_contour"][:] |
| 1120 | + |
| 1121 | + return probe |
| 1122 | + |
| 1123 | + @staticmethod |
| 1124 | + def from_zarr(folder_path: str | Path) -> "Probe": |
| 1125 | + """ |
| 1126 | + Deserialize the Probe object from a Zarr file located at the given folder path. |
| 1127 | +
|
| 1128 | + Parameters |
| 1129 | + ---------- |
| 1130 | + folder_path : str | Path |
| 1131 | + The path to the folder where the Zarr file is located. |
| 1132 | +
|
| 1133 | + Returns |
| 1134 | + ------- |
| 1135 | + Probe |
| 1136 | + An instance of the Probe class initialized with data from the Zarr file. |
| 1137 | + """ |
| 1138 | + import zarr |
| 1139 | + |
| 1140 | + zarr_group = zarr.open(folder_path, mode="r") |
| 1141 | + return Probe.from_zarr_group(zarr_group) |
| 1142 | + |
975 | 1143 | def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": |
976 | 1144 | """ |
977 | 1145 | Export the probe to a pandas dataframe |
|
0 commit comments