Skip to content

Commit d83aea0

Browse files
feat: allow library to be specified in writer_info removal (#1042)
* feat: allow library to be specified in writer_info removal Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> * style: pre-commit fixes --------- Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 397e5af commit d83aea0

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

src/boost_histogram/serialization/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,25 @@ def from_uhi(data: dict[str, Any], /) -> histogram.Histogram:
4646
T = TypeVar("T", bound="dict[str, Any]")
4747

4848

49-
def remove_writer_info(obj: T) -> T:
50-
"""Removes all boost-histogram writer_info from a histogram dict, axes dict, or storage dict. Makes copies where required, and the outer dictionary is always copied."""
49+
def remove_writer_info(obj: T, /, *, library: str | None = "boost-histogram") -> T:
50+
"""
51+
Removes all ``writer_info`` for a library from a histogram dict, axes dict,
52+
or storage dict. Makes copies where required, and the outer dictionary is
53+
always copied.
54+
55+
Specify a library name, or ``None`` to remove all.
56+
"""
5157

5258
obj = copy.copy(obj)
53-
if "boost-histogram" in obj.get("writer_info", {}):
59+
if library is None:
60+
obj.pop("writer_info")
61+
elif library in obj.get("writer_info", {}):
5462
obj["writer_info"] = copy.copy(obj["writer_info"])
55-
del obj["writer_info"]["boost-histogram"]
63+
del obj["writer_info"][library]
5664

5765
if "axes" in obj:
58-
obj["axes"] = [remove_writer_info(ax) for ax in obj["axes"]]
66+
obj["axes"] = [remove_writer_info(ax, library=library) for ax in obj["axes"]]
5967
if "storage" in obj:
60-
obj["storage"] = remove_writer_info(obj["storage"])
68+
obj["storage"] = remove_writer_info(obj["storage"], library=library)
6169

6270
return obj

tests/test_serialization_uhi.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,16 @@ def test_round_trip_native() -> None:
221221
assert h2.storage_type is bh.storage.AtomicInt64
222222

223223

224-
def test_round_trip_clean() -> None:
224+
@pytest.mark.parametrize("remove", ["boost-histogram", None])
225+
def test_round_trip_clean(remove: str | None) -> None:
225226
h = bh.Histogram(
226227
bh.axis.Integer(0, 10),
227228
storage=bh.storage.AtomicInt64(),
228229
)
229230
h.fill([-1, 0, 0, 1, 20, 20, 20])
230231

231232
data = to_uhi(h)
232-
data = remove_writer_info(data)
233+
data = remove_writer_info(data, library=remove)
233234
h2 = from_uhi(data)
234235

235236
assert isinstance(h2.axes[0], bh.axis.Regular)
@@ -262,3 +263,25 @@ def test_histogram_metadata() -> None:
262263
"other": 3,
263264
"_variance_known": True,
264265
}
266+
267+
268+
def test_remove_writer_info() -> None:
269+
d = {
270+
"uhi_schema": 1,
271+
"writer_info": {"boost-histogram": {"foo": "bar"}, "hist": {"FOO": "BAR"}},
272+
}
273+
274+
assert remove_writer_info(d, library=None) == {"uhi_schema": 1}
275+
assert remove_writer_info(d) == {
276+
"uhi_schema": 1,
277+
"writer_info": {"hist": {"FOO": "BAR"}},
278+
}
279+
assert remove_writer_info(d, library="boost-histogram") == {
280+
"uhi_schema": 1,
281+
"writer_info": {"hist": {"FOO": "BAR"}},
282+
}
283+
assert remove_writer_info(d, library="hist") == {
284+
"uhi_schema": 1,
285+
"writer_info": {"boost-histogram": {"foo": "bar"}},
286+
}
287+
assert remove_writer_info(d, library="c") == d

0 commit comments

Comments
 (0)