Skip to content

gh-111178: fix UBSan failures in Modules/_functoolsmodule.c #129778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions Modules/_functoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ typedef struct {
} partialobject;

// cast a PyObject pointer PTR to a partialobject pointer (no type checks)
#define _PyPartialObject_CAST(PTR) ((partialobject *)(PTR))
#define partialobject_CAST(op) ((partialobject *)(op))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There' no need to change existing macros, especially ones with the _Py prefix that's de-facto reserved for Python.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err actually this is something I probably wrote recently so I thought it would have been better to make it consistent at least. I can revert it though (I think that's my code because there is a PTR and this is something I used before Victor told me to use "op")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, it's from #124733 from October


static void partial_setvectorcall(partialobject *pto);
static struct PyModuleDef _functools_module;
Expand Down Expand Up @@ -312,7 +312,7 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
static int
partial_clear(PyObject *self)
{
partialobject *pto = _PyPartialObject_CAST(self);
partialobject *pto = partialobject_CAST(self);
Py_CLEAR(pto->fn);
Py_CLEAR(pto->args);
Py_CLEAR(pto->kw);
Expand All @@ -323,7 +323,7 @@ partial_clear(PyObject *self)
static int
partial_traverse(PyObject *self, visitproc visit, void *arg)
{
partialobject *pto = _PyPartialObject_CAST(self);
partialobject *pto = partialobject_CAST(self);
Py_VISIT(Py_TYPE(pto));
Py_VISIT(pto->fn);
Py_VISIT(pto->args);
Expand All @@ -338,7 +338,7 @@ partial_dealloc(PyObject *self)
PyTypeObject *tp = Py_TYPE(self);
/* bpo-31095: UnTrack is needed before calling any callbacks */
PyObject_GC_UnTrack(self);
if (_PyPartialObject_CAST(self)->weakreflist != NULL) {
if (partialobject_CAST(self)->weakreflist != NULL) {
PyObject_ClearWeakRefs(self);
}
(void)partial_clear(self);
Expand Down Expand Up @@ -372,7 +372,7 @@ static PyObject *
partial_vectorcall(PyObject *self, PyObject *const *args,
size_t nargsf, PyObject *kwnames)
{
partialobject *pto = _PyPartialObject_CAST(self);;
partialobject *pto = partialobject_CAST(self);;
PyThreadState *tstate = _PyThreadState_GET();
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);

Expand Down Expand Up @@ -482,7 +482,7 @@ partial_setvectorcall(partialobject *pto)
static PyObject *
partial_call(PyObject *self, PyObject *args, PyObject *kwargs)
{
partialobject *pto = _PyPartialObject_CAST(self);
partialobject *pto = partialobject_CAST(self);
assert(PyCallable_Check(pto->fn));
assert(PyTuple_Check(pto->args));
assert(PyDict_Check(pto->kw));
Expand Down Expand Up @@ -595,7 +595,7 @@ static PyGetSetDef partial_getsetlist[] = {
static PyObject *
partial_repr(PyObject *self)
{
partialobject *pto = _PyPartialObject_CAST(self);
partialobject *pto = partialobject_CAST(self);
PyObject *result = NULL;
PyObject *arglist;
PyObject *mod;
Expand Down Expand Up @@ -668,7 +668,7 @@ partial_repr(PyObject *self)
static PyObject *
partial_reduce(PyObject *self, PyObject *Py_UNUSED(args))
{
partialobject *pto = _PyPartialObject_CAST(self);
partialobject *pto = partialobject_CAST(self);
return Py_BuildValue("O(O)(OOOO)", Py_TYPE(pto), pto->fn, pto->fn,
pto->args, pto->kw,
pto->dict ? pto->dict : Py_None);
Expand All @@ -677,7 +677,7 @@ partial_reduce(PyObject *self, PyObject *Py_UNUSED(args))
static PyObject *
partial_setstate(PyObject *self, PyObject *state)
{
partialobject *pto = _PyPartialObject_CAST(self);
partialobject *pto = partialobject_CAST(self);
PyObject *fn, *fnargs, *kw, *dict;

if (!PyTuple_Check(state)) {
Expand Down Expand Up @@ -782,16 +782,19 @@ typedef struct {
PyObject *object;
} keyobject;

#define keyobject_CAST(op) ((keyobject *)(op))

static int
keyobject_clear(keyobject *ko)
keyobject_clear(PyObject *op)
{
keyobject *ko = keyobject_CAST(op);
Py_CLEAR(ko->cmp);
Py_CLEAR(ko->object);
return 0;
}

static void
keyobject_dealloc(keyobject *ko)
keyobject_dealloc(PyObject *ko)
{
PyTypeObject *tp = Py_TYPE(ko);
PyObject_GC_UnTrack(ko);
Expand All @@ -801,8 +804,9 @@ keyobject_dealloc(keyobject *ko)
}

static int
keyobject_traverse(keyobject *ko, visitproc visit, void *arg)
keyobject_traverse(PyObject *op, visitproc visit, void *arg)
{
keyobject *ko = keyobject_CAST(op);
Py_VISIT(Py_TYPE(ko));
Py_VISIT(ko->cmp);
Py_VISIT(ko->object);
Expand All @@ -817,18 +821,18 @@ static PyMemberDef keyobject_members[] = {
};

static PyObject *
keyobject_text_signature(PyObject *self, void *Py_UNUSED(ignored))
keyobject_text_signature(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
{
return PyUnicode_FromString("(obj)");
}

static PyGetSetDef keyobject_getset[] = {
{"__text_signature__", keyobject_text_signature, (setter)NULL},
{"__text_signature__", keyobject_text_signature, NULL},
{NULL}
};

static PyObject *
keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds);
keyobject_call(PyObject *ko, PyObject *args, PyObject *kwds);

static PyObject *
keyobject_richcompare(PyObject *ko, PyObject *other, int op);
Expand All @@ -854,11 +858,12 @@ static PyType_Spec keyobject_type_spec = {
};

static PyObject *
keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds)
keyobject_call(PyObject *self, PyObject *args, PyObject *kwds)
{
PyObject *object;
keyobject *result;
static char *kwargs[] = {"obj", NULL};
keyobject *ko = keyobject_CAST(self);

if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:K", kwargs, &object))
return NULL;
Expand All @@ -874,17 +879,20 @@ keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds)
}

static PyObject *
keyobject_richcompare(PyObject *ko, PyObject *other, int op)
keyobject_richcompare(PyObject *self, PyObject *other, int op)
{
if (!Py_IS_TYPE(other, Py_TYPE(ko))) {
if (!Py_IS_TYPE(other, Py_TYPE(self))) {
PyErr_Format(PyExc_TypeError, "other argument must be K instance");
return NULL;
}

PyObject *compare = ((keyobject *) ko)->cmp;
keyobject *lhs = keyobject_CAST(self);
keyobject *rhs = keyobject_CAST(other);

PyObject *compare = lhs->cmp;
assert(compare != NULL);
PyObject *x = ((keyobject *) ko)->object;
PyObject *y = ((keyobject *) other)->object;
PyObject *x = lhs->object;
PyObject *y = rhs->object;
if (!x || !y){
PyErr_Format(PyExc_AttributeError, "object");
return NULL;
Expand Down Expand Up @@ -1053,9 +1061,12 @@ typedef struct lru_list_elem {
PyObject *key, *result;
} lru_list_elem;

#define lru_list_elem_CAST(op) ((lru_list_elem *)(op))

static void
lru_list_elem_dealloc(lru_list_elem *link)
lru_list_elem_dealloc(PyObject *op)
{
lru_list_elem *link = lru_list_elem_CAST(op);
PyTypeObject *tp = Py_TYPE(link);
Py_XDECREF(link->key);
Py_XDECREF(link->result);
Expand Down Expand Up @@ -1096,6 +1107,8 @@ typedef struct lru_cache_object {
PyObject *weakreflist;
} lru_cache_object;

#define lru_cache_object_CAST(op) ((lru_cache_object *)(op))

static PyObject *
lru_cache_make_key(PyObject *kwd_mark, PyObject *args,
PyObject *kwds, int typed)
Expand Down Expand Up @@ -1531,8 +1544,9 @@ lru_cache_clear_list(lru_list_elem *link)
}

static int
lru_cache_tp_clear(lru_cache_object *self)
lru_cache_tp_clear(PyObject *op)
{
lru_cache_object *self = lru_cache_object_CAST(op);
lru_list_elem *list = lru_cache_unlink_list(self);
Py_CLEAR(self->cache);
Py_CLEAR(self->func);
Expand All @@ -1545,23 +1559,25 @@ lru_cache_tp_clear(lru_cache_object *self)
}

static void
lru_cache_dealloc(lru_cache_object *obj)
lru_cache_dealloc(PyObject *op)
{
lru_cache_object *obj = lru_cache_object_CAST(op);
PyTypeObject *tp = Py_TYPE(obj);
/* bpo-31095: UnTrack is needed before calling any callbacks */
PyObject_GC_UnTrack(obj);
if (obj->weakreflist != NULL) {
PyObject_ClearWeakRefs((PyObject*)obj);
PyObject_ClearWeakRefs(op);
}

(void)lru_cache_tp_clear(obj);
(void)lru_cache_tp_clear(op);
tp->tp_free(obj);
Py_DECREF(tp);
}

static PyObject *
lru_cache_call(lru_cache_object *self, PyObject *args, PyObject *kwds)
lru_cache_call(PyObject *op, PyObject *args, PyObject *kwds)
{
lru_cache_object *self = lru_cache_object_CAST(op);
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(self);
result = self->wrapper(self, args, kwds);
Expand Down Expand Up @@ -1638,8 +1654,9 @@ lru_cache_deepcopy(PyObject *self, PyObject *unused)
}

static int
lru_cache_tp_traverse(lru_cache_object *self, visitproc visit, void *arg)
lru_cache_tp_traverse(PyObject *op, visitproc visit, void *arg)
{
lru_cache_object *self = lru_cache_object_CAST(op);
Py_VISIT(Py_TYPE(self));
lru_list_elem *link = self->root.next;
while (link != &self->root) {
Expand Down Expand Up @@ -1827,7 +1844,7 @@ _functools_clear(PyObject *module)
static void
_functools_free(void *module)
{
_functools_clear((PyObject *)module);
(void)_functools_clear((PyObject *)module);
}

static struct PyModuleDef_Slot _functools_slots[] = {
Expand Down
Loading