Skip to content

Commit 2c69e76

Browse files
authored
feat: overlap-and-add (#1)
* feat: overlap-and-add * feat: overlap-and-add working implementation * refactor: don't explicitly support negative overlap
1 parent 55545b4 commit 2c69e76

File tree

3 files changed

+85
-20
lines changed

3 files changed

+85
-20
lines changed

examples/440Hz.wav

434 KB
Binary file not shown.

examples/read_and_filter_file.livemd

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Untitled notebook
2+
3+
```elixir
4+
Mix.install([
5+
{:waverider, ">= 0.0.0", github: "StareIntoTheBeard/waverider"},
6+
{:nx_signal, ">= 0.0.0", github: "polvalente/nx-signal"},
7+
{:nx, ">= 0.0.0",
8+
github: "elixir-nx/nx", branch: "pv-fix/reduce-instead-of-sum", sparse: "nx", override: true},
9+
{:vega_lite, "~> 0.1.4"},
10+
{:kino_vega_lite, "~> 0.1.1"}
11+
])
12+
```
13+
14+
## Section
15+
16+
```elixir
17+
# Audio sample download from https://www.mediacollege.com/audio/tone/files/440Hz_44100Hz_16bit_05sec.wav
18+
19+
{:ok, contents} = File.read("./examples/440Hz.wav")
20+
{:ok, wavefile} = Wave.parse(contents)
21+
```

lib/nx_signal.ex

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ defmodule NxSignal do
1717
1818
* `:fs` - the sampling frequency for the input in Hz. Defaults to `1000`.
1919
* `:nfft` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`.
20-
* `overlap_size` - the number of samples for the overlap between frames.
20+
* `:overlap_size` - the number of samples for the overlap between frames.
2121
Defaults to `div(frame_size, 2)`.
2222
2323
## Examples
@@ -224,33 +224,77 @@ defmodule NxSignal do
224224
output
225225
end
226226

227+
@doc """
228+
Performs the overlap-and-add algorithm over
229+
an M by N tensor, where M is the number of
230+
windows and N is the window size.
231+
232+
The tensor is zero-padded on the right so
233+
the last window fully appears in the result.
234+
235+
## Options
236+
237+
* `:overlap_size` - The number of overlapping samples between windows
238+
239+
## Examples
240+
241+
iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_size: 0)
242+
#Nx.Tensor<
243+
s64[12]
244+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
245+
>
246+
247+
iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_size: 3)
248+
#Nx.Tensor<
249+
s64[6]
250+
[0, 5, 15, 18, 17, 11]
251+
>
252+
253+
"""
227254
defn overlap_and_add(tensor, opts \\ []) do
228-
# pad the tensor to recover the edges
229-
padded = Nx.pad(tensor, 0, [{1, 1, 0}, {0, 0, 0}])
255+
{stride, num_windows, window_size, output_holder_shape} =
256+
transform({tensor, opts}, fn {tensor, opts} ->
257+
import Nx.Defn.Kernel, only: []
258+
import Elixir.Kernel
230259

231-
{num_frames, output_shape, index_template_shape, hop_size} =
232-
transform({opts, padded}, fn {opts, padded} ->
233-
{num_frames, num_samples} = Nx.shape(padded)
234-
hop_size = num_samples - opts[:overlap_size]
260+
{num_windows, window_size} = Nx.shape(tensor)
261+
overlap_size = opts[:overlap_size]
235262

236-
# TO-DO: this can probably be calculated dinamically
237-
output_shape = {opts[:n]}
263+
unless is_number(overlap_size) and overlap_size < window_size do
264+
raise ArgumentError,
265+
"overlap_size must be a number less than the window size #{window_size}, got: #{inspect(window_size)}"
266+
end
238267

239-
index_template_shape = {num_samples, 1}
268+
stride = window_size - overlap_size
240269

241-
{num_frames - 2, output_shape, index_template_shape, hop_size}
242-
end)
270+
output_holder_shape = {num_windows * stride + overlap_size}
243271

244-
zeros = Nx.broadcast(Nx.tensor(0, type: Nx.type(padded)), output_shape)
272+
{stride, num_windows, window_size, output_holder_shape}
273+
end)
245274

246-
{result, _, _, _, _} =
247-
while {x = zeros, update_offset = 0, frame_offset = 0,
248-
index_template = Nx.iota(index_template_shape), tensor},
249-
frame_offset < num_frames do
250-
updated = Nx.indexed_add(x, index_template + update_offset, tensor[frame_offset])
251-
{updated, update_offset + hop_size, frame_offset + 1, index_template, tensor}
275+
{output, _, _, _, _, _} =
276+
while {
277+
out = Nx.broadcast(0, output_holder_shape),
278+
tensor,
279+
i = 0,
280+
idx_template = Nx.iota({window_size, 1}),
281+
stride,
282+
num_windows
283+
},
284+
i < num_windows do
285+
current_window = tensor[i]
286+
idx = idx_template + i * stride
287+
288+
{
289+
Nx.indexed_add(out, idx, current_window),
290+
tensor,
291+
i + 1,
292+
idx_template,
293+
stride,
294+
num_windows
295+
}
252296
end
253297

254-
result
298+
output
255299
end
256300
end

0 commit comments

Comments
 (0)