Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import numpy as np
import math
from typing import Optional, Sequence, Tuple, TypeVar, Union

Expand Down Expand Up @@ -2048,6 +2049,9 @@
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)

def float_lowest(dtype):
"""Returns the lowest representable value for the given numpy dtype."""
return np.finfo(np.dtype(dtype)).min

def _aten_scaled_dot_product_attention_bool_mask_onnx(
query: TFloat,
Expand Down Expand Up @@ -2078,7 +2082,7 @@
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype)
Copy link
Collaborator

@justinchuby justinchuby Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you can use query.dtype.min directly because it is an implemented method in ir.DataType: https://onnx.ai/ir-py/api/generated/onnx_ir.DataType.html#onnx_ir.DataType.min

Check failure

Code scanning / lintrunner

PYLINT/E1123 Error

Unexpected keyword argument 'dtype' in method call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg
attn_mask = op.Where(attn_mask, zero, neg_inf)
attn_weight = op.Softmax(
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
Expand Down
7 changes: 7 additions & 0 deletions tests/common/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from onnxscript import optimizer
from onnxscript.onnx_opset import opset18 as op
from onnxscript.rewriter import onnxruntime as ort_rewriter
from onnxscript.utils import evaluation_utils

Expand Down Expand Up @@ -101,3 +102,9 @@ def test_onnxruntime_rewrite(
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
)
raise

def test_softmax_with_all_inf_mask():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the test. It does not belong here

# GH #2561
input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32)
output = op.Softmax(input, axis=-1)
assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf"
Loading