Skip to content

Commit

Permalink
type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Sep 25, 2024
1 parent 7b74c61 commit f9ea772
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/scippneutron/absorption/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .material import Material


__all__ = (compute_transmission_map, Cylinder, Material)
__all__ = ('compute_transmission_map', 'Cylinder', 'Material')
3 changes: 2 additions & 1 deletion src/scippneutron/absorption/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import Any

import scipp as sc

Expand All @@ -12,7 +13,7 @@ def compute_transmission_map(
beam_direction: sc.Variable,
wavelength: sc.Variable,
detector_position: sc.Variable,
quadrature_kind='medium',
quadrature_kind: Any = 'medium',
) -> sc.DataArray:
points, weights = sample_shape.quadrature(quadrature_kind)
scatter_direction = detector_position - points.to(unit=detector_position.unit)
Expand Down
8 changes: 5 additions & 3 deletions src/scippneutron/absorption/cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class Cylinder(SampleShape):
radius: sc.Variable
height: sc.Variable

def beam_intersection(self, start_point, direction):
def beam_intersection(
self, start_point: sc.Variable, direction: sc.Variable
) -> sc.Variable:
'Length of intersection between beam and cylinder'
base_point = self.center_of_base - start_point
cyl_intersection, *cyl_interval = _line_infinite_cylinder_intersection(
Expand All @@ -37,7 +39,7 @@ def center(self) -> sc.Variable:
return self.center_of_base + self.symmetry_line * self.height / 2

@property
def volume(self):
def volume(self) -> sc.Variable:
return self.radius**2 * self.height * np.pi

def _select_quadrature_points(self, kind):
Expand Down Expand Up @@ -94,7 +96,7 @@ def _select_quadrature_points(self, kind):
def quadrature(
self,
kind: Literal['expensive', 'medium', 'cheap', 'mc'] | tuple[Literal['mc'], int],
):
) -> tuple[sc.Variable, sc.Variable]:
'Returns quadrature points and weights of the cylinder.'
quad = self._select_quadrature_points(kind)
# Scale to size of cylinder
Expand Down
2 changes: 1 addition & 1 deletion src/scippneutron/absorption/material.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Material:
scattering_cross_section: sc.Variable
effective_sample_number_density: sc.Variable

def attenuation_coefficient(self, wavelength):
def attenuation_coefficient(self, wavelength: sc.Variable) -> sc.Variable:
'''Computes marginal attenuation per distance for
the given neutron wavelength.'''
return self.effective_sample_number_density * (
Expand Down
5 changes: 3 additions & 2 deletions src/scippneutron/absorption/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Any

import scipp as sc

Expand All @@ -14,12 +15,12 @@ def beam_intersection(

@property
@abstractmethod
def volume(self):
def volume(self) -> sc.Variable:
'''Volume of the shape'''
pass

@abstractmethod
def quadrature(self, kind) -> tuple[sc.Variable, sc.Variable]:
def quadrature(self, kind: Any) -> tuple[sc.Variable, sc.Variable]:
'''Returns quadrature points and weights for evaluating integrals over
the shape. The method returns a tuple where the first entry is
an array containing vectors representing points in the shape and the
Expand Down

0 comments on commit f9ea772

Please sign in to comment.