Skip to content

Commit e82e8ec

Browse files
[mypyc] Add a str.format specializer which only supports empty brackets (#10697)
This PR adds a str.format specializer which only supports empty brackets. Also, the detection code can deal with the bracket literals, for example, "{{}}". The specializer first splits the formatting string by "{}". Then it replaces the brackets by PyObject_Str. Finally, all these separated substrings would be sent to a C helper function to generate a joint one.
1 parent 49bb90a commit e82e8ec

File tree

6 files changed

+227
-20
lines changed

6 files changed

+227
-20
lines changed

mypyc/irbuild/specialize.py

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from mypy.nodes import (
1818
CallExpr, RefExpr, MemberExpr, NameExpr, TupleExpr, GeneratorExpr,
19-
ListExpr, DictExpr, ARG_POS
19+
ListExpr, DictExpr, StrExpr, ARG_POS
2020
)
2121
from mypy.types import AnyType, TypeOfAny
2222

@@ -25,20 +25,20 @@
2525
)
2626
from mypyc.ir.rtypes import (
2727
RType, RTuple, str_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive,
28-
bool_rprimitive, is_dict_rprimitive, c_int_rprimitive
28+
bool_rprimitive, is_dict_rprimitive, c_int_rprimitive, is_str_rprimitive
2929
)
3030
from mypyc.primitives.dict_ops import (
3131
dict_keys_op, dict_values_op, dict_items_op, dict_setdefault_spec_init_op
3232
)
3333
from mypyc.primitives.list_ops import new_list_set_item_op
3434
from mypyc.primitives.tuple_ops import new_tuple_set_item_op
35+
from mypyc.primitives.str_ops import str_op, str_build_op
3536
from mypyc.irbuild.builder import IRBuilder
3637
from mypyc.irbuild.for_helpers import (
3738
translate_list_comprehension, translate_set_comprehension,
3839
comprehension_helper, sequence_from_generator_preallocate_helper
3940
)
4041

41-
4242
# Specializers are attempted before compiling the arguments to the
4343
# function. Specializers can return None to indicate that they failed
4444
# and the call should be compiled normally. Otherwise they should emit
@@ -62,9 +62,11 @@ def specialize_function(
6262
There may exist multiple specializers for one function. When translating method
6363
calls, the earlier appended specializer has higher priority.
6464
"""
65+
6566
def wrapper(f: Specializer) -> Specializer:
6667
specializers.setdefault((name, typ), []).append(f)
6768
return f
69+
6870
return wrapper
6971

7072

@@ -189,13 +191,13 @@ def translate_safe_generator_call(
189191
return builder.gen_method_call(
190192
builder.accept(callee.expr), callee.name,
191193
([translate_list_comprehension(builder, expr.args[0])]
192-
+ [builder.accept(arg) for arg in expr.args[1:]]),
194+
+ [builder.accept(arg) for arg in expr.args[1:]]),
193195
builder.node_type(expr), expr.line, expr.arg_kinds, expr.arg_names)
194196
else:
195197
return builder.call_refexpr_with_args(
196198
expr, callee,
197199
([translate_list_comprehension(builder, expr.args[0])]
198-
+ [builder.accept(arg) for arg in expr.args[1:]]))
200+
+ [builder.accept(arg) for arg in expr.args[1:]]))
199201
return None
200202

201203

@@ -343,7 +345,7 @@ def translate_dict_setdefault(
343345
return None
344346
data_type = Integer(2, c_int_rprimitive, expr.line)
345347
elif (isinstance(arg, CallExpr) and isinstance(arg.callee, NameExpr)
346-
and arg.callee.fullname == 'builtins.set'):
348+
and arg.callee.fullname == 'builtins.set'):
347349
if len(arg.args):
348350
return None
349351
data_type = Integer(3, c_int_rprimitive, expr.line)
@@ -356,3 +358,95 @@ def translate_dict_setdefault(
356358
[callee_dict, key_val, data_type],
357359
expr.line)
358360
return None
361+
362+
363+
@specialize_function('format', str_rprimitive)
364+
def translate_str_format(
365+
builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]:
366+
if (isinstance(callee, MemberExpr) and isinstance(callee.expr, StrExpr)
367+
and expr.arg_kinds.count(ARG_POS) == len(expr.arg_kinds)):
368+
369+
format_str = callee.expr.value
370+
if not can_optimize_format(format_str):
371+
return None
372+
373+
literals = split_braces(format_str)
374+
375+
variables = [builder.accept(x) if is_str_rprimitive(builder.node_type(x))
376+
else builder.call_c(str_op, [builder.accept(x)], expr.line)
377+
for x in expr.args]
378+
379+
# The first parameter is the total size of the following PyObject* merged from
380+
# two lists alternatively.
381+
result_list: List[Value] = [Integer(0, c_int_rprimitive)]
382+
for a, b in zip(literals, variables):
383+
if a:
384+
result_list.append(builder.load_str(a))
385+
result_list.append(b)
386+
# The split_braces() always generates one more element
387+
if literals[-1]:
388+
result_list.append(builder.load_str(literals[-1]))
389+
390+
# Special case for empty string and literal string
391+
if len(result_list) == 1:
392+
return builder.load_str("")
393+
if not variables and len(result_list) == 2:
394+
return result_list[1]
395+
396+
result_list[0] = Integer(len(result_list) - 1, c_int_rprimitive)
397+
return builder.call_c(str_build_op, result_list, expr.line)
398+
return None
399+
400+
401+
def can_optimize_format(format_str: str) -> bool:
402+
# TODO
403+
# Only empty braces can be optimized
404+
prev = ''
405+
for c in format_str:
406+
if (c == '{' and prev == '{'
407+
or c == '}' and prev == '}'):
408+
prev = ''
409+
continue
410+
if (prev != '' and (c == '}' and prev != '{'
411+
or prev == '{' and c != '}')):
412+
return False
413+
prev = c
414+
return True
415+
416+
417+
def split_braces(format_str: str) -> List[str]:
418+
# This function can only be called after format_str pass can_optimize_format()
419+
tmp_str = ''
420+
ret_list = []
421+
prev = ''
422+
for c in format_str:
423+
# There are three cases: {, }, others
424+
# when c is '}': prev is '{' -> match empty braces
425+
# '}' -> merge into one } in literal
426+
# others -> pass
427+
# c is '{': prev is '{' -> merge into one { in literal
428+
# '}' -> pass
429+
# others -> pass
430+
# c is others: add c into literal
431+
clear_prev = True
432+
if c == '}':
433+
if prev == '{':
434+
ret_list.append(tmp_str)
435+
tmp_str = ''
436+
elif prev == '}':
437+
tmp_str += '}'
438+
else:
439+
clear_prev = False
440+
elif c == '{':
441+
if prev == '{':
442+
tmp_str += '{'
443+
else:
444+
clear_prev = False
445+
else:
446+
tmp_str += c
447+
clear_prev = False
448+
prev = c
449+
if clear_prev:
450+
prev = ''
451+
ret_list.append(tmp_str)
452+
return ret_list

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
384384
// Str operations
385385

386386

387+
PyObject *CPyStr_Build(int len, ...);
387388
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
388389
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
389390
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);

mypyc/lib-rt/str_ops.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
4343
}
4444
}
4545

46+
PyObject *CPyStr_Build(int len, ...) {
47+
int i;
48+
va_list args;
49+
va_start(args, len);
50+
51+
PyObject *res = PyUnicode_FromObject(va_arg(args, PyObject *));
52+
for (i = 1; i < len; i++) {
53+
PyObject *str = va_arg(args, PyObject *);
54+
PyUnicode_Append(&res, str);
55+
}
56+
57+
va_end(args);
58+
return res;
59+
}
60+
4661
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split)
4762
{
4863
Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split);

mypyc/primitives/str_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
src='PyUnicode_Type')
2222

2323
# str(obj)
24-
function_op(
24+
str_op = function_op(
2525
name='builtins.str',
2626
arg_types=[object_rprimitive],
2727
return_type=str_rprimitive,
@@ -44,6 +44,14 @@
4444
error_kind=ERR_MAGIC
4545
)
4646

47+
str_build_op = custom_op(
48+
arg_types=[c_int_rprimitive],
49+
return_type=str_rprimitive,
50+
c_function_name='CPyStr_Build',
51+
error_kind=ERR_MAGIC,
52+
var_arg_type=str_rprimitive
53+
)
54+
4755
# str.startswith(str)
4856
method_op(
4957
name='startswith',

mypyc/test-data/irbuild-str.test

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,46 @@ L2:
158158
return 0
159159
L3:
160160
unreachable
161+
162+
[case testStringFormatMethod]
163+
def f(s: str, num: int) -> None:
164+
s1 = "Hi! I'm {}, and I'm {} years old.".format(s, num)
165+
s2 = ''.format()
166+
s3 = 'abc'.format()
167+
s3 = '}}{}{{{}}}{{{}'.format(num, num, num)
168+
[out]
169+
def f(s, num):
170+
s :: str
171+
num :: int
172+
r0 :: object
173+
r1, r2, r3, r4, r5, s1, r6, s2, r7, s3 :: str
174+
r8 :: object
175+
r9 :: str
176+
r10 :: object
177+
r11 :: str
178+
r12 :: object
179+
r13, r14, r15, r16, r17 :: str
180+
L0:
181+
r0 = box(int, num)
182+
r1 = PyObject_Str(r0)
183+
r2 = "Hi! I'm "
184+
r3 = ", and I'm "
185+
r4 = ' years old.'
186+
r5 = CPyStr_Build(5, r2, s, r3, r1, r4)
187+
s1 = r5
188+
r6 = ''
189+
s2 = r6
190+
r7 = 'abc'
191+
s3 = r7
192+
r8 = box(int, num)
193+
r9 = PyObject_Str(r8)
194+
r10 = box(int, num)
195+
r11 = PyObject_Str(r10)
196+
r12 = box(int, num)
197+
r13 = PyObject_Str(r12)
198+
r14 = '}'
199+
r15 = '{'
200+
r16 = '}{'
201+
r17 = CPyStr_Build(6, r14, r9, r15, r11, r16, r13)
202+
s3 = r17
203+
return 1

mypyc/test-data/run-strings.test

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class A:
198198
self.age = age
199199

200200
def __repr__(self):
201-
return f"{self.name} is {self.age} years old."
201+
return f'{self.name} is {self.age} years old.'
202202

203203
def test_fstring_datatype() -> None:
204204
u = A('John Doe', 14)
@@ -236,8 +236,8 @@ def test_fstring_conversion() -> None:
236236
assert f'{s}' == 'test: āĀēĒčČ..šŠūŪžŽ'
237237
assert f'{s!a}' == "'test: \\u0101\\u0100\\u0113\\u0112\\u010d\\u010c..\\u0161\\u0160\\u016b\\u016a\\u017e\\u017d'"
238238

239-
assert f'Hello {var!s}' == "Hello mypyc"
240-
assert f'Hello {num!s}' == "Hello 20"
239+
assert f'Hello {var!s}' == 'Hello mypyc'
240+
assert f'Hello {num!s}' == 'Hello 20'
241241

242242
def test_fstring_align() -> None:
243243
assert f'Hello {var:>20}' == "Hello mypyc"
@@ -252,34 +252,57 @@ def test_fstring_multi() -> None:
252252
assert s == 'mypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypymypy'
253253

254254
def test_fstring_python_doc() -> None:
255-
name = "Fred"
255+
name = 'Fred'
256256
assert f"He said his name is {name!r}." == "He said his name is 'Fred'."
257257
assert f"He said his name is {repr(name)}." == "He said his name is 'Fred'."
258258

259259
width = 10
260260
precision = 4
261-
value = decimal.Decimal("12.34567")
262-
assert f"result: {value:{width}.{precision}}" == 'result: 12.35' # nested field
261+
value = decimal.Decimal('12.34567')
262+
assert f'result: {value:{width}.{precision}}' == 'result: 12.35' # nested field
263263

264264
today = datetime(year=2017, month=1, day=27)
265-
assert f"{today:%B %d, %Y}" == 'January 27, 2017' # using date format specifier
265+
assert f'{today:%B %d, %Y}' == 'January 27, 2017' # using date format specifier
266266

267267
number = 1024
268-
assert f"{number:#0x}" == '0x400' # using integer format specifier
268+
assert f'{number:#0x}' == '0x400' # using integer format specifier
269269

270270
[case testStringFormatMethod]
271271
from typing import Tuple
272272

273273
def test_format_method_basics() -> None:
274-
assert "".format() == ""
275-
assert "abc".format() == "abc"
274+
assert ''.format() == ''
275+
assert 'abc'.format() == 'abc'
276+
assert '{}{}'.format(1, 2) == '12'
276277

277-
name = "Eric"
278+
name = 'Eric'
278279
age = 14
279280
assert "My name is {name}, I'm {age}.".format(name=name, age=age) == "My name is Eric, I'm 14."
280281
assert "My name is {A}, I'm {B}.".format(A=name, B=age) == "My name is Eric, I'm 14."
281282
assert "My name is {}, I'm {B}.".format(name, B=age) == "My name is Eric, I'm 14."
282283

284+
bool_var1 = True
285+
bool_var2 = False
286+
assert 'bool: {}, {}'.format(bool_var1, bool_var2) == 'bool: True, False'
287+
288+
def test_format_method_empty_braces() -> None:
289+
name = 'Eric'
290+
age = 14
291+
292+
assert 'Hello, {}!'.format(name) == 'Hello, Eric!'
293+
assert '{}'.format(name) == 'Eric'
294+
assert '{}! Hi!'.format(name) == 'Eric! Hi!'
295+
assert '{}, Hi, {}'.format(name, name) == 'Eric, Hi, Eric'
296+
assert 'Hi! {}'.format(name) == 'Hi! Eric'
297+
assert "Hi, I'm {}. I'm {}.".format(name, age) == "Hi, I'm Eric. I'm 14."
298+
299+
assert '{{}}'.format() == '{}'
300+
assert '{{{{}}}}'.format() == '{{}}'
301+
assert '{{}}{}'.format(name) == '{}Eric'
302+
assert 'Hi! {{{}}}'.format(name) == 'Hi! {Eric}'
303+
assert 'Hi! {{ {}'.format(name) == 'Hi! { Eric'
304+
assert 'Hi! {{ {} }}}}'.format(name) == 'Hi! { Eric }}'
305+
283306
def test_format_method_numbers() -> None:
284307
s = 'int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}'.format(-233)
285308
assert s == 'int: -233; hex: -e9; oct: -351; bin: -11101001'
@@ -295,6 +318,29 @@ def test_format_method_numbers() -> None:
295318
assert 'negative integer: {}'.format(neg_num) == 'negative integer: -3'
296319
assert 'negative integer: {}'.format(-large_num) == 'negative integer: -36893488147419103232'
297320

321+
large_float = 1.23e30
322+
large_float2 = 1234123412341234123400000000000000000
323+
small_float = 1.23e-20
324+
assert '{}, {}, {}'.format(small_float, large_float, large_float2) == '1.23e-20, 1.23e+30, 1234123412341234123400000000000000000'
325+
nan_num = float('nan')
326+
inf_num = float('inf')
327+
assert '{}, {}'.format(nan_num, inf_num) == 'nan, inf'
328+
329+
def format_args(*args: int) -> str:
330+
return 'x{}y{}'.format(*args)
331+
def format_kwargs(**kwargs: int) -> str:
332+
return 'c{x}d{y}'.format(**kwargs)
333+
def format_args_self(*args: int) -> str:
334+
return '{}'.format(args)
335+
def format_kwargs_self(**kwargs: int) -> str:
336+
return '{}'.format(kwargs)
337+
338+
def test_format_method_args() -> None:
339+
assert format_args(10, 2) == 'x10y2'
340+
assert format_args_self(10, 2) == '(10, 2)'
341+
assert format_kwargs(x=10, y=2) == 'c10d2'
342+
assert format_kwargs(x=10, y=2, z=1) == 'c10d2'
343+
assert format_kwargs_self(x=10, y=2, z=1) == "{'x': 10, 'y': 2, 'z': 1}"
298344

299345
class Point:
300346
def __init__(self, x, y):
@@ -319,7 +365,7 @@ def test_format_method_python_doc() -> None:
319365
assert 'Coordinates: {latitude}, {longitude}'.format(**coord) == 'Coordinates: 37.24N, -115.81W'
320366

321367
# Accessing arguments’ attributes:
322-
assert str(Point(4, 2)) == "Point(4, 2)"
368+
assert str(Point(4, 2)) == 'Point(4, 2)'
323369

324370
# Accessing arguments’ items:
325371
coord2 = (3, 5)
@@ -371,7 +417,7 @@ def test_format_method_python_doc() -> None:
371417
width = 5
372418
tmp_strs = []
373419
for num in range(5,12):
374-
tmp_str = ""
420+
tmp_str = ''
375421
for base in 'dXob':
376422
tmp_str += ('{0:{width}{base}}'.format(num, base=base, width=width))
377423
tmp_strs.append(tmp_str)

0 commit comments

Comments
 (0)