Skip to content

Commit

Permalink
test: more cylinder intersection tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Sep 17, 2024
1 parent 5844be2 commit 131859a
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 50 deletions.
35 changes: 27 additions & 8 deletions src/scippneutron/absorption/cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import scipp as sc
from numpy.polynomial.chebyshev import chebgauss
from numpy.polynomial.legendre import leggauss

from . import quadratures
Expand Down Expand Up @@ -42,13 +43,16 @@ def volume(self):

def _select_quadrature_points(self, kind):
if kind == 'expensive':
x, w = leggauss(10)
# x, w = leggauss(40)
x, w = chebgauss(35)
w *= (1 - x**2) ** 0.5
w /= sum(w) / 2
quad = _cylinder_quadrature_from_product(
quadratures.disk254_cheb,
dict(x=x, weights=w), # noqa: C408
)
elif kind == 'medium':
x, w = leggauss(8)
x, w = leggauss(40)
# Would be nice to have a medium size Chebychev quadrature on the disk,
# but I only found the large one for now.
quad = _cylinder_quadrature_from_product(
Expand All @@ -61,6 +65,16 @@ def _select_quadrature_points(self, kind):
quadratures.disk12,
dict(x=x, weights=w), # noqa: C408
)
elif kind == 'mc':
r = np.random.random(5000) ** 0.5
th = 2 * np.pi * np.random.random(5000)
z = 2 * np.random.random(5000) - 1
quad = {
'x': r * np.cos(th),
'y': r * np.sin(th),
'z': z,
'weights': 2 * np.pi * np.ones(5000) / 5000,
}
else:
raise NotImplementedError
return {k: sc.array(dims=['quad'], values=v) for k, v in quad.items()}
Expand Down Expand Up @@ -151,19 +165,24 @@ def _line_infinite_cylinder_intersection(a, b, r, n):
'''
nxa = sc.cross(n, a)
nxa_square = sc.dot(nxa, nxa)
parallel_to_cylinder = nxa_square == sc.scalar(0.0, unit=nxa.unit)
s2 = nxa_square * r**2 - sc.dot(b, nxa) ** 2
s = sc.sqrt(s2)
m = sc.dot(nxa, sc.cross(b, a))
intersection = s2 >= sc.scalar(0.0, unit=s2.unit)
left = (m - s) / nxa_square
right = (m + s) / nxa_square
left = sc.where(
parallel_to_cylinder,
sc.scalar(float('-inf'), unit=m.unit),
(m - s) / nxa_square,
)
right = sc.where(
parallel_to_cylinder, sc.scalar(float('inf'), unit=m.unit), (m + s) / nxa_square
)
origin_in_cylinder = sc.norm(b - sc.dot(b, a) * a) <= r
ndota = sc.dot(n, a)
parallel_to_cylinder = sc.abs(ndota) == sc.scalar(1.0, unit=ndota.unit)
return (
sc.where(parallel_to_cylinder, origin_in_cylinder, intersection),
sc.where(parallel_to_cylinder, sc.scalar(float('-inf'), unit=left.unit), left),
sc.where(parallel_to_cylinder, sc.scalar(float('inf'), unit=right.unit), right),
left,
right,
)


Expand Down
160 changes: 118 additions & 42 deletions tests/absorption/cylinder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,145 @@

@pytest.fixture(
params=[
sc.vector([1, 0, 0]),
sc.vector([0, 1, 0]),
sc.vector([2**-0.5, 2**-0.5, 0]),
sc.vector([2**-0.5, -(2**-0.5), 0]),
sc.vector([-(2**-0.5), 2**-0.5, 0]),
sc.vector([-(2**-0.5), -(2**-0.5), 0]),
sc.vector([-0.9854498123542775, -0.1699666653521188, 0]),
sc.vector([0.7242293076582138, 0.6895592142295717, 0]),
sc.vector([0.9445601257691059, 0.3283384972966938, 0]),
sc.vector([-0.8107703958381735, -0.5853643012965615, 0]),
sc.vector([0.48126033718549477, -0.8765777135269319, 0]),
]
)
def point_on_unit_circle(request):
return request.param


@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)])
@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)])
def test_intersection_in_base(r, h, point_on_unit_circle):
c = Cylinder(sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0]), r, h)
@pytest.fixture(params=[sc.scalar(1.2), sc.scalar(0.2)])
def height(request):
return request.param


@pytest.fixture(params=[sc.scalar(1.2), sc.scalar(0.2)])
def radius(request):
return request.param


@pytest.fixture(
params=[
sc.vector([0.0, 0.0, 1.0]),
sc.vector([-0.5620126808026259, -0.1798933776079791, 0.8073290031392648]),
]
)
def axis(request):
return request.param


@pytest.fixture(
params=[
sc.vector([0.0, 0.0, 0.0]),
sc.vector([1.0, -2.0, 3.0]),
]
)
def base(request):
return request.param


@pytest.fixture()
def cylinder(request, axis, base, radius, height):
return Cylinder(axis, base, radius, height)


def _rotate_from_z_to_axis(p, ax):
z = sc.vector([0, 0, 1.0])
if ax == z:
return p
u = sc.cross(z, ax)
un = sc.norm(u)
u *= sc.asin(un) / un
return sc.spatial.rotations_from_rotvecs(u) * p


def test_intersection_in_base(cylinder, point_on_unit_circle):
v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line)
assert_allclose(
c.beam_intersection(point_on_unit_circle, -point_on_unit_circle), 2 * r
cylinder.beam_intersection(
cylinder.radius * v
+ cylinder.center_of_base
# Move start point just inside cylinder.
# Required to make all tests pass.
+ 2 * np.finfo(float).eps * cylinder.symmetry_line,
-v,
),
2 * cylinder.radius,
)


@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)])
@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)])
def test_intersection_diagonal(r, h, point_on_unit_circle):
ax, base = sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0])
c = Cylinder(ax, base, r, h)
x = r * point_on_unit_circle
n = base - x + ax * h / 2
def test_intersection_diagonal(cylinder, point_on_unit_circle):
v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line)
n = -2 * cylinder.radius * v + cylinder.symmetry_line * cylinder.height
n /= sc.norm(n)
# Diagonal intersection is as expected
assert_allclose(
c.beam_intersection(x, n / sc.norm(n)), ((2 * r) ** 2 + h**2) ** 0.5
cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, n),
((2 * cylinder.radius) ** 2 + cylinder.height**2) ** 0.5,
)
# Intersection is zero in other directions
assert_allclose(
cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, -n),
sc.scalar(0.0),
atol=sc.scalar(1e-14),
)
n = 2 * cylinder.radius * v + cylinder.symmetry_line * cylinder.height
n /= sc.norm(n)
assert_allclose(
cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, n),
sc.scalar(0.0),
atol=sc.scalar(1e-14),
)
assert_allclose(
cylinder.beam_intersection(cylinder.radius * v + cylinder.center_of_base, -n),
sc.scalar(0.0),
atol=sc.scalar(1e-14),
)


@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)])
@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)])
def test_intersection_along_axis(r, h, point_on_unit_circle):
ax, base = sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0])
c = Cylinder(ax, base, r, h)
x = (1.0 - np.finfo(float).eps) * r * point_on_unit_circle
assert_allclose(c.beam_intersection(x, ax), h)
def test_intersection_along_axis(cylinder, point_on_unit_circle):
v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line)
# Move point just inside cylinder
x = cylinder.radius * v + cylinder.center_of_base - 2 * np.finfo(float).eps * v
assert_allclose(
cylinder.beam_intersection(x, cylinder.symmetry_line), cylinder.height
)


def test_no_intersection(point_on_unit_circle):
c = Cylinder(
sc.vector([0, 0, 1.0]), sc.vector([0, 0, 0]), sc.scalar(1.0), sc.scalar(1.0)
def test_no_intersection(cylinder, point_on_unit_circle):
v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line)
x = cylinder.center_of_base + cylinder.radius * v
assert_allclose(
cylinder.beam_intersection(x, v), sc.scalar(0.0), atol=sc.scalar(1e-14)
)
# Rotate 90 deg around cylinder axis.
# Should still not intersect the cylinder.
v = (
sc.spatial.rotations_from_rotvecs(
cylinder.symmetry_line * sc.scalar(np.pi / 2, unit='rad')
)
* v
)
assert_allclose(
c.beam_intersection(point_on_unit_circle, point_on_unit_circle), sc.scalar(0.0)
cylinder.beam_intersection(x, v), sc.scalar(0.0), atol=sc.scalar(1e-7)
)

n = point_on_unit_circle - sc.vector([0, 0, 1]) / 2
n /= sc.norm(n)
assert_allclose(c.beam_intersection(point_on_unit_circle, n), sc.scalar(0.0))

x = (1.0 + np.finfo(float).eps) * point_on_unit_circle
assert_allclose(c.beam_intersection(x, sc.vector([0, 0, 1])), sc.scalar(0.0))


@pytest.mark.parametrize('h', [sc.scalar(0.2), sc.scalar(1.0)])
@pytest.mark.parametrize('r', [sc.scalar(0.2), sc.scalar(1.0)])
def test_intersection_interior(r, h, point_on_unit_circle):
c = Cylinder(sc.vector([0, 0, 1.0]), sc.vector([0, 0, -h.value / 2]), r, h)
assert_allclose(c.beam_intersection(sc.vector([0, 0, 0]), point_on_unit_circle), r)
def test_intersection_from_center(cylinder, point_on_unit_circle):
v = _rotate_from_z_to_axis(point_on_unit_circle, cylinder.symmetry_line)
assert_allclose(cylinder.beam_intersection(cylinder.center, v), cylinder.radius)
assert_allclose(
cylinder.beam_intersection(cylinder.center, cylinder.symmetry_line),
cylinder.height / 2,
)
assert_allclose(
cylinder.beam_intersection(cylinder.center, -cylinder.symmetry_line),
cylinder.height / 2,
)


@pytest.mark.parametrize('kind', ['expensive', 'medium', 'cheap'])
Expand Down

0 comments on commit 131859a

Please sign in to comment.