Skip to content

Test fft normalization #2209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Optional, Sequence

from onnxscript import INT64
from onnxscript import INT64, ir

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'ir' is not used.

Copilot Autofix

AI 10 days ago

To fix the issue, we should remove the unused import of ir from the onnxscript module. This will clean up the code and eliminate the unnecessary dependency. The change should be made on line 17, where the import statement is defined.

Suggested changeset 1
onnxscript/function_libs/torch_lib/ops/fft.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py
--- a/onnxscript/function_libs/torch_lib/ops/fft.py
+++ b/onnxscript/function_libs/torch_lib/ops/fft.py
@@ -16,3 +16,3 @@
 
-from onnxscript import INT64, ir
+from onnxscript import INT64
 from onnxscript.function_libs.torch_lib.registration import torch_op
EOF
@@ -16,3 +16,3 @@

from onnxscript import INT64, ir
from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
Copilot is powered by AI and may make mistakes. Always verify output.

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused ir imported from onnxscript (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

onnxscript.ir imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
Expand Down Expand Up @@ -118,12 +118,18 @@
# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
# place no such restriction on the ONNX side.
transformed = op.DFT(
transformed,
dft_length=last_dim_size,
axis=dimension,
inverse=True,
onesided=False,
scale = (op.CastLike(last_dim_size, self)) / op.CastLike(

Check warning on line 121 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L121

Added line #L121 was not covered by tests
op.Shape(transformed, start=dimension, end=dimension + 1), self
)
transformed = (

Check warning on line 124 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L124

Added line #L124 was not covered by tests
op.DFT(
transformed,
dft_length=last_dim_size,
axis=dimension,
inverse=True,
onesided=False,
)
* scale
)
transformed = _fftn_onnx_normalization(
transformed,
Copy link
Contributor

@bmehta001 bmehta001 Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps replace line 137 (op.Shape...) with op.CastLike(last_dim_size, self) and then remove scale? Would that yield the same/better results?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought last_dim_size was op.Shape(transformed, start=dimension, end=dimension + 1)? Let me try

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, nvm, sorry, you're completely right about op.Shape(transformed, start=dimension, end=dimension + 1) being different between line 122 and 137. But your code made me realize that without modifying anything else, line 137 perhaps should be directly replaced with last_dim_size just to save a call to op.Shape.

Expand Down
Loading