Skip to content

Alignment fault in PyMember_SetOne under freethreading on aarch64 #129675

Open
@hawkinsp

Description

@hawkinsp

Bug report

Bug description:

I saw a SIGBUS crash with the following backtrace immediately on startup when running a free-threaded build of https://github.com/jax-ml/jax on Python 3.13t:

#0  PyMember_SetOne (addr=0x20006fa48bc "", addr@entry=0x20006fa4890 "`\367N~", <incomplete sequence \346>, l=<optimized out>,
    v=v@entry='Context manager for `jax2tf_associative_scan_reductions` config option.\n\nJAX has two separate lowering rules for the cumulative reduction primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. The latter has a slow implementation on CPUs and GPUs. By default, jax2tf uses the TPU lowering. Set this flag to True to use the associative scan lowering usage, and only if it makes a difference for your application. See the jax2tf README.md for more details.') at Python/structmember.c:308
#1  0x0000b1fea968e908 in member_set (self=<member_descriptor at remote 0x20006384210>, obj=<State at remote 0x20006fa4890>,
    value='Context manager for `jax2tf_associative_scan_reductions` config option.\n\nJAX has two separate lowering rules for the cumulative reduction primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. The latter has a slow implementation on CPUs and GPUs. By default, jax2tf uses the TPU lowering. Set this flag to True to use the associative scan lowering usage, and only if it makes a difference for your application. See the jax2tf README.md for more details.') at Objects/descrobject.c:238
#2  0x0000b1fea96e9928 in _PyObject_GenericSetAttrWithDict (obj=<State at remote 0x20006fa4890>, name='__doc__',
    value='Context manager for `jax2tf_associative_scan_reductions` config option.\n\nJAX has two separate lowering rules for the cumulative reduction primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. The latter has a slow implementation on CPUs and GPUs. By default, jax2tf uses the TPU lowering. Set this flag to True to use the associative scan lowering usage, and only if it makes a difference for your application. See the jax2tf README.md for more details.', dict=dict@entry=0x0) at Objects/object.c:1778

What we're doing here is running this code:

obj.__doc__ = "..."

and crashing.

We're at this line of code:

FT_ATOMIC_STORE_PTR_RELEASE(*(PyObject **)addr, Py_XNewRef(v));

Digging a little deeper, the relevant disassembly is:

(gdb) display /10i $pc-20
10: x/10i $pc-20
   0xb1fea9860094 <PyMember_SetOne+1596>:       bl      0xb1fea985f690 <_PyCriticalSection_BeginMutex>
   0xb1fea9860098 <PyMember_SetOne+1600>:       ldr     x19, [x19, x21]
   0xb1fea986009c <PyMember_SetOne+1604>:       cbz     x20, 0xb1fea98600a8 <PyMember_SetOne+1616>
   0xb1fea98600a0 <PyMember_SetOne+1608>:       mov     x0, x20
   0xb1fea98600a4 <PyMember_SetOne+1612>:       bl      0xb1fea985f42c <Py_INCREF>
=> 0xb1fea98600a8 <PyMember_SetOne+1616>:       stlr    x20, [x24]
   0xb1fea98600ac <PyMember_SetOne+1620>:       add     x0, sp, #0x8
   0xb1fea98600b0 <PyMember_SetOne+1624>:       bl      0xb1fea985f774 <_PyCriticalSection_End>
   0xb1fea98600b4 <PyMember_SetOne+1628>:       mov     x0, x19
   0xb1fea98600b8 <PyMember_SetOne+1632>:       bl      0xb1fea985f664 <Py_XDECREF>

(gdb) info registers
...
x20            0x200064f9610       2199129134608
x21            0x2c                44
x22            0x10                16
x23            0x0                 0
x24            0x20006fa48bc       2199140321468
x25            0x0                 0
...

What's going on here is that stlr's target address must be 8-byte aligned on aarch64: https://developer.arm.com/documentation/102336/0100/Load-Acquire-and-Store-Release-instructions

but the __doc__ field of this object is only 4-byte aligned, with offset = 44.

(gdb) up
(gdb) print *descr->d_member
$38 = {name = 0xb1fea9bdc950 <_PyRuntime+63952> "__doc__", type = 16, offset = 44, flags = 0, doc = 0x0}

We've placed an object field at an unaligned address, and used it in an atomic access that requires alignment, which is an error.

How should we fix this?

This seems like it's a CPython bug: CPython shouldn't choose underaligned slot offsets for object fields.

However, in this particular case it comes from a Python subclass of a C extension base class that has tp_basicsize=44; I'm also not aware of any rule that says tp_basicsize has to be a multiple of the word size. Perhaps CPython should either enforce that or round up the size of base classes to ensure alignment.

Or we can argue that CPython shouldn't be using an aligned atomic in this case.

What do you think?

CPython versions tested on:

3.13

Operating systems tested on:

Linux

Linked PRs

Metadata

Metadata

Assignees

No one assigned

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions