Skip to content

Commit 35e843c

Browse files
authored
[mypyc] Add efficient librt.base64.b64decode (#20263)
The performance can be 10x faster than stdlib if input is valid base64, or if input has extra non-base64 characters only at the end of input. Similar to the base64 encode implementation I added recently, this uses SIMD instructions when available. The implementation first tries to decode the input optimistically assuming valid base64. If this fails, we'll perform a slow path with a preprocessing step that removes extra characters, and we'll perform a strict base64 decode on the cleaned up input. The semantics aren't 100% compatible with stdlib. First, we raise ValueError on invalid padding instead of `binascii.Error`, since I don't want a runtime dependency on the unrelated a`binascii` module. This needs to be documented, but stdlib can already raise ValueError on other conditions, so the deviation is not huge. Also, some invalid inputs are checked more strictly for padding violations. The stdlib implementation has some mysterious behaviors with invalid inputs that didn't seem worth replicating. The function only accepts a single ASCII str or bytes argument for now, since that seems to be by the far the most common use case. The stdlib function also accepts buffer objects and a `validate` argument. The slow path is still somewhat faster than stdlib (on the order of 1.3x to 2x for longer inputs), at least if the input is much smaller than L1 cache size. Got the initial fast path implementation from ChatGPT, but did a bunch of manual edits afterwards and reviewed carefully.
1 parent 094f66d commit 35e843c

File tree

3 files changed

+297
-3
lines changed

3 files changed

+297
-3
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
def b64encode(s: bytes) -> bytes: ...
2+
def b64decode(s: bytes | str) -> bytes: ...

mypyc/lib-rt/librt_base64.c

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
#define PY_SSIZE_T_CLEAN
22
#include <Python.h>
3+
#include <stdbool.h>
34
#include "librt_base64.h"
45
#include "libbase64.h"
56
#include "pythoncapi_compat.h"
67

78
#ifdef MYPYC_EXPERIMENTAL
89

10+
static PyObject *
11+
b64decode_handle_invalid_input(
12+
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen);
13+
914
#define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2)
1015

1116
#define STACK_BUFFER_SIZE 1024
@@ -63,11 +68,193 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
6368
return b64encode_internal(args[0]);
6469
}
6570

71+
static inline int
72+
is_valid_base64_char(char c, bool allow_padding) {
73+
return ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
74+
(c >= '0' && c <= '9') || (c == '+') || (c == '/') || (allow_padding && c == '='));
75+
}
76+
77+
static PyObject *
78+
b64decode_internal(PyObject *arg) {
79+
const char *src;
80+
Py_ssize_t srclen_ssz;
81+
82+
// Get input pointer and length
83+
if (PyBytes_Check(arg)) {
84+
src = PyBytes_AS_STRING(arg);
85+
srclen_ssz = PyBytes_GET_SIZE(arg);
86+
} else if (PyUnicode_Check(arg)) {
87+
if (!PyUnicode_IS_ASCII(arg)) {
88+
PyErr_SetString(PyExc_ValueError,
89+
"string argument should contain only ASCII characters");
90+
return NULL;
91+
}
92+
src = (const char *)PyUnicode_1BYTE_DATA(arg);
93+
srclen_ssz = PyUnicode_GET_LENGTH(arg);
94+
} else {
95+
PyErr_SetString(PyExc_TypeError,
96+
"argument should be a bytes-like object or ASCII string");
97+
return NULL;
98+
}
99+
100+
// Fast-path: empty input
101+
if (srclen_ssz == 0) {
102+
return PyBytes_FromStringAndSize(NULL, 0);
103+
}
104+
105+
// Quickly ignore invalid characters at the end. Other invalid characters
106+
// are also accepted, but they need a slow path.
107+
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) {
108+
srclen_ssz--;
109+
}
110+
111+
// Compute an output capacity that's at least 3/4 of input, without overflow:
112+
// ceil(3/4 * N) == N - floor(N/4)
113+
size_t srclen = (size_t)srclen_ssz;
114+
size_t max_out = srclen - (srclen / 4);
115+
if (max_out == 0) {
116+
max_out = 1; // defensive (srclen > 0 implies >= 1 anyway)
117+
}
118+
if (max_out > (size_t)PY_SSIZE_T_MAX) {
119+
PyErr_SetString(PyExc_OverflowError, "input too large");
120+
return NULL;
121+
}
122+
123+
// Allocate output bytes (uninitialized) of the max capacity
124+
PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out);
125+
if (out_bytes == NULL) {
126+
return NULL; // Propagate memory error
127+
}
128+
129+
char *outbuf = PyBytes_AS_STRING(out_bytes);
130+
size_t outlen = max_out;
131+
132+
int ret = base64_decode(src, srclen, outbuf, &outlen, 0);
133+
134+
if (ret != 1) {
135+
if (ret == 0) {
136+
// Slow path: handle non-base64 input
137+
return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen);
138+
}
139+
Py_DECREF(out_bytes);
140+
if (ret == -1) {
141+
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
142+
} else {
143+
PyErr_SetString(PyExc_RuntimeError, "base64_decode failed");
144+
}
145+
return NULL;
146+
}
147+
148+
// Sanity-check contract (decoder must not overflow our buffer)
149+
if (outlen > max_out) {
150+
Py_DECREF(out_bytes);
151+
PyErr_SetString(PyExc_RuntimeError, "decoder wrote past output buffer");
152+
return NULL;
153+
}
154+
155+
// Shrink in place to the actual decoded length
156+
if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) {
157+
// _PyBytes_Resize sets an exception and may free the old object
158+
return NULL;
159+
}
160+
return out_bytes;
161+
}
162+
163+
// Process non-base64 input by ignoring non-base64 characters, for compatibility
164+
// with stdlib b64decode.
165+
static PyObject *
166+
b64decode_handle_invalid_input(
167+
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen)
168+
{
169+
// Copy input to a temporary buffer, with non-base64 characters and extra suffix
170+
// characters removed
171+
size_t newbuf_len = 0;
172+
char *newbuf = PyMem_Malloc(srclen);
173+
if (newbuf == NULL) {
174+
Py_DECREF(out_bytes);
175+
return PyErr_NoMemory();
176+
}
177+
178+
// Copy base64 characters and some padding to the new buffer
179+
for (size_t i = 0; i < srclen; i++) {
180+
char c = src[i];
181+
if (is_valid_base64_char(c, false)) {
182+
newbuf[newbuf_len++] = c;
183+
} else if (c == '=') {
184+
// Copy a necessary amount of padding
185+
int remainder = newbuf_len % 4;
186+
if (remainder == 0) {
187+
// No padding needed
188+
break;
189+
}
190+
int numpad = 4 - remainder;
191+
// Check that there is at least the required amount padding (CPython ignores
192+
// extra padding)
193+
while (numpad > 0) {
194+
if (i == srclen || src[i] != '=') {
195+
break;
196+
}
197+
newbuf[newbuf_len++] = '=';
198+
i++;
199+
numpad--;
200+
// Skip non-base64 alphabet characters within padding
201+
while (i < srclen && !is_valid_base64_char(src[i], true)) {
202+
i++;
203+
}
204+
}
205+
break;
206+
}
207+
}
208+
209+
// Stdlib always performs a non-strict padding check
210+
if (newbuf_len % 4 != 0) {
211+
Py_DECREF(out_bytes);
212+
PyMem_Free(newbuf);
213+
PyErr_SetString(PyExc_ValueError, "Incorrect padding");
214+
return NULL;
215+
}
216+
217+
size_t outlen = max_out;
218+
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
219+
PyMem_Free(newbuf);
220+
221+
if (ret != 1) {
222+
Py_DECREF(out_bytes);
223+
if (ret == 0) {
224+
PyErr_SetString(PyExc_ValueError, "Only base64 data is allowed");
225+
}
226+
if (ret == -1) {
227+
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
228+
} else {
229+
PyErr_SetString(PyExc_RuntimeError, "base64_decode failed");
230+
}
231+
return NULL;
232+
}
233+
234+
// Shrink in place to the actual decoded length
235+
if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) {
236+
// _PyBytes_Resize sets an exception and may free the old object
237+
return NULL;
238+
}
239+
return out_bytes;
240+
}
241+
242+
243+
static PyObject*
244+
b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
245+
if (nargs != 1) {
246+
PyErr_SetString(PyExc_TypeError, "b64decode() takes exactly one argument");
247+
return 0;
248+
}
249+
return b64decode_internal(args[0]);
250+
}
251+
66252
#endif
67253

68254
static PyMethodDef librt_base64_module_methods[] = {
69255
#ifdef MYPYC_EXPERIMENTAL
70-
{"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes-like object using Base64.")},
256+
{"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using Base64.")},
257+
{"b64decode", (PyCFunction)b64decode, METH_FASTCALL, PyDoc_STR("Decode a Base64 encoded bytes object or ASCII string.")},
71258
#endif
72259
{NULL, NULL, 0, NULL}
73260
};
@@ -111,7 +298,7 @@ static PyModuleDef_Slot librt_base64_module_slots[] = {
111298
static PyModuleDef librt_base64_module = {
112299
.m_base = PyModuleDef_HEAD_INIT,
113300
.m_name = "base64",
114-
.m_doc = "base64 encoding and decoding optimized for mypyc",
301+
.m_doc = "Fast base64 encoding and decoding optimized for mypyc",
115302
.m_size = 0,
116303
.m_methods = librt_base64_module_methods,
117304
.m_slots = librt_base64_module_slots,

mypyc/test-data/run-base64.test

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
[case testAllBase64Features_librt_experimental]
22
from typing import Any
33
import base64
4+
import binascii
45

5-
from librt.base64 import b64encode
6+
from librt.base64 import b64encode, b64decode
67

78
from testutil import assertRaises
89

@@ -44,6 +45,111 @@ def test_encode_wrapper() -> None:
4445
with assertRaises(TypeError):
4546
enc(b"x", b"y")
4647

48+
def test_decode_basic() -> None:
49+
assert b64decode(b"eA==") == b"x"
50+
51+
with assertRaises(TypeError):
52+
b64decode(bytearray(b"eA=="))
53+
54+
for non_ascii in "\x80", "foo\u100bar", "foo\ua1234bar":
55+
with assertRaises(ValueError):
56+
b64decode(non_ascii)
57+
58+
def check_decode(b: bytes, encoded: bool = False) -> None:
59+
if encoded:
60+
enc = b
61+
else:
62+
enc = b64encode(b)
63+
assert b64decode(enc) == getattr(base64, "b64decode")(enc)
64+
if getattr(enc, "isascii")(): # Test stub has no "isascii"
65+
enc_str = enc.decode("ascii")
66+
assert b64decode(enc_str) == getattr(base64, "b64decode")(enc_str)
67+
68+
def test_decode_different_strings() -> None:
69+
for i in range(256):
70+
check_decode(bytes([i]))
71+
check_decode(bytes([i]) + b"x")
72+
check_decode(bytes([i]) + b"xy")
73+
check_decode(bytes([i]) + b"xyz")
74+
check_decode(bytes([i]) + b"xyza")
75+
check_decode(b"x" + bytes([i]))
76+
check_decode(b"xy" + bytes([i]))
77+
check_decode(b"xyz" + bytes([i]))
78+
check_decode(b"xyza" + bytes([i]))
79+
80+
b = b"a\x00\xb7" * 1000
81+
for i in range(1000):
82+
check_decode(b[:i])
83+
84+
for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200:
85+
check_decode(b)
86+
87+
def is_base64_char(x: int) -> bool:
88+
c = chr(x)
89+
return ('a' <= c <= 'z') or ('A' <= c <= 'Z') or ('0' <= c <= '9') or c in '+/='
90+
91+
def test_decode_with_non_base64_chars() -> None:
92+
# For stdlib compatibility, non-base64 characters should be ignored.
93+
94+
# Invalid characters as a suffix use a fast path.
95+
check_decode(b"eA== ", encoded=True)
96+
check_decode(b"eA==\n", encoded=True)
97+
check_decode(b"eA== \t\n", encoded=True)
98+
check_decode(b"\n", encoded=True)
99+
100+
check_decode(b" e A = = ", encoded=True)
101+
102+
# Special case: Two different encodings of the same data
103+
check_decode(b"eAa=", encoded=True)
104+
check_decode(b"eAY=", encoded=True)
105+
106+
for x in range(256):
107+
if not is_base64_char(x):
108+
b = bytes([x])
109+
check_decode(b, encoded=True)
110+
check_decode(b"eA==" + b, encoded=True)
111+
check_decode(b"e" + b + b"A==", encoded=True)
112+
check_decode(b"eA=" + b + b"=", encoded=True)
113+
114+
def check_decode_error(b: bytes, ignore_stdlib: bool = False) -> None:
115+
if not ignore_stdlib:
116+
with assertRaises(binascii.Error):
117+
getattr(base64, "b64decode")(b)
118+
119+
# The raised error is different, since librt shouldn't depend on binascii
120+
with assertRaises(ValueError):
121+
b64decode(b)
122+
123+
def test_decode_with_invalid_padding() -> None:
124+
check_decode_error(b"eA")
125+
check_decode_error(b"eA=")
126+
check_decode_error(b"eHk")
127+
check_decode_error(b"eA = ")
128+
129+
# Here stdlib behavior seems nonsensical, so we don't try to duplicate it
130+
check_decode_error(b"eA=a=", ignore_stdlib=True)
131+
132+
def test_decode_with_extra_data_after_padding() -> None:
133+
check_decode(b"=", encoded=True)
134+
check_decode(b"==", encoded=True)
135+
check_decode(b"===", encoded=True)
136+
check_decode(b"====", encoded=True)
137+
check_decode(b"eA===", encoded=True)
138+
check_decode(b"eHk==", encoded=True)
139+
check_decode(b"eA==x", encoded=True)
140+
check_decode(b"eHk=x", encoded=True)
141+
check_decode(b"eA==abc=======efg", encoded=True)
142+
143+
def test_decode_wrapper() -> None:
144+
dec: Any = b64decode
145+
assert dec(b"eA==") == b"x"
146+
147+
with assertRaises(TypeError):
148+
dec()
149+
150+
with assertRaises(TypeError):
151+
dec(b"x", b"y")
152+
47153
[case testBase64FeaturesNotAvailableInNonExperimentalBuild_librt_base64]
48154
# This also ensures librt.base64 can be built without experimental features
49155
import librt.base64

0 commit comments

Comments
 (0)