Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 96 additions & 24 deletions src/ess/reduce/time_of_flight/toa_to_tof.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,56 @@ def _time_of_flight_data_histogram(
return rebinned.assign_coords(tof=tofs)


def _guess_pulse_stride_offset(
pulse_index: sc.Variable,
ltotal: sc.Variable,
event_time_offset: sc.Variable,
pulse_stride: int,
interp: Callable,
) -> int:
"""
Using the minimum ``event_time_zero`` to calculate a reference time when computing
the time-of-flight for the neutron events makes the workflow depend on when the
first event was recorded. There is no straightforward way to know if we started
recording at the beginning of a frame, or half-way through a frame, without looking
at the chopper logs. This can be manually corrected using the pulse_stride_offset
parameter, but this makes automatic reduction of the data difficult.
See https://github.com/scipp/essreduce/issues/184.

Here, we perform a simple guess for the ``pulse_stride_offset`` if it is not
provided.
We choose a few random events, compute the time-of-flight for every possible value
of pulse_stride_offset, and return the value that yields the least number of NaNs
in the computed time-of-flight.

Parameters
----------
pulse_index:
Pulse index for every event.
ltotal:
Total length of the flight path from the source to the detector for each event.
event_time_offset:
Time of arrival of the neutron at the detector for each event.
pulse_stride:
Stride of used pulses.
interp:
2D interpolator for the lookup table.
"""
tofs = {}
# Choose a few random events to compute the time-of-flight
inds = np.random.choice(
len(event_time_offset), min(5000, len(event_time_offset)), replace=False
)
pulse_index_values = pulse_index.values[inds]
ltotal_values = ltotal.values[inds]
etos_values = event_time_offset.values[inds]
for i in range(pulse_stride):
pulse_inds = (pulse_index_values + i) % pulse_stride
tofs[i] = interp((pulse_inds, ltotal_values, etos_values))
# Find the entry in the list with the least number of nan values
return sorted(tofs, key=lambda x: np.isnan(tofs[x]).sum())[0]


def _time_of_flight_data_events(
da: sc.DataArray,
lookup: sc.DataArray,
Expand All @@ -399,28 +449,6 @@ def _time_of_flight_data_events(
) -> sc.DataArray:
etos = da.bins.coords["event_time_offset"]
eto_unit = elem_unit(etos)
pulse_period = pulse_period.to(unit=eto_unit)
frame_period = pulse_period * pulse_stride

# TODO: Finding the `tmin` below will not work in the case were data is processed
# in chunks, as taking the minimum time in each chunk will lead to inconsistent
# pulse indices (this will be the case in live data, or when using the
# StreamProcessor). We could instead read it from the first chunk and store it?

# Compute a pulse index for every event: it is the index of the pulse within a
# frame period. When there is no pulse skipping, those are all zero. When there is
# pulse skipping, the index ranges from zero to pulse_stride - 1.
tmin = da.bins.coords['event_time_zero'].min()
pulse_index = (
(
(da.bins.coords['event_time_zero'] - tmin).to(unit=eto_unit)
+ 0.5 * pulse_period
)
% frame_period
) // pulse_period
# Apply the pulse_stride_offset
pulse_index += pulse_stride_offset
pulse_index %= pulse_stride

# Create 2D interpolator
interp = _make_tof_interpolator(
Expand All @@ -430,7 +458,51 @@ def _time_of_flight_data_events(
# Operate on events (broadcast distances to all events)
ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"]
etos = etos.bins.constituents["data"]
pulse_index = pulse_index.bins.constituents["data"]

# Compute a pulse index for every event: it is the index of the pulse within a
# frame period. When there is no pulse skipping, those are all zero. When there is
# pulse skipping, the index ranges from zero to pulse_stride - 1.
if pulse_stride == 1:
pulse_index = sc.zeros(sizes=etos.sizes)
else:
etz_unit = 'ns'
etz = (
da.bins.coords["event_time_zero"]
.bins.constituents["data"]
.to(unit=etz_unit, copy=False)
)
pulse_period = pulse_period.to(unit=etz_unit, dtype=int)
frame_period = pulse_period * pulse_stride
# Define a common reference time using epoch as a base, but making sure that it
# is aligned with the pulse_period and the frame_period.
# We need to use a global reference time instead of simply taking the minimum
# event_time_zero because the events may arrive in chunks, and the first event
# may not be the first event of the first pulse for all chunks. This would lead
# to inconsistent pulse indices.
epoch = sc.datetime(0, unit=etz_unit)
diff_to_epoch = (etz.min() - epoch) % pulse_period
# Here we offset the reference by half a pulse period to avoid errors from
# fluctuations in the event_time_zeros in the data. They are triggered by the
# neutron source, and may not always be exactly separated by the pulse period.
# While fluctuations will exist, they will be small, and offsetting the times
# by half a pulse period is a simple enough fix.
reference = epoch + diff_to_epoch - (pulse_period // 2)
# Use in-place operations to avoid large allocations
pulse_index = etz - reference
pulse_index %= frame_period
pulse_index //= pulse_period

# Apply the pulse_stride_offset
if pulse_stride_offset is None:
pulse_stride_offset = _guess_pulse_stride_offset(
pulse_index=pulse_index,
ltotal=ltotal,
event_time_offset=etos,
pulse_stride=pulse_stride,
interp=interp,
)
pulse_index += pulse_stride_offset
pulse_index %= pulse_stride

# Compute time-of-flight for all neutrons using the interpolator
tofs = sc.array(
Expand Down Expand Up @@ -535,7 +607,7 @@ def default_parameters() -> dict:
return {
PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"),
PulseStride: 1,
PulseStrideOffset: 0,
PulseStrideOffset: None,
DistanceResolution: sc.scalar(0.1, unit="m"),
TimeResolution: sc.scalar(250.0, unit='us'),
LookupTableRelativeErrorThreshold: 0.1,
Expand Down
4 changes: 2 additions & 2 deletions src/ess/reduce/time_of_flight/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ class SimulationResults:
Stride of used pulses. Usually 1, but may be a small integer when pulse-skipping.
"""

PulseStrideOffset = NewType("PulseStrideOffset", int)
PulseStrideOffset = NewType("PulseStrideOffset", int | None)
"""
When pulse-skipping, the offset of the first pulse in the stride. This is typically
zero but can be a small integer < pulse_stride.
zero but can be a small integer < pulse_stride. If None, a guess is made.
"""

RawData = NewType("RawData", sc.DataArray)
Expand Down
43 changes: 43 additions & 0 deletions tests/time_of_flight/unwrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,49 @@ def test_pulse_skipping_unwrap_180_phase_shift() -> None:
assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3))


def test_pulse_skipping_stride_offset_guess_gives_expected_result() -> None:
distance = sc.scalar(100.0, unit="m")
choppers = fakes.psc_choppers()
choppers["pulse_skipping"] = fakes.pulse_skipping_chopper()
choppers["pulse_skipping"].phase.value += 180.0

beamline = fakes.FakeBeamline(
choppers=choppers,
monitors={"detector": distance},
run_length=sc.scalar(1.0, unit="s"),
events_per_pulse=100_000,
seed=4,
)
mon = beamline.get_monitor("detector")[0]

sim = time_of_flight.simulate_beamline(
choppers=choppers, neutrons=300_000, pulses=2, seed=1234
)

pl = sl.Pipeline(
time_of_flight.providers(), params=time_of_flight.default_parameters()
)

pl[time_of_flight.RawData] = mon
pl[time_of_flight.SimulationResults] = sim
pl[time_of_flight.LtotalRange] = distance, distance
pl[time_of_flight.PulseStride] = 2

# Cache the table to avoid noise from re-computing
pl[time_of_flight.TimeOfFlightLookupTable] = pl.compute(
time_of_flight.TimeOfFlightLookupTable
)

with_guess = pl.compute(time_of_flight.TofData)
pl[time_of_flight.PulseStrideOffset] = 1 # Start the stride at the second pulse
no_guess = pl.compute(time_of_flight.TofData)
assert sc.allclose(
with_guess.bins.concat().value.coords['tof'],
no_guess.bins.concat().value.coords['tof'],
equal_nan=True,
)


def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse() -> None:
distance = sc.scalar(150.0, unit="m")
choppers = fakes.psc_choppers()
Expand Down