Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-40528: Implement a metadata system for ASDL Generator #20193

Merged
merged 1 commit into from
Jun 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 75 additions & 24 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
import textwrap
import types

from argparse import ArgumentParser
from contextlib import contextmanager
Expand Down Expand Up @@ -100,11 +101,12 @@ def asdl_of(name, obj):
class EmitVisitor(asdl.VisitorBase):
"""Visit that emits lines"""

def __init__(self, file):
def __init__(self, file, metadata = None):
self.file = file
self.identifiers = set()
self.singletons = set()
self.types = set()
self._metadata = metadata
super(EmitVisitor, self).__init__()

def emit_identifier(self, name):
Expand All @@ -127,6 +129,42 @@ def emit(self, s, depth, reflow=True):
line = (" " * TABSIZE * depth) + line
self.file.write(line + "\n")

@property
def metadata(self):
if self._metadata is None:
raise ValueError(
"%s was expecting to be annnotated with metadata"
% type(self).__name__
)
return self._metadata

@metadata.setter
def metadata(self, value):
self._metadata = value

class MetadataVisitor(asdl.VisitorBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Metadata:
# - simple_sums: Tracks the list of compound type
# names where all the constructors
# belonging to that type lack of any
# fields.
self.metadata = types.SimpleNamespace(
simple_sums=set()
)

def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)

def visitType(self, type):
self.visit(type.value, type.name)

def visitSum(self, sum, name):
if is_simple(sum):
self.metadata.simple_sums.add(name)

class TypeDefVisitor(EmitVisitor):
def visitModule(self, mod):
Expand Down Expand Up @@ -244,7 +282,7 @@ def visitField(self, field, depth):
ctype = get_c_type(field.type)
name = field.name
if field.seq:
if field.type == 'cmpop':
if field.type in self.metadata.simple_sums:
self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
else:
_type = field.type
Expand Down Expand Up @@ -304,7 +342,7 @@ def get_args(self, fields):
name = f.name
# XXX should extend get_c_type() to handle this
if f.seq:
if f.type == 'cmpop':
if f.type in self.metadata.simple_sums:
ctype = "asdl_int_seq *"
else:
ctype = f"asdl_{f.type}_seq *"
Expand Down Expand Up @@ -549,16 +587,11 @@ def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0):
ctype = get_c_type(field.type)
self.emit("%s %s;" % (ctype, field.name), depth)

def isSimpleSum(self, field):
# XXX can the members of this list be determined automatically?
return field.type in ('expr_context', 'boolop', 'operator',
'unaryop', 'cmpop')

def isNumeric(self, field):
return get_c_type(field.type) in ("int", "bool")

def isSimpleType(self, field):
return self.isSimpleSum(field) or self.isNumeric(field)
return field.type in self.metadata.simple_sums or self.isNumeric(field)

def visitField(self, field, name, sum=None, prod=None, depth=0):
ctype = get_c_type(field.type)
Expand Down Expand Up @@ -1282,18 +1315,23 @@ def emit(s, d):

def set(self, field, value, depth):
if field.seq:
# XXX should really check for is_simple, but that requires a symbol table
if field.type == "cmpop":
if field.type in self.metadata.simple_sums:
# While the sequence elements are stored as void*,
# ast2obj_cmpop expects an enum
# simple sums expects an enum
self.emit("{", depth)
self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1)
self.emit("value = PyList_New(n);", depth+1)
self.emit("if (!value) goto failed;", depth+1)
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value,
depth+2, reflow=False)
self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type,
value
),
depth + 2,
reflow=False,
)
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
Expand Down Expand Up @@ -1362,11 +1400,13 @@ class PartingShots(StaticVisitor):
"""

class ChainOfVisitors:
def __init__(self, *visitors):
def __init__(self, *visitors, metadata = None):
self.visitors = visitors
self.metadata = metadata

def visit(self, object):
for v in self.visitors:
v.metadata = self.metadata
v.visit(object)
v.emit("", 0)

Expand Down Expand Up @@ -1468,7 +1508,7 @@ def generate_module_def(mod, f, internal_h):
f.write(' return 1;\n')
f.write('};\n\n')

def write_header(mod, f):
def write_header(mod, metadata, f):
f.write(textwrap.dedent("""
#ifndef Py_INTERNAL_AST_H
#define Py_INTERNAL_AST_H
Expand All @@ -1483,12 +1523,19 @@ def write_header(mod, f):
#include "pycore_asdl.h"

""").lstrip())
c = ChainOfVisitors(TypeDefVisitor(f),
SequenceDefVisitor(f),
StructVisitor(f))

c = ChainOfVisitors(
TypeDefVisitor(f),
SequenceDefVisitor(f),
StructVisitor(f),
metadata=metadata
)
c.visit(mod)

f.write("// Note: these macros affect function definitions, not only call sites.\n")
PrototypeVisitor(f).visit(mod)
prototype_visitor = PrototypeVisitor(f, metadata=metadata)
prototype_visitor.visit(mod)

f.write(textwrap.dedent("""

PyObject* PyAST_mod2obj(mod_ty t);
Expand Down Expand Up @@ -1535,8 +1582,7 @@ def write_internal_h_footer(mod, f):
#endif /* !Py_INTERNAL_AST_STATE_H */
"""), file=f)


def write_source(mod, f, internal_h_file):
def write_source(mod, metadata, f, internal_h_file):
generate_module_def(mod, f, internal_h_file)

v = ChainOfVisitors(
Expand All @@ -1549,6 +1595,7 @@ def write_source(mod, f, internal_h_file):
Obj2ModVisitor(f),
ASTModuleVisitor(f),
PartingShots(f),
metadata=metadata
)
v.visit(mod)

Expand All @@ -1561,6 +1608,10 @@ def main(input_filename, c_filename, h_filename, internal_h_filename, dump_modul
if not asdl.check(mod):
sys.exit(1)

metadata_visitor = MetadataVisitor()
metadata_visitor.visit(mod)
metadata = metadata_visitor.metadata

with c_filename.open("w") as c_file, \
h_filename.open("w") as h_file, \
internal_h_filename.open("w") as internal_h_file:
Expand All @@ -1569,8 +1620,8 @@ def main(input_filename, c_filename, h_filename, internal_h_filename, dump_modul
internal_h_file.write(auto_gen_msg)

write_internal_h_header(mod, internal_h_file)
write_source(mod, c_file, internal_h_file)
write_header(mod, h_file)
write_source(mod, metadata, c_file, internal_h_file)
write_header(mod, metadata, h_file)
write_internal_h_footer(mod, internal_h_file)

print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.")
Expand Down