Skip to content

Commit

Permalink
Converts Spectrum to a dataclass.
Browse files Browse the repository at this point in the history
    Formerly abstract, `Spectrum` has to be made concrete due to
    [mypy issue#5374](python/mypy#5374)
    `Only concrete class can be given where "Type[Abstract]" is expected`
  • Loading branch information
jevandezande committed Mar 8, 2022
1 parent 2f3e2e9 commit 2524d20
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 35 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ classifiers =
Programming Language :: Python :: 3.9
Topic :: Software Development :: Libraries :: Python Modules


[options]
packages = spectra
python_requires = >= 3.9
Expand All @@ -38,6 +37,7 @@ line_length = 120
[mypy]
files = spectra, tests
ignore_missing_imports = true
plugins = numpy.typing.mypy_plugin

[flake8]
ignore = E203, E266, E501, W503, E731
Expand Down
37 changes: 21 additions & 16 deletions spectra/_abc_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Generator, Iterable, Optional

import numpy as np
from numpy.typing import ArrayLike

from ._typing import ArrayLike
from .tools import index_of_x, integrate, read_csvs


class Spectrum(ABC):
@dataclass
class Spectrum:
name: str
energies: np.ndarray
intensities: np.ndarray
units: Optional[str] = None
style: Optional[str] = None
time: Optional[datetime] = None

def __init__(
self,
name: str,
energies: ArrayLike,
intensities: ArrayLike,
units: Optional[str] = None,
style: Optional[str] = None,
time=None,
time: Optional[datetime] = None,
):
energies = np.asarray(energies)
intensities = np.asarray(intensities)
assert len(energies.shape) == 1
assert energies.shape == intensities.shape

self.name = name
self.energies = np.asarray(energies, dtype=float)
self.intensities = np.asarray(intensities, dtype=float)
self.units = units
self.style = style
self.time = time

def __post_init__(self):
self.energies = np.asarray(self.energies, dtype=float)
self.intensities = np.asarray(self.intensities, dtype=float)
assert len(self.energies.shape) == 1
assert self.energies.shape == self.intensities.shape

def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.name}>"

def __str__(self) -> str:
return repr(self)

def __iter__(self) -> Generator[tuple[float, float], None, None]:
"""
Iterate over points in the Spectrum.
Expand Down Expand Up @@ -73,9 +80,8 @@ def __abs__(self) -> Spectrum:
def __radd__(self, other: float) -> Spectrum:
return self.__add__(other)

@abstractmethod
def __add__(self, other: float) -> Spectrum:
pass
raise NotImplementedError()

def __rtruediv__(self, other: float) -> Spectrum:
new = self.copy()
Expand Down Expand Up @@ -282,9 +288,8 @@ def sliced(self, start: Optional[float] = None, end: Optional[float] = None) ->
new.intensities = new.intensities[start_i:end_i]
return new

@abstractmethod
def smoothed(self, box_pts: int | bool = True) -> Spectrum:
pass
raise NotImplementedError()

@classmethod
def from_csvs(cls, *inps: str, names: Optional[Iterable[str]] = None):
Expand Down
3 changes: 3 additions & 0 deletions spectra/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy import ndarray

ITER_STR = Optional[Union[Iterable[str], str]]
ITER_FLOAT = Optional[Union[Iterable[float], float]]

OPT_PLOT = Optional[tuple[Figure, Axes]]

ArrayLike = Union[Iterable, ndarray]
3 changes: 1 addition & 2 deletions spectra/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike

from ._typing import OPT_PLOT
from ._typing import OPT_PLOT, ArrayLike
from .conv_spectrum import ConvSpectrum
from .tools import integrate, smooth_curve

Expand Down
3 changes: 2 additions & 1 deletion spectra/shapes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from numpy.typing import ArrayLike

from ._typing import ArrayLike


def gaussian(energy: float, width: float, xs: ArrayLike) -> np.ndarray:
Expand Down
19 changes: 6 additions & 13 deletions spectra/sticks_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np
from numpy.typing import ArrayLike

from ._abc_spectrum import Spectrum
from .conv_spectrum import ConvSpectrum
from .shapes import gaussian


@dataclass
class SticksSpectrum(Spectrum):
"""
A SticksSpectrum is a collection of intensities at various energies
These may be convolved with a shape to produce a ConvSpectrum.
"""

def __init__(
self,
name: str,
energies: ArrayLike,
intensities: ArrayLike,
units: Optional[str] = None,
style: Optional[str] = None,
time=None,
y_shift: float = 0,
):
super().__init__(name, energies, intensities, units, style, time)
self.y_shift = y_shift
y_shift: float = 0

def __eq__(self, other):
return self.y_shift == other.y_shift and super().__eq__(other)

def __rsub__(self, other: float) -> SticksSpectrum:
"""
Expand Down
3 changes: 2 additions & 1 deletion spectra/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import TYPE_CHECKING, Generator, Iterable, Optional, Sequence

import numpy as np
from numpy.typing import ArrayLike
from scipy import constants

from ._typing import ArrayLike

if TYPE_CHECKING:
from ._abc_spectrum import Spectrum

Expand Down
5 changes: 4 additions & 1 deletion tests/test_sticks_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def test_str():
energies, intensities = np.arange(10), np.arange(10)
s1 = SticksSpectrum("Hello World", energies, intensities)

assert str(s1) == "<SticksSpectrum: Hello World>"
assert (
str(s1)
== "SticksSpectrum(name='Hello World', energies=array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), intensities=array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), units=None, style=None, time=None, y_shift=0)"
)


def test_add_sub():
Expand Down

0 comments on commit 2524d20

Please sign in to comment.