@@ -688,10 +688,14 @@ def visitModule(self, mod):
688688static int
689689ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
690690{
691+ astmodulestate *state = get_global_ast_state();
692+ if (state == NULL) {
693+ return -1;
694+ }
695+
691696 Py_ssize_t i, numfields = 0;
692697 int res = -1;
693698 PyObject *key, *value, *fields;
694- astmodulestate *state = get_global_ast_state();
695699 if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
696700 goto cleanup;
697701 }
@@ -761,6 +765,10 @@ def visitModule(self, mod):
761765ast_type_reduce(PyObject *self, PyObject *unused)
762766{
763767 astmodulestate *state = get_global_ast_state();
768+ if (state == NULL) {
769+ return NULL;
770+ }
771+
764772 PyObject *dict;
765773 if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
766774 return NULL;
@@ -969,9 +977,8 @@ def visitModule(self, mod):
969977
970978""" , 0 , reflow = False )
971979
972- self .emit ("static int init_types(void )" ,0 )
980+ self .emit ("static int init_types(astmodulestate *state )" ,0 )
973981 self .emit ("{" , 0 )
974- self .emit ("astmodulestate *state = get_global_ast_state();" , 1 )
975982 self .emit ("if (state->initialized) return 1;" , 1 )
976983 self .emit ("if (init_identifiers(state) < 0) return 0;" , 1 )
977984 self .emit ("state->AST_type = PyType_FromSpec(&AST_type_spec);" , 1 )
@@ -1046,40 +1053,55 @@ def emit_defaults(self, name, fields, depth):
10461053class ASTModuleVisitor (PickleVisitor ):
10471054
10481055 def visitModule (self , mod ):
1049- self .emit ("PyMODINIT_FUNC " , 0 )
1050- self .emit ("PyInit__ast(void )" , 0 )
1056+ self .emit ("static int " , 0 )
1057+ self .emit ("astmodule_exec(PyObject *m )" , 0 )
10511058 self .emit ("{" , 0 )
1052- self .emit ("PyObject *m = PyModule_Create(&_astmodule);" , 1 )
1053- self .emit ("if (!m) {" , 1 )
1054- self .emit ("return NULL;" , 2 )
1055- self .emit ("}" , 1 )
10561059 self .emit ('astmodulestate *state = get_ast_state(m);' , 1 )
1057- self .emit ('' , 1 )
1060+ self .emit ("" , 0 )
10581061
1059- self .emit ("if (!init_types()) {" , 1 )
1060- self .emit ("goto error ;" , 2 )
1062+ self .emit ("if (!init_types(state )) {" , 1 )
1063+ self .emit ("return -1 ;" , 2 )
10611064 self .emit ("}" , 1 )
10621065 self .emit ('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {' , 1 )
1063- self .emit ('goto error ;' , 2 )
1066+ self .emit ('return -1 ;' , 2 )
10641067 self .emit ('}' , 1 )
10651068 self .emit ('Py_INCREF(state->AST_type);' , 1 )
10661069 self .emit ('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {' , 1 )
1067- self .emit ("goto error ;" , 2 )
1070+ self .emit ("return -1 ;" , 2 )
10681071 self .emit ('}' , 1 )
10691072 self .emit ('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {' , 1 )
1070- self .emit ("goto error ;" , 2 )
1073+ self .emit ("return -1 ;" , 2 )
10711074 self .emit ('}' , 1 )
10721075 self .emit ('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {' , 1 )
1073- self .emit ("goto error ;" , 2 )
1076+ self .emit ("return -1 ;" , 2 )
10741077 self .emit ('}' , 1 )
10751078 for dfn in mod .dfns :
10761079 self .visit (dfn )
1077- self .emit ("return m;" , 1 )
1078- self .emit ("" , 0 )
1079- self .emit ("error:" , 0 )
1080- self .emit ("Py_DECREF(m);" , 1 )
1081- self .emit ("return NULL;" , 1 )
1080+ self .emit ("return 0;" , 1 )
10821081 self .emit ("}" , 0 )
1082+ self .emit ("" , 0 )
1083+ self .emit ("""
1084+ static PyModuleDef_Slot astmodule_slots[] = {
1085+ {Py_mod_exec, astmodule_exec},
1086+ {0, NULL}
1087+ };
1088+
1089+ static struct PyModuleDef _astmodule = {
1090+ PyModuleDef_HEAD_INIT,
1091+ .m_name = "_ast",
1092+ .m_size = sizeof(astmodulestate),
1093+ .m_slots = astmodule_slots,
1094+ .m_traverse = astmodule_traverse,
1095+ .m_clear = astmodule_clear,
1096+ .m_free = astmodule_free,
1097+ };
1098+
1099+ PyMODINIT_FUNC
1100+ PyInit__ast(void)
1101+ {
1102+ return PyModuleDef_Init(&_astmodule);
1103+ }
1104+ """ .strip (), 0 , reflow = False )
10831105
10841106 def visitProduct (self , prod , name ):
10851107 self .addObj (name )
@@ -1095,7 +1117,7 @@ def visitConstructor(self, cons, name):
10951117 def addObj (self , name ):
10961118 self .emit ("if (PyModule_AddObject(m, \" %s\" , "
10971119 "state->%s_type) < 0) {" % (name , name ), 1 )
1098- self .emit ("goto error ;" , 2 )
1120+ self .emit ("return -1 ;" , 2 )
10991121 self .emit ('}' , 1 )
11001122 self .emit ("Py_INCREF(state->%s_type);" % name , 1 )
11011123
@@ -1255,11 +1277,10 @@ class PartingShots(StaticVisitor):
12551277 CODE = """
12561278PyObject* PyAST_mod2obj(mod_ty t)
12571279{
1258- if (!init_types()) {
1280+ astmodulestate *state = get_global_ast_state();
1281+ if (state == NULL) {
12591282 return NULL;
12601283 }
1261-
1262- astmodulestate *state = get_global_ast_state();
12631284 return ast2obj_mod(state, t);
12641285}
12651286
@@ -1281,10 +1302,6 @@ class PartingShots(StaticVisitor):
12811302
12821303 assert(0 <= mode && mode <= 2);
12831304
1284- if (!init_types()) {
1285- return NULL;
1286- }
1287-
12881305 isinstance = PyObject_IsInstance(ast, req_type[mode]);
12891306 if (isinstance == -1)
12901307 return NULL;
@@ -1303,11 +1320,10 @@ class PartingShots(StaticVisitor):
13031320
13041321int PyAST_Check(PyObject* obj)
13051322{
1306- if (!init_types()) {
1323+ astmodulestate *state = get_global_ast_state();
1324+ if (state == NULL) {
13071325 return -1;
13081326 }
1309-
1310- astmodulestate *state = get_global_ast_state();
13111327 return PyObject_IsInstance(obj, state->AST_type);
13121328}
13131329"""
@@ -1358,12 +1374,35 @@ def generate_module_def(f, mod):
13581374 f .write (' PyObject *' + s + ';\n ' )
13591375 f .write ('} astmodulestate;\n \n ' )
13601376 f .write ("""
1361- static astmodulestate global_ast_state;
1377+ static astmodulestate*
1378+ get_ast_state(PyObject *module)
1379+ {
1380+ void *state = PyModule_GetState(module);
1381+ assert(state != NULL);
1382+ return (astmodulestate*)state;
1383+ }
13621384
1363- static astmodulestate *
1364- get_ast_state(PyObject *Py_UNUSED(module) )
1385+ static astmodulestate*
1386+ get_global_ast_state(void )
13651387{
1366- return &global_ast_state;
1388+ _Py_IDENTIFIER(_ast);
1389+ PyObject *name = _PyUnicode_FromId(&PyId__ast); // borrowed reference
1390+ if (name == NULL) {
1391+ return NULL;
1392+ }
1393+ PyObject *module = PyImport_GetModule(name);
1394+ if (module == NULL) {
1395+ if (PyErr_Occurred()) {
1396+ return NULL;
1397+ }
1398+ module = PyImport_Import(name);
1399+ if (module == NULL) {
1400+ return NULL;
1401+ }
1402+ }
1403+ astmodulestate *state = get_ast_state(module);
1404+ Py_DECREF(module);
1405+ return state;
13671406}
13681407
13691408static int astmodule_clear(PyObject *module)
@@ -1390,17 +1429,6 @@ def generate_module_def(f, mod):
13901429 astmodule_clear((PyObject*)module);
13911430}
13921431
1393- static struct PyModuleDef _astmodule = {
1394- PyModuleDef_HEAD_INIT,
1395- .m_name = "_ast",
1396- .m_size = -1,
1397- .m_traverse = astmodule_traverse,
1398- .m_clear = astmodule_clear,
1399- .m_free = astmodule_free,
1400- };
1401-
1402- #define get_global_ast_state() (&global_ast_state)
1403-
14041432""" )
14051433 f .write ('static int init_identifiers(astmodulestate *state)\n ' )
14061434 f .write ('{\n ' )
0 commit comments