From eb3a74297211aae8f13e6974563e6da819bfbedb Mon Sep 17 00:00:00 2001 From: Kimoon Han Date: Mon, 4 Nov 2024 23:37:36 +0900 Subject: [PATCH] feat: add `qinfo` accessor Adds a `qinfo` accessor that prints a table summarizing the data in a human readable format. Closes #27 --- src/erlab/accessors/general.py | 59 ++++++++++++++++++++++++++++++ src/erlab/analysis/fit/models.py | 14 ++++--- src/erlab/analysis/gold.py | 16 ++++---- src/erlab/interactive/fermiedge.py | 6 +-- src/erlab/io/dataloader.py | 10 ++--- src/erlab/io/plugins/merlin.py | 27 ++++++-------- src/erlab/io/plugins/ssrl52.py | 25 +++++-------- tests/conftest.py | 4 +- tests/io/plugins/test_merlin.py | 13 +++++++ 9 files changed, 119 insertions(+), 55 deletions(-) diff --git a/src/erlab/accessors/general.py b/src/erlab/accessors/general.py index 16612a9..a495449 100644 --- a/src/erlab/accessors/general.py +++ b/src/erlab/accessors/general.py @@ -7,8 +7,10 @@ "SelectionAccessor", ] +import functools import importlib from collections.abc import Hashable, Mapping +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -19,6 +21,7 @@ ERLabDatasetAccessor, either_dict_or_kwargs, ) +from erlab.utils.formatting import format_html_table from erlab.utils.misc import emit_user_level_warning @@ -504,3 +507,59 @@ def around( if average: return masked.mean(sel_kw.keys()) return masked + + +@xr.register_dataarray_accessor("qinfo") +class InfoDataArrayAccessor(ERLabDataArrayAccessor): + """`xarray.Dataset.qinfo` accessor for displaying information about the data.""" + + def get_value(self, attr_or_coord_name: str) -> Any: + """Get the value of the specified attribute or coordinate. + + If the attribute or coordinate is not found, `None` is returned. + + Parameters + ---------- + attr_or_coord_name + The name of the attribute or coordinate. + + """ + if attr_or_coord_name in self._obj.attrs: + return self._obj.attrs[attr_or_coord_name] + if attr_or_coord_name in self._obj.coords: + return self._obj.coords[attr_or_coord_name] + return None + + @functools.cached_property + def _summary_table(self) -> list[tuple[str, str, str]]: + import erlab.io + + if "data_loader_name" in self._obj.attrs: + loader = erlab.io.loaders[self._obj.attrs["data_loader_name"]] + else: + raise ValueError("Data loader information not found in data attributes") + + out: list[tuple[str, str, str]] = [] + + for key, true_key in loader.summary_attrs.items(): + val = loader.get_formatted_attr_or_coord(self._obj, true_key) + if callable(true_key): + true_key = "" + out.append((key, loader.value_to_string(val), true_key)) + + return out + + def _repr_html_(self) -> str: + return format_html_table( + [("Name", "Value", "Key"), *self._summary_table], + header_cols=1, + header_rows=1, + ) + + def __repr__(self) -> str: + return "\n".join( + [ + f"{key}: {val}" if not true_key else f"{key} ({true_key}): {val}" + for key, val, true_key in self._summary_table + ] + ) diff --git a/src/erlab/analysis/fit/models.py b/src/erlab/analysis/fit/models.py index 796f78b..b120842 100644 --- a/src/erlab/analysis/fit/models.py +++ b/src/erlab/analysis/fit/models.py @@ -10,7 +10,6 @@ "StepEdgeModel", ] -import contextlib from typing import Literal import lmfit @@ -178,10 +177,11 @@ def guess(self, data, x, **kwargs): np.argmin(np.gradient(scipy.ndimage.gaussian_filter1d(data, 0.2 * len(x)))) ] - temp = 30.0 + temp = None if isinstance(data, xr.DataArray): - with contextlib.suppress(KeyError): - temp = float(data.attrs["sample_temp"]) + temp = data.qinfo.get_value("sample_temp") + if temp is None: + temp = 30.0 pars[f"{self.prefix}center"].set( value=efermi, min=np.asarray(x).min(), max=np.asarray(x).max() @@ -190,7 +190,7 @@ def guess(self, data, x, **kwargs): pars[f"{self.prefix}back1"].set(value=back1) pars[f"{self.prefix}dos0"].set(value=dos0) pars[f"{self.prefix}dos1"].set(value=dos1) - pars[f"{self.prefix}temp"].set(value=temp) + pars[f"{self.prefix}temp"].set(value=float(temp)) pars[f"{self.prefix}resolution"].set(value=0.02) return lmfit.models.update_param_vals(pars, self.prefix, **kwargs) @@ -418,7 +418,9 @@ def guess(self, data, eV, alpha, **kwargs): pars[f"{self.prefix}lin_bkg"].set(value=dos1) if isinstance(data, xr.DataArray): - pars[f"{self.prefix}temp"].set(value=data.attrs["sample_temp"]) + temp = data.qinfo.get_value("sample_temp") + if temp is not None: + pars[f"{self.prefix}temp"].set(value=float(temp)) return lmfit.models.update_param_vals(pars, self.prefix, **kwargs) diff --git a/src/erlab/analysis/gold.py b/src/erlab/analysis/gold.py index 42c2427..c417202 100644 --- a/src/erlab/analysis/gold.py +++ b/src/erlab/analysis/gold.py @@ -221,8 +221,13 @@ def edge( model_cls: lmfit.Model = StepEdgeModel else: if temp is None: - temp = gold.attrs["sample_temp"] - params = lmfit.create_params(temp={"value": temp, "vary": vary_temp}) + temp = gold.qinfo.get_value("sample_temp") + if temp is None: + raise ValueError( + "Temperature not found in data attributes, please provide manually" + ) + + params = lmfit.create_params(temp={"value": float(temp), "vary": vary_temp}) model_cls = FermiEdgeModel model = model_cls() @@ -602,11 +607,8 @@ def quick_fit( data_fit = data.sel(eV=slice(*eV_range)) if eV_range is not None else data if temp is None: - if "sample_temp" in data.coords: - temp = float(data.coords["sample_temp"]) - elif "sample_temp" in data.attrs: - temp = float(data.attrs["sample_temp"]) - else: + temp = data.qinfo.get_value("sample_temp") + if temp is None: raise ValueError( "Temperature not found in data attributes, please provide manually" ) diff --git a/src/erlab/interactive/fermiedge.py b/src/erlab/interactive/fermiedge.py index 5f8d957..0070e87 100644 --- a/src/erlab/interactive/fermiedge.py +++ b/src/erlab/interactive/fermiedge.py @@ -187,10 +187,10 @@ def __init__( self.axes[2].setVisible(False) self.hists[2].setVisible(False) - try: - temp = float(self.data.attrs["sample_temp"]) - except KeyError: + temp = self.data.qinfo.get_value("sample_temp") + if temp is None: temp = 30.0 + temp = float(temp) self.params_roi = ROIControls(self.aw.add_roi(0)) self.params_edge = ParameterGroup( diff --git a/src/erlab/io/dataloader.py b/src/erlab/io/dataloader.py index 4126cfd..086fc2f 100644 --- a/src/erlab/io/dataloader.py +++ b/src/erlab/io/dataloader.py @@ -1544,13 +1544,13 @@ def validate(cls, data: xr.DataArray | xr.Dataset | xr.DataTree) -> None: for c in ("beta", "delta", "xi", "hv"): if c not in data.coords: - cls._raise_or_warn(f"Missing coordinate {c}") + cls._raise_or_warn(f"Missing coordinate '{c}'") - for a in ("configuration", "sample_temp"): - if a not in data.attrs: - cls._raise_or_warn(f"Missing attribute {a}") + if data.qinfo.get_value("sample_temp") is None: + cls._raise_or_warn("Missing attribute 'sample_temp'") if "configuration" not in data.attrs: + cls._raise_or_warn("Missing attribute 'configuration'") return if data.attrs["configuration"] not in (1, 2): @@ -1559,7 +1559,7 @@ def validate(cls, data: xr.DataArray | xr.Dataset | xr.DataTree) -> None: f"Invalid configuration {data.attrs['configuration']}" ) elif "chi" not in data.coords: - cls._raise_or_warn("Missing coordinate chi") + cls._raise_or_warn("Missing coordinate 'chi'") def load_multiple_parallel( self, diff --git a/src/erlab/io/plugins/merlin.py b/src/erlab/io/plugins/merlin.py index 0ba0cbe..eb4b678 100644 --- a/src/erlab/io/plugins/merlin.py +++ b/src/erlab/io/plugins/merlin.py @@ -83,17 +83,17 @@ class MERLINLoader(LoaderBase): } summary_attrs: ClassVar[dict[str, str | Callable[[xr.DataArray], Any]]] = { - "Time": _parse_time, - "Type": _determine_kind, - "Lens Mode": "Lens Mode", - "Scan Type": "Acquisition Mode", - "T(K)": "sample_temp", - "Pass E": "Pass Energy", - "Analyzer Slit": "Slit Plate", - "Polarization": "polarization", + "time": _parse_time, + "type": _determine_kind, + "lens mode": "Lens Mode", + "mode": "Acquisition Mode", + "temperature": "sample_temp", + "pass energy": "Pass Energy", + "analyzer slit": "Slit Plate", + "pol": "polarization", "hv": "hv", - "Entrance Slit": "Entrance Slit", - "Exit Slit": "Exit Slit", + "entrance slit": "Entrance Slit", + "exit slit": "Exit Slit", "polar": "beta", "tilt": "xi", "azi": "delta", @@ -102,7 +102,7 @@ class MERLINLoader(LoaderBase): "z": "z", } - summary_sort = "Time" + summary_sort = "time" always_single = False @@ -185,11 +185,6 @@ def post_process(self, data: xr.DataArray) -> xr.DataArray: if "eV" in data.coords: data = data.assign_coords(eV=-data.eV.values) - if "sample_temp" in data.coords: - # Add temperature to attributes, for backwards compatibility - temp = float(data.sample_temp.mean()) - data = data.assign_attrs(sample_temp=temp) - return data def load_live(self, identifier, data_dir): diff --git a/src/erlab/io/plugins/ssrl52.py b/src/erlab/io/plugins/ssrl52.py index 611aeea..982bdd6 100644 --- a/src/erlab/io/plugins/ssrl52.py +++ b/src/erlab/io/plugins/ssrl52.py @@ -59,26 +59,24 @@ class SSRL52Loader(LoaderBase): } summary_attrs: ClassVar[dict[str, str | Callable[[xr.DataArray], Any]]] = { - "Time": "CreationTimeStamp", - "Type": "Description", - "Lens Mode": "LensModeName", - "Region": "RegionName", - "T(K)": "sample_temp", - "Pass E": "PassEnergy", - "Polarization": "polarization", + "time": "CreationTimeStamp", + "type": "Description", + "lens mode": "LensModeName", + "region": "RegionName", + "temperature": "sample_temp", + "pass energy": "PassEnergy", + "pol": "polarization", "hv": "hv", - # "Entrance Slit": "Entrance Slit", - # "Exit Slit": "Exit Slit", "polar": "chi", "tilt": "xi", "azi": "delta", - "DA": "beta", + "deflector": "beta", "x": "x", "y": "y", "z": "z", } - summary_sort = "Time" + summary_sort = "time" always_single: bool = True skip_validate: bool = True @@ -266,11 +264,6 @@ def load_single( def post_process(self, data: xr.DataArray) -> xr.DataArray: data = super().post_process(data) - if "sample_temp" in data.coords: - # Add temperature to attributes - temp = float(data.sample_temp.mean()) - data = data.assign_attrs(sample_temp=temp) - # Convert to binding energy if ( "sample_workfunction" in data.attrs diff --git a/tests/conftest.py b/tests/conftest.py index 80d8146..46acf2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,10 +9,10 @@ from erlab.io.exampledata import generate_data_angles, generate_gold_edge -DATA_COMMIT_HASH = "ff4ff7ee6fb60f5b5f2e272ed15ddff0663943ca" +DATA_COMMIT_HASH = "ad7dbdf35ef2404feee0854cb3a52973770709f4" """The commit hash of the commit to retrieve from `kmnhan/erlabpy-data`.""" -DATA_KNOWN_HASH = "ddd9f6dda241945e8e7f8a5a4387513359f6372e5409dbebe5a1067b0556a07d" +DATA_KNOWN_HASH = "43d89ef27482e127e7509b65d635b7a6a0cbf648f84aa597c8b4094bdc0c46ab" """The hash of the `.tar.gz` file.""" diff --git a/tests/io/plugins/test_merlin.py b/tests/io/plugins/test_merlin.py index a900104..f26463f 100644 --- a/tests/io/plugins/test_merlin.py +++ b/tests/io/plugins/test_merlin.py @@ -44,3 +44,16 @@ def test_load_live(expected_dir): def test_summarize(data_dir): erlab.io.summarize() + + +def test_qinfo(data_dir): + data = erlab.io.load(5) + assert ( + data.qinfo.__repr__() + == """time: 2022-03-27 07:53:26\ntype: map\nlens mode (Lens Mode): A30 +mode (Acquisition Mode): Dither\ntemperature (sample_temp): 110.67 +pass energy (Pass Energy): 10\nanalyzer slit (Slit Plate): 7\npol (polarization): LH +hv (hv): 100\nentrance slit (Entrance Slit): 70\nexit slit (Exit Slit): 70 +polar (beta): [-15.5, -15]\ntilt (xi): 0\nazi (delta): 3\nx (x): 2.487\ny (y): 0.578 +z (z): -1.12""" + )