@@ -17,7 +17,7 @@ defmodule NxSignal do
17
17
18
18
* `:fs` - the sampling frequency for the input in Hz. Defaults to `1000`.
19
19
* `:nfft` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`.
20
- * `overlap_size` - the number of samples for the overlap between frames.
20
+ * `: overlap_size` - the number of samples for the overlap between frames.
21
21
Defaults to `div(frame_size, 2)`.
22
22
23
23
## Examples
@@ -224,33 +224,77 @@ defmodule NxSignal do
224
224
output
225
225
end
226
226
227
+ @ doc """
228
+ Performs the overlap-and-add algorithm over
229
+ an M by N tensor, where M is the number of
230
+ windows and N is the window size.
231
+
232
+ The tensor is zero-padded on the right so
233
+ the last window fully appears in the result.
234
+
235
+ ## Options
236
+
237
+ * `:overlap_size` - The number of overlapping samples between windows
238
+
239
+ ## Examples
240
+
241
+ iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_size: 0)
242
+ #Nx.Tensor<
243
+ s64[12]
244
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
245
+ >
246
+
247
+ iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_size: 3)
248
+ #Nx.Tensor<
249
+ s64[6]
250
+ [0, 5, 15, 18, 17, 11]
251
+ >
252
+
253
+ """
227
254
defn overlap_and_add ( tensor , opts \\ [ ] ) do
228
- # pad the tensor to recover the edges
229
- padded = Nx . pad ( tensor , 0 , [ { 1 , 1 , 0 } , { 0 , 0 , 0 } ] )
255
+ { stride , num_windows , window_size , output_holder_shape } =
256
+ transform ( { tensor , opts } , fn { tensor , opts } ->
257
+ import Nx.Defn.Kernel , only: [ ]
258
+ import Elixir.Kernel
230
259
231
- { num_frames , output_shape , index_template_shape , hop_size } =
232
- transform ( { opts , padded } , fn { opts , padded } ->
233
- { num_frames , num_samples } = Nx . shape ( padded )
234
- hop_size = num_samples - opts [ :overlap_size ]
260
+ { num_windows , window_size } = Nx . shape ( tensor )
261
+ overlap_size = opts [ :overlap_size ]
235
262
236
- # TO-DO: this can probably be calculated dinamically
237
- output_shape = { opts [ :n ] }
263
+ unless is_number ( overlap_size ) and overlap_size < window_size do
264
+ raise ArgumentError ,
265
+ "overlap_size must be a number less than the window size #{ window_size } , got: #{ inspect ( window_size ) } "
266
+ end
238
267
239
- index_template_shape = { num_samples , 1 }
268
+ stride = window_size - overlap_size
240
269
241
- { num_frames - 2 , output_shape , index_template_shape , hop_size }
242
- end )
270
+ output_holder_shape = { num_windows * stride + overlap_size }
243
271
244
- zeros = Nx . broadcast ( Nx . tensor ( 0 , type: Nx . type ( padded ) ) , output_shape )
272
+ { stride , num_windows , window_size , output_holder_shape }
273
+ end )
245
274
246
- { result , _ , _ , _ , _ } =
247
- while { x = zeros , update_offset = 0 , frame_offset = 0 ,
248
- index_template = Nx . iota ( index_template_shape ) , tensor } ,
249
- frame_offset < num_frames do
250
- updated = Nx . indexed_add ( x , index_template + update_offset , tensor [ frame_offset ] )
251
- { updated , update_offset + hop_size , frame_offset + 1 , index_template , tensor }
275
+ { output , _ , _ , _ , _ , _ } =
276
+ while {
277
+ out = Nx . broadcast ( 0 , output_holder_shape ) ,
278
+ tensor ,
279
+ i = 0 ,
280
+ idx_template = Nx . iota ( { window_size , 1 } ) ,
281
+ stride ,
282
+ num_windows
283
+ } ,
284
+ i < num_windows do
285
+ current_window = tensor [ i ]
286
+ idx = idx_template + i * stride
287
+
288
+ {
289
+ Nx . indexed_add ( out , idx , current_window ) ,
290
+ tensor ,
291
+ i + 1 ,
292
+ idx_template ,
293
+ stride ,
294
+ num_windows
295
+ }
252
296
end
253
297
254
- result
298
+ output
255
299
end
256
300
end
0 commit comments