From b2767564b6abc1883881402b3c8f0df67d304260 Mon Sep 17 00:00:00 2001 From: woutdenolf Date: Sun, 21 Apr 2024 16:58:29 +0200 Subject: [PATCH] WIP --- doc/howtoguides.rst | 7 ++ doc/howtoguides/convert_files.rst | 8 ++ doc/howtoguides/install.rst | 6 ++ doc/index.rst | 3 + doc/tutorials.rst | 6 ++ doc/tutorials/models.rst | 34 ++++++ setup.cfg | 4 +- src/pynxxas/apps/__init__.py | 2 + src/pynxxas/apps/nxxas_convert.py | 90 ++++++++++++++-- src/pynxxas/io/__init__.py | 30 ++++++ src/pynxxas/io/hdf5_utils.py | 41 +++++++ src/pynxxas/io/nexus.py | 144 +++++++++++++++++++++++++ src/pynxxas/io/url_utils.py | 44 ++++++++ src/pynxxas/io/xdi.py | 51 +++++---- src/pynxxas/models/__init__.py | 7 ++ src/pynxxas/models/convert/__init__.py | 26 +++++ src/pynxxas/models/convert/nexus.py | 9 ++ src/pynxxas/models/convert/xdi.py | 25 +++++ src/pynxxas/models/nexus.py | 83 ++++++++++++++ src/pynxxas/models/units.py | 5 +- src/pynxxas/models/xdi.py | 21 ++-- src/pynxxas/tests/conftest.py | 20 ++++ src/pynxxas/tests/test_convert.py | 54 ++++++++++ src/pynxxas/tests/test_nexus.py | 71 ++++++++++++ src/pynxxas/tests/test_xdi.py | 24 +++-- 25 files changed, 766 insertions(+), 49 deletions(-) create mode 100644 doc/howtoguides.rst create mode 100644 doc/howtoguides/convert_files.rst create mode 100644 doc/howtoguides/install.rst create mode 100644 doc/tutorials.rst create mode 100644 doc/tutorials/models.rst create mode 100644 src/pynxxas/io/hdf5_utils.py create mode 100644 src/pynxxas/io/nexus.py create mode 100644 src/pynxxas/io/url_utils.py create mode 100644 src/pynxxas/models/convert/__init__.py create mode 100644 src/pynxxas/models/convert/nexus.py create mode 100644 src/pynxxas/models/convert/xdi.py create mode 100644 src/pynxxas/models/nexus.py create mode 100644 src/pynxxas/tests/test_convert.py create mode 100644 src/pynxxas/tests/test_nexus.py diff --git a/doc/howtoguides.rst b/doc/howtoguides.rst new file mode 100644 index 0000000..b9d659f --- /dev/null +++ b/doc/howtoguides.rst @@ -0,0 +1,7 @@ +How-to Guides +============= + +.. toctree:: + + howtoguides/install + howtoguides/convert_files diff --git a/doc/howtoguides/convert_files.rst b/doc/howtoguides/convert_files.rst new file mode 100644 index 0000000..4a661c0 --- /dev/null +++ b/doc/howtoguides/convert_files.rst @@ -0,0 +1,8 @@ +Convert file formats +==================== + +Convert all files in the *xdi_files* and *xas_beamline_data* to *HDF5/NeXus* format + +.. code-block:: bash + + nxxas-convert xdi_files/*.* xas_beamline_data/*.* data.h5 diff --git a/doc/howtoguides/install.rst b/doc/howtoguides/install.rst new file mode 100644 index 0000000..69956a9 --- /dev/null +++ b/doc/howtoguides/install.rst @@ -0,0 +1,6 @@ +Install +======= + +.. code-block:: bash + + pip install pynxxas diff --git a/doc/index.rst b/doc/index.rst index 8ccdcb8..b75796f 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -6,4 +6,7 @@ Library for reading and writing XAS data in `NeXus format =3.8 install_requires = + typing_extensions; python_version < "3.9" + strenum; python_version < "3.11" numpy h5py pydantic >=2.6 pint - typing_extensions; python_version < "3.9" + periodictable [options.packages.find] where=src diff --git a/src/pynxxas/apps/__init__.py b/src/pynxxas/apps/__init__.py index e69de29..2c9e51d 100644 --- a/src/pynxxas/apps/__init__.py +++ b/src/pynxxas/apps/__init__.py @@ -0,0 +1,2 @@ +"""Command-Line Interface (CLI) +""" diff --git a/src/pynxxas/apps/nxxas_convert.py b/src/pynxxas/apps/nxxas_convert.py index 751be61..0b7b917 100644 --- a/src/pynxxas/apps/nxxas_convert.py +++ b/src/pynxxas/apps/nxxas_convert.py @@ -1,9 +1,14 @@ import sys +import pathlib import logging import argparse from glob import glob -from ..io.xdi import read_xdi +from .. import io +from .. import models +from ..models import convert + +logger = logging.getLogger(__name__) def main(argv=None): @@ -14,20 +19,89 @@ def main(argv=None): prog="nxxas_convert", description="Convert data to NXxas format" ) - parser.add_argument("--output", type=str, default=None, help="Path to HDF5 file") parser.add_argument( - "patterns", + "--output-format", type=str, - nargs="+", - help="Glob file name patterns", + default="nexus", + choices=list(models.MODELS), + help="Output format", + ) + + parser.add_argument( + "file_patterns", + type=str, + nargs="*", + help="Files to convert", + ) + + parser.add_argument( + "output_filename", type=pathlib.Path, help="Convert destination filename" ) args = parser.parse_args(argv[1:]) logging.basicConfig() - for pattern in args.patterns: - for filename in glob(pattern): - read_xdi(filename) + output_filename = args.output_filename + model_type = models.MODELS[args.output_format] + + if output_filename.exists(): + result = input(f"Overwrite {output_filename}? (y/[n])") + if not result.lower() in ("y", "yes"): + return 0 + output_filename.unlink() + + filenames = list() + for pattern in args.file_patterns: + filenames.extend(glob(pattern)) + + ndigitsfile = len(str(len(filenames) - 1)) + return_code = 0 + for file_number, filename in enumerate(filenames, 1): + try: + models_in = io.load_models(filename) + except NotImplementedError as e: + return_code = 1 + logger.warning("Error when loading '%s': %s", filename, e) + continue + except Exception: + return_code = 1 + logger.error("Error when loading '%s'", filename, exc_info=True) + continue + + try: + models_out = [convert.convert_model(m, model_type) for m in models_in] + except NotImplementedError as e: + return_code = 1 + logger.warning("Error when loading '%s': %s", filename, e) + continue + except Exception: + return_code = 1 + logger.error("Error when converting '%s'", filename, exc_info=True) + continue + + ndigitsscan = len(str(len(models_out) - 1)) + for scan_number, model_out in enumerate(models_out, 1): + if args.output_format == "nexus": + url = f"{output_filename}?path=/{file_number}.{scan_number}" + else: + nrfmt = f"{{:0{ndigitsfile}d}}_{{:0{ndigitsscan}d}}" + scan_number = nrfmt.format(file_number, scan_number) + url = ( + output_filename.parent + / f"{output_filename.stem}_{scan_number:02}{output_filename.suffix}" + ) + + try: + io.save_model(model_out, url) + except NotImplementedError as e: + return_code = 1 + logger.warning("Error when saving '%s' in '%s': %s", filename, url, e) + except Exception: + return_code = 1 + logger.error( + "Error when saving '%s' in '%s'", filename, url, exc_info=True + ) + return return_code if __name__ == "__main__": diff --git a/src/pynxxas/io/__init__.py b/src/pynxxas/io/__init__.py index e69de29..bd2ccd7 100644 --- a/src/pynxxas/io/__init__.py +++ b/src/pynxxas/io/__init__.py @@ -0,0 +1,30 @@ +"""File formats +""" + +from typing import List + +import pydantic + +from .url_utils import UrlType +from . import xdi +from . import nexus +from .. import models + + +def load_models(url: UrlType) -> List[pydantic.BaseModel]: + if xdi.is_xdi_file(url): + return xdi.load_xdi_file(url) + if nexus.is_nexus_file(url): + return nexus.load_nexus_file(url) + raise NotImplementedError(f"File format not supported: {url}") + + +def save_model(model_instance: pydantic.BaseModel, url: UrlType) -> None: + if isinstance(model_instance, models.NxXasModel): + nexus.save_nexus_file(model_instance, url) + elif isinstance(model_instance, models.XdiModel): + xdi.save_xdi_file(model_instance, url) + else: + raise NotImplementedError( + f"Saving of {type(model_instance).__name__} not implemented" + ) diff --git a/src/pynxxas/io/hdf5_utils.py b/src/pynxxas/io/hdf5_utils.py new file mode 100644 index 0000000..80fda85 --- /dev/null +++ b/src/pynxxas/io/hdf5_utils.py @@ -0,0 +1,41 @@ +import os +from typing import Optional, Union + +import h5py + + +def create_hdf5_link( + h5group: h5py.Group, + target_name: str, + target_filename: Optional[str], + absolute: bool = False, +) -> Union[h5py.SoftLink, h5py.ExternalLink]: + """Create HDF5 soft link (supports relative down paths) or external link (supports relative paths).""" + this_name = h5group.name + this_filename = h5group.file.filename + + target_filename = target_filename or this_filename + + if os.path.isabs(target_filename): + rel_target_filename = os.path.relpath(target_filename, this_filename) + else: + rel_target_filename = target_filename + target_filename = os.path.abs(os.path.join(this_filename, target_filename)) + + if "." not in target_name: + rel_target_name = os.path.relpath(target_name, this_name) + else: + rel_target_name = target_name + target_name = os.path.abspath(os.path.join(this_name, target_name)) + + # Internal link + if rel_target_filename == ".": + if absolute or ".." in rel_target_name: + # h5py.SoftLink does not support relative links upwards + return h5py.SoftLink(target_name) + return h5py.SoftLink(rel_target_name) + + # External link + if absolute: + return h5py.ExternalLink(target_filename, target_name) + return h5py.ExternalLink(rel_target_filename, target_name) diff --git a/src/pynxxas/io/nexus.py b/src/pynxxas/io/nexus.py new file mode 100644 index 0000000..4097179 --- /dev/null +++ b/src/pynxxas/io/nexus.py @@ -0,0 +1,144 @@ +"""NeXus/HDF5 file format +""" + +from typing import List, Any + +try: + from enum import StrEnum +except ImportError: + from strenum import StrEnum + + +import h5py +import pint + +from . import url_utils +from . import hdf5_utils +from ..models import nexus + + +def is_nexus_file(url: url_utils.UrlType) -> bool: + filename = url_utils.as_url(url).path + with open(filename, "rb") as file: + try: + with h5py.File(file, mode="r"): + return True + except Exception: + return False + + +def load_nexus_file(url: url_utils.UrlType) -> List[nexus.NxGroup]: + raise NotImplementedError(f"File format not supported: {url}") + + +def save_nexus_file(nxgroup: nexus.NxXasModel, url: url_utils.UrlType) -> None: + if not isinstance(nxgroup, nexus.NxXasModel): + raise TypeError(f"nxgroup is not of type NxXasModel ({type(nxgroup)})") + if not nxgroup.has_data(): + return + filename = url_utils.as_url(url).path + url = url_utils.as_url(url) + + with h5py.File(filename, mode="a", track_order=True) as nxroot: + nxparent = _prepare_nxparent(nxgroup, url, nxroot) + _save_nxcontent(nxgroup, nxparent) + + +def _save_nxcontent(nxgroup: nexus.NxGroup, nxparent: h5py.Group) -> None: + if not isinstance(nxgroup, nexus.NxGroup): + raise TypeError(f"nxgroup is not of type NxGroup ({type(nxgroup)})") + for field_name, field in nxgroup.__fields__.items(): + field_value = getattr(nxgroup, field_name) + if field_value is None: + continue + elif isinstance(field_value, nexus.NxGroup): + nxchild = nxparent.require_group(field_name) + _save_nxcontent(field_value, nxchild) + if isinstance(field_value, nexus.NxDataModel): + _set_default(nxchild) + elif field.alias and field.alias.startswith("@"): + try: + _save_attribute(nxparent, field_name, field_value) + except Exception as e: + raise ValueError( + f"{field_name} = {field_value} ({type(field_value)}) cannot be saved as an HDF5 attribute" + ) from e + else: + try: + _save_dataset(nxparent, field_name, field_value) + except Exception as e: + raise ValueError( + f"{field_name} = {field_value} ({type(field_value)}) cannot be saved as an HDF5 dataset" + ) from e + + +def _save_dataset(nxparent: h5py.Group, field_name: str, field_value: Any) -> None: + if isinstance(field_value, StrEnum): + nxparent[field_name] = str(field_value) + elif isinstance(field_value, pint.Quantity): + if field_value.size: + nxparent[field_name] = field_value.magnitude + units = str(field_value.units) + if units: + nxparent.attrs["units"] = units + elif isinstance(field_value, nexus.NxLinkModel): + link = hdf5_utils.create_hdf5_link( + nxparent, field_value.target_name, field_value.target_filename + ) + nxparent[field_name] = link + else: + nxparent[field_name] = field_value + + +def _save_attribute(nxparent: h5py.Group, field_name: str, field_value: Any) -> None: + if isinstance(field_value, StrEnum): + nxparent.attrs[field_name] = str(field_value) + else: + nxparent.attrs[field_name] = field_value + + +def _set_default(h5group: h5py.Group) -> None: + while h5group.name != "/": + h5group.parent.attrs["default"] = h5group.name.split("/")[-1] + h5group = h5group.parent + + +def _prepare_nxparent( + nxgroup: nexus.NxGroup, + url: url_utils.ParsedUrlType, + nxroot: h5py.File, +) -> h5py.Group: + """Creates and returns the parent group of `nxgroup`""" + internal_path = url_utils.as_url(url).internal_path + parts = [s for s in internal_path.split("/") if s] + nparts = len(parts) + + if nxgroup.NX_class == "NXroot": + if nparts != 0: + raise ValueError( + f"NXroot URL cannot have an internal path ({internal_path})" + ) + nxclasses = [] + elif nxgroup.NX_class == "NXentry": + if nparts != 1: + raise ValueError( + f"NXentry URL must have an internal path of 1 level deep ({internal_path})" + ) + nxclasses = ["NXentry"] + elif nxgroup.NX_class == "NXsubentry": + if nparts != 2: + raise ValueError( + f"NXsubentry URL must have an internal path of 2 levels deep ({internal_path})" + ) + nxclasses = ["NXentry", "NXsubentry"] + else: + nxclasses = ["NXentry"] + ["NXsubentry"] * (len(parts) - 1) + + nxroot.attrs.setdefault("NX_class", "NXroot") + + nxparent = nxroot + for part, nxclass in zip(parts, nxclasses): + nxparent = nxparent.require_group(part) + nxparent.attrs.setdefault("NX_class", nxclass) + + return nxparent diff --git a/src/pynxxas/io/url_utils.py b/src/pynxxas/io/url_utils.py new file mode 100644 index 0000000..82a414a --- /dev/null +++ b/src/pynxxas/io/url_utils.py @@ -0,0 +1,44 @@ +import os +import sys +import pathlib +import urllib.parse +import urllib.request +from typing import Union, NamedTuple + + +class ParsedUrlType(NamedTuple): + path: str + internal_path: str + + +UrlType = Union[str, pathlib.Path, urllib.parse.ParseResult, ParsedUrlType] + + +_WIN32 = sys.platform == "win32" + + +def as_url(url: UrlType) -> ParsedUrlType: + if isinstance(url, ParsedUrlType): + return url + + if isinstance(url, urllib.parse.ParseResult): + parsed = url + else: + url_str = str(url) + parsed = urllib.parse.urlparse(url_str) + if not parsed.scheme or (_WIN32 and len(parsed.scheme) == 1): + url_str = "file://" + os.path.abspath(url_str).replace("\\", "/") + parsed = urllib.parse.urlparse(url_str) + + if parsed.scheme != "file": + raise ValueError("URL is not a file") + + if parsed.netloc: + path = f"{parsed.netloc}{parsed.path}" + else: + path = parsed.path + + query = urllib.parse.parse_qs(parsed.query) + internal_path = query.get("path", [""])[0] + + return ParsedUrlType(path=path, internal_path=internal_path) diff --git a/src/pynxxas/io/xdi.py b/src/pynxxas/io/xdi.py index d172d9c..66b427a 100644 --- a/src/pynxxas/io/xdi.py +++ b/src/pynxxas/io/xdi.py @@ -1,29 +1,37 @@ +"""XAS Data Interchange (XDI) file format +""" + import re -import pathlib import datetime -from typing import Union, Tuple, Optional +from typing import Union, Tuple, Optional, List import pint import numpy +from . import url_utils from ..models import units from ..models.xdi import XdiModel -def is_xdi_file(filename: Union[str, pathlib.Path]) -> bool: +def is_xdi_file(url: url_utils.UrlType) -> bool: + filename = url_utils.as_url(url).path with open(filename, "r") as file: - for line in file: - line = line.strip() - if not line: - continue - return line.startswith("# XDI") + try: + for line in file: + line = line.strip() + if not line: + continue + return line.startswith("# XDI") + except Exception: + return False -def read_xdi(filename: Union[str, pathlib.Path]) -> XdiModel: +def load_xdi_file(url: url_utils.UrlType) -> List[XdiModel]: """Specs described in https://github.com/XraySpectroscopy/XAS-Data-Interchange/blob/master/specification/spec.md """ + filename = url_utils.as_url(url).path content = {"comments": [], "column": dict(), "data": dict()} with open(filename, "r") as file: @@ -78,10 +86,8 @@ def read_xdi(filename: Union[str, pathlib.Path]) -> XdiModel: key = _parse_xdi_value(key) content[key] = value - # Data - table = numpy.loadtxt(file, dtype=float) - - # Parse data in dictionary of pint Quantity objects + # Data + table = numpy.loadtxt(filename, dtype=float) columns = [ name for _, name in sorted(content.pop("column").items(), key=lambda tpl: tpl[0]) @@ -90,7 +96,13 @@ def read_xdi(filename: Union[str, pathlib.Path]) -> XdiModel: name, quant = _parse_xdi_column_name(name) content["data"][name] = array, quant - return XdiModel(**content) + return [XdiModel(**content)] + + +def save_xdi_file(model_instance: XdiModel, url: url_utils.UrlType) -> None: + raise NotImplementedError( + f"Saving of {type(model_instance).__name__} not implemented" + ) _XDI_FIELD_REGEX = re.compile(r"#\s*([\w.]+):\s*(.*)") @@ -134,10 +146,13 @@ def _parse_xdi_value( def _parse_xdi_column_name( name: str, -) -> Union[Tuple[str, Optional[pint.Unit]]]: +) -> Union[Tuple[str, Optional[str]]]: parts = _SPACES_REGEX.split(name) if len(parts) == 1: return name, None - if len(parts) == 2: - return tuple(parts) - raise ValueError(f"XDI column name '{name}' is not valid") + try: + units.as_units(parts[-1]) + except pint.UndefinedUnitError: + return name, None + name = " ".join(parts[:-1]) + return name, parts[-1] diff --git a/src/pynxxas/models/__init__.py b/src/pynxxas/models/__init__.py index e69de29..9ef6a64 100644 --- a/src/pynxxas/models/__init__.py +++ b/src/pynxxas/models/__init__.py @@ -0,0 +1,7 @@ +"""Data models +""" + +from .xdi import XdiModel +from .nexus import NxXasModel + +MODELS = {"xdi": XdiModel, "nexus": NxXasModel} diff --git a/src/pynxxas/models/convert/__init__.py b/src/pynxxas/models/convert/__init__.py new file mode 100644 index 0000000..1f3a590 --- /dev/null +++ b/src/pynxxas/models/convert/__init__.py @@ -0,0 +1,26 @@ +from typing import Type +import pydantic + +from . import xdi +from . import nexus +from .. import XdiModel +from .. import NxXasModel + + +def convert_model( + instance: pydantic.BaseModel, model_type: Type[pydantic.BaseModel] +) -> pydantic.BaseModel: + if isinstance(instance, model_type): + return instance + + mod_to = _CONVERT_MODULE.get(type(instance)) + mod_from = _CONVERT_MODULE.get(model_type) + if mod_to is None or mod_from is None: + raise NotImplementedError( + f"Conversion from {type(instance).__name__} to {model_type.__name__} is not implemented" + ) + + return mod_from.from_nxxas(mod_to.to_nxxas(instance)) + + +_CONVERT_MODULE = {XdiModel: xdi, NxXasModel: nexus} diff --git a/src/pynxxas/models/convert/nexus.py b/src/pynxxas/models/convert/nexus.py new file mode 100644 index 0000000..4cd8b99 --- /dev/null +++ b/src/pynxxas/models/convert/nexus.py @@ -0,0 +1,9 @@ +from .. import NxXasModel + + +def to_nxxas(nxxas_model: NxXasModel) -> NxXasModel: + return nxxas_model + + +def from_nxxas(nxxas_model: NxXasModel) -> NxXasModel: + return nxxas_model diff --git a/src/pynxxas/models/convert/xdi.py b/src/pynxxas/models/convert/xdi.py new file mode 100644 index 0000000..11166e6 --- /dev/null +++ b/src/pynxxas/models/convert/xdi.py @@ -0,0 +1,25 @@ +from .. import XdiModel +from .. import NxXasModel + + +def to_nxxas(xdi_model: XdiModel) -> NxXasModel: + nxxas_model = NxXasModel( + element=xdi_model.element.symbol, + absorption_edge=xdi_model.element.edge, + mode="transmission", + ) + nxxas_model.energy = xdi_model.data.energy + if xdi_model.data.mutrans is not None: + nxxas_model.intensity = xdi_model.data.mutrans + nxxas_model.mode = "transmission" + return nxxas_model + + +def from_nxxas(nxxas_model: NxXasModel) -> XdiModel: + xdi_model = XdiModel() + xdi_model.element.symbol = nxxas_model.element + xdi_model.element.edge = nxxas_model.absorption_edge + xdi_model.data.energy = nxxas_model.energy + if nxxas_model.mode == "transmission": + xdi_model.data.mutrans = nxxas_model.intensity + return xdi_model diff --git a/src/pynxxas/models/nexus.py b/src/pynxxas/models/nexus.py new file mode 100644 index 0000000..2520223 --- /dev/null +++ b/src/pynxxas/models/nexus.py @@ -0,0 +1,83 @@ +"""NeXus data model_instance +""" + +from typing import Dict, Literal, List, Optional + +try: + from enum import StrEnum +except ImportError: + from strenum import StrEnum + +import pydantic +import periodictable + +from . import units + + +class NxGroup(pydantic.BaseModel, extra="allow"): + pass + + +class NxClass: + _NXCLASSES: Dict[str, "NxClass"] = dict() + + def __init_subclass__(cls, nx_class: str, **kwargs): + super().__init_subclass__(**kwargs) + NxClass._NXCLASSES[nx_class] = cls + + +class NxLinkModel(pydantic.BaseModel): + target_name: str + target_filename: Optional[str] = None + + +class NxDataModel(NxClass, NxGroup, nx_class="NxData"): + NX_class: Literal["NXdata"] = pydantic.Field(default="NXdata", alias="@NX_class") + signal: Literal["intensity"] = pydantic.Field(default="intensity", alias="@signal") + axes: List[str] = pydantic.Field(default=["energy"], alias="@axes") + energy: NxLinkModel + intensity: NxLinkModel + + +class NxEntryClass(StrEnum): + NXentry = "NXentry" + NXsubentry = "NXsubentry" + + +class NxXasMode(StrEnum): + transmission = "transmission" + fluorescence_yield = "fluorescence yield" + + +ChemicalElement = StrEnum( + "ChemicalElement", {el.symbol: el.symbol for el in periodictable.elements} +) + +XRayCoreExcitationState = StrEnum( + "XRayCoreExcitationState", {s: s for s in ("K", "L1", "L2", "L3")} +) + + +class NxXasModel(NxClass, NxGroup, nx_class="NXxas"): + NX_class: NxEntryClass = pydantic.Field(alias="@NX_class", default="NXentry") + definition: Literal["NXxas"] = "NXxas" + mode: NxXasMode + element: ChemicalElement + absorption_edge: XRayCoreExcitationState + energy: units.PydanticQuantity = units.as_quantity([]) + intensity: units.PydanticQuantity = units.as_quantity([]) + title: Optional[str] = None + plot: Optional[NxDataModel] = None + + @pydantic.model_validator(mode="after") + def set_title(self) -> "NxXasModel": + if self.element is not None and self.absorption_edge is not None: + self.title = f"{self.element} {self.absorption_edge}" + if self.plot is None: + energy = NxLinkModel(target_name="../energy") + intensity = NxLinkModel(target_name="../intensity") + self.plot = NxDataModel(energy=energy, intensity=intensity) + return self + + def has_data(self) -> bool: + return bool(self.energy.size and self.intensity.size) diff --git a/src/pynxxas/models/units.py b/src/pynxxas/models/units.py index 9c13f48..82cecb3 100644 --- a/src/pynxxas/models/units.py +++ b/src/pynxxas/models/units.py @@ -43,8 +43,9 @@ def __get_pydantic_core_schema__( _source_type: Any, _handler: pydantic.GetCoreSchemaHandler, ) -> core_schema.CoreSchema: - def serialize(value: pint.Quantity) -> List: - return list(value.to_tuple()) + def serialize(value: Any) -> List: + value = as_quantity(value) + return [value.magnitude.tolist(), str(value.units)] json_schema = core_schema.chain_schema( [ diff --git a/src/pynxxas/models/xdi.py b/src/pynxxas/models/xdi.py index cc63d79..7353d07 100644 --- a/src/pynxxas/models/xdi.py +++ b/src/pynxxas/models/xdi.py @@ -1,3 +1,6 @@ +"""XAS Data Interchange (XDI) data model_instance +""" + import datetime from typing import Optional, List, Any, Mapping @@ -110,12 +113,12 @@ class XdiData(XdiBaseModel): class XdiModel(XdiBaseModel): - element: Optional[XdiElementNamespace] = None - scan: Optional[XdiScanNamespace] = None - mono: Optional[XdiMonoNamespace] = None - beamline: Optional[XdiBeamlineNamespace] = None - facility: Optional[XdiFacilityNamespace] = None - detector: Optional[XdiDetectorNamespace] = None - sample: Optional[XdiSampleNamespace] = None - comments: Optional[List[str]] = None - data: Optional[XdiData] = None + element: XdiElementNamespace = XdiElementNamespace() + scan: XdiScanNamespace = XdiScanNamespace() + mono: XdiMonoNamespace = XdiMonoNamespace() + beamline: XdiBeamlineNamespace = XdiBeamlineNamespace() + facility: XdiFacilityNamespace = XdiFacilityNamespace() + detector: XdiDetectorNamespace = XdiDetectorNamespace() + sample: XdiSampleNamespace = XdiSampleNamespace() + comments: List[str] = list() + data: XdiData = XdiData() diff --git a/src/pynxxas/tests/conftest.py b/src/pynxxas/tests/conftest.py index a555a75..1d34399 100644 --- a/src/pynxxas/tests/conftest.py +++ b/src/pynxxas/tests/conftest.py @@ -1,4 +1,6 @@ import pytest +from ..models import NxXasModel +from ..io.xdi import load_xdi_file @pytest.fixture() @@ -9,6 +11,24 @@ def xdi_file(tmp_path): return filename +@pytest.fixture() +def xdi_model(xdi_file): + return load_xdi_file(xdi_file)[0] + + +@pytest.fixture() +def nxxas_model(): + return NxXasModel(**_NXXAS_CONTENT) + + +_NXXAS_CONTENT = { + "element": "Co", + "absorption_edge": "K", + "mode": "transmission", + "energy": [[7509, 7519], "eV"], + "intensity": [[-0.51329170, -0.78493490], ""], +} + _XDI_CONTENT = """ # XDI/1.0 GSE/1.0 # Column.1: energy eV diff --git a/src/pynxxas/tests/test_convert.py b/src/pynxxas/tests/test_convert.py new file mode 100644 index 0000000..35e2926 --- /dev/null +++ b/src/pynxxas/tests/test_convert.py @@ -0,0 +1,54 @@ +from .. import models +from ..models import convert + + +def test_xdi_to_xdi(xdi_model): + xdi_model = convert.convert_model(xdi_model, models.XdiModel) + _assert_model(xdi_model) + + +def test_nxxas_to_nxxas(nxxas_model): + nxxas_model = convert.convert_model(nxxas_model, models.NxXasModel) + _assert_model(nxxas_model) + + +def test_xdi_to_nexus(xdi_model): + nxxas_model = convert.convert_model(xdi_model, models.NxXasModel) + _assert_model(nxxas_model) + + +def test_nexus_to_xdi(nxxas_model): + xdi_model = convert.convert_model(nxxas_model, models.XdiModel) + _assert_model(xdi_model) + + +def _assert_xdi_model(xdi_model: models.XdiModel): + xdi_model.element.symbol = "Co" + assert str(xdi_model.data.energy.units) == "eV" + + assert xdi_model.data.energy.magnitude.tolist() == [7509, 7519] + assert str(xdi_model.data.energy.units) == "eV" + + assert xdi_model.data.mutrans.magnitude.tolist() == [-0.51329170, -0.78493490] + assert str(xdi_model.data.mutrans.units) == "" + + +def _assert_nxxas_model(xdi_model: models.NxXasModel): + xdi_model.element = "Co" + assert str(xdi_model.energy.units) == "eV" + + assert xdi_model.energy.magnitude.tolist() == [7509, 7519] + assert str(xdi_model.energy.units) == "eV" + + assert xdi_model.intensity.magnitude.tolist() == [-0.51329170, -0.78493490] + assert str(xdi_model.intensity.units) == "" + + +_ASSERT_MODEL = { + models.XdiModel: _assert_xdi_model, + models.NxXasModel: _assert_nxxas_model, +} + + +def _assert_model(model_instance): + _ASSERT_MODEL[type(model_instance)](model_instance) diff --git a/src/pynxxas/tests/test_nexus.py b/src/pynxxas/tests/test_nexus.py new file mode 100644 index 0000000..401a6ab --- /dev/null +++ b/src/pynxxas/tests/test_nexus.py @@ -0,0 +1,71 @@ +from ..models import NxXasModel + + +def test_nxxas(): + data = { + "@NX_class": "NXsubentry", + "definition": "NXxas", + "mode": "transmission", + "element": "Fe", + "absorption_edge": "K", + "energy": [[7, 7.1], "keV"], + "intensity": [10, 20], + } + model_instance = NxXasModel(**data) + + expected = _expected_content("NXsubentry", [[7, 7.1], "keV"], [[10, 20], ""]) + assert model_instance.model_dump() == expected + + +def test_nxxas_defaults(): + data = { + "mode": "transmission", + "element": "Fe", + "absorption_edge": "K", + } + model_instance = NxXasModel(**data) + + expected = _expected_content("NXentry", [[], ""], [[], ""]) + assert model_instance.model_dump() == expected + + +def test_nxxas_fill_data(): + data = { + "mode": "transmission", + "element": "Fe", + "absorption_edge": "K", + } + model_instance = NxXasModel(**data) + model_instance.energy = [7, 7.1], "keV" + model_instance.intensity = [10, 20] + + expected = _expected_content("NXentry", [[7, 7.1], "keV"], [[10, 20], ""]) + assert model_instance.model_dump() == expected + + +def _expected_content(nx_class, energy, intensity): + return { + "NX_class": nx_class, + "definition": "NXxas", + "mode": "transmission", + "element": "Fe", + "absorption_edge": "K", + "energy": energy, + "intensity": intensity, + "plot": { + "NX_class": "NXdata", + "axes": [ + "energy", + ], + "energy": { + "target_filename": None, + "target_name": "../energy", + }, + "intensity": { + "target_filename": None, + "target_name": "../intensity", + }, + "signal": "intensity", + }, + "title": "Fe K", + } diff --git a/src/pynxxas/tests/test_xdi.py b/src/pynxxas/tests/test_xdi.py index 38db020..79d53fb 100644 --- a/src/pynxxas/tests/test_xdi.py +++ b/src/pynxxas/tests/test_xdi.py @@ -5,12 +5,14 @@ def test_is_xdi(xdi_file): assert xdi.is_xdi_file(xdi_file) -def test_read_xdi(xdi_file): - model = xdi.read_xdi(xdi_file) +def test_load_xdi_file(xdi_file): + models = xdi.load_xdi_file(xdi_file) + assert len(models) == 1 + model_instance = models[0] # Fields - assert model.facility.energy.magnitude == 7 - assert str(model.facility.energy.units) == "GeV" + assert model_instance.facility.energy.magnitude == 7 + assert str(model_instance.facility.energy.units) == "GeV" # User ccomments comments = [ @@ -18,14 +20,14 @@ def test_read_xdi(xdi_file): "measured at beamline 13-ID-C", "vert slits = 0.3 x 0.3mm (at ~50m)", ] - assert model.comments == comments + assert model_instance.comments == comments # XAS data - assert model.data.energy.magnitude.tolist() == [7509, 7519] - assert str(model.data.energy.units) == "eV" + assert model_instance.data.energy.magnitude.tolist() == [7509, 7519] + assert str(model_instance.data.energy.units) == "eV" - assert model.data.mutrans.magnitude.tolist() == [-0.51329170, -0.78493490] - assert str(model.data.mutrans.units) == "" + assert model_instance.data.mutrans.magnitude.tolist() == [-0.51329170, -0.78493490] + assert str(model_instance.data.mutrans.units) == "" - assert model.data.i0.magnitude.tolist() == [165872.70, 161255.70] - assert str(model.data.i0.units) == "" + assert model_instance.data.i0.magnitude.tolist() == [165872.70, 161255.70] + assert str(model_instance.data.i0.units) == ""