From 6bb26e0187e04acb68db9891e57386ef694c85e4 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Sun, 2 Jul 2023 18:50:43 -0400 Subject: [PATCH] Misc fixes (#410) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix corner case when typechecking scoped names with static compilation * Undo log * Fix nested loop domination; Minor aestethic fixes * clang-format * Add slice indices() method * Fix overloads with static arguments * Update itertools combinatorics functions * Fix import domination issue (missing stack insert) * Fix itertools * Remove log * Bump version --------- Co-authored-by: Ibrahim Numanagić --- CMakeLists.txt | 4 +- codon/parser/ast/types/link.cpp | 3 +- codon/parser/visitors/simplify/access.cpp | 15 +- codon/parser/visitors/simplify/ctx.cpp | 1 + codon/parser/visitors/simplify/loops.cpp | 3 +- codon/parser/visitors/translate/translate.cpp | 3 +- codon/parser/visitors/typecheck/ctx.cpp | 3 +- codon/parser/visitors/typecheck/ctx.h | 3 + codon/parser/visitors/typecheck/loops.cpp | 2 +- codon/parser/visitors/typecheck/op.cpp | 2 +- codon/parser/visitors/typecheck/typecheck.cpp | 3 +- stdlib/internal/types/collections/list.codon | 12 +- stdlib/internal/types/slice.codon | 8 + stdlib/itertools.codon | 512 ++++++++++++++---- test/core/containers.codon | 165 ++++++ test/stdlib/itertools_test.codon | 261 ++++++++- 16 files changed, 848 insertions(+), 152 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2327df0f..75a55c97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,10 @@ cmake_minimum_required(VERSION 3.14) project( Codon - VERSION "0.16.1" + VERSION "0.16.2" HOMEPAGE_URL "https://github.com/exaloop/codon" DESCRIPTION "high-performance, extensible Python compiler") -set(CODON_JIT_PYTHON_VERSION "0.1.5") +set(CODON_JIT_PYTHON_VERSION "0.1.6") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in" "${PROJECT_SOURCE_DIR}/codon/config/config.h") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in" diff --git a/codon/parser/ast/types/link.cpp b/codon/parser/ast/types/link.cpp index c2ab722a..41913b67 100644 --- a/codon/parser/ast/types/link.cpp +++ b/codon/parser/ast/types/link.cpp @@ -154,7 +154,8 @@ bool LinkType::isInstantiated() const { return kind == Link && type->isInstantia std::string LinkType::debugString(char mode) const { if (kind == Unbound || kind == Generic) { if (mode == 2) { - return fmt::format("{}{}{}", kind == Unbound ? '?' : '#', id, + return fmt::format("{}{}{}{}", genericName.empty() ? "" : genericName + ":", + kind == Unbound ? '?' : '#', id, trait ? ":" + trait->debugString(mode) : ""); } if (trait) diff --git a/codon/parser/visitors/simplify/access.cpp b/codon/parser/visitors/simplify/access.cpp index aec276a1..a1078789 100644 --- a/codon/parser/visitors/simplify/access.cpp +++ b/codon/parser/visitors/simplify/access.cpp @@ -38,11 +38,16 @@ void SimplifyVisitor::visit(IdExpr *expr) { // while True: // if x > 10: break // x = x + 1 # x must be dominated after the loop to ensure that it gets updated - if (auto loop = ctx->getBase()->getLoop()) { - bool inside = val->scope.size() >= loop->scope.size() && - val->scope[loop->scope.size() - 1] == loop->scope.back(); - if (!inside) - loop->seenVars.insert(expr->value); + if (ctx->getBase()->getLoop()) { + for (size_t li = ctx->getBase()->loops.size(); li-- > 0;) { + auto &loop = ctx->getBase()->loops[li]; + bool inside = val->scope.size() >= loop.scope.size() && + val->scope[loop.scope.size() - 1] == loop.scope.back(); + if (!inside) + loop.seenVars.insert(expr->value); + else + break; + } } // Replace the variable with its canonical name diff --git a/codon/parser/visitors/simplify/ctx.cpp b/codon/parser/visitors/simplify/ctx.cpp index fec1499e..35b1112f 100644 --- a/codon/parser/visitors/simplify/ctx.cpp +++ b/codon/parser/visitors/simplify/ctx.cpp @@ -153,6 +153,7 @@ SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string & (*lastGood)->importPath); item->accessChecked = {(*lastGood)->scope}; lastGood = it->second.insert(++lastGood, item); + stack.front().push_back(name); // Make sure to prepend a binding declaration: `var` and `var__used__ = False` // to the dominating scope. scope.stmts[scope.blocks[prefix - 1]].push_back(std::make_unique( diff --git a/codon/parser/visitors/simplify/loops.cpp b/codon/parser/visitors/simplify/loops.cpp index 99f38fa1..0fdb71c0 100644 --- a/codon/parser/visitors/simplify/loops.cpp +++ b/codon/parser/visitors/simplify/loops.cpp @@ -125,8 +125,9 @@ void SimplifyVisitor::visit(ForStmt *stmt) { ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts)); // Dominate loop variables - for (auto &var : ctx->getBase()->getLoop()->seenVars) + for (auto &var : ctx->getBase()->getLoop()->seenVars) { ctx->findDominatingBinding(var); + } ctx->getBase()->loops.pop_back(); } diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 5176c0ab..42a7934b 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -284,6 +284,7 @@ void TranslateVisitor::visit(PipeExpr *expr) { simplePipeline &= !isGen(fn); std::vector args; + args.reserve(call->args.size()); for (auto &a : call->args) args.emplace_back(a.value->getEllipsis() ? nullptr : transform(a.value)); stages.emplace_back(fn, args, isGen(fn), false); @@ -642,7 +643,7 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt ltrim(lp); rtrim(lp); // Extract declares and constants. - if (isDeclare && !startswith(lp, "declare ")) { + if (isDeclare && !startswith(lp, "declare ") && !startswith(lp, "@")) { bool isConst = lp.find("private constant") != std::string::npos; if (!isConst) { isDeclare = false; diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index a4698f9d..e1f42746 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -286,7 +286,8 @@ int TypeContext::reorderNamedArgs(types::FuncType *func, Emsg(Error::CALL_ARGS_MISSING, cache->rev(func->ast->name), cache->reverseIdentifierLookup[func->ast->args[i].name])); } - return score + onDone(starArgIndex, kwstarArgIndex, slots, partial); + auto s = onDone(starArgIndex, kwstarArgIndex, slots, partial); + return s != -1 ? score + s : -1; } void TypeContext::dump(int pad) { diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 8021c89d..412ba5aa 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -86,6 +86,9 @@ struct TypeContext : public Context { return item; } std::shared_ptr find(const std::string &name) const override; + std::shared_ptr find(const char *name) const { + return find(std::string(name)); + } /// Find an internal type. Assumes that it exists. std::shared_ptr forceFind(const std::string &name) const; types::TypePtr getType(const std::string &name) const; diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 435359f8..4fd3819e 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -103,7 +103,7 @@ void TypecheckVisitor::visit(ForStmt *stmt) { unify(stmt->var->type, iterType ? unify(val->type, iterType->generics[0].type) : val->type); - ctx->staticLoops.push_back(""); + ctx->staticLoops.emplace_back(); ctx->blockLevel++; transform(stmt->suite); ctx->blockLevel--; diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 76a9b1a1..5c66a664 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -310,7 +310,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) { } else { if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` transformType(expr->typeParams[i]); - if (!expr->typeParams[i]->isType()) + if (expr->typeParams[i]->type->getClass() && !expr->typeParams[i]->isType()) E(Error::EXPECTED_TYPE, expr->typeParams[i], "type"); t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(), expr->typeParams[i]->getType()); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index b0231a2a..76c61da1 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -258,8 +258,9 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn, if (slots[si].empty()) { // is this "real" type? if (in(niGenerics, fn->ast->args[si].name) && - !fn->ast->args[si].defaultValue) + !fn->ast->args[si].defaultValue) { return -1; + } reordered.push_back({nullptr, 0}); } else { reordered.push_back({args[slots[si][0]].value->type, slots[si][0]}); diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index fb96844c..502281d8 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -8,11 +8,9 @@ class List: self.arr = Array[T](10) self.len = 0 - def __init__(self, it: Generator[T]): - self.arr = Array[T](10) + def __init__(self, capacity: int): + self.arr = Array[T](capacity) self.len = 0 - for i in it: - self.append(i) def __init__(self, other: List[T]): self.arr = Array[T](other.len) @@ -20,9 +18,11 @@ class List: for i in other: self.append(i) - def __init__(self, capacity: int): - self.arr = Array[T](capacity) + def __init__(self, it: Generator[T]): + self.arr = Array[T](10) self.len = 0 + for i in it: + self.append(i) def __init__(self, arr: Array[T], len: int): self.arr = arr diff --git a/stdlib/internal/types/slice.codon b/stdlib/internal/types/slice.codon index a5db08d5..06cb4d7a 100644 --- a/stdlib/internal/types/slice.codon +++ b/stdlib/internal/types/slice.codon @@ -6,6 +6,9 @@ class Slice: stop: Optional[int] step: Optional[int] + def __new__(stop: Optional[int]): + return Slice(None, stop, None) + def adjust_indices(self, length: int) -> Tuple[int, int, int, int]: step: int = self.step if self.step is not None else 1 start: int = 0 @@ -47,6 +50,11 @@ class Slice: return start, stop, step, 0 + def indices(self, length: int): + if length < 0: + raise ValueError("length should not be negative") + return self.adjust_indices(length)[:-1] + def __repr__(self): return f"slice({self.start}, {self.stop}, {self.step})" diff --git a/stdlib/itertools.codon b/stdlib/itertools.codon index b5852e61..951799db 100644 --- a/stdlib/itertools.codon +++ b/stdlib/itertools.codon @@ -333,140 +333,426 @@ def zip_longest(*args): # Combinatoric iterators -def combinations(pool: Generator[T], r: int, T: type) -> Generator[List[T]]: - """ - Return successive r-length combinations of elements in the iterable. +def _as_list(x): + if isinstance(x, list): + return x + else: + return list(x) - combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3) - """ +def product(*iterables, repeat: int): + if repeat < 0: + raise ValueError("repeat must be non-negative") + + if repeat == 0: + nargs = 0 + else: + nargs = len(iterables) + + npools = nargs * repeat + indices = Ptr[int](npools) - def combinations_helper(pool: List[T], r: int, T: type) -> Generator[List[T]]: - n = len(pool) - if r > n: + pools = list(capacity=npools) + i = 0 + + while i < nargs: + p = _as_list(iterables[i]) + if len(p) == 0: return - indices = list(range(r)) - yield [pool[i] for i in indices] + pools.append(p) + indices[i] = 0 + i += 1 + + while i < npools: + pools.append(pools[i - nargs]) + indices[i] = 0 + i += 1 + + result = list(capacity=npools) + for i in range(npools): + result.append(pools[i][0]) + + while True: + yield result + + result = result.copy() + i = npools - 1 + while i >= 0: + pool = pools[i] + indices[i] += 1 + + if indices[i] == len(pool): + indices[i] = 0 + result[i] = pool[0] + else: + result[i] = pool[indices[i]] + break + + i -= 1 + + if i < 0: + break + +@overload +def product(*iterables, repeat: Static[int] = 1): + if repeat < 0: + compile_error("repeat must be non-negative") + + # handle some common cases + if repeat == 0: + yield () + elif repeat == 1 and staticlen(iterables) == 1: + it0 = iterables[0] + for a in it0: + yield (a,) + elif repeat == 1 and staticlen(iterables) == 2: + it0 = iterables[0] + it1 = iterables[1] + for a in it0: + for b in it1: + yield (a, b) + elif repeat == 1 and staticlen(iterables) == 3: + it0 = iterables[0] + it1 = iterables[1] + it2 = iterables[2] + for a in it0: + for b in it1: + for c in it2: + yield (a, b, c) + else: + nargs: Static[int] = staticlen(iterables) + npools: Static[int] = nargs * repeat + indices_tuple = (0,) * npools + indices = Ptr[int](__ptr__(indices_tuple).as_byte()) + pools = tuple(_as_list(it) for it in iterables) * repeat + + for i in staticrange(nargs): + if len(pools[i]) == 0: + return + + result = tuple(pool[0] for pool in pools) + while True: - b = -1 - for i in reversed(range(r)): - if indices[i] != i + n - r: - b = i + yield result + + i = npools - 1 + while i >= 0: + pool = pools[i] + indices[i] += 1 + + if indices[i] == len(pool): + indices[i] = 0 + else: break - if b == -1: - return - indices[b] += 1 - for j in range(b + 1, r): - indices[j] = indices[j - 1] + 1 - yield [pool[i] for i in indices] + i -= 1 + + if i < 0: + break + + result = tuple(pools[i][indices[i]] for i in staticrange(npools)) + +def combinations(pool, r: int): if r < 0: raise ValueError("r must be non-negative") - if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"): - return combinations_helper(pool, r) + + pool_list = _as_list(pool) + n = len(pool) + + if r > n: + return + + pool = pool_list.arr.ptr + indices = Ptr[int](r) + result = list(capacity=r) + + for i in range(r): + indices[i] = i + result.append(pool[i]) + + while True: + yield result + + i = r - 1 + while i >= 0 and indices[i] == i + n - r: + i -= 1 + + if i < 0: + break + + indices[i] += 1 + + for j in range(i + 1, r): + indices[j] = indices[j-1] + 1 + + result = result.copy() + while i < r: + result[i] = pool[indices[i]] + i += 1 + +@overload +def combinations(pool, r: Static[int]): + def empty(T: type) -> T: + pass + + if r < 0: + compile_error("r must be non-negative") + + if isinstance(pool, list): + pool_list = pool else: - return combinations_helper([a for a in pool], r) + pool_list = list(pool) -def combinations_with_replacement( - pool: Generator[T], r: int, T: type -) -> Generator[List[T]]: - """ - Return successive r-length combinations of elements in the iterable - allowing individual elements to have successive repeats. - """ + n = len(pool) - def combinations_with_replacement_helper( - pool: List[T], r: int, T: type - ) -> Generator[List[T]]: - n = len(pool) - if not n and r: - return - indices = [0 for _ in range(r)] - yield [pool[i] for i in indices] - while True: - b = -1 - for i in reversed(range(r)): - if indices[i] != n - 1: - b = i - break - if b == -1: - return - newval = indices[b] + 1 - for j in range(r - b): - indices[b + j] = newval - yield [pool[i] for i in indices] + if r > n: + return + + pool = pool_list.arr.ptr + indices_tuple = (0,) * r + indices = Ptr[int](__ptr__(indices_tuple).as_byte()) + result_tuple = (empty(pool.T),) * r + result = Ptr[pool.T](__ptr__(result_tuple).as_byte()) + + for i in range(r): + indices[i] = i + result[i] = pool[i] + + while True: + yield result_tuple + + i = r - 1 + while i >= 0 and indices[i] == i + n - r: + i -= 1 + + if i < 0: + break + + indices[i] += 1 + + for j in range(i + 1, r): + indices[j] = indices[j-1] + 1 + while i < r: + result[i] = pool[indices[i]] + i += 1 + +def combinations_with_replacement(pool, r: int): if r < 0: raise ValueError("r must be non-negative") - if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"): - return combinations_with_replacement_helper(pool, r) + + pool_list = _as_list(pool) + n = len(pool) + + if n == 0: + if r == 0: + yield List[pool_list.T](capacity=0) + return + + pool = pool_list.arr.ptr + indices = Ptr[int](r) + result = list(capacity=r) + + for i in range(r): + indices[i] = 0 + result.append(pool[0]) + + while True: + yield result + + i = r - 1 + while i >= 0 and indices[i] == n - 1: + i -= 1 + + if i < 0: + break + + result = result.copy() + index = indices[i] + 1 + elem = pool[index] + + while i < r: + indices[i] = index + result[i] = elem + i += 1 + +@overload +def combinations_with_replacement(pool, r: Static[int]): + def empty(T: type) -> T: + pass + + if r < 0: + compile_error("r must be non-negative") + + if r == 0: + yield () + return + + if isinstance(pool, list): + pool_list = pool else: - return combinations_with_replacement_helper([a for a in pool], r) + pool_list = list(pool) -def permutations( - pool: Generator[T], r: Optional[int] = None, T: type -) -> Generator[List[T]]: - """ - Return successive r-length permutations of elements in the iterable. - """ + n = len(pool) - def permutations_helper( - pool: List[T], r: Optional[int], T: type - ) -> Generator[List[T]]: - n = len(pool) - r: int = r if r is not None else n - if r > n: - return + if n == 0: + return - indices = list(range(n)) - cycles = list(range(n, n - r, -1)) - yield [pool[i] for i in indices[:r]] - while n: - b = -1 - for i in reversed(range(r)): - cycles[i] -= 1 - if cycles[i] == 0: - indices = indices[:i] + indices[i + 1 :] + indices[i : i + 1] - cycles[i] = n - i - else: - b = i - j = cycles[i] - indices[i], indices[-j] = indices[-j], indices[i] - yield [pool[i] for i in indices[:r]] - break - if b == -1: - return + pool = pool_list.arr.ptr + indices_tuple = (0,) * r + indices = Ptr[int](__ptr__(indices_tuple).as_byte()) + result_tuple = (empty(pool.T),) * r + result = Ptr[pool.T](__ptr__(result_tuple).as_byte()) + + for i in range(r): + result[i] = pool[0] + + while True: + yield result_tuple + + i = r - 1 + while i >= 0 and indices[i] == n - 1: + i -= 1 + + if i < 0: + break + + index = indices[i] + 1 + elem = pool[index] + + while i < r: + indices[i] = index + result[i] = elem + i += 1 - if r is not None and r.__val__() < 0: +def _permutations_non_static(pool, r = None): + pool_list = _as_list(pool) + n = len(pool) + + if r is None: + return _permutations_non_static(pool_list, n) + elif not isinstance(r, int): + compile_error("Expected int as r") + + if r < 0: raise ValueError("r must be non-negative") - if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"): - return permutations_helper(pool, r) - else: - return permutations_helper([a for a in pool], r) -@inline -def product(*args): - """ - Cartesian product of input iterables. - """ - if staticlen(args) == 0: - yield () + if r > n: + return + + indices = Ptr[int](n) + cycles = Ptr[int](r) + + for i in range(n): + indices[i] = i + + for i in range(r): + cycles[i] = n - i + + pool = pool_list.arr.ptr + result = list(capacity=r) + + for i in range(r): + result.append(pool[i]) + + while True: + yield result + + if n == 0: + break + + result = result.copy() + i = r - 1 + while i >= 0: + cycles[i] -= 1 + if cycles[i] == 0: + index = indices[i] + for j in range(i, n - 1): + indices[j] = indices[j+1] + indices[n-1] = index + cycles[i] = n - i + else: + j = cycles[i] + index = indices[i] + indices[i] = indices[n - j] + indices[n - j] = index + + for k in range(i, r): + index = indices[k] + result[k] = pool[index] + + break + i -= 1 + + if i < 0: + break + +def _permutations_static(pool, r: Static[int]): + def empty(T: type) -> T: + pass + + pool_list = _as_list(pool) + n = len(pool) + + if r < 0: + raise compile_error("r must be non-negative") + + if r > n: + return + + indices = Ptr[int](n) + cycles_tuple = (0,) * r + cycles = Ptr[int](__ptr__(cycles_tuple).as_byte()) + + for i in range(n): + indices[i] = i + + for i in range(r): + cycles[i] = n - i + + pool = pool_list.arr.ptr + result_tuple = (empty(pool.T),) * r + result = Ptr[pool.T](__ptr__(result_tuple).as_byte()) + + for i in range(r): + result[i] = pool[i] + + while True: + yield result_tuple + + if n == 0: + break + + i = r - 1 + while i >= 0: + cycles[i] -= 1 + if cycles[i] == 0: + index = indices[i] + for j in range(i, n - 1): + indices[j] = indices[j+1] + indices[n-1] = index + cycles[i] = n - i + else: + j = cycles[i] + index = indices[i] + indices[i] = indices[n - j] + indices[n - j] = index + + for k in range(i, r): + index = indices[k] + result[k] = pool[index] + + break + i -= 1 + + if i < 0: + break + +def permutations(pool, r = None): + if isinstance(pool, Tuple) and r is None: + return _permutations_static(pool, staticlen(pool)) else: - for a in args[0]: - rest = args[1:] - for b in product(*rest): - yield (a, *b) + return _permutations_non_static(pool, r) -@inline @overload -def product(*args, repeat: int): - """ - Cartesian product of input iterables. - """ - if repeat < 0: - raise ValueError("repeat argument cannot be negative") - pools = [list(pool) for _ in range(repeat) for pool in args] - result = [List[type(pools[0][0])]()] - for pool in pools: - result = [x + [y] for x in result for y in pool] - for prod in result: - yield prod +def permutations(pool, r: Static[int]): + return _permutations_static(pool, r) diff --git a/test/core/containers.codon b/test/core/containers.codon index 4eb88b62..6aff7cf6 100644 --- a/test/core/containers.codon +++ b/test/core/containers.codon @@ -571,6 +571,171 @@ def test_dict(): assert repr(Dict[int,int]()) == '{}' test_dict() +def slice_indices(slice, length): + """ + Reference implementation for the slice.indices method. + + """ + # Compute step and length as integers. + #length = operator.index(length) + step: int = 1 if slice.step is None else slice.step + + # Raise ValueError for negative length or zero step. + if length < 0: + raise ValueError("length should not be negative") + if step == 0: + raise ValueError("slice step cannot be zero") + + # Find lower and upper bounds for start and stop. + lower = -1 if step < 0 else 0 + upper = length - 1 if step < 0 else length + + # Compute start. + if slice.start is None: + start = upper if step < 0 else lower + else: + start = slice.start + start = max(start + length, lower) if start < 0 else min(start, upper) + + # Compute stop. + if slice.stop is None: + stop = lower if step < 0 else upper + else: + stop = slice.stop + stop = max(stop + length, lower) if stop < 0 else min(stop, upper) + + return start, stop, step + +def check_indices(slice, length): + err1 = False + err2 = False + + try: + actual = slice.indices(length) + except ValueError: + err1 = True + + try: + expected = slice_indices(slice, length) + except ValueError: + err2 = True + + if err1 or err2: + return err1 and err2 + + if actual != expected: + return False + + if length >= 0 and slice.step != 0: + actual = range(*slice.indices(length)) + expected = range(length)[slice] + if actual != expected: + return False + + return True + +@test +def test_slice(): + assert repr(slice(1, 2, 3)) == 'slice(1, 2, 3)' + + s1 = slice(1, 2, 3) + s2 = slice(1, 2, 3) + s3 = slice(1, 2, 4) + + assert s1 == s2 + assert s1 != s3 + + s = slice(1) + assert s.start == None + assert s.stop == 1 + assert s.step == None + + s = slice(1, 2) + assert s.start == 1 + assert s.stop == 2 + assert s.step == None + + s = slice(1, 2, 3) + assert s.start == 1 + assert s.stop == 2 + assert s.step == 3 + + # TODO + assert slice(None ).indices(10) == (0, 10, 1) + assert slice(None, None, 2).indices(10) == (0, 10, 2) + assert slice(1, None, 2).indices(10) == (1, 10, 2) + assert slice(None, None, -1).indices(10) == (9, -1, -1) + assert slice(None, None, -2).indices(10) == (9, -1, -2) + assert slice(3, None, -2).indices(10) == (3, -1, -2) + # issue 3004 tests + assert slice(None, -9).indices(10) == (0, 1, 1) + assert slice(None, -10).indices(10) == (0, 0, 1) + assert slice(None, -11).indices(10) == (0, 0, 1) + assert slice(None, -10, -1).indices(10) == (9, 0, -1) + assert slice(None, -11, -1).indices(10) == (9, -1, -1) + assert slice(None, -12, -1).indices(10) == (9, -1, -1) + assert slice(None, 9).indices(10) == (0, 9, 1) + assert slice(None, 10).indices(10) == (0, 10, 1) + assert slice(None, 11).indices(10) == (0, 10, 1) + assert slice(None, 8, -1).indices(10) == (9, 8, -1) + assert slice(None, 9, -1).indices(10) == (9, 9, -1) + assert slice(None, 10, -1).indices(10) == (9, 9, -1) + + assert slice(-100, 100 ).indices(10) == slice(None).indices(10) + + assert slice(100, -100, -1).indices(10) == slice(None, None, -1).indices(10) + + assert slice(-100, 100, 2).indices(10) == (0, 10, 2) + + import sys + assert list(range(10))[::sys.maxsize - 1] == [0] + + # Check a variety of start, stop, step and length values, including + # values exceeding sys.maxsize (see issue #14794). + vals = [None, -2**100, -2**30, -53, -7, -1, 0, 1, 7, 53, 2**30, 2**100] + lengths = [0, 1, 7, 53, 2**30, 2**100] + #for slice_args in itertools.product(vals, repeat=3): + for a in vals: + for b in vals: + for c in vals: + slice_args = (a, b, c) + s = slice(*slice_args) + for length in lengths: + assert check_indices(s, length) + assert check_indices(slice(0, 10, 1), -3) + + # Negative length should raise ValueError + try: + slice(None).indices(-1) + assert False + except ValueError: + pass + + # Zero step should raise ValueError + try: + slice(0, 10, 0).indices(5) + except ValueError: + pass + + # ... but it should be fine to use a custom class that provides index. + assert slice(0, 10, 1).indices(5) == (0, 5, 1) + ''' # not yet supported in Codon + assert slice(MyIndexable(0), 10, 1).indices(5) == (0, 5, 1) + assert slice(0, MyIndexable(10), 1).indices(5) == (0, 5, 1) + assert slice(0, 10, MyIndexable(1)).indices(5) == (0, 5, 1) + assert slice(0, 10, 1).indices(MyIndexable(5)) == (0, 5, 1) + ''' + tmp = [] + class X[T](object): + tmp: T + def __setitem__(self, i, k): + self.tmp.append((i, k)) + + x = X(tmp) + x[1:2] = 42 + assert tmp == [(slice(1, 2), 42)] +test_slice() + @test def test_deque(): from collections import deque diff --git a/test/stdlib/itertools_test.codon b/test/stdlib/itertools_test.codon index 081f1c4a..001dc75e 100644 --- a/test/stdlib/itertools_test.codon +++ b/test/stdlib/itertools_test.codon @@ -61,7 +61,9 @@ def underten(x): @test def test_combinations(): - assert list(itertools.combinations("ABCD", 2)) == [ + f = lambda x: x # hack to get non-static argument + + assert list(itertools.combinations("ABCD", f(2))) == [ ["A", "B"], ["A", "C"], ["A", "D"], @@ -69,7 +71,7 @@ def test_combinations(): ["B", "D"], ["C", "D"], ] - test_intermediate = itertools.combinations("ABCD", 2) + test_intermediate = itertools.combinations("ABCD", f(2)) next(test_intermediate) assert list(test_intermediate) == [ ["A", "C"], @@ -78,20 +80,49 @@ def test_combinations(): ["B", "D"], ["C", "D"], ] - assert list(itertools.combinations(range(4), 3)) == [ + assert list(itertools.combinations(range(4), f(3))) == [ [0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3], ] - test_intermediate = itertools.combinations(range(4), 3) + test_intermediate = itertools.combinations(range(4), f(3)) next(test_intermediate) assert list(test_intermediate) == [[0, 1, 3], [0, 2, 3], [1, 2, 3]] + assert list(itertools.combinations("ABCD", 2)) == [ + ("A", "B"), + ("A", "C"), + ("A", "D"), + ("B", "C"), + ("B", "D"), + ("C", "D"), + ] + test_intermediate = itertools.combinations("ABCD", 2) + next(test_intermediate) + assert list(test_intermediate) == [ + ("A", "C"), + ("A", "D"), + ("B", "C"), + ("B", "D"), + ("C", "D"), + ] + assert list(itertools.combinations(range(4), 3)) == [ + (0, 1, 2), + (0, 1, 3), + (0, 2, 3), + (1, 2, 3), + ] + test_intermediate = itertools.combinations(range(4), 3) + next(test_intermediate) + assert list(test_intermediate) == [(0, 1, 3), (0, 2, 3), (1, 2, 3)] + @test def test_combinations_with_replacement(): - assert list(itertools.combinations_with_replacement(range(3), 3)) == [ + f = lambda x: x # hack to get non-static argument + + assert list(itertools.combinations_with_replacement(range(3), f(3))) == [ [0, 0, 0], [0, 0, 1], [0, 0, 2], @@ -103,7 +134,7 @@ def test_combinations_with_replacement(): [1, 2, 2], [2, 2, 2], ] - assert list(itertools.combinations_with_replacement("ABC", 2)) == [ + assert list(itertools.combinations_with_replacement("ABC", f(2))) == [ ["A", "A"], ["A", "B"], ["A", "C"], @@ -111,7 +142,7 @@ def test_combinations_with_replacement(): ["B", "C"], ["C", "C"], ] - test_intermediate = itertools.combinations_with_replacement("ABC", 2) + test_intermediate = itertools.combinations_with_replacement("ABC", f(2)) next(test_intermediate) assert list(test_intermediate) == [ ["A", "B"], @@ -121,6 +152,35 @@ def test_combinations_with_replacement(): ["C", "C"], ] + assert list(itertools.combinations_with_replacement(range(3), 3)) == [ + (0, 0, 0), + (0, 0, 1), + (0, 0, 2), + (0, 1, 1), + (0, 1, 2), + (0, 2, 2), + (1, 1, 1), + (1, 1, 2), + (1, 2, 2), + (2, 2, 2), + ] + assert list(itertools.combinations_with_replacement("ABC", 2)) == [ + ("A", "A"), + ("A", "B"), + ("A", "C"), + ("B", "B"), + ("B", "C"), + ("C", "C"), + ] + test_intermediate = itertools.combinations_with_replacement("ABC", 2) + next(test_intermediate) + assert list(test_intermediate) == [ + ("A", "B"), + ("A", "C"), + ("B", "B"), + ("B", "C"), + ("C", "C"), + ] @test def test_islice(): @@ -243,7 +303,9 @@ def test_filterfalse(): @test def test_permutations(): - assert list(itertools.permutations(range(3), 2)) == [ + f = lambda x: x # hack to get non-static argument + + assert list(itertools.permutations(range(3), f(2))) == [ [0, 1], [0, 2], [1, 0], @@ -255,6 +317,24 @@ def test_permutations(): for n in range(3): values = [5 * x - 12 for x in range(n)] for r in range(n + 2): + result = list(itertools.permutations(values, f(r))) + if r > n: # right number of perms + assert len(result) == 0 + # factorial is not yet implemented in math + # else: fact(n) / fact(n - r) + + assert list(itertools.permutations(range(3), 2)) == [ + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1), + ] + + for n in staticrange(3): + values = [5 * x - 12 for x in range(n)] + for r in staticrange(n + 2): result = list(itertools.permutations(values, r)) if r > n: # right number of perms assert len(result) == 0 @@ -487,18 +567,19 @@ test_chain_from_iterable_from_cpython() @test def test_combinations_from_cpython(): + f = lambda x: x # hack to get non-static argument from math import factorial as fact err = False try: - list(combinations("abc", -2)) + list(combinations("abc", f(-2))) assert False except ValueError: err = True assert err - assert list(combinations("abc", 32)) == [] # r > n - assert list(combinations("ABCD", 2)) == [ + assert list(combinations("abc", f(32))) == [] # r > n + assert list(combinations("ABCD", f(2))) == [ ["A", "B"], ["A", "C"], ["A", "D"], @@ -506,7 +587,7 @@ def test_combinations_from_cpython(): ["B", "D"], ["C", "D"], ] - assert list(combinations(range(4), 3)) == [ + assert list(combinations(range(4), f(3))) == [ [0, 1, 2], [0, 1, 3], [0, 2, 3], @@ -516,7 +597,7 @@ def test_combinations_from_cpython(): for n in range(7): values = [5 * x - 12 for x in range(n)] for r in range(n + 2): - result = list(combinations(values, r)) + result = list(combinations(values, f(r))) assert len(result) == (0 if r > n else fact(n) // fact(r) // fact(n - r)) assert len(result) == len(set(result)) # no repeats @@ -531,21 +612,55 @@ def test_combinations_from_cpython(): ] # comb is a subsequence of the input iterable + assert list(combinations("abc", 32)) == [] # r > n + assert list(combinations("ABCD", 2)) == [ + ("A", "B"), + ("A", "C"), + ("A", "D"), + ("B", "C"), + ("B", "D"), + ("C", "D"), + ] + assert list(combinations(range(4), 3)) == [ + (0, 1, 2), + (0, 1, 3), + (0, 2, 3), + (1, 2, 3), + ] + + for n in staticrange(7): + values = [5 * x - 12 for x in range(n)] + for r in staticrange(n + 2): + result = list(combinations(values, r)) + + assert len(result) == (0 if r > n else fact(n) // fact(r) // fact(n - r)) + assert len(result) == len(set(result)) # no repeats + # assert result == sorted(result) # lexicographic order + for c in result: + assert len(c) == r # r-length combinations + assert len(set(c)) == r # no duplicate elements + assert list(c) == sorted(c) # keep original ordering + assert all(e in values for e in c) # elements taken from input iterable + assert list(c) == [ + e for e in values if e in c + ] # comb is a subsequence of the input iterable + test_combinations_from_cpython() @test def test_combinations_with_replacement_from_cpython(): + f = lambda x: x # hack to get non-static argument cwr = combinations_with_replacement err = False try: - list(cwr("abc", -2)) + list(combinations_with_replacement("abc", f(-2))) assert False except ValueError: err = True assert err - assert list(cwr("ABC", 2)) == [ + assert list(combinations_with_replacement("ABC", f(2))) == [ ["A", "A"], ["A", "B"], ["A", "C"], @@ -564,7 +679,44 @@ def test_combinations_with_replacement_from_cpython(): for n in range(7): values = [5 * x - 12 for x in range(n)] for r in range(n + 2): - result = list(cwr(values, r)) + result = list(combinations_with_replacement(values, r)) + regular_combs = list(combinations(values, r)) + + assert len(result) == numcombs(n, r) + assert len(result) == len(set(result)) # no repeats + # assert result == sorted(result) # lexicographic order + + if n == 0 or r <= 1: + assert result == regular_combs # cases that should be identical + else: + assert set(result) >= set(regular_combs) + + for c in result: + assert len(c) == r # r-length combinations + noruns = [k for k, v in groupby(c)] # combo without consecutive repeats + assert len(noruns) == len( + set(noruns) + ) # no repeats other than consecutive + assert list(c) == sorted(c) # keep original ordering + assert all(e in values for e in c) # elements taken from input iterable + assert noruns == [ + e for e in values if e in c + ] # comb is a subsequence of the input iterable + + + assert list(combinations_with_replacement("ABC", 2)) == [ + ("A", "A"), + ("A", "B"), + ("A", "C"), + ("B", "B"), + ("B", "C"), + ("C", "C"), + ] + + for n in staticrange(7): + values = [5 * x - 12 for x in range(n)] + for r in staticrange(n + 2): + result = list(combinations_with_replacement(values, r)) regular_combs = list(combinations(values, r)) assert len(result) == numcombs(n, r) @@ -594,18 +746,19 @@ test_combinations_with_replacement_from_cpython() @test def test_permutations_from_cpython(): + f = lambda x: x # hack to get non-static argument from math import factorial as fact err = False try: - list(permutations("abc", -2)) + list(permutations("abc", f(-2))) assert False except ValueError: err = True assert err - assert list(permutations("abc", 32)) == [] - assert list(permutations(range(3), 2)) == [ + assert list(permutations("abc", f(32))) == [] + assert list(permutations(range(3), f(2))) == [ [0, 1], [0, 2], [1, 0], @@ -632,6 +785,33 @@ def test_permutations_from_cpython(): assert result == list(permutations(values, None)) # test r as None assert result == list(permutations(values)) # test default r + assert list(permutations("abc", 32)) == [] + assert list(permutations(range(3), 2)) == [ + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1), + ] + + for n in staticrange(7): + values = [5 * x - 12 for x in range(n)] + for r in staticrange(n + 2): + result = list(permutations(values, r)) + assert len(result) == ( + 0 if r > n else fact(n) // fact(n - r) + ) # right number of perms + assert len(result) == len(set(result)) # no repeats + # assert result == sorted(result) # lexicographic order + for p in result: + assert len(p) == r # r-length permutations + assert len(set(p)) == r # no duplicate elements + assert all(e in values for e in p) # elements taken from input iterable + + if r == n: + assert result == list(permutations(values, r)) + test_permutations_from_cpython() @@ -728,6 +908,49 @@ def test_combinatorics_from_cpython(): ) # comb: cwr that is a perm assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm + for n in staticrange(6): + s = "ABCDEFG"[:n] + for r in staticrange(8): + prod = list(product(s, repeat=r)) + cwr = list(combinations_with_replacement(s, r)) + perm = list(permutations(s, r)) + comb = list(combinations(s, r)) + + # Check size + assert len(prod) == n ** r + assert len(cwr) == ( + (fact(n + r - 1) // fact(r) // fact(n - 1)) if n else (0 if r else 1) + ) + assert len(perm) == (0 if r > n else fact(n) // fact(n - r)) + assert len(comb) == (0 if r > n else fact(n) // fact(r) // fact(n - r)) + + # Check lexicographic order without repeated tuples + assert prod == sorted(set(prod)) + assert cwr == sorted(set(cwr)) + assert perm == sorted(set(perm)) + assert comb == sorted(set(comb)) + + # Check interrelationships + assert cwr == [ + t for t in prod if sorted(t) == list(t) + ] # cwr: prods which are sorted + assert perm == [ + t for t in prod if len(set(t)) == r + ] # perm: prods with no dups + assert comb == [ + t for t in perm if sorted(t) == list(t) + ] # comb: perms that are sorted + assert comb == [ + t for t in cwr if len(set(t)) == r + ] # comb: cwrs without dups + assert comb == list( + filter(set(cwr).__contains__, perm) + ) # comb: perm that is a cwr + assert comb == list( + filter(set(perm).__contains__, cwr) + ) # comb: cwr that is a perm + assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm + test_combinatorics_from_cpython()