diff --git a/src/scippneutron/absorption/base.py b/src/scippneutron/absorption/base.py index 938b7bd36..dcb273869 100644 --- a/src/scippneutron/absorption/base.py +++ b/src/scippneutron/absorption/base.py @@ -1,3 +1,5 @@ +from functools import partial + import scipp as sc @@ -12,7 +14,7 @@ def single_scatter_distance_through_sample( def transmission_fraction(material, distance_through_sample, wavelength): return sc.exp( -(material.attenuation_coefficient(wavelength) * distance_through_sample).to( - unit='dimensionless' + unit='dimensionless', copy=False ) ) @@ -23,25 +25,76 @@ def compute_transmission_map( beam_direction, wavelength, detector_position, - quadrature_kind='expensive', + quadrature_kind='medium', ): points, weights = sample_shape.quadrature(quadrature_kind) scatter_direction = detector_position - points.to(unit=detector_position.unit) scatter_direction /= sc.norm(scatter_direction) - Ltot = single_scatter_distance_through_sample( - sample_shape, points, beam_direction, scatter_direction - ) - total_transmission = sc.concat( - # The Ltot array is already large, to avoid OOM, don't vectorize this operation - [ - (transmission_fraction(sample_material, Ltot, w) * weights).sum(weights.dim) - / sample_shape.volume - for w in wavelength - ], - dim=wavelength.dim, + transmission = _integrate_transmission_fraction( + partial( + single_scatter_distance_through_sample, sample_shape, points, beam_direction + ), + partial(transmission_fraction, sample_material), + points, + weights, + scatter_direction, + wavelength, ) return sc.DataArray( - data=total_transmission, + data=transmission / sample_shape.volume, coords={'detector_position': detector_position, 'wavelength': wavelength}, ) + + +def _size_after_broadcast(a, b): + 'Size of the result of broadcasting a and b' + size = 1 + for s in {**a.sizes, **b.sizes}.values(): + size *= s + return size + + +def _integrate_transmission_fraction( + distance_through_sample, + transmission, + points, + weights, + scatter_direction, + wavelengths, +): + # If size after broadcast is too large + # then don't vectorize the operation to avoid OOM + if _size_after_broadcast(points, scatter_direction) > 100_000_000: + dim = scatter_direction.dims[0] + return sc.concat( + [ + _integrate_transmission_fraction( + distance_through_sample, + transmission, + points, + weights, + scatter_direction[dim, i], + wavelengths, + ) + for i in range(scatter_direction.shape[0]) + ], + dim=dim, + ) + + Ltot = distance_through_sample(scatter_direction) + + # The Ltot array is already large, to avoid OOM, don't vectorize this operation + return sc.concat( + [ + # Instead of broadcast multiply and sum, use matvec for efficiency + # this becomes relevant when the number of wavelength points grows + sc.array( + dims=(tf := transmission(Ltot, w)).dims[:-1], + values=tf.values @ weights.values, + unit=tf.unit * weights.unit, + ) + for w in wavelengths + ], + dim=wavelengths.dim, + )