Skip to content

Commit

Permalink
read XDI as pydantic model
Browse files Browse the repository at this point in the history
  • Loading branch information
woutdenolf committed Apr 21, 2024
1 parent cb88905 commit d6ab4ff
Show file tree
Hide file tree
Showing 12 changed files with 489 additions and 2 deletions.
9 changes: 9 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ package_dir=
packages=find:
python_requires = >=3.8
install_requires =
numpy
h5py
pydantic >=2.6
pint
typing_extensions; python_version < "3.9"

[options.packages.find]
where=src
Expand All @@ -40,6 +45,10 @@ doc =
sphinx-autodoc-typehints >=1.16
pydata-sphinx-theme < 0.15

[options.entry_points]
console_scripts =
nxxas-convert=pynxxas.apps.nxxas_convert:main

# E501 (line too long) ignored for now
# E203 and W503 incompatible with black formatting (https://black.readthedocs.io/en/stable/compatible_configs.html#flake8)
[flake8]
Expand Down
Empty file added src/pynxxas/apps/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions src/pynxxas/apps/nxxas_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
import logging
import argparse
from glob import glob

from ..io.xdi import read_xdi


def main(argv=None):
if argv is None:
argv = sys.argv

parser = argparse.ArgumentParser(
prog="nxxas_convert", description="Convert data to NXxas format"
)

parser.add_argument("--output", type=str, default=None, help="Path to HDF5 file")
parser.add_argument(
"patterns",
type=str,
nargs="+",
help="Glob file name patterns",
)

args = parser.parse_args(argv[1:])
logging.basicConfig()

for pattern in args.patterns:
for filename in glob(pattern):
read_xdi(filename)


if __name__ == "__main__":
sys.exit(main())
Empty file added src/pynxxas/io/__init__.py
Empty file.
143 changes: 143 additions & 0 deletions src/pynxxas/io/xdi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import re
import pathlib
import datetime
from typing import Union, Tuple, Optional

import pint
import numpy

from ..models import units
from ..models.xdi import XdiModel


def is_xdi_file(filename: Union[str, pathlib.Path]) -> bool:
with open(filename, "r") as file:
for line in file:
line = line.strip()
if not line:
continue
return line.startswith("# XDI")


def read_xdi(filename: Union[str, pathlib.Path]) -> XdiModel:
"""Specs described in
https://github.com/XraySpectroscopy/XAS-Data-Interchange/blob/master/specification/spec.md
"""
content = {"comments": [], "column": dict(), "data": dict()}

with open(filename, "r") as file:
# Version: first non-empty line
for line in file:
line = line.strip()
if not line:
continue
if not line.startswith("# XDI"):
raise ValueError(f"XDI file does not start with '# XDI': '{filename}'")
break

# Fields and comments: lines starting with "#"
is_comment = False
for line in file:
line = line.strip()

if not line.startswith("#"):
raise ValueError(f"Invalid XDI header line: '{line}'")

if _XDI_HEADER_END_REGEX.match(line):
break

if _XDI_FIELDS_END_REGEX.match(line):
# Next lines in the header are user comments
is_comment = True
continue

if is_comment:
match_comment = _XDI_COMMENT_REGEX.match(line)
if not match_comment:
continue
(comment,) = match_comment.groups()
content["comments"].append(comment)
continue

match_namespace = _XDI_FIELD_REGEX.match(line)
if match_namespace:
key, value = match_namespace.groups()
value = _parse_xdi_value(value)
key_parts = key.split(".")
if len(key_parts) > 1:
namespace, key = key_parts
namespace = namespace.lower()
key = key.lower()
key = _parse_xdi_value(key)
if namespace not in content:
content[namespace] = {}
content[namespace][key] = value
else:
key = key_parts[0]
key = _parse_xdi_value(key)
content[key] = value

# Data
table = numpy.loadtxt(file, dtype=float)

# Parse data in dictionary of pint Quantity objects
columns = [
name
for _, name in sorted(content.pop("column").items(), key=lambda tpl: tpl[0])
]
for name, array in zip(columns, table.T):
name, quant = _parse_xdi_column_name(name)
content["data"][name] = array, quant

return XdiModel(**content)


_XDI_FIELD_REGEX = re.compile(r"#\s*([\w.]+):\s*(.*)")
_XDI_COMMENT_REGEX = re.compile(r"#\s*(.*)")
_XDI_HEADER_END_REGEX = re.compile(r"#\s*-")
_XDI_FIELDS_END_REGEX = re.compile(r"#\s*///")
_NUMBER_REGEX = re.compile(r"(?=.)([+-]?([0-9]*)(\.([0-9]+))?)([eE][+-]?\d+)?\s+\w+")
_SPACES_REGEX = re.compile(r"\s+")


def _parse_xdi_value(
value: str,
) -> Union[str, datetime.datetime, pint.Quantity, Tuple[str, pint.Quantity]]:
# Dimensionless integral number
try:
return units.as_quantity(int(value))
except ValueError:
pass

# Dimensionless decimal number
try:
return units.as_quantity(float(value))
except ValueError:
pass

# Date and time
try:
return datetime.datetime.fromisoformat(value)
except ValueError:
pass

# Number with units
if _NUMBER_REGEX.match(value):
try:
return units.as_quantity(value)
except pint.UndefinedUnitError:
pass

return value


def _parse_xdi_column_name(
name: str,
) -> Union[Tuple[str, Optional[pint.Unit]]]:
parts = _SPACES_REGEX.split(name)
if len(parts) == 1:
return name, None
if len(parts) == 2:
return tuple(parts)
raise ValueError(f"XDI column name '{name}' is not valid")
Empty file added src/pynxxas/models/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions src/pynxxas/models/units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pint
import pydantic
from pydantic_core import core_schema
from pydantic.json_schema import JsonSchemaValue

from typing import Any, Sequence, Union, List

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

_REGISTRY = pint.UnitRegistry()
_REGISTRY.default_format = "~" # unit symbols instead of full unit names


def as_quantity(value: Union[str, pint.Quantity, Sequence]) -> pint.Quantity:
if isinstance(value, pint.Quantity):
return value
if (
isinstance(value, Sequence)
and len(value) == 2
and (isinstance(value[1], str) or value[1] is None)
):
value, units = value
else:
units = None
return _REGISTRY.Quantity(value, units)


def as_units(value: Union[str, pint.Unit]) -> pint.Unit:
if isinstance(value, pint.Unit):
return value
return _REGISTRY.parse_units(value)


class _QuantityPydanticAnnotation:
# https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: pydantic.GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def serialize(value: pint.Quantity) -> List:
return list(value.to_tuple())

json_schema = core_schema.chain_schema(
[
core_schema.no_info_plain_validator_function(as_quantity),
]
)

return core_schema.json_or_python_schema(
json_schema=json_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(pint.Quantity),
json_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
)

@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: core_schema.CoreSchema,
handler: pydantic.GetJsonSchemaHandler,
) -> JsonSchemaValue:
return handler(
core_schema.union_schema(
[
core_schema.float_schema(),
core_schema.list_schema(core_schema.float_schema()),
]
)
)


PydanticQuantity = Annotated[pint.Quantity, _QuantityPydanticAnnotation]
Loading

0 comments on commit d6ab4ff

Please sign in to comment.