Skip to content

Commit

Permalink
feat: reduce memory usage to avoid oom
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Sep 24, 2024
1 parent 86f25ed commit db12bdd
Showing 1 changed file with 67 additions and 14 deletions.
81 changes: 67 additions & 14 deletions src/scippneutron/absorption/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import scipp as sc


Expand All @@ -12,7 +14,7 @@ def single_scatter_distance_through_sample(
def transmission_fraction(material, distance_through_sample, wavelength):
return sc.exp(
-(material.attenuation_coefficient(wavelength) * distance_through_sample).to(
unit='dimensionless'
unit='dimensionless', copy=False
)
)

Expand All @@ -23,25 +25,76 @@ def compute_transmission_map(
beam_direction,
wavelength,
detector_position,
quadrature_kind='expensive',
quadrature_kind='medium',
):
points, weights = sample_shape.quadrature(quadrature_kind)
scatter_direction = detector_position - points.to(unit=detector_position.unit)
scatter_direction /= sc.norm(scatter_direction)

Ltot = single_scatter_distance_through_sample(
sample_shape, points, beam_direction, scatter_direction
)
total_transmission = sc.concat(
# The Ltot array is already large, to avoid OOM, don't vectorize this operation
[
(transmission_fraction(sample_material, Ltot, w) * weights).sum(weights.dim)
/ sample_shape.volume
for w in wavelength
],
dim=wavelength.dim,
transmission = _integrate_transmission_fraction(
partial(
single_scatter_distance_through_sample, sample_shape, points, beam_direction
),
partial(transmission_fraction, sample_material),
points,
weights,
scatter_direction,
wavelength,
)
return sc.DataArray(
data=total_transmission,
data=transmission / sample_shape.volume,
coords={'detector_position': detector_position, 'wavelength': wavelength},
)


def _size_after_broadcast(a, b):
'Size of the result of broadcasting a and b'
size = 1
for s in {**a.sizes, **b.sizes}.values():
size *= s
return size


def _integrate_transmission_fraction(
distance_through_sample,
transmission,
points,
weights,
scatter_direction,
wavelengths,
):
# If size after broadcast is too large
# then don't vectorize the operation to avoid OOM
if _size_after_broadcast(points, scatter_direction) > 100_000_000:
dim = scatter_direction.dims[0]
return sc.concat(
[
_integrate_transmission_fraction(
distance_through_sample,
transmission,
points,
weights,
scatter_direction[dim, i],
wavelengths,
)
for i in range(scatter_direction.shape[0])
],
dim=dim,
)

Ltot = distance_through_sample(scatter_direction)

# The Ltot array is already large, to avoid OOM, don't vectorize this operation
return sc.concat(
[
# Instead of broadcast multiply and sum, use matvec for efficiency
# this becomes relevant when the number of wavelength points grows
sc.array(
dims=(tf := transmission(Ltot, w)).dims[:-1],
values=tf.values @ weights.values,
unit=tf.unit * weights.unit,
)
for w in wavelengths
],
dim=wavelengths.dim,
)

0 comments on commit db12bdd

Please sign in to comment.