Skip to content

Conversation

@Aniketsy
Copy link

#2561

  • Use lowest representable float value instead of -inf for attention masks.
  • Add NaN-safe handling and a unit test for softmax with all masked positions.

Please let me know if my approach or fix needs any improvements . I’m open to feedback and happy to make changes based on suggestions.
Thankyou !

@Aniketsy
Copy link
Author

@microsoft-github-policy-service agree

)
raise

def test_softmax_with_all_inf_mask():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the test. It does not belong here

# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype)
Copy link
Collaborator

@justinchuby justinchuby Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you can use query.dtype.min directly because it is an implemented method in ir.DataType: https://onnx.ai/ir-py/api/generated/onnx_ir.DataType.html#onnx_ir.DataType.min

@justinchuby
Copy link
Collaborator

If this PR is facilitated by an AI, please disclose its usage.

# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype)

Check failure

Code scanning / lintrunner

PYLINT/E1123 Error

Unexpected keyword argument 'dtype' in method call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg
@codecov
Copy link

codecov bot commented Oct 26, 2025

Codecov Report

❌ Patch coverage is 50.00000% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.46%. Comparing base (8a94ad6) to head (d6eccaa).

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 50.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2654      +/-   ##
==========================================
- Coverage   70.46%   70.46%   -0.01%     
==========================================
  Files         224      224              
  Lines       26572    26575       +3     
  Branches     2637     2637              
==========================================
+ Hits        18723    18725       +2     
- Misses       6928     6929       +1     
  Partials      921      921              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Aniketsy
Copy link
Author

Aniketsy commented Oct 26, 2025

I went through the questions you mentioned and yes, i used AI assistance to help add the unit test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

2 participants