Skip to content

Commit 2197797

Browse files
Merge pull request sympy#26948 from oscarbenjamin/pr_zero_comparison_113
fix(core): make comparisons of Integer(0) and Float(0) consistent.
2 parents 863fc64 + 9d6100a commit 2197797

File tree

10 files changed

+34
-21
lines changed

10 files changed

+34
-21
lines changed

sympy/core/evalf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,15 +1494,15 @@ def evalf(x: 'Expr', prec: int, options: OPT_DICT) -> TMP_RES:
14941494
re, im = as_real_imag()
14951495
if re.has(re_) or im.has(im_):
14961496
raise NotImplementedError
1497-
if re == 0.0:
1497+
if not re:
14981498
re = None
14991499
reprec = None
15001500
elif re.is_number:
15011501
re = re._to_mpmath(prec, allow_ints=False)._mpf_
15021502
reprec = prec
15031503
else:
15041504
raise NotImplementedError
1505-
if im == 0.0:
1505+
if not im:
15061506
im = None
15071507
imprec = None
15081508
elif im.is_number:

sympy/core/numbers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,8 +1597,6 @@ def __eq__(self, other):
15971597
# S(0) == S.false is False
15981598
# S(0) == False is True
15991599
return False
1600-
if not self:
1601-
return not other
16021600
if other.is_NumberSymbol:
16031601
if other.is_irrational:
16041602
return False

sympy/core/tests/test_numbers.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,27 @@ def eq(a, b):
455455
t = Float("1.0E-15")
456456
return (-t < a - b < t)
457457

458-
zeros = (0, S.Zero, 0., Float(0))
459-
for i, j in permutations(zeros[:-1], 2):
460-
assert i == j
461-
for i, j in permutations(zeros[-2:], 2):
462-
assert i == j
463-
for z in zeros:
464-
assert z in zeros
458+
equal_pairs = [
459+
(0, 0.0), # This is just how Python works...
460+
(0, S.Zero),
461+
(0.0, Float(0)),
462+
]
463+
unequal_pairs = [
464+
(0.0, S.Zero),
465+
(0, Float(0)),
466+
(S.Zero, Float(0)),
467+
]
468+
for p1, p2 in equal_pairs:
469+
assert (p1 == p2) is True
470+
assert (p1 != p2) is False
471+
assert (p2 == p1) is True
472+
assert (p2 != p1) is False
473+
for p1, p2 in unequal_pairs:
474+
assert (p1 == p2) is False
475+
assert (p1 != p2) is True
476+
assert (p2 == p1) is False
477+
assert (p2 != p1) is True
478+
465479
assert S.Zero.is_zero
466480

467481
a = Float(2) ** Float(3)

sympy/geometry/tests/test_point.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def test_arguments():
418418
a = Point(0, 1)
419419
assert a/10.0 == Point(0, 0.1, evaluate=False)
420420
a = Point(0, 1)
421-
assert a*10.0 == Point(0.0, 10.0, evaluate=False)
421+
assert a*10.0 == Point(0, 10.0, evaluate=False)
422422

423423
# test evaluate=False when changing dimensions
424424
u = Point(.1, .2, evaluate=False)

sympy/integrals/tests/test_integrals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2080,7 +2080,7 @@ def test_issue_20782():
20802080
assert integrate(fun1, L) == 1
20812081
assert integrate(fun2, L) == 0
20822082
assert integrate(-fun1, L) == -1
2083-
assert integrate(-fun2, L) == 0.
2083+
assert integrate(-fun2, L) == 0
20842084
assert integrate(fun_sum, L) == 1.
20852085
assert integrate(-fun_sum, L) == -1.
20862086

sympy/physics/quantum/qubit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def matrix_to_qubit(matrix):
493493
element = matrix[0, i]
494494
if format in ('numpy', 'scipy.sparse'):
495495
element = complex(element)
496-
if element != 0.0:
496+
if element:
497497
# Form Qubit array; 0 in bit-locations where i is 0, 1 in
498498
# bit-locations where i is 1
499499
qubit_array = [int(i & (1 << x) != 0) for x in range(nqubits)]
@@ -582,7 +582,7 @@ def measure_all(qubit, format='sympy', normalize=True):
582582
size = max(m.shape) # Max of shape to account for bra or ket
583583
nqubits = int(math.log(size)/math.log(2))
584584
for i in range(size):
585-
if m[i] != 0.0:
585+
if m[i]:
586586
results.append(
587587
(Qubit(IntQubit(i, nqubits=nqubits)), m[i]*conjugate(m[i]))
588588
)

sympy/polys/matrices/tests/test_linsolve.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def test__linsolve_float():
3232
y - x,
3333
y - 0.0216 * x
3434
]
35-
sol = {x:0.0, y:0.0}
35+
# Should _linsolve return floats here?
36+
sol = {x:0, y:0}
3637
assert _linsolve(eqs, (x, y)) == sol
3738

3839
# Other cases should be close to eps

sympy/polys/tests/test_polytools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3033,7 +3033,7 @@ def test_nroots():
30333033
eps = Float("1e-5")
30343034

30353035
assert re(roots[0]).epsilon_eq(-0.75487, eps) is S.true
3036-
assert im(roots[0]) == 0.0
3036+
assert im(roots[0]) == 0
30373037
assert re(roots[1]) == Float(-0.5, 5)
30383038
assert im(roots[1]).epsilon_eq(-0.86602, eps) is S.true
30393039
assert re(roots[2]) == Float(-0.5, 5)
@@ -3046,7 +3046,7 @@ def test_nroots():
30463046
eps = Float("1e-6")
30473047

30483048
assert re(roots[0]).epsilon_eq(-0.75487, eps) is S.false
3049-
assert im(roots[0]) == 0.0
3049+
assert im(roots[0]) == 0
30503050
assert re(roots[1]) == Float(-0.5, 5)
30513051
assert im(roots[1]).epsilon_eq(-0.86602, eps) is S.false
30523052
assert re(roots[2]) == Float(-0.5, 5)

sympy/solvers/tests/test_numeric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def getroot(x0):
7373

7474
def test_issue_6408():
7575
x = Symbol('x')
76-
assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0.0
76+
assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0
7777

7878

7979
def test_issue_6408_integral():
8080
x, y = symbols('x y')
81-
assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0.0
81+
assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0
8282

8383

8484
@conserve_mpmath_dps

sympy/utilities/tests/test_wester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_C24():
269269

270270

271271
def test_D1():
272-
assert 0.0 / sqrt(2) == 0.0
272+
assert 0.0 / sqrt(2) == 0
273273

274274

275275
def test_D2():

0 commit comments

Comments
 (0)