diff --git a/src/scippneutron/absorption/cylinder.py b/src/scippneutron/absorption/cylinder.py index 8d2d1ac94..ea92c235f 100644 --- a/src/scippneutron/absorption/cylinder.py +++ b/src/scippneutron/absorption/cylinder.py @@ -3,6 +3,7 @@ import numpy as np import scipp as sc +from numpy.polynomial.chebyshev import chebgauss from numpy.polynomial.legendre import leggauss from . import quadratures @@ -42,13 +43,16 @@ def volume(self): def _select_quadrature_points(self, kind): if kind == 'expensive': - x, w = leggauss(10) + # x, w = leggauss(40) + x, w = chebgauss(35) + w *= (1 - x**2) ** 0.5 + w /= sum(w) / 2 quad = _cylinder_quadrature_from_product( quadratures.disk254_cheb, dict(x=x, weights=w), # noqa: C408 ) elif kind == 'medium': - x, w = leggauss(8) + x, w = leggauss(40) # Would be nice to have a medium size Chebychev quadrature on the disk, # but I only found the large one for now. quad = _cylinder_quadrature_from_product( @@ -61,6 +65,16 @@ def _select_quadrature_points(self, kind): quadratures.disk12, dict(x=x, weights=w), # noqa: C408 ) + elif kind == 'mc': + r = np.random.random(5000) ** 0.5 + th = 2 * np.pi * np.random.random(5000) + z = 2 * np.random.random(5000) - 1 + quad = { + 'x': r * np.cos(th), + 'y': r * np.sin(th), + 'z': z, + 'weights': 2 * np.pi * np.ones(5000) / 5000, + } else: raise NotImplementedError return {k: sc.array(dims=['quad'], values=v) for k, v in quad.items()} @@ -151,19 +165,24 @@ def _line_infinite_cylinder_intersection(a, b, r, n): ''' nxa = sc.cross(n, a) nxa_square = sc.dot(nxa, nxa) + parallel_to_cylinder = nxa_square == sc.scalar(0.0, unit=nxa.unit) s2 = nxa_square * r**2 - sc.dot(b, nxa) ** 2 s = sc.sqrt(s2) m = sc.dot(nxa, sc.cross(b, a)) intersection = s2 >= sc.scalar(0.0, unit=s2.unit) - left = (m - s) / nxa_square - right = (m + s) / nxa_square + left = sc.where( + parallel_to_cylinder, + sc.scalar(float('-inf'), unit=m.unit), + (m - s) / nxa_square, + ) + right = sc.where( + parallel_to_cylinder, sc.scalar(float('inf'), unit=m.unit), (m + s) / nxa_square + ) origin_in_cylinder = sc.norm(b - sc.dot(b, a) * a) <= r - ndota = sc.dot(n, a) - parallel_to_cylinder = sc.abs(ndota) == sc.scalar(1.0, unit=ndota.unit) return ( sc.where(parallel_to_cylinder, origin_in_cylinder, intersection), - sc.where(parallel_to_cylinder, sc.scalar(float('-inf'), unit=left.unit), left), - sc.where(parallel_to_cylinder, sc.scalar(float('inf'), unit=right.unit), right), + left, + right, ) diff --git a/tests/absorption/cylinder_test.py b/tests/absorption/cylinder_test.py index 32b33de86..9272e7714 100644 --- a/tests/absorption/cylinder_test.py +++ b/tests/absorption/cylinder_test.py @@ -8,69 +8,145 @@ @pytest.fixture( params=[ - sc.vector([1, 0, 0]), - sc.vector([0, 1, 0]), - sc.vector([2**-0.5, 2**-0.5, 0]), - sc.vector([2**-0.5, -(2**-0.5), 0]), - sc.vector([-(2**-0.5), 2**-0.5, 0]), - sc.vector([-(2**-0.5), -(2**-0.5), 0]), + sc.vector([-0.9854498123542775, -0.1699666653521188, 0]), + sc.vector([0.7242293076582138, 0.6895592142295717, 0]), + sc.vector([0.9445601257691059, 0.3283384972966938, 0]), + sc.vector([-0.8107703958381735, -0.5853643012965615, 0]), + sc.vector([0.48126033718549477, -0.8765777135269319, 0]), ] ) def point_on_unit_circle(request): return request.param -@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)]) -@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)]) -def test_intersection_in_base(r, h, point_on_unit_circle): - c = Cylinder(sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0]), r, h) +@pytest.fixture(params=[sc.scalar(1.2), sc.scalar(0.2)]) +def height(request): + return request.param + + +@pytest.fixture(params=[sc.scalar(1.2), sc.scalar(0.2)]) +def radius(request): + return request.param + + +@pytest.fixture( + params=[ + sc.vector([0.0, 0.0, 1.0]), + sc.vector([-0.5620126808026259, -0.1798933776079791, 0.8073290031392648]), + ] +) +def axis(request): + return request.param + + +@pytest.fixture( + params=[ + sc.vector([0.0, 0.0, 0.0]), + sc.vector([1.0, -2.0, 3.0]), + ] +) +def base(request): + return request.param + + +@pytest.fixture() +def cylinder(request, axis, base, radius, height): + return Cylinder(axis, base, radius, height) + + +def _rotate_from_z_to_axis(p, ax): + z = sc.vector([0, 0, 1.0]) + if ax == z: + return p + u = sc.cross(z, ax) + un = sc.norm(u) + u *= sc.asin(un) / un + return sc.spatial.rotations_from_rotvecs(u) * p + + +def test_intersection_in_base(cylinder, point_on_unit_circle): + v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line) assert_allclose( - c.beam_intersection(point_on_unit_circle, -point_on_unit_circle), 2 * r + cylinder.beam_intersection( + cylinder.radius * v + + cylinder.center_of_base + # Move start point just inside cylinder. + # Required to make all tests pass. + + 2 * np.finfo(float).eps * cylinder.symmetry_line, + -v, + ), + 2 * cylinder.radius, ) -@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)]) -@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)]) -def test_intersection_diagonal(r, h, point_on_unit_circle): - ax, base = sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0]) - c = Cylinder(ax, base, r, h) - x = r * point_on_unit_circle - n = base - x + ax * h / 2 +def test_intersection_diagonal(cylinder, point_on_unit_circle): + v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line) + n = -2 * cylinder.radius * v + cylinder.symmetry_line * cylinder.height + n /= sc.norm(n) + # Diagonal intersection is as expected assert_allclose( - c.beam_intersection(x, n / sc.norm(n)), ((2 * r) ** 2 + h**2) ** 0.5 + cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, n), + ((2 * cylinder.radius) ** 2 + cylinder.height**2) ** 0.5, + ) + # Intersection is zero in other directions + assert_allclose( + cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, -n), + sc.scalar(0.0), + atol=sc.scalar(1e-14), + ) + n = 2 * cylinder.radius * v + cylinder.symmetry_line * cylinder.height + n /= sc.norm(n) + assert_allclose( + cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, n), + sc.scalar(0.0), + atol=sc.scalar(1e-14), + ) + assert_allclose( + cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, -n), + sc.scalar(0.0), + atol=sc.scalar(1e-14), ) -@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)]) -@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)]) -def test_intersection_along_axis(r, h, point_on_unit_circle): - ax, base = sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0]) - c = Cylinder(ax, base, r, h) - x = (1.0 - np.finfo(float).eps) * r * point_on_unit_circle - assert_allclose(c.beam_intersection(x, ax), h) +def test_intersection_along_axis(cylinder, point_on_unit_circle): + v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line) + # Move point just inside cylinder + x = cylinder.radius * v + cylinder.center_of_base - 2 * np.finfo(float).eps * v + assert_allclose( + cylinder.beam_intersection(x, cylinder.symmetry_line), cylinder.height + ) -def test_no_intersection(point_on_unit_circle): - c = Cylinder( - sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0]), sc.scalar(1.0), sc.scalar(1.0) +def test_no_intersection(cylinder, point_on_unit_circle): + v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line) + x = cylinder.center_of_base + cylinder.radius * v + assert_allclose( + cylinder.beam_intersection(x, v), sc.scalar(0.0), atol=sc.scalar(1e-14) + ) + # Rotate 90 deg around cylinder axis. + # Should still not intersect the cylinder. + v = ( + sc.spatial.rotations_from_rotvecs( + cylinder.symmetry_line * sc.scalar(np.pi / 2, unit='rad') + ) + * v ) assert_allclose( - c.beam_intersection(point_on_unit_circle, point_on_unit_circle), sc.scalar(0.0) + cylinder.beam_intersection(x, v), sc.scalar(0.0), atol=sc.scalar(1e-7) ) - n = point_on_unit_circle - sc.vector([0, 0, 1]) / 2 - n /= sc.norm(n) - assert_allclose(c.beam_intersection(point_on_unit_circle, n), sc.scalar(0.0)) - - x = (1.0 + np.finfo(float).eps) * point_on_unit_circle - assert_allclose(c.beam_intersection(x, sc.vector([0, 0, 1])), sc.scalar(0.0)) - -@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)]) -@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)]) -def test_intersection_interior(r, h, point_on_unit_circle): - c = Cylinder(sc.vector([0, 0, 1.0]), sc.vector([0, 0, -h.value / 2]), r, h) - assert_allclose(c.beam_intersection(sc.vector([0, 0, 0]), point_on_unit_circle), r) +def test_intersection_from_center(cylinder, point_on_unit_circle): + v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line) + assert_allclose(cylinder.beam_intersection(cylinder.center, v), cylinder.radius) + assert_allclose( + cylinder.beam_intersection(cylinder.center, cylinder.symmetry_line), + cylinder.height / 2, + ) + assert_allclose( + cylinder.beam_intersection(cylinder.center, -cylinder.symmetry_line), + cylinder.height / 2, + ) @pytest.mark.parametrize('kind', ['expensive', 'medium', 'cheap'])