diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 82ec387c..7af5b7f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,12 +46,12 @@ jobs: - name: Test with pytest if: matrix.python-version != '3.12' || matrix.qt-api != 'pyqt6' run: | - pytest -vv + python -u -m pytest -vv --full-trace - name: Test with pytest with coverage if: matrix.python-version == '3.12' && matrix.qt-api == 'pyqt6' run: | - pytest -vv --cov erlab --junitxml=junit.xml + python -u -m pytest -vv --full-trace --cov erlab --junitxml=junit.xml - name: Upload coverage to Codecov if: matrix.python-version == '3.12' && matrix.qt-api == 'pyqt6' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a135e532..dc475c58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: # Lint and format with ruff - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.9 + rev: v0.7.2 hooks: # Run the linter. - id: ruff @@ -43,7 +43,7 @@ repos: # Commitizen - repo: https://github.com/commitizen-tools/commitizen - rev: v3.29.1 + rev: v3.30.0 hooks: - id: commitizen additional_dependencies: [ cz-changeup ] diff --git a/CHANGELOG.md b/CHANGELOG.md index b9180b6e..34ec9dc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,98 @@ +## Unreleased + +### Breaking Changes + +- Deprecated module `erlab.io.utilities` is removed. Use `erlab.io.utils` instead. ([e189722](https://github.com/kmnhan/erlabpy/commit/e189722f129d55cab0d2ec279e5303929cb09979)) +- Deprecated module `erlab.interactive.utilities` is removed. Use `erlab.interactive.utils` instead. ([af2c81c](https://github.com/kmnhan/erlabpy/commit/af2c81c676455ddfa19ae9bbbbdbdd68d257f26c)) +- Deprecated module `erlab.characterization` is removed. Use `erlab.io.characterization` instead. ([8d770bf](https://github.com/kmnhan/erlabpy/commit/8d770bfe298253c020aeda6d61a9eab625facf6c)) +- Deprecated module `erlab.analysis.utils` is removed. Use `erlab.analysis.transform.shift` and `erlab.analysis.gold.correct_with_edge`. ([0b2ca44](https://github.com/kmnhan/erlabpy/commit/0b2ca44844cc5802d32d9ed949e831b534525183)) +- Deprecated alias `slice_along_path` in `erlab.analysis` is removed. Call from `erlab.analysis.interpolate` instead. ([305832b](https://github.com/kmnhan/erlabpy/commit/305832b7bb18aa3d1fda21f4cd0c0992b174d639)) +- Deprecated aliases `correct_with_edge` and `quick_resolution` in `erlab.analysis` are removed. Call from `erlab.analysis.gold` instead. ([075eaf8](https://github.com/kmnhan/erlabpy/commit/075eaf8cd222044aa5cc0c3459698ab33568958c)) +- Removed deprecated aliases `load_igor_ibw` and `load_igor_pxp`. Use `xarray.open_dataarray` and `xarray.open_dataset` instead. ([7f07ad2](https://github.com/kmnhan/erlabpy/commit/7f07ad2c46f80d48c255d408f3f200ae01930060)) +- The default attribute name for the sample temperature is changed to `sample_temp` from `temp_sample`. This will unfortunately break a lot of code that relies on the key `temp_sample`, but will be easy to refactor with find and replace. ([32e1cd5](https://github.com/kmnhan/erlabpy/commit/32e1cd5fb45bce12cfa83c520e8c61af96a8cb39)) +- All dataloaders must now add a new keyword argument to `load_single`, but implementing it is not mandatory. + + Also, dataloaders that implements summary generation by overriding `generate_summary` must now switch to the new method. + + See the summary generation section in the updated user guide. + + Furthermore, the `isummarize` method is no longer public; code that uses this method should use `summarize` instead. + + The `usecache` argument to the `summarize` method is no longer available, and the cache will be updated whenever it is outdated. ([0f5dab4](https://github.com/kmnhan/erlabpy/commit/0f5dab46e3d3a75fc77908b4072f08aa89059acd)) + +### Features + +- **io.igor:** enable loading experiment files to DataTree ([1835be0](https://github.com/kmnhan/erlabpy/commit/1835be0d08ed899b2edbb06fb442cd9addb40929)) + + Added methods to the backend to allow using `xarray.open_datatree` and `xarray.open_groups` with Igor packed experiment files. Closes [#29](https://github.com/kmnhan/erlabpy/issues/29) +- add `qinfo` accessor ([eb3a742](https://github.com/kmnhan/erlabpy/commit/eb3a74297211aae8f13e6974563e6da819bfbedb)) + + Adds a `qinfo` accessor that prints a table summarizing the data in a human readable format. Closes [#27](https://github.com/kmnhan/erlabpy/issues/27) +- **interactive.kspace:** pass lattice parameters and colormap info to `ktool` ([6830af3](https://github.com/kmnhan/erlabpy/commit/6830af343326e0367a6dfb016728a6cf1325cf64)) + + Added the ability to pass lattice vectors and colormaps to `ktool`. +- **interactive.kspace:** add circle ROI to ktool ([304e1a5](https://github.com/kmnhan/erlabpy/commit/304e1a53f189ebed9a890680c3499a756c586498)) + + Added a button to the visualization tab which creates a circle ROI. The position and radius can be edited by right-clicking on the roi. +- **interactive.colors:** add zero center button to right-click colorbar ([c037de1](https://github.com/kmnhan/erlabpy/commit/c037de1f4387c0daf7cc7aa252124f01269bc633)) +- **interactive.imagetool:** add `.ibw` and `.pxt` files to load menu ([73c3afe](https://github.com/kmnhan/erlabpy/commit/73c3afef306109be858d23dbf8511617c5d203dd)) +- **io.dataloader:** allow passing rcParams to interactive summary plot ([a348366](https://github.com/kmnhan/erlabpy/commit/a34836673315fdc9acc0ed52d8e56edc90c18456)) +- **io.dataloader:** implement automatic summary generation ([0f5dab4](https://github.com/kmnhan/erlabpy/commit/0f5dab46e3d3a75fc77908b4072f08aa89059acd)) + + It is now much easier to implement a summary generation mechanism. This commit also adds a new keyword argument to `load_single` that can greatly speed up summary generation. +- **io.dataloader:** support callable objects in `additional_attrs` ([e209499](https://github.com/kmnhan/erlabpy/commit/e209499c8044f0085fda74b7dc491517a695099c)) + +### Bug Fixes + +- **interactive.imagetool:** fix copy cursor value for numpy 2 ([dc19c82](https://github.com/kmnhan/erlabpy/commit/dc19c827c4082989e47b0f8e2d7adda45ad62aaa)) +- **io.dataloader:** retain selected dimension in interactive summary ([9d54f8b](https://github.com/kmnhan/erlabpy/commit/9d54f8b3402767cf15e6cf5ab00ee5a1b766d172)) +- **accessors.general:** keep associated coords in `qsel` when averaging ([03a7b4a](https://github.com/kmnhan/erlabpy/commit/03a7b4a30b4c6a635f904fcab377298b06b86f66)) +- **io.dataloader:** ignore old summary files ([bda95fc](https://github.com/kmnhan/erlabpy/commit/bda95fc1f0aaec73c179fd47258f6fde8056aaf9)) +- **io.plugins.kriss:** fix KRISS ibw file match pattern ([7ced571](https://github.com/kmnhan/erlabpy/commit/7ced57152edb802bd14f831c77494a6f805f5097)) +- **analysis.gold:** retain attributes in `quick_resolution` ([504acdc](https://github.com/kmnhan/erlabpy/commit/504acdc1d7d9b8dcd4613ca97551d78c366f0337)) +- do not require qt libs on initial import ([118ead6](https://github.com/kmnhan/erlabpy/commit/118ead603b89867e56b29932f59bd02b476ab43b)) + +### Code Refactor + +- **io:** remove deprecated module ([e189722](https://github.com/kmnhan/erlabpy/commit/e189722f129d55cab0d2ec279e5303929cb09979)) +- **interactive:** remove deprecated module ([af2c81c](https://github.com/kmnhan/erlabpy/commit/af2c81c676455ddfa19ae9bbbbdbdd68d257f26c)) +- remove deprecated module `erlab.characterization` ([8d770bf](https://github.com/kmnhan/erlabpy/commit/8d770bfe298253c020aeda6d61a9eab625facf6c)) +- **analysis:** remove deprecated module ([0b2ca44](https://github.com/kmnhan/erlabpy/commit/0b2ca44844cc5802d32d9ed949e831b534525183)) +- **analysis:** remove deprecated alias ([305832b](https://github.com/kmnhan/erlabpy/commit/305832b7bb18aa3d1fda21f4cd0c0992b174d639)) +- **analysis:** remove deprecated aliases ([075eaf8](https://github.com/kmnhan/erlabpy/commit/075eaf8cd222044aa5cc0c3459698ab33568958c)) +- **interactive.imagetool.manager:** add prefix to temporary directories for better identification ([e56163b](https://github.com/kmnhan/erlabpy/commit/e56163ba7fe7d92f3a01ec78098c2d0194ea0302)) +- **io.plugins:** implement DA30 file identification patterns in superclass ([f6dfc44](https://github.com/kmnhan/erlabpy/commit/f6dfc4412b56fc1d83efceb4a65070eb9ef1c2b1)) +- **io:** remove deprecated aliases ([7f07ad2](https://github.com/kmnhan/erlabpy/commit/7f07ad2c46f80d48c255d408f3f200ae01930060)) +- change temperature attribute name ([32e1cd5](https://github.com/kmnhan/erlabpy/commit/32e1cd5fb45bce12cfa83c520e8c61af96a8cb39)) + + Changes `temp_sample` to `sample_temp` for all data loaders and analysis code. +- **utils.formatting:** change formatting for numpy arrays ([95d9f0b](https://github.com/kmnhan/erlabpy/commit/95d9f0b602551141232eb5a2fa10c421d11d2233)) + + For arrays with 2 or more dimensions upon squeezing, only the minimum and maximum values are shown. Also, arrays with only two entries are displayed as a list. +- **io.dataloader:** disable parallel loading by default ([fed2428](https://github.com/kmnhan/erlabpy/commit/fed2428229e3ef70fc95a35670fc75ace44024bd)) + + Parallel loading is now disabled by default since the overhead is larger than the performance gain in most cases. +- change some warnings to emit from user level ([e81f2b1](https://github.com/kmnhan/erlabpy/commit/e81f2b121d2931b327d30b146db1e77e7a3b3ec2)) +- **io.dataloader:** cache summary only if directory is writable ([85bcb80](https://github.com/kmnhan/erlabpy/commit/85bcb80bdf27ea12edb9314247a978f71c8be6dc)) +- **io.plugins:** improve warning message when a plugin fails to load ([9ee0b90](https://github.com/kmnhan/erlabpy/commit/9ee0b901b1b904dabb38d29f4c166dca07c9a7e9)) +- **io:** update datatree to use public api ([6c27e07](https://github.com/kmnhan/erlabpy/commit/6c27e074c5aceb16eb9808cca38b8ba73748f07e)) + + Also bumps the minimum supported xarray version to 2024.10.0. +- **io.dataloader:** make `RegistryBase` private ([df7079e](https://github.com/kmnhan/erlabpy/commit/df7079e4fc96b195d34436bcc93684e10ddecdad)) +- **io.dataloader:** rename loader registry attribute `default_data_dir` to `current_data_dir` ([d87eba7](https://github.com/kmnhan/erlabpy/commit/d87eba7db6cea051f76b61ea7b0834e439460810)) + + The attribute `default_data_dir` has been renamed to `current_data_dir` so that it is consistent with `current_loader`. Accessing the old name is now deprecated. + + Also, the `current_loader` and `current_data_dir` can now be assigned directly with a syntax like `erlab.io.loaders.current_loader = "merlin"`. + +### Performance + +- **io.plugins.da30:** faster summary generation for DA30 zip files ([22b77bf](https://github.com/kmnhan/erlabpy/commit/22b77bf0ee787fe1236fb85262702b79265e3b8d)) +- **io.igor:** suppress `igor2` logging ([5cd3a8c](https://github.com/kmnhan/erlabpy/commit/5cd3a8c273b143d1a83f3286678638fede1ddd01)) +- **analysis.interpolate:** extend acceleration ([84daa88](https://github.com/kmnhan/erlabpy/commit/84daa8866ec4223555568f441b6010bb4936a413)) + + The fast linear interpolator now allows more general interpolation points like interpolating 3D data on a 2D grid. This means that passing `method='linearfast'` to `DataArray.interp` is faster in many cases. + ## v2.12.0 (2024-10-22) ### Features @@ -668,15 +763,11 @@ ## v2.0.0 (2024-04-08) -### BREAKING CHANGE +### Breaking Changes - `PolyFunc` is now `PolynomialFunction`, and `FermiEdge2dFunc` is now `FermiEdge2dFunction`. The corresponding model names are unchanged. ([20d784c](https://github.com/kmnhan/erlabpy/commit/20d784c1d8fdcd786ab73b3ae03d3e331dc04df5)) - - BREAKING CHANGE: `PolyFunc` is now `PolynomialFunction`, and `FermiEdge2dFunc` is now `FermiEdge2dFunction`. The corresponding model names are unchanged. - This change disables the use of guess_fit. All fitting must be performed in the syntax recommended by lmfit. Addition of a accessor or a convenience function for coordinate-aware fitting is planned in the next release. ([59163d5](https://github.com/kmnhan/erlabpy/commit/59163d5f0e000d65aa53690a51b6db82df1ce5f1)) - BREAKING CHANGE: This change disables the use of guess_fit. All fitting must be performed in the syntax recommended by lmfit. Addition of a accessor or a convenience function for coordinate-aware fitting is planned in the next release. - ### Features - **itool:** add copy code to PlotItem vb menu ([7b4f30a](https://github.com/kmnhan/erlabpy/commit/7b4f30ada21c5accc1d3824ad3d0f8097f9a99c1)) @@ -689,8 +780,6 @@ - add derivative module with minimum gradient implementation ([e0eabde](https://github.com/kmnhan/erlabpy/commit/e0eabde60e6860c3827959b45be6d4f491918363)) - **fit:** directly base models on lmfit.Model ([59163d5](https://github.com/kmnhan/erlabpy/commit/59163d5f0e000d65aa53690a51b6db82df1ce5f1)) - BREAKING CHANGE: This change disables the use of guess_fit. All fitting must be performed in the syntax recommended by lmfit. Addition of a accessor or a convenience function for coordinate-aware fitting is planned in the next release. - ### Bug Fixes - **dynamic:** properly broadcast xarray input ([2f6672f](https://github.com/kmnhan/erlabpy/commit/2f6672f3b003792ecd98b4fbc99fb11fcc0efb8b)) @@ -703,8 +792,6 @@ ### Code Refactor - **fit:** unify dynamic function names ([20d784c](https://github.com/kmnhan/erlabpy/commit/20d784c1d8fdcd786ab73b3ae03d3e331dc04df5)) - - BREAKING CHANGE: `PolyFunc` is now `PolynomialFunction`, and `FermiEdge2dFunc` is now `FermiEdge2dFunction`. The corresponding model names are unchanged. - update dtool to use new functions ([a6e46bb](https://github.com/kmnhan/erlabpy/commit/a6e46bb8b19512e438291afbbd5e0e9a4eb4fe87)) - **analysis.image:** add documentation and reorder functions ([340665d](https://github.com/kmnhan/erlabpy/commit/340665dc507a99acc7d56c46a2a2326fbb56b1e3)) - rename module to image and add citation ([b74a654](https://github.com/kmnhan/erlabpy/commit/b74a654e07d9f4522cee2db0b897f1ffcdb86e94)) diff --git a/docs/source/changelog.md b/docs/source/changelog.md new file mode 100644 index 00000000..922f663e --- /dev/null +++ b/docs/source/changelog.md @@ -0,0 +1,4 @@ +# Changelog + +```{include} ../../CHANGELOG.md +``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 521c9dca..90f8c765 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -40,7 +40,6 @@ "sphinx.ext.intersphinx", "sphinx.ext.mathjax", "sphinx.ext.napoleon", - "sphinx.ext.autosectionlabel", "sphinx_autodoc_typehints", # "IPython.sphinxext.ipython_directive", # "IPython.sphinxext.ipython_console_highlighting", @@ -54,6 +53,7 @@ "matplotlib.sphinxext.roles", "sphinxcontrib.bibtex", "sphinx_qt_documentation", + "myst_parser", ] if os.getenv("READTHEDOCS"): @@ -114,10 +114,7 @@ def linkcode_resolve(domain, info) -> str | None: fn = os.path.relpath(fn, start=os.path.dirname(erlab.__file__)) - return ( - f"https://github.com/kmnhan/erlabpy/blob/" - f"v{version}/src/erlab/{fn}{linespec}" - ) + return f"https://github.com/kmnhan/erlabpy/blob/v{version}/src/erlab/{fn}{linespec}" # -- Autosummary and autodoc settings ---------------------------------------- @@ -161,8 +158,9 @@ def linkcode_resolve(domain, info) -> str | None: napoleon_preprocess_types = True napoleon_type_aliases = { "ndarray": "numpy.ndarray", - "DataArray": "xarray.DataArray", - "Dataset": "xarray.Dataset", + "DataArray": "`DataArray `", + "Dataset": "`Dataset `", + "DataTree": "`DataTree `", "np.float32": "float32", "numpy.float32": "float32", "np.float64": "float64", diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 136359a7..1fc7ae4a 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -232,7 +232,8 @@ The editing workflow -------------------- 1. Make some changes. Make sure to follow the :ref:`code standards - ` and the `documentation standards <#documentation>`_. + ` and the :ref:`documentation standards + `. 2. See which files have changed with ``git status``. You'll see a listing like this one: :: @@ -277,7 +278,7 @@ Writing tests for data loader plugins ------------------------------------- When contributing a new data loader plugin, it is important to write tests to ensure -that the plugin works as expected over time. +that the plugin always returns the expected data for newer package versions. Since ARPES data required for testing take up a lot of space, we have a separate repository for test data: `erlabpy-data `_. @@ -293,7 +294,8 @@ add tests, follow these steps: 3. Place the test data files into the directory you created in step 3. It's a good practice to also include a processed version of the data that the plugin should - return, and use this as a reference in the tests. + return, and use this as a reference in the tests. See preexisting directories and + tests for examples. 4. Set the environment variable `ERLAB_TEST_DATA_DIR` to the path of the cloned `erlabpy-data `_ repository in your @@ -327,23 +329,10 @@ add tests, follow these steps: the test data has not been modified or corrupted since the last time the tests were run. - To calculate the hash, first download the tarball from the GitHub API:: - - https://api.github.com/repos/kmnhan/erlabpy-data/tarball/ - - The new hash can be calculated by running the following command in the terminal: - - .. code-block:: bash - - openssl sha256 path/to/kmnhan-erlabpy-data-.tar.gz - - or using `pooch `_: - - .. code-block:: python - - import pooch - - pooch.file_hash("path/to/kmnhan-erlabpy-data-.tar.gz") + The hash is calculated by `this workflow + `_ for each + push to main. It can be copied from the workflow summary corresponding to the + commit you wish to refer to. .. _development.code-standards: @@ -385,6 +374,8 @@ Code standards super().__init__() self.setupUi(self) +.. _development.docs: + Documentation ============= diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst index 95b234c4..c3115508 100644 --- a/docs/source/getting-started.rst +++ b/docs/source/getting-started.rst @@ -158,8 +158,3 @@ Notes on compatibility - There are some `known compatibility issues `_ with PyQt5 and PySide2, so it is recommended to use the newer PyQt6 or PySide6 if possible. -- If you meet any unexpected behaviour while using IPython's `autoreload extension - `_, try - excluding the following modules: :: - - %aimport -erlab.io -erlab.accessors diff --git a/docs/source/images/ktool_1_dark.png b/docs/source/images/ktool_1_dark.png index 680273b3..4a242861 100644 Binary files a/docs/source/images/ktool_1_dark.png and b/docs/source/images/ktool_1_dark.png differ diff --git a/docs/source/images/ktool_1_light.png b/docs/source/images/ktool_1_light.png index c69a9e59..14018cae 100644 Binary files a/docs/source/images/ktool_1_light.png and b/docs/source/images/ktool_1_light.png differ diff --git a/docs/source/images/ktool_2_dark.png b/docs/source/images/ktool_2_dark.png index 07c68465..93db91f2 100644 Binary files a/docs/source/images/ktool_2_dark.png and b/docs/source/images/ktool_2_dark.png differ diff --git a/docs/source/images/ktool_2_light.png b/docs/source/images/ktool_2_light.png index ac8eb97a..05bb7a40 100644 Binary files a/docs/source/images/ktool_2_light.png and b/docs/source/images/ktool_2_light.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 3c787110..3fbc81fb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -89,4 +89,4 @@ research! reference contributing bibliography - Changelog + changelog diff --git a/docs/source/user-guide/imagetool.rst b/docs/source/user-guide/imagetool.rst index a661b9ba..b25b6f96 100644 --- a/docs/source/user-guide/imagetool.rst +++ b/docs/source/user-guide/imagetool.rst @@ -54,6 +54,12 @@ Tips quickly pasted to a Python script or Jupyter notebook to reproduce the data in the clicked region. +- Right-clicking on the colorbar allows you to manually set the color range or to + copy the color range to the clipboard. + +- Data manipulation such as rotation and normalization are also possible. Try exploring + the edit and view menus. + - ImageTool is also very extensible. At our home lab, we use a modified version of ImageTool to plot data as it is being collected in real-time! diff --git a/docs/source/user-guide/io.ipynb b/docs/source/user-guide/io.ipynb index 60e7b066..1b5e39e4 100644 --- a/docs/source/user-guide/io.ipynb +++ b/docs/source/user-guide/io.ipynb @@ -22,24 +22,33 @@ } }, "source": [ - "In ERLabPy, most data are represented as :class:`xarray.Dataset` objects or\n", - ":class:`xarray.DataArray` objects.\n", - "\n", - ":class:`xarray.DataArray` are similar to waves in Igor Pro, but are much more flexible.\n", - "Opposed to the maximum of 4 dimensions in Igor, :class:`xarray.DataArray` can have as\n", - "many dimensions as you want (up to 64). Another advantage is that the coordinates of the\n", - "dimensions do not have to be evenly spaced. In fact, they are not limited to numbers but\n", - "can be any type of data, such as date and time representations.\n", - "\n", - ":class:`xarray.Dataset` is a collection of :class:`xarray.DataArray` objects. It is used to\n", - "store multiple data arrays that are related to each other, such as a set of measurements.\n", - "\n", ".. note::\n", "\n", " If you are not familiar with :mod:`xarray`, it is recommended to read the `xarray\n", " tutorial `_ and the `xarray user guide\n", " `_ first.\n", "\n", + "In ERLabPy, data are represented as :class:`xarray.DataArray`, :class:`xarray.Dataset`,\n", + "and :class:`xarray.DataTree` objects.\n", + "\n", + "- :class:`xarray.DataArray` are similar to waves in Igor Pro, but are much more flexible.\n", + " Opposed to the maximum of 4 dimensions in Igor, :class:`xarray.DataArray` can have as\n", + " many dimensions as you want (up to 64). Another advantage is that the coordinates of\n", + " the dimensions do not have to be evenly spaced. In fact, they are not limited to\n", + " numbers but can be any type of data, such as date and time representations.\n", + "\n", + "- :class:`xarray.Dataset` is a collection of :class:`xarray.DataArray` objects. It is\n", + " used to store multiple data arrays that are related to each other, such as a set of\n", + " measurements.\n", + "\n", + "- :class:`xarray.DataTree` is a hierarchical data structure that can store multiple\n", + " :class:`xarray.Dataset` objects, just like an Igor experiment file with multiple waves\n", + " within nested folders.\n", + "\n", + "See `Data Structures\n", + "`_ in the xarray\n", + "documentation for more information.\n", + "\n", "This guide will introduce you to reading and writing data from and to various file\n", "formats, and how to implement a custom plugin for a experimental setup.\n", "\n", @@ -69,11 +78,12 @@ } }, "source": [ - ":mod:`xarray` provides native support for reading and writing NetCDF and HDF5 files into\n", + ":mod:`xarray` provides basic support for reading and writing NetCDF and HDF5 files into\n", ":mod:`xarray` objects. See the :mod:`xarray` documentation on `I/O operations\n", - "`_.\n", + "`_ for more information.\n", "\n", - "Here, we will focus on working with Igor Pro and xarray objects." + "Here, we will focus on working with data exported from Igor Pro and some other commonly\n", + "used file formats." ] }, { @@ -109,12 +119,8 @@ "```python\n", "import xarray as xr\n", "\n", - "data = xr.open_dataarray(\"path/to/file.ibw\")\n", - "```\n", - "\n", - "Along with the Igor Pro file formats, the backend also supports loading HDF5 files\n", - "exported from Igor Pro. For such files, the engine must be specified explicitly with\n", - "`engine=\"erlab-igor\"`." + "data = xr.open_dataarray(\"path/to/wave.ibw\")\n", + "```" ] }, { @@ -131,14 +137,20 @@ } }, "source": [ - ".. warning::\n", - "\n", - " Loading waves from complex ``.pxp`` files may fail or produce unexpected results. It\n", - " is recommended to export the waves to a ``.ibw`` file to load them in ERLabPy. If you\n", - " encounter any problems, please let us know by opening an issue.\n", + "Loading an experiment file to a :class:`xarray.DataTree` is also possible:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "data = xr.open_datatree(\"path/to/experiment.pxpt\")\n", + "```\n", "\n", - "Convenience functions that can load Igor Pro files directly can be found in\n", - ":mod:`erlab.io.igor`. However, it is recommended to use the xarray functions." + "Along with the Igor Pro file formats, the backend also supports loading HDF5 files\n", + "exported from Igor Pro. For such files, the engine must be specified explicitly with\n", + "`engine=\"erlab-igor\"`." ] }, { @@ -155,6 +167,15 @@ } }, "source": [ + ".. warning::\n", + "\n", + " Loading waves from complex ``.pxp`` files may fail or produce unexpected results. It\n", + " is recommended to export the waves to a ``.ibw`` file to load them in ERLabPy. If you\n", + " encounter any problems, please let us know by opening an issue.\n", + "\n", + "Convenience functions that can load Igor Pro files directly are implemented in\n", + ":mod:`erlab.io.igor`, but it is recommended to use the xarray functions presented above.\n", + "\n", "From arbitrary formats\n", "~~~~~~~~~~~~~~~~~~~~~~\n", "\n", @@ -168,20 +189,14 @@ " :meth:`pandas.DataFrame.to_xarray` or :meth:`xarray.Dataset.from_dataframe`.\n", "\n", "* When reading HDF5 files with arbitrary groups and metadata, you must first explore the \n", - " group structure using `h5netcdf `_ or ``open_datatree``.\n", - "\n", - " .. note::\n", - " \n", - " ``open_datatree`` can be imported from ``xarray.backends.api``. In the near future,\n", - " it will be documented and made available in the public API.\n", - " \n", - " Loading a specific HDF5 group into an xarray object can be done using\n", - " :func:`xarray.open_dataset` or :func:`xarray.open_mfdataset` by supplying the\n", - " ``group`` argument.\n", + " group structure using `h5netcdf `_. More conveniently, you can\n", + " use :func:`xarray.open_groups` to inspect the group structure.\n", "\n", "* FITS files can be read with `astropy\n", " `_. In the near future, ERLabPy\n", - " will provide a loader for FITS files." + " will provide a loader for FITS files.\n", + "\n", + "* For working with NeXus files, see :mod:`erlab.io.nexusutils`." ] }, { @@ -193,13 +208,12 @@ "Since the state and variables of a Python interpreter are not saved, it is important to\n", "save your data in a format that can be easily read and written.\n", "\n", - "While it is possible to save and load entire Python interpreter sessions using `pickle`\n", - "or the more versatile [dill](https://github.com/uqfoundation/dill), it is out of the\n", - "scope of this guide. Instead, we recommend saving your data in a format that is easy to\n", - "read and write, such as HDF5 or NetCDF. These formats are supported by many programming\n", - "languages and are optimized for fast read and write operations.\n", - "\n", - "To save and load `xarray` objects, see the `xarray` documentation on [I/O operations](https://docs.xarray.dev/en/stable/user-guide/io.html)." + "While it is possible to save and load entire Python interpreter sessions using\n", + "[pickle](https://docs.python.org/3/library/pickle.html) or the more versatile\n", + "[dill](https://github.com/uqfoundation/dill), it is out of the scope of this guide.\n", + "Instead, we recommend saving your data in a format that is easy to read and write, such\n", + "as HDF5 or NetCDF. To save and load xarray objects to such formats, see the xarray\n", + "documentation on [I/O operations](https://docs.xarray.dev/en/stable/user-guide/io.html)." ] }, { @@ -220,13 +234,15 @@ "~~~~~~~~~~~\n", "\n", "As an experimental feature, :func:`save_as_hdf5 ` can save\n", - "certain :class:`xarray.DataArray`\\ s in a format that is compatible with the Igor Pro\n", - "HDF5 loader. An `accompanying Igor procedure\n", + "certain DataArrays in a format that is compatible with the Igor Pro HDF5 loader. An\n", + "`accompanying Igor procedure\n", "`_ is available in the\n", "repository. If loading in Igor Pro fails, try saving again with all attributes removed.\n", "\n", "Alternatively, `igorwriter `_ can be used to write\n", - "numpy arrays to ``.ibw`` and ``.itx`` files directly." + "numpy arrays to ``.ibw`` and ``.itx`` files directly.\n", + "\n", + ".. _loading-arpes-data:" ] }, { @@ -256,10 +272,6 @@ " ERLabPy is still in development and the API may change. Some major changes regarding\n", " data loading and handling are planned:\n", "\n", - " - The `xarray datatree structure `_\n", - " will enable much more intuitive and powerful data handling. Once the feature gets\n", - " incorporated into xarray, ERLabPy will be updated to use it.\n", - "\n", " - A universal translation layer between true data header attributes and\n", " human-readable representations will be implemented. This will allow for more\n", " consistent and user-friendly data handling." @@ -270,7 +282,7 @@ "metadata": {}, "source": [ "ERLabPy's data loading framework consists of various plugins, or *loaders*, each\n", - "designed to load data from a different beamline or laboratory. Each loader is a class\n", + "designed to load data from a different beamline or laboratory. Each *loader* is a class\n", "that has a `load` method which takes a file path or sequence number and returns data.\n", "\n", "Let's see the list of loaders available by default:" @@ -501,6 +513,16 @@ "```" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you don't want automatic concatenation to happen, you can suppress it with `combine=False`. The following code will return a list of DataArrays:\n", + "```python\n", + "erlab.io.load(3, combine=False)\n", + "```" + ] + }, { "cell_type": "raw", "metadata": { @@ -517,16 +539,14 @@ "source": [ "Handling multiple data directories\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - " \n", + "\n", "If you call :func:`erlab.io.set_loader` or :func:`erlab.io.set_data_dir` multiple times,\n", "the last call will override the previous ones. While this is useful for changing the\n", "loader or data directory, it makes data loading *dependent on execution order*. This may\n", - "lead to unexpected behavior.\n", + "lead to unexpected behavior in notebooks.\n", "\n", "If you plan to use multiple loaders or data directories in the same session, it is\n", - "recommended to use the context manager. If you have to load data from multiple\n", - "directories multiple times, it may be convenient to define functions that set the loader\n", - "and data directory and call :func:`erlab.io.load` with the appropriate arguments. For example:" + "recommended to use the context manager :func:`erlab.io.loader_context`:" ] }, { @@ -534,9 +554,8 @@ "metadata": {}, "source": [ "```python\n", - "def load1(identifier):\n", - " with erlab.io.loader_context(\"merlin\", data_dir=\"/path/to/data1\"):\n", - " return erlab.io.load(identifier)\n", + "with erlab.io.loader_context(\"merlin\", data_dir=\"/path/to/data\"):\n", + " data = erlab.io.load(identifier)\n", "```" ] }, @@ -554,16 +573,17 @@ } }, "source": [ + "It may also be convenient to define functions that set the loader and data directory and\n", + "call :func:`erlab.io.load` with the appropriate arguments.\n", + "\n", "Summarizing data\n", "~~~~~~~~~~~~~~~~\n", "\n", - "Some loaders have :meth:`generate_summary\n", - "` implemented, which generates a\n", - ":class:`pandas.DataFrame` containing an overview of the data in a given directory. The\n", - "generated summary can be viewed as a table with the :meth:`summarize\n", - "` method. If ``ipywidgets`` is installed, an\n", - "interactive widget is also displayed. This is useful for quickly skimming through the\n", - "data.\n", + "Some supported loaders can generate a :class:`pandas.DataFrame` containing an overview\n", + "of the data in a given directory. The generated summary can be viewed as a table with\n", + "the :meth:`summarize ` method. If\n", + "``ipywidgets`` is installed, an interactive widget is also displayed. This is useful for\n", + "quickly skimming through the data.\n", "\n", "Just like :meth:`load `, :meth:`summarize\n", "` can also be accessed with the shortcut\n", @@ -577,9 +597,9 @@ "source": [ "```python\n", "erlab.io.set_loader(\"merlin\")\n", - "erlab.io.set_data_dir(\"/path/to/data\")\n", - "erlab.io.summarize()\n", - "```" + "erlab.io.summarize(\"/path/to/data\")\n", + "```\n", + "If the path is not specified, the current data directory is used." ] }, { @@ -597,7 +617,9 @@ }, "source": [ "To see what the generated summary looks like, see the :ref:`example below `." + "example>`.\n", + "\n", + ".. _implementing-plugins:" ] }, { @@ -677,32 +699,38 @@ "~~~~~~~~~~~\n", "\n", "There are some rules that loaded ARPES data must follow to ensure that analysis\n", - "procedures such as momentum conversion and fitting works seamlessly:\n", + "procedures such as momentum conversion and fermi edge fitting works seamlessly:\n", "\n", "- The experimental geometry should be stored in the ``'configuration'`` attribute as an\n", " integer. See :ref:`Nomenclature ` and :class:`AxesConfiguration\n", " ` for more information.\n", "- All standard angle coordinates must follow the naming conventions in\n", " :ref:`Nomenclature `.\n", - "- The sample temperature, if available, should be stored in the ``'temp_sample'``\n", - " attribute.\n", - "- The sample work function, if available, should be stored in the\n", - " ``'sample_workfunction'`` attribute.\n", + "- The sample temperature, if available, should be stored in an attribute or coordinate\n", + " named ``'sample_temp'``.\n", + "- The sample work function, if available, should be stored in an attribute named\n", + " ``'sample_workfunction'``.\n", + "- The angular resolution of the analyzer, if available, should be stored in an attribute\n", + " named ``'angle_resolution'``. This is used in estimating momentum grid sizes when\n", + " converting to momentum space.\n", + "\n", + "In addition, use the following units:\n", "- Energies should be given in electronvolts.\n", "- Angles should be given in degrees.\n", "- Temperatures should be given in Kelvins.\n", "\n", - "All loaders by default does a basic check for a subset of these rules using\n", - ":meth:`validate ` and will raise a warning if\n", - "some are missing. This behavior can be controlled with loader class attributes\n", + "All loaders by default does a basic check for some of these rules using :meth:`validate\n", + "` for every data file loaded. A warning is\n", + "issued if some are missing. This behavior can be controlled with loader class attributes\n", ":attr:`skip_validate ` and\n", ":attr:`strict_validation `.\n", "\n", "A minimal example\n", "~~~~~~~~~~~~~~~~~\n", "\n", - "Consider a setup that saves data into a ``.csv`` file named ``data_0001.csv`` and so on.\n", - "A bare minimum implementation of a loader for the setup will look something like this:" + "Consider a setup that saves data into a ``.csv`` file named ``data_0001.csv``,\n", + "``data_0002.csv``, and so on. A bare minimum implementation of a loader for the setup\n", + "will look something like this:" ] }, { @@ -735,10 +763,17 @@ " file = os.path.join(data_dir, f\"data_{str(num).zfill(4)}.csv\")\n", " return [file], {}\n", "\n", - " def load_single(self, file_path):\n", + " def load_single(self, file_path, without_values=False):\n", " return pd.read_csv(file_path).to_xarray()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, the `without_values` argument to `load_single` is unused; it will be explained later." + ] + }, { "cell_type": "code", "execution_count": null, @@ -833,6 +868,7 @@ " \"delta\": \"Azimuth\",\n", " }\n", " )\n", + " dt = datetime.datetime.now()\n", "\n", " # Assign some attributes that real data would have\n", " data = data.assign_attrs(\n", @@ -841,7 +877,8 @@ " \"SpectrumType\": \"Fixed\", # Acquisition mode of the analyzer\n", " \"PassEnergy\": 10, # Pass energy of the analyzer\n", " \"UndPol\": 0, # Undulator polarization\n", - " \"DateTime\": datetime.datetime.now().isoformat(), # Acquisition time\n", + " \"Date\": dt.strftime(r\"%d/%m/%Y\"), # Date of the measurement\n", + " \"Time\": dt.strftime(\"%I:%M:%S %p\"), # Time of the measurement\n", " \"TB\": temp,\n", " \"X\": 0.0,\n", " \"Y\": 0.0,\n", @@ -916,7 +953,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The data has been properly loaded, but the coordinates and attrubutes have names that\n", + "The data has been properly loaded, but the coordinates and attributes have names that\n", "are specific to the beamline.\n", "\n", "Our loader should do three things: rename the coordinates and attributes to standard\n", @@ -989,8 +1026,12 @@ " \"z\": \"Z\",\n", " \"hv\": \"PhotonEnergy\",\n", " \"polarization\": \"UndPol\",\n", - " \"temp_sample\": \"TB\",\n", + " \"sample_temp\": \"TB\",\n", " }\n", + " # Map the names of the coordinates or attributes in the resulting data to the names\n", + " # present in the data returned by `load_single`. Note that the order of\n", + " # non-dimension coordinates in the output data will follow the order of the keys in\n", + " # this dictionary.\n", "\n", " coordinate_attrs: tuple[str, ...] = (\n", " \"beta\",\n", @@ -1002,6 +1043,7 @@ " \"z\",\n", " \"polarization\",\n", " \"photon_flux\",\n", + " \"sample_temp\",\n", " )\n", " # Attributes to be used as coordinates. Place all attributes that we don't want to\n", " # lose when merging multiple file scans here.\n", @@ -1018,8 +1060,6 @@ " # Additional non-dimension coordinates to be added to the data, for instance the\n", " # photon energy for lab-based ARPES.\n", "\n", - " skip_validate = False\n", - "\n", " always_single = False\n", "\n", " def identify(self, num, data_dir):\n", @@ -1055,7 +1095,7 @@ "\n", " return files, coord_dict\n", "\n", - " def load_single(self, file_path):\n", + " def load_single(self, file_path, without_values=False):\n", " return xr.open_dataarray(file_path, engine=\"h5netcdf\")\n", "\n", " def infer_index(self, name):\n", @@ -1072,6 +1112,14 @@ " return None, None" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that there are more class attributes and methods that can be inherited or\n", + "overridden to customize the loader's behavior." + ] + }, { "cell_type": "code", "execution_count": null, @@ -1124,14 +1172,37 @@ }, "source": [ "Brilliant! We now have a working loader for our hypothetical setup. However, we can't\n", - "use :func:`erlab.io.summarize` with our loader since we haven't implemented\n", - ":meth:`generate_summary `.\n", - "\n", - "This method should return a :class:`pandas.DataFrame` with the index containing file\n", - "names. The only requirement for the DataFrame is that it should include a column named\n", - "``'Path'`` that contains the paths to the data files. Other than that, the DataFrame can\n", - "contain any metadata you wish to display in the summary. Let's implement it in a\n", - "subclass of the ``example`` loader:" + "use :func:`erlab.io.summarize` with our loader yet.\n", + "\n", + "To enable summary generation, we need to implement two attributes and one method:\n", + "\n", + "- :attr:`formatters `: A dictionary that maps\n", + " attribute or coordinate names in the data to functions that convert the coordinate or\n", + " attribute value into a human-readable form.\n", + "\n", + "- :attr:`summary_attrs `: A dictionary\n", + " that maps summary column names to attribute or coordinate names in the data. A\n", + " callable can also be used to generate entries for attributes that are not directly\n", + " present in the data. \n", + "\n", + "- :meth:`files_for_summary `: A method\n", + " that takes a path to a directory and returns a list of file paths in the directory\n", + " that are associated with the loader. \n", + "\n", + "You can also choose to implement the following attribute to further customize the summary:\n", + "\n", + "- :attr:`summary_sort `: A string that determines the\n", + " column name to sort the summary table with. If not provided, the table will respect\n", + " the order of the files returned by :meth:`files_for_summary\n", + " `.\n", + "\n", + "To improve the performance of summary generation, you can optionally implement\n", + ":meth:`load_single ` to utilize the\n", + "``without_values`` argument. If it is True, it means that the values in the\n", + "returned data of :meth:`load_single ` will\n", + "not be accessed, so you can return the data with its values set to arbitrary numbers.\n", + "This is useful when only the metadata is needed for the summary. An example of this will\n", + "be shown below.\n" ] }, { @@ -1140,100 +1211,76 @@ "metadata": {}, "outputs": [], "source": [ + "def _format_polarization(val) -> str:\n", + " val = round(float(val))\n", + " return {0: \"LH\", 2: \"LV\", -1: \"RC\", 1: \"LC\"}.get(val, str(val))\n", + "\n", + "\n", + "def _parse_time(darr: xr.DataArray) -> datetime.datetime:\n", + " return datetime.datetime.strptime(\n", + " f\"{darr.attrs['Date']} {darr.attrs['Time']}\",\n", + " \"%d/%m/%Y %I:%M:%S %p\",\n", + " )\n", + "\n", + "\n", + "def _determine_kind(darr: xr.DataArray) -> str:\n", + " if \"scan_type\" in darr.attrs and darr.attrs[\"scan_type\"] == \"live\":\n", + " return \"LP\" if \"beta\" in darr.dims else \"LXY\"\n", + "\n", + " data_type = \"xps\"\n", + " if \"alpha\" in darr.dims:\n", + " data_type = \"cut\"\n", + " if \"beta\" in darr.dims:\n", + " data_type = \"map\"\n", + " if \"hv\" in darr.dims:\n", + " data_type = \"hvdep\"\n", + " return data_type\n", + "\n", + "\n", "class ExampleLoaderComplete(ExampleLoader):\n", " name = \"example_complete\"\n", " aliases = [\"ExC\"]\n", "\n", - " def generate_summary(self, data_dir):\n", - " # Get all valid data files in directory\n", - " files = {}\n", - " for path in erlab.io.utils.get_files(data_dir, extensions=[\".h5\"]):\n", - " # If multiple scans, strip the _S### part\n", - " name_match = re.match(r\"(.*?_\\d{3})_(?:_S\\d{3})?\", path.stem)\n", - " data_name = path.stem if name_match is None else name_match.group(1)\n", - " files[data_name] = str(path)\n", - "\n", - " # Map dataframe column names to data attributes\n", - " attrs_mapping = {\n", - " \"Lens Mode\": \"LensMode\",\n", - " \"Scan Type\": \"SpectrumType\",\n", - " \"T(K)\": \"temp_sample\",\n", - " \"Pass E\": \"PassEnergy\",\n", - " \"Polarization\": \"polarization\",\n", - " \"hv\": \"hv\",\n", - " \"x\": \"x\",\n", - " \"y\": \"y\",\n", - " \"z\": \"z\",\n", - " \"polar\": \"beta\",\n", - " \"tilt\": \"xi\",\n", - " \"azi\": \"delta\",\n", - " }\n", - " column_names = [\"File Name\", \"Path\", \"Time\", \"Type\", *attrs_mapping.keys()]\n", - "\n", - " data_info = []\n", - "\n", - " processed_indices = set()\n", - " for name, path in files.items():\n", - " # Skip already processed multi-file scans\n", - " index, _ = self.infer_index(name)\n", - " if index in processed_indices:\n", - " continue\n", - " if index is not None:\n", - " processed_indices.add(index)\n", - "\n", - " # Load data\n", - " data = self.load(path)\n", - "\n", - " # Determine type of scan\n", - " data_type = \"core\"\n", - " if \"alpha\" in data.dims:\n", - " data_type = \"cut\"\n", - " if \"beta\" in data.dims:\n", - " data_type = \"map\"\n", - " if \"hv\" in data.dims:\n", - " data_type = \"hvdep\"\n", - "\n", - " data_info.append(\n", - " [\n", - " name,\n", - " path,\n", - " datetime.datetime.fromisoformat(data.attrs[\"DateTime\"]),\n", - " data_type,\n", - " ]\n", + " formatters = {\n", + " \"polarization\": _format_polarization,\n", + " \"LensMode\": lambda x: x.replace(\"Angular\", \"A\"),\n", + " }\n", + "\n", + " summary_attrs = {\n", + " \"Time\": _parse_time,\n", + " \"Type\": _determine_kind,\n", + " \"Lens Mode\": \"LensMode\",\n", + " \"Scan Type\": \"SpectrumType\",\n", + " \"T(K)\": \"sample_temp\",\n", + " \"Pass E\": \"PassEnergy\",\n", + " \"Polarization\": \"polarization\",\n", + " \"hv\": \"hv\",\n", + " \"x\": \"x\",\n", + " \"y\": \"y\",\n", + " \"z\": \"z\",\n", + " \"polar\": \"beta\",\n", + " \"tilt\": \"xi\",\n", + " \"azi\": \"delta\",\n", + " }\n", + "\n", + " summary_sort = \"Time\"\n", + "\n", + " def load_single(self, file_path, without_values=False):\n", + " darr = xr.open_dataarray(file_path, engine=\"h5netcdf\")\n", + "\n", + " if without_values:\n", + " # Do not load the data into memory\n", + " return xr.DataArray(\n", + " np.zeros(darr.shape, darr.dtype),\n", + " coords=darr.coords,\n", + " dims=darr.dims,\n", + " attrs=darr.attrs,\n", " )\n", "\n", - " for k, v in attrs_mapping.items():\n", - " # Try to get the attribute from the data, then from the coordinates\n", - " try:\n", - " val = data.attrs[v]\n", - " except KeyError:\n", - " try:\n", - " val = data.coords[v].values\n", - " if val.size == 1:\n", - " val = val.item()\n", - " except KeyError:\n", - " val = \"\"\n", - "\n", - " # Convert polarization values to human readable form\n", - " if k == \"Polarization\":\n", - " if np.iterable(val):\n", - " val = np.asarray(val).astype(int)\n", - " else:\n", - " val = [round(val)]\n", - " val = [{0: \"LH\", 2: \"LV\", -1: \"RC\", 1: \"LC\"}.get(v, v) for v in val]\n", - " if len(val) == 1:\n", - " val = val[0]\n", - "\n", - " data_info[-1].append(val)\n", - "\n", - " del data\n", - "\n", - " # Sort by time and set index\n", - " return (\n", - " pd.DataFrame(data_info, columns=column_names)\n", - " .sort_values(\"Time\")\n", - " .set_index(\"File Name\")\n", - " )\n", + " return darr\n", + "\n", + " def files_for_summary(self, data_dir):\n", + " return erlab.io.utils.get_files(data_dir, extensions=[\".h5\"])\n", "\n", "\n", "erlab.io.loaders" @@ -1255,11 +1302,7 @@ "source": [ ".. _summary example:\n", "\n", - "The implementation looks complicated, but most of the code is boilerplate, and the\n", - "actual logic is quite simple. You get a list of file names and paths to generate a\n", - "summary for, define DataFrame columns and corresponding attributes, and then load the\n", - "data one by one and extract the metadata. Let's see how the resulting summary looks\n", - "like.\n", + "Let's see how the resulting summary looks like.\n", "\n", ".. note::\n", "\n", @@ -1292,17 +1335,15 @@ }, "source": [ "Each cell in the summary table is formatted with :meth:`formatter\n", - "`. If additional formatting that cannot be\n", - "achieved within :meth:`generate_summary\n", - "` is needed, :meth:`formatter\n", - "` can be inherited in the subclass.\n", + "` after applying the :attr:`formatters\n", + "`.\n", "\n", "Tips\n", "~~~~\n", "\n", "- The data loading framework is designed to be simple and flexible, but it may not cover\n", " all possible setups. If you encounter a setup that cannot be loaded with the existing\n", - " loaders, please let us know by opening an issue!\n", + " api, please let us know by opening an issue!\n", "\n", "- Before implementing a loader, see :doc:`../generated/erlab.io.dataloader` for\n", " descriptions about each attribute, and the values and types of the expected outputs.\n", @@ -1310,11 +1351,6 @@ " starting point; see the `source code on github\n", " `_.\n", "\n", - "- If you have implemented a new loader or have improved an existing one, consider\n", - " contributing it to the ERLabPy project by opening a pull request. We are always\n", - " looking for new loaders to support more experimental setups! See more about\n", - " contributing in the :doc:`../contributing`.\n", - "\n", "- If you wish to add post-processing steps that are applicable to all data loaded by\n", " that loader such as fixing the sign of the binding energy coordinates, you can inherit\n", " the :meth:`post_process ` which by\n", @@ -1340,7 +1376,12 @@ " ` from an identifier given as a file name. This\n", " is where the second return value of :meth:`infer_index\n", " ` comes in handy, where you can return a\n", - " dictionary which is passed to :meth:`load `." + " dictionary which is passed to :meth:`load `.\n", + "\n", + "- If you have implemented a new loader or have improved an existing one, consider\n", + " contributing it to the ERLabPy project by opening a pull request. We are always\n", + " looking for new loaders to support more experimental setups! See more about\n", + " contributing in the :doc:`../contributing`.\n" ] }, { @@ -1375,7 +1416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/docs/source/user-guide/kconv.ipynb b/docs/source/user-guide/kconv.ipynb index 8808480f..713eaae9 100644 --- a/docs/source/user-guide/kconv.ipynb +++ b/docs/source/user-guide/kconv.ipynb @@ -53,14 +53,48 @@ "source": [ "Momentum conversion in ERLabPy follows the nomenclature from :cite:t:`ishida2018kconv`.\n", "All experimental geometry are classified into 4 types. Definition of angles differ for\n", - "each geometry.\n", - "\n", + "each geometry.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "For instance, imagine a typical Type 1 setup with a vertical slit that acquires maps by\n", "rotating about the `z` axis in the lab frame. In this case, the polar angle (rotation\n", - "about `z`) is :math:`β`, and the tilt angle is :math:`ξ`.\n", + "about `z`) is $\\beta$, and the tilt angle is $\\xi$.\n", + "\n", + "In all cases, $\\delta$ is the azimuthal angle that indicates in-plane rotation, and\n", + "$\\alpha$ is the angle detected by the analyzer.\n", "\n", - "In all cases, :math:`δ` is the azimuthal angle that indicates in-plane rotation, and\n", - ":math:`α` is the angle along the slit." + "The following table summarizes angle conventions for commonly encountered configurations." + ] + }, + { + "cell_type": "raw", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [], + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "+---------------------------+---------------+-------------------+-------+------+-----------+---------+----------+\n", + "| Analyzer slit orientation | Mapping angle | Configuration | Polar | Tilt | Deflector | Azimuth | Analyzer |\n", + "+===========================+===============+===================+=======+======+===========+=========+==========+\n", + "| Vertical | Polar | 1 (Type 1) | beta | xi | | delta | alpha |\n", + "+---------------------------+---------------+-------------------+-------+------+-----------+ | |\n", + "| Horizontal | Tilt | 2 (Type 2) | xi | beta | | | |\n", + "+---------------------------+---------------+-------------------+-------+------+-----------+ | |\n", + "| Vertical | Deflector | 3 (Type 1 + DA) | chi | xi | beta | | |\n", + "+---------------------------+ +-------------------+ | | | | |\n", + "| Horizontal | | 4 (Type 2 + DA) | | | | | |\n", + "+---------------------------+---------------+-------------------+-------+------+-----------+---------+----------+" ] }, { @@ -394,7 +428,7 @@ " :class: only-dark\n", "\n", "The second tab provides visualization options. You can overlay Brillouin zones and high\n", - "symmetry points on the result, adjust colors, and apply binning." + "symmetry points on the result, adjust colors, apply binning, and more." ] } ], diff --git a/environment.yml b/environment.yml index ccdf7a28..239e72e9 100644 --- a/environment.yml +++ b/environment.yml @@ -21,4 +21,4 @@ dependencies: - scipy>=1.13.0 - tqdm>=4.66.2 - varname>=0.13.0 - - xarray>=2024.07.0 + - xarray>=2024.10.0 diff --git a/pyproject.toml b/pyproject.toml index 33cbd5bc..fc1fee94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "scipy>=1.13.0", "tqdm>=4.66.2", "varname>=0.13.0", - "xarray>=2024.07.0", + "xarray>=2024.10.0", ] [project.optional-dependencies] @@ -52,11 +52,10 @@ dev = [ "pre-commit>=3.7.0", "pytest-cov>=5.0.0", "pytest-qt>=4.4.0", - "pytest-xdist>=3.6.0", "pytest-datadir>=1.5.0", "pytest>=8.3.0", "commitizen>=3.29.0", - "cz-changeup>=0.3.0", + "cz-changeup>=1.0.1", "ruff>=0.6.3", ] docs = [ @@ -67,6 +66,7 @@ docs = [ "sphinx-qt-documentation", "pybtex", "nbsphinx", + "myst-parser", "furo>=2024.07.18", "sphinx-design", ] @@ -84,6 +84,8 @@ Changelog = "https://github.com/kmnhan/erlabpy/blob/main/CHANGELOG.md" erlab-igor = "erlab.io.igor:IgorBackendEntrypoint" [tool.commitizen] +change_type_map = { "BREAKING CHANGE" = "Breaking Changes", "feat" = "Features", "fix" = "Bug Fixes", "refactor" = "Code Refactor", "perf" = "Performance" } + version_provider = "pep621" update_changelog_on_bump = true tag_format = "v$version" @@ -91,14 +93,12 @@ changelog_start_rev = "v1.2.1" changelog_merge_prerelease = true name = 'cz_changeup' change_type_order = [ - "BREAKING CHANGE", + "Breaking Changes", "Features", "Bug Fixes", "Code Refactor", "Performance", ] -change_type_map = { "feat" = "Features", "fix" = "Bug Fixes", "refactor" = "Code Refactor", "perf" = "Performance" } - # cz-changeup configuration changeup_repo_base_url = "https://github.com/kmnhan/erlabpy" changeup_show_body = true diff --git a/src/erlab/accessors/__init__.py b/src/erlab/accessors/__init__.py index e16c1d24..38bba492 100644 --- a/src/erlab/accessors/__init__.py +++ b/src/erlab/accessors/__init__.py @@ -3,22 +3,8 @@ `_ for convenient data analysis and visualization. -.. currentmodule:: erlab.accessors - -Modules -======= - -.. autosummary:: - :toctree: generated - - utils - general - kspace - fit - -ERLabPy provides a collection of `xarray accessors -`_ for convenient -data analysis and visualization. The following table lists the available accessors. +ERLabPy provides a collection of accessors for convenient data analysis and +visualization. The following table lists all accessors provided by ERLabPy: .. list-table:: :header-rows: 1 @@ -37,4 +23,17 @@ * - :class:`da.kspace ` - Momentum conversion +.. currentmodule:: erlab.accessors + +Modules +======= + +.. autosummary:: + :toctree: generated + + utils + general + kspace + fit + """ # noqa: D205 diff --git a/src/erlab/accessors/fit.py b/src/erlab/accessors/fit.py index 617a6783..f401188e 100644 --- a/src/erlab/accessors/fit.py +++ b/src/erlab/accessors/fit.py @@ -11,7 +11,6 @@ import contextlib import copy import itertools -import warnings from collections.abc import Collection, Hashable, Iterable, Mapping, Sequence from typing import Any, Literal, cast @@ -26,6 +25,7 @@ ERLabDataArrayAccessor, ERLabDatasetAccessor, ) +from erlab.utils.misc import emit_user_level_warning from erlab.utils.parallel import joblib_progress @@ -57,10 +57,10 @@ def _concat_along_keys(d: dict[str, xr.DataArray], dim_name: str) -> xr.DataArra def _parse_params( d: dict[str, Any] | lmfit.Parameters, dask: bool -) -> xr.DataArray | _ParametersWraper: +) -> xr.DataArray | _ParametersWrapper: if isinstance(d, lmfit.Parameters): # Input to apply_ufunc cannot be a Mapping, so wrap in a class - return _ParametersWraper(d) + return _ParametersWrapper(d) # Iterate over all values for v in _nested_dict_vals(d): @@ -69,7 +69,7 @@ def _parse_params( # convert to str return _parse_multiple_params(copy.deepcopy(d), dask) - return _ParametersWraper(lmfit.create_params(**d)) + return _ParametersWrapper(lmfit.create_params(**d)) def _parse_multiple_params(d: dict[str, Any], as_str: bool) -> xr.DataArray: @@ -110,7 +110,7 @@ def _reduce_to_param(arr, axis=0): return da.reduce(_reduce_to_param, ("__dict_keys", "__param_names")) -class _ParametersWraper: +class _ParametersWrapper: def __init__(self, params: lmfit.Parameters) -> None: self.params = params @@ -129,7 +129,7 @@ def __call__( | dict[str, float | dict[str, Any]] | xr.DataArray | xr.Dataset - | _ParametersWraper + | _ParametersWrapper | None = None, guess: bool = False, errors: Literal["raise", "ignore"] = "raise", @@ -337,7 +337,7 @@ def _wrapper(Y, *args, **kwargs): initial_params = lmfit.create_params() if guess else model.make_params() - if isinstance(init_params_, _ParametersWraper): + if isinstance(init_params_, _ParametersWrapper): initial_params.update(init_params_.params) elif isinstance(init_params_, str): @@ -404,10 +404,9 @@ def _wrapper(Y, *args, **kwargs): initial_params ) except NotImplementedError: - warnings.warn( + emit_user_level_warning( f"`guess` is not implemented for {model}, " - "using supplied initial parameters", - stacklevel=1, + "using supplied initial parameters" ) initial_params = model.make_params().update(initial_params) try: @@ -530,10 +529,9 @@ def _output_wrapper(name, da, out=None) -> dict: if parallel: if is_dask: - warnings.warn( + emit_user_level_warning( "The input Dataset is chunked. Parallel fitting will not offer any " - "performance benefits.", - stacklevel=1, + "performance benefits." ) parallel_kw.setdefault("n_jobs", -1) diff --git a/src/erlab/accessors/general.py b/src/erlab/accessors/general.py index 78d85721..a4954491 100644 --- a/src/erlab/accessors/general.py +++ b/src/erlab/accessors/general.py @@ -7,9 +7,10 @@ "SelectionAccessor", ] +import functools import importlib -import warnings from collections.abc import Hashable, Mapping +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -20,6 +21,8 @@ ERLabDatasetAccessor, either_dict_or_kwargs, ) +from erlab.utils.formatting import format_html_table +from erlab.utils.misc import emit_user_level_warning @xr.register_dataarray_accessor("qplot") @@ -245,10 +248,9 @@ def get_comps(*s): part_params = hvplot.bind(get_slice_params, *sliders).interactive() if "modelfit_results" not in self._obj.data_vars: - warnings.warn( - "`model_results` not included in Dataset. " - "Components will not be plotted", - stacklevel=2, + emit_user_level_warning( + "`modelfit_results` not included in Dataset. " + "Components will not be plotted" ) plot_components = False @@ -391,6 +393,7 @@ def __call__( scalars: dict[Hashable, float] = {} slices: dict[Hashable, slice] = {} avg_dims: list[Hashable] = [] + lost_dims: list[Hashable] = [] for dim, width in bin_widths.items(): value = indexers[dim] @@ -408,6 +411,9 @@ def __call__( ) slices[dim] = slice(value - width / 2.0, value + width / 2.0) avg_dims.append(dim) + for k, v in self._obj.coords.items(): + if dim in v.dims: + lost_dims.append(k) unindexed_dims: list[Hashable] = [ k for k in slices | scalars if k not in self._obj.indexes @@ -423,9 +429,8 @@ def __call__( if len(scalars) >= 1: for k, v in scalars.items(): if v < out[k].min() or v > out[k].max(): - warnings.warn( - f"Selected value {v} for `{k}` is outside coordinate bounds", - stacklevel=2, + emit_user_level_warning( + f"Selected value {v} for `{k}` is outside coordinate bounds" ) out = out.sel({str(k): v for k, v in scalars.items()}, method="nearest") @@ -433,7 +438,7 @@ def __call__( out = out.sel(slices) lost_coords = { - k: out[k].mean() for k in avg_dims if k not in unindexed_dims + k: out[k].mean() for k in lost_dims if k not in unindexed_dims } out = out.mean(dim=avg_dims, keep_attrs=True) out = out.assign_coords(lost_coords) @@ -502,3 +507,59 @@ def around( if average: return masked.mean(sel_kw.keys()) return masked + + +@xr.register_dataarray_accessor("qinfo") +class InfoDataArrayAccessor(ERLabDataArrayAccessor): + """`xarray.Dataset.qinfo` accessor for displaying information about the data.""" + + def get_value(self, attr_or_coord_name: str) -> Any: + """Get the value of the specified attribute or coordinate. + + If the attribute or coordinate is not found, `None` is returned. + + Parameters + ---------- + attr_or_coord_name + The name of the attribute or coordinate. + + """ + if attr_or_coord_name in self._obj.attrs: + return self._obj.attrs[attr_or_coord_name] + if attr_or_coord_name in self._obj.coords: + return self._obj.coords[attr_or_coord_name] + return None + + @functools.cached_property + def _summary_table(self) -> list[tuple[str, str, str]]: + import erlab.io + + if "data_loader_name" in self._obj.attrs: + loader = erlab.io.loaders[self._obj.attrs["data_loader_name"]] + else: + raise ValueError("Data loader information not found in data attributes") + + out: list[tuple[str, str, str]] = [] + + for key, true_key in loader.summary_attrs.items(): + val = loader.get_formatted_attr_or_coord(self._obj, true_key) + if callable(true_key): + true_key = "" + out.append((key, loader.value_to_string(val), true_key)) + + return out + + def _repr_html_(self) -> str: + return format_html_table( + [("Name", "Value", "Key"), *self._summary_table], + header_cols=1, + header_rows=1, + ) + + def __repr__(self) -> str: + return "\n".join( + [ + f"{key}: {val}" if not true_key else f"{key} ({true_key}): {val}" + for key, val, true_key in self._summary_table + ] + ) diff --git a/src/erlab/accessors/kspace.py b/src/erlab/accessors/kspace.py index e147d89a..e1806890 100644 --- a/src/erlab/accessors/kspace.py +++ b/src/erlab/accessors/kspace.py @@ -4,7 +4,6 @@ import functools import time -import warnings from collections.abc import Hashable, ItemsView, Iterable, Iterator, Mapping from typing import Literal, Self, cast @@ -16,6 +15,7 @@ from erlab.analysis.kspace import AxesConfiguration, get_kconv_func, kz_func from erlab.constants import rel_kconv, rel_kzconv from erlab.utils.formatting import format_html_table +from erlab.utils.misc import emit_user_level_warning def only_angles(method=None): @@ -212,9 +212,8 @@ def inner_potential(self) -> float: """ if "inner_potential" in self._obj.attrs: return float(self._obj.attrs["inner_potential"]) - warnings.warn( + emit_user_level_warning( "Inner potential not found in data attributes, assuming 10 eV", - stacklevel=1, ) return 10.0 @@ -243,9 +242,8 @@ def work_function(self) -> float: """ if "sample_workfunction" in self._obj.attrs: return float(self._obj.attrs["sample_workfunction"]) - warnings.warn( - "Work function not found in data attributes, assuming 4.5 eV", - stacklevel=1, + emit_user_level_warning( + "Work function not found in data attributes, assuming 4.5 eV" ) return 4.5 @@ -262,11 +260,23 @@ def angle_resolution(self) -> float: This property is used in `best_kp_resolution` upon estimating momentum step sizes through `estimate_resolution`. + + Note + ---- + This property provides a setter method that takes a float value and sets the + data attribute accordingly. + + Example + ------- + >>> data.kspace.angle_resolution = 0.05 + >>> data.kspace.angle_resolution + 0.05 + """ try: return float(self._obj.attrs["angle_resolution"]) except KeyError: - # warnings.warn( + # emit_user_level_warning( # "Angle resolution not found in data attributes, assuming 0.1 degrees" # ) return 0.1 @@ -277,7 +287,7 @@ def angle_resolution(self, value: float) -> None: @property def slit_axis(self) -> Literal["kx", "ky"]: - """Return the momentum axis parallel to the slit. + """Return the momentum axis parallel to the analyzer slit. Returns ------- @@ -292,7 +302,7 @@ def slit_axis(self) -> Literal["kx", "ky"]: @property def other_axis(self) -> Literal["kx", "ky"]: - """Return the momentum axis perpendicular to the slit. + """Return the momentum axis perpendicular to the analyzer slit. Returns ------- @@ -526,7 +536,7 @@ def estimate_bounds(self) -> dict[Literal["kx", "ky", "kz"], tuple[float, float] Returns ------- - bounds : dict[str, tuple[float, float]] + bounds : dict of str to tuple of float A dictionary containing the estimated bounds for each parameter. The keys of the dictionary are 'kx', 'ky', and 'kz' (for :math:`hν`-dependent data). The values are tuples representing the minimum and maximum values. @@ -638,7 +648,7 @@ def _inverse_broadcast(self, kx, ky, kz=None) -> dict[str, xr.DataArray]: @only_angles def convert_coords(self) -> xr.DataArray: - """Convert the coordinates to momentum space. + """Convert coordinates to momentum space. Assigns new exact momentum coordinates to the data. This is useful when you want to work with momentum coordinates but don't want to interpolate the data. @@ -652,7 +662,7 @@ def convert_coords(self) -> xr.DataArray: @only_angles def _get_coord_for_conversion(self, name: Hashable) -> xr.DataArray: - """Get the coordinte array for given dimension name. + """Get the coordinate array for given dimension name. This just ensures that the energy coordinates are given as binding energy. """ diff --git a/src/erlab/analysis/__init__.py b/src/erlab/analysis/__init__.py index 5badf939..4219d95f 100644 --- a/src/erlab/analysis/__init__.py +++ b/src/erlab/analysis/__init__.py @@ -20,42 +20,4 @@ """ -import warnings - from erlab.analysis import fit, gold, image, interpolate, mask, transform # noqa: F401 - - -def correct_with_edge(*args, **kwargs): - from erlab.analysis.gold import correct_with_edge - - warnings.warn( - "importing as erlab.analysis.correct_with_edge is deprecated, " - "use erlab.analysis.gold.correct_with_edge instead", - DeprecationWarning, - stacklevel=2, - ) - return correct_with_edge(*args, **kwargs) - - -def quick_resolution(*args, **kwargs): - from erlab.analysis.gold import quick_resolution - - warnings.warn( - "importing as erlab.analysis.quick_resolution is deprecated, " - "use erlab.analysis.gold.quick_resolution instead", - DeprecationWarning, - stacklevel=2, - ) - return quick_resolution(*args, **kwargs) - - -def slice_along_path(*args, **kwargs): - from erlab.analysis.interpolate import slice_along_path - - warnings.warn( - "importing as erlab.analysis.slice_along_path is deprecated, " - "use erlab.analysis.interpolate.slice_along_path instead", - DeprecationWarning, - stacklevel=2, - ) - return slice_along_path(*args, **kwargs) diff --git a/src/erlab/analysis/fit/functions/general.py b/src/erlab/analysis/fit/functions/general.py index d873ded5..25904321 100644 --- a/src/erlab/analysis/fit/functions/general.py +++ b/src/erlab/analysis/fit/functions/general.py @@ -110,7 +110,7 @@ def do_convolve( Parameters ---------- x - A evenly spaced array specifing where to evaluate the convolution. + A evenly spaced array specifying where to evaluate the convolution. func Function to convolve. resolution diff --git a/src/erlab/analysis/fit/models.py b/src/erlab/analysis/fit/models.py index 7b876d98..b1208424 100644 --- a/src/erlab/analysis/fit/models.py +++ b/src/erlab/analysis/fit/models.py @@ -10,7 +10,6 @@ "StepEdgeModel", ] -import contextlib from typing import Literal import lmfit @@ -178,10 +177,11 @@ def guess(self, data, x, **kwargs): np.argmin(np.gradient(scipy.ndimage.gaussian_filter1d(data, 0.2 * len(x)))) ] - temp = 30.0 + temp = None if isinstance(data, xr.DataArray): - with contextlib.suppress(KeyError): - temp = float(data.attrs["temp_sample"]) + temp = data.qinfo.get_value("sample_temp") + if temp is None: + temp = 30.0 pars[f"{self.prefix}center"].set( value=efermi, min=np.asarray(x).min(), max=np.asarray(x).max() @@ -190,7 +190,7 @@ def guess(self, data, x, **kwargs): pars[f"{self.prefix}back1"].set(value=back1) pars[f"{self.prefix}dos0"].set(value=dos0) pars[f"{self.prefix}dos1"].set(value=dos1) - pars[f"{self.prefix}temp"].set(value=temp) + pars[f"{self.prefix}temp"].set(value=float(temp)) pars[f"{self.prefix}resolution"].set(value=0.02) return lmfit.models.update_param_vals(pars, self.prefix, **kwargs) @@ -418,7 +418,9 @@ def guess(self, data, eV, alpha, **kwargs): pars[f"{self.prefix}lin_bkg"].set(value=dos1) if isinstance(data, xr.DataArray): - pars[f"{self.prefix}temp"].set(value=data.attrs["temp_sample"]) + temp = data.qinfo.get_value("sample_temp") + if temp is not None: + pars[f"{self.prefix}temp"].set(value=float(temp)) return lmfit.models.update_param_vals(pars, self.prefix, **kwargs) diff --git a/src/erlab/analysis/gold.py b/src/erlab/analysis/gold.py index e2b0071d..c417202c 100644 --- a/src/erlab/analysis/gold.py +++ b/src/erlab/analysis/gold.py @@ -221,8 +221,13 @@ def edge( model_cls: lmfit.Model = StepEdgeModel else: if temp is None: - temp = gold.attrs["temp_sample"] - params = lmfit.create_params(temp={"value": temp, "vary": vary_temp}) + temp = gold.qinfo.get_value("sample_temp") + if temp is None: + raise ValueError( + "Temperature not found in data attributes, please provide manually" + ) + + params = lmfit.create_params(temp={"value": float(temp), "vary": vary_temp}) model_cls = FermiEdgeModel model = model_cls() @@ -556,9 +561,10 @@ def quick_fit( bkg_slope: bool = True, **kwargs, ) -> xr.Dataset: - """Perform a quick Fermi edge fit on the given data. + """Perform a quick Fermi edge fit on an EDC. - The data is averaged over all dimensions except the energy prior to fitting. + If data with 2 or more dimensions is provided, the data is averaged over all + dimensions except the energy prior to fitting. Parameters ---------- @@ -601,9 +607,8 @@ def quick_fit( data_fit = data.sel(eV=slice(*eV_range)) if eV_range is not None else data if temp is None: - if "temp_sample" in data.attrs: - temp = float(data.attrs["temp_sample"]) - else: + temp = data.qinfo.get_value("sample_temp") + if temp is None: raise ValueError( "Temperature not found in data attributes, please provide manually" ) @@ -632,9 +637,7 @@ def quick_fit( def quick_resolution( - darr: xr.DataArray, - ax: matplotlib.axes.Axes | None = None, - **kwargs, + darr: xr.DataArray, ax: matplotlib.axes.Axes | None = None, **kwargs ) -> xr.Dataset: """Fit a Fermi edge to the given data and plot the results. @@ -661,23 +664,50 @@ def quick_resolution( if ax is None: ax = plt.gca() - darr = darr.mean([d for d in darr.dims if d != "eV"]) result = quick_fit(darr, **kwargs) + data = darr.mean([d for d in darr.dims if d != "eV"]) + ax.plot( - darr.eV, darr, ".", mec="0.6", alpha=1, mfc="none", ms=5, mew=0.3, label="Data" + data.eV, data, ".", mec="0.6", alpha=1, mfc="none", ms=5, mew=0.3, label="Data" ) result.modelfit_best_fit.qplot(ax=ax, c="r", label="Fit") ax.set_ylabel("Intensity (arb. units)") - if (darr.eV[0] * darr.eV[-1]) < 0: + if (data.eV[0] * data.eV[-1]) < 0: ax.set_xlabel("$E - E_F$ (eV)") else: ax.set_xlabel(r"$E_{kin}$ (eV)") coeffs = result.modelfit_coefficients - center = result.modelfit_results.item().uvars["center"] - resolution = result.modelfit_results.item().uvars["resolution"] + modelresult: lmfit.model.ModelResult = result.modelfit_results.item() + + if hasattr(modelresult, "uvars"): + center = modelresult.uvars["center"] + resolution = modelresult.uvars["resolution"] + center_bounds = ((center - resolution).n, (center + resolution).n) + + center_repr = ( + f"$E_F = {center * 1e3:L}$ meV" + if center < 0.1 + else f"$E_F = {center:L}$ eV" + ) + resolution_repr = f"$\\Delta E = {resolution * 1e3:L}$ meV" + + else: + center = coeffs.sel(param="center") + resolution = coeffs.sel(param="resolution") + center_bounds = (center - resolution, center + resolution) + + center_repr = ( + f"$E_F = {center * 1e3:.3f}$ meV" + if center < 0.1 + else f"$E_F = {center:.6f}$ eV" + ) + resolution_repr = f"$\\Delta E = {resolution * 1e3:.3f}$ meV" + + if kwargs.get("fix_center", False): + center_repr = "" fig = ax.figure if fig is not None: @@ -685,13 +715,7 @@ def quick_resolution( 0, 0, "\n".join( - [ - f"$T ={coeffs.sel(param='temp'):.3f}$ K", - f"$E_F = {center * 1e3:L}$ meV" - if center < 0.1 - else f"$E_F = {center:L}$ eV", - f"$\\Delta E = {resolution * 1e3:L}$ meV", - ] + [f"$T ={coeffs.sel(param='temp'):.3f}$ K", center_repr, resolution_repr] ), ha="left", va="baseline", @@ -700,16 +724,10 @@ def quick_resolution( 6 / 72, 6 / 72, fig.dpi_scale_trans ), ) - ax.set_xlim(darr.eV[[0, -1]]) + ax.set_xlim(data.eV[[0, -1]]) ax.set_title("") ax.axvline(coeffs.sel(param="center"), ls="--", c="k", lw=0.4, alpha=0.5) - ax.axvspan( - (center - resolution).n, - (center + resolution).n, - color="r", - alpha=0.2, - label="FWHM", - ) + ax.axvspan(*center_bounds, color="r", alpha=0.2, label="FWHM") return result @@ -746,7 +764,7 @@ def resolution( edc_avg = gold_roi.mean("alpha").sel(eV=slice(*eV_range_fit)) params = lmfit.create_params( - temp={"value": gold_roi.attrs["temp_sample"], "vary": False}, + temp={"value": gold_roi.attrs["sample_temp"], "vary": False}, resolution={"value": 0.1, "vary": True, "min": 0}, ) model = FermiEdgeModel() @@ -799,7 +817,7 @@ def resolution_roi( edc_avg = gold_roi.mean("alpha").sel(eV=slice(*eV_range)) params = lmfit.create_params( - temp={"value": gold_roi.attrs["temp_sample"], "vary": not fix_temperature}, + temp={"value": gold_roi.attrs["sample_temp"], "vary": not fix_temperature}, resolution={"value": 0.1, "vary": True, "min": 0}, ) model = FermiEdgeModel() diff --git a/src/erlab/analysis/interpolate.py b/src/erlab/analysis/interpolate.py index bed63ca8..4c34d1eb 100644 --- a/src/erlab/analysis/interpolate.py +++ b/src/erlab/analysis/interpolate.py @@ -4,7 +4,6 @@ import itertools import math -import warnings from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from typing import cast @@ -16,6 +15,7 @@ import xarray.core.missing from erlab.accessors.utils import either_dict_or_kwargs +from erlab.utils.misc import emit_user_level_warning class FastInterpolator(scipy.interpolate.RegularGridInterpolator): @@ -38,10 +38,8 @@ class FastInterpolator(scipy.interpolate.RegularGridInterpolator): * `method` is ``"linear"``. * Coordinates along all dimensions are evenly spaced. - * Values are 1D, 2D or 3D. + * Points are 1D, 2D or 3D. * Extrapolation is disabled, i.e., `fill_value` is not `None`. - * The dimension of coordinates `xi` matches the number of dimensions of the values. - Also, each coordinate array in `xi` must have the same shape. See Also -------- @@ -103,44 +101,42 @@ def from_xarray( ) def __call__(self, xi, method: str | None = None): + ndim: int = len(self.grid) is_linear: bool = method == "linear" or self.method == "linear" - nd_supported: bool = self.values.ndim in (1, 2, 3) + nd_supported: bool = ndim in (1, 2, 3) no_extrap: bool = self.fill_value is not None if (len(self.uneven_dims) == 0) and is_linear and nd_supported and no_extrap: if isinstance(xi, np.ndarray): - xi = tuple(xi.take(i, axis=-1) for i in range(xi.shape[-1])) + if xi.ndim == 1: + xi_tuple = (xi,) + else: + xi_tuple = tuple(xi.take(i, axis=-1) for i in range(xi.shape[-1])) + else: + xi_tuple = tuple(xi) - xi_shapes = [x.shape for x in xi] + xi_shapes = [x.shape for x in xi_tuple] if not all(s == xi_shapes[0] for s in xi_shapes): - warnings.warn( + emit_user_level_warning( "Not all coordinate arrays have the same shape, " "falling back to scipy.", RuntimeWarning, - stacklevel=1, - ) - elif len(xi) != self.values.ndim: - warnings.warn( - f"Number of input dimensions ({len(xi)}) does not match " - "the input data dimensions, " - "falling back to scipy.", - RuntimeWarning, - stacklevel=1, ) - else: - return _get_interp_func(self.values.ndim)( + + elif ndim == len(xi_tuple): + interp_func = _get_interp_func(ndim) + return interp_func( *self.grid, self.values, - *(c.ravel() for c in xi), + *xi_tuple, fill_value=self.fill_value, - ).reshape(xi[0].shape + self.values.shape[self.values.ndim :]) + ) if (len(self.uneven_dims) != 0) and is_linear: - warnings.warn( + emit_user_level_warning( f"Dimension(s) {self.uneven_dims} are not uniform, " "falling back to scipy.", RuntimeWarning, - stacklevel=1, ) return super().__call__(xi, method) @@ -214,7 +210,7 @@ def _do_interp3(x, y, z, v0, v1, v2, v3, v4, v5, v6, v7): @numba.njit(nogil=True, inline="always") def _calc_interp1(values, v0): i0 = math.floor(v0) - n0 = values.size + n0 = values.shape[0] j0 = min(i0 + 1, n0 - 1) return _do_interp1(v0 - i0, values[i0], values[j0]) @@ -222,7 +218,7 @@ def _calc_interp1(values, v0): @numba.njit(nogil=True, inline="always") def _calc_interp2(values, v0, v1): i0, i1 = math.floor(v0), math.floor(v1) - n0, n1 = values.shape + n0, n1 = values.shape[:2] j0, j1 = min(i0 + 1, n0 - 1), min(i1 + 1, n1 - 1) return _do_interp2( v0 - i0, @@ -237,7 +233,7 @@ def _calc_interp2(values, v0, v1): @numba.njit(nogil=True, inline="always") def _calc_interp3(values, v0, v1, v2): i0, i1, i2 = math.floor(v0), math.floor(v1), math.floor(v2) - n0, n1, n2 = values.shape + n0, n1, n2 = values.shape[:3] j0, j1, j2 = min(i0 + 1, n0 - 1), min(i1 + 1, n1 - 1), min(i2 + 1, n2 - 1) return _do_interp3( v0 - i0, @@ -263,50 +259,63 @@ def _val2ind(val, coord): @numba.njit(nogil=True, parallel=True) def _interp1(x, values, xc, fill_value=np.nan): - n = len(xc) + out_shape = xc.shape + values.shape[1:] + xc_flat = xc.ravel() + n = len(xc_flat) - arr_new = np.empty(n, values.dtype) + arr_new = np.empty((n,) + values.shape[1:], values.dtype) for m in numba.prange(n): - v0 = _val2ind(xc[m], x) + v0 = _val2ind(xc_flat[m], x) if np.isnan(v0): arr_new[m] = fill_value else: arr_new[m] = _calc_interp1(values, v0) - return arr_new + + return arr_new.reshape(out_shape) @numba.njit(nogil=True, parallel=True) def _interp2(x, y, values, xc, yc, fill_value=np.nan): - n = len(xc) + out_shape = xc.shape + values.shape[2:] + xc_flat, yc_flat = xc.ravel(), yc.ravel() + n = len(xc_flat) - arr_new = np.empty(n, values.dtype) + arr_new = np.empty((n,) + values.shape[2:], values.dtype) for m in numba.prange(n): - v0, v1 = _val2ind(xc[m], x), _val2ind(yc[m], y) + v0, v1 = _val2ind(xc_flat[m], x), _val2ind(yc_flat[m], y) if np.isnan(v0) or np.isnan(v1): arr_new[m] = fill_value else: arr_new[m] = _calc_interp2(values, v0, v1) - return arr_new + + return arr_new.reshape(out_shape) @numba.njit(nogil=True, parallel=True) def _interp3(x, y, z, values, xc, yc, zc, fill_value=np.nan): - n = len(xc) + out_shape = xc.shape + values.shape[3:] + xc_flat, yc_flat, zc_flat = xc.ravel(), yc.ravel(), zc.ravel() + n = len(xc_flat) - arr_new = np.empty(n, values.dtype) + arr_new = np.empty((n,) + values.shape[3:], values.dtype) for m in numba.prange(n): - v0, v1, v2 = _val2ind(xc[m], x), _val2ind(yc[m], y), _val2ind(zc[m], z) + v0, v1, v2 = ( + _val2ind(xc_flat[m], x), + _val2ind(yc_flat[m], y), + _val2ind(zc_flat[m], z), + ) if np.isnan(v0) or np.isnan(v1) or np.isnan(v2): arr_new[m] = fill_value else: arr_new[m] = _calc_interp3(values, v0, v1, v2) - return arr_new + + return arr_new.reshape(out_shape) def _get_interp_func(ndim: int) -> Callable: @@ -451,6 +460,8 @@ def slice_along_path( xr.DataArray(p, dims=dim_name, coords={dim_name: path_coord}) for p in points_arr ] + interp_kwargs.setdefault("method", "linearfast") + return darr.interp( dict(zip(vertices.keys(), interp_coords, strict=False)), **interp_kwargs ) diff --git a/src/erlab/analysis/utils.py b/src/erlab/analysis/utils.py deleted file mode 100644 index 88355093..00000000 --- a/src/erlab/analysis/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import warnings - -import xarray as xr - - -def shift(*args, **kwargs) -> xr.DataArray: - from erlab.analysis.transform import shift as shift_func - - warnings.warn( - "erlab.analysis.utils.shift is deprecated, " - "use erlab.analysis.gold.correct_with_edge instead", - DeprecationWarning, - stacklevel=2, - ) - return shift_func(*args, **kwargs) - - -def correct_with_edge(*args, **kwargs): - from erlab.analysis.gold import correct_with_edge - - warnings.warn( - "erlab.analysis.utils.correct_with_edge is deprecated, " - "use erlab.analysis.gold.correct_with_edge instead", - DeprecationWarning, - stacklevel=2, - ) - return correct_with_edge(*args, **kwargs) diff --git a/src/erlab/characterization/__init__.py b/src/erlab/characterization/__init__.py deleted file mode 100644 index 238c186f..00000000 --- a/src/erlab/characterization/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import warnings - -from erlab.io.characterization import resistance, xrd # noqa: F401 - -warnings.warn( - "`erlab.characterization` is deprecated. Use `erlab.io.characterization` instead", - DeprecationWarning, - stacklevel=2, -) diff --git a/src/erlab/interactive/colors.py b/src/erlab/interactive/colors.py index 5135fbcd..3b11780a 100644 --- a/src/erlab/interactive/colors.py +++ b/src/erlab/interactive/colors.py @@ -323,25 +323,28 @@ def __init__(self, colorbar: BetterColorBarItem) -> None: self.cb = colorbar layout = QtWidgets.QFormLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setVerticalSpacing(0) self.setLayout(layout) + self.setFocusPolicy(QtCore.Qt.FocusPolicy.StrongFocus) self.max_spin = pg.SpinBox(dec=True, compactHeight=False, finite=False) self.min_spin = pg.SpinBox(dec=True, compactHeight=False, finite=False) + self.zero_btn = QtWidgets.QPushButton("Center Zero") self.rst_btn = QtWidgets.QPushButton("Reset") self.max_spin.setObjectName("_vmax_spin") self.min_spin.setObjectName("_vmin_spin") + self.zero_btn.setObjectName("_vlim_zero_btn") self.rst_btn.setObjectName("_vlim_reset_btn") layout.addRow("Max", self.max_spin) layout.addRow("Min", self.min_spin) + layout.addRow(self.zero_btn) layout.addRow(self.rst_btn) self.cb._span.sigRegionChanged.connect(self.region_changed) self.min_spin.sigValueChanged.connect(self.update_region) self.max_spin.sigValueChanged.connect(self.update_region) + self.zero_btn.clicked.connect(self.center_zero) self.rst_btn.clicked.connect(self.reset) def _set_spin_values(self, mn: float, mx: float) -> None: @@ -357,6 +360,20 @@ def region_changed(self): mn, mx = self.cb._span.getRegion() self._set_spin_values(mn, mx) + @QtCore.Slot() + def center_zero(self): + old_min, old_max = self.cb._span.getRegion() + self.reset() + + mn, mx = self.cb._span.getRegion() + if mn < 0 < mx: + half_len = min(abs(mn), abs(mx)) + self._set_spin_values(-half_len, half_len) + else: + self._set_spin_values(old_min, old_max) + + self.update_region() + @QtCore.Slot() def reset(self): self._set_spin_values(-np.inf, np.inf) diff --git a/src/erlab/interactive/curvefittingtool.py b/src/erlab/interactive/curvefittingtool.py index f2ec2183..948c389f 100644 --- a/src/erlab/interactive/curvefittingtool.py +++ b/src/erlab/interactive/curvefittingtool.py @@ -4,7 +4,6 @@ import lmfit import pyqtgraph as pg -import xarray as xr from qtpy import QtCore, QtWidgets from erlab.analysis.fit.models import MultiPeakModel @@ -682,31 +681,3 @@ def __post_init__(self, execute=None): pass if execute: self.qapp.exec() - - -if __name__ == "__main__": - data = xr.open_dataarray( - "~/Library/CloudStorage/Dropbox-KAIST_12/Kimoon Han/ERLab/Projects/TiSe2 Chiral" - "/Experiment/220922 ALS BL4/TS2_testedc_2209ALS.nc" - ) - edctool( - data, - 3, - parameters={ - "p0_center": -0.11933563574455268, - "p0_width": 0.1678003093053242, - "p0_height": 2041.8150354041834, - "p1_center": -0.22027597972817484, - "p1_width": 0.31653868423339476, - "p1_height": 2750.4023331130743, - "p2_center": -0.017453580016272105, - "p2_width": 0.0031789748073378895, - "p2_height": 7749.551605384519, - "lin_bkg": 86.67942474925626, - "const_bkg": 68.66930909421156, - "efermi": -0.0, - "temp": 30, - "offset": 0.7379771332091531, - "resolution": 0.04826209466467372, - }, - ) diff --git a/src/erlab/interactive/fermiedge.py b/src/erlab/interactive/fermiedge.py index fc5c0a83..0070e87d 100644 --- a/src/erlab/interactive/fermiedge.py +++ b/src/erlab/interactive/fermiedge.py @@ -187,10 +187,10 @@ def __init__( self.axes[2].setVisible(False) self.hists[2].setVisible(False) - try: - temp = float(self.data.attrs["temp_sample"]) - except KeyError: + temp = self.data.qinfo.get_value("sample_temp") + if temp is None: temp = 30.0 + temp = float(temp) self.params_roi = ROIControls(self.aw.add_roi(0)) self.params_edge = ParameterGroup( diff --git a/src/erlab/interactive/imagetool/__init__.py b/src/erlab/interactive/imagetool/__init__.py index 6a11dac4..00f5db6c 100644 --- a/src/erlab/interactive/imagetool/__init__.py +++ b/src/erlab/interactive/imagetool/__init__.py @@ -62,6 +62,15 @@ def _parse_input( return [xr.DataArray(d) if not isinstance(d, xr.DataArray) else d for d in data] +def _convert_to_native(obj: list[Any]) -> list[Any]: + """Convert a nested list of numpy objects to native types.""" + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, list): + return [_convert_to_native(item) for item in obj] + return obj + + def itool( data: Collection[xr.DataArray | npt.NDArray] | xr.DataArray @@ -693,10 +702,14 @@ def _set_colormap_options(self) -> None: ) def _copy_cursor_val(self) -> None: - copy_to_clipboard(str(self.slicer_area.array_slicer._values)) + copy_to_clipboard( + str(_convert_to_native(self.slicer_area.array_slicer._values)) + ) def _copy_cursor_idx(self) -> None: - copy_to_clipboard(str(self.slicer_area.array_slicer._indices)) + copy_to_clipboard( + str(_convert_to_native(self.slicer_area.array_slicer._indices)) + ) @QtCore.Slot() def _open_file( @@ -709,6 +722,11 @@ def _open_file( valid_loaders: dict[str, tuple[Callable, dict]] = { "xarray HDF5 Files (*.h5)": (xr.load_dataarray, {"engine": "h5netcdf"}), "NetCDF Files (*.nc *.nc4 *.cdf)": (xr.load_dataarray, {}), + "Igor Binary Waves (*.ibw)": (xr.load_dataarray, {"engine": "erlab-igor"}), + "Igor Packed Experiment Templates (*.pxt)": ( + xr.load_dataarray, + {"engine": "erlab-igor"}, + ), } try: import erlab.io @@ -785,7 +803,10 @@ def _to_hdf5(darr: xr.DataArray, file: str, **kwargs) -> None: _to_netcdf(_add_igor_scaling(darr), file, **kwargs) valid_savers: dict[str, tuple[Callable, dict[str, Any]]] = { - "xarray HDF5 Files (*.h5)": (_to_hdf5, {"engine": "h5netcdf"}), + "xarray HDF5 Files (*.h5)": ( + _to_hdf5, + {"engine": "h5netcdf", "invalid_netcdf": True}, + ), "NetCDF Files (*.nc *.nc4 *.cdf)": (_to_netcdf, {}), } diff --git a/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py b/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py index 849a961e..2794f452 100644 --- a/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py +++ b/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py @@ -816,7 +816,7 @@ class pg_itool(pg.GraphicsLayoutWidget): The data to explore. Must have three coordinate axes. snap : bool, default: True - Wheter to snap the cursor to data pixels. + Whether to snap the cursor to data pixels. gamma : float, default: 0.5 Colormap default gamma. diff --git a/src/erlab/interactive/imagetool/core.py b/src/erlab/interactive/imagetool/core.py index 3c4c1382..1db730b0 100644 --- a/src/erlab/interactive/imagetool/core.py +++ b/src/erlab/interactive/imagetool/core.py @@ -1416,6 +1416,11 @@ def hoverEvent(self, ev) -> None: else: self.setMouseHover(False) + def _computeBoundingRect(self): + """CursorLine debugging.""" + _ = self.getViewBox().size() + return super()._computeBoundingRect() + class ItoolCursorSpan(pg.LinearRegionItem): def __init__(self, *args, **kargs) -> None: diff --git a/src/erlab/interactive/imagetool/fastbinning.py b/src/erlab/interactive/imagetool/fastbinning.py index bccccc89..d42ae582 100644 --- a/src/erlab/interactive/imagetool/fastbinning.py +++ b/src/erlab/interactive/imagetool/fastbinning.py @@ -341,7 +341,7 @@ def fast_nanmean( A numpy array of floats. axis Axis or iterable of axis along which the means are computed. If `None`, the mean - of the flattend array is computed. + of the flattened array is computed. Returns ------- @@ -354,7 +354,7 @@ def fast_nanmean( ``len(axis) < N``. - For calculating the average of a flattened array (``axis = None`` or ``len(axis) - == N``), the `numba` implemenation of `numpy.nanmean` is used. + == N``), the `numba` implementation of `numpy.nanmean` is used. - For bigger ``N``, ``numbagg.nanmean`` is used if `numbagg `_ is installed. Otherwise, the calculation diff --git a/src/erlab/interactive/imagetool/manager.py b/src/erlab/interactive/imagetool/manager.py index 25d19786..ff76f7f7 100644 --- a/src/erlab/interactive/imagetool/manager.py +++ b/src/erlab/interactive/imagetool/manager.py @@ -426,7 +426,7 @@ def __init__(self: _ImageToolManagerGUI) -> None: self.options_layout.addStretch() # Temporary directory for storing archived data - self._tmp_dir = tempfile.TemporaryDirectory() + self._tmp_dir = tempfile.TemporaryDirectory(prefix="erlab_archive_") # Store most recent name filter and directory for new windows self._recent_name_filter: str | None = None @@ -807,7 +807,7 @@ def show_in_manager( darr_list: list[xarray.DataArray] = _parse_input(data) # Save the data to a temporary file - tmp_dir = tempfile.mkdtemp() + tmp_dir = tempfile.mkdtemp(prefix="erlab_manager_") files: list[str] = [] diff --git a/src/erlab/interactive/imagetool/slicer.py b/src/erlab/interactive/imagetool/slicer.py index ce7c4942..89460ab0 100644 --- a/src/erlab/interactive/imagetool/slicer.py +++ b/src/erlab/interactive/imagetool/slicer.py @@ -683,6 +683,7 @@ def qsel_args(self, cursor: int, disp: Sequence[int]) -> dict: for dim, selector in self.isel_args(cursor, disp, int_if_one=True).items(): inc = self.incs[self._obj.dims.index(dim)] + # Estimate minimum number of decimal places required to represent selection order = int(-np.floor(np.log10(inc)) + 1) if binned[self._obj.dims.index(dim)]: diff --git a/src/erlab/interactive/kspace.py b/src/erlab/interactive/kspace.py index 62ca22a9..c4d5f027 100644 --- a/src/erlab/interactive/kspace.py +++ b/src/erlab/interactive/kspace.py @@ -1,19 +1,21 @@ """Interactive momentum conversion tool.""" +from __future__ import annotations + __all__ = ["ktool"] import os import sys -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import numpy.typing as npt import pyqtgraph as pg import varname -import xarray as xr -from qtpy import QtGui, QtWidgets, uic +from qtpy import QtCore, QtGui, QtWidgets, uic import erlab.analysis +import erlab.lattice from erlab.interactive.colors import ( BetterColorBarItem, # noqa: F401 ColorMapComboBox, # noqa: F401 @@ -23,11 +25,115 @@ from erlab.interactive.utils import copy_to_clipboard, generate_code, xImageItem from erlab.plotting.bz import get_bz_edge +if TYPE_CHECKING: + import xarray as xr + + +class _CircleROIControlWidget(QtWidgets.QWidget): + def __init__(self, roi: _MovableCircleROI) -> None: + super().__init__() + self.setGeometry(QtCore.QRect(0, 640, 242, 182)) + self._roi = roi + + layout = QtWidgets.QFormLayout(self) + self.setLayout(layout) + + self.x_spin = pg.SpinBox(dec=True, compactHeight=False) + self.y_spin = pg.SpinBox(dec=True, compactHeight=False) + self.r_spin = pg.SpinBox(dec=True, compactHeight=False) + self.x_spin.sigValueChanged.connect(self.update_roi) + self.y_spin.sigValueChanged.connect(self.update_roi) + self.r_spin.sigValueChanged.connect(self.update_roi) + + layout.addRow("X", self.x_spin) + layout.addRow("Y", self.y_spin) + layout.addRow("Radius", self.r_spin) + + self._roi.sigRegionChanged.connect(self.update_spins) + + @QtCore.Slot() + def update_roi(self) -> None: + self._roi.blockSignals(True) + self._roi.set_position( + (self.x_spin.value(), self.y_spin.value()), self.r_spin.value() + ) + self._roi.blockSignals(False) + + @QtCore.Slot() + def update_spins(self) -> None: + x, y, r = self._roi.get_position() + self.x_spin.blockSignals(True) + self.y_spin.blockSignals(True) + self.r_spin.blockSignals(True) + self.x_spin.setValue(x) + self.y_spin.setValue(y) + self.r_spin.setValue(r) + self.x_spin.blockSignals(False) + self.y_spin.blockSignals(False) + self.r_spin.blockSignals(False) + + def setVisible(self, visible: bool) -> None: + super().setVisible(visible) + if visible: + self.update_spins() + + +class _MovableCircleROI(pg.CircleROI): + """Circle ROI with a menu to control position and radius.""" + + def __init__(self, pos, size=None, radius=None, **args): + args.setdefault("removable", True) + super().__init__(pos, size, radius, **args) + + def getMenu(self): + if self.menu is None: + self.menu = QtWidgets.QMenu() + self.menu.setTitle("ROI") + if self.removable: + remAct = QtGui.QAction("Remove Circle", self.menu) + remAct.triggered.connect(self.removeClicked) + self.menu.addAction(remAct) + self.menu.remAct = remAct + self._pos_menu = self.menu.addMenu("Edit Circle") + ctrlAct = QtWidgets.QWidgetAction(self._pos_menu) + ctrlAct.setDefaultWidget(_CircleROIControlWidget(self)) + self._pos_menu.addAction(ctrlAct) + + return self.menu + + def radius(self) -> float: + """Radius of the circle.""" + return float(self.size()[0] / 2) + + def center(self) -> tuple[float, float]: + """Center of the circle.""" + x, y = self.pos() + r = self.radius() + return x + r, y + r + + def get_position(self) -> tuple[float, float, float]: + """Return the center and radius of the circle.""" + return (*self.center(), self.radius()) + + def set_position(self, center, radius: float | None = None) -> None: + """Set the center and radius of the circle.""" + if radius is None: + radius = self.radius() + else: + diameter = 2 * radius + self.setSize((diameter, diameter), update=False) + self.setPos(center[0] - radius, center[1] - radius) + class KspaceToolGUI( *uic.loadUiType(os.path.join(os.path.dirname(__file__), "ktool.ui")) # type: ignore[misc] ): - def __init__(self) -> None: + def __init__( + self, + avec: npt.NDArray | None = None, + cmap: str | None = None, + gamma: float = 0.5, + ) -> None: # Initialize UI super().__init__() self.setupUi(self) @@ -44,13 +150,19 @@ def __init__(self) -> None: plot.addItem(self.images[i]) plot.showGrid(x=True, y=True, alpha=0.5) + if cmap is None: + cmap = "ColdWarm" + + if cmap.endswith("_r"): + cmap = cmap[:-2] + self.invert_check.setChecked(True) + # Set up colormap controls - self.cmap_combo.setDefaultCmap("terrain") + self.cmap_combo.setDefaultCmap(cmap) self.cmap_combo.textActivated.connect(self.update_cmap) self.gamma_widget.setValue(0.5) self.gamma_widget.valueChanged.connect(self.update_cmap) self.invert_check.stateChanged.connect(self.update_cmap) - self.invert_check.setChecked(True) self.contrast_check.stateChanged.connect(self.update_cmap) self.update_cmap() @@ -72,6 +184,34 @@ def __init__(self) -> None: lambda: self.plotitems[0].setVisible(self.angle_plot_check.isChecked()) ) + self._roi_list: list[_MovableCircleROI] = [] + self.add_circle_btn.clicked.connect(self._add_circle) + + if avec is not None: + self._populate_bz(avec) + + def _populate_bz(self, avec) -> None: + a, b, c, _, _, gamma = erlab.lattice.avec2abc(avec) + self.a_spin.setValue(a) + self.b_spin.setValue(b) + self.c_spin.setValue(c) + self.ab_spin.setValue(a) + self.ang_spin.setValue(gamma) + + @QtCore.Slot() + def _add_circle(self) -> None: + roi = _MovableCircleROI( + [-0.3, -0.3], radius=0.3, removable=True, pen=pg.mkPen("m", width=2) + ) + self.plotitems[1].addItem(roi) + self._roi_list.append(roi) + + def _remove_roi(): + self.plotitems[1].removeItem(roi) + self._roi_list.remove(roi) + + roi.sigRemoveRequested.connect(_remove_roi) + def update_cmap(self) -> None: name = self.cmap_combo.currentText() if name == self.cmap_combo.LOAD_ALL_TEXT: @@ -123,8 +263,16 @@ def update_bz(self) -> None: class KspaceTool(KspaceToolGUI): - def __init__(self, data: xr.DataArray, *, data_name: str | None = None) -> None: - super().__init__() + def __init__( + self, + data: xr.DataArray, + avec: npt.NDArray | None = None, + cmap: str | None = None, + gamma: float = 0.5, + *, + data_name: str | None = None, + ) -> None: + super().__init__(avec=avec, cmap=cmap, gamma=gamma) self._argnames = {} @@ -208,6 +356,15 @@ def __init__(self, data: xr.DataArray, *, data_name: str | None = None) -> None: self._resolution_spins[k].setSuffix(" Å⁻¹") self.resolution_group.layout().addRow(k, self._resolution_spins[k]) + # Temporary customization for beta scaling + # self._beta_scale_spin = QtWidgets.QDoubleSpinBox() + # self._beta_scale_spin.setValue(1.0) + # self._beta_scale_spin.setDecimals(2) + # self._beta_scale_spin.setSingleStep(0.01) + # self._beta_scale_spin.setRange(0.01, 10) + # self.offsets_group.layout().addRow("scale", self._beta_scale_spin) + # self._beta_scale_spin.valueChanged.connect(self.update) + self.res_btn.clicked.connect(self.calculate_resolution) self.res_npts_check.toggled.connect(self.calculate_resolution) @@ -217,6 +374,8 @@ def __init__(self, data: xr.DataArray, *, data_name: str | None = None) -> None: self.open_btn.clicked.connect(self.show_converted) self.copy_btn.clicked.connect(self.copy_code) self.update() + if avec is not None: + self.bz_group.setChecked(True) def calculate_resolution(self) -> None: for k, spin in self._resolution_spins.items(): @@ -328,6 +487,10 @@ def _angle_data(self) -> xr.DataArray: def get_data(self) -> tuple[xr.DataArray, xr.DataArray]: # Set angle offsets data_ang = self._angle_data() + # if "beta" in data_ang.dims: + # data_ang = data_ang.assign_coords( + # beta=data_ang.beta * self._beta_scale_spin.value() + # ) data_ang.kspace.offsets = self.offset_dict if self.data.kspace.has_hv: @@ -397,9 +560,31 @@ def closeEvent(self, event: QtGui.QCloseEvent) -> None: def ktool( - data: xr.DataArray, *, data_name: str | None = None, execute: bool | None = None + data: xr.DataArray, + avec: npt.NDArray | None = None, + cmap: str | None = None, + gamma: float = 0.5, + *, + data_name: str | None = None, + execute: bool | None = None, ) -> KspaceTool: - """Interactive momentum conversion tool.""" + """Interactive momentum conversion tool. + + Parameters + ---------- + data + Data to convert. + avec : array-like, optional + Real-space lattice vectors as a 2x2 or 3x3 numpy array. If provided, the + Brillouin zone boundary overlay will be calculated based on these vectors. + cmap : str, optional + Name of the default colormap to use. + gamma + Default gamma value for the colormap. + data_name + Name to use in code generation. If not provided, the name will be inferred. + + """ if data_name is None: try: data_name = str(varname.argname("data", func=ktool, vars_only=False)) @@ -412,7 +597,7 @@ def ktool( cast(QtWidgets.QApplication, qapp).setStyle("Fusion") - win = KspaceTool(data, data_name=data_name) + win = KspaceTool(data, avec=avec, cmap=cmap, gamma=gamma, data_name=data_name) win.show() win.raise_() win.activateWindow() @@ -432,8 +617,3 @@ def ktool( qapp.exec() return win - - -if __name__ == "__main__": - dat = cast(xr.DataArray, erlab.io.load_hdf5("/Users/khan/2210_ALS_f0008.h5")) - win = ktool(dat) diff --git a/src/erlab/interactive/ktool.ui b/src/erlab/interactive/ktool.ui index 1676ffdc..b68dea4b 100644 --- a/src/erlab/interactive/ktool.ui +++ b/src/erlab/interactive/ktool.ui @@ -503,6 +503,16 @@ + + + + Create a circle ROI. The position and radius can be edited by right-clicking on the created ROI. + + + Add Circle ROI + + + diff --git a/src/erlab/interactive/utilities.py b/src/erlab/interactive/utilities.py deleted file mode 100644 index 0db35ac8..00000000 --- a/src/erlab/interactive/utilities.py +++ /dev/null @@ -1,10 +0,0 @@ -import warnings - -from erlab.interactive.utils import * # noqa: F403 - -warnings.warn( - "`erlab.interactive.utilities` has been moved to `erlab.interactive.utils` " - "and will be removed in a future release", - DeprecationWarning, - stacklevel=2, -) diff --git a/src/erlab/interactive/utils.py b/src/erlab/interactive/utils.py index 84b7aaa5..87bdb7eb 100644 --- a/src/erlab/interactive/utils.py +++ b/src/erlab/interactive/utils.py @@ -1872,6 +1872,11 @@ def mouseDragEvent(self, ev) -> None: if ev.isFinish(): self.moving = False + def _computeBoundingRect(self): + """RotatableLine debugging.""" + _ = self.getViewBox().size() + return super()._computeBoundingRect() + def make_crosshairs(n: Literal[1, 2, 3] = 1) -> list[pg.TargetItem | RotatableLine]: r"""Create a :class:`pyqtgraph.TargetItem` and associated `RotatableLine`\ s. @@ -1901,58 +1906,3 @@ def make_crosshairs(n: Literal[1, 2, 3] = 1) -> list[pg.TargetItem | RotatableLi l0.link(l1) return [*lines, target] - - -if __name__ == "__main__": - from scipy.ndimage import gaussian_filter # , uniform_filter - - qapp: QtWidgets.QApplication = cast( - QtWidgets.QApplication, QtWidgets.QApplication.instance() - ) - if not qapp: - qapp = QtWidgets.QApplication(sys.argv) - qapp.setStyle("Fusion") - - dat = ( - xr.open_dataarray( - "/Users/khan/Documents/ERLab/CsV3Sb5/2021_Dec_ALS_CV3Sb5/Data/cvs_kxy_small.nc" - ) - .sel(eV=-0.15, method="nearest") - .fillna(0) - ) - win = AnalysisWindow(dat, analysisWidget=ComparisonWidget, orientation="vertical") - win.setAttribute(QtCore.Qt.WidgetAttribute.WA_AcceptTouchEvents, False) - - def gaussfilt_2d(dat, sx, sy): - return gaussian_filter(dat, sigma=(sx, sy)) - - win.aw.set_main_function(gaussfilt_2d, sx=0.1, sy=1, only_values=True) - - # win.set_pre_function(gaussian_filter, sigma=[1, 1], only_values=True) - # win.set_pre_function(gaussian_filter, sigma=(0.1, 0.1)) - - # layout.addWidget(win) - win.addParameterGroup( - sigma_x={ - "qwtype": "btspin", - "minimum": 0, - "maximum": 10, - "valueChanged": lambda x: win.aw.set_main_function_args(sx=x), - }, - sigma_y={ - "qwtype": "btspin", - "minimum": 0, - "maximum": 10, - "valueChanged": lambda x: win.aw.set_main_function_args(sy=x), - }, - b={"qwtype": "combobox", "items": ["item1", "item2", "item3"]}, - ) - - win.__post_init__(execute=True) - # new_roi = win.add_roi(0) - - # layout.addWidget(ROIControls(new_roi)) - - # wdgt.show() - # wdgt.activateWindow() - # wdgt.raise_() diff --git a/src/erlab/io/__init__.py b/src/erlab/io/__init__.py index b132bea8..a0a4357a 100644 --- a/src/erlab/io/__init__.py +++ b/src/erlab/io/__init__.py @@ -43,7 +43,7 @@ >>> erlab.io.set_loader("merlin") -Learn more about loaders in the :ref:`User Guide `. +Learn more about loaders in the :ref:`User Guide `. """ @@ -82,9 +82,6 @@ set_loader = loaders.set_loader summarize = loaders.summarize -merlin = loaders["merlin"] -ssrl52 = loaders["ssrl52"] - def load_wave(*args, **kwargs): from erlab.io.igor import load_wave as _load_wave @@ -100,11 +97,3 @@ def load_experiment(*args, **kwargs): warnings.warn("Use `xarray.open_dataset` instead", DeprecationWarning, stacklevel=2) return _load_experiment(*args, **kwargs) - - -def load_igor_ibw(*args, **kwargs): - return load_wave(*args, **kwargs) - - -def load_igor_pxp(*args, **kwargs): - return load_experiment(*args, **kwargs) diff --git a/src/erlab/io/dataloader.py b/src/erlab/io/dataloader.py index 49047a58..086fc2f0 100644 --- a/src/erlab/io/dataloader.py +++ b/src/erlab/io/dataloader.py @@ -7,7 +7,7 @@ methods and attributes. A detailed guide on how to implement a data loader can be found in the :ref:`User Guide -`. +`. """ from __future__ import annotations @@ -34,9 +34,9 @@ import numpy as np import pandas import xarray as xr -from xarray.core.datatree import DataTree from erlab.utils.formatting import format_html_table, format_value +from erlab.utils.misc import emit_user_level_warning if TYPE_CHECKING: from collections.abc import ( @@ -176,17 +176,26 @@ class LoaderBase(metaclass=_Loader): :meth:`post_process ` """ - additional_attrs: ClassVar[dict[str, str | int | float]] = {} + additional_attrs: ClassVar[ + dict[str, str | float | Callable[[xr.DataArray], str | float]] + ] = {} """Additional attributes to be added to the data after loading. + If a callable is provided, it will be called with the data as the only argument. + Notes ----- - The attributes are added after renaming with :meth:`process_keys `, so keys will appear in the data as provided. - - If an attribute with the same name is already present in the data, it is skipped. + - If an attribute with the same name is already present in the data, it is skipped + unless the key is listed in :attr:`overridden_attrs + `. """ + overridden_attrs: tuple[str, ...] = () + """Keys in :attr:`additional_attrs` that should override existing attributes.""" + additional_coords: ClassVar[dict[str, str | int | float]] = {} """Additional non-dimension coordinates to be added to the data after loading. @@ -195,22 +204,13 @@ class LoaderBase(metaclass=_Loader): - The coordinates are added after renaming with :meth:`process_keys `, so keys will appear in the data as provided. - - If a coordinate with the same name is already present in the data, it is skipped. + - If a coordinate with the same name is already present in the data, it is skipped + unless the key is listed in :attr:`overridden_coords + `. """ - formatters: ClassVar[dict[str, Callable]] = {} - """Mapping from attribute names (after renaming) to custom formatters. - - The formatters must take the attribute value and return a value that can be - converted to a string with :meth:`value_to_string - `. The resulting formats are used - for human readable display of some attributes in the summary table and the - information accessor. - - Note - ---- - The formatters are only used for display purposes and do not affect the stored data. - """ + overridden_coords: tuple[str, ...] = () + """Keys in :attr:`additional_coords` that should override existing coordinates.""" always_single: bool = True """ @@ -221,7 +221,7 @@ class LoaderBase(metaclass=_Loader): skip_validate: bool = False """ If `True`, validation checks will be skipped. If `False`, data will be checked with - :meth:`validate ` every time it is loaded. + :meth:`validate `. """ strict_validation: bool = False @@ -231,6 +231,56 @@ class LoaderBase(metaclass=_Loader): `skip_validate` is `True`. """ + formatters: ClassVar[dict[str, Callable]] = {} + """Optional mapping from attr or coord names (after renaming) to custom formatters. + + The formatters are callables that takes the attribute value and returns a value that + can be converted to a string via :meth:`value_to_string + `. The resulting string + representations are used for human readable display in the summary table and the + information accessor. + + The values returned by the formatters will be further formatted by + :meth:`value_to_string ` before + being displayed. + + If the key is a coordinate, the function will automatically be vectorized over every + value. + + Note + ---- + The formatters are only used for display purposes and do not affect the stored data. + + See Also + -------- + :meth:`get_formatted_attr_or_coord` + The method that uses this mapping to provide human-readable values. + """ + + summary_sort: str | None = None + """Optional default column to sort the summary table by. + + If `None`, the summary table is sorted in the order of the files returned by + :meth:`files_for_summary `. + """ + + @property + def summary_attrs(self) -> dict[str, str | Callable[[xr.DataArray], Any]]: + """Mapping from summary column names to attr or coord names (after renaming). + + If the value is a callable, it will be called with the data as the only + argument. This can be used to extract values from the data that are not stored + as attributes or spread across multiple attributes. + + If not overridden, returns a basic mapping based on :attr:`name_map`. + + It is highly recommended to override this property to provide a more detailed + and informative summary. See existing loaders for examples. + + """ + excluded = {"eV", "alpha", "sample_workfunction"} + return {k: k for k in self.name_map if k not in excluded} + @property def name_map_reversed(self) -> dict[str, str]: """Reverse of :attr:`name_map `. @@ -313,7 +363,9 @@ def value_to_string(cls, val: object) -> str: The default behavior formats the given value with :func:`format_value `. Override this classmethod to change the - printed format of each cell. + printed format of summaries and information accessors. This method is applied + after the formatters in :attr:`formatters + `. """ return format_value(val) @@ -340,7 +392,7 @@ def get_styler(cls, df: pandas.DataFrame) -> pandas.io.formats.style.Styler: """ style = df.style.format(cls.value_to_string) - hidden = [c for c in ("Time", "Path") if c in df.columns] + hidden = [c for c in ("Path",) if c in df.columns] if len(hidden) > 0: style = style.hide(hidden, axis="columns") @@ -353,16 +405,16 @@ def load( *, single: bool = False, combine: bool = True, - parallel: bool | None = None, + parallel: bool = False, load_kwargs: dict[str, Any] | None = None, **kwargs, ) -> ( xr.DataArray | xr.Dataset - | DataTree + | xr.DataTree | list[xr.DataArray] | list[xr.Dataset] - | list[DataTree] + | list[xr.DataTree] ): """Load ARPES data. @@ -399,12 +451,13 @@ def load( `False`, a list of data is returned. If `True`, the loader tries to combined the data into a single data object and return it. Depending on the type of each data object, the returned object can be a `xarray.DataArray`, - `xarray.Dataset`, or a `DataTree`. + `xarray.Dataset`, or a `xarray.DataTree`. This argument is only used when `single` is `False`. parallel - Whether to load multiple files in parallel. If not specified, files are - loaded in parallel only when there are more than 15 files to load. + Whether to load multiple files in parallel using the `joblib` library. + + This argument is only used when `single` is `False`. load_kwargs Additional keyword arguments to be passed to :meth:`load_single `. @@ -413,7 +466,7 @@ def load( Returns ------- - `xarray.DataArray` or `xarray.Dataset` or `DataTree` + `xarray.DataArray` or `xarray.Dataset` or `xarray.DataTree` The loaded data. Notes @@ -475,7 +528,7 @@ def load( if len(file_paths) == 1: # Single file resolved - data: xr.DataArray | xr.Dataset | DataTree = self.load_single( + data: xr.DataArray | xr.Dataset | xr.DataTree = self.load_single( file_paths[0], **load_kwargs ) else: @@ -520,6 +573,7 @@ def load( single=single, combine=combine, parallel=parallel, + load_kwargs=load_kwargs, **new_kwargs, ) @@ -536,7 +590,9 @@ def load( return data def get_formatted_attr_or_coord( - self, data: xr.DataArray, attr_or_coord_name: str + self, + data: xr.DataArray, + attr_or_coord_name: str | Callable[[xr.DataArray], Any], ) -> Any: """Return the formatted value of the given attribute or coordinate. @@ -548,10 +604,14 @@ def get_formatted_attr_or_coord( ---------- data : DataArray The data to extract the attribute or coordinate from. - attr_or_coord_name : str - The name of the attribute or coordinate to extract. + attr_or_coord_name : str or callable + The name of the attribute or coordinate to extract. If a callable is passed, + it is called with the data as the only argument. """ + if callable(attr_or_coord_name): + return attr_or_coord_name(data) + func = self.formatters.get(attr_or_coord_name, lambda x: x) if attr_or_coord_name in data.attrs: @@ -569,11 +629,11 @@ def get_formatted_attr_or_coord( def summarize( self, data_dir: str | os.PathLike, - usecache: bool = True, + exclude: str | Sequence[str] | None = None, *, cache: bool = True, display: bool = True, - **kwargs, + rc: dict[str, Any] | None = None, ) -> pandas.DataFrame | pandas.io.formats.style.Styler | None: """Summarize the data in the given directory. @@ -589,24 +649,23 @@ def summarize( ---------- data_dir Directory to summarize. - usecache - Whether to use the cached summary if available. If `False`, the summary will - be regenerated. The cache will be updated if `cache` is `True`. + exclude + A string or sequence of strings specifying glob patterns for files to be + excluded from the summary. If provided, caching will be disabled. cache - Whether to cache the summary in a pickle file in the directory. If `False`, - no cache will be created or updated. Note that existing cache files will not - be deleted, and will be used if `usecache` is `True`. + Whether to use caching for the summary. display Whether to display the formatted dataframe using the IPython shell. If `False`, the dataframe will be returned without formatting. If `True` but the IPython shell is not detected, the dataframe styler will be returned. - **kwargs - Additional keyword arguments to be passed to :meth:`generate_summary - `. + rc + Optional dictionary of matplotlib rcParams to override the default for the + plot in the interactive summary. Figure size and the colormap can be changed + using this argument. Returns ------- - df : pandas.DataFrame or pandas.io.formats.style.Styler or None + pandas.DataFrame or pandas.io.formats.style.Styler or None Summary of the data in the directory. - If `display` is `False`, the summary DataFrame is returned. @@ -621,29 +680,35 @@ def summarize( for the summary DataFrame will be returned. """ - if not os.path.isdir(data_dir): + data_dir = pathlib.Path(data_dir) + + if not data_dir.is_dir(): raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), str(data_dir) ) - pkl_path = os.path.join(data_dir, ".summary.pkl") + pkl_path = data_dir / ".summary.pkl" df = None - if usecache: + + if exclude is not None: + cache = False + + if pkl_path.is_file() and cache: try: df = pandas.read_pickle(pkl_path) - df = df.head(len(df)) - except FileNotFoundError: - pass + except Exception: + df = None + + if df is not None: + contents = {str(f.relative_to(data_dir)) for f in data_dir.glob("[!.]*")} + if contents != df.attrs.get("__contents", set()): + # Cache is outdated + df = None if df is None: - df = self.generate_summary(data_dir, **kwargs) - if cache: - try: - df.to_pickle(pkl_path) - except OSError: - warnings.warn( - f"Failed to cache summary to {pkl_path}", stacklevel=1 - ) + df = self.generate_summary(data_dir, exclude) + if cache and os.access(data_dir, os.W_OK): + df.to_pickle(pkl_path) if not display: return df @@ -663,7 +728,7 @@ def summarize( display(styled) # type: ignore[misc] if importlib.util.find_spec("ipywidgets"): - return self._isummarize(df) + return self._isummarize(df, rc=rc) return None @@ -672,39 +737,141 @@ def summarize( return styled - def isummarize(self, df: pandas.DataFrame | None = None, **kwargs) -> None: - """Display an interactive summary. + def generate_summary( + self, data_dir: str | os.PathLike, exclude: str | Sequence[str] | None = None + ) -> pandas.DataFrame: + """Generate a dataframe summarizing the data in the given directory. - This method provides an interactive summary of the data using ipywidgets and - matplotlib. + Takes a path to a directory and summarizes the data in the directory to a pandas + DataFrame, much like a log file. This is useful for quickly inspecting the + contents of a directory. Parameters ---------- - df - A summary dataframe as returned by :meth:`generate_summary - `. If `None`, a dataframe - will be generated using :meth:`summarize - `. Defaults to `None`. - **kwargs - Additional keyword arguments to be passed to :meth:`summarize - ` if `df` is None. + data_dir + Path to a directory. + exclude + A string or sequence of strings specifying glob patterns for files to be + excluded from the summary. - Note - ---- - This method requires `ipywidgets` to be installed. + Returns + ------- + pandas.DataFrame + Summary of the data in the directory. """ + data_dir = pathlib.Path(data_dir) + + excluded: list[pathlib.Path] = [] + + if exclude is not None: + if isinstance(exclude, str): + exclude = [exclude] + + for pattern in exclude: + excluded = excluded + list(data_dir.glob(pattern)) + + target_files: list[pathlib.Path] = [ + pathlib.Path(f) + for f in self.files_for_summary(data_dir) + if pathlib.Path(f) not in excluded + ] + + if not self.always_single: + signatures: list[int | None] = [ + self.infer_index(f.stem)[0] for f in target_files + ] + + # Removing duplicates that exist in same multi-file scan + seen = set() # set to track seen elements + target_files_new = [] + for f, sig in zip(target_files, signatures, strict=True): + if sig is not None and sig in seen: + # sig[0] == None for files that cannot be inferred, keep them + continue + seen.add(sig) + target_files_new.append(f) + target_files = target_files_new + + columns = ["File Name", "Path", *self.summary_attrs.keys()] + content = [] + + def _add_content( + data: xr.DataArray | xr.Dataset | xr.DataTree, + file_path: pathlib.Path, + suffix: str | None = None, + ) -> None: + if suffix is None: + suffix = "" + + if isinstance(data, xr.DataArray): + name = file_path.stem + if suffix != "": + name = f"{name} ({suffix})" + content.append( + [ + file_path.stem, + str(file_path), + *( + self.get_formatted_attr_or_coord(data, v) + for v in self.summary_attrs.values() + ), + ] + ) + + elif isinstance(data, xr.Dataset): + if len(data.data_vars) == 1: + _add_content(next(iter(data.data_vars.values())), file_path, suffix) + else: + for k, darr in data.data_vars.items(): + _add_content(darr, file_path, suffix=suffix + k) + + elif isinstance(data, xr.DataTree): + for leaf in data.leaves: + _add_content(leaf.dataset, file_path, suffix=leaf.path) + + for f in target_files: + _add_content( + cast( + xr.DataArray | xr.Dataset | xr.DataTree, + self.load(f, load_kwargs={"without_values": True}), + ), + f, + ) + + sort_by = self.summary_sort if self.summary_sort is not None else "File Name" + + df = ( + pandas.DataFrame(content, columns=columns) + .sort_values(sort_by) + .set_index("File Name") + ) + + # Cache directory contents for determining whether cache is up-to-date + contents = {str(f.relative_to(data_dir)) for f in data_dir.glob("[!.]*")} + df.attrs["__contents"] = contents + + return df + + def _isummarize( + self, + summary: pandas.DataFrame | None = None, + rc: dict[str, Any] | None = None, + **kwargs, + ): + rc_dict: dict[str, Any] = {} if rc is None else rc + if not importlib.util.find_spec("ipywidgets"): raise ImportError( "ipywidgets and IPython is required for interactive summaries" ) - if df is None: + + if summary is None: kwargs["display"] = False df = cast(pandas.DataFrame, self.summarize(**kwargs)) + else: + df = summary - self._isummarize(df) - - def _isummarize(self, df: pandas.DataFrame): import matplotlib.pyplot as plt from ipywidgets import ( HTML, @@ -730,7 +897,7 @@ def _format_data_info(series: pandas.Series) -> str: table = "" table += ( "
" + "style='height:300px;overflow-y:auto;'>" ) table += "" table += "" @@ -783,12 +950,20 @@ def _update_data(_, *, full: bool = False) -> None: # !TODO: Add 2 sliders for 4D data if self._temp_data.ndim == 3: + old_dim = str(dim_sel.value) + dim_sel.unobserve(_update_sliders, "value") coord_sel.unobserve(_update_plot, "value") dim_sel.options = self._temp_data.dims - # Set the default dimension to the one with the smallest size - dim_sel.value = self._temp_data.dims[np.argmin(self._temp_data.shape)] + # Set the default dimension to the one with the smallest size if + # previous dimension is not present + if old_dim in dim_sel.options: + dim_sel.value = old_dim + else: + dim_sel.value = self._temp_data.dims[ + np.argmin(self._temp_data.shape) + ] coord_sel.observe(_update_plot, "value") dim_sel.observe(_update_sliders, "value") @@ -837,12 +1012,14 @@ def _update_plot(_) -> None: else: plot_data = self._temp_data + old_rc = {k: v for k, v in plt.rcParams.items() if k in rc_dict} with out: + plt.rcParams.update(rc_dict) plot_data.qplot(ax=plt.gca()) plt.title("") # Remove automatically generated title - # Add line at Fermi level if the data is 2D and has an energy dimension - # that includes zero + # Add line at Fermi level if the data is 2D and has an energy + # dimension that includes zero if (plot_data.ndim == 2 and "eV" in plot_data.dims) and ( plot_data["eV"].values[0] * plot_data["eV"].values[-1] < 0 ): @@ -850,6 +1027,7 @@ def _update_plot(_) -> None: orientation="h" if plot_data.dims[0] == "eV" else "v" ) show_inline_matplotlib_plots() + plt.rcParams.update(old_rc) def _next(_) -> None: # Select next row @@ -876,7 +1054,9 @@ def _prev(_) -> None: buttons = [prev_button, next_button, full_button] # List of data files - data_select = Select(options=list(df.index), value=next(iter(df.index)), rows=8) + data_select = Select( + options=list(df.index), value=next(iter(df.index)), rows=10 + ) data_select.observe(_update_data, "value") # HTML table for data info @@ -905,8 +1085,10 @@ def _prev(_) -> None: ) def load_single( - self, file_path: str | os.PathLike - ) -> xr.DataArray | xr.Dataset | DataTree: + self, + file_path: str | os.PathLike, + without_values: bool = False, + ) -> xr.DataArray | xr.Dataset | xr.DataTree: r"""Load a single file and return it as an xarray data structure. Any scan-specific postprocessing should be implemented in this method. @@ -915,14 +1097,19 @@ def load_single( that represents the data in a single file. For instance, if a single file contains a single scan region, the method should return a single `xarray.DataArray`. If it contains multiple regions, the method should return a - `xarray.Dataset` or `DataTree` depending on whether the regions can be merged - with without conflicts (i.e., all mutual coordinates of the regions are the - same). + `xarray.Dataset` or `xarray.DataTree` depending on whether the regions can be + merged with without conflicts (i.e., all mutual coordinates of the regions are + the same). Parameters ---------- file_path Full path to the file to be loaded. + without_values + Used when creating a summary table. With this option set to `True`, only the + coordinates and attributes of the output data are accessed so that the + values can be replaced with placeholder numbers, speeding up the summary + generation for lazy loading enabled file formats like HDF5 or NeXus. Returns ------- @@ -939,8 +1126,8 @@ def load_single( :meth:`combine_multiple `. This should not be a problem since in most cases, the data structure of associated files acquired during the same scan will be identical. - - For `DataTree` objects, returned trees must be named with a unique identifier - to avoid conflicts when combining. + - For `xarray.DataTree` objects, returned trees must be named with a unique + identifier to avoid conflicts when combining. """ raise NotImplementedError("method must be implemented in the subclass") @@ -1019,22 +1206,21 @@ def infer_index(self, name: str) -> tuple[int | None, dict[str, Any]]: """ raise NotImplementedError("method must be implemented in the subclass") - def generate_summary(self, data_dir: str | os.PathLike) -> pandas.DataFrame: - """Generate a dataframe summarizing the data in the given directory. + def files_for_summary(self, data_dir: str | os.PathLike) -> list[str | os.PathLike]: + """Return a list of files that can be loaded by the loader. - Takes a path to a directory and summarizes the data in the directory to a pandas - DataFrame, much like a log file. This is useful for quickly inspecting the - contents of a directory. + This method is used to select files that can be loaded by the loader when + generating a summary. Parameters ---------- data_dir - Path to a directory. + The directory containing the data. Returns ------- - pandas.DataFrame - Summary of the data in the directory. + list of str or path-like + A list of files that can be loaded by the loader. """ raise NotImplementedError( @@ -1089,16 +1275,16 @@ def combine_multiple( @overload def combine_multiple( self, - data_list: list[DataTree], + data_list: list[xr.DataTree], coord_dict: dict[str, Sequence], - ) -> DataTree: ... + ) -> xr.DataTree: ... def combine_multiple( self, - data_list: list[xr.DataArray] | list[xr.Dataset] | list[DataTree], + data_list: list[xr.DataArray] | list[xr.Dataset] | list[xr.DataTree], coord_dict: dict[str, Sequence], - ) -> xr.DataArray | xr.Dataset | DataTree: - if _is_sequence_of(data_list, DataTree): + ) -> xr.DataArray | xr.Dataset | xr.DataTree: + if _is_sequence_of(data_list, xr.DataTree): raise NotImplementedError( "Combining DataTrees into a single tree " "will be supported in a future release" @@ -1108,7 +1294,7 @@ def combine_multiple( # No coordinates to combine given # Multiregion scans over multiple files may be provided like this - if _is_sequence_of(data_list, DataTree): + if _is_sequence_of(data_list, xr.DataTree): pass else: try: @@ -1231,14 +1417,26 @@ def post_process(self, darr: xr.DataArray) -> xr.DataArray: v = darr[k].values.mean() darr = darr.drop_vars(k).assign_attrs({k: v}) + new_attrs: dict[str, str | float] = {} + for k, v in self.additional_attrs.items(): + if k not in darr.attrs: + if callable(v): + new_attrs[k] = v(darr) + else: + new_attrs[k] = v + new_attrs = { - k: v for k, v in self.additional_attrs.items() if k not in darr.attrs + k: v + for k, v in self.additional_attrs.items() + if k not in darr.attrs or k in self.overridden_attrs } new_attrs["data_loader_name"] = str(self.name) darr = darr.assign_attrs(new_attrs) new_coords = { - k: v for k, v in self.additional_coords.items() if k not in darr.coords + k: v + for k, v in self.additional_coords.items() + if k not in darr.coords or k in self.overridden_coords } return darr.assign_coords(new_coords) @@ -1266,11 +1464,11 @@ def post_process_general(self, data: xr.DataArray) -> xr.DataArray: ... def post_process_general(self, data: xr.Dataset) -> xr.Dataset: ... @overload - def post_process_general(self, data: DataTree) -> DataTree: ... + def post_process_general(self, data: xr.DataTree) -> xr.DataTree: ... def post_process_general( - self, data: xr.DataArray | xr.Dataset | DataTree - ) -> xr.DataArray | xr.Dataset | DataTree: + self, data: xr.DataArray | xr.Dataset | xr.DataTree + ) -> xr.DataArray | xr.Dataset | xr.DataTree: """Post-process any data structure. This method extends :meth:`post_process @@ -1290,7 +1488,7 @@ def post_process_general( post-processed using :meth:`post_process ` is returned. The attributes of the original `Dataset` are preserved. - - If a `DataTree`, the post-processing is applied to each leaf node + - If a `xarray.DataTree`, the post-processing is applied to each leaf node `Dataset`. Returns @@ -1310,15 +1508,15 @@ def post_process_general( attrs=data.attrs, ) - if isinstance(data, DataTree): - return cast(DataTree, data.map_over_subtree(self.post_process_general)) + if isinstance(data, xr.DataTree): + return cast(xr.DataTree, data.map_over_datasets(self.post_process_general)) raise TypeError( "data must be a DataArray, Dataset, or DataTree, but got " + type(data) ) @classmethod - def validate(cls, data: xr.DataArray | xr.Dataset | DataTree) -> None: + def validate(cls, data: xr.DataArray | xr.Dataset | xr.DataTree) -> None: """Validate the input data to ensure it is in the correct format. Checks for the presence of all coordinates and attributes required for common @@ -1331,8 +1529,8 @@ def validate(cls, data: xr.DataArray | xr.Dataset | DataTree) -> None: Parameters ---------- data : DataArray or Dataset or DataTree - The data to be validated. If a `Dataset` or `DataTree` is passed, validation - is performed on each data variable recursively. + The data to be validated. If a `xarray.Dataset` or `xarray.DataTree` is + passed, validation is performed on each data variable recursively. """ if isinstance(data, xr.Dataset): @@ -1340,19 +1538,19 @@ def validate(cls, data: xr.DataArray | xr.Dataset | DataTree) -> None: cls.validate(v) return - if isinstance(data, DataTree): - data.map_over_subtree(cls.validate) + if isinstance(data, xr.DataTree): + data.map_over_datasets(cls.validate) return for c in ("beta", "delta", "xi", "hv"): if c not in data.coords: - cls._raise_or_warn(f"Missing coordinate {c}") + cls._raise_or_warn(f"Missing coordinate '{c}'") - for a in ("configuration", "temp_sample"): - if a not in data.attrs: - cls._raise_or_warn(f"Missing attribute {a}") + if data.qinfo.get_value("sample_temp") is None: + cls._raise_or_warn("Missing attribute 'sample_temp'") if "configuration" not in data.attrs: + cls._raise_or_warn("Missing attribute 'configuration'") return if data.attrs["configuration"] not in (1, 2): @@ -1361,15 +1559,15 @@ def validate(cls, data: xr.DataArray | xr.Dataset | DataTree) -> None: f"Invalid configuration {data.attrs['configuration']}" ) elif "chi" not in data.coords: - cls._raise_or_warn("Missing coordinate chi") + cls._raise_or_warn("Missing coordinate 'chi'") def load_multiple_parallel( self, file_paths: list[str], - parallel: bool | None = None, + parallel: bool = False, post_process: bool = False, **kwargs, - ) -> list[xr.DataArray] | list[xr.Dataset] | list[DataTree]: + ) -> list[xr.DataArray] | list[xr.Dataset] | list[xr.DataTree]: """Load multiple files in parallel. Parameters @@ -1388,9 +1586,6 @@ def load_multiple_parallel( ------- A list of the loaded data. """ - if parallel is None: - parallel = len(file_paths) > 15 - if post_process: def _load_func(filename): @@ -1414,17 +1609,17 @@ def _load_func(filename): def _raise_or_warn(cls, msg: str) -> None: if cls.strict_validation: raise ValidationError(msg) - warnings.warn(msg, ValidationWarning, stacklevel=2) + emit_user_level_warning(msg, ValidationWarning) -class RegistryBase: +class _RegistryBase: """Base class for the loader registry. This class implements the singleton pattern, ensuring that only one instance of the registry is created and used throughout the application. """ - __instance: RegistryBase | None = None + __instance: _RegistryBase | None = None def __new__(cls): if not isinstance(cls.__instance, cls): @@ -1437,41 +1632,72 @@ def instance(cls) -> Self: return cls() -class LoaderRegistry(RegistryBase): - loaders: ClassVar[dict[str, LoaderBase | type[LoaderBase]]] = {} - """Registered loaders \n\n:meta hide-value:""" +class LoaderRegistry(_RegistryBase): + _loaders: ClassVar[dict[str, LoaderBase | type[LoaderBase]]] = {} + """Mapping of registered loaders.""" + + _alias_mapping: ClassVar[dict[str, str]] = {} + """Mapping of aliases to loader names.""" + + _current_loader: LoaderBase | None = None + _current_data_dir: pathlib.Path | None = None - alias_mapping: ClassVar[dict[str, str]] = {} - """Mapping of aliases to loader names \n\n:meta hide-value:""" + @property + def current_loader(self) -> LoaderBase | None: + """Current loader.""" + return self._current_loader + + @current_loader.setter + def current_loader(self, loader: str | LoaderBase | None) -> None: + self.set_loader(loader) + + @property + def current_data_dir(self) -> os.PathLike | None: + """Directory to search for data files.""" + return self._current_data_dir - current_loader: LoaderBase | None = None - """Current loader \n\n:meta hide-value:""" + @current_data_dir.setter + def current_data_dir(self, data_dir: str | os.PathLike | None) -> None: + self.set_data_dir(data_dir) - default_data_dir: pathlib.Path | None = None - """Default directory to search for data files \n\n:meta hide-value:""" + @property + def default_data_dir(self) -> os.PathLike | None: + """Deprecated alias for current_data_dir. + + .. deprecated:: 3.0.0 + + Use :attr:`current_data_dir` instead. + """ + warnings.warn( + "`default_data_dir` is deprecated, use `current_data_dir` instead", + DeprecationWarning, + stacklevel=1, + ) + return self.current_data_dir def _register(self, loader_class: type[LoaderBase]) -> None: # Add class to loader - self.loaders[loader_class.name] = loader_class + self._loaders[loader_class.name] = loader_class # Add aliases to mapping - self.alias_mapping[loader_class.name] = loader_class.name + self._alias_mapping[loader_class.name] = loader_class.name if loader_class.aliases is not None: for alias in loader_class.aliases: - self.alias_mapping[alias] = loader_class.name + self._alias_mapping[alias] = loader_class.name def keys(self) -> KeysView[str]: - return self.loaders.keys() + return self._loaders.keys() def items(self) -> ItemsView[str, LoaderBase | type[LoaderBase]]: - return self.loaders.items() + return self._loaders.items() def get(self, key: str) -> LoaderBase: - loader_name = self.alias_mapping.get(key) + """Get a loader instance by name or alias.""" + loader_name = self._alias_mapping.get(key) if loader_name is None: raise LoaderNotFoundError(key) - loader = self.loaders.get(loader_name) + loader = self._loaders.get(loader_name) if loader is None: raise LoaderNotFoundError(key) @@ -1479,12 +1705,12 @@ def get(self, key: str) -> LoaderBase: if not isinstance(loader, LoaderBase): # If not an instance, create one loader = loader() - self.loaders[loader_name] = loader + self._loaders[loader_name] = loader return loader def __iter__(self) -> Iterator[str]: - return iter(self.loaders) + return iter(self._loaders) def __getitem__(self, key: str) -> LoaderBase: return self.get(key) @@ -1498,7 +1724,7 @@ def __getattr__(self, key: str) -> LoaderBase: def set_loader(self, loader: str | LoaderBase | None) -> None: """Set the current data loader. - All subsequent calls to `load` will use the loader set here. + All subsequent calls to `load` will use the provided loader. Parameters ---------- @@ -1515,9 +1741,9 @@ def set_loader(self, loader: str | LoaderBase | None) -> None: """ if isinstance(loader, str): - self.current_loader = self.get(loader) + self._current_loader = self.get(loader) else: - self.current_loader = loader + self._current_loader = loader @contextlib.contextmanager def loader_context( @@ -1559,7 +1785,7 @@ def loader_context( self.set_loader(loader) if data_dir is not None: - old_data_dir = self.default_data_dir + old_data_dir = self.current_data_dir self.set_data_dir(data_dir) try: @@ -1574,13 +1800,13 @@ def loader_context( def set_data_dir(self, data_dir: str | os.PathLike | None) -> None: """Set the default data directory for the data loader. - All subsequent calls to :func:`erlab.io.load` will use the `data_dir` set here + All subsequent calls to :func:`erlab.io.load` will use the provided `data_dir` unless specified. Parameters ---------- data_dir - The path to a directory. + The default data directory to use. Note ---- @@ -1589,23 +1815,28 @@ def set_data_dir(self, data_dir: str | os.PathLike | None) -> None: """ if data_dir is None: - self.default_data_dir = None + self._current_data_dir = None return - self.default_data_dir = pathlib.Path(data_dir).resolve(strict=True) + self._current_data_dir = pathlib.Path(data_dir).resolve(strict=True) def load( self, identifier: str | os.PathLike | int, data_dir: str | os.PathLike | None = None, + *, + single: bool = False, + combine: bool = True, + parallel: bool = False, + load_kwargs: dict[str, Any] | None = None, **kwargs, ) -> ( xr.DataArray | xr.Dataset - | DataTree + | xr.DataTree | list[xr.DataArray] | list[xr.Dataset] - | list[DataTree] + | list[xr.DataTree] ): loader, default_dir = self._get_current_defaults() @@ -1619,12 +1850,11 @@ def load( default_file = (default_dir / identifier).resolve() if default_file.exists() and abs_file != default_file: - warnings.warn( + emit_user_level_warning( f"Found {identifier!s} in the default directory " f"{default_dir!s}, but conflicting file {abs_file!s} was found. " "The first file will be loaded. " "Consider specifying the directory explicitly.", - stacklevel=2, ) else: # If the identifier is a path to a file, ignore default_dir @@ -1633,16 +1863,24 @@ def load( if data_dir is None: data_dir = default_dir - return loader.load(identifier, data_dir=data_dir, **kwargs) + return loader.load( + identifier, + data_dir=data_dir, + single=single, + combine=combine, + parallel=parallel, + load_kwargs=load_kwargs, + **kwargs, + ) def summarize( self, data_dir: str | os.PathLike | None = None, - usecache: bool = True, + exclude: str | Sequence[str] | None = None, *, cache: bool = True, display: bool = True, - **kwargs, + rc: dict[str, Any] | None = None, ) -> pandas.DataFrame | pandas.io.formats.style.Styler | None: loader, default_dir = self._get_current_defaults() @@ -1650,7 +1888,7 @@ def summarize( data_dir = default_dir return loader.summarize( - data_dir, usecache, cache=cache, display=display, **kwargs + data_dir=data_dir, exclude=exclude, cache=cache, display=display, rc=rc ) def _get_current_defaults(self): @@ -1658,18 +1896,18 @@ def _get_current_defaults(self): raise ValueError( "No loader has been set. Set a loader with `erlab.io.set_loader` first" ) - return self.current_loader, self.default_data_dir + return self.current_loader, self.current_data_dir def __repr__(self) -> str: out = "Registered data loaders\n=======================\n\n" out += "Loaders\n-------\n" + "\n".join( - [f"{k}: {v}" for k, v in self.loaders.items()] + [f"{k}: {v}" for k, v in self._loaders.items()] ) out += "\n\n" out += "Aliases\n-------\n" + "\n".join( [ f"{k}: {tuple(v.aliases)}" - for k, v in self.loaders.items() + for k, v in self._loaders.items() if v.aliases is not None ] ) @@ -1678,7 +1916,7 @@ def __repr__(self) -> str: def _repr_html_(self) -> str: rows: list[tuple[str, str, str]] = [("Name", "Aliases", "Loader class")] - for k, v in self.loaders.items(): + for k, v in self._loaders.items(): aliases = ", ".join(v.aliases) if v.aliases is not None else "" # May be either a class or an instance diff --git a/src/erlab/io/exampledata.py b/src/erlab/io/exampledata.py index 8e8ba0eb..5dc24c8e 100644 --- a/src/erlab/io/exampledata.py +++ b/src/erlab/io/exampledata.py @@ -294,7 +294,7 @@ def generate_data_angles( if assign_attributes: out = out.assign_attrs( configuration=int(configuration), - temp_sample=temp, + sample_temp=temp, sample_workfunction=4.5, ) @@ -392,4 +392,4 @@ def generate_gold_edge( rng.poisson(data).astype(float), sigma=ccd_sigma ) - return data.assign_attrs(temp_sample=temp) + return data.assign_attrs(sample_temp=temp) diff --git a/src/erlab/io/igor.py b/src/erlab/io/igor.py index 35b4e0d2..fb5effbc 100644 --- a/src/erlab/io/igor.py +++ b/src/erlab/io/igor.py @@ -2,7 +2,7 @@ __all__ = ["IgorBackendEntrypoint", "load_experiment", "load_igor_hdf5", "load_wave"] -import contextlib +import logging import os from typing import TYPE_CHECKING, Any @@ -20,6 +20,9 @@ from xarray.backends.common import AbstractDataStore +# https://github.com/AFM-analysis/igor2/issues/20 +logging.getLogger("igor2.struct").setLevel(logging.ERROR) + class IgorBackendEntrypoint(BackendEntrypoint): """Backend for Igor Pro files. @@ -58,6 +61,35 @@ def guess_can_open( return ext in {".pxt", ".pxp", ".ibw"} return False + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + recursive: bool = True, + **kwargs, + ) -> xr.DataTree: + if not isinstance(filename_or_obj, str | os.PathLike): + raise TypeError("filename_or_obj must be a string or a path-like object") + return xr.DataTree.from_dict( + self.open_groups_as_dict(filename_or_obj, recursive=recursive) + ) + + def open_groups_as_dict( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + recursive: bool = True, + **kwargs, + ) -> dict[str, xr.Dataset]: + if not isinstance(filename_or_obj, str | os.PathLike): + raise TypeError("filename_or_obj must be a string or a path-like object") + return { + k: v.to_dataset() + for k, v in _load_experiment_raw( + filename_or_obj, recursive=recursive + ).items() + } + def _open_igor_ds( filename: str | os.PathLike[Any], @@ -89,11 +121,6 @@ def _load_experiment_raw( recursive: bool = False, **kwargs, ) -> dict[str, xr.DataArray]: - if folder is None: - split_path: list[Any] = [] - if ignore is None: - ignore = [] - expt = None for bo in [">", "=", "<"]: try: @@ -105,29 +132,40 @@ def _load_experiment_raw( if expt is None: raise OSError("Failed to load the experiment file. Please report this issue.") - waves: dict[str, xr.DataArray] = {} - if isinstance(folder, str): - split_path = folder.split("/") - split_path = [n.encode() for n in split_path] + if folder is None: + split_path: list[bytes] = [] + else: + folder = folder.strip().strip("/") + split_path = [n.encode() for n in folder.split("/")] + + if ignore is None: + ignore = set() expt = expt["root"] for dirname in split_path: expt = expt[dirname] - def unpack_folders(expt) -> None: - for name, record in expt.items(): + def _unpack_folders(contents: dict, parent: str = "") -> dict[str, xr.DataArray]: + # drop: set = set() + waves: dict[str, xr.DataArray] = {} + + for name, record in contents.items(): + decoded_name = name.decode() if isinstance(name, bytes) else name + new_name = f"{parent}/{decoded_name}" if parent else decoded_name + if isinstance(record, igor2.record.WaveRecord): - if prefix is not None and not name.decode().startswith(prefix): + if prefix is not None and not decoded_name.startswith(prefix): continue - if name.decode() in ignore: + if decoded_name in ignore: continue - waves[name.decode()] = load_wave(record, **kwargs) - elif isinstance(record, dict): - if recursive: - unpack_folders(record) + waves[new_name] = load_wave(record, **kwargs) + + elif isinstance(record, dict) and recursive: + waves.update(_unpack_folders(record, new_name)) + + return waves - unpack_folders(expt) - return waves + return _unpack_folders(expt) def load_experiment( @@ -280,42 +318,36 @@ def load_wave( # data_units = d["data_units"].decode() axis_units = [d["dimension_units"].decode()] - def get_dim_name(index): - dim = dim_labels[index] - unit = axis_units[index] + coords = {} + for i, (a, b, c) in enumerate(zip(sfA, sfB, shape, strict=True)): + if c == 0: + continue + + dim, unit = dim_labels[i], axis_units[i] + if dim == "": if unit == "": - return DEFAULT_DIMS[index] - return unit - if unit == "": - return dim - return f"{dim} ({unit})" - - dims = [get_dim_name(i) for i in range(_MAXDIM)] - coords = { - dims[i]: np.linspace(b, b + a * (c - 1), c) - for i, (a, b, c) in enumerate(zip(sfA, sfB, shape, strict=True)) - if c != 0 - } - - attrs = {} - for ln in d.get("note", "").decode().splitlines(): + dim = DEFAULT_DIMS[i] + else: + # If dim is empty, but the unit is not, use the unit as the dim name + dim, unit = unit, "" + + coords[dim] = np.linspace(b, b + a * (c - 1), c) + if unit != "": + coords[dim] = xr.DataArray(coords[dim], dims=(dim,), attrs={"units": unit}) + + attrs: dict[str, int | float | str] = {} + for ln in d.get("note", b"").decode().splitlines(): if "=" in ln: - k, v = ln.split("=", 1) + key, value = ln.split("=", maxsplit=1) try: - v = int(v) + attrs[key] = int(value) except ValueError: - with contextlib.suppress(ValueError): - v = float(v) - attrs[k] = v + try: + attrs[key] = float(value) + except ValueError: + attrs[key] = value return xr.DataArray( d["wData"], dims=coords.keys(), coords=coords, attrs=attrs ).rename(wave_header["bname"].decode()) - - -load_pxp = load_experiment -"""Alias for :func:`load_experiment`.""" - -load_ibw = load_wave -"""Alias for :func:`load_wave`.""" diff --git a/src/erlab/io/nexusutils.py b/src/erlab/io/nexusutils.py index ae982f8c..a519b3eb 100644 --- a/src/erlab/io/nexusutils.py +++ b/src/erlab/io/nexusutils.py @@ -24,7 +24,7 @@ def _parse_value(value): return _parse_value(value.nxdata) if isinstance(value, np.ndarray) and value.size == 1: return _parse_value(np.atleast_1d(value)[0]) - if isinstance(value, np.number): + if isinstance(value, np.generic): # Convert to native Python type return value.item() return value @@ -77,7 +77,7 @@ def _parse_group( exclude List of full paths to exclude from the output. parse - Wheter to parse the values of NXfields to native Python types. + Whether to parse the values of NXfields to native Python types. Note ---- @@ -151,19 +151,19 @@ def get_primary_coords(group: NXgroup) -> list[NXfield]: Returns ------- - list of NXfield + fields_primary : list of NXfield """ - out: list[NXfield] = [] - _get_primary_coords(group, out) - return sorted(out, key=lambda field: int(field.axis)) + fields_primary: list[NXfield] = [] + _get_primary_coords(group, fields_primary) + return sorted(fields_primary, key=lambda field: int(field.axis)) def get_non_primary_coords(group: NXgroup) -> list[NXfield]: """Get all non-primary coordinates in a group. Retrieves all fields with the attribute `axis` in the group and its subgroups - recursively. The output list in the order of traversal. + recursively. The output list is sorted by the order of traversal. Parameters ---------- @@ -172,12 +172,12 @@ def get_non_primary_coords(group: NXgroup) -> list[NXfield]: Returns ------- - list of NXfield + fields_non_primary : list of NXfield """ - out: list[NXfield] = [] - _get_non_primary_coords(group, out) - return out + fields_non_primary: list[NXfield] = [] + _get_non_primary_coords(group, fields_non_primary) + return fields_non_primary def get_primary_coord_dict( @@ -311,13 +311,13 @@ def nexus_group_to_dict( exclude List of paths to exclude from the output. relative - Wheter to use the relative or absolute paths of the items. If `True`, the keys + Whether to use the relative or absolute paths of the items. If `True`, the keys are the paths of the items relative to the path of the group. If `False`, the keys are the absolute paths of the items relative to the root of the NeXus file. replace_slash - Wheter to replace the slashes in the paths with dots. + Whether to replace the slashes in the paths with dots. parse - Wheter to coerce the values of NXfields to native Python types. + Whether to coerce the values of NXfields to native Python types. """ if exclude is None: @@ -344,6 +344,10 @@ def nxfield_to_xarray(field: NXfield, no_dims: bool = False) -> xr.DataArray: field The NeXus field to be converted. + Returns + ------- + DataArray + """ attrs = _remove_axis_attrs(field.attrs) @@ -362,7 +366,9 @@ def nxfield_to_xarray(field: NXfield, no_dims: bool = False) -> xr.DataArray: def nxgroup_to_xarray( - group: NXgroup, data: str | Callable[[NXgroup], NXfield] + group: NXgroup, + data: str | Callable[[NXgroup], NXfield], + without_values: bool = False, ) -> xr.DataArray: """Convert a NeXus group to an xarray DataArray. @@ -379,10 +385,13 @@ def nxgroup_to_xarray( - If a callable, it must be a function that takes ``group`` as an argument and returns the `NXfield ` containing the data values. + without_values + If `True`, the returned DataArray values will be filled with zeros. Use this to + check the coords or attrs quickly without loading in the full data. Returns ------- - xarray.DataArray + DataArray The DataArray containing the data. Dimension and coordinate names are the relative paths of the corresponding NXfields, with the slashes replaced by dots. @@ -422,6 +431,9 @@ def _make_coord_relative(t: xr.DataArray | tuple) -> xr.DataArray | tuple: dims = tuple(_make_relative(d) for d in dims) coords = {_make_relative(k): _make_coord_relative(v) for k, v in coords.items()} + + if without_values: + values = np.zeros(values.shape, values.dtype) return xr.DataArray(values, dims=dims, coords=coords, attrs=attrs) @@ -436,6 +448,11 @@ def get_entry(filename: str | os.PathLike, entry: str | None = None) -> NXentry: The path of the entry to get. If `None`, the first entry in the file is returned. + Returns + ------- + entry : NXentry + The NXentry object obtained from the file. + """ root = nxload(filename) if entry is None: diff --git a/src/erlab/io/plugins/__init__.py b/src/erlab/io/plugins/__init__.py index a63f3450..cf0d82dc 100644 --- a/src/erlab/io/plugins/__init__.py +++ b/src/erlab/io/plugins/__init__.py @@ -24,9 +24,13 @@ import importlib import pathlib -import traceback import warnings + +class PluginImportWarning(UserWarning): + """Issued when a plugin fails to load.""" + + for path in pathlib.Path(__file__).resolve().parent.iterdir(): if ( path.is_file() @@ -38,7 +42,8 @@ importlib.import_module(module_name) except Exception: warnings.warn( - f"Failed to load module {module_name} due to the following error:\n" - f"{traceback.format_exc()}", + f"Failed to load '{module_name}'. " + f"Import the module to trace the error.", + PluginImportWarning, stacklevel=1, ) diff --git a/src/erlab/io/plugins/da30.py b/src/erlab/io/plugins/da30.py index a0aff5c6..fba01ff7 100644 --- a/src/erlab/io/plugins/da30.py +++ b/src/erlab/io/plugins/da30.py @@ -6,6 +6,7 @@ import configparser import os +import re import tempfile import zipfile from collections.abc import Iterable @@ -13,12 +14,14 @@ import numpy as np import xarray as xr -from xarray.core.datatree import DataTree +import erlab.io from erlab.io.dataloader import LoaderBase class CasePreservingConfigParser(configparser.ConfigParser): + """ConfigParser that preserves the case of the keys.""" + def optionxform(self, optionstr): return str(optionstr) @@ -42,13 +45,13 @@ def file_dialog_methods(self): return {"DA30 Raw Data (*.ibw *.pxt *.zip)": (self.load, {})} def load_single( - self, file_path: str | os.PathLike - ) -> xr.DataArray | xr.Dataset | DataTree: + self, file_path: str | os.PathLike, without_values: bool = False + ) -> xr.DataArray | xr.Dataset | xr.DataTree: ext = os.path.splitext(file_path)[-1] match ext: case ".ibw": - data: xr.DataArray | xr.Dataset | DataTree = xr.load_dataarray( + data: xr.DataArray | xr.Dataset | xr.DataTree = xr.load_dataarray( file_path, engine="erlab-igor" ) @@ -60,13 +63,35 @@ def load_single( data = data[next(iter(data.data_vars))] case ".zip": - data = load_zip(file_path) + data = load_zip(file_path, without_values) case _: raise ValueError(f"Unsupported file extension {ext}") return data + def identify(self, num: int, data_dir: str | os.PathLike): + for file in erlab.io.utils.get_files( + data_dir, extensions=(".ibw", ".pxt", ".zip") + ): + match file.suffix: + case ".zip": + m = re.match(r"(.*?)" + str(num).zfill(4), file.stem) + + case ".pxt": + m = re.match(r"(.*?)" + str(num).zfill(4), file.stem) + + case ".ibw": + m = re.match( + r"(.*?)" + str(num).zfill(4) + ".*" + str(num).zfill(3), + file.stem, + ) + + if m is not None: + return [file], {} + + return None + def post_process(self, data: xr.DataArray) -> xr.DataArray: data = super().post_process(data) @@ -75,10 +100,22 @@ def post_process(self, data: xr.DataArray) -> xr.DataArray: return data + def files_for_summary(self, data_dir: str | os.PathLike): + return sorted( + erlab.io.utils.get_files(data_dir, extensions=(".pxt", ".ibw", ".zip")) + ) + def load_zip( - filename: str | os.PathLike, -) -> xr.DataArray | xr.Dataset | DataTree: + filename: str | os.PathLike, without_values: bool = False +) -> xr.DataArray | xr.Dataset | xr.DataTree: + """Load data from a ``.zip`` file from a Scienta Omicron DA30 analyzer. + + If the file contains a single region, a DataArray is returned. If the file contains + multiple regions that can be merged without conflicts, a Dataset is returned. If the + regions cannot be merged without conflicts, a DataTree containing all regions is + returned. + """ with zipfile.ZipFile(filename) as z: regions: list[str] = [ fn[9:-4] @@ -90,7 +127,6 @@ def load_zip( with tempfile.TemporaryDirectory() as tmp_dir: z.extract(f"Spectrum_{region}.ini", tmp_dir) z.extract(f"{region}.ini", tmp_dir) - z.extract(f"Spectrum_{region}.bin", tmp_dir) region_info = parse_ini( os.path.join(tmp_dir, f"Spectrum_{region}.ini") @@ -99,9 +135,12 @@ def load_zip( for d in parse_ini(os.path.join(tmp_dir, f"{region}.ini")).values(): attrs.update(d) - arr = np.fromfile( - os.path.join(tmp_dir, f"Spectrum_{region}.bin"), dtype=np.float32 - ) + if not without_values: + z.extract(f"Spectrum_{region}.bin", tmp_dir) + arr = np.fromfile( + os.path.join(tmp_dir, f"Spectrum_{region}.bin"), + dtype=np.float32, + ) shape = [] coords = {} @@ -114,13 +153,13 @@ def load_zip( offset, offset + (n - 1) * delta, n ) + if not without_values: + arr = arr.reshape(shape) + else: + arr = np.zeros(shape, dtype=np.float32) + out.append( - xr.DataArray( - arr.reshape(shape), - coords=coords, - name=region_info["name"], - attrs=attrs, - ) + xr.DataArray(arr, coords=coords, name=region_info["name"], attrs=attrs) ) if len(out) == 1: @@ -131,16 +170,30 @@ def load_zip( return xr.merge(out, join="exact") except: # noqa: E722 # On failure, combine into DataTree - return DataTree.from_dict( + return xr.DataTree.from_dict( {str(da.name): da.to_dataset(promote_attrs=True) for da in out} ) +def _parse_value(value): + if isinstance(value, str): + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value + + def parse_ini(filename: str | os.PathLike) -> dict: + """Parse an ``.ini`` file into a dictionary.""" parser = CasePreservingConfigParser(strict=False) out = {} with open(filename, encoding="utf-8") as f: parser.read_file(f) for section in parser.sections(): - out[section] = dict(parser.items(section)) + out[section] = {k: _parse_value(v) for k, v in parser.items(section)} return out diff --git a/src/erlab/io/plugins/esm.py b/src/erlab/io/plugins/esm.py index 0ccbdc91..78fc0afd 100644 --- a/src/erlab/io/plugins/esm.py +++ b/src/erlab/io/plugins/esm.py @@ -1,10 +1,7 @@ """Data loader for beamline ID21 ESM at NSLS-II.""" -import os -import re from typing import ClassVar -import erlab.io.utils from erlab.io.plugins.da30 import DA30Loader @@ -16,25 +13,3 @@ class ESMLoader(DA30Loader): coordinate_attrs = ("beta", "hv") additional_attrs: ClassVar[dict] = {"configuration": 3} - - def identify(self, num: int, data_dir: str | os.PathLike): - for file in erlab.io.utils.get_files( - data_dir, extensions=(".ibw", ".pxt", ".zip") - ): - match file.suffix: - case ".zip": - m = re.match(r"(.*?)" + str(num).zfill(4) + r".zip", file.name) - - case ".pxt": - m = re.match(r"(.*?)" + str(num).zfill(4) + r".pxt", file.name) - - case ".ibw": - m = re.match( - r"(.*?)" + str(num).zfill(4) + str(num).zfill(3) + r".ibw", - file.name, - ) - - if m is not None: - return [file], {} - - return None diff --git a/src/erlab/io/plugins/i05.py b/src/erlab/io/plugins/i05.py index ac700b96..fedaf9ea 100644 --- a/src/erlab/io/plugins/i05.py +++ b/src/erlab/io/plugins/i05.py @@ -27,7 +27,7 @@ class I05Loader(LoaderBase): "y": "instrument.manipulator.say", "z": "instrument.manipulator.saz", "hv": "instrument.monochromator.energy", - "temp_sample": "sample.temperature", + "sample_temp": "sample.temperature", } coordinate_attrs = ("beta", "delta", "chi", "xi", "hv", "x", "y", "z") @@ -40,8 +40,10 @@ class I05Loader(LoaderBase): def file_dialog_methods(self): return {"Diamond I05 Raw Data (*.nxs)": (self.load, {})} - def load_single(self, file_path) -> xr.DataArray: - out = nxgroup_to_xarray(get_entry(file_path), "analyser/data").squeeze() + def load_single(self, file_path, without_values=False) -> xr.DataArray: + out = nxgroup_to_xarray( + get_entry(file_path), "analyser/data", without_values + ).squeeze() if ( "instrument.centre_energy.centre_energy" in out.dims diff --git a/src/erlab/io/plugins/kriss.py b/src/erlab/io/plugins/kriss.py index a4f65112..f47a17ac 100644 --- a/src/erlab/io/plugins/kriss.py +++ b/src/erlab/io/plugins/kriss.py @@ -1,37 +1,16 @@ """Plugin for data acquired at KRISS.""" -import os -import re from typing import ClassVar -import erlab.io.utils from erlab.io.plugins.da30 import DA30Loader class KRISSLoader(DA30Loader): name = "kriss" - aliases = ("KRISS",) - coordinate_attrs = ("beta", "chi", "xi", "hv", "x", "y", "z") - additional_attrs: ClassVar[dict] = {"configuration": 4} @property def name_map(self): return super().name_map | {"chi": "ThetaY", "xi": "ThetaX"} - - def identify(self, num: int, data_dir: str | os.PathLike): - for file in erlab.io.utils.get_files(data_dir, extensions=(".ibw", ".zip")): - if file.suffix == ".zip": - match = re.match(r"(.*?)" + str(num).zfill(4) + r".zip", str(file)) - else: - match = re.match( - r"(.*?)" + str(num).zfill(4) + str(num).zfill(3) + r".ibw", - str(file), - ) - - if match is not None: - return [file], {} - - return None diff --git a/src/erlab/io/plugins/lorea.py b/src/erlab/io/plugins/lorea.py index d29100b3..6be5b123 100644 --- a/src/erlab/io/plugins/lorea.py +++ b/src/erlab/io/plugins/lorea.py @@ -33,7 +33,7 @@ class LOREALoader(LoaderBase): "y": "instrument.manipulator.say", "z": "instrument.manipulator.saz", "hv": "instrument.monochromator.energy", - "temp_sample": "sample.temperature", + "sample_temp": "sample.temperature", } coordinate_attrs = ("beta", "delta", "chi", "xi", "hv", "x", "y", "z") @@ -46,11 +46,11 @@ class LOREALoader(LoaderBase): def file_dialog_methods(self): return {"ALBA BL20 LOREA Raw Data (*.nxs, *.krx)": (self.load, {})} - def load_single(self, file_path) -> xr.DataArray: + def load_single(self, file_path, without_values: bool = False) -> xr.DataArray: if pathlib.Path(file_path).suffix == ".krx": return self._load_krx(file_path) - return nxgroup_to_xarray(get_entry(file_path), _get_data) + return nxgroup_to_xarray(get_entry(file_path), _get_data, without_values) def identify(self, num, data_dir, krax=False): file = None diff --git a/src/erlab/io/plugins/maestro.py b/src/erlab/io/plugins/maestro.py index bb7f09fd..8860f3b4 100644 --- a/src/erlab/io/plugins/maestro.py +++ b/src/erlab/io/plugins/maestro.py @@ -10,68 +10,66 @@ import os import re -import warnings from pathlib import Path from typing import TYPE_CHECKING, ClassVar import numpy as np import xarray as xr -# from xarray.backends.api import open_dataarray import erlab.io from erlab.io.dataloader import LoaderBase +from erlab.utils.misc import emit_user_level_warning if TYPE_CHECKING: from collections.abc import Hashable -def open_datatree(file_path, **kwargs): - """:meta private:""" # noqa: D400 - # Temporary fix for https://github.com/pydata/xarray/issues/9427 - from xarray.backends.api import open_groups - from xarray.core.datatree import DataTree - - return DataTree.from_dict(open_groups(file_path, **kwargs)) - - -def get_cache_file(file_path): +def get_cache_file(file_path: str | os.PathLike) -> Path: file_path = Path(file_path) data_dir = file_path.parent cache_dir = data_dir.with_name(f".{data_dir.name}_cache") return cache_dir.joinpath(file_path.stem + "_2D_Data" + file_path.suffix) -def cache_as_float32(file_path): +def cache_as_float32( + file_path: str | os.PathLike, data: xr.Dataset, without_values: bool +) -> xr.DataArray: """Cache and return the 2D part of the data as a float32 DataArray. + If the cache file exists, it is loaded and returned. + Loading MAESTRO `.h5` files is slow because the data is stored in double precision. - This function caches the 2D Data part in float32 to speed up subsequent loading. As - a consequence, the loader will fail in read-only file systems. + This function caches the 2D Data part in float32 to speed up subsequent loading. - """ - dt = open_datatree(file_path, engine="h5netcdf", phony_dims="sort") + Caching is disabled in read-only file systems. + """ cache_file = get_cache_file(file_path) + if cache_file.is_file(): - return None + return xr.open_dataarray(cache_file, engine="h5netcdf") - if not cache_file.parent.is_dir(): - cache_file.parent.mkdir(parents=True) + writable: bool = os.access(cache_file.parent.parent, os.W_OK) - data = dt["2D_Data"].load().to_dataset().astype(np.float32) + if writable and not cache_file.parent.is_dir() and not without_values: + cache_file.parent.mkdir(parents=True) if len(data.data_vars) > 1: - warnings.warn( + emit_user_level_warning( "More than one data variable is present in the data." "Only the first one will be used", - stacklevel=2, ) # Get the first data variable data = data[next(iter(data.data_vars))] - # Save cache - data.to_netcdf(cache_file, engine="h5netcdf") + if without_values: + data = xr.DataArray(np.zeros(data.shape), dims=data.dims, attrs=data.attrs) + + elif writable: + # Save cache + data = data.astype(np.float32) + data.to_netcdf(cache_file, engine="h5netcdf") return data @@ -82,18 +80,27 @@ class MAESTROMicroLoader(LoaderBase): aliases = ("ALS_BL7", "als_bl7", "BL702", "bl702") name_map: ClassVar[dict] = { - "x": "LMOTOR0", - "y": "LMOTOR1", - "z": "LMOTOR2", - "chi": "LMOTOR3", # Theta, polar - "xi": "LMOTOR4", # Beta, tilt - "delta": "LMOTOR5", # Phi, azimuth - "beta": ("Slit Defl", "LMOTOR9"), - "hv": ("MONOEV", "BL_E"), - "temp_sample": "Cryostat_A", - "polarization": "EPU Polarization", + "x": "Motors_Logical.X", + "y": "Motors_Logical.Y", + "z": "Motors_Logical.Z", + "chi": "Motors_Logical.Theta", # polar + "xi": "Motors_Logical.Beta", # tilt + "delta": "Motors_Logical.Phi", # azimuth + "beta": ("Slit Defl", "Motors_Logical.Slit Defl"), + "hv": ("MONOEV", "Beamline.Beamline Energy"), + "sample_temp": "Cryostat_D", } - coordinate_attrs = ("beta", "delta", "chi", "xi", "hv", "x", "y", "z") + coordinate_attrs = ( + "beta", + "delta", + "hv", + "sample_temp", + "chi", + "xi", + "x", + "y", + "z", + ) additional_attrs: ClassVar[dict] = {} skip_validate: bool = True @@ -114,14 +121,18 @@ def identify(self, num, data_dir): return [file], {} - def load_single(self, file_path) -> xr.DataArray: - cache_file = get_cache_file(file_path) - dt = open_datatree(file_path, engine="h5netcdf", phony_dims="sort") + def load_single(self, file_path, without_values: bool = False) -> xr.DataArray: + groups = xr.open_groups(file_path, engine="h5netcdf", phony_dims="sort") + + if "PreScan" in groups["/Comments"]: + pre_scan: str = groups["/Comments"]["PreScan"].item()[0].decode() + else: + pre_scan = "" - if "PreScan" in dt["Comments"]: - comment: str = dt["Comments"]["PreScan"].item()[0].decode() + if "PostScan" in groups["/Comments"]: + post_scan: str = groups["/Comments"]["PostScan"].item()[0].decode() else: - comment = "" + post_scan = "" def _parse_attr(v) -> str | int | float: """Strip quotes and convert numerical strings to int or float.""" @@ -138,8 +149,7 @@ def _parse_attr(v) -> str | int | float: return v nested_attrs: dict[Hashable, dict[str, tuple[str, str | int | float]]] = {} - combined_attrs: dict[str, str | int | float] = {} - for key, val in dt["Headers"].data_vars.items(): + for key, val in groups["/Headers"].data_vars.items(): # v given as (longname, name, value, comment) # we want to extract the name, comment and value nested_attrs[key] = { @@ -147,47 +157,49 @@ def _parse_attr(v) -> str | int | float: for v in val.values } - combined_attrs = { - **combined_attrs, - **{k: v[1] for k, v in nested_attrs[key].items()}, - } + human_readable_attrs: dict[str, str | int | float] = {} + # Final attributes are stored here + # Keys are in the form "group_name.commment" - if "LWLVNM" in combined_attrs: - scan_type: str = str(combined_attrs["LWLVNM"]) + for group_name, contents_dict in nested_attrs.items(): + for k, v in contents_dict.items(): + new_key = f"{group_name}.{k}" if v[0] == "" else f"{group_name}.{v[0]}" + human_readable_attrs[new_key] = v[1] - lwlvlpn = int(combined_attrs["LWLVLPN"]) # number of low level loops + scan_attrs: dict[str, str | int | float] = { + k: v[1] for k, v in nested_attrs.get("Low_Level_Scan", {}).items() + } + + if "LWLVNM" in scan_attrs: + scan_type: str = str(scan_attrs["LWLVNM"]) + + lwlvlpn = int(scan_attrs["LWLVLPN"]) # number of low level loops motors: list[str] = [] motor_shape: list[int] = [] for i in range(lwlvlpn): nmsbdv = int( - combined_attrs[f"NMSBDV{i}"] + scan_attrs[f"NMSBDV{i}"] ) # number of subdevices in i-th loop for j in range(nmsbdv): nm = str( - combined_attrs[f"NM_{i}_{j}"] + scan_attrs[f"NM_{i}_{j}"] ) # name of j-th subdevice in i-th loop nm = nm.replace("CRYO-", "").strip() motors.append(nm) - motor_shape.append(int(combined_attrs[f"N_{i}_{j}"])) + motor_shape.append(int(scan_attrs[f"N_{i}_{j}"])) else: scan_type = "unknown" motors = ["XY"] # Get coords - coords = ( - dt["0D_Data"].load().rename({"phony_dim_0": "phony_dim_3"}).to_dataset() - ) + coords = groups["/0D_Data"].rename({"phony_dim_0": "phony_dim_3"}) if len(motors) == 1: coords = coords.swap_dims({"phony_dim_3": motors[0]}) - if cache_file.is_file(): - # Load cache - data = xr.load_dataarray(cache_file, engine="h5netcdf") - else: - # Create cache - data = cache_as_float32(file_path) + # Create or load cache + data = cache_as_float32(file_path, groups["/2D_Data"], without_values) coord_dict = { name: np.linspace(offset, offset + (size - 1) * delta, size) @@ -214,13 +226,13 @@ def _parse_attr(v) -> str | int | float: .unstack("phony_dim_3") ) - # The configuration is hardcoded to 3, which is for vertical analyzer slit and - # deflector map. For horizontal slit configuration or beta maps, coordinates and - # the attribute must be changed accordingly. + # The configuration is hardcoded to 3, which is for vertical analyzer slit with + # deflector map. For horizontal slit configuration and/or tilt/polar maps, + # coordinates and the attribute must be changed accordingly. data.attrs = { "scan_type": scan_type, - "comment": comment, + "pre_scan": pre_scan, + "post_scan": post_scan, "configuration": 3, - "nested_attrs": nested_attrs, } - return data.assign_attrs(combined_attrs).squeeze() + return data.assign_attrs(human_readable_attrs).squeeze() diff --git a/src/erlab/io/plugins/merlin.py b/src/erlab/io/plugins/merlin.py index bf988333..eb4b678a 100644 --- a/src/erlab/io/plugins/merlin.py +++ b/src/erlab/io/plugins/merlin.py @@ -6,11 +6,10 @@ import re import warnings from collections.abc import Callable -from typing import Any, ClassVar, cast +from typing import Any, ClassVar import numpy as np import numpy.typing as npt -import pandas as pd import xarray as xr import erlab.io.utils @@ -22,6 +21,27 @@ def _format_polarization(val) -> str: return {0: "LH", 2: "LV", -1: "RC", 1: "LC"}.get(val, str(val)) +def _parse_time(darr: xr.DataArray) -> datetime.datetime: + return datetime.datetime.strptime( + f"{darr.attrs['Date']} {darr.attrs['Time']}", + "%d/%m/%Y %I:%M:%S %p", + ) + + +def _determine_kind(data: xr.DataArray) -> str: + if "scan_type" in data.attrs and data.attrs["scan_type"] == "live": + return "LP" if "beta" in data.dims else "LXY" + + data_type = "xps" + if "alpha" in data.dims: + data_type = "cut" + if "beta" in data.dims: + data_type = "map" + if "hv" in data.dims: + data_type = "hvdep" + return data_type + + class MERLINLoader(LoaderBase): name = "merlin" @@ -37,7 +57,7 @@ class MERLINLoader(LoaderBase): "z": "Sample Z", "hv": "BL Energy", "polarization": "EPU POL", - "temp_sample": "Temperature Sensor B", + "sample_temp": "Temperature Sensor B", "mesh_current": "Mesh Current", } coordinate_attrs = ( @@ -50,12 +70,10 @@ class MERLINLoader(LoaderBase): "z", "polarization", "mesh_current", - "temp_sample", + "sample_temp", ) - additional_attrs: ClassVar[dict] = { - "configuration": 1, - "sample_workfunction": 4.44, - } + additional_attrs: ClassVar[dict] = {"configuration": 1} + formatters: ClassVar[dict[str, Callable]] = { "polarization": _format_polarization, "Lens Mode": lambda x: x.replace("Angular", "A"), @@ -63,13 +81,38 @@ class MERLINLoader(LoaderBase): "Exit Slit": round, "Slit Plate": round, } + + summary_attrs: ClassVar[dict[str, str | Callable[[xr.DataArray], Any]]] = { + "time": _parse_time, + "type": _determine_kind, + "lens mode": "Lens Mode", + "mode": "Acquisition Mode", + "temperature": "sample_temp", + "pass energy": "Pass Energy", + "analyzer slit": "Slit Plate", + "pol": "polarization", + "hv": "hv", + "entrance slit": "Entrance Slit", + "exit slit": "Exit Slit", + "polar": "beta", + "tilt": "xi", + "azi": "delta", + "x": "x", + "y": "y", + "z": "z", + } + + summary_sort = "time" + always_single = False @property def file_dialog_methods(self): return {"ALS BL4.0.3 Raw Data (*.pxt *.ibw)": (self.load, {})} - def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: + def load_single( + self, file_path: str | os.PathLike, without_values: bool = False + ) -> xr.DataArray: if os.path.splitext(file_path)[1] == ".ibw": return self._load_live(file_path) @@ -142,11 +185,6 @@ def post_process(self, data: xr.DataArray) -> xr.DataArray: if "eV" in data.coords: data = data.assign_coords(eV=-data.eV.values) - if "temp_sample" in data.coords: - # Add temperature to attributes, for backwards compatibility - temp = float(data.temp_sample.mean()) - data = data.assign_attrs(temp_sample=temp) - return data def load_live(self, identifier, data_dir): @@ -166,100 +204,8 @@ def _load_live(self, file_path: str | os.PathLike) -> xr.DataArray: if k in wave.dims } ) + wave = wave.assign_attrs(scan_type="live") return wave.assign_coords(eV=-wave["eV"] + wave.attrs["BL Energy"]) - def generate_summary( - self, data_dir: str | os.PathLike, exclude_live: bool = False - ) -> pd.DataFrame: - files: dict[str, str] = {} - - for pth in erlab.io.utils.get_files(data_dir, extensions=(".pxt",)): - data_name = pth.stem - name_match = re.match(r"(.*?_\d{3})_(?:_S\d{3})?", data_name) - if name_match is not None: - data_name = name_match.group(1) - files[data_name] = str(pth) - - if not exclude_live: - for file in os.listdir(data_dir): - if file.endswith(".ibw"): - data_name = os.path.splitext(file)[0] - files[data_name] = os.path.join(data_dir, file) - - summary_attrs: dict[str, str] = { - "Lens Mode": "Lens Mode", - "Scan Type": "Acquisition Mode", - "T(K)": "temp_sample", - "Pass E": "Pass Energy", - "Analyzer Slit": "Slit Plate", - "Polarization": "polarization", - "hv": "hv", - "Entrance Slit": "Entrance Slit", - "Exit Slit": "Exit Slit", - "x": "x", - "y": "y", - "z": "z", - "polar": "beta", - "tilt": "xi", - "azi": "delta", - } - - cols = ["File Name", "Path", "Time", "Type", *summary_attrs.keys()] - - data_info = [] - processed_indices: list[int] = [] - - def _add_darr(dname: str, file: str, darr: xr.DataArray, live: bool = False): - if live: - data_type = "LP" if "beta" in darr.dims else "LXY" - else: - data_type = "core" - if "alpha" in darr.dims: - data_type = "cut" - if "beta" in darr.dims: - data_type = "map" - if "hv" in darr.dims: - data_type = "hvdep" - data_info.append( - [ - dname, - file, - datetime.datetime.strptime( - f"{darr.attrs['Date']} {darr.attrs['Time']}", - "%d/%m/%Y %I:%M:%S %p", - ), - data_type, - *( - self.get_formatted_attr_or_coord(darr, k) - for k in summary_attrs.values() - ), - ] - ) - - for name, path in files.items(): - if os.path.splitext(path)[1] == ".ibw": - _add_darr( - name, path, darr=cast(xr.DataArray, self.load(path)), live=True - ) - else: - idx, _ = self.infer_index(os.path.splitext(os.path.basename(path))[0]) - if idx in processed_indices: - continue - - if idx is not None: - processed_indices.append(idx) - - data = cast(xr.DataArray | xr.Dataset, self.load(path)) - - if isinstance(data, xr.Dataset): - for k, darr in data.data_vars.items(): - _add_darr(f"{name}_{k}", path, darr) - else: - _add_darr(name, path, data) - del data - - return ( - pd.DataFrame(data_info, columns=cols) - .sort_values("Time") - .set_index("File Name") - ) + def files_for_summary(self, data_dir: str | os.PathLike): + return sorted(erlab.io.utils.get_files(data_dir, extensions=(".pxt", ".ibw"))) diff --git a/src/erlab/io/plugins/ssrl52.py b/src/erlab/io/plugins/ssrl52.py index fe503935..982bdd65 100644 --- a/src/erlab/io/plugins/ssrl52.py +++ b/src/erlab/io/plugins/ssrl52.py @@ -3,16 +3,27 @@ import datetime import os import re -import warnings -from typing import ClassVar, cast +from collections.abc import Callable +from typing import Any, ClassVar import h5netcdf import numpy as np -import pandas as pd import xarray as xr import erlab.io.utils from erlab.io.dataloader import LoaderBase +from erlab.utils.misc import emit_user_level_warning + + +def _format_polarization(val) -> str: + val = float(np.round(val, 3)) + return {0.0: "LH", 0.5: "LV", 0.25: "RC", -0.25: "LC"}.get(val, str(val)) + + +def _parse_value(value): + if isinstance(value, np.generic): + return value.item() + return value class SSRL52Loader(LoaderBase): @@ -30,7 +41,7 @@ class SSRL52Loader(LoaderBase): "y": "Y", "z": "Z", "hv": ["energy", "photon_energy"], - "temp_sample": ["TB", "sample_stage_temperature"], + "sample_temp": ["TB", "sample_stage_temperature"], "sample_workfunction": "WorkFunction", } @@ -41,6 +52,32 @@ class SSRL52Loader(LoaderBase): "sample_workfunction": 4.5, } + formatters: ClassVar[dict[str, Callable]] = { + "CreationTimeStamp": datetime.datetime.fromtimestamp, + "PassEnergy": round, + "polarization": _format_polarization, + } + + summary_attrs: ClassVar[dict[str, str | Callable[[xr.DataArray], Any]]] = { + "time": "CreationTimeStamp", + "type": "Description", + "lens mode": "LensModeName", + "region": "RegionName", + "temperature": "sample_temp", + "pass energy": "PassEnergy", + "pol": "polarization", + "hv": "hv", + "polar": "chi", + "tilt": "xi", + "azi": "delta", + "deflector": "beta", + "x": "x", + "y": "y", + "z": "z", + } + + summary_sort = "time" + always_single: bool = True skip_validate: bool = True @@ -48,7 +85,9 @@ class SSRL52Loader(LoaderBase): def file_dialog_methods(self): return {"SSRL BL5-2 Raw Data (*.h5)": (self.load, {})} - def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: + def load_single( + self, file_path: str | os.PathLike, without_values: bool = False + ) -> xr.DataArray: is_hvdep: bool = False dim_mapping: dict[str, str] = {} @@ -92,7 +131,7 @@ def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: dim_mapping = { f"phony_dim_{i}": str(ax["label"]) for i, ax in enumerate(axes) } - data = ds.rename_dims(dim_mapping).load() + data = ds.rename_dims(dim_mapping) # Apply coordinates for i, ax in enumerate(axes): @@ -114,12 +153,11 @@ def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: # For now, just ignore them and use beamline attributes continue if ax["label"] != "Kinetic Energy": - warnings.warn( + emit_user_level_warning( "Undefined offset for non-energy axis. This was " "not taken into account while writing the loader " "code. Please report this issue. Resulting data " "may be incorrect", - stacklevel=1, ) continue is_hvdep = True @@ -139,12 +177,11 @@ def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: delta = np.array(ncf["MapInfo"][ax["delta"][8:]]) # may be ~1e-8 difference between values if not np.allclose(delta, delta[0], atol=1e-7): - warnings.warn( + emit_user_level_warning( "Non-uniform delta for hv-dependent scan. This " "was not taken into account while writing the " "loader code. Please report this issue. " "Resulting data may be incorrect", - stacklevel=1, ) delta = delta[0] else: @@ -163,6 +200,8 @@ def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: data = data.assign_coords({ax["label"]: coord}) + attrs = {k: _parse_value(v) for k, v in attrs.items()} + coord_names = list(data.coords.keys()) coord_sizes = [len(data[coord]) for coord in coord_names] coord_attrs: dict = {} @@ -183,44 +222,48 @@ def load_single(self, file_path: str | os.PathLike) -> xr.DataArray: same_length_indices.remove(idx) if len(same_length_indices) != 1: # Multiple dimensions with the same length, ambiguous - warnings.warn( + emit_user_level_warning( f"Ambiguous length for {k}. This was not taken into " "account while writing the loader code. Please report this " "issue. Resulting data may be incorrect", - stacklevel=1, ) idx = same_length_indices[-1] coord_attrs[k] = xr.DataArray(var, dims=[coord_names[idx]]) - if is_hvdep: - data = data.assign_coords( - { - "Binding Energy": data["Binding Energy"] - - data["energy"].values[0] - + attrs.get("WorkFunction", 4.465) - } - ) + if is_hvdep: + data = data.assign_coords( + { + "Binding Energy": data["Binding Energy"] + - data["energy"].values[0] + + attrs.get("WorkFunction", 4.465) + } + ) - # data = data.rename(energy="hv") + # data = data.rename(energy="hv") - if "time" in data.variables: - # Normalize by dwell time - darr = data["spectrum"] / data["time"] - else: darr = data["spectrum"] - darr = darr.assign_attrs(attrs) + if not without_values: + darr = darr.load() # Load into memory before closing file + if "time" in data.variables: + # Normalize by dwell time + darr = darr / data["time"] + + else: + darr = xr.DataArray( + np.zeros(darr.shape, darr.dtype), + coords=darr.coords, + dims=darr.dims, + attrs=darr.attrs, + ) + + darr = darr.assign_attrs(attrs) return darr.assign_coords(coord_attrs) def post_process(self, data: xr.DataArray) -> xr.DataArray: data = super().post_process(data) - if "temp_sample" in data.coords: - # Add temperature to attributes - temp = float(data.temp_sample.mean()) - data = data.assign_attrs(temp_sample=temp) - # Convert to binding energy if ( "sample_workfunction" in data.attrs @@ -258,88 +301,5 @@ def identify( def load_zap(self, identifier, data_dir): return self.load(identifier, data_dir, zap=True) - def generate_summary( - self, data_dir: str | os.PathLike, exclude_zap: bool = False - ) -> pd.DataFrame: - files: dict[str, str] = {} - - if exclude_zap: - target_files = erlab.io.utils.get_files( - data_dir, extensions=(".h5",), notcontains="zap" - ) - else: - target_files = erlab.io.utils.get_files(data_dir, extensions=(".h5",)) - - for path in target_files: - files[path.stem] = str(path) - - summary_attrs: dict[str, str] = { - "Type": "Description", - "Lens Mode": "LensModeName", - "Region": "RegionName", - "T(K)": "temp_sample", - "Pass E": "PassEnergy", - "Polarization": "polarization", - "hv": "hv", - # "Entrance Slit": "Entrance Slit", - # "Exit Slit": "Exit Slit", - "x": "x", - "y": "y", - "z": "z", - "polar": "chi", - "tilt": "xi", - "azi": "delta", - "DA": "beta", - } - - cols = ["File Name", "Path", "Time", *summary_attrs.keys()] - - data_info = [] - - for name, path in files.items(): - data = cast(xr.DataArray, self.load(path)) - - data_info.append( - [ - name, - path, - datetime.datetime.fromtimestamp(data.attrs["CreationTimeStamp"]), - ] - ) - - for k, v in summary_attrs.items(): - try: - val = data.attrs[v] - except KeyError: - try: - val = data.coords[v].values - if val.size == 1: - val = val.item() - except KeyError: - val = "" - - if k == "Pass E": - val = round(val) - - elif k == "Polarization": - if np.iterable(val): - val = np.round(np.asarray(val), 3).astype(float) - else: - val = [float(np.round(val, 3))] - val = [ - {0.0: "LH", 0.5: "LV", 0.25: "RC", -0.25: "LC"}.get(v, v) - for v in val - ] - - if len(val) == 1: - val = val[0] - - data_info[-1].append(val) - - del data - - return ( - pd.DataFrame(data_info, columns=cols) - .sort_values("Time") - .set_index("File Name") - ) + def files_for_summary(self, data_dir): + return sorted(erlab.io.utils.get_files(data_dir, extensions=(".h5",))) diff --git a/src/erlab/io/utilities.py b/src/erlab/io/utilities.py deleted file mode 100644 index 8cbfef5f..00000000 --- a/src/erlab/io/utilities.py +++ /dev/null @@ -1,10 +0,0 @@ -import warnings - -from erlab.io.utils import * # noqa: F403 - -warnings.warn( - "`erlab.io.utilities` has been moved to `erlab.io.utils` " - "and will be removed in a future release", - DeprecationWarning, - stacklevel=2, -) diff --git a/src/erlab/io/utils.py b/src/erlab/io/utils.py index 317c7c8f..8e9633fb 100644 --- a/src/erlab/io/utils.py +++ b/src/erlab/io/utils.py @@ -10,13 +10,14 @@ import importlib.util import os import pathlib -import warnings from collections.abc import Sequence import numpy as np import numpy.typing as npt import xarray as xr +from erlab.utils.misc import emit_user_level_warning + def showfitsinfo(path: str | os.PathLike) -> None: """Print raw metadata from a ``.fits`` file. @@ -45,6 +46,7 @@ def get_files( extensions: Sequence[str] | str | None = None, contains: str | None = None, notcontains: str | None = None, + exclude: str | Sequence[str] | None = None, ) -> set[pathlib.Path]: """Return file names in a directory with the given extension(s). @@ -60,6 +62,8 @@ def get_files( String to filter for in the file names. notcontains String to filter out of the file names. + exclude + Glob patterns to exclude from the search. Returns ------- @@ -72,7 +76,18 @@ def get_files( if isinstance(extensions, str): extensions = [extensions] - for f in pathlib.Path(directory).iterdir(): + dir_path = pathlib.Path(directory) + + excluded: list[pathlib.Path] = [] + + if exclude is not None: + if isinstance(exclude, str): + exclude = [exclude] + + for pattern in exclude: + excluded = excluded + list(dir_path.glob(pattern)) + + for f in dir_path.iterdir(): if ( f.is_dir() or ( @@ -84,7 +99,8 @@ def get_files( ): continue - files.add(f) + if f not in excluded: + files.add(f) return files @@ -111,17 +127,15 @@ def fix_attr_format(da: xr.DataArray): if not isValid: try: da = da.assign_attrs({key: str(da.attrs[key])}) - warnings.warn( + emit_user_level_warning( f"The attribute {key} with invalid type {dt}" " will be converted to string", - stacklevel=1, ) except TypeError: # this is VERY unprobable... da = da.assign_attrs({key: ""}) - warnings.warn( + emit_user_level_warning( f"The attribute {key} with invalid type {dt} will be removed", - stacklevel=1, ) return da diff --git a/src/erlab/plotting/general.py b/src/erlab/plotting/general.py index a1790888..b38a8eef 100644 --- a/src/erlab/plotting/general.py +++ b/src/erlab/plotting/general.py @@ -17,7 +17,6 @@ import contextlib import copy -import warnings from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Literal, Union, cast @@ -41,6 +40,7 @@ nice_colorbar, ) from erlab.utils.array import is_dims_uniform +from erlab.utils.misc import emit_user_level_warning if TYPE_CHECKING: from collections.abc import Callable, Collection, Sequence @@ -154,11 +154,10 @@ def array_extent( for dim, coord in zip(darr.dims, data_coords, strict=True): dif = np.diff(coord) if not np.allclose(dif, dif[0], rtol=rtol, atol=atol): - warnings.warn( + emit_user_level_warning( f"Coordinates for {dim} are not evenly spaced, and the plot may not be " "accurate. Use `DataArray.plot`, `xarray.plot.pcolormesh` or " "`matplotlib.pyplot.pcolormesh` for non-evenly spaced data.", - stacklevel=2, ) data_incs = tuple(coord[1] - coord[0] for coord in data_coords) diff --git a/src/erlab/utils/__init__.py b/src/erlab/utils/__init__.py index 1fe5b898..eb2b1306 100644 --- a/src/erlab/utils/__init__.py +++ b/src/erlab/utils/__init__.py @@ -1,6 +1,9 @@ """ Generic utilities used in various parts of the package. +Most of the functions in this module are used internally and are not intended to be used +directly by the user. + .. currentmodule:: erlab.utils Modules @@ -12,5 +15,6 @@ array parallel formatting + misc """ diff --git a/src/erlab/utils/formatting.py b/src/erlab/utils/formatting.py index e5536213..042e271b 100644 --- a/src/erlab/utils/formatting.py +++ b/src/erlab/utils/formatting.py @@ -1,4 +1,4 @@ -"""Utilites related to representing data in a human-readable format.""" +"""Utilities related to representing data in a human-readable format.""" __all__ = ["format_html_table", "format_value"] import datetime @@ -50,7 +50,7 @@ def format_html_table( def format_value( - val: object, precision: int = 4, use_unicode_minus: bool = True + val: object, precision: int = 4, use_unicode_minus: bool = False ) -> str: """Format the given value based on its type. @@ -96,8 +96,11 @@ def format_value( `format_value(val[0])`. - If the array is not monotonic, the minimum and maximum values are formatted and returned as a string in the format "min~max". + - If the array has two elements, the two elements are formatted and + returned. - - For arrays with more dimensions, the array is returned as is. + - For arrays with more dimensions, the minimum and maximum values are formatted + and returned as a string in the format "min~max". - For lists: The list is grouped by consecutive equal elements, and the count of each element @@ -150,8 +153,11 @@ def _format(val: object) -> str: if val.size == 1: return _format(val.item()) - if val.squeeze().ndim == 1: - val = val.squeeze() + val = val.squeeze() + + if val.ndim == 1: + if len(val) == 2: + return f"[{_format(val[0])}, {_format(val[1])}]" if is_uniform_spaced(val): start, end, step = tuple( @@ -168,10 +174,8 @@ def _format(val: object) -> str: return f"{_format(val[0])}→{_format(val[-1])} ({len(val)})" - mn, mx = tuple(_format(v) for v in (np.min(val), np.max(val))) - return f"{mn}~{mx} ({len(val)})" - - return str(val) + mn, mx = tuple(_format(v) for v in (np.nanmin(val), np.nanmax(val))) + return f"{mn}~{mx} ({len(val)})" if isinstance(val, list): return ", ".join( @@ -192,6 +196,10 @@ def _format(val: object) -> str: return str(val).replace("-", "−") return str(val) + if isinstance(val, np.generic): + # Convert to native Python type + return _format(val.item()) + if isinstance(val, datetime.datetime): return val.strftime("%Y-%m-%d %H:%M:%S") diff --git a/src/erlab/utils/misc.py b/src/erlab/utils/misc.py new file mode 100644 index 00000000..92074a84 --- /dev/null +++ b/src/erlab/utils/misc.py @@ -0,0 +1,57 @@ +import inspect +import pathlib +import sys +import warnings + +import xarray + + +def _find_stack_level() -> int: + """Find the first place in the stack that is not inside erlab, xarray, or stdlib. + + This is unless the code emanates from a test, in which case we would prefer to see + the source. + + This function is adapted from xarray.core.utils.find_stack_level. + + Returns + ------- + stacklevel : int + First level in the stack that is not part of erlab or stdlib. + """ + import erlab + + xarray_dir = pathlib.Path(xarray.__file__).parent + pkg_dir = pathlib.Path(erlab.__file__).parent.parent.parent + test_dir = pkg_dir / "tests" + + std_lib_init = sys.modules["os"].__file__ + if std_lib_init is None: + return 0 + + std_lib_dir = pathlib.Path(std_lib_init).parent + + frame = inspect.currentframe() + n = 0 + while frame: + fname = inspect.getfile(frame) + if ( + (fname.startswith(str(pkg_dir)) and not fname.startswith(str(test_dir))) + or ( + fname.startswith(str(std_lib_dir)) + and "site-packages" not in fname + and "dist-packages" not in fname + ) + or fname.startswith(str(xarray_dir)) + ): + frame = frame.f_back + n += 1 + else: + break + return n + + +def emit_user_level_warning(message, category=None) -> None: + """Emit a warning at the user level by inspecting the stack trace.""" + stacklevel = _find_stack_level() + return warnings.warn(message, category=category, stacklevel=stacklevel) diff --git a/src/erlab/utils/parallel.py b/src/erlab/utils/parallel.py index 787e5a3a..69950e75 100644 --- a/src/erlab/utils/parallel.py +++ b/src/erlab/utils/parallel.py @@ -8,7 +8,6 @@ import joblib import joblib._parallel_backends import tqdm.auto -from qtpy import QtCore @contextlib.contextmanager @@ -35,7 +34,7 @@ def tqdm_print_progress(self) -> None: @contextlib.contextmanager -def joblib_progress_qt(signal: QtCore.Signal): +def joblib_progress_qt(signal): """Context manager for interactive windows. The number of completed tasks are emitted by the given signal. diff --git a/tests/accessors/test_general.py b/tests/accessors/test_general.py index 7027a4e1..c8df710c 100644 --- a/tests/accessors/test_general.py +++ b/tests/accessors/test_general.py @@ -98,6 +98,22 @@ def test_qsel_slice_with_width(): dat.qsel({"x": slice(1.0, 3.0), "x_width": 1.0}) +def test_qsel_associated_dim(): + dat = xr.DataArray( + np.arange(25).reshape(5, 5), + dims=("x", "y"), + coords={"x": np.arange(5), "y": np.arange(5), "z": ("x", np.arange(5))}, + ) + xr.testing.assert_identical( + dat.qsel(x=2, x_width=3), + xr.DataArray( + np.array([10.0, 11.0, 12.0, 13.0, 14.0]), + dims=("y",), + coords={"y": np.arange(5), "x": 2.0, "z": 2.0}, + ), + ) + + def test_qsel_value_outside_bounds(): dat = xr.DataArray( np.arange(25).reshape(5, 5), diff --git a/tests/analysis/fit/test_models.py b/tests/analysis/fit/test_models.py index 8fb866a7..7f814f13 100644 --- a/tests/analysis/fit/test_models.py +++ b/tests/analysis/fit/test_models.py @@ -53,7 +53,7 @@ def test_fermi_edge_2d_model(): dims=["eV", "alpha"], coords={"eV": eV, "alpha": alpha}, ) - data.attrs["temp_sample"] = 300.0 + data.attrs["sample_temp"] = 300.0 # Create an instance of FermiEdge2dModel model = models.FermiEdge2dModel(degree=2) diff --git a/tests/analysis/test_gold.py b/tests/analysis/test_gold.py index e7e3f9fa..6069352e 100644 --- a/tests/analysis/test_gold.py +++ b/tests/analysis/test_gold.py @@ -3,7 +3,7 @@ import pytest from numpy.testing import assert_allclose -from erlab.analysis.gold import correct_with_edge, poly, quick_fit, spline +from erlab.analysis.gold import correct_with_edge, poly, quick_resolution, spline @pytest.mark.parametrize("parallel_kw", [None, {"n_jobs": 1, "return_as": "list"}]) @@ -59,10 +59,10 @@ def test_spline(gold): @pytest.mark.parametrize("resolution", [None, 1e-2]) @pytest.mark.parametrize("temp", [None, 100.0]) @pytest.mark.parametrize("eV_range", [None, (-0.2, 0.2)]) -def test_quick_fit( +def test_quick_resolution( gold, eV_range, temp, resolution, fix_temp, fix_center, fix_resolution, bkg_slope ): - ds = quick_fit( + ds = quick_resolution( gold, eV_range=eV_range, temp=temp, @@ -72,4 +72,5 @@ def test_quick_fit( fix_resolution=fix_resolution, bkg_slope=bkg_slope, ) + plt.close() assert ds.modelfit_results.item().success diff --git a/tests/analysis/test_interpolate.py b/tests/analysis/test_interpolate.py index 68738182..99c4720c 100644 --- a/tests/analysis/test_interpolate.py +++ b/tests/analysis/test_interpolate.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import scipy.interpolate import xarray as xr @@ -17,12 +18,25 @@ def value_func_3d(x, y, z): return 2 * x + 3 * y - z -def test_interpn_1d(): - x = np.linspace(0, 4, 5) - points = (x,) - values = value_func_1d(*np.meshgrid(*points, indexing="ij")) - point = np.array([[2.21], [2.67]]) +def value_func_4d(x, y, z, w): + return 2 * x + 3 * y - z + w + + +x = np.linspace(0, 4, 5) +y = np.linspace(0, 5, 6) +z = np.linspace(0, 6, 7) +w = np.linspace(0, 7, 8) + +values_1d = value_func_1d(*np.meshgrid(x, indexing="ij")) +values_2d = value_func_2d(*np.meshgrid(x, y, indexing="ij")) +values_3d = value_func_3d(*np.meshgrid(x, y, z, indexing="ij")) +values_4d = value_func_4d(*np.meshgrid(x, y, z, w, indexing="ij")) + +@pytest.mark.parametrize("values", [values_1d, values_2d, values_3d, values_4d]) +@pytest.mark.parametrize("point", [np.array([2.21, 2.67]), np.array([[2.21], [2.67]])]) +def test_interpn_1d(values, point): + points = (x,) assert np.allclose( interpn(points, values, point), scipy.interpolate.interpn( @@ -31,13 +45,12 @@ def test_interpn_1d(): ) -def test_interpn_2d(): - x = np.linspace(0, 4, 5) - y = np.linspace(0, 5, 6) +@pytest.mark.parametrize("values", [values_2d, values_3d, values_4d]) +@pytest.mark.parametrize( + "point", [np.array([2.21, 3.12]), np.array([[2.21, 3.12], [2.67, 3.54]])] +) +def test_interpn_2d(values, point): points = (x, y) - values = value_func_2d(*np.meshgrid(*points, indexing="ij")) - point = np.array([[2.21, 3.12], [2.67, 3.54]]) - assert np.allclose( interpn(points, values, point), scipy.interpolate.interpn( @@ -46,14 +59,10 @@ def test_interpn_2d(): ) -def test_interpn_3d(): - x = np.linspace(0, 4, 5) - y = np.linspace(0, 5, 6) - z = np.linspace(0, 6, 7) +@pytest.mark.parametrize("values", [values_3d, values_4d]) +@pytest.mark.parametrize("point", [np.array([[2.21, 3.12, 1.15], [2.67, 3.54, 1.03]])]) +def test_interpn_3d(values, point): points = (x, y, z) - values = value_func_3d(*np.meshgrid(*points, indexing="ij")) - point = np.array([[2.21, 3.12, 1.15], [2.67, 3.54, 1.03]]) - assert np.allclose( interpn(points, values, point), scipy.interpolate.interpn( diff --git a/tests/conftest.py b/tests/conftest.py index 5ef724de..46acf2f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,10 +9,10 @@ from erlab.io.exampledata import generate_data_angles, generate_gold_edge -DATA_COMMIT_HASH = "985457233a02eace67da27a0d5450fc0301e55c3" +DATA_COMMIT_HASH = "ad7dbdf35ef2404feee0854cb3a52973770709f4" """The commit hash of the commit to retrieve from `kmnhan/erlabpy-data`.""" -DATA_KNOWN_HASH = "691d69af64feaa92f928365050392f08c640d3f11a04243c9bdf323c1bc3f0b0" +DATA_KNOWN_HASH = "43d89ef27482e127e7509b65d635b7a6a0cbf648f84aa597c8b4094bdc0c46ab" """The hash of the `.tar.gz` file.""" diff --git a/tests/interactive/test_imagetool.py b/tests/interactive/test_imagetool.py index d8cc222d..c4de3b8a 100644 --- a/tests/interactive/test_imagetool.py +++ b/tests/interactive/test_imagetool.py @@ -179,7 +179,7 @@ def _go_to_file(dialog: QtWidgets.QFileDialog): win.close() - xr.testing.assert_equal(data, xr.load_dataarray(filename)) + xr.testing.assert_equal(data, xr.load_dataarray(filename, engine="h5netcdf")) tmp_dir.cleanup() @@ -229,6 +229,8 @@ def test_itool(qtbot): assert clw.max_spin.value() == 2.0 clw.rst_btn.click() assert win.slicer_area.levels == (0.0, 24.0) + clw.zero_btn.click() + assert win.slicer_area.levels == (0.0, 24.0) win.slicer_area.levels = (1.0, 23.0) win.slicer_area.lock_levels(False) @@ -402,8 +404,13 @@ def test_value_update(qtbot): win.close() -def test_value_update_errors(): +def test_value_update_errors(qtbot): win = ImageTool(xr.DataArray(np.arange(25).reshape((5, 5)), dims=["x", "y"])) + qtbot.addWidget(win) + + with qtbot.waitExposed(win): + win.show() + win.activateWindow() with pytest.raises(ValueError, match="DataArray dimensions do not match"): win.slicer_area.update_values( @@ -420,6 +427,8 @@ def test_value_update_errors(): with pytest.raises(ValueError, match="^Data shape does not match.*"): win.slicer_area.update_values(np.arange(24).reshape((4, 6))) + win.close() + def test_sync(qtbot): manager = ImageToolManager() diff --git a/tests/interactive/test_tools.py b/tests/interactive/test_tools.py index fa70644a..e346ab57 100644 --- a/tests/interactive/test_tools.py +++ b/tests/interactive/test_tools.py @@ -3,6 +3,7 @@ import xarray as xr from numpy.testing import assert_allclose +import erlab.lattice from erlab.interactive.bzplot import BZPlotter from erlab.interactive.curvefittingtool import edctool, mdctool from erlab.interactive.derivative import dtool @@ -16,7 +17,6 @@ def test_goldtool(qtbot, gold): with qtbot.waitExposed(win): win.show() win.activateWindow() - win.raise_() win.params_edge.widgets["# CPU"].setValue(1) win.params_edge.widgets["Fast"].setChecked(True) @@ -45,7 +45,6 @@ def test_dtool(qtbot): with qtbot.waitExposed(win): win.show() win.activateWindow() - win.raise_() win.tab_widget.setCurrentIndex(0) win.interp_group.setChecked(False) @@ -72,12 +71,17 @@ def test_dtool(qtbot): def test_ktool(qtbot, anglemap): - win = ktool(anglemap, execute=False) + win = ktool( + anglemap, + avec=erlab.lattice.abc2avec(6.97, 6.97, 8.685, 90, 90, 120), + cmap="terrain_r", + execute=False, + ) + qtbot.addWidget(win) with qtbot.waitExposed(win): win.show() win.activateWindow() - win.raise_() win._offset_spins["delta"].setValue(30.0) win._offset_spins["xi"].setValue(20.0) @@ -88,6 +92,17 @@ def test_ktool(qtbot, anglemap): == """anglemap.kspace.offsets = {"delta": 30.0, "xi": 20.0, "beta": 10.0} anglemap_kconv = anglemap.kspace.convert()""" ) + win.add_circle_btn.click() + roi = win._roi_list[0] + roi.getMenu() + roi.set_position((0.1, 0.1), 0.2) + assert roi.get_position() == (0.1, 0.1, 0.2) + + roi_control_widget = roi._pos_menu.actions()[0].defaultWidget() + roi_control_widget.x_spin.setValue(0.0) + roi_control_widget.y_spin.setValue(0.2) + roi_control_widget.r_spin.setValue(0.3) + assert roi.get_position() == (0.0, 0.2, 0.3) def test_curvefittingtool(qtbot): diff --git a/tests/io/plugins/test_da30.py b/tests/io/plugins/test_da30.py new file mode 100644 index 00000000..809b1da5 --- /dev/null +++ b/tests/io/plugins/test_da30.py @@ -0,0 +1,32 @@ +import pytest +import xarray as xr + +import erlab.io + + +@pytest.fixture(scope="module") +def data_dir(test_data_dir): + erlab.io.set_loader("da30") + erlab.io.set_data_dir(test_data_dir / "da30") + return test_data_dir / "da30" + + +@pytest.fixture(scope="module") +def expected_dir(data_dir): + return data_dir / "expected" + + +@pytest.mark.parametrize( + ("args", "expected"), + [ + ("f0001f_001.ibw", "f0001f_001.h5"), + ("f0002.zip", "f0002.h5"), + ("f0003.pxt", "f0003.h5"), + ], +) +def test_load(expected_dir, args, expected): + loaded = erlab.io.load(**args) if isinstance(args, dict) else erlab.io.load(args) + + xr.testing.assert_identical( + loaded, xr.load_dataarray(expected_dir / expected, engine="h5netcdf") + ) diff --git a/tests/io/plugins/test_maestro.py b/tests/io/plugins/test_maestro.py new file mode 100644 index 00000000..9c6f7b36 --- /dev/null +++ b/tests/io/plugins/test_maestro.py @@ -0,0 +1,31 @@ +import pytest +import xarray as xr + +import erlab.io + + +@pytest.fixture(scope="module") +def data_dir(test_data_dir): + erlab.io.set_loader("maestro") + erlab.io.set_data_dir(test_data_dir / "maestro") + return test_data_dir / "maestro" + + +@pytest.fixture(scope="module") +def expected_dir(data_dir): + return data_dir / "expected" + + +@pytest.mark.parametrize( + ("args", "expected"), + [ + (1, "20241026_00001.h5"), + ("20241026_00001.h5", "20241026_00001.h5"), + ], +) +def test_load(expected_dir, args, expected): + loaded = erlab.io.load(**args) if isinstance(args, dict) else erlab.io.load(args) + + xr.testing.assert_identical( + loaded, xr.load_dataarray(expected_dir / expected, engine="h5netcdf") + ) diff --git a/tests/io/plugins/test_merlin.py b/tests/io/plugins/test_merlin.py index 20a8c960..f26463ff 100644 --- a/tests/io/plugins/test_merlin.py +++ b/tests/io/plugins/test_merlin.py @@ -19,18 +19,18 @@ def expected_dir(data_dir): def test_load_xps(expected_dir): xr.testing.assert_identical( erlab.io.load("core.pxt"), - xr.load_dataarray(expected_dir / "core.nc"), + xr.load_dataarray(expected_dir / "core.h5"), ) def test_load_multiple(expected_dir): xr.testing.assert_identical( erlab.io.load("f_005_S001.pxt"), - xr.load_dataarray(expected_dir / "5.nc"), + xr.load_dataarray(expected_dir / "5.h5"), ) xr.testing.assert_identical( erlab.io.load(5), - xr.load_dataarray(expected_dir / "5.nc"), + xr.load_dataarray(expected_dir / "5.h5"), ) @@ -38,9 +38,22 @@ def test_load_live(expected_dir): for live in ("lp", "lxy"): xr.testing.assert_identical( erlab.io.load(f"{live}.ibw"), - xr.load_dataarray(expected_dir / f"{live}.nc"), + xr.load_dataarray(expected_dir / f"{live}.h5"), ) def test_summarize(data_dir): erlab.io.summarize() + + +def test_qinfo(data_dir): + data = erlab.io.load(5) + assert ( + data.qinfo.__repr__() + == """time: 2022-03-27 07:53:26\ntype: map\nlens mode (Lens Mode): A30 +mode (Acquisition Mode): Dither\ntemperature (sample_temp): 110.67 +pass energy (Pass Energy): 10\nanalyzer slit (Slit Plate): 7\npol (polarization): LH +hv (hv): 100\nentrance slit (Entrance Slit): 70\nexit slit (Exit Slit): 70 +polar (beta): [-15.5, -15]\ntilt (xi): 0\nazi (delta): 3\nx (x): 2.487\ny (y): 0.578 +z (z): -1.12""" + ) diff --git a/tests/io/plugins/test_ssrl52.py b/tests/io/plugins/test_ssrl52.py new file mode 100644 index 00000000..529538cb --- /dev/null +++ b/tests/io/plugins/test_ssrl52.py @@ -0,0 +1,37 @@ +import pytest +import xarray as xr + +import erlab.io + + +@pytest.fixture(scope="module") +def data_dir(test_data_dir): + erlab.io.set_loader("ssrl52") + erlab.io.set_data_dir(test_data_dir / "ssrl52") + return test_data_dir / "ssrl52" + + +@pytest.fixture(scope="module") +def expected_dir(data_dir): + return data_dir / "expected" + + +@pytest.mark.parametrize( + ("args", "expected"), + [ + ("f_0002.h5", "f_0002.h5"), + (2, "f_0002.h5"), + ("f_zap_0002.h5", "f_zap_0002.h5"), + ({"identifier": 2, "zap": True}, "f_zap_0002.h5"), + ], +) +def test_load(expected_dir, args, expected): + loaded = erlab.io.load(**args) if isinstance(args, dict) else erlab.io.load(args) + + xr.testing.assert_identical( + loaded, xr.load_dataarray(expected_dir / expected, engine="h5netcdf") + ) + + +def test_summarize(data_dir): + erlab.io.summarize() diff --git a/tests/io/test_dataloader.py b/tests/io/test_dataloader.py index 77e63d80..0a26d180 100644 --- a/tests/io/test_dataloader.py +++ b/tests/io/test_dataloader.py @@ -8,8 +8,8 @@ from typing import ClassVar import numpy as np -import pandas as pd import pytest +import xarray as xr import erlab.io from erlab.io.dataloader import LoaderBase @@ -21,11 +21,11 @@ def make_data(beta=5.0, temp=20.0, hv=50.0, bandshift=0.0): shape=(250, 1, 300), angrange={"alpha": (-15, 15), "beta": (beta, beta)}, hv=hv, - configuration=1, # Configuration, see + configuration=1, temp=temp, bandshift=bandshift, - count=1000, assign_attributes=False, + seed=1, ).T # Rename coordinates. The loader must rename them back to the original names. @@ -39,6 +39,7 @@ def make_data(beta=5.0, temp=20.0, hv=50.0, bandshift=0.0): "delta": "Azimuth", } ) + dt = datetime.datetime.now() # Assign some attributes that real data would have return data.assign_attrs( @@ -47,7 +48,8 @@ def make_data(beta=5.0, temp=20.0, hv=50.0, bandshift=0.0): "SpectrumType": "Fixed", # Acquisition mode of the analyzer "PassEnergy": 10, # Pass energy of the analyzer "UndPol": 0, # Undulator polarization - "DateTime": datetime.datetime.now().isoformat(), # Acquisition time + "Date": dt.strftime(r"%d/%m/%Y"), # Date of the measurement + "Time": dt.strftime("%I:%M:%S %p"), # Time of the measurement "TB": temp, "X": 0.0, "Y": 0.0, @@ -56,6 +58,32 @@ def make_data(beta=5.0, temp=20.0, hv=50.0, bandshift=0.0): ) +def _format_polarization(val) -> str: + val = round(float(val)) + return {0: "LH", 2: "LV", -1: "RC", 1: "LC"}.get(val, str(val)) + + +def _parse_time(darr: xr.DataArray) -> datetime.datetime: + return datetime.datetime.strptime( + f"{darr.attrs['Date']} {darr.attrs['Time']}", + r"%d/%m/%Y %I:%M:%S %p", + ) + + +def _determine_kind(darr: xr.DataArray) -> str: + if "scan_type" in darr.attrs and darr.attrs["scan_type"] == "live": + return "LP" if "beta" in darr.dims else "LXY" + + data_type = "xps" + if "alpha" in darr.dims: + data_type = "cut" + if "beta" in darr.dims: + data_type = "map" + if "hv" in darr.dims: + data_type = "hvdep" + return data_type + + def test_loader(): # Create a temporary directory tmp_dir = tempfile.TemporaryDirectory() @@ -63,32 +91,25 @@ def test_loader(): # Generate a map beta_coords = np.linspace(2, 7, 10) + # Generate and save cuts with different beta values for i, beta in enumerate(beta_coords): - erlab.io.save_as_hdf5( - make_data(beta=beta, temp=20.0 + i, hv=50.0), - filename=f"{tmp_dir.name}/data_001_S{str(i + 1).zfill(3)}.h5", - igor_compat=False, - ) + data = make_data(beta=beta, temp=20.0 + i, hv=50.0) + filename = f"{tmp_dir.name}/data_001_S{str(i + 1).zfill(3)}.h5" + data.to_netcdf(filename, engine="h5netcdf") # Write scan coordinates to a csv file - with open( - f"{tmp_dir.name}/data_001_axis.csv", "w", newline="", encoding="utf-8" - ) as file: + with open(f"{tmp_dir.name}/data_001_axis.csv", "w", newline="") as file: writer = csv.writer(file) writer.writerow(["Index", "Polar"]) for i, beta in enumerate(beta_coords): writer.writerow([i + 1, beta]) - # Generate a cut - erlab.io.save_as_hdf5( - make_data(beta=5.0, temp=20.0, hv=50.0), - filename=f"{tmp_dir.name}/data_002.h5", - igor_compat=False, - ) - - # List the generated files - sorted(os.listdir(tmp_dir.name)) + # Generate some cuts with different band shifts + for i in range(4): + data = make_data(beta=5.0, temp=20.0, hv=50.0, bandshift=-i * 0.05) + filename = f"{tmp_dir.name}/data_{str(i + 2).zfill(3)}.h5" + data.to_netcdf(filename, engine="h5netcdf") class ExampleLoader(LoaderBase): name = "example" @@ -102,6 +123,7 @@ class ExampleLoader(LoaderBase): "Polar", "Polar Compens", ], # Can have multiple names assigned to the same name + # If both are present in the data, a ValueError will be raised "delta": "Azimuth", "xi": "Tilt", "x": "X", @@ -109,7 +131,7 @@ class ExampleLoader(LoaderBase): "z": "Z", "hv": "PhotonEnergy", "polarization": "UndPol", - "temp_sample": "TB", + "sample_temp": "TB", } coordinate_attrs = ( @@ -122,7 +144,7 @@ class ExampleLoader(LoaderBase): "z", "polarization", "photon_flux", - "temp_sample", + "sample_temp", ) # Attributes to be used as coordinates. Place all attributes that we don't want # to lose when merging multiple file scans here. @@ -132,6 +154,30 @@ class ExampleLoader(LoaderBase): "sample_workfunction": 4.3, } # Any additional metadata you want to add to the data + formatters: ClassVar[dict] = { + "polarization": _format_polarization, + "LensMode": lambda x: x.replace("Angular", "A"), + } + + summary_attrs: ClassVar[dict] = { + "Time": _parse_time, + "Type": _determine_kind, + "Lens Mode": "LensMode", + "Scan Type": "SpectrumType", + "T(K)": "sample_temp", + "Pass E": "PassEnergy", + "Polarization": "polarization", + "hv": "hv", + "x": "x", + "y": "y", + "z": "z", + "polar": "beta", + "tilt": "xi", + "azi": "delta", + } + + summary_sort = "Time" + skip_validate = False always_single = False @@ -166,8 +212,29 @@ def identify(self, num, data_dir): return files, coord_dict - def load_single(self, file_path): - return erlab.io.load_hdf5(file_path) + def load_single(self, file_path, without_values=False): + darr = xr.open_dataarray(file_path, engine="h5netcdf") + + if without_values: + # Do not load the data into memory + return xr.DataArray( + np.zeros(darr.shape, darr.dtype), + coords=darr.coords, + dims=darr.dims, + attrs=darr.attrs, + ) + + return darr + + def post_process(self, data: xr.DataArray) -> xr.DataArray: + data = super().post_process(data) + + if "sample_temp" in data.coords: + # Add temperature to attributes, for backwards compatibility + temp = float(data.sample_temp.mean()) + data = data.assign_attrs(sample_temp=temp) + + return data def infer_index(self, name): # Get the scan number from file name @@ -180,99 +247,8 @@ def infer_index(self, name): return int(scan_num), {} return None, None - def generate_summary(self, data_dir): - # Get all valid data files in directory - files = {} - for path in erlab.io.utils.get_files(data_dir, extensions=[".h5"]): - # If multiple scans, strip the _S### part - name_match = re.match(r"(.*?_\d{3})_(?:_S\d{3})?", path.stem) - data_name = path.stem if name_match is None else name_match.group(1) - files[data_name] = str(path) - - # Map dataframe column names to data attributes - attrs_mapping = { - "Lens Mode": "LensMode", - "Scan Type": "SpectrumType", - "T(K)": "temp_sample", - "Pass E": "PassEnergy", - "Polarization": "polarization", - "hv": "hv", - "x": "x", - "y": "y", - "z": "z", - "polar": "beta", - "tilt": "xi", - "azi": "delta", - } - column_names = ["File Name", "Path", "Time", "Type", *attrs_mapping.keys()] - - data_info = [] - - processed_indices = set() - for name, path in files.items(): - # Skip already processed multi-file scans - index, _ = self.infer_index(name) - if index in processed_indices: - continue - - if index is not None: - processed_indices.add(index) - - # Load data - data = self.load(path) - - # Determine type of scan - data_type = "core" - if "alpha" in data.dims: - data_type = "cut" - if "beta" in data.dims: - data_type = "map" - if "hv" in data.dims: - data_type = "hvdep" - - data_info.append( - [ - name, - path, - datetime.datetime.fromisoformat(data.attrs["DateTime"]), - data_type, - ] - ) - - for k, v in attrs_mapping.items(): - # Try to get the attribute from the data, then from the coordinates - try: - val = data.attrs[v] - except KeyError: - try: - val = data.coords[v].values - if val.size == 1: - val = val.item() - except KeyError: - val = "" - - # Convert polarization values to human readable form - if k == "Polarization": - if np.iterable(val): - val = np.asarray(val).astype(int) - else: - val = [round(val)] - val = [ - {0: "LH", 2: "LV", -1: "RC", 1: "LC"}.get(v, v) for v in val - ] - if len(val) == 1: - val = val[0] - - data_info[-1].append(val) - - del data - - # Sort by time and set index - return ( - pd.DataFrame(data_info, columns=column_names) - .sort_values("Time") - .set_index("File Name") - ) + def files_for_summary(self, data_dir): + return erlab.io.utils.get_files(data_dir, extensions=[".h5"]) with erlab.io.loader_context("example", tmp_dir.name): erlab.io.load(1) @@ -293,23 +269,30 @@ def generate_summary(self, data_dir): assert repr(erlab.io.loaders).startswith("Registered data loaders") assert erlab.io.loaders._repr_html_().startswith("