-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated the allclose
function to have less precision according to the dtype of input DNDarrays.
#1177
Updated the allclose
function to have less precision according to the dtype of input DNDarrays.
#1177
Changes from all commits
680e34e
0925dad
8a1546a
9076e9b
bfc4c23
7a0a468
da9778d
aa35803
8ee5062
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,7 +106,11 @@ def local_all(t, *args, **kwargs): | |
|
||
|
||
def allclose( | ||
x: DNDarray, y: DNDarray, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False | ||
x: DNDarray, | ||
y: DNDarray, | ||
rtol: Optional[float] = None, | ||
atol: Optional[float] = None, | ||
equal_nan: bool = False, | ||
) -> bool: | ||
""" | ||
Test whether two tensors are element-wise equal within a tolerance. Returns ``True`` if ``|x-y|<=atol+rtol*|y|`` | ||
|
@@ -139,17 +143,48 @@ def allclose( | |
""" | ||
t1, t2 = __sanitize_close_input(x, y) | ||
|
||
# Here the adjustment of rtol and atol for float32 arrays increases the values of rtol and atol | ||
# effectively decreasing the precision required to return True. | ||
# By adjusting the tolerance values, it allows for a looser comparison | ||
# that considers the reduced precision of float32 arrays. | ||
|
||
try: | ||
dtype_precision = torch.finfo(t1.larray.dtype).bits | ||
except TypeError: | ||
dtype_precision = torch.iinfo(t1.larray.dtype).bits | ||
|
||
adjustment_factor = 1.0 | ||
|
||
if dtype_precision != 64: | ||
adjustment_factor = 64 / dtype_precision | ||
|
||
if rtol is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand this correctly, the default Nevertheless, I think that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure...sir, I agree. I think |
||
rtol = 1e-05 | ||
adjusted_rtol = rtol * adjustment_factor | ||
|
||
else: | ||
adjusted_rtol = rtol | ||
|
||
if atol is None: | ||
atol = 1e-08 | ||
adjusted_atol = atol * adjustment_factor | ||
|
||
else: | ||
adjusted_atol = atol | ||
|
||
# no sanitation for shapes of x and y needed, torch.allclose raises relevant errors | ||
try: | ||
_local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan)) | ||
_local_allclose = torch.tensor( | ||
torch.allclose(t1.larray, t2.larray, adjusted_rtol, adjusted_atol, equal_nan) | ||
) | ||
except RuntimeError: | ||
promoted_dtype = torch.promote_types(t1.larray.dtype, t2.larray.dtype) | ||
_local_allclose = torch.tensor( | ||
torch.allclose( | ||
t1.larray.type(promoted_dtype), | ||
t2.larray.type(promoted_dtype), | ||
rtol, | ||
atol, | ||
adjusted_rtol, | ||
adjusted_atol, | ||
equal_nan, | ||
) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be discussed (e.g. in the PR talk): should we keep the NumPy/PyTorch-standard and choose some float limits as default (but adapted to Heats float32-standard datatype) or do we want to proceed differently and let tolerances by default be chosen automatically depending on the input datatypes?