diff --git a/typed_python/compiler/tests/string_compilation_test.py b/typed_python/compiler/tests/string_compilation_test.py index 7c4771e38..90a27ebbf 100644 --- a/typed_python/compiler/tests/string_compilation_test.py +++ b/typed_python/compiler/tests/string_compilation_test.py @@ -186,6 +186,31 @@ def endswith(x: str, y: str): self.assertEqual(startswith(s1, s2), compiledSW(s1, s2)) self.assertEqual(endswith(s1, s2), compiledEW(s1, s2)) + def test_string_replace(self): + def replace(x: str, y: str, z: str): + return x.replace(y, z) + + replaceCompiled = Compiled(replace) + + def replace2(x: str, y: str, z: str, i: int): + return x.replace(y, z, i) + + replaceCompiled = Compiled(replace) + replace2Compiled = Compiled(replace2) + + strings = [""] + for _ in range(6): + for s in ["ab"]: + strings = [x + s for x in strings] + + for s1 in strings: + for s2 in strings: + for s3 in strings: + self.assertEqual(replace(s1, s2, s3), replaceCompiled(s1, s2, s3)) + + for i in [-1, 0, 1, 2]: + self.assertEqual(replace2(s1, s2, s3, i), replace2Compiled(s1, s2, s3, i)) + def test_string_getitem_slice(self): def getitem1(x: str, y: int): return x[:y] diff --git a/typed_python/compiler/type_wrappers/string_wrapper.py b/typed_python/compiler/type_wrappers/string_wrapper.py index 9bc4931cc..b16ecc665 100644 --- a/typed_python/compiler/type_wrappers/string_wrapper.py +++ b/typed_python/compiler/type_wrappers/string_wrapper.py @@ -59,6 +59,35 @@ def strEndswith(s, suffix): return s[-len(suffix):] == suffix +def strReplace(s, old, new, maxCount): + if maxCount == 0: + return s + + accumulator = ListOf(str)() + + pos = 0 + seen = 0 + + while True: + if maxCount >= 0 and seen >= maxCount: + nextLoc = -1 + else: + nextLoc = s.find(old, pos) + + if nextLoc >= 0: + accumulator.append(s[pos:nextLoc]) + + if len(old): + pos = nextLoc + len(old) + else: + pos += 1 + + seen += 1 + else: + accumulator.append(s[pos:]) + return new.join(accumulator) + + class StringWrapper(RefcountedWrapper): is_pod = False is_empty = False @@ -287,7 +316,7 @@ def constant(self, context, s): def convert_attribute(self, context, instance, attr): if ( - attr in ("find", "split", "join", 'strip', 'rstrip', 'lstrip', "startswith", "endswith") + attr in ("find", "split", "join", 'strip', 'rstrip', 'lstrip', "startswith", "endswith", "replace") or attr in self._str_methods or attr in self._bool_methods ): @@ -296,7 +325,7 @@ def convert_attribute(self, context, instance, attr): return super().convert_attribute(context, instance, attr) def convert_method_call(self, context, instance, methodname, args, kwargs): - if not (methodname in ("find", "split", "join", 'strip', 'rstrip', 'lstrip', "startswith", "endswith") + if not (methodname in ("find", "split", "join", 'strip', 'rstrip', 'lstrip', "startswith", "endswith", "replace") or methodname in self._str_methods or methodname in self._bool_methods): return context.pushException(AttributeError, methodname) @@ -354,12 +383,34 @@ def convert_method_call(self, context, instance, methodname, args, kwargs): if args[0].expr_type != self: context.pushException( TypeError, - "startswith first arg must be str (tuple of str not supported yet)" + "endswith first arg must be str (tuple of str not supported yet)" ) return return context.call_py_function(strEndswith, (instance, args[0]), {}) + elif methodname == "replace": + if len(args) in [2, 3]: + for i in [0, 1]: + if args[i].expr_type != self: + context.pushException( + TypeError, + f"replace() argument {i + 1} must be str" + ) + return + + if len(args) == 3 and args[2].expr_type.typeRepresentation != Int64: + context.pushException( + TypeError, + f"replace() argument 3 must be int, not {args[2].expr_type.typeRepresentation}" + ) + return + + if len(args) == 2: + return context.call_py_function(strReplace, (instance, args[0], args[1], context.constant(-1)), {}) + else: + return context.call_py_function(strReplace, (instance, args[0], args[1], args[2]), {}) + elif methodname == "find": if len(args) == 1: return context.push(