Skip to content

bpo-15987: Add ast.AST richcompare methods #1375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import weakref

from test import support
from test.support import findfile


def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)):
Expand Down Expand Up @@ -432,6 +434,80 @@ def test_empty_yield_from(self):
self.assertIn("field value is required", str(cm.exception))


class ASTCompareTest(unittest.TestCase):
def setUp(self):
import imp
imp.reload(ast)

def test_normal_compare(self):
self.assertEqual(ast.parse('x = 10'), ast.parse('x = 10'))
self.assertNotEqual(ast.parse('x = 10'), ast.parse(''))
self.assertNotEqual(ast.parse('x = 10'), ast.parse('x'))
self.assertNotEqual(ast.parse('x = 10;y = 20'), ast.parse('class C:pass'))

def test_literals_compare(self):
self.assertEqual(ast.Num(), ast.Num())
self.assertEqual(ast.Num(-20), ast.Num(-20))
self.assertEqual(ast.Num(10), ast.Num(10))
self.assertEqual(ast.Num(2048), ast.Num(2048))
self.assertEqual(ast.Str(), ast.Str())
self.assertEqual(ast.Str("ABCD"), ast.Str("ABCD"))
self.assertEqual(ast.Str("中文字"), ast.Str("中文字"))

self.assertNotEqual(ast.Num(10), ast.Num(20))
self.assertNotEqual(ast.Num(-10), ast.Num(10))
self.assertNotEqual(ast.Str("AAAA"), ast.Str("BBBB"))
self.assertNotEqual(ast.Str("一二三"), ast.Str("中文字"))

self.assertNotEqual(ast.Num(10), ast.Num())
self.assertNotEqual(ast.Str("AB"), ast.Str())

def test_operator_compare(self):
self.assertEqual(ast.Add(), ast.Add())
self.assertEqual(ast.Sub(), ast.Sub())

self.assertNotEqual(ast.Add(), ast.Sub())
self.assertNotEqual(ast.Add(), ast.Num())

def test_complex_ast(self):
fps = [findfile('test_asyncgen.py'),
findfile('test_generators.py'),
findfile('test_unicode.py')]

for fp in fps:
with open(fp) as f:
try:
source = f.read()
except UnicodeDecodeError:
continue

a = ast.parse(source)
b = ast.parse(source)
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)

def test_exec_compare(self):
for source in exec_tests:
a = ast.parse(source, mode='exec')
b = ast.parse(source, mode='exec')
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)

def test_single_compare(self):
for source in single_tests:
a = ast.parse(source, mode='single')
b = ast.parse(source, mode='single')
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)

def test_eval_compare(self):
for source in eval_tests:
a = ast.parse(source, mode='eval')
b = ast.parse(source, mode='eval')
self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b)))
self.assertFalse(a != b)


class ASTHelpers_Test(unittest.TestCase):

def test_parse(self):
Expand Down
59 changes: 58 additions & 1 deletion Python/Python-ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,63 @@ ast_type_reduce(PyObject *self, PyObject *unused)
return Py_BuildValue("O()", Py_TYPE(self));
}

static PyObject *
ast_richcompare(PyObject *self, PyObject *other, int op)
{
int i, len;
PyObject *fields, *key, *a = Py_None, *b = Py_None;

/* Check operator */
if ((op != Py_EQ && op != Py_NE) ||
!PyAST_Check(self) ||
!PyAST_Check(other)) {
Py_RETURN_NOTIMPLEMENTED;
}

/* Compare types */
if (Py_TYPE(self) != Py_TYPE(other)) {
if (op == Py_EQ)
Py_RETURN_FALSE;
else
Py_RETURN_TRUE;
}

/* Compare fields */
fields = PyObject_GetAttrString(self, "_fields");
len = PySequence_Size(fields);
for (i = 0; i < len; ++i) {
key = PySequence_GetItem(fields, i);

if (PyObject_HasAttr(self, key))
a = PyObject_GetAttr(self, key);
if (PyObject_HasAttr(other, key))
b = PyObject_GetAttr(other, key);


if (Py_TYPE(a) != Py_TYPE(b)) {
if (op == Py_EQ) {
Py_RETURN_FALSE;
}
}

if (op == Py_EQ) {
if (!PyObject_RichCompareBool(a, b, Py_EQ)) {
Py_RETURN_FALSE;
}
}
else if (op == Py_NE) {
if (PyObject_RichCompareBool(a, b, Py_NE)) {
Py_RETURN_TRUE;
}
}
}

if (op == Py_EQ)
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}

static PyMethodDef ast_type_methods[] = {
{"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
{NULL}
Expand Down Expand Up @@ -641,7 +698,7 @@ static PyTypeObject AST_type = {
0, /* tp_doc */
(traverseproc)ast_traverse, /* tp_traverse */
(inquiry)ast_clear, /* tp_clear */
0, /* tp_richcompare */
ast_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Expand Down