Skip to content

Commit 397e5af

Browse files
fix: support __dict__ on histogram like axes (#1041)
* fix: support __dict__ on histogram like axes Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> * fix: ignore None values Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> * style: pre-commit fixes --------- Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fdf8597 commit 397e5af

File tree

5 files changed

+70
-54
lines changed

5 files changed

+70
-54
lines changed

src/boost_histogram/histogram.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -226,27 +226,35 @@ def __init_subclass__(cls, *, family: object | None = None) -> None:
226226
cls._family = family if family is not None else object()
227227

228228
@typing.overload
229-
def __init__(self, arg: Histogram, /, *, metadata: Any = ...) -> None: ...
229+
def __init__(
230+
self, arg: Histogram, /, *, metadata: Any = ..., __dict__: Any = ...
231+
) -> None: ...
230232

231233
@typing.overload
232-
def __init__(self, arg: dict[str, Any], /, *, metadata: Any = ...) -> None: ...
234+
def __init__(
235+
self, arg: dict[str, Any], /, *, metadata: Any = ..., __dict__: Any = ...
236+
) -> None: ...
233237

234238
@typing.overload
235-
def __init__(self, arg: CppHistogram, /, *, metadata: Any = ...) -> None: ...
239+
def __init__(
240+
self, arg: CppHistogram, /, *, metadata: Any = ..., __dict__: Any = ...
241+
) -> None: ...
236242

237243
@typing.overload
238244
def __init__(
239245
self,
240246
*axes: Axis | CppAxis,
241247
storage: Storage = ...,
242248
metadata: Any = ...,
249+
__dict__: Any = ...,
243250
) -> None: ...
244251

245252
def __init__(
246253
self,
247254
*axes: Axis | CppAxis | Histogram | CppHistogram | dict[str, Any],
248255
storage: Storage | None = None,
249256
metadata: Any = NO_METADATA,
257+
__dict__: Any = None,
250258
) -> None:
251259
"""
252260
Construct a new histogram.
@@ -263,33 +271,45 @@ def __init__(
263271
Select a storage to use in the histogram
264272
metadata : Any = None
265273
Data that is passed along if a new histogram is created. No not use
266-
in new code; set properties in __dict__ directly instead.
274+
in new code; use ``__dict__`` instead.
275+
__dict__ : Any = None
276+
Better way to set metadata.
267277
"""
268278
self._variance_known = True
269279
storage_err_msg = "storage= is not allowed with conversion constructor"
270280

281+
if metadata is not NO_METADATA and __dict__:
282+
msg = (
283+
"Can't set both metadata and __dict__. Set the 'metadata' key instead."
284+
)
285+
raise TypeError(msg)
286+
if metadata is not NO_METADATA:
287+
__dict__ = {"metadata": metadata}
288+
if __dict__ is None:
289+
__dict__ = {}
290+
271291
# Allow construction from a raw histogram object (internal)
272292
if len(axes) == 1 and isinstance(axes[0], tuple(_histograms)):
273293
if storage is not None:
274294
raise TypeError(storage_err_msg)
275295
cpp_hist: CppHistogram = axes[0] # type: ignore[assignment]
276-
self._from_histogram_cpp(cpp_hist, metadata=metadata)
296+
self._from_histogram_cpp(cpp_hist, __dict__=__dict__)
277297
return
278298

279299
# If we construct with another Histogram as the only positional argument,
280300
# support that too
281301
if len(axes) == 1 and isinstance(axes[0], Histogram):
282302
if storage is not None:
283303
raise TypeError(storage_err_msg)
284-
self._from_histogram_object(axes[0], metadata=metadata)
304+
self._from_histogram_object(axes[0], __dict__=__dict__)
285305
return
286306

287307
# Support objects that provide a to_boost method, like Uproot
288308
if len(axes) == 1 and hasattr(axes[0], "_to_boost_histogram_"):
289309
if storage is not None:
290310
raise TypeError(storage_err_msg)
291311
self._from_histogram_object(
292-
axes[0]._to_boost_histogram_(), metadata=metadata
312+
axes[0]._to_boost_histogram_(), __dict__=__dict__
293313
)
294314
return
295315

@@ -298,15 +318,14 @@ def __init__(
298318
if storage is not None:
299319
raise TypeError(storage_err_msg)
300320
self._from_histogram_object(
301-
serialization.from_uhi(axes[0]), metadata=metadata
321+
serialization.from_uhi(axes[0]), __dict__=__dict__
302322
)
303323
return
304324

305325
if storage is None:
306326
storage = Double()
307327

308-
if metadata is not NO_METADATA:
309-
self.metadata = metadata
328+
self.__dict__.update(__dict__)
310329

311330
# Check for missed parenthesis or incorrect types
312331
if not isinstance(storage, Storage):
@@ -347,7 +366,7 @@ def _clone(
347366

348367
self = cls.__new__(cls)
349368
if isinstance(_hist, tuple(_histograms)):
350-
self._from_histogram_cpp(_hist, metadata=NO_METADATA) # type: ignore[arg-type]
369+
self._from_histogram_cpp(_hist, __dict__={}) # type: ignore[arg-type]
351370
if other is not None:
352371
return cls._clone(self, other=other, memo=memo)
353372
return self
@@ -357,18 +376,19 @@ def _clone(
357376
if other is None:
358377
other = _hist
359378

360-
self._from_histogram_object(_hist, metadata=NO_METADATA)
361-
362379
if memo is NOTHING:
363-
self.__dict__ = copy.copy(other.__dict__)
380+
dict_copy = copy.copy(other.__dict__)
364381
else:
365-
self.__dict__ = copy.deepcopy(other.__dict__, memo)
382+
dict_copy = copy.deepcopy(other.__dict__, memo)
383+
384+
self._from_histogram_object(_hist, __dict__=dict_copy)
366385

367386
for ax in self.axes:
368387
if memo is NOTHING:
369-
ax.__dict__ = copy.copy(ax._ax.raw_metadata)
388+
ax._ax.raw_metadata = copy.copy(ax._ax.raw_metadata)
370389
else:
371-
ax.__dict__ = copy.deepcopy(ax._ax.raw_metadata, memo)
390+
ax._ax.raw_metadata = copy.deepcopy(ax._ax.raw_metadata, memo)
391+
ax.__dict__ = ax._ax.raw_metadata
372392
return self
373393

374394
def _new_hist(self, _hist: CppHistogram, memo: Any = NOTHING) -> Self:
@@ -377,17 +397,20 @@ def _new_hist(self, _hist: CppHistogram, memo: Any = NOTHING) -> Self:
377397
"""
378398
return self.__class__._clone(_hist, other=self, memo=memo)
379399

380-
def _from_histogram_cpp(self, other: CppHistogram, *, metadata: Any) -> None:
400+
def _from_histogram_cpp(
401+
self, other: CppHistogram, *, __dict__: dict[str, Any]
402+
) -> None:
381403
"""
382404
Import a Cpp histogram.
383405
"""
384406
self._variance_known = True
385407
self._hist = other
386-
if metadata is not NO_METADATA:
387-
self.metadata = metadata
408+
self.__dict__.update(__dict__)
388409
self.axes = self._generate_axes_()
389410

390-
def _from_histogram_object(self, other: Histogram, *, metadata: Any) -> None:
411+
def _from_histogram_object(
412+
self, other: Histogram, *, __dict__: dict[str, Any]
413+
) -> None:
391414
"""
392415
Convert self into a new histogram object based on another, possibly
393416
converting from a different subclass.
@@ -396,11 +419,9 @@ def _from_histogram_object(self, other: Histogram, *, metadata: Any) -> None:
396419
self.__dict__ = copy.copy(other.__dict__)
397420
self.axes = self._generate_axes_()
398421
for ax in self.axes:
399-
ax.__dict__ = copy.copy(ax._ax.raw_metadata)
400-
if metadata is not NO_METADATA:
401-
self.metadata = metadata
402-
elif "metadata" in other.__dict__:
403-
self.metadata = other.metadata
422+
ax.__dict__.update(ax._ax.raw_metadata)
423+
self.__dict__.update(other.__dict__)
424+
self.__dict__.update(__dict__)
404425

405426
# Allow custom behavior on either "from" or "to"
406427
other._export_bh_(self)

src/boost_histogram/serialization/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pylint: disable-next=import-error
77
from .. import histogram, version
88
from ._axis import _axis_from_dict, _axis_to_dict
9+
from ._common import serialize_metadata
910
from ._storage import _data_from_dict, _storage_from_dict, _storage_to_dict
1011

1112
__all__ = ["from_uhi", "remove_writer_info", "to_uhi"]
@@ -25,7 +26,7 @@ def to_uhi(h: histogram.Histogram, /) -> dict[str, Any]:
2526
"axes": [_axis_to_dict(axis) for axis in h.axes],
2627
"storage": _storage_to_dict(h.storage_type(), h.view(flow=True)),
2728
}
28-
data["metadata"] = {k: v for k, v in h.__dict__.items() if not k.startswith("@")}
29+
data["metadata"] = serialize_metadata(h.__dict__)
2930

3031
return data
3132

src/boost_histogram/serialization/_axis.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55

66
from .. import axis
7+
from ._common import serialize_metadata
78

89
__all__ = ["_axis_from_dict", "_axis_to_dict"]
910

@@ -55,10 +56,8 @@ def _(ax: axis.Regular | axis.Integer, /) -> dict[str, Any]:
5556

5657
if isinstance(ax, axis.Integer):
5758
data["writer_info"] = {"boost-histogram": {"orig_type": "Integer"}}
58-
if ax.metadata is not None:
59-
data["metadata"] = {
60-
k: v for k, v in ax.metadata.items() if not k.startswith("@")
61-
}
59+
60+
data["metadata"] = serialize_metadata(ax.__dict__)
6261

6362
return data
6463

@@ -73,10 +72,7 @@ def _(ax: axis.Variable, /) -> dict[str, Any]:
7372
"overflow": ax.traits.overflow,
7473
"circular": ax.traits.circular,
7574
}
76-
if ax.metadata is not None:
77-
data["metadata"] = {
78-
k: v for k, v in ax.metadata.items() if not k.startswith("@")
79-
}
75+
data["metadata"] = serialize_metadata(ax.__dict__)
8076

8177
return data
8278

@@ -89,10 +85,7 @@ def _(ax: axis.IntCategory, /) -> dict[str, Any]:
8985
"categories": list(ax),
9086
"flow": ax.traits.overflow,
9187
}
92-
if ax.metadata is not None:
93-
data["metadata"] = {
94-
k: v for k, v in ax.metadata.items() if not k.startswith("@")
95-
}
88+
data["metadata"] = serialize_metadata(ax.__dict__)
9689

9790
return data
9891

@@ -105,10 +98,7 @@ def _(ax: axis.StrCategory, /) -> dict[str, Any]:
10598
"categories": list(ax),
10699
"flow": ax.traits.overflow,
107100
}
108-
if ax.metadata is not None:
109-
data["metadata"] = {
110-
k: v for k, v in ax.metadata.items() if not k.startswith("@")
111-
}
101+
data["metadata"] = serialize_metadata(ax.__dict__)
112102

113103
return data
114104

@@ -119,10 +109,7 @@ def _(ax: axis.Boolean, /) -> dict[str, Any]:
119109
data: dict[str, Any] = {
120110
"type": "boolean",
121111
}
122-
if ax.metadata is not None:
123-
data["metadata"] = {
124-
k: v for k, v in ax.metadata.items() if not k.startswith("@")
125-
}
112+
data["metadata"] = serialize_metadata(ax.__dict__)
126113

127114
return data
128115

@@ -139,7 +126,7 @@ def _axis_from_dict(data: dict[str, Any], /) -> axis.Axis:
139126
underflow=data["underflow"],
140127
overflow=data["overflow"],
141128
circular=data["circular"],
142-
metadata=data.get("metadata"),
129+
__dict__=data.get("metadata"),
143130
)
144131

145132
hist_type = data["type"]
@@ -151,29 +138,29 @@ def _axis_from_dict(data: dict[str, Any], /) -> axis.Axis:
151138
underflow=data["underflow"],
152139
overflow=data["overflow"],
153140
circular=data["circular"],
154-
metadata=data.get("metadata"),
141+
__dict__=data.get("metadata"),
155142
)
156143
if hist_type == "variable":
157144
return axis.Variable(
158145
data["edges"],
159146
underflow=data["underflow"],
160147
overflow=data["overflow"],
161148
circular=data["circular"],
162-
metadata=data.get("metadata"),
149+
__dict__=data.get("metadata"),
163150
)
164151
if hist_type == "category_int":
165152
return axis.IntCategory(
166153
data["categories"],
167154
overflow=data["flow"],
168-
metadata=data.get("metadata"),
155+
__dict__=data.get("metadata"),
169156
)
170157
if hist_type == "category_str":
171158
return axis.StrCategory(
172159
data["categories"],
173160
overflow=data["flow"],
174-
metadata=data.get("metadata"),
161+
__dict__=data.get("metadata"),
175162
)
176163
if hist_type == "boolean":
177-
return axis.Boolean(metadata=data.get("metadata"))
164+
return axis.Boolean(__dict__=data.get("metadata"))
178165

179166
raise TypeError(f"Unsupported axis type: {hist_type}")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
6+
def serialize_metadata(value: dict[str, Any], /) -> dict[str, Any]:
7+
return {k: v for k, v in value.items() if not k.startswith("@") and v is not None}

tests/test_serialization_uhi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def test_round_trip_clean() -> None:
238238

239239
def test_unserializable_metadata() -> None:
240240
h = bh.Histogram(
241-
bh.axis.Integer(0, 10, metadata={"c": 3, "@d": 4}),
241+
bh.axis.Integer(0, 10, __dict__={"c": 3, "@d": 4}),
242242
)
243243
h.__dict__["a"] = 1
244244
h.__dict__["@b"] = 2

0 commit comments

Comments
 (0)