diff --git a/greenlet.c b/greenlet.c index d88c0ff2..16a0bd6c 100644 --- a/greenlet.c +++ b/greenlet.c @@ -507,9 +507,14 @@ g_switch(PyGreenlet* target, PyObject* args, PyObject* kwargs) break; } if (!PyGreenlet_STARTED(target)) { + int err; void* dummymarker; ts_target = target; - if (g_initialstub(&dummymarker) < 0) { + err = g_initialstub(&dummymarker); + if (err == 1) { + continue; /* retry the switch */ + } + else if (err < 0) { g_passaround_clear(); return NULL; } @@ -590,11 +595,14 @@ static int GREENLET_NOINLINE(g_initialstub)(void* mark) int err; PyObject* o; PyObject *exc, *val, *tb; + PyGreenlet* self = ts_target; + PyObject* args = ts_passaround_args; + PyObject* kwargs = ts_passaround_kwargs; /* save exception in case getattr clears it */ PyErr_Fetch(&exc, &val, &tb); - /* ts_target.run is the object to call in the new greenlet */ - PyObject* run = PyObject_GetAttrString((PyObject*) ts_target, "run"); + /* self.run is the object to call in the new greenlet */ + PyObject* run = PyObject_GetAttrString((PyObject*) self, "run"); if (run == NULL) { Py_XDECREF(exc); Py_XDECREF(val); @@ -603,27 +611,48 @@ static int GREENLET_NOINLINE(g_initialstub)(void* mark) } /* restore saved exception */ PyErr_Restore(exc, val, tb); + + /* recheck the state in case getattr caused thread switches */ + if (!STATE_OK) { + Py_DECREF(run); + return -1; + } + + /* by the time we got here another start could happen elsewhere, + * that means it should now be a regular switch + */ + if (PyGreenlet_STARTED(self)) { + Py_DECREF(run); + ts_passaround_args = args; + ts_passaround_kwargs = kwargs; + return 1; + } + /* restore arguments in case they are clobbered */ + ts_target = self; + ts_passaround_args = args; + ts_passaround_kwargs = kwargs; + /* now use run_info to store the statedict */ - o = ts_target->run_info; - ts_target->run_info = green_statedict(ts_target->parent); - Py_INCREF(ts_target->run_info); + o = self->run_info; + self->run_info = green_statedict(self->parent); + Py_INCREF(self->run_info); Py_XDECREF(o); /* start the greenlet */ - ts_target->stack_start = NULL; - ts_target->stack_stop = (char*) mark; + self->stack_start = NULL; + self->stack_stop = (char*) mark; if (ts_current->stack_start == NULL) { /* ts_current is dying */ - ts_target->stack_prev = ts_current->stack_prev; + self->stack_prev = ts_current->stack_prev; } else { - ts_target->stack_prev = ts_current; + self->stack_prev = ts_current; } - ts_target->top_frame = NULL; - ts_target->exc_type = NULL; - ts_target->exc_value = NULL; - ts_target->exc_traceback = NULL; - ts_target->recursion_depth = PyThreadState_GET()->recursion_depth; + self->top_frame = NULL; + self->exc_type = NULL; + self->exc_value = NULL; + self->exc_traceback = NULL; + self->recursion_depth = PyThreadState_GET()->recursion_depth; err = g_switchstack(); /* returns twice! The 1st time with err=1: we are in the new greenlet @@ -631,15 +660,10 @@ static int GREENLET_NOINLINE(g_initialstub)(void* mark) */ if (err == 1) { /* in the new greenlet */ - PyObject* args; - PyObject* kwargs; PyObject* result; PyGreenlet* parent; - PyGreenlet* ts_self = ts_current; - ts_self->stack_start = (char*) 1; /* running */ + self->stack_start = (char*) 1; /* running */ - args = ts_passaround_args; - kwargs = ts_passaround_kwargs; if (args == NULL) { /* pending exception */ result = NULL; @@ -654,8 +678,8 @@ static int GREENLET_NOINLINE(g_initialstub)(void* mark) result = g_handle_exit(result); /* jump back to parent */ - ts_self->stack_start = NULL; /* dead */ - for (parent = ts_self->parent; parent != NULL; parent = parent->parent) { + self->stack_start = NULL; /* dead */ + for (parent = self->parent; parent != NULL; parent = parent->parent) { result = g_switch(parent, result, NULL); /* Return here means switch to parent failed, * in which case we throw *current* exception @@ -663,7 +687,7 @@ static int GREENLET_NOINLINE(g_initialstub)(void* mark) */ } /* We ran out of parents, cannot continue */ - PyErr_WriteUnraisable((PyObject *) ts_self); + PyErr_WriteUnraisable((PyObject *) self); Py_FatalError("greenlets cannot continue"); } /* back in the parent */ diff --git a/tests/test_greenlet.py b/tests/test_greenlet.py index d151d00b..a467c88e 100644 --- a/tests/test_greenlet.py +++ b/tests/test_greenlet.py @@ -344,3 +344,19 @@ def creator(): t.start() t.join() self.assertRaises(greenlet.error, result[0].throw, SomeError()) + + def test_recursive_startup(self): + class convoluted(greenlet): + def __init__(self): + greenlet.__init__(self) + self.count = 0 + def __getattribute__(self, name): + if name == 'run' and self.count == 0: + self.count = 1 + self.switch(43) + return greenlet.__getattribute__(self, name) + def run(self, value): + while True: + self.parent.switch(value) + g = convoluted() + self.assertEqual(g.switch(42), 43)