Skip to content

Commit

Permalink
feat: add qinfo accessor
Browse files Browse the repository at this point in the history
Adds a `qinfo` accessor that prints a table summarizing the data in a human readable format. Closes #27
  • Loading branch information
kmnhan committed Nov 5, 2024
1 parent e56163b commit eb3a742
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 55 deletions.
59 changes: 59 additions & 0 deletions src/erlab/accessors/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
"SelectionAccessor",
]

import functools
import importlib
from collections.abc import Hashable, Mapping
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -19,6 +21,7 @@
ERLabDatasetAccessor,
either_dict_or_kwargs,
)
from erlab.utils.formatting import format_html_table
from erlab.utils.misc import emit_user_level_warning


Expand Down Expand Up @@ -504,3 +507,59 @@ def around(
if average:
return masked.mean(sel_kw.keys())
return masked


@xr.register_dataarray_accessor("qinfo")
class InfoDataArrayAccessor(ERLabDataArrayAccessor):
"""`xarray.Dataset.qinfo` accessor for displaying information about the data."""

def get_value(self, attr_or_coord_name: str) -> Any:
"""Get the value of the specified attribute or coordinate.
If the attribute or coordinate is not found, `None` is returned.
Parameters
----------
attr_or_coord_name
The name of the attribute or coordinate.
"""
if attr_or_coord_name in self._obj.attrs:
return self._obj.attrs[attr_or_coord_name]
if attr_or_coord_name in self._obj.coords:
return self._obj.coords[attr_or_coord_name]
return None

@functools.cached_property
def _summary_table(self) -> list[tuple[str, str, str]]:
import erlab.io

if "data_loader_name" in self._obj.attrs:
loader = erlab.io.loaders[self._obj.attrs["data_loader_name"]]
else:
raise ValueError("Data loader information not found in data attributes")

out: list[tuple[str, str, str]] = []

for key, true_key in loader.summary_attrs.items():
val = loader.get_formatted_attr_or_coord(self._obj, true_key)
if callable(true_key):
true_key = ""
out.append((key, loader.value_to_string(val), true_key))

return out

def _repr_html_(self) -> str:
return format_html_table(
[("Name", "Value", "Key"), *self._summary_table],
header_cols=1,
header_rows=1,
)

def __repr__(self) -> str:
return "\n".join(
[
f"{key}: {val}" if not true_key else f"{key} ({true_key}): {val}"
for key, val, true_key in self._summary_table
]
)
14 changes: 8 additions & 6 deletions src/erlab/analysis/fit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"StepEdgeModel",
]

import contextlib
from typing import Literal

import lmfit
Expand Down Expand Up @@ -178,10 +177,11 @@ def guess(self, data, x, **kwargs):
np.argmin(np.gradient(scipy.ndimage.gaussian_filter1d(data, 0.2 * len(x))))
]

temp = 30.0
temp = None
if isinstance(data, xr.DataArray):
with contextlib.suppress(KeyError):
temp = float(data.attrs["sample_temp"])
temp = data.qinfo.get_value("sample_temp")
if temp is None:
temp = 30.0

pars[f"{self.prefix}center"].set(
value=efermi, min=np.asarray(x).min(), max=np.asarray(x).max()
Expand All @@ -190,7 +190,7 @@ def guess(self, data, x, **kwargs):
pars[f"{self.prefix}back1"].set(value=back1)
pars[f"{self.prefix}dos0"].set(value=dos0)
pars[f"{self.prefix}dos1"].set(value=dos1)
pars[f"{self.prefix}temp"].set(value=temp)
pars[f"{self.prefix}temp"].set(value=float(temp))
pars[f"{self.prefix}resolution"].set(value=0.02)

return lmfit.models.update_param_vals(pars, self.prefix, **kwargs)
Expand Down Expand Up @@ -418,7 +418,9 @@ def guess(self, data, eV, alpha, **kwargs):
pars[f"{self.prefix}lin_bkg"].set(value=dos1)

if isinstance(data, xr.DataArray):
pars[f"{self.prefix}temp"].set(value=data.attrs["sample_temp"])
temp = data.qinfo.get_value("sample_temp")
if temp is not None:
pars[f"{self.prefix}temp"].set(value=float(temp))

return lmfit.models.update_param_vals(pars, self.prefix, **kwargs)

Expand Down
16 changes: 9 additions & 7 deletions src/erlab/analysis/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,13 @@ def edge(
model_cls: lmfit.Model = StepEdgeModel
else:
if temp is None:
temp = gold.attrs["sample_temp"]
params = lmfit.create_params(temp={"value": temp, "vary": vary_temp})
temp = gold.qinfo.get_value("sample_temp")
if temp is None:
raise ValueError(
"Temperature not found in data attributes, please provide manually"
)

params = lmfit.create_params(temp={"value": float(temp), "vary": vary_temp})
model_cls = FermiEdgeModel

model = model_cls()
Expand Down Expand Up @@ -602,11 +607,8 @@ def quick_fit(
data_fit = data.sel(eV=slice(*eV_range)) if eV_range is not None else data

if temp is None:
if "sample_temp" in data.coords:
temp = float(data.coords["sample_temp"])
elif "sample_temp" in data.attrs:
temp = float(data.attrs["sample_temp"])
else:
temp = data.qinfo.get_value("sample_temp")
if temp is None:
raise ValueError(
"Temperature not found in data attributes, please provide manually"
)
Expand Down
6 changes: 3 additions & 3 deletions src/erlab/interactive/fermiedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def __init__(
self.axes[2].setVisible(False)
self.hists[2].setVisible(False)

try:
temp = float(self.data.attrs["sample_temp"])
except KeyError:
temp = self.data.qinfo.get_value("sample_temp")
if temp is None:
temp = 30.0
temp = float(temp)

self.params_roi = ROIControls(self.aw.add_roi(0))
self.params_edge = ParameterGroup(
Expand Down
10 changes: 5 additions & 5 deletions src/erlab/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,13 +1544,13 @@ def validate(cls, data: xr.DataArray | xr.Dataset | xr.DataTree) -> None:

for c in ("beta", "delta", "xi", "hv"):
if c not in data.coords:
cls._raise_or_warn(f"Missing coordinate {c}")
cls._raise_or_warn(f"Missing coordinate '{c}'")

for a in ("configuration", "sample_temp"):
if a not in data.attrs:
cls._raise_or_warn(f"Missing attribute {a}")
if data.qinfo.get_value("sample_temp") is None:
cls._raise_or_warn("Missing attribute 'sample_temp'")

if "configuration" not in data.attrs:
cls._raise_or_warn("Missing attribute 'configuration'")
return

if data.attrs["configuration"] not in (1, 2):
Expand All @@ -1559,7 +1559,7 @@ def validate(cls, data: xr.DataArray | xr.Dataset | xr.DataTree) -> None:
f"Invalid configuration {data.attrs['configuration']}"
)
elif "chi" not in data.coords:
cls._raise_or_warn("Missing coordinate chi")
cls._raise_or_warn("Missing coordinate 'chi'")

def load_multiple_parallel(
self,
Expand Down
27 changes: 11 additions & 16 deletions src/erlab/io/plugins/merlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ class MERLINLoader(LoaderBase):
}

summary_attrs: ClassVar[dict[str, str | Callable[[xr.DataArray], Any]]] = {
"Time": _parse_time,
"Type": _determine_kind,
"Lens Mode": "Lens Mode",
"Scan Type": "Acquisition Mode",
"T(K)": "sample_temp",
"Pass E": "Pass Energy",
"Analyzer Slit": "Slit Plate",
"Polarization": "polarization",
"time": _parse_time,
"type": _determine_kind,
"lens mode": "Lens Mode",
"mode": "Acquisition Mode",
"temperature": "sample_temp",
"pass energy": "Pass Energy",
"analyzer slit": "Slit Plate",
"pol": "polarization",
"hv": "hv",
"Entrance Slit": "Entrance Slit",
"Exit Slit": "Exit Slit",
"entrance slit": "Entrance Slit",
"exit slit": "Exit Slit",
"polar": "beta",
"tilt": "xi",
"azi": "delta",
Expand All @@ -102,7 +102,7 @@ class MERLINLoader(LoaderBase):
"z": "z",
}

summary_sort = "Time"
summary_sort = "time"

always_single = False

Expand Down Expand Up @@ -185,11 +185,6 @@ def post_process(self, data: xr.DataArray) -> xr.DataArray:
if "eV" in data.coords:
data = data.assign_coords(eV=-data.eV.values)

if "sample_temp" in data.coords:
# Add temperature to attributes, for backwards compatibility
temp = float(data.sample_temp.mean())
data = data.assign_attrs(sample_temp=temp)

return data

def load_live(self, identifier, data_dir):
Expand Down
25 changes: 9 additions & 16 deletions src/erlab/io/plugins/ssrl52.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,24 @@ class SSRL52Loader(LoaderBase):
}

summary_attrs: ClassVar[dict[str, str | Callable[[xr.DataArray], Any]]] = {
"Time": "CreationTimeStamp",
"Type": "Description",
"Lens Mode": "LensModeName",
"Region": "RegionName",
"T(K)": "sample_temp",
"Pass E": "PassEnergy",
"Polarization": "polarization",
"time": "CreationTimeStamp",
"type": "Description",
"lens mode": "LensModeName",
"region": "RegionName",
"temperature": "sample_temp",
"pass energy": "PassEnergy",
"pol": "polarization",
"hv": "hv",
# "Entrance Slit": "Entrance Slit",
# "Exit Slit": "Exit Slit",
"polar": "chi",
"tilt": "xi",
"azi": "delta",
"DA": "beta",
"deflector": "beta",
"x": "x",
"y": "y",
"z": "z",
}

summary_sort = "Time"
summary_sort = "time"

always_single: bool = True
skip_validate: bool = True
Expand Down Expand Up @@ -266,11 +264,6 @@ def load_single(
def post_process(self, data: xr.DataArray) -> xr.DataArray:
data = super().post_process(data)

if "sample_temp" in data.coords:
# Add temperature to attributes
temp = float(data.sample_temp.mean())
data = data.assign_attrs(sample_temp=temp)

# Convert to binding energy
if (
"sample_workfunction" in data.attrs
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from erlab.io.exampledata import generate_data_angles, generate_gold_edge

DATA_COMMIT_HASH = "ff4ff7ee6fb60f5b5f2e272ed15ddff0663943ca"
DATA_COMMIT_HASH = "ad7dbdf35ef2404feee0854cb3a52973770709f4"
"""The commit hash of the commit to retrieve from `kmnhan/erlabpy-data`."""

DATA_KNOWN_HASH = "ddd9f6dda241945e8e7f8a5a4387513359f6372e5409dbebe5a1067b0556a07d"
DATA_KNOWN_HASH = "43d89ef27482e127e7509b65d635b7a6a0cbf648f84aa597c8b4094bdc0c46ab"
"""The hash of the `.tar.gz` file."""


Expand Down
13 changes: 13 additions & 0 deletions tests/io/plugins/test_merlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ def test_load_live(expected_dir):

def test_summarize(data_dir):
erlab.io.summarize()


def test_qinfo(data_dir):
data = erlab.io.load(5)
assert (
data.qinfo.__repr__()
== """time: 2022-03-27 07:53:26\ntype: map\nlens mode (Lens Mode): A30
mode (Acquisition Mode): Dither\ntemperature (sample_temp): 110.67
pass energy (Pass Energy): 10\nanalyzer slit (Slit Plate): 7\npol (polarization): LH
hv (hv): 100\nentrance slit (Entrance Slit): 70\nexit slit (Exit Slit): 70
polar (beta): [-15.5, -15]\ntilt (xi): 0\nazi (delta): 3\nx (x): 2.487\ny (y): 0.578
z (z): -1.12"""
)

0 comments on commit eb3a742

Please sign in to comment.