From 792dcb37d523ce2ccdbcea1d8886aa60e4657fd0 Mon Sep 17 00:00:00 2001 From: Johannes Kasimir Date: Fri, 27 Sep 2024 14:17:59 +0200 Subject: [PATCH] fix: reduce size before broadcasting --- src/scippneutron/absorption/base.py | 34 ++++++++++++----------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/scippneutron/absorption/base.py b/src/scippneutron/absorption/base.py index 1830d8dca..fe3a89e21 100644 --- a/src/scippneutron/absorption/base.py +++ b/src/scippneutron/absorption/base.py @@ -16,9 +16,6 @@ def compute_transmission_map( quadrature_kind: Any = 'medium', ) -> sc.DataArray: points, weights = sample_shape.quadrature(quadrature_kind) - scatter_direction = detector_position - points.to(unit=detector_position.unit) - scatter_direction /= sc.norm(scatter_direction) - transmission = _integrate_transmission_fraction( partial( _single_scatter_distance_through_sample, @@ -29,7 +26,7 @@ def compute_transmission_map( partial(_transmission_fraction, sample_material), points, weights, - scatter_direction, + detector_position, wavelength, ) return sc.DataArray( @@ -54,41 +51,38 @@ def _transmission_fraction(material, distance_through_sample, wavelength): ) -def _size_after_broadcast(a, b): - 'Size of the result of broadcasting a and b' - size = 1 - for s in {**a.sizes, **b.sizes}.values(): - size *= s - return size - - def _integrate_transmission_fraction( distance_through_sample, transmission, points, weights, - scatter_direction, + detector_position, wavelengths, ): # If size after broadcast is too large # then don't vectorize the operation to avoid OOM - if _size_after_broadcast(points, scatter_direction) > 100_000_000: - dim = scatter_direction.dims[0] - return sc.concat( - [ + if points.size * detector_position.size > 20_000_000: + out = [] + dim = detector_position.dims[0] + for i in range(detector_position.sizes[dim]): + out.append( # noqa: PERF401 _integrate_transmission_fraction( distance_through_sample, transmission, points, weights, - scatter_direction[dim, i], + detector_position[dim, i], wavelengths, ) - for i in range(scatter_direction.shape[0]) - ], + ) + + return sc.concat( + out, dim=dim, ) + scatter_direction = detector_position - points.to(unit=detector_position.unit) + scatter_direction /= sc.norm(scatter_direction) Ltot = distance_through_sample(scatter_direction) # The Ltot array is already large, to avoid OOM, don't vectorize this operation