Skip to content

Conversation

@moatom
Copy link

@moatom moatom commented Oct 19, 2025

Fixed pytorch/pytorch#147052

$ python -m pytest tests/function_libs/torch_lib/ops_test.py -k ops_aten_stft
====================================================================================================================================================================================================== test session starts ======================================================================================================================================================================================================
platform linux -- Python 3.13.1, pytest-8.4.1, pluggy-1.6.0
Using --randomly-seed=371864411
rootdir: /home/moatom/github/onnxscript
configfile: pyproject.toml
plugins: randomly-3.16.0, xdist-3.8.0, subtests-0.14.2, cov-6.2.1, hypothesis-6.138.2
collected 2158 items / 2154 deselected / 4 selected                                                                                                                                                                                                                                                                                                                                                                             

tests/function_libs/torch_lib/ops_test.py s..x                                                                                                                                                                                                                                                                                                                                                                     [100%]

======================================================================================================================================================================================================= warnings summary ========================================================================================================================================================================================================
onnxscript/converter.py:457: 429 warnings
tests/function_libs/torch_lib/ops_test.py: 15 warnings
  /home/moatom/github/onnxscript/onnxscript/converter.py:457: DeprecationWarning: Expression.__init__ got an unexpected keyword argument 'lineno'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15.
    expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset)

onnxscript/converter.py:457: 429 warnings
tests/function_libs/torch_lib/ops_test.py: 15 warnings
  /home/moatom/github/onnxscript/onnxscript/converter.py:457: DeprecationWarning: Expression.__init__ got an unexpected keyword argument 'col_offset'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15.
    expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset)

tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32
tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32
tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32
  /home/moatom/github/onnxscript/tests/function_libs/torch_lib/ops_test_common.py:329: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
    value = np.array(value.cpu())

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================================================================================================== short test summary info ====================================================================================================================================================================================================
SKIPPED [1] tests/function_libs/torch_lib/ops_test.py:101: Traced functions does not have a function proto
=================================================================================================================================================================== 2 passed, 1 skipped, 2154 deselected, 1 xfailed, 891 warnings, 7 subtests passed in 4.42s ===================================================================================================================================================================

@moatom moatom changed the title Add stft Implement aten.stft Oct 19, 2025
@moatom
Copy link
Author

moatom commented Oct 19, 2025

@microsoft-github-policy-service agree

@moatom
Copy link
Author

moatom commented Oct 19, 2025

pytorch/pytorch#147052 (comment)

It would be nice to decay to onnx.STFT node in the graph if options of a given call allow this...

I think this PR’s implementation already uses onnx.STFT as its core: https://github.com/microsoft/onnxscript/pull/2645/files#diff-773a426eaaf0eb2a1ed6039f2cd50613d2db7d61c239453545413499b2a15776R8194

result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)

So I don’t think an decay to onnx.STFT is necessary in this case. 🤔

@moatom
Copy link
Author

moatom commented Oct 19, 2025

@justinchuby

Hi!

pytorch/pytorch#147052 (comment)

I think it can be simplified quite a bit.

Could you elaborate a bit on this comment at your convenience? (If you think this simplification is necessary.)

@moatom moatom marked this pull request as ready for review October 19, 2025 15:00
@justinchuby justinchuby requested review from Copilot and titaiwangms and removed request for titaiwangms October 20, 2025 16:17
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements the aten::stft (Short-Time Fourier Transform) operator to resolve issue #147052. The implementation includes handling for various optional parameters like hop_length, win_length, window, normalized, onesided, and return_complex.

Key changes:

  • Added STFT operator implementation with helper functions for batch dimension handling, window centering, and FFT normalization
  • Registered the operator in test data with appropriate tolerance settings and xfail for float16 dtype

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Implements aten_stft and five helper functions for STFT processing
tests/function_libs/torch_lib/ops_test_data.py Registers the new operator in test suite with tolerance and xfail configuration

@titaiwangms titaiwangms self-assigned this Oct 24, 2025
@codecov
Copy link

codecov bot commented Oct 24, 2025

Codecov Report

❌ Patch coverage is 70.96774% with 18 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.47%. Comparing base (8a94ad6) to head (3c12aae).

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 70.96% 15 Missing and 3 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2645      +/-   ##
==========================================
+ Coverage   70.46%   70.47%   +0.01%     
==========================================
  Files         224      224              
  Lines       26572    26634      +62     
  Branches     2637     2645       +8     
==========================================
+ Hits        18723    18770      +47     
- Misses       6928     6940      +12     
- Partials      921      924       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

if signal_rank == 1:
# Add a batch dimension
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
return op.Identity(self), signal_rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby Is identity op necessary to return self?

) -> TFloat:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
result = result / sqrt_nfft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use op for this kind of calculation when you delete private flag.

Copy link
Author

@moatom moatom Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]:
   signal_rank = op.Size(op.Shape(self))
   if signal_rank == 1:

I am surveying, but I don’t know how to handle conditionals in op right now. (op.If and op.While don’t work well...)

For example, the following expression doesn't make op.Equal true:

self = op.Where(
        op.Equal(signal_rank, op.Constant(value_int=1)),
        op.Unsqueeze(self, op.Constant(value_ints=[0])),
        self
    )

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surveying, but I don’t know how to handle conditionals in op right now. (op.If and op.While don’t work well...)

I resolved this by moving the conditional parts to aten_stft except _center_window_around_zeros_if_needed.

Perhaps we should add a test to check window = op.Where(op.Less(op.Squeeze(n_win), n_fft), window_padded, window).

# first dimension
n_win = op.Shape(window, start=0, end=1)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby Is there a good way we trace this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

[ONNX] Implement aten.stft

2 participants