Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-97982: Reuse PyUnicode_Count in unicode_count #98025

Merged
merged 10 commits into from
Oct 12, 2022
10 changes: 10 additions & 0 deletions Lib/test/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def test_count(self):
self.checkequal(0, 'a' * 10, 'count', 'a\u0102')
self.checkequal(0, 'a' * 10, 'count', 'a\U00100304')
self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304')
# test subclass
class MyStr(str):
pass
self.checkequal(3, MyStr('aaa'), 'count', 'a')

def test_find(self):
string_tests.CommonTest.test_find(self)
Expand Down Expand Up @@ -2983,6 +2987,12 @@ def test_count(self):
self.assertEqual(unicode_count(uni, ch, 0, len(uni)), 1)
self.assertEqual(unicode_count(st, ch, 0, len(st)), 0)

# subclasses should still work
class MyStr(str):
pass

self.assertEqual(unicode_count(MyStr('aab'), 'a', 0, 3), 2)

# Test PyUnicode_FindChar()
@support.cpython_only
@unittest.skipIf(_testcapi is None, 'need _testcapi module')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Create ``any_unicode_count`` private helper for
sobolevn marked this conversation as resolved.
Show resolved Hide resolved
both ``PyUnicode_Count`` and ``unicode_count`` in
``unicodeobject.c``.
81 changes: 24 additions & 57 deletions Objects/unicodeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -8964,21 +8964,20 @@ _PyUnicode_InsertThousandsGrouping(
return count;
}


Py_ssize_t
PyUnicode_Count(PyObject *str,
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
static Py_ssize_t
any_unicode_count(PyObject *str,
sobolevn marked this conversation as resolved.
Show resolved Hide resolved
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
{
assert(PyUnicode_Check(str));
assert(PyUnicode_Check(substr));

Py_ssize_t result;
int kind1, kind2;
const void *buf1 = NULL, *buf2 = NULL;
Py_ssize_t len1, len2;

if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
return -1;

kind1 = PyUnicode_KIND(str);
kind2 = PyUnicode_KIND(substr);
if (kind1 < kind2)
Expand Down Expand Up @@ -9039,6 +9038,18 @@ PyUnicode_Count(PyObject *str,
return -1;
}

Py_ssize_t
PyUnicode_Count(PyObject *str,
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
{
if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
return -1;

return any_unicode_count(str, substr, start, end);
}

Py_ssize_t
PyUnicode_Find(PyObject *str,
PyObject *substr,
Expand Down Expand Up @@ -10858,60 +10869,16 @@ unicode_count(PyObject *self, PyObject *args)
Py_ssize_t start = 0;
Py_ssize_t end = PY_SSIZE_T_MAX;
PyObject *result;
int kind1, kind2;
const void *buf1, *buf2;
Py_ssize_t len1, len2, iresult;
Py_ssize_t iresult;

if (!parse_args_finds_unicode("count", args, &substring, &start, &end))
return NULL;

kind1 = PyUnicode_KIND(self);
kind2 = PyUnicode_KIND(substring);
if (kind1 < kind2)
return PyLong_FromLong(0);

len1 = PyUnicode_GET_LENGTH(self);
len2 = PyUnicode_GET_LENGTH(substring);
ADJUST_INDICES(start, end, len1);
if (end - start < len2)
return PyLong_FromLong(0);

buf1 = PyUnicode_DATA(self);
buf2 = PyUnicode_DATA(substring);
if (kind2 != kind1) {
buf2 = unicode_askind(kind2, buf2, len2, kind1);
if (!buf2)
return NULL;
}
switch (kind1) {
case PyUnicode_1BYTE_KIND:
iresult = ucs1lib_count(
((const Py_UCS1*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
case PyUnicode_2BYTE_KIND:
iresult = ucs2lib_count(
((const Py_UCS2*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
case PyUnicode_4BYTE_KIND:
iresult = ucs4lib_count(
((const Py_UCS4*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
default:
Py_UNREACHABLE();
}
iresult = any_unicode_count(self, substring, start, end);
if (iresult == -1)
return NULL;

result = PyLong_FromSsize_t(iresult);
sobolevn marked this conversation as resolved.
Show resolved Hide resolved

assert((kind2 == kind1) == (buf2 == PyUnicode_DATA(substring)));
if (kind2 != kind1)
PyMem_Free((void *)buf2);

return result;
}

Expand Down