Skip to content

Commit

Permalink
fix: reduce size before broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr committed Sep 27, 2024
1 parent f9ea772 commit 792dcb3
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions src/scippneutron/absorption/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ def compute_transmission_map(
quadrature_kind: Any = 'medium',
) -> sc.DataArray:
points, weights = sample_shape.quadrature(quadrature_kind)
scatter_direction = detector_position - points.to(unit=detector_position.unit)
scatter_direction /= sc.norm(scatter_direction)

transmission = _integrate_transmission_fraction(
partial(
_single_scatter_distance_through_sample,
Expand All @@ -29,7 +26,7 @@ def compute_transmission_map(
partial(_transmission_fraction, sample_material),
points,
weights,
scatter_direction,
detector_position,
wavelength,
)
return sc.DataArray(
Expand All @@ -54,41 +51,38 @@ def _transmission_fraction(material, distance_through_sample, wavelength):
)


def _size_after_broadcast(a, b):
'Size of the result of broadcasting a and b'
size = 1
for s in {**a.sizes, **b.sizes}.values():
size *= s
return size


def _integrate_transmission_fraction(
distance_through_sample,
transmission,
points,
weights,
scatter_direction,
detector_position,
wavelengths,
):
# If size after broadcast is too large
# then don't vectorize the operation to avoid OOM
if _size_after_broadcast(points, scatter_direction) > 100_000_000:
dim = scatter_direction.dims[0]
return sc.concat(
[
if points.size * detector_position.size > 20_000_000:
out = []
dim = detector_position.dims[0]
for i in range(detector_position.sizes[dim]):
out.append( # noqa: PERF401
_integrate_transmission_fraction(
distance_through_sample,
transmission,
points,
weights,
scatter_direction[dim, i],
detector_position[dim, i],
wavelengths,
)
for i in range(scatter_direction.shape[0])
],
)

return sc.concat(
out,
dim=dim,
)

scatter_direction = detector_position - points.to(unit=detector_position.unit)
scatter_direction /= sc.norm(scatter_direction)
Ltot = distance_through_sample(scatter_direction)

# The Ltot array is already large, to avoid OOM, don't vectorize this operation
Expand Down

0 comments on commit 792dcb3

Please sign in to comment.