diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 786ef4ca..ba70b777 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -172,6 +172,7 @@ jobs: ln -s build/libcodonrt.${LIBEXT} . build/codon_test test/app/test.sh build + (cd test/python && CODON_DIR=$(pwd)/../../codon-deploy python3 setup.py build_ext --inplace && python3 pyext.py) env: CODON_PATH: ./stdlib PYTHONPATH: .:./test/python @@ -232,7 +233,7 @@ jobs: path: codon-linux-x86_64.tar.gz - name: Publish on TestPyPI - if: startsWith(matrix.os, 'ubuntu') + if: github.ref == 'refs/heads/develop' && startsWith(matrix.os, 'ubuntu') uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ diff --git a/.gitignore b/.gitignore index a10c998a..d0cd885e 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ build/ build_*/ install/ +install_*/ extra/python/src/jit.cpp extra/jupyter/build/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a06e6f1..66bd526d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,16 +1,15 @@ cmake_minimum_required(VERSION 3.14) project( Codon - VERSION "0.15.5" + VERSION "0.16.0" HOMEPAGE_URL "https://github.com/exaloop/codon" DESCRIPTION "high-performance, extensible Python compiler") -set(CODON_JIT_PYTHON_VERSION "0.1.3") +set(CODON_JIT_PYTHON_VERSION "0.1.4") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in" "${PROJECT_SOURCE_DIR}/codon/config/config.h") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in" "${PROJECT_SOURCE_DIR}/extra/python/codon/version.py") -option(CODON_JUPYTER "build Codon Jupyter server" OFF) option(CODON_GPU "build Codon GPU backend" OFF) set(CMAKE_CXX_STANDARD 17) @@ -69,6 +68,10 @@ add_custom_command( omp DEPENDS peg2cpp codon/parser/peg/openmp.peg) +# Codon Jupyter library +set(CODON_JUPYTER_FILES codon/util/jupyter.h codon/util/jupyter.cpp) +add_library(codon_jupyter SHARED ${CODON_JUPYTER_FILES}) + # Codon runtime library set(CODONRT_FILES codon/runtime/lib.h codon/runtime/lib.cpp codon/runtime/re.cpp codon/runtime/exc.cpp @@ -181,6 +184,7 @@ set(CODON_HPPFILES codon/cir/llvm/llvm.h codon/cir/llvm/optimize.h codon/cir/module.h + codon/cir/pyextension.h codon/cir/cir.h codon/cir/transform/cleanup/canonical.h codon/cir/transform/cleanup/dead_code.h @@ -218,7 +222,6 @@ set(CODON_HPPFILES codon/cir/value.h codon/cir/var.h codon/util/common.h - extra/jupyter/jupyter.h codon/compiler/jit_extern.h) set(CODON_CPPFILES codon/compiler/compiler.cpp @@ -320,16 +323,10 @@ set(CODON_CPPFILES codon/cir/util/visitor.cpp codon/cir/value.cpp codon/cir/var.cpp - codon/util/common.cpp - extra/jupyter/jupyter.cpp) + codon/util/common.cpp) add_library(codonc SHARED ${CODON_HPPFILES}) target_include_directories(codonc PRIVATE ${peglib_SOURCE_DIR} ${toml_SOURCE_DIR}/include ${semver_SOURCE_DIR}/include) target_sources(codonc PRIVATE ${CODON_CPPFILES} codon_rules.cpp omp_rules.cpp) -if(CODON_JUPYTER) - add_compile_definitions(CODON_JUPYTER) - add_dependencies(codonc xeus-static nlohmann_json) - target_link_libraries(codonc PRIVATE xeus-static) -endif() if(ASAN) target_compile_options( codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" @@ -427,7 +424,7 @@ endif() # Codon command-line tool add_executable(codon codon/app/main.cpp) -target_link_libraries(codon PUBLIC ${STATIC_LIBCPP} fmt codonc Threads::Threads) +target_link_libraries(codon PUBLIC ${STATIC_LIBCPP} fmt codonc codon_jupyter Threads::Threads) # Codon test Download and unpack googletest at configure time include(FetchContent) @@ -463,7 +460,7 @@ target_link_libraries(codon_test fmt codonc codonrt gtest_main) target_compile_definitions(codon_test PRIVATE TEST_DIR="${CMAKE_CURRENT_SOURCE_DIR}/test") -install(TARGETS codonrt codonc DESTINATION lib/codon) +install(TARGETS codonrt codonc codon_jupyter DESTINATION lib/codon) install(FILES ${CMAKE_BINARY_DIR}/libomp${CMAKE_SHARED_LIBRARY_SUFFIX} DESTINATION lib/codon) install(TARGETS codon DESTINATION bin) install(DIRECTORY ${CMAKE_BINARY_DIR}/include/codon DESTINATION include) diff --git a/cmake/deps.cmake b/cmake/deps.cmake index 172bb291..90687bb6 100644 --- a/cmake/deps.cmake +++ b/cmake/deps.cmake @@ -169,50 +169,3 @@ if(APPLE AND APPLE_ARM) "LIBUNWIND_ENABLE_SHARED ON" "LIBUNWIND_INCLUDE_DOCS OFF") endif() - -if(CODON_JUPYTER) - CPMAddPackage( - NAME libzmq - VERSION 4.3.4 - URL https://github.com/zeromq/libzmq/releases/download/v4.3.4/zeromq-4.3.4.tar.gz - EXCLUDE_FROM_ALL YES - OPTIONS "WITH_PERF_TOOL OFF" - "ZMQ_BUILD_TESTS OFF" - "ENABLE_CPACK OFF" - "BUILD_SHARED ON" - "WITH_LIBSODIUM OFF" - "WITH_TLS OFF") - CPMAddPackage( - NAME cppzmq - URL https://github.com/zeromq/cppzmq/archive/refs/tags/v4.8.1.tar.gz - VERSION 4.8.1 - EXCLUDE_FROM_ALL YES - OPTIONS "CPPZMQ_BUILD_TESTS OFF") - CPMAddPackage( - NAME xtl - GITHUB_REPOSITORY "xtensor-stack/xtl" - VERSION 0.7.3 - GIT_TAG 0.7.3 - EXCLUDE_FROM_ALL YES - OPTIONS "BUILD_TESTS OFF") - CPMAddPackage( - NAME json - GITHUB_REPOSITORY "nlohmann/json" - VERSION 3.10.1) - CPMAddPackage( - NAME xeus - GITHUB_REPOSITORY "jupyter-xeus/xeus" - VERSION 2.2.0 - GIT_TAG 2.2.0 - EXCLUDE_FROM_ALL YES - PATCH_COMMAND patch -N -u CMakeLists.txt -b ${CMAKE_SOURCE_DIR}/cmake/xeus.patch || true - OPTIONS "BUILD_EXAMPLES OFF" - "XEUS_BUILD_SHARED_LIBS OFF" - "XEUS_STATIC_DEPENDENCIES ON" - "CMAKE_POSITION_INDEPENDENT_CODE ON" - "XEUS_DISABLE_ARCH_NATIVE ON" - "XEUS_USE_DYNAMIC_UUID ${XEUS_USE_DYNAMIC_UUID}") - if (xeus_ADDED) - install(TARGETS nlohmann_json EXPORT xeus-targets) - endif() -endif() diff --git a/codon/app/main.cpp b/codon/app/main.cpp index 90adba77..64b8255d 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -14,7 +15,9 @@ #include "codon/compiler/error.h" #include "codon/compiler/jit.h" #include "codon/util/common.h" +#include "codon/util/jupyter.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" namespace { void versMsg(llvm::raw_ostream &out) { @@ -83,7 +86,7 @@ void initLogFlags(const llvm::cl::opt &log) { codon::getLogger().parse(std::string(d)); } -enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect }; +enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect }; enum OptMode { Debug, Release }; enum Numerics { C, Python }; } // namespace @@ -109,8 +112,9 @@ int docMode(const std::vector &args, const std::string &argv0) { return EXIT_SUCCESS; } -std::unique_ptr processSource(const std::vector &args, - bool standalone) { +std::unique_ptr processSource( + const std::vector &args, bool standalone, + std::function pyExtension = [] { return false; }) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); auto regs = llvm::cl::getRegisteredOptions(); @@ -163,9 +167,9 @@ std::unique_ptr processSource(const std::vector & const bool isDebug = (optMode == OptMode::Debug); std::vector disabledOptsVec(disabledOpts); - auto compiler = std::make_unique(args[0], isDebug, disabledOptsVec, - /*isTest=*/false, - (numerics == Numerics::Python)); + auto compiler = std::make_unique( + args[0], isDebug, disabledOptsVec, + /*isTest=*/false, (numerics == Numerics::Python), pyExtension()); compiler->getLLVMVisitor()->setStandalone(standalone); // load plugins @@ -296,21 +300,27 @@ int buildMode(const std::vector &args, const std::string &argv0) { llvm::cl::desc("Pass given flags to linker")); llvm::cl::opt buildKind( llvm::cl::desc("output type"), - llvm::cl::values(clEnumValN(LLVM, "llvm", "Generate LLVM IR"), - clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"), - clEnumValN(Object, "obj", "Generate native object file"), - clEnumValN(Executable, "exe", "Generate executable"), - clEnumValN(Library, "lib", "Generate shared library"), - clEnumValN(Detect, "detect", - "Detect output type based on output file extension")), + llvm::cl::values( + clEnumValN(LLVM, "llvm", "Generate LLVM IR"), + clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"), + clEnumValN(Object, "obj", "Generate native object file"), + clEnumValN(Executable, "exe", "Generate executable"), + clEnumValN(Library, "lib", "Generate shared library"), + clEnumValN(PyExtension, "pyext", "Generate Python extension module"), + clEnumValN(Detect, "detect", + "Detect output type based on output file extension")), llvm::cl::init(Detect)); llvm::cl::opt output( "o", llvm::cl::desc( "Write compiled output to specified file. Supported extensions: " "none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)")); + llvm::cl::opt pyModule( + "module", llvm::cl::desc("Python extension module name (only applicable when " + "building Python extension module)")); - auto compiler = processSource(args, /*standalone=*/true); + auto compiler = processSource(args, /*standalone=*/true, + [&] { return buildKind == BuildKind::PyExtension; }); if (!compiler) return EXIT_FAILURE; std::vector libsVec(libs); @@ -326,6 +336,7 @@ int buildMode(const std::vector &args, const std::string &argv0) { extension = ".bc"; break; case BuildKind::Object: + case BuildKind::PyExtension: extension = ".o"; break; case BuildKind::Library: @@ -358,6 +369,12 @@ int buildMode(const std::vector &args, const std::string &argv0) { compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec, lflags); break; + case BuildKind::PyExtension: + compiler->getCache()->pyModule->name = + pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule; + compiler->getLLVMVisitor()->writeToPythonExtension(*compiler->getCache()->pyModule, + filename); + break; case BuildKind::Detect: compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); break; @@ -368,15 +385,7 @@ int buildMode(const std::vector &args, const std::string &argv0) { return EXIT_SUCCESS; } -#ifdef CODON_JUPYTER -namespace codon { -int startJupyterKernel(const std::string &argv0, - const std::vector &plugins, - const std::string &configPath); -} -#endif int jupyterMode(const std::vector &args) { -#ifdef CODON_JUPYTER llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); llvm::cl::opt input(llvm::cl::Positional, @@ -385,11 +394,6 @@ int jupyterMode(const std::vector &args) { llvm::cl::ParseCommandLineOptions(args.size(), args.data()); int code = codon::startJupyterKernel(args[0], plugins, input); return code; -#else - fmt::print("Jupyter support not included. Please recompile with " - "-DCODON_JUPYTER."); - return EXIT_FAILURE; -#endif } void showCommandsAndExit() { diff --git a/codon/cir/attribute.cpp b/codon/cir/attribute.cpp index 10153adb..0f104e99 100644 --- a/codon/cir/attribute.cpp +++ b/codon/cir/attribute.cpp @@ -41,6 +41,8 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const { const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute"; +const std::string DocstringAttribute::AttributeName = "docstringAttribute"; + const std::string TupleLiteralAttribute::AttributeName = "tupleLiteralAttribute"; std::unique_ptr TupleLiteralAttribute::clone(util::CloneVisitor &cv) const { diff --git a/codon/cir/attribute.h b/codon/cir/attribute.h index 774b9f0e..2fc8a841 100644 --- a/codon/cir/attribute.h +++ b/codon/cir/attribute.h @@ -64,6 +64,26 @@ struct SrcInfoAttribute : public Attribute { std::ostream &doFormat(std::ostream &os) const override { return os << info; } }; +/// Attribute containing docstring from source +struct DocstringAttribute : public Attribute { + static const std::string AttributeName; + + /// the docstring + std::string docstring; + + DocstringAttribute() = default; + /// Constructs a DocstringAttribute. + /// @param docstring the docstring + explicit DocstringAttribute(const std::string &docstring) : docstring(docstring) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + +private: + std::ostream &doFormat(std::ostream &os) const override { return os << docstring; } +}; + /// Attribute containing function information struct KeyValueAttribute : public Attribute { static const std::string AttributeName; diff --git a/codon/cir/instr.cpp b/codon/cir/instr.cpp index f8f59e04..4ea18c5b 100644 --- a/codon/cir/instr.cpp +++ b/codon/cir/instr.cpp @@ -114,6 +114,8 @@ types::Type *TypePropertyInstr::doGetType() const { switch (property) { case Property::IS_ATOMIC: return getModule()->getBoolType(); + case Property::IS_CONTENT_ATOMIC: + return getModule()->getBoolType(); case Property::SIZEOF: return getModule()->getIntType(); default: diff --git a/codon/cir/instr.h b/codon/cir/instr.h index 1a5ff045..ff6bd71b 100644 --- a/codon/cir/instr.h +++ b/codon/cir/instr.h @@ -269,7 +269,7 @@ class StackAllocInstr : public AcceptorExtend { /// Instr representing getting information about a type. class TypePropertyInstr : public AcceptorExtend { public: - enum Property { IS_ATOMIC, SIZEOF }; + enum Property { IS_ATOMIC, IS_CONTENT_ATOMIC, SIZEOF }; private: /// the type being inspected diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index c018a4a9..ff79a8a7 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -52,6 +52,17 @@ std::string LLVMVisitor::getNameForFunction(const Func *x) { } } +std::string LLVMVisitor::getNameForVar(const Var *x) { + if (auto *f = cast(x)) + return getNameForFunction(f); + + if (x->isExternal()) { + return x->getName(); + } else { + return "." + x->getName(); + } +} + LLVMVisitor::LLVMVisitor() : util::ConstVisitor(), context(std::make_unique()), M(), B(std::make_unique>(*context)), func(nullptr), block(nullptr), @@ -118,11 +129,16 @@ void LLVMVisitor::registerGlobal(const Var *var) { : llvm::GlobalValue::PrivateLinkage; auto *storage = new llvm::GlobalVariable( *M, llvmType, /*isConstant=*/false, linkage, - external ? nullptr : llvm::Constant::getNullValue(llvmType), var->getName()); + external ? nullptr : llvm::Constant::getNullValue(llvmType), + getNameForVar(var)); insertVar(var, storage); if (external) { - storage->setDSOLocal(true); + if (db.jit) { + storage->setDSOLocal(true); + } else { + storage->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Local); + } } else { // debug info auto *srcInfo = getSrcInfo(var); @@ -145,13 +161,14 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) { if (!it->second) { // if value is null, it's from another module // see if it's in the module already auto name = var->getName(); - if (auto *global = M->getNamedValue(name)) + auto privName = getNameForVar(var); + if (auto *global = M->getNamedValue(privName)) return global; llvm::Type *llvmType = getLLVMType(var->getType()); auto *storage = new llvm::GlobalVariable(*M, llvmType, /*isConstant=*/false, llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/nullptr, name); + /*Initializer=*/nullptr, privName); storage->setExternallyInitialized(true); // debug info @@ -410,9 +427,9 @@ void executeCommand(const std::vector &args) { void LLVMVisitor::setupGlobalCtorForSharedLibrary() { const std::string llvmCtor = "llvm.global_ctors"; auto *main = M->getFunction("main"); - main->setName(".main"); // avoid clash with other main if (M->getNamedValue(llvmCtor) || !main) return; + main->setName(".main"); // avoid clash with other main auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false); auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(), @@ -467,7 +484,7 @@ void LLVMVisitor::writeToExecutable(const std::string &filename, rpaths.push_back(std::string(path)); } - std::vector command = {"gcc"}; + std::vector command = {"g++"}; // Avoid "argument unused during compilation" warning command.push_back("-Wno-unused-command-line-argument"); // MUST go before -llib to compile on Linux @@ -505,15 +522,20 @@ void LLVMVisitor::writeToExecutable(const std::string &filename, if (plugins) { for (auto *plugin : *plugins) { - auto dylibPath = plugin->info.dylibPath; - if (dylibPath.empty()) - continue; + if (plugin->info.linkArgs.empty()) { + auto dylibPath = plugin->info.dylibPath; + if (dylibPath.empty()) + continue; - auto stem = llvm::sys::path::stem(dylibPath); - if (stem.startswith("lib")) - stem = stem.substr(3); + auto stem = llvm::sys::path::stem(dylibPath); + if (stem.startswith("lib")) + stem = stem.substr(3); - command.push_back("-l" + stem.str()); + command.push_back("-l" + stem.str()); + } else { + for (auto &l : plugin->info.linkArgs) + command.push_back(l); + } } } @@ -547,6 +569,626 @@ void LLVMVisitor::writeToExecutable(const std::string &filename, llvm::sys::fs::remove(objFile); } +namespace { +// https://github.com/python/cpython/blob/main/Include/methodobject.h +constexpr int PYEXT_METH_VARARGS = 0x0001; +constexpr int PYEXT_METH_KEYWORDS = 0x0002; +constexpr int PYEXT_METH_NOARGS = 0x0004; +constexpr int PYEXT_METH_O = 0x0008; +constexpr int PYEXT_METH_CLASS = 0x0010; +constexpr int PYEXT_METH_STATIC = 0x0020; +constexpr int PYEXT_METH_COEXIST = 0x0040; +constexpr int PYEXT_METH_FASTCALL = 0x0080; +constexpr int PYEXT_METH_METHOD = 0x0200; +// https://github.com/python/cpython/blob/main/Include/modsupport.h +constexpr int PYEXT_PYTHON_ABI_VERSION = 1013; +// https://github.com/python/cpython/blob/main/Include/descrobject.h +constexpr int PYEXT_READONLY = 1; +} // namespace + +llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) { + auto *wrap = + cast(M->getOrInsertFunction((func->getName() + ".tc_wrap").str(), + func->getFunctionType()) + .getCallee()); + wrap->setPersonalityFn(llvm::cast(makePersonalityFunc().getCallee())); + auto *entry = llvm::BasicBlock::Create(*context, "entry", wrap); + auto *normal = llvm::BasicBlock::Create(*context, "normal", wrap); + auto *unwind = llvm::BasicBlock::Create(*context, "unwind", wrap); + + B->SetInsertPoint(entry); + std::vector args; + for (auto &arg : wrap->args()) { + args.push_back(&arg); + } + auto *result = B->CreateInvoke(func, normal, unwind, args); + + B->SetInsertPoint(normal); + B->CreateRet(result); + + B->SetInsertPoint(unwind); + auto *caughtResult = B->CreateLandingPad(getPadType(), 1); + caughtResult->setCleanup(true); + caughtResult->addClause(getTypeIdxVar(nullptr)); + auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only + auto *unwindException = B->CreateExtractValue(caughtResult, 0); + auto *unwindExceptionClass = B->CreateLoad( + B->getInt64Ty(), + B->CreateStructGEP( + unwindType, B->CreatePointerCast(unwindException, unwindType->getPointerTo()), + 0)); + unwindException = B->CreateExtractValue(caughtResult, 0); + auto *excType = llvm::StructType::get(getTypeInfoType(), B->getInt8PtrTy()); + auto *excVal = + B->CreatePointerCast(B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, + (uint64_t)seq_exc_offset()), + excType->getPointerTo()); + auto *loadedExc = B->CreateLoad(excType, excVal); + auto *objPtr = B->CreateExtractValue(loadedExc, 1); + + auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getInt8PtrTy()); + auto *excHeader = + llvm::StructType::get(strType, strType, strType, strType, B->getInt64Ty(), + B->getInt64Ty(), B->getInt8PtrTy()); + auto *header = B->CreateLoad(excHeader, objPtr); + auto *msg = B->CreateExtractValue(header, 1); + auto *msgLen = B->CreateExtractValue(msg, 0); + auto *msgPtr = B->CreateExtractValue(msg, 1); + auto *pyType = B->CreateExtractValue(header, 6); + + // copy msg into new null-terminated buffer + auto alloc = makeAllocFunc(/*atomic=*/true); + auto *buf = B->CreateCall(alloc, B->CreateAdd(msgLen, B->getInt64(1))); + B->CreateMemCpy(buf, {}, msgPtr, {}, msgLen); + auto *last = B->CreateInBoundsGEP(B->getInt8Ty(), buf, msgLen); + B->CreateStore(B->getInt8(0), last); + + auto *pyErrSetString = llvm::cast( + M->getOrInsertFunction("PyErr_SetString", B->getVoidTy(), B->getInt8PtrTy(), + B->getInt8PtrTy()) + .getCallee()); + + const std::string pyExcRuntimeErrorName = "PyExc_RuntimeError"; + llvm::Value *pyExcRuntimeError = M->getNamedValue(pyExcRuntimeErrorName); + if (!pyExcRuntimeError) { + auto *pyExcRuntimeErrorVar = new llvm::GlobalVariable( + *M, B->getInt8PtrTy(), /*isConstant=*/false, llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/nullptr, pyExcRuntimeErrorName); + pyExcRuntimeErrorVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + pyExcRuntimeError = pyExcRuntimeErrorVar; + } + pyExcRuntimeError = B->CreateLoad(B->getInt8PtrTy(), pyExcRuntimeError); + + auto *havePyType = + B->CreateICmpNE(pyType, llvm::ConstantPointerNull::get(B->getInt8PtrTy())); + B->CreateCall(pyErrSetString, + {B->CreateSelect(havePyType, pyType, pyExcRuntimeError), buf}); + + auto *retType = wrap->getReturnType(); + if (retType == B->getInt32Ty()) { + B->CreateRet(B->getInt32(-1)); + } else { + B->CreateRet(llvm::Constant::getNullValue(retType)); + } + + return wrap; +} + +void LLVMVisitor::writeToPythonExtension(const PyModule &pymod, + const std::string &filename) { + // Setup LLVM types & constants + auto *i64 = B->getInt64Ty(); + auto *i32 = B->getInt32Ty(); + auto *i8 = B->getInt8Ty(); + auto *ptr = B->getInt8PtrTy(); + auto *pyMethodDefType = llvm::StructType::create("PyMethodDef", ptr, ptr, i32, ptr); + auto *pyObjectType = llvm::StructType::create("PyObject", i64, ptr); + auto *pyVarObjectType = llvm::StructType::create("PyVarObject", pyObjectType, i64); + auto *pyModuleDefBaseType = + llvm::StructType::create("PyMethodDefBase", pyObjectType, ptr, i64, ptr); + auto *pyModuleDefType = + llvm::StructType::create("PyModuleDef", pyModuleDefBaseType, ptr, ptr, i64, + pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr); + auto *pyMemberDefType = + llvm::StructType::create("PyMemberDef", ptr, i32, i64, i32, ptr); + auto *pyGetSetDefType = + llvm::StructType::create("PyGetSetDef", ptr, ptr, ptr, ptr, ptr); + std::vector pyNumberMethodsFields(36, ptr); + auto *pyNumberMethodsType = + llvm::StructType::create(*context, pyNumberMethodsFields, "PyNumberMethods"); + std::vector pySequenceMethodsFields(10, ptr); + auto *pySequenceMethodsType = + llvm::StructType::create(*context, pySequenceMethodsFields, "PySequenceMethods"); + std::vector pyMappingMethodsFields(3, ptr); + auto *pyMappingMethodsType = + llvm::StructType::create(*context, pyMappingMethodsFields, "PyMappingMethods"); + std::vector pyAsyncMethodsFields(4, ptr); + auto *pyAsyncMethodsType = + llvm::StructType::create(*context, pyAsyncMethodsFields, "PyAsyncMethods"); + auto *pyBufferProcsType = llvm::StructType::create("PyBufferProcs", ptr, ptr); + auto *pyTypeObjectType = llvm::StructType::create( + "PyTypeObject", pyVarObjectType, ptr, i64, i64, ptr, i64, ptr, ptr, ptr, ptr, ptr, + ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i64, ptr, ptr, ptr, ptr, i64, ptr, ptr, + ptr, ptr, ptr, ptr, ptr, ptr, ptr, i64, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, + ptr, ptr, ptr, i32, ptr, ptr, i8); + auto *zero64 = B->getInt64(0); + auto *zero32 = B->getInt32(0); + auto *zero8 = B->getInt8(0); + auto *null = llvm::Constant::getNullValue(ptr); + auto *pyTypeType = new llvm::GlobalVariable(*M, ptr, /*isConstant=*/false, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/nullptr, "PyType_Type"); + + auto allocUncollectable = llvm::cast( + M->getOrInsertFunction("seq_alloc_uncollectable", ptr, i64).getCallee()); + allocUncollectable->setDoesNotThrow(); + allocUncollectable->setReturnDoesNotAlias(); + allocUncollectable->setOnlyAccessesInaccessibleMemory(); + + auto free = llvm::cast( + M->getOrInsertFunction("seq_free", B->getVoidTy(), ptr).getCallee()); + free->setDoesNotThrow(); + + // Helpers + auto pyFuncWrap = [&](Func *func, bool wrap) -> llvm::Constant * { + if (!func) + return null; + auto llvmName = getNameForFunction(func); + auto *llvmFunc = M->getFunction(llvmName); + seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName); + if (wrap) + llvmFunc = createPyTryCatchWrapper(llvmFunc); + return llvmFunc; + }; + + auto pyFunc = [&](Func *func) -> llvm::Constant * { return pyFuncWrap(func, true); }; + + auto pyString = [&](const std::string &str) -> llvm::Constant * { + if (str.empty()) + return null; + auto *var = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(i8, str.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, str), ".pyext_str"); + var->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + return var; + }; + + auto pyFunctions = [&](const std::vector &functions) -> llvm::Constant * { + if (functions.empty()) + return null; + + std::vector pyMethods; + for (auto &pyfunc : functions) { + int flag = 0; + if (pyfunc.keywords) { + flag = PYEXT_METH_FASTCALL | PYEXT_METH_KEYWORDS; + } else { + switch (pyfunc.nargs) { + case 0: + flag = PYEXT_METH_NOARGS; + break; + case 1: + flag = PYEXT_METH_O; + break; + default: + flag = PYEXT_METH_FASTCALL; + break; + } + } + + switch (pyfunc.type) { + case PyFunction::CLASS: + flag |= PYEXT_METH_CLASS; + break; + case PyFunction::STATIC: + flag |= PYEXT_METH_STATIC; + break; + default: + break; + } + + if (pyfunc.coexist) + flag |= PYEXT_METH_COEXIST; + + pyMethods.push_back(llvm::ConstantStruct::get( + pyMethodDefType, pyString(pyfunc.name), pyFunc(pyfunc.func), + B->getInt32(flag), pyString(pyfunc.doc))); + } + pyMethods.push_back( + llvm::ConstantStruct::get(pyMethodDefType, null, null, zero32, null)); + + auto *pyMethodDefArrayType = + llvm::ArrayType::get(pyMethodDefType, pyMethods.size()); + auto *pyMethodDefArray = new llvm::GlobalVariable( + *M, pyMethodDefArrayType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantArray::get(pyMethodDefArrayType, pyMethods), ".pyext_methods"); + return pyMethodDefArray; + }; + + auto pyMembers = [&](const std::vector &members, + llvm::StructType *type) -> llvm::Constant * { + if (members.empty()) + return null; + + std::vector pyMemb; + for (auto &memb : members) { + // Calculate offset by creating const GEP into null ptr + std::vector indexes = {zero64, B->getInt32(1)}; + for (auto idx : memb.indexes) { + indexes.push_back(B->getInt32(idx)); + } + auto offset = llvm::ConstantExpr::getPtrToInt( + llvm::ConstantExpr::getGetElementPtr(type, null, indexes), i64); + + pyMemb.push_back(llvm::ConstantStruct::get( + pyMemberDefType, pyString(memb.name), B->getInt32(memb.type), offset, + B->getInt32(memb.readonly ? PYEXT_READONLY : 0), pyString(memb.doc))); + } + pyMemb.push_back( + llvm::ConstantStruct::get(pyMemberDefType, null, zero32, zero64, zero32, null)); + + auto *pyMemberDefArrayType = llvm::ArrayType::get(pyMemberDefType, pyMemb.size()); + auto *pyMemberDefArray = new llvm::GlobalVariable( + *M, pyMemberDefArrayType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantArray::get(pyMemberDefArrayType, pyMemb), ".pyext_members"); + return pyMemberDefArray; + }; + + auto pyGetSet = [&](const std::vector &getset) -> llvm::Constant * { + if (getset.empty()) + return null; + + std::vector pyGS; + for (auto &gs : getset) { + pyGS.push_back(llvm::ConstantStruct::get(pyGetSetDefType, pyString(gs.name), + pyFunc(gs.get), pyFunc(gs.set), + pyString(gs.doc), null)); + } + pyGS.push_back( + llvm::ConstantStruct::get(pyGetSetDefType, null, null, null, null, null)); + + auto *pyGetSetDefArrayType = llvm::ArrayType::get(pyGetSetDefType, pyGS.size()); + auto *pyGetSetDefArray = new llvm::GlobalVariable( + *M, pyGetSetDefArrayType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantArray::get(pyGetSetDefArrayType, pyGS), ".pyext_getset"); + return pyGetSetDefArray; + }; + + // Construct PyModuleDef array + auto *pyObjectConst = llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null); + auto *pyModuleDefBaseConst = + llvm::ConstantStruct::get(pyModuleDefBaseType, pyObjectConst, null, zero64, null); + + auto *pyModuleDef = llvm::ConstantStruct::get( + pyModuleDefType, pyModuleDefBaseConst, pyString(pymod.name), pyString(pymod.doc), + B->getInt64(-1), pyFunctions(pymod.functions), null, null, null, null); + auto *pyModuleVar = + new llvm::GlobalVariable(*M, pyModuleDef->getType(), + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + pyModuleDef, ".pyext_module"); + + std::unordered_map typeVars; + for (auto &pytype : pymod.types) { + std::vector numberSlots = { + pyFunc(pytype.add), // nb_add + pyFunc(pytype.sub), // nb_subtract + pyFunc(pytype.mul), // nb_multiply + pyFunc(pytype.mod), // nb_remainder + pyFunc(pytype.divmod), // nb_divmod + pyFunc(pytype.pow), // nb_power + pyFunc(pytype.neg), // nb_negative + pyFunc(pytype.pos), // nb_positive + pyFunc(pytype.abs), // nb_absolute + pyFunc(pytype.bool_), // nb_bool + pyFunc(pytype.invert), // nb_invert + pyFunc(pytype.lshift), // nb_lshift + pyFunc(pytype.rshift), // nb_rshift + pyFunc(pytype.and_), // nb_and + pyFunc(pytype.xor_), // nb_xor + pyFunc(pytype.or_), // nb_or + pyFunc(pytype.int_), // nb_int + null, // nb_reserved + pyFunc(pytype.float_), // nb_float + pyFunc(pytype.iadd), // nb_inplace_add + pyFunc(pytype.isub), // nb_inplace_subtract + pyFunc(pytype.imul), // nb_inplace_multiply + pyFunc(pytype.imod), // nb_inplace_remainder + pyFunc(pytype.ipow), // nb_inplace_power + pyFunc(pytype.ilshift), // nb_inplace_lshift + pyFunc(pytype.irshift), // nb_inplace_rshift + pyFunc(pytype.iand), // nb_inplace_and + pyFunc(pytype.ixor), // nb_inplace_xor + pyFunc(pytype.ior), // nb_inplace_or + pyFunc(pytype.floordiv), // nb_floor_divide + pyFunc(pytype.truediv), // nb_true_divide + pyFunc(pytype.ifloordiv), // nb_inplace_floor_divide + pyFunc(pytype.itruediv), // nb_inplace_true_divide + pyFunc(pytype.index), // nb_index + pyFunc(pytype.matmul), // nb_matrix_multiply + pyFunc(pytype.imatmul), // nb_inplace_matrix_multiply + }; + + std::vector sequenceSlots = { + pyFunc(pytype.len), // sq_length + null, // sq_concat + null, // sq_repeat + null, // sq_item + null, // was_sq_slice + null, // sq_ass_item + null, // was_sq_ass_slice + pyFunc(pytype.contains), // sq_contains + null, // sq_inplace_concat + null, // sq_inplace_repeat + }; + + std::vector mappingSlots = { + null, // mp_length + pyFunc(pytype.getitem), // mp_subscript + pyFunc(pytype.setitem), // mp_ass_subscript + }; + + bool needNumberSlots = + std::find_if(numberSlots.begin(), numberSlots.end(), + [&](auto *v) { return v != null; }) != numberSlots.end(); + bool needSequenceSlots = + std::find_if(sequenceSlots.begin(), sequenceSlots.end(), + [&](auto *v) { return v != null; }) != sequenceSlots.end(); + bool needMappingSlots = + std::find_if(mappingSlots.begin(), mappingSlots.end(), + [&](auto *v) { return v != null; }) != mappingSlots.end(); + + llvm::Constant *numberSlotsConst = null; + llvm::Constant *sequenceSlotsConst = null; + llvm::Constant *mappingSlotsConst = null; + + if (needNumberSlots) { + auto *pyNumberSlotsVar = new llvm::GlobalVariable( + *M, pyNumberMethodsType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantStruct::get(pyNumberMethodsType, numberSlots), + ".pyext_number_slots." + pytype.name); + numberSlotsConst = pyNumberSlotsVar; + } + + if (needSequenceSlots) { + auto *pySequenceSlotsVar = new llvm::GlobalVariable( + *M, pySequenceMethodsType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantStruct::get(pySequenceMethodsType, sequenceSlots), + ".pyext_sequence_slots." + pytype.name); + sequenceSlotsConst = pySequenceSlotsVar; + } + + if (needMappingSlots) { + auto *pyMappingSlotsVar = new llvm::GlobalVariable( + *M, pyMappingMethodsType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantStruct::get(pyMappingMethodsType, mappingSlots), + ".pyext_mapping_slots." + pytype.name); + mappingSlotsConst = pyMappingSlotsVar; + } + + auto *refType = cast(pytype.type); + auto *llvmType = getLLVMType(pytype.type); + auto *objectType = llvm::StructType::get(pyObjectType, llvmType); + auto codonSize = + refType + ? M->getDataLayout().getTypeAllocSize(getLLVMType(refType->getContents())) + : 0; + auto pySize = M->getDataLayout().getTypeAllocSize(objectType); + + auto *alloc = llvm::cast( + M->getOrInsertFunction(pytype.name + ".py_alloc", ptr, ptr, i64).getCallee()); + { + auto *entry = llvm::BasicBlock::Create(*context, "entry", alloc); + B->SetInsertPoint(entry); + auto *pythonObject = B->CreateCall(allocUncollectable, B->getInt64(pySize)); + auto *header = B->CreateInsertValue( + llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null), + alloc->arg_begin(), 1); + B->CreateStore(header, pythonObject); + if (refType) { + auto *codonObject = B->CreateCall( + makeAllocFunc(refType->getContents()->isAtomic()), B->getInt64(codonSize)); + B->CreateStore(codonObject, B->CreateGEP(objectType, pythonObject, + {zero64, B->getInt32(1)})); + } + B->CreateRet(pythonObject); + } + + auto *delFn = pyFuncWrap(pytype.del, /*wrap=*/false); + auto *dealloc = llvm::cast( + M->getOrInsertFunction(pytype.name + ".py_dealloc", B->getVoidTy(), ptr) + .getCallee()); + { + llvm::Value *obj = dealloc->arg_begin(); + auto *entry = llvm::BasicBlock::Create(*context, "entry", dealloc); + B->SetInsertPoint(entry); + if (delFn != null) + B->CreateCall(llvm::FunctionCallee(dealloc->getFunctionType(), delFn), obj); + B->CreateCall(free, obj); + B->CreateRetVoid(); + } + + auto *pyNew = llvm::cast( + M->getOrInsertFunction("PyType_GenericNew", ptr, ptr, ptr, ptr).getCallee()); + + std::vector typeSlots = { + llvm::ConstantStruct::get( + pyVarObjectType, + llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), pyTypeType), + zero64), // PyObject_VAR_HEAD + pyString(pymod.name + "." + pytype.name), // tp_name + B->getInt64(pySize), // tp_basicsize + zero64, // tp_itemsize + dealloc, // tp_dealloc + zero64, // tp_vectorcall_offset + null, // tp_getattr + null, // tp_setattr + null, // tp_as_async + pyFunc(pytype.repr), // tp_repr + numberSlotsConst, // tp_as_number + sequenceSlotsConst, // tp_as_sequence + mappingSlotsConst, // tp_as_mapping + pyFunc(pytype.hash), // tp_hash + pyFunc(pytype.call), // tp_call + pyFunc(pytype.str), // tp_str + null, // tp_getattro + null, // tp_setattro + null, // tp_as_buffer + zero64, // tp_flags + pyString(pytype.doc), // tp_doc + null, // tp_traverse + null, // tp_clear + pyFunc(pytype.cmp), // tp_richcompare + zero64, // tp_weaklistoffset + pyFunc(pytype.iter), // tp_iter + pyFunc(pytype.iternext), // tp_iternext + pyFunctions(pytype.methods), // tp_methods + pyMembers(pytype.members, objectType), // tp_members + pyGetSet(pytype.getset), // tp_getset + null, // tp_base + null, // tp_dict + null, // tp_descr_get + null, // tp_descr_set + zero64, // tp_dictoffset + pyFunc(pytype.init), // tp_init + alloc, // tp_alloc + pyNew, // tp_new + free, // tp_free + null, // tp_is_gc + null, // tp_bases + null, // tp_mro + null, // tp_cache + null, // tp_subclasses + null, // tp_weaklist + null, // tp_del + zero32, // tp_version_tag + free, // tp_finalize + null, // tp_vectorcall + B->getInt8(0), // tp_watched + }; + + auto *pyTypeObjectVar = new llvm::GlobalVariable( + *M, pyTypeObjectType, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantStruct::get(pyTypeObjectType, typeSlots), + ".pyext_type." + pytype.name); + + if (pytype.typePtrHook) { + auto *hook = llvm::cast(pyFuncWrap(pytype.typePtrHook, false)); + for (auto it = llvm::inst_begin(hook), end = llvm::inst_end(hook); it != end; + ++it) { + if (auto *ret = llvm::dyn_cast(&*it)) + ret->setOperand(0, pyTypeObjectVar); + } + } + + typeVars.emplace(pytype.type, pyTypeObjectVar); + } + + // Construct initialization hook + auto pyIncRef = llvm::cast( + M->getOrInsertFunction("Py_IncRef", B->getVoidTy(), ptr).getCallee()); + pyIncRef->setDoesNotThrow(); + + auto pyDecRef = llvm::cast( + M->getOrInsertFunction("Py_DecRef", B->getVoidTy(), ptr).getCallee()); + pyDecRef->setDoesNotThrow(); + + auto *pyModuleCreate = llvm::cast( + M->getOrInsertFunction("PyModule_Create2", ptr, ptr, i32).getCallee()); + pyModuleCreate->setDoesNotThrow(); + + auto *pyTypeReady = llvm::cast( + M->getOrInsertFunction("PyType_Ready", i32, ptr).getCallee()); + pyTypeReady->setDoesNotThrow(); + + auto *pyModuleAddObject = llvm::cast( + M->getOrInsertFunction("PyModule_AddObject", i32, ptr, ptr, ptr).getCallee()); + pyModuleAddObject->setDoesNotThrow(); + + auto *pyModuleInit = llvm::cast( + M->getOrInsertFunction("PyInit_" + pymod.name, ptr).getCallee()); + auto *block = llvm::BasicBlock::Create(*context, "entry", pyModuleInit); + B->SetInsertPoint(block); + + if (auto *main = M->getFunction("main")) { + main->setName(".main"); + B->CreateCall({main->getFunctionType(), main}, {zero32, null}); + } + + // Set base types + for (auto &pytype : pymod.types) { + if (pytype.base) { + auto subcIt = typeVars.find(pytype.type); + auto baseIt = typeVars.find(pytype.base->type); + seqassertn(subcIt != typeVars.end() && baseIt != typeVars.end(), + "types not found"); + // 30 is the index of tp_base + B->CreateStore(baseIt->second, B->CreateConstInBoundsGEP2_64( + pyTypeObjectType, subcIt->second, 0, 30)); + } + } + + // Call PyType_Ready + for (auto &pytype : pymod.types) { + auto it = typeVars.find(pytype.type); + seqassertn(it != typeVars.end(), "type not found"); + auto *typeVar = it->second; + + auto *fail = llvm::BasicBlock::Create(*context, "failure", pyModuleInit); + block = llvm::BasicBlock::Create(*context, "success", pyModuleInit); + auto *status = B->CreateCall(pyTypeReady, typeVar); + B->CreateCondBr(B->CreateICmpSLT(status, zero32), fail, block); + + B->SetInsertPoint(fail); + B->CreateRet(null); + + B->SetInsertPoint(block); + } + + // Create module + auto *mod = B->CreateCall(pyModuleCreate, + {pyModuleVar, B->getInt32(PYEXT_PYTHON_ABI_VERSION)}); + auto *fail = llvm::BasicBlock::Create(*context, "failure", pyModuleInit); + block = llvm::BasicBlock::Create(*context, "success", pyModuleInit); + + B->CreateCondBr(B->CreateICmpEQ(mod, null), fail, block); + B->SetInsertPoint(fail); + B->CreateRet(null); + + B->SetInsertPoint(block); + + // Add types + for (auto &pytype : pymod.types) { + auto it = typeVars.find(pytype.type); + seqassertn(it != typeVars.end(), "type not found"); + auto *typeVar = it->second; + + B->CreateCall(pyIncRef, typeVar); + auto *status = + B->CreateCall(pyModuleAddObject, {mod, pyString(pytype.name), typeVar}); + fail = llvm::BasicBlock::Create(*context, "failure", pyModuleInit); + block = llvm::BasicBlock::Create(*context, "success", pyModuleInit); + B->CreateCondBr(B->CreateICmpSLT(status, zero32), fail, block); + + B->SetInsertPoint(fail); + B->CreateCall(pyDecRef, typeVar); + B->CreateCall(pyDecRef, mod); + B->CreateRet(null); + + B->SetInsertPoint(block); + } + B->CreateRet(mod); + + writeToObjectFile(filename); +} + void LLVMVisitor::compile(const std::string &filename, const std::string &argv0, const std::vector &libs, const std::string &lflags) { @@ -706,6 +1348,7 @@ llvm::GlobalVariable *LLVMVisitor::getTypeIdxVar(const std::string &name) { tidx = new llvm::GlobalVariable( *M, typeInfoType, /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, llvm::ConstantStruct::get(typeInfoType, B->getInt32(idx)), typeVarName); + tidx->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); } return tidx; } @@ -1728,10 +2371,10 @@ void LLVMVisitor::visit(const BoolConst *x) { void LLVMVisitor::visit(const StringConst *x) { B->SetInsertPoint(block); std::string s = x->getVal(); - auto *strVar = new llvm::GlobalVariable( - *M, llvm::ArrayType::get(B->getInt8Ty(), s.length() + 1), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, - llvm::ConstantDataArray::getString(*context, s), "str_literal"); + auto *strVar = + new llvm::GlobalVariable(*M, llvm::ArrayType::get(B->getInt8Ty(), s.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, s), ".str"); strVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getInt8PtrTy()); llvm::Value *ptr = B->CreateBitCast(strVar, B->getInt8PtrTy()); @@ -2419,6 +3062,9 @@ void LLVMVisitor::visit(const TypePropertyInstr *x) { case TypePropertyInstr::Property::IS_ATOMIC: value = B->getInt8(x->getInspectType()->isAtomic() ? 1 : 0); break; + case TypePropertyInstr::Property::IS_CONTENT_ATOMIC: + value = B->getInt8(x->getInspectType()->isContentAtomic() ? 1 : 0); + break; default: seqassertn(0, "unknown type property"); } diff --git a/codon/cir/llvm/llvisitor.h b/codon/cir/llvm/llvisitor.h index 499bf98a..7ab9a0eb 100644 --- a/codon/cir/llvm/llvisitor.h +++ b/codon/cir/llvm/llvisitor.h @@ -4,6 +4,7 @@ #include "codon/cir/cir.h" #include "codon/cir/llvm/llvm.h" +#include "codon/cir/pyextension.h" #include "codon/dsl/plugins.h" #include "codon/util/common.h" @@ -198,6 +199,9 @@ class LLVMVisitor : public util::ConstVisitor { // Shared library setup void setupGlobalCtorForSharedLibrary(); + // Python extension setup + llvm::Function *createPyTryCatchWrapper(llvm::Function *func); + // LLVM passes void runLLVMPipeline(); @@ -213,6 +217,7 @@ class LLVMVisitor : public util::ConstVisitor { public: static std::string getNameForFunction(const Func *x); + static std::string getNameForVar(const Var *x); static std::string getDebugNameForVariable(const Var *x) { std::string name = x->getName(); @@ -342,6 +347,10 @@ class LLVMVisitor : public util::ConstVisitor { bool library = false, const std::vector &libs = {}, const std::string &lflags = ""); + /// Writes module as Python extension object. + /// @param pymod extension module + /// @param filename the file to write to + void writeToPythonExtension(const PyModule &pymod, const std::string &filename); /// Runs optimization passes on module and writes the result /// to the specified file. The output type is determined by /// the file extension (.ll for LLVM IR, .bc for LLVM bitcode diff --git a/codon/cir/llvm/optimize.cpp b/codon/cir/llvm/optimize.cpp index ff9e05fb..87a4aa6a 100644 --- a/codon/cir/llvm/optimize.cpp +++ b/codon/cir/llvm/optimize.cpp @@ -53,11 +53,9 @@ void applyDebugTransformations(llvm::Module *module, bool debug, bool jit) { if (debug) { // remove tail calls and fix linkage for stack traces for (auto &f : *module) { -#ifdef __APPLE__ // needed for debug symbols if (!jit) f.setLinkage(llvm::GlobalValue::ExternalLinkage); -#endif if (!f.hasFnAttribute(llvm::Attribute::AttrKind::AlwaysInline)) f.addFnAttr(llvm::Attribute::AttrKind::NoInline); f.setUWTableKind(llvm::UWTableKind::Default); @@ -81,16 +79,16 @@ void applyDebugTransformations(llvm::Module *module, bool debug, bool jit) { /// Lowers allocations of known, small size to alloca when possible. /// Also removes unused allocations. struct AllocationRemover : public llvm::PassInfoMixin { - std::string alloc; - std::string allocAtomic; + std::vector allocators; std::string realloc; std::string free; - AllocationRemover(const std::string &alloc = "seq_alloc", - const std::string &allocAtomic = "seq_alloc_atomic", - const std::string &realloc = "seq_realloc", - const std::string &free = "seq_free") - : alloc(alloc), allocAtomic(allocAtomic), realloc(realloc), free(free) {} + explicit AllocationRemover( + std::vector allocators = {"seq_alloc", "seq_alloc_atomic", + "seq_alloc_uncollectable", + "seq_alloc_atomic_uncollectable"}, + const std::string &realloc = "seq_realloc", const std::string &free = "seq_free") + : allocators(std::move(allocators)), realloc(realloc), free(free) {} static bool sizeOkToDemote(uint64_t size) { return 0 < size && size <= 1024; } @@ -110,8 +108,8 @@ struct AllocationRemover : public llvm::PassInfoMixin { bool isAlloc(const llvm::Value *value) { if (auto *func = getCalledFunction(value)) { - return func->arg_size() == 1 && - (func->getName() == alloc || func->getName() == allocAtomic); + return func->arg_size() == 1 && std::find(allocators.begin(), allocators.end(), + func->getName()) != allocators.end(); } return false; } diff --git a/codon/cir/pyextension.h b/codon/cir/pyextension.h new file mode 100644 index 00000000..f01f97b7 --- /dev/null +++ b/codon/cir/pyextension.h @@ -0,0 +1,134 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#pragma once + +#include +#include + +#include "codon/cir/func.h" +#include "codon/cir/types/types.h" + +namespace codon { +namespace ir { + +struct PyFunction { + enum Type { TOPLEVEL, METHOD, CLASS, STATIC }; + std::string name; + std::string doc; + Func *func = nullptr; + Type type = Type::TOPLEVEL; + int nargs = 0; + bool keywords = false; + bool coexist = false; +}; + +struct PyMember { + enum Type { + SHORT = 0, + INT = 1, + LONG = 2, + FLOAT = 3, + DOUBLE = 4, + STRING = 5, + OBJECT = 6, + CHAR = 7, + BYTE = 8, + UBYTE = 9, + USHORT = 10, + UINT = 11, + ULONG = 12, + STRING_INPLACE = 13, + BOOL = 14, + OBJECT_EX = 16, + LONGLONG = 17, + ULONGLONG = 18, + PYSSIZET = 19, + }; + + std::string name; + std::string doc; + Type type = Type::SHORT; + bool readonly = false; + /// Indexes of the member. For example, in the + /// tuple (a, (b, c, (d,))), 'a' would have indexes + /// [0], 'b' would have indexes [1, 0], 'c' would + /// have indexes [1, 1], and 'd' would have indexes + /// [1, 2, 0]. This corresponds to an LLVM GEP. + std::vector indexes; +}; + +struct PyGetSet { + std::string name; + std::string doc; + Func *get = nullptr; + Func *set = nullptr; +}; + +struct PyType { + std::string name; + std::string doc; + types::Type *type = nullptr; + PyType *base = nullptr; + Func *repr = nullptr; + Func *add = nullptr; + Func *iadd = nullptr; + Func *sub = nullptr; + Func *isub = nullptr; + Func *mul = nullptr; + Func *imul = nullptr; + Func *mod = nullptr; + Func *imod = nullptr; + Func *divmod = nullptr; + Func *pow = nullptr; + Func *ipow = nullptr; + Func *neg = nullptr; + Func *pos = nullptr; + Func *abs = nullptr; + Func *bool_ = nullptr; + Func *invert = nullptr; + Func *lshift = nullptr; + Func *ilshift = nullptr; + Func *rshift = nullptr; + Func *irshift = nullptr; + Func *and_ = nullptr; + Func *iand = nullptr; + Func *xor_ = nullptr; + Func *ixor = nullptr; + Func *or_ = nullptr; + Func *ior = nullptr; + Func *int_ = nullptr; + Func *float_ = nullptr; + Func *floordiv = nullptr; + Func *ifloordiv = nullptr; + Func *truediv = nullptr; + Func *itruediv = nullptr; + Func *index = nullptr; + Func *matmul = nullptr; + Func *imatmul = nullptr; + Func *len = nullptr; + Func *getitem = nullptr; + Func *setitem = nullptr; + Func *contains = nullptr; + Func *hash = nullptr; + Func *call = nullptr; + Func *str = nullptr; + Func *cmp = nullptr; + Func *iter = nullptr; + Func *iternext = nullptr; + Func *del = nullptr; + Func *init = nullptr; + std::vector methods; + std::vector members; + std::vector getset; + Func *typePtrHook = nullptr; +}; + +struct PyModule { + std::string name; + std::string doc; + std::vector functions; + std::vector types; +}; + +} // namespace ir +} // namespace codon diff --git a/codon/cir/transform/manager.h b/codon/cir/transform/manager.h index 9ef74bee..c36cf41b 100644 --- a/codon/cir/transform/manager.h +++ b/codon/cir/transform/manager.h @@ -94,6 +94,9 @@ class PassManager { /// whether to use Python (vs. C) numeric semantics in passes bool pyNumerics; + /// true if we are compiling as a Python extension + bool pyExtension; + public: /// PassManager initialization mode. enum Init { @@ -104,16 +107,17 @@ class PassManager { }; explicit PassManager(Init init, std::vector disabled = {}, - bool pyNumerics = false) + bool pyNumerics = false, bool pyExtension = false) : km(), passes(), analyses(), executionOrder(), results(), - disabled(std::move(disabled)), pyNumerics(pyNumerics) { + disabled(std::move(disabled)), pyNumerics(pyNumerics), + pyExtension(pyExtension) { registerStandardPasses(init); } explicit PassManager(bool debug = false, std::vector disabled = {}, - bool pyNumerics = false) + bool pyNumerics = false, bool pyExtension = false) : PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled), - pyNumerics) {} + pyNumerics, pyExtension) {} /// Checks if the given pass is included in this manager. /// @param key the pass key diff --git a/codon/cir/types/types.cpp b/codon/cir/types/types.cpp index 078230b0..1688f4f2 100644 --- a/codon/cir/types/types.cpp +++ b/codon/cir/types/types.cpp @@ -123,6 +123,13 @@ void RecordType::realize(std::vector mTypes, std::vector mN const char RefType::NodeId = 0; +bool RefType::doIsContentAtomic() const { + auto *contents = getContents(); + return !std::any_of(contents->begin(), contents->end(), [](auto &field) { + return field.getName().rfind(".__vtable__", 0) != 0 && !field.getType()->isAtomic(); + }); +} + Value *RefType::doConstruct(std::vector args) { auto *module = getModule(); auto *argsTuple = util::makeTuple(args, module); diff --git a/codon/cir/types/types.h b/codon/cir/types/types.h index fff29aa8..e2952dbb 100644 --- a/codon/cir/types/types.h +++ b/codon/cir/types/types.h @@ -95,6 +95,11 @@ class Type : public ReplaceableNodeBase { /// @return true if the type is atomic bool isAtomic() const { return getActual()->doIsAtomic(); } + /// Checks if the contents (i.e. within an allocated block of memory) + /// of a type are atomic. Currently only meaningful for reference types. + /// @return true if the type's content is atomic + bool isContentAtomic() const { return getActual()->doIsContentAtomic(); } + /// @return the ast type ast::types::TypePtr getAstType() const { return getActual()->astType; } /// Sets the ast type. Should not generally be used. @@ -121,6 +126,7 @@ class Type : public ReplaceableNodeBase { virtual std::vector doGetUsedTypes() const { return {}; } virtual bool doIsAtomic() const = 0; + virtual bool doIsContentAtomic() const { return true; } virtual Value *doConstruct(std::vector args); }; @@ -329,6 +335,8 @@ class RefType : public AcceptorExtend { bool doIsAtomic() const override { return false; } + bool doIsContentAtomic() const override; + Value *doConstruct(std::vector args) override; }; diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 54f7dd86..5ee5b509 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -32,15 +32,16 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is Compiler::Compiler(const std::string &argv0, Compiler::Mode mode, const std::vector &disabledPasses, bool isTest, - bool pyNumerics) - : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(), - plm(std::make_unique(argv0)), + bool pyNumerics, bool pyExtension) + : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), + pyExtension(pyExtension), input(), plm(std::make_unique(argv0)), cache(std::make_unique(argv0)), module(std::make_unique()), - pm(std::make_unique(getPassManagerInit(mode, isTest), - disabledPasses, pyNumerics)), + pm(std::make_unique( + getPassManagerInit(mode, isTest), disabledPasses, pyNumerics, pyExtension)), llvisitor(std::make_unique()) { cache->module = module.get(); + cache->pythonExt = pyExtension; cache->pythonCompat = pyNumerics; module->setCache(cache.get()); llvisitor->setDebug(debug); @@ -181,6 +182,14 @@ std::unordered_map Compiler::getEarlyDefines() { std::unordered_map earlyDefines; earlyDefines.emplace("__debug__", debug ? "1" : "0"); earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0"); + earlyDefines.emplace("__py_extension__", pyExtension ? "1" : "0"); + earlyDefines.emplace("__apple__", +#if __APPLE__ + "1" +#else + "0" +#endif + ); return earlyDefines; } diff --git a/codon/compiler/compiler.h b/codon/compiler/compiler.h index cc13bfa7..de188e67 100644 --- a/codon/compiler/compiler.h +++ b/codon/compiler/compiler.h @@ -28,6 +28,7 @@ class Compiler { std::string argv0; bool debug; bool pyNumerics; + bool pyExtension; std::string input; std::unique_ptr plm; std::unique_ptr cache; @@ -42,13 +43,14 @@ class Compiler { public: Compiler(const std::string &argv0, Mode mode, const std::vector &disabledPasses = {}, bool isTest = false, - bool pyNumerics = false); + bool pyNumerics = false, bool pyExtension = false); explicit Compiler(const std::string &argv0, bool debug = false, const std::vector &disabledPasses = {}, - bool isTest = false, bool pyNumerics = false) + bool isTest = false, bool pyNumerics = false, + bool pyExtension = false) : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest, - pyNumerics) {} + pyNumerics, pyExtension) {} std::string getInput() const { return input; } PluginManager *getPluginManager() const { return plm.get(); } diff --git a/codon/dsl/dsl.h b/codon/dsl/dsl.h index 8bdeb8ca..3f59e690 100644 --- a/codon/dsl/dsl.h +++ b/codon/dsl/dsl.h @@ -34,6 +34,8 @@ class DSL { std::string stdlibPath; /// Plugin dynamic library path std::string dylibPath; + /// Linker arguments (to replace "-l dylibPath" if present) + std::vector linkArgs; }; using KeywordCallback = diff --git a/codon/dsl/plugins.cpp b/codon/dsl/plugins.cpp index 06dc61c5..83179002 100644 --- a/codon/dsl/plugins.cpp +++ b/codon/dsl/plugins.cpp @@ -57,6 +57,22 @@ llvm::Expected PluginManager::load(const std::string &path) { dylibPath = p.str(); } + auto link = library["link"]; + std::vector linkArgs; + if (auto arr = link.as_array()) { + arr->for_each([&linkArgs](auto &&el) { + std::string l = el.value_or(""); + if (!l.empty()) + linkArgs.push_back(l); + }); + } else { + std::string l = link.value_or(""); + if (!l.empty()) + linkArgs.push_back(l); + } + for (auto &l : linkArgs) + l = fmt::format(l, fmt::arg("root", llvm::sys::path::parent_path(tomlPath))); + std::string codonLib = library["codon"].value_or(""); std::string stdlibPath; if (!codonLib.empty()) { @@ -71,7 +87,8 @@ llvm::Expected PluginManager::load(const std::string &path) { about["url"].value_or(""), about["supported"].value_or(""), stdlibPath, - dylibPath}; + dylibPath, + linkArgs}; bool versionOk = false; try { diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index 457dcf4c..aae0e0c1 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -470,6 +470,65 @@ size_t FunctionStmt::getKwStarArgs() const { } return i; } +std::string FunctionStmt::getDocstr() { + if (auto s = suite->firstInBlock()) { + if (auto e = s->getExpr()) { + if (auto ss = e->expr->getString()) + return ss->getValue(); + } + } + return ""; +} + +// Search expression tree for a identifier +class IdSearchVisitor : public CallbackASTVisitor { + std::string what; + bool result; + +public: + IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {} + bool transform(const std::shared_ptr &expr) override { + if (result) + return result; + IdSearchVisitor v(what); + if (expr) + expr->accept(v); + return result = v.result; + } + bool transform(const std::shared_ptr &stmt) override { + if (result) + return result; + IdSearchVisitor v(what); + if (stmt) + stmt->accept(v); + return result = v.result; + } + void visit(IdExpr *expr) override { + if (expr->value == what) + result = true; + } +}; + +/// Check if a function can be called with the given arguments. +/// See @c reorderNamedArgs for details. +std::unordered_set FunctionStmt::getNonInferrableGenerics() { + std::unordered_set nonInferrableGenerics; + for (auto &a : args) { + if (a.status == Param::Generic && !a.defaultValue) { + bool inferrable = false; + for (auto &b : args) + if (b.type && IdSearchVisitor(a.name).transform(b.type)) { + inferrable = true; + break; + } + if (ret && IdSearchVisitor(a.name).transform(ret)) + inferrable = true; + if (!inferrable) + nonInferrableGenerics.insert(a.name); + } + } + return nonInferrableGenerics; +} ClassStmt::ClassStmt(std::string name, std::vector args, StmtPtr suite, std::vector decorators, std::vector baseClasses, @@ -555,6 +614,8 @@ void ClassStmt::parseDecorators() { for (auto &d : decorators) { if (d->isId("deduce")) { attributes.customAttr.insert("deduce"); + } else if (d->isId("__notuple__")) { + attributes.customAttr.insert("__notuple__"); } else if (auto c = d->getCall()) { if (c->expr->isId(Attr::Tuple)) { attributes.set(Attr::Tuple); @@ -622,6 +683,7 @@ void ClassStmt::parseDecorators() { } if (startswith(name, TYPE_TUPLE)) { tupleMagics["add"] = true; + tupleMagics["mul"] = true; } else { tupleMagics["dict"] = true; } @@ -644,6 +706,15 @@ bool ClassStmt::isClassVar(const Param &p) { return i->expr->isId("ClassVar"); return false; } +std::string ClassStmt::getDocstr() { + if (auto s = suite->firstInBlock()) { + if (auto e = s->getExpr()) { + if (auto ss = e->expr->getString()) + return ss->getValue(); + } + } + return ""; +} YieldFromStmt::YieldFromStmt(ExprPtr expr) : Stmt(), expr(std::move(expr)) {} YieldFromStmt::YieldFromStmt(const YieldFromStmt &stmt) diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index b7e57f1b..293ed5fa 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -481,6 +481,8 @@ struct FunctionStmt : public Stmt { size_t getKwStarArgs() const; FunctionStmt *getFunction() override { return this; } + std::string getDocstr(); + std::unordered_set getNonInferrableGenerics(); }; /// Class statement (@(attributes...) class name[generics...]: args... ; suite). @@ -515,6 +517,7 @@ struct ClassStmt : public Stmt { void parseDecorators(); static bool isClassVar(const Param &p); + std::string getDocstr(); }; /// Yield-from statement (yield from expr). diff --git a/codon/parser/ast/types/class.cpp b/codon/parser/ast/types/class.cpp index ef272418..d3ef83f1 100644 --- a/codon/parser/ast/types/class.cpp +++ b/codon/parser/ast/types/class.cpp @@ -129,12 +129,14 @@ std::string ClassType::realizedTypeName() const { } RecordType::RecordType(Cache *cache, std::string name, std::string niceName, - std::vector generics, std::vector args) + std::vector generics, std::vector args, + bool noTuple) : ClassType(cache, std::move(name), std::move(niceName), std::move(generics)), - args(std::move(args)) {} + args(std::move(args)), noTuple(false) {} -RecordType::RecordType(const ClassTypePtr &base, std::vector args) - : ClassType(base), args(std::move(args)) {} +RecordType::RecordType(const ClassTypePtr &base, std::vector args, + bool noTuple) + : ClassType(base), args(std::move(args)), noTuple(noTuple) {} int RecordType::unify(Type *typ, Unification *us) { if (auto tr = typ->getRecord()) { @@ -157,7 +159,10 @@ int RecordType::unify(Type *typ, Unification *us) { } // Handle Tuple<->@tuple: when unifying tuples, only record members matter. if (startswith(name, TYPE_TUPLE) || startswith(tr->name, TYPE_TUPLE)) { - return s1 + int(name == tr->name); + if (!args.empty() || (!noTuple && !tr->noTuple)) // prevent POD<->() unification + return s1 + int(name == tr->name); + else + return -1; } return this->ClassType::unify(tr.get(), us); } else if (auto t = typ->getLink()) { @@ -172,7 +177,7 @@ TypePtr RecordType::generalize(int atLevel) { auto a = args; for (auto &t : a) t = t->generalize(atLevel); - return std::make_shared(c, a); + return std::make_shared(c, a, noTuple); } TypePtr RecordType::instantiate(int atLevel, int *unboundCount, @@ -182,7 +187,7 @@ TypePtr RecordType::instantiate(int atLevel, int *unboundCount, auto a = args; for (auto &t : a) t = t->instantiate(atLevel, unboundCount, cache); - return std::make_shared(c, a); + return std::make_shared(c, a, noTuple); } std::vector RecordType::getUnbounds() const { diff --git a/codon/parser/ast/types/class.h b/codon/parser/ast/types/class.h index 1865a9b7..973a07ee 100644 --- a/codon/parser/ast/types/class.h +++ b/codon/parser/ast/types/class.h @@ -77,13 +77,13 @@ using ClassTypePtr = std::shared_ptr; struct RecordType : public ClassType { /// List of tuple arguments. std::vector args; - char flags = 0; + bool noTuple; explicit RecordType( Cache *cache, std::string name, std::string niceName, std::vector generics = std::vector(), - std::vector args = std::vector()); - RecordType(const ClassTypePtr &base, std::vector args); + std::vector args = std::vector(), bool noTuple = false); + RecordType(const ClassTypePtr &base, std::vector args, bool noTuple = false); public: int unify(Type *typ, Unification *undo) override; diff --git a/codon/parser/ast/types/function.cpp b/codon/parser/ast/types/function.cpp index 19a81d88..677f07cf 100644 --- a/codon/parser/ast/types/function.cpp +++ b/codon/parser/ast/types/function.cpp @@ -158,6 +158,7 @@ std::string FuncType::realizedName() const { PartialType::PartialType(const std::shared_ptr &baseType, std::shared_ptr func, std::vector known) : RecordType(*baseType), func(std::move(func)), known(std::move(known)) {} + int PartialType::unify(Type *typ, Unification *us) { return this->RecordType::unify(typ, us); } diff --git a/codon/parser/ast/types/traits.cpp b/codon/parser/ast/types/traits.cpp index 0cf36a16..84608bc0 100644 --- a/codon/parser/ast/types/traits.cpp +++ b/codon/parser/ast/types/traits.cpp @@ -167,7 +167,8 @@ TypePtr CallableTrait::instantiate(int atLevel, int *unboundCount, } std::string CallableTrait::debugString(char mode) const { - return fmt::format("Callable[{},{}]", args[0]->debugString(mode).substr(5), + auto s = args[0]->debugString(mode); + return fmt::format("Callable[{},{}]", startswith(s, "Tuple") ? s.substr(5) : s, args[1]->debugString(mode)); } diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index f7ab2a68..75df1de2 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -6,6 +6,8 @@ #include #include +#include "codon/cir/pyextension.h" +#include "codon/cir/util/irtools.h" #include "codon/parser/common.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/simplify/simplify.h" @@ -18,7 +20,7 @@ namespace codon::ast { Cache::Cache(std::string argv0) : generatedSrcInfoCount(0), unboundCount(256), varCount(0), age(0), argv0(std::move(argv0)), typeCtx(nullptr), codegenCtx(nullptr), isJit(false), - jitCell(0) {} + jitCell(0), pythonExt(false), pyModule(nullptr) {} std::string Cache::getTemporaryVar(const std::string &prefix, char sigil) { return fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix, @@ -229,4 +231,327 @@ std::vector Cache::mergeC3(std::vector> &seqs) { return result; } +/** + * Generate Python bindings for Cython-like access. + * + * TODO: this function is total mess. Needs refactoring. + */ +void Cache::populatePythonModule() { + if (!pythonExt) + return; + + LOG_USER("[py] ====== module generation ======="); + +#define N std::make_shared + + if (!pyModule) + pyModule = std::make_shared(); + using namespace ast; + + int oldAge = typeCtx->age; + typeCtx->age = 99999; + + auto realizeIR = [&](const types::FuncTypePtr &fn, + const std::vector &generics = {}) -> ir::Func * { + auto fnType = typeCtx->instantiate(fn); + types::Type::Unification u; + for (size_t i = 0; i < generics.size(); i++) + fnType->getFunc()->funcGenerics[i].type->unify(generics[i].get(), &u); + fnType = TypecheckVisitor(typeCtx).realize(fnType); + if (!fnType) + return nullptr; + + auto pr = pendingRealizations; // copy it as it might be modified + for (auto &fn : pr) + TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); + return functions[fn->ast->name].realizations[fnType->realizedName()]->ir; + }; + + const std::string pyWrap = "std.internal.python._PyWrap"; + auto clss = classes; // needs copy as below fns can mutate this + for (const auto &[cn, c] : clss) { + if (c.module.empty()) { + if (!in(c.methods, "__to_py__") || !in(c.methods, "__from_py__")) + continue; + + LOG_USER("[py] Cythonizing {}", cn); + ir::PyType py{rev(cn), c.ast->getDocstr()}; + + auto tc = typeCtx->forceFind(cn)->type; + if (!tc->canRealize()) + compilationError(fmt::format("cannot realize '{}' for Python export", rev(cn))); + tc = TypecheckVisitor(typeCtx).realize(tc); + seqassertn(tc, "cannot realize '{}'", cn); + + // 1. Replace to_py / from_py with _PyWrap.wrap_to_py/from_py + if (auto ofnn = in(c.methods, "__to_py__")) { + auto fnn = overloads[*ofnn].begin()->name; // default first overload! + auto &fna = functions[fnn].ast; + fna->getFunction()->suite = N(N( + N(pyWrap + ".wrap_to_py:0"), N(fna->args[0].name))); + } + if (auto ofnn = in(c.methods, "__from_py__")) { + auto fnn = overloads[*ofnn].begin()->name; // default first overload! + auto &fna = functions[fnn].ast; + fna->getFunction()->suite = + N(N(N(pyWrap + ".wrap_from_py:0"), + N(fna->args[0].name), N(cn))); + } + for (auto &n : std::vector{"__from_py__", "__to_py__"}) { + auto fnn = overloads[*in(c.methods, n)].begin()->name; + ir::Func *oldIR = nullptr; + if (!functions[fnn].realizations.empty()) + oldIR = functions[fnn].realizations.begin()->second->ir; + functions[fnn].realizations.clear(); + auto tf = TypecheckVisitor(typeCtx).realize(functions[fnn].type); + seqassertn(tf, "cannot re-realize '{}'", fnn); + if (oldIR) { + std::vector args; + for (auto it = oldIR->arg_begin(); it != oldIR->arg_end(); ++it) { + args.push_back(module->Nr(*it)); + } + ir::cast(oldIR)->setBody(ir::util::series( + ir::util::call(functions[fnn].realizations.begin()->second->ir, args))); + } + } + for (auto &[rn, r] : functions[pyWrap + ".py_type:0"].realizations) { + if (r->type->funcGenerics[0].type->unify(tc.get(), nullptr) >= 0) { + py.typePtrHook = r->ir; + break; + } + } + + // 2. Handle methods + auto methods = c.methods; + for (const auto &[n, ofnn] : methods) { + auto canonicalName = overloads[ofnn].back().name; + if (overloads[ofnn].size() == 1 && + functions[canonicalName].ast->hasAttr("autogenerated")) + continue; + auto fna = functions[canonicalName].ast; + bool isMethod = fna->hasAttr(Attr::Method); + bool isProperty = fna->hasAttr(Attr::Property); + + std::string call = pyWrap + ".wrap_multiple"; + bool isMagic = false; + if (startswith(n, "__") && endswith(n, "__")) { + auto m = n.substr(2, n.size() - 4); + if (m == "new" && c.ast->hasAttr(Attr::Tuple)) + m = "init"; + if (auto i = in(classes[pyWrap].methods, "wrap_magic_" + m)) { + call = *i; + isMagic = true; + } + } + if (isProperty) + call = pyWrap + ".wrap_get"; + + auto fnName = call + ":0"; + seqassertn(in(functions, fnName), "bad name"); + auto generics = std::vector{tc}; + if (isProperty) { + generics.push_back( + std::make_shared(this, rev(canonicalName))); + } else if (!isMagic) { + generics.push_back(std::make_shared(this, n)); + generics.push_back(std::make_shared(this, (int)isMethod)); + } + auto f = realizeIR(functions[fnName].type, generics); + if (!f) + continue; + + LOG_USER("[py] {} -> {} ({}; {})", n, call, isMethod, isProperty); + if (isProperty) { + py.getset.push_back({rev(canonicalName), "", f, nullptr}); + } else if (n == "__repr__") { + py.repr = f; + } else if (n == "__add__") { + py.add = f; + } else if (n == "__iadd__") { + py.iadd = f; + } else if (n == "__sub__") { + py.sub = f; + } else if (n == "__isub__") { + py.isub = f; + } else if (n == "__mul__") { + py.mul = f; + } else if (n == "__imul__") { + py.imul = f; + } else if (n == "__mod__") { + py.mod = f; + } else if (n == "__imod__") { + py.imod = f; + } else if (n == "__divmod__") { + py.divmod = f; + } else if (n == "__pow__") { + py.pow = f; + } else if (n == "__ipow__") { + py.ipow = f; + } else if (n == "__neg__") { + py.neg = f; + } else if (n == "__pos__") { + py.pos = f; + } else if (n == "__abs__") { + py.abs = f; + } else if (n == "__bool__") { + py.bool_ = f; + } else if (n == "__invert__") { + py.invert = f; + } else if (n == "__lshift__") { + py.lshift = f; + } else if (n == "__ilshift__") { + py.ilshift = f; + } else if (n == "__rshift__") { + py.rshift = f; + } else if (n == "__irshift__") { + py.irshift = f; + } else if (n == "__and__") { + py.and_ = f; + } else if (n == "__iand__") { + py.iand = f; + } else if (n == "__xor__") { + py.xor_ = f; + } else if (n == "__ixor__") { + py.ixor = f; + } else if (n == "__or__") { + py.or_ = f; + } else if (n == "__ior__") { + py.ior = f; + } else if (n == "__int__") { + py.int_ = f; + } else if (n == "__float__") { + py.float_ = f; + } else if (n == "__floordiv__") { + py.floordiv = f; + } else if (n == "__ifloordiv__") { + py.ifloordiv = f; + } else if (n == "__truediv__") { + py.truediv = f; + } else if (n == "__itruediv__") { + py.itruediv = f; + } else if (n == "__index__") { + py.index = f; + } else if (n == "__matmul__") { + py.matmul = f; + } else if (n == "__imatmul__") { + py.imatmul = f; + } else if (n == "__len__") { + py.len = f; + } else if (n == "__getitem__") { + py.getitem = f; + } else if (n == "__setitem__") { + py.setitem = f; + } else if (n == "__contains__") { + py.contains = f; + } else if (n == "__hash__") { + py.hash = f; + } else if (n == "__call__") { + py.call = f; + } else if (n == "__str__") { + py.str = f; + } else if (n == "__iter__") { + py.iter = f; + } else if (n == "__del__") { + py.del = f; + } else if (n == "__init__" || (c.ast->hasAttr(Attr::Tuple) && n == "__new__")) { + py.init = f; + } else { + py.methods.push_back(ir::PyFunction{ + n, fna->getDocstr(), f, + fna->hasAttr(Attr::Method) ? ir::PyFunction::Type::METHOD + : ir::PyFunction::Type::CLASS, + // always use FASTCALL for now; works even for 0- or 1- arg methods + 2}); + py.methods.back().keywords = true; + } + } + + for (auto &m : py.methods) { + if (in(std::set{"__lt__", "__le__", "__eq__", "__ne__", "__gt__", + "__ge__"}, + m.name)) { + py.cmp = realizeIR( + typeCtx->forceFind(pyWrap + ".wrap_cmp:0")->type->getFunc(), {tc}); + break; + } + } + + if (c.realizations.size() != 1) + compilationError(fmt::format("cannot pythonize generic class '{}'", cn)); + auto &r = c.realizations.begin()->second; + py.type = realizeType(r->type); + for (auto &[mn, mt] : r->fields) { + /// TODO: handle PyMember for tuples + // Generate getters & setters + auto generics = std::vector{ + tc, std::make_shared(this, mn)}; + auto gf = realizeIR(functions[pyWrap + ".wrap_get:0"].type, generics); + ir::Func *sf = nullptr; + if (!c.ast->hasAttr(Attr::Tuple)) + sf = realizeIR(functions[pyWrap + ".wrap_set:0"].type, generics); + py.getset.push_back({mn, "", gf, sf}); + LOG_USER("[py] {}: {} . {}", "member", cn, mn); + } + pyModule->types.push_back(py); + } + } + + // Handle __iternext__ wrappers + auto cin = "_PyWrap.IterWrap"; + for (auto &[cn, cr] : classes[cin].realizations) { + LOG_USER("[py] iterfn: {}", cn); + ir::PyType py{cn, ""}; + auto tc = cr->type; + for (auto &[rn, r] : functions[pyWrap + ".py_type:0"].realizations) { + if (r->type->funcGenerics[0].type->unify(tc.get(), nullptr) >= 0) { + py.typePtrHook = r->ir; + break; + } + } + + auto &methods = classes[cin].methods; + for (auto &n : std::vector{"_iter", "_iternext"}) { + auto fnn = overloads[methods[n]].begin()->name; + auto &fna = functions[fnn]; + auto ft = typeCtx->instantiate(fna.type, tc->getClass()); + auto rtv = TypecheckVisitor(typeCtx).realize(ft); + auto f = + functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; + if (n == "_iter") + py.iter = f; + else + py.iternext = f; + } + py.type = cr->ir; + pyModule->types.push_back(py); + } +#undef N + + auto fns = functions; // needs copy as below fns can mutate this + for (const auto &[fn, f] : fns) { + if (f.isToplevel) { + std::string call = pyWrap + ".wrap_multiple"; + auto fnName = call + ":0"; + seqassertn(in(functions, fnName), "bad name"); + auto generics = std::vector{ + typeCtx->forceFind(".toplevel")->type, + std::make_shared(this, rev(f.ast->name)), + std::make_shared(this, 0)}; + if (auto ir = realizeIR(functions[fnName].type, generics)) { + LOG_USER("[py] {}: {}", "toplevel", fn); + pyModule->functions.push_back(ir::PyFunction{rev(fn), f.ast->getDocstr(), ir, + ir::PyFunction::Type::TOPLEVEL, + int(f.ast->args.size())}); + pyModule->functions.back().keywords = true; + } + } + } + + // Handle pending realizations! + auto pr = pendingRealizations; // copy it as it might be modified + for (auto &fn : pr) + TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); + typeCtx->age = oldAge; +} + } // namespace codon::ast diff --git a/codon/parser/cache.h b/codon/parser/cache.h index d230bac1..498f64e3 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -9,6 +9,7 @@ #include #include "codon/cir/cir.h" +#include "codon/cir/pyextension.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/ctx.h" @@ -156,6 +157,9 @@ struct Cache : public std::enable_shared_from_this { /// List of statically inherited classes. std::vector staticParentClasses; + /// Module information + std::string module; + Class() : ast(nullptr), originalAst(nullptr) {} }; /// Class lookup table that maps a canonical class identifier to the corresponding @@ -186,7 +190,13 @@ struct Cache : public std::enable_shared_from_this { /// Unrealized function type. types::FuncTypePtr type; - Function() : ast(nullptr), origAst(nullptr), type(nullptr) {} + /// Module information + std::string rootName = ""; + bool isToplevel = false; + + Function() + : ast(nullptr), origAst(nullptr), type(nullptr), rootName(""), + isToplevel(false) {} }; /// Function lookup table that maps a canonical function identifier to the /// corresponding Function instance. @@ -236,6 +246,8 @@ struct Cache : public std::enable_shared_from_this { /// Set if Codon operates in Python compatibility mode (e.g., with Python numerics) bool pythonCompat = false; + /// Set if Codon operates in Python extension mode + bool pythonExt = false; public: explicit Cache(std::string argv0 = ""); @@ -298,6 +310,9 @@ struct Cache : public std::enable_shared_from_this { void parseCode(const std::string &code); static std::vector mergeC3(std::vector> &); + + std::shared_ptr pyModule = nullptr; + void populatePythonModule(); }; } // namespace codon::ast diff --git a/codon/parser/ctx.h b/codon/parser/ctx.h index 54c94b5f..6236ff95 100644 --- a/codon/parser/ctx.h +++ b/codon/parser/ctx.h @@ -68,7 +68,6 @@ template class Context : public std::enable_shared_from_this initParser() { init_codon_actions(*g); ~(*g)["NLP"] <= peg::usr([](const char *s, size_t n, peg::SemanticValues &, std::any &dt) { - return std::any_cast(dt).parens - ? 0 - : (n >= 1 && s[0] == '\\' ? 1 : -1); + auto e = (n >= 1 && s[0] == '\\' ? 1 : -1); + if (std::any_cast(dt).parens && e == -1) + e = 0; + return e; }); for (auto &x : *g) { auto v = peg::LinkReferences(*g, x.second.params); diff --git a/codon/parser/visitors/simplify/access.cpp b/codon/parser/visitors/simplify/access.cpp index 99fe50ed..aec276a1 100644 --- a/codon/parser/visitors/simplify/access.cpp +++ b/codon/parser/visitors/simplify/access.cpp @@ -19,8 +19,14 @@ void SimplifyVisitor::visit(IdExpr *expr) { return; } auto val = ctx->findDominatingBinding(expr->value); - if (!val) + + if (!val && ctx->getBase()->pyCaptures) { + ctx->getBase()->pyCaptures->insert(expr->value); + resultExpr = N(N("__pyenv__"), N(expr->value)); + return; + } else if (!val) { E(Error::ID_NOT_FOUND, expr, expr->value); + } // If we are accessing an outside variable, capture it or raise an error auto captured = checkCapture(val); @@ -107,7 +113,11 @@ void SimplifyVisitor::visit(DotExpr *expr) { std::reverse(chain.begin(), chain.end()); auto p = getImport(chain); - if (p.second->getModule() == "std.python") { + if (!p.second) { + seqassert(ctx->getBase()->pyCaptures, "unexpected py capture"); + ctx->getBase()->pyCaptures->insert(chain[0]); + resultExpr = N(N("__pyenv__"), N(chain[0])); + } else if (p.second->getModule() == "std.python") { resultExpr = transform(N( N(N(N("internal"), "python"), "_get_identifier"), N(chain[p.first++]))); @@ -157,7 +167,6 @@ bool SimplifyVisitor::checkCapture(const SimplifyContext::Item &val) { if (!localGeneric && !parentClassGeneric && !ctx->bases[i].captures) crossCaptureBoundary = true; } - seqassert(i < ctx->bases.size(), "invalid base for '{}'", val->canonicalName); // Mark methods (class functions that access class generics) if (parentClassGeneric) @@ -250,8 +259,11 @@ SimplifyVisitor::getImport(const std::vector &chain) { } } } - if (itemName.empty() && importName.empty()) + if (itemName.empty() && importName.empty()) { + if (ctx->getBase()->pyCaptures) + return {1, nullptr}; E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]); + } if (itemName.empty()) E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd], ctx->cache->imports[importName].moduleName); diff --git a/codon/parser/visitors/simplify/call.cpp b/codon/parser/visitors/simplify/call.cpp index 50a9ba03..4158575a 100644 --- a/codon/parser/visitors/simplify/call.cpp +++ b/codon/parser/visitors/simplify/call.cpp @@ -78,6 +78,9 @@ SimplifyVisitor::transformTupleGenerator(const std::vector &args) E(Error::CALL_TUPLE_COMPREHENSION, args[0].value); auto var = clone(g->loops[0].vars); auto ex = clone(g->expr); + + ctx->enterConditionalBlock(); + ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}}); if (auto i = var->getId()) { ctx->addVar(i->value, ctx->generateCanonicalName(i->value), var->getSrcInfo()); var = transform(var); @@ -89,6 +92,11 @@ SimplifyVisitor::transformTupleGenerator(const std::vector &args) auto head = transform(N(clone(g->loops[0].vars), clone(var))); ex = N(head, transform(ex)); } + ctx->leaveConditionalBlock(); + // Dominate loop variables + for (auto &var : ctx->getBase()->getLoop()->seenVars) + ctx->findDominatingBinding(var); + ctx->getBase()->loops.pop_back(); return N( GeneratorExpr::Generator, ex, std::vector{{var, transform(g->loops[0].gen), {}}}); diff --git a/codon/parser/visitors/simplify/class.cpp b/codon/parser/visitors/simplify/class.cpp index 575d7474..3ac6c2b9 100644 --- a/codon/parser/visitors/simplify/class.cpp +++ b/codon/parser/visitors/simplify/class.cpp @@ -107,13 +107,16 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { } // Collect classes (and their fields) that are to be statically inherited - auto staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, - stmt->attributes, canonicalName); - if (ctx->cache->isJit && !stmt->baseClasses.empty()) - E(Error::CUSTOM, stmt->baseClasses[0], - "inheritance is not yet supported in JIT mode"); - auto baseASTs = parseBaseClasses(stmt->baseClasses, args, stmt->attributes, - canonicalName, transformedTypeAst); + std::vector staticBaseASTs, baseASTs; + if (!stmt->attributes.has(Attr::Extend)) { + staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, stmt->attributes, + canonicalName); + if (ctx->cache->isJit && !stmt->baseClasses.empty()) + E(Error::CUSTOM, stmt->baseClasses[0], + "inheritance is not yet supported in JIT mode"); + parseBaseClasses(stmt->baseClasses, args, stmt->attributes, canonicalName, + transformedTypeAst); + } // A ClassStmt will be separated into class variable assignments, method-free // ClassStmts (that include nested classes) and method FunctionStmts @@ -177,6 +180,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { for (auto &b : staticBaseASTs) ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name); ctx->cache->classes[canonicalName].ast->validate(); + ctx->cache->classes[canonicalName].module = ctx->getModule(); // Codegen default magic methods for (auto &m : stmt->attributes.magics) { @@ -324,9 +328,8 @@ std::vector SimplifyVisitor::parseBaseClasses( E(Error::CLASS_NO_INHERIT, getSrcInfo(), "internal"); // Add __vtable__ to parent classes if it is not there already - if (typeAst && (cachedCls->fields.empty() || - cachedCls->fields[0].name != format("{}.{}", VAR_VTABLE, name))) { - auto var = format("{}.{}", VAR_VTABLE, name); + auto var = format("{}.{}", VAR_VTABLE, name); + if (typeAst && (cachedCls->fields.empty() || cachedCls->fields[0].name != var)) { // LOG("[virtual] vtable({}) := {}", name, var); cachedCls->fields.insert(cachedCls->fields.begin(), {var, nullptr}); cachedCls->ast->args.insert( @@ -367,8 +370,17 @@ std::vector SimplifyVisitor::parseBaseClasses( // Add normal fields for (auto &ast : asts) { for (auto &a : ast->args) { - if (a.status == Param::Normal && !ClassStmt::isClassVar(a)) - args.emplace_back(Param{a.name, a.type, a.defaultValue}); + if (a.status == Param::Normal && !ClassStmt::isClassVar(a)) { + auto name = a.name; + if (startswith(name, VAR_VTABLE)) { // prevent clashing names + int i = 0; + for (auto &aa : args) + i += bool(startswith(aa.name, a.name)); + if (i) + name = format("{}#{}", name, i); + } + args.emplace_back(Param{name, a.type, a.defaultValue}); + } } } if (typeAst) { @@ -800,11 +812,18 @@ StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typE stmts.emplace_back(N(I("d"))); } else if (op == "add") { // def __add__(self, tup): - // return (*self, *t) + // return __internal__.tuple_add(self, tup) fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"tup", nullptr}); - stmts.emplace_back(N(N( - std::vector{N(I("self")), N(I("tup"))}))); + stmts.emplace_back(N( + N(N(I("__internal__"), "tuple_add"), I("self"), I("tup")))); + } else if (op == "mul") { + // def __mul__(self, i: Static[int]): + // return __internal__.tuple_add(self, tup) + fargs.emplace_back(Param{"self", typExpr->clone()}); + fargs.emplace_back(Param{"i", N(I("Static"), I("int"))}); + stmts.emplace_back(N( + N(N(I("__internal__"), "tuple_mul"), I("self"), I("i")))); } else if (op == "tuplesize") { // def __tuplesize__() -> int: // return Tuple[arg_types...].__elemsize__ diff --git a/codon/parser/visitors/simplify/ctx.cpp b/codon/parser/visitors/simplify/ctx.cpp index c5a95493..e77c9407 100644 --- a/codon/parser/visitors/simplify/ctx.cpp +++ b/codon/parser/visitors/simplify/ctx.cpp @@ -25,7 +25,7 @@ SimplifyContext::SimplifyContext(std::string filename, Cache *cache) SimplifyContext::Base::Base(std::string name, Attr *attributes) : name(move(name)), attributes(attributes), deducedMembers(nullptr), selfName(), - captures(nullptr) {} + captures(nullptr), pyCaptures(nullptr) {} void SimplifyContext::add(const std::string &name, const SimplifyContext::Item &var) { auto v = find(name); diff --git a/codon/parser/visitors/simplify/ctx.h b/codon/parser/visitors/simplify/ctx.h index 1f959170..31c11bf7 100644 --- a/codon/parser/visitors/simplify/ctx.h +++ b/codon/parser/visitors/simplify/ctx.h @@ -112,6 +112,10 @@ struct SimplifyContext : public Context { /// function after processing) and their types (indicating if they are a type, a /// static or a variable). std::unordered_map> *captures; + + /// Map of identifiers that are to be fetched from Python. + std::unordered_set *pyCaptures; + /// Scope that defines the base. std::vector scope; diff --git a/codon/parser/visitors/simplify/function.cpp b/codon/parser/visitors/simplify/function.cpp index f8a6ee6e..8454512b 100644 --- a/codon/parser/visitors/simplify/function.cpp +++ b/codon/parser/visitors/simplify/function.cpp @@ -161,6 +161,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { StmtPtr suite = nullptr; ExprPtr ret = nullptr; std::unordered_map> captures; + std::unordered_set pyCaptures; { // Set up the base SimplifyContext::BaseGuard br(ctx.get(), canonicalName); @@ -239,6 +240,8 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { } else { if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember) ctx->getBase()->captures = &captures; + if (stmt->attributes.has("std.internal.attributes.pycapture")) + ctx->getBase()->pyCaptures = &pyCaptures; suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite); } } @@ -254,6 +257,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { stmt->attributes.parentClass = ctx->getBase()->name; // Add the method to the class' method list ctx->cache->classes[ctx->getBase()->name].methods[stmt->name] = rootName; + } else { + // Hack so that we can later use same helpers for class overloads + ctx->cache->classes[".toplevel"].methods[stmt->name] = rootName; } // Handle captures. Add additional argument to the function for every capture. @@ -278,6 +284,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ctx->cache->functions[canonicalName].ast = f; ctx->cache->functions[canonicalName].origAst = std::static_pointer_cast(stmt->clone()); + ctx->cache->functions[canonicalName].isToplevel = + ctx->getModule().empty() && ctx->isGlobal(); + ctx->cache->functions[canonicalName].rootName = rootName; // Expression to be used if function binding is modified by captures or decorators ExprPtr finalExpr = nullptr; diff --git a/codon/parser/visitors/simplify/import.cpp b/codon/parser/visitors/simplify/import.cpp index 77ab4aca..2e149040 100644 --- a/codon/parser/visitors/simplify/import.cpp +++ b/codon/parser/visitors/simplify/import.cpp @@ -90,7 +90,6 @@ void SimplifyVisitor::visit(ImportStmt *stmt) { // `__` while the standard library is being loaded auto c = i.second.front(); if (c->isConditional() && i.first.find('.') == std::string::npos) { - LOG("-> fix {} :: {}", import.moduleName, i.first); c = import.ctx->findDominatingBinding(i.first); } // Imports should ignore noShadow property @@ -193,7 +192,6 @@ StmtPtr SimplifyVisitor::transformCImport(const std::string &name, auto val = ctx->forceFind(name); ctx->add(altName, val); ctx->remove(name); - seqassert(ctx->find(name) == nullptr, "import not properly handled"); } return f; } @@ -325,8 +323,10 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) { // str is not defined when loading internal.core; __name__ is not needed anyway n = nullptr; } - n = SimplifyVisitor(ictx, preamble) - .transform(N(n, parseFile(ctx->cache, file.path))); + n = N(n, parseFile(ctx->cache, file.path)); + n = SimplifyVisitor(ictx, preamble).transform(n); + if (!ctx->cache->errors.empty()) + throw exc::ParserException(); // Add comment to the top of import for easier dump inspection auto comment = N(format("import: {} at {}", file.module, file.path)); if (ctx->isStdlibLoading) { diff --git a/codon/parser/visitors/simplify/loops.cpp b/codon/parser/visitors/simplify/loops.cpp index b01cc08d..99f38fa1 100644 --- a/codon/parser/visitors/simplify/loops.cpp +++ b/codon/parser/visitors/simplify/loops.cpp @@ -57,12 +57,6 @@ void SimplifyVisitor::visit(WhileStmt *stmt) { ctx->getBase()->loops.push_back({breakVar, ctx->scope.blocks, {}}); stmt->cond = transform(N(N(stmt->cond, "__bool__"))); transformConditionalScope(stmt->suite); - ctx->leaveConditionalBlock(); - // Dominate loop variables - for (auto &var : ctx->getBase()->getLoop()->seenVars) { - ctx->findDominatingBinding(var); - } - ctx->getBase()->loops.pop_back(); // Complete while-else clause if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { @@ -70,6 +64,13 @@ void SimplifyVisitor::visit(WhileStmt *stmt) { N(transform(N(breakVar)), transformConditionalScope(stmt->elseSuite))); } + + ctx->leaveConditionalBlock(); + // Dominate loop variables + for (auto &var : ctx->getBase()->getLoop()->seenVars) { + ctx->findDominatingBinding(var); + } + ctx->getBase()->loops.pop_back(); } /// Transform for loop. @@ -114,11 +115,6 @@ void SimplifyVisitor::visit(ForStmt *stmt) { stmts.push_back(stmt->suite); stmt->suite = transform(N(stmts)); } - ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts)); - // Dominate loop variables - for (auto &var : ctx->getBase()->getLoop()->seenVars) - ctx->findDominatingBinding(var); - ctx->getBase()->loops.pop_back(); // Complete while-else clause if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { @@ -126,6 +122,12 @@ void SimplifyVisitor::visit(ForStmt *stmt) { N(transform(N(breakVar)), transformConditionalScope(stmt->elseSuite))); } + + ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts)); + // Dominate loop variables + for (auto &var : ctx->getBase()->getLoop()->seenVars) + ctx->findDominatingBinding(var); + ctx->getBase()->loops.pop_back(); } /// Transform and check for OpenMP decorator. diff --git a/codon/parser/visitors/simplify/op.cpp b/codon/parser/visitors/simplify/op.cpp index b5d9e503..9b9904d0 100644 --- a/codon/parser/visitors/simplify/op.cpp +++ b/codon/parser/visitors/simplify/op.cpp @@ -74,6 +74,7 @@ void SimplifyVisitor::visit(IndexExpr *expr) { // IndexExpr[i1, ..., iN] is internally represented as // IndexExpr[TupleExpr[i1, ..., iN]] for N > 1 std::vector items; + bool isTuple = expr->index->getTuple(); if (auto t = expr->index->getTuple()) { items = t->items; } else { @@ -91,7 +92,7 @@ void SimplifyVisitor::visit(IndexExpr *expr) { resultExpr = N(expr->expr, items); resultExpr->markType(); } else { - expr->index = items.size() == 1 ? items[0] : N(items); + expr->index = (!isTuple && items.size() == 1) ? items[0] : N(items); } } diff --git a/codon/parser/visitors/simplify/simplify.cpp b/codon/parser/visitors/simplify/simplify.cpp index 18efa584..3f1acd70 100644 --- a/codon/parser/visitors/simplify/simplify.cpp +++ b/codon/parser/visitors/simplify/simplify.cpp @@ -32,6 +32,7 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil auto preamble = std::make_shared>(); seqassertn(cache->module, "cache's module is not set"); +#define N std::make_shared // Load standard library if it has not been loaded if (!in(cache->imports, STDLIB_IMPORT)) { // Load the internal.__init__ @@ -63,11 +64,9 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil // Load early compile-time defines (for standard library) preamble->push_back( SimplifyVisitor(stdlib, preamble) - .transform(std::make_shared( - std::make_shared(d.first), - std::make_shared(d.second), - std::make_shared(std::make_shared("Static"), - std::make_shared("int"))))); + .transform( + N(N(d.first), N(d.second), + N(N("Static"), N("int"))))); } preamble->push_back(SimplifyVisitor(stdlib, preamble) .transform(parseFile(stdlib->cache, stdlibPath->path))); @@ -86,28 +85,30 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil ctx->moduleName = {ImportFile::PACKAGE, file, MODULE_MAIN}; // Prepare the code - auto suite = std::make_shared(); + auto suite = N(); + suite->stmts.push_back(N(".toplevel", std::vector{}, nullptr, + std::vector{N(Attr::Internal)})); for (auto &d : defines) { // Load compile-time defines (e.g., codon run -DFOO=1 ...) - suite->stmts.push_back(std::make_shared( - std::make_shared(d.first), std::make_shared(d.second), - std::make_shared(std::make_shared("Static"), - std::make_shared("int")))); + suite->stmts.push_back( + N(N(d.first), N(d.second), + N(N("Static"), N("int")))); } // Set up __name__ - suite->stmts.push_back(std::make_shared( - std::make_shared("__name__"), std::make_shared(MODULE_MAIN))); + suite->stmts.push_back( + N(N("__name__"), N(MODULE_MAIN))); suite->stmts.push_back(node); auto n = SimplifyVisitor(ctx, preamble).transform(suite); - suite = std::make_shared(); - suite->stmts.push_back(std::make_shared(*preamble)); + suite = N(); + suite->stmts.push_back(N(*preamble)); // Add dominated assignment declarations if (in(ctx->scope.stmts, ctx->scope.blocks.back())) suite->stmts.insert(suite->stmts.end(), ctx->scope.stmts[ctx->scope.blocks.back()].begin(), ctx->scope.stmts[ctx->scope.blocks.back()].end()); suite->stmts.push_back(n); +#undef N if (!ctx->cache->errors.empty()) throw exc::ParserException(); diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 43be4743..83b57122 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -56,6 +56,7 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) { } TranslateVisitor(cache->codegenCtx).transform(stmts); + cache->populatePythonModule(); return main; } @@ -205,6 +206,10 @@ void TranslateVisitor::visit(CallExpr *expr) { arrayType->setAstType(expr->getType()); result = make(expr, arrayType, sz); return; + } else if (expr->expr->getId() && startswith(expr->expr->getId()->value, + "__internal__.yield_in_no_suspend:0")) { + result = make(expr, getType(expr->getType()), false); + return; } auto ft = expr->expr->type->getFunc(); @@ -229,14 +234,18 @@ void TranslateVisitor::visit(CallExpr *expr) { } void TranslateVisitor::visit(DotExpr *expr) { - if (expr->member == "__atomic__" || expr->member == "__elemsize__") { + if (expr->member == "__atomic__" || expr->member == "__elemsize__" || + expr->member == "__contents_atomic__") { seqassert(expr->expr->getId(), "expected IdExpr, got {}", expr->expr); auto type = ctx->find(expr->expr->getId()->value)->getType(); seqassert(type, "{} is not a type", expr->expr->getId()->value); result = make( expr, type, - expr->member == "__atomic__" ? ir::TypePropertyInstr::Property::IS_ATOMIC - : ir::TypePropertyInstr::Property::SIZEOF); + expr->member == "__atomic__" + ? ir::TypePropertyInstr::Property::IS_ATOMIC + : (expr->member == "__contents_atomic__" + ? ir::TypePropertyInstr::Property::IS_CONTENT_ATOMIC + : ir::TypePropertyInstr::Property::SIZEOF)); } else { result = make(expr, transform(expr->expr), expr->member); } @@ -332,7 +341,16 @@ void TranslateVisitor::visit(ContinueStmt *stmt) { result = make(stmt); } -void TranslateVisitor::visit(ExprStmt *stmt) { result = transform(stmt->expr); } +void TranslateVisitor::visit(ExprStmt *stmt) { + if (stmt->expr->getCall() && + stmt->expr->getCall()->expr->isId("__internal__.yield_final:0")) { + result = make(stmt, transform(stmt->expr->getCall()->args[0].value), + true); + ctx->getBase()->setGenerator(); + } else { + result = transform(stmt->expr); + } +} void TranslateVisitor::visit(AssignStmt *stmt) { if (stmt->lhs && stmt->lhs->isId(VAR_ARGV)) @@ -418,7 +436,6 @@ void TranslateVisitor::visit(ForStmt *stmt) { fc->funcGenerics[2].type->getStatic()->expr->staticValue.getInt(); bool gpu = fc->funcGenerics[3].type->getStatic()->expr->staticValue.getInt(); os = std::make_unique(schedule, threads, chunk, ordered, collapse, gpu); - LOG_TYPECHECK("parsed {}", stmt->decorator); } seqassert(stmt->var->getId(), "expected IdExpr, got {}", stmt->var); diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index 9caa87d0..76a5d797 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -166,6 +166,12 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, return transform(N(expr->expr->type->prettyString())); return nullptr; } + // Special case: expr.__is_static__ + if (expr->member == "__is_static__") { + if (expr->expr->isDone()) + return transform(N(expr->expr->isStatic())); + return nullptr; + } // Special case: cls.__vtable_id__ if (expr->expr->isType() && expr->member == "__vtable_id__") { if (auto c = realize(expr->expr->type)) @@ -189,18 +195,18 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, unify(expr->type, ctx->instantiate(bestMethod, typ)); // Handle virtual calls - auto baseClass = expr->expr->type->getClass(); - auto vtableName = format("{}.{}", VAR_VTABLE, baseClass->name); + auto vtableName = format("{}.{}", VAR_VTABLE, typ->name); // A function is deemed virtual if it is marked as such and if a base class has a // vtable - bool isVirtual = in(ctx->cache->classes[baseClass->name].virtuals, expr->member); - isVirtual &= ctx->findMember(baseClass->name, vtableName) != nullptr; + bool isVirtual = in(ctx->cache->classes[typ->name].virtuals, expr->member); + isVirtual &= ctx->findMember(typ->name, vtableName) != nullptr; + isVirtual &= !expr->expr->isType(); if (isVirtual && !bestMethod->ast->attributes.has(Attr::StaticMethod) && !bestMethod->ast->attributes.has(Attr::Property)) { // Special case: route the call through a vtable if (realize(expr->type)) { auto fn = expr->type->getFunc(); - auto vid = getRealizationID(expr->expr->type->getClass().get(), fn.get()); + auto vid = getRealizationID(typ.get(), fn.get()); // Function[Tuple[TArg1, TArg2, ...], TRet] std::vector ids; @@ -341,6 +347,8 @@ TypePtr TypecheckVisitor::findSpecialMember(const std::string &member) { return ctx->getType("int"); if (member == "__atomic__") return ctx->getType("bool"); + if (member == "__contents_atomic__") + return ctx->getType("bool"); if (member == "__name__") return ctx->getType("str"); return nullptr; diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 4e9f391f..21902156 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -281,8 +281,10 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e }; // Handle reordered arguments (see @c reorderNamedArgs for details) + bool partial = false; auto reorderFn = [&](int starArgIndex, int kwstarArgIndex, - const std::vector> &slots, bool partial) { + const std::vector> &slots, bool _partial) { + partial = _partial; ctx->addBlock(); // add function generics to typecheck default arguments addFunctionGenerics(calleeFn->getFunc().get()); for (size_t si = 0, pi = 0; si < slots.size(); si++) { @@ -410,18 +412,27 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e (!expr->hasAttr(ExprAttr::OrderedCall) && typeArgs.size() == calleeFn->funcGenerics.size()), "bad vector sizes"); - for (size_t si = 0; - !expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size(); - si++) { - if (typeArgs[si]) { - auto typ = typeArgs[si]->type; - if (calleeFn->funcGenerics[si].type->isStaticType()) { - if (!typeArgs[si]->isStatic()) { - E(Error::EXPECTED_STATIC, typeArgs[si]); + if (!calleeFn->funcGenerics.empty()) { + auto niGenerics = calleeFn->ast->getNonInferrableGenerics(); + for (size_t si = 0; + !expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size(); + si++) { + if (typeArgs[si]) { + auto typ = typeArgs[si]->type; + if (calleeFn->funcGenerics[si].type->isStaticType()) { + if (!typeArgs[si]->isStatic()) { + E(Error::EXPECTED_STATIC, typeArgs[si]); + } + typ = Type::makeStatic(ctx->cache, typeArgs[si]); + } + unify(typ, calleeFn->funcGenerics[si].type); + } else { + if (calleeFn->funcGenerics[si].type->getUnbound() && + !calleeFn->ast->args[si].defaultValue && !partial && + in(niGenerics, calleeFn->funcGenerics[si].name)) { + error("generic '{}' not provided", calleeFn->funcGenerics[si].niceName); } - typ = Type::makeStatic(ctx->cache, typeArgs[si]); } - unify(typ, calleeFn->funcGenerics[si].type); } } @@ -495,7 +506,10 @@ bool TypecheckVisitor::typecheckCallArgs(const FuncTypePtr &calleeFn, if (calleeFn->ast->args[i].status == Param::Generic) { if (calleeFn->ast->args[i].defaultValue && calleeFn->funcGenerics[j].type->getUnbound()) { + ctx->addBlock(); // add function generics to typecheck default arguments + addFunctionGenerics(calleeFn->getFunc().get()); auto def = transform(clone(calleeFn->ast->args[i].defaultValue)); + ctx->popBlock(); unify(calleeFn->funcGenerics[j].type, def->isStatic() ? Type::makeStatic(ctx->cache, def) : def->getType()); } @@ -543,6 +557,8 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) return {true, transformHasAttr(expr)}; } else if (val == "getattr") { return {true, transformGetAttr(expr)}; + } else if (val == "setattr") { + return {true, transformSetAttr(expr)}; } else if (val == "type.__new__:0") { return {true, transformTypeFn(expr)}; } else if (val == "compile_error") { @@ -551,10 +567,10 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) return {true, transformTupleFn(expr)}; } else if (val == "__realized__") { return {true, transformRealizedFn(expr)}; - } else if (val == "__static_print__") { + } else if (val == "std.internal.static.static_print") { return {false, transformStaticPrintFn(expr)}; } else { - return {false, nullptr}; + return transformInternalStaticFn(expr); } } @@ -622,10 +638,13 @@ ExprPtr TypecheckVisitor::transformSuper() { E(Error::CALL_SUPER_PARENT, getSrcInfo()); auto superTyp = ctx->instantiate(vCands[1]->type, typ)->getClass(); - // LOG("-> {}", superTyp); auto self = N(funcTyp->ast->args[0].name); self->type = typ; - return castToSuperClass(self, superTyp, true); + + auto typExpr = N(superTyp->name); + typExpr->setType(superTyp); + return transform(N(N(N("__internal__"), "class_super"), + self, typExpr)); } auto name = cands.front(); // the first inherited type @@ -827,6 +846,18 @@ ExprPtr TypecheckVisitor::transformGetAttr(CallExpr *expr) { return transform(N(expr->args[0].value, staticTyp->evaluate().getString())); } +/// Transform setattr method to a AssignMemberStmt. +ExprPtr TypecheckVisitor::transformSetAttr(CallExpr *expr) { + auto funcTyp = expr->expr->type->getFunc(); + auto staticTyp = funcTyp->funcGenerics[0].type->getStatic(); + if (!staticTyp->canRealize()) + return nullptr; + return transform(N(N(expr->args[0].value, + staticTyp->evaluate().getString(), + expr->args[1].value), + N(N("NoneType")))); +} + /// Raise a compiler error. ExprPtr TypecheckVisitor::transformCompileError(CallExpr *expr) { auto funcTyp = expr->expr->type->getFunc(); @@ -842,11 +873,29 @@ ExprPtr TypecheckVisitor::transformTupleFn(CallExpr *expr) { if (!cls) return nullptr; + // tuple(ClassType) is a tuple type that corresponds to a class + if (expr->args.front().value->isType()) { + if (!realize(cls)) + return expr->clone(); + + std::vector items; + auto tn = generateTuple(ctx->cache->classes[cls->name].fields.size()); + for (auto &ft : ctx->cache->classes[cls->name].fields) { + auto t = ctx->instantiate(ft.type, cls); + auto rt = realize(t); + seqassert(rt, "cannot realize '{}' in {}", t, ft.name); + items.push_back(NT(t->realizedName())); + } + auto e = transform(NT(N(tn), items)); + return e; + } + std::vector args; args.reserve(ctx->cache->classes[cls->name].fields.size()); std::string var = ctx->cache->getTemporaryVar("tup"); for (auto &field : ctx->cache->classes[cls->name].fields) args.emplace_back(N(N(var), field.name)); + return transform(N( N(N(var), expr->args.front().value), N(N(format("{}{}", TYPE_TUPLE, args.size())), args))); @@ -856,6 +905,7 @@ ExprPtr TypecheckVisitor::transformTupleFn(CallExpr *expr) { ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) { expr->markType(); transform(expr->args[0].value); + unify(expr->type, expr->args[0].value->getType()); if (!realize(expr->type)) @@ -887,13 +937,128 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) { auto &args = expr->args[0].value->getCall()->args; for (size_t i = 0; i < args.size(); i++) { realize(args[i].value->type); - fmt::print(stderr, "[static_print] {}: {} := {}\n", getSrcInfo(), + fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(), FormatVisitor::apply(args[i].value), - args[i].value->type ? args[i].value->type->debugString(1) : "-"); + args[i].value->type ? args[i].value->type->debugString(1) : "-", + args[i].value->isStatic() ? " [static]" : ""); } return nullptr; } +// Transform internal.static calls +std::pair TypecheckVisitor::transformInternalStaticFn(CallExpr *expr) { + if (expr->expr->isId("std.internal.static.fn_can_call")) { + expr->staticValue.type = StaticValue::INT; + auto typ = expr->args[0].value->getType()->getClass(); + if (!typ) + return {true, nullptr}; + + auto fn = expr->args[0].value->type->getFunc(); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + + auto inargs = unpackTupleTypes(expr->args[1].value); + auto kwargs = unpackTupleTypes(expr->args[2].value); + seqassert(inargs && kwargs, "bad call to fn_can_call"); + + std::vector callArgs; + for (auto &a : *inargs) { + callArgs.push_back({a.first, std::make_shared()}); // dummy expression + callArgs.back().value->setType(a.second); + } + for (auto &a : *kwargs) { + callArgs.push_back({a.first, std::make_shared()}); // dummy expression + callArgs.back().value->setType(a.second); + } + return {true, transform(N(canCall(fn, callArgs) >= 0))}; + } else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) { + expr->staticValue.type = StaticValue::INT; + auto fn = ctx->extractFunction(expr->args[0].value->type); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); + seqassert(idx, "expected a static integer"); + auto &args = fn->getArgTypes(); + return {true, transform(N(*idx >= 0 && *idx < args.size() && + args[*idx]->canRealize()))}; + } else if (expr->expr->isId("std.internal.static.fn_arg_get_type")) { + auto fn = ctx->extractFunction(expr->args[0].value->type); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); + seqassert(idx, "expected a static integer"); + auto &args = fn->getArgTypes(); + if (*idx < 0 || *idx >= args.size() || !args[*idx]->canRealize()) + error("argument does not have type"); + return {true, transform(NT(args[*idx]->realizedName()))}; + } else if (expr->expr->isId("std.internal.static.fn_args")) { + auto fn = ctx->extractFunction(expr->args[0].value->type); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + std::vector v; + for (size_t i = 0; i < fn->ast->args.size(); i++) { + auto n = fn->ast->args[i].name; + trimStars(n); + n = ctx->cache->rev(n); + v.push_back(N(n)); + } + return {true, transform(N(v))}; + } else if (expr->expr->isId("std.internal.static.fn_has_default")) { + expr->staticValue.type = StaticValue::INT; + auto fn = ctx->extractFunction(expr->args[0].value->type); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); + seqassert(idx, "expected a static integer"); + auto &args = fn->ast->args; + if (*idx < 0 || *idx >= args.size()) + error("argument out of bounds"); + return {true, transform(N(args[*idx].defaultValue != nullptr))}; + } else if (expr->expr->isId("std.internal.static.fn_get_default")) { + auto fn = ctx->extractFunction(expr->args[0].value->type); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); + seqassert(idx, "expected a static integer"); + auto &args = fn->ast->args; + if (*idx < 0 || *idx >= args.size()) + error("argument out of bounds"); + return {true, transform(args[*idx].defaultValue)}; + } else if (expr->expr->isId("std.internal.static.fn_wrap_call_args")) { + auto typ = expr->args[0].value->getType()->getClass(); + if (!typ) + return {true, nullptr}; + + auto fn = ctx->extractFunction(expr->args[0].value->type); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + + std::vector callArgs; + if (auto tup = expr->args[1].value->origExpr->getTuple()) { + for (auto &a : tup->items) { + callArgs.push_back({"", a}); + } + } + if (auto kw = expr->args[1].value->origExpr->getCall()) { + auto kwCls = in(ctx->cache->classes, expr->getType()->getClass()->name); + seqassert(kwCls, "cannot find {}", expr->getType()->getClass()->name); + for (size_t i = 0; i < kw->args.size(); i++) { + callArgs.push_back({kwCls->fields[i].name, kw->args[i].value}); + } + } + auto zzz = transform(N(N(fn->ast->name), callArgs)); + if (!zzz->isDone()) + return {true, nullptr}; + + std::vector tupArgs; + for (auto &a : zzz->getCall()->args) + tupArgs.push_back(a.value); + return {true, transform(N(tupArgs))}; + } else { + return {false, nullptr}; + } +} + /// Get the list that describes the inheritance hierarchy of a given type. /// The first type in the list is the most recently inherited type. std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) { diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index 9e5b781b..fa615dc5 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -26,6 +26,8 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { auto typ = Type::makeType(ctx->cache, stmt->name, ctx->cache->rev(stmt->name), stmt->isRecord()) ->getClass(); + if (stmt->isRecord() && stmt->hasAttr("__notuple__")) + typ->getRecord()->noTuple = true; if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) { // Special handling of partial types (e.g., `Partial.0001.foo`) if (auto p = in(ctx->cache->partials, stmt->name)) @@ -90,8 +92,9 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { ctx->typecheckLevel--; // Handle MRO - for (auto &m : ctx->cache->classes[stmt->name].mro) + for (auto &m : ctx->cache->classes[stmt->name].mro) { m = transformType(m); + } // Generalize generics and remove them from the context for (const auto &g : stmt->args) diff --git a/codon/parser/visitors/typecheck/collections.cpp b/codon/parser/visitors/typecheck/collections.cpp index d41ed6a2..6e136491 100644 --- a/codon/parser/visitors/typecheck/collections.cpp +++ b/codon/parser/visitors/typecheck/collections.cpp @@ -199,6 +199,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type, /// @example /// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)` void TypecheckVisitor::visit(TupleExpr *expr) { + expr->setType(ctx->getUnbound()); for (int ai = 0; ai < expr->items.size(); ai++) if (auto star = expr->items[ai]->getStar()) { // Case: unpack star expressions (e.g., `*arg` -> `arg.item1, arg.item2, ...`) @@ -226,6 +227,7 @@ void TypecheckVisitor::visit(TupleExpr *expr) { auto tupleName = generateTuple(expr->items.size()); resultExpr = transform(N(N(tupleName, "__new__"), clone(expr->items))); + unify(expr->type, resultExpr->type); } /// Transform a tuple generator expression. @@ -242,17 +244,49 @@ void TypecheckVisitor::visit(GeneratorExpr *expr) { if (!gen->type->canRealize()) return; // Wait until the iterator can be realized - auto tuple = gen->type->getRecord(); - if (!tuple || - !(startswith(tuple->name, TYPE_TUPLE) || startswith(tuple->name, TYPE_KWTUPLE))) - E(Error::CALL_BAD_ITER, gen, gen->type->prettyString()); - auto block = N(); // `tuple = tuple_generator` auto tupleVar = ctx->cache->getTemporaryVar("tuple"); block->stmts.push_back(N(N(tupleVar), gen)); - // `a := tuple[i]` for each i + seqassert(expr->loops[0].vars->getId(), "tuple() not simplified"); + std::vector vars{expr->loops[0].vars->getId()->value}; + auto suiteVec = expr->expr->getStmtExpr() + ? expr->expr->getStmtExpr()->stmts[0]->getSuite() + : nullptr; + auto oldSuite = suiteVec ? suiteVec->clone() : nullptr; + for (int validI = 0; suiteVec && validI < suiteVec->stmts.size(); validI++) { + if (auto a = suiteVec->stmts[validI]->getAssign()) + if (a->rhs && a->rhs->getIndex()) + if (a->rhs->getIndex()->expr->isId(vars[0])) { + vars.push_back(a->lhs->getId()->value); + suiteVec->stmts[validI] = nullptr; + continue; + } + break; + } + if (vars.size() > 1) + vars.erase(vars.begin()); + auto [ok, staticItems] = + transformStaticLoopCall(vars, expr->loops[0].gen, [&](StmtPtr wrap) { + return N(wrap, clone(expr->expr)); + }); + if (ok) { + std::vector tupleItems; + for (auto &i : staticItems) + tupleItems.push_back(std::dynamic_pointer_cast(i)); + resultExpr = transform(N(block, N(tupleItems))); + return; + } else if (oldSuite) { + expr->expr->getStmtExpr()->stmts[0] = oldSuite; + } + + auto tuple = gen->type->getRecord(); + if (!tuple || + !(startswith(tuple->name, TYPE_TUPLE) || startswith(tuple->name, TYPE_KWTUPLE))) + E(Error::CALL_BAD_ITER, gen, gen->type->prettyString()); + + // `a := tuple[i]; expr...` for each i std::vector items; items.reserve(tuple->args.size()); for (int ai = 0; ai < tuple->args.size(); ai++) { @@ -262,7 +296,7 @@ void TypecheckVisitor::visit(GeneratorExpr *expr) { clone(expr->expr))); } - // `((a := tuple[0]), (a := tuple[1]), ...)` + // `((a := tuple[0]; expr), (a := tuple[1]; expr), ...)` resultExpr = transform(N(block, N(items))); } diff --git a/codon/parser/visitors/typecheck/cond.cpp b/codon/parser/visitors/typecheck/cond.cpp index f73d38d3..45f06076 100644 --- a/codon/parser/visitors/typecheck/cond.cpp +++ b/codon/parser/visitors/typecheck/cond.cpp @@ -65,7 +65,7 @@ void TypecheckVisitor::visit(IfExpr *expr) { transform(expr->ifexpr); transform(expr->elsexpr); // Add __bool__ wrapper - if (expr->cond->type->getClass() && !expr->cond->type->is("bool")) + while (expr->cond->type->getClass() && !expr->cond->type->is("bool")) expr->cond = transform(N(N(expr->cond, "__bool__"))); // Add wrappers and unify both sides wrapExpr(expr->elsexpr, expr->ifexpr->getType(), nullptr, /*allowUnwrap*/ false); @@ -96,7 +96,7 @@ void TypecheckVisitor::visit(IfStmt *stmt) { return; } - if (stmt->cond->type->getClass() && !stmt->cond->type->is("bool")) + while (stmt->cond->type->getClass() && !stmt->cond->type->is("bool")) stmt->cond = transform(N(N(stmt->cond, "__bool__"))); ctx->blockLevel++; transform(stmt->ifSuite); diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index 97adfec5..76146085 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -305,4 +305,44 @@ std::string TypeContext::debugInfo() { getRealizationBase()->iteration, getSrcInfo()); } +std::shared_ptr, std::vector>> +TypeContext::getFunctionArgs(types::TypePtr t) { + if (!t->getFunc()) + return nullptr; + auto fn = t->getFunc(); + auto ret = std::make_shared< + std::pair, std::vector>>(); + for (auto &t : fn->funcGenerics) + ret->first.push_back(t.type); + for (auto &t : fn->generics[0].type->getRecord()->args) + ret->second.push_back(t); + return ret; +} + +std::shared_ptr TypeContext::getStaticString(types::TypePtr t) { + if (auto s = t->getStatic()) { + auto r = s->evaluate(); + if (r.type == StaticValue::STRING) + return std::make_shared(r.getString()); + } + return nullptr; +} + +std::shared_ptr TypeContext::getStaticInt(types::TypePtr t) { + if (auto s = t->getStatic()) { + auto r = s->evaluate(); + if (r.type == StaticValue::INT) + return std::make_shared(r.getInt()); + } + return nullptr; +} + +types::FuncTypePtr TypeContext::extractFunction(types::TypePtr t) { + if (auto f = t->getFunc()) + return f; + if (auto p = t->getPartial()) + return p->func; + return nullptr; +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 06c085c4..8021c89d 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -157,6 +157,13 @@ struct TypeContext : public Context { void dump(int pad); /// Pretty-print the current realization context. std::string debugInfo(); + +public: + std::shared_ptr, std::vector>> + getFunctionArgs(types::TypePtr t); + std::shared_ptr getStaticString(types::TypePtr t); + std::shared_ptr getStaticInt(types::TypePtr t); + types::FuncTypePtr extractFunction(types::TypePtr t); }; } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 29ab77ef..2abc3173 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -16,6 +16,8 @@ using fmt::format; using namespace codon::error; +const int MAX_TYPECHECK_ITER = 1000; + namespace codon::ast { using namespace types; @@ -49,6 +51,12 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { for (ctx->getRealizationBase()->iteration = 1;; ctx->getRealizationBase()->iteration++) { + LOG_TYPECHECK("[iter] {} :: {}", ctx->getRealizationBase()->name, + ctx->getRealizationBase()->iteration); + if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER) + error(result, "cannot typecheck '{}' in reasonable time", + ctx->cache->rev(ctx->getRealizationBase()->name)); + // Keep iterating until: // (1) success: the statement is marked as done; or // (2) failure: no expression or statements were marked as done during an @@ -215,7 +223,7 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) { return nullptr; } - LOG_TYPECHECK("[realize] ty {} -> {}", realized->name, realized->realizedTypeName()); + LOG_REALIZE("[realize] ty {} -> {}", realized->name, realized->realizedTypeName()); // Realizations should always be visible, so add them to the toplevel ctx->addToplevel(realized->realizedTypeName(), @@ -259,6 +267,16 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) { cls->getContents()->setAttribute( std::make_unique(memberInfo)); } + + // Fix for partial types + if (auto p = type->getPartial()) { + auto pt = std::make_shared(realized->getRecord(), p->func, p->known); + ctx->addToplevel(pt->realizedName(), + std::make_shared(TypecheckItem::Type, pt)); + ctx->cache->classes[pt->name].realizations[pt->realizedName()] = + ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()]; + } + return realized; } @@ -407,6 +425,8 @@ StmtPtr TypecheckVisitor::prepareVTables() { if (!vtable.ir) vtSz += vtable.table.size(); } + if (!vtSz) + continue; auto var = initFn.ast->args[0].name; // p.__setitem__(real.ID) = Ptr[cobj](real.vtables.size() + 2) suite->stmts.push_back(N(N( @@ -469,9 +489,8 @@ StmtPtr TypecheckVisitor::prepareVTables() { auto suite = N(); for (auto &f : fields) if (startswith(f.name, VAR_VTABLE)) { - auto name = f.name.substr(std::string(VAR_VTABLE).size() + 1); suite->stmts.push_back(N( - N(varName), format("{}.{}", VAR_VTABLE, name), + N(varName), f.name, N( N("__vtables__"), N(N(clsTyp->realizedName()), "__vtable_id__")))); @@ -576,8 +595,15 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType std::vector args = fp->getArgTypes(); args[0] = ct; auto m = findBestMethod(ct, fnName, args); - if (!m) - E(Error::DOT_NO_ATTR_ARGS, getSrcInfo(), ct->prettyString(), fnName); + if (!m) { + // Print a nice error message + std::vector a; + for (auto &t : args) + a.emplace_back(fmt::format("{}", t->prettyString())); + std::string argsNice = fmt::format("({})", fmt::join(a, ", ")); + E(Error::DOT_NO_ATTR_ARGS, getSrcInfo(), ct->prettyString(), fnName, + argsNice); + } std::vector ns; for (auto &a : args) diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 479241da..d415f5e8 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -166,62 +166,50 @@ StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) { /// loop = False # also set to False on break /// A separate suite is generated for each static iteration. StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { + auto var = stmt->var->getId()->value; + if (!stmt->iter->getCall() || !stmt->iter->getCall()->expr->getId()) + return nullptr; + auto iter = stmt->iter->getCall()->expr->getId(); auto loopVar = ctx->cache->getTemporaryVar("loop"); - auto fn = [&](const std::string &var, const ExprPtr &expr) { - bool staticInt = expr->isStatic(); - auto t = NT( - N("Static"), - N(expr->staticValue.type == StaticValue::INT ? "int" : "str")); + + std::vector vars{var}; + auto suiteVec = stmt->suite->getSuite(); + auto oldSuite = suiteVec ? suiteVec->clone() : nullptr; + for (int validI = 0; suiteVec && validI < suiteVec->stmts.size(); validI++) { + if (auto a = suiteVec->stmts[validI]->getAssign()) + if (a->rhs && a->rhs->getIndex()) + if (a->rhs->getIndex()->expr->isId(var)) { + vars.push_back(a->lhs->getId()->value); + suiteVec->stmts[validI] = nullptr; + continue; + } + break; + } + if (vars.size() > 1) + vars.erase(vars.begin()); + auto [ok, items] = transformStaticLoopCall(vars, stmt->iter, [&](StmtPtr assigns) { auto brk = N(); brk->setDone(); // Avoid transforming this one to continue // var [: Static] := expr; suite... auto loop = N(N(loopVar), - N(N(N(var), expr->clone(), - staticInt ? t : nullptr), - clone(stmt->suite), brk)); + N(assigns, clone(stmt->suite), brk)); loop->gotoVar = loopVar; return loop; - }; - - auto var = stmt->var->getId()->value; - if (!stmt->iter->getCall() || !stmt->iter->getCall()->expr->getId()) - return nullptr; - auto iter = stmt->iter->getCall()->expr->getId(); - auto block = N(); - if (iter && startswith(iter->value, "statictuple:0")) { - auto &args = stmt->iter->getCall()->args[0].value->getCall()->args; - for (size_t i = 0; i < args.size(); i++) - block->stmts.push_back(fn(var, args[i].value)); - } else if (iter && - startswith(iter->value, "std.internal.types.range.staticrange:0")) { - int st = - iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); - int ed = - iter->type->getFunc()->funcGenerics[1].type->getStatic()->evaluate().getInt(); - int step = - iter->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt(); - if (abs(st - ed) / abs(step) > MAX_STATIC_ITER) - E(Error::STATIC_RANGE_BOUNDS, iter, MAX_STATIC_ITER, abs(st - ed) / abs(step)); - for (int i = st; step > 0 ? i < ed : i > ed; i += step) - block->stmts.push_back(fn(var, N(i))); - } else if (iter && - startswith(iter->value, "std.internal.types.range.staticrange:1")) { - int ed = - iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); - if (ed > MAX_STATIC_ITER) - E(Error::STATIC_RANGE_BOUNDS, iter, MAX_STATIC_ITER, ed); - for (int i = 0; i < ed; i++) - block->stmts.push_back(fn(var, N(i))); - } else { + }); + if (!ok) { + if (oldSuite) + stmt->suite = oldSuite; return nullptr; } - ctx->blockLevel++; // Close the loop + ctx->blockLevel++; auto a = N(N(loopVar), N(false)); a->setUpdate(); + auto block = N(); + for (auto &i : items) + block->stmts.push_back(std::dynamic_pointer_cast(i)); block->stmts.push_back(a); - auto loop = transform(N(N(N(loopVar), N(true)), N(N(loopVar), block))); @@ -229,4 +217,119 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { return loop; } +std::pair>> +TypecheckVisitor::transformStaticLoopCall( + const std::vector &vars, ExprPtr iter, + std::function(StmtPtr)> wrap) { + if (!iter->getCall()) + return {false, {}}; + auto fn = iter->getCall()->expr->getId(); + if (!fn || vars.empty()) + return {false, {}}; + + auto stmt = N(N(vars[0]), nullptr, nullptr); + + std::vector> block; + if (startswith(fn->value, "statictuple:0")) { + auto &args = iter->getCall()->args[0].value->getCall()->args; + if (vars.size() != 1) + error("expected one item"); + for (size_t i = 0; i < args.size(); i++) { + stmt->rhs = args[i].value; + if (stmt->rhs->isStatic()) { + stmt->type = NT( + N("Static"), + N(stmt->rhs->staticValue.type == StaticValue::INT ? "int" : "str")); + } else { + stmt->type = nullptr; + } + block.push_back(wrap(stmt->clone())); + } + } else if (fn && startswith(fn->value, "std.internal.types.range.staticrange:0")) { + if (vars.size() != 1) + error("expected one item"); + int st = + fn->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); + int ed = + fn->type->getFunc()->funcGenerics[1].type->getStatic()->evaluate().getInt(); + int step = + fn->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt(); + if (abs(st - ed) / abs(step) > MAX_STATIC_ITER) + E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, abs(st - ed) / abs(step)); + for (int i = st; step > 0 ? i < ed : i > ed; i += step) { + stmt->rhs = N(i); + stmt->type = NT(N("Static"), N("int")); + block.push_back(wrap(stmt->clone())); + } + } else if (fn && startswith(fn->value, "std.internal.types.range.staticrange:1")) { + if (vars.size() != 1) + error("expected one item"); + int ed = + fn->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); + if (ed > MAX_STATIC_ITER) + E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, ed); + for (int i = 0; i < ed; i++) { + stmt->rhs = N(i); + stmt->type = NT(N("Static"), N("int")); + block.push_back(wrap(stmt->clone())); + } + } else if (fn && startswith(fn->value, "std.internal.static.fn_overloads")) { + if (vars.size() != 1) + error("expected one item"); + if (auto fna = ctx->getFunctionArgs(fn->type)) { + auto [generics, args] = *fna; + auto typ = generics[0]->getClass(); + auto name = ctx->getStaticString(generics[1]); + seqassert(name, "bad static string"); + if (auto n = in(ctx->cache->classes[typ->name].methods, *name)) { + auto &mt = ctx->cache->overloads[*n]; + for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { + auto &method = mt[mti]; + if (endswith(method.name, ":dispatch") || + !ctx->cache->functions[method.name].type) + continue; + if (method.age <= ctx->age) { + if (typ->getHeterogenousTuple()) { + auto &ast = ctx->cache->functions[method.name].ast; + if (ast->hasAttr("autogenerated") && + (endswith(ast->name, ".__iter__:0") || + endswith(ast->name, ".__getitem__:0"))) { + // ignore __getitem__ and other heterogenuous methods + continue; + } + } + stmt->rhs = N(method.name); + block.push_back(wrap(stmt->clone())); + } + } + } + } else { + error("bad call to fn_overloads"); + } + } else if (fn && startswith(fn->value, "std.internal.builtin.staticenumerate")) { + if (vars.size() != 2) + error("expected two items"); + if (auto fna = ctx->getFunctionArgs(fn->type)) { + auto [generics, args] = *fna; + auto typ = args[0]->getRecord(); + if (!typ) + error("staticenumerate needs a tuple"); + for (size_t i = 0; i < typ->args.size(); i++) { + auto b = N( + {N(N(vars[0]), N(i), + NT(NT("Static"), NT("int"))), + N(N(vars[1]), + N(iter->getCall()->args[0].value->clone(), + N(i)))}); + block.push_back(wrap(b)); + } + } else { + error("bad call to staticenumerate"); + } + } else { + return {false, {}}; + } + return {true, block}; +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index b39553e1..7680839a 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -21,8 +21,10 @@ using namespace types; void TypecheckVisitor::visit(UnaryExpr *expr) { transform(expr->expr); + static std::unordered_map> + staticOps = {{StaticValue::INT, {"-", "+", "!"}}, {StaticValue::STRING, {"@"}}}; // Handle static expressions - if (expr->expr->isStatic()) { + if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) { resultExpr = evaluateStaticUnary(expr); return; } @@ -569,6 +571,8 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { unify(expr->type, ctx->getType("bool")); return nullptr; } + if (expr->lexpr->isType() && expr->rexpr->isType()) + return transform(N(lc->realizedName() == rc->realizedName())); if (!lc->getRecord() && !rc->getRecord()) { // Both reference types: `return lhs.__raw__() == rhs.__raw__()` return transform( diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 7f1131d2..1392d755 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -27,9 +27,8 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, const StmtPtr &stmts) { if (!s) { v.error("cannot typecheck the program"); } - if (s->getSuite()) { + if (s->getSuite()) v.prepareVTables(); - } return s; } @@ -66,6 +65,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { ctx->changedNodes++; } realize(typ); + LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), expr, expr->isDone() ? "[done]" : ""); return expr; } @@ -215,75 +215,118 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod( return m.empty() ? nullptr : m[0]; } -/// Select the best method among the provided methods given the list of arguments. +// Search expression tree for a identifier +class IdSearchVisitor : public CallbackASTVisitor { + std::string what; + bool result; + +public: + IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {} + bool transform(const std::shared_ptr &expr) override { + if (result) + return result; + IdSearchVisitor v(what); + if (expr) + expr->accept(v); + return v.result; + } + bool transform(const std::shared_ptr &stmt) override { + if (result) + return result; + IdSearchVisitor v(what); + if (stmt) + stmt->accept(v); + return v.result; + } + void visit(IdExpr *expr) override { + if (expr->value == what) + result = true; + } +}; + +/// Check if a function can be called with the given arguments. /// See @c reorderNamedArgs for details. -std::vector -TypecheckVisitor::findMatchingMethods(const types::ClassTypePtr &typ, - const std::vector &methods, - const std::vector &args) { - // Pick the last method that accepts the given arguments. - std::vector results; - for (const auto &mi : methods) { - if (!mi) - continue; // avoid overloads that have not been seen yet - auto method = ctx->instantiate(mi, typ)->getFunc(); - std::vector> reordered; - auto score = ctx->reorderNamedArgs( - method.get(), args, - [&](int s, int k, const std::vector> &slots, bool _) { - for (int si = 0; si < slots.size(); si++) { - if (method->ast->args[si].status == Param::Generic) { - if (slots[si].empty()) - reordered.push_back({nullptr, 0}); - else - reordered.push_back({args[slots[si][0]].value->type, slots[si][0]}); - } else if (si == s || si == k || slots[si].size() != 1) { - // Ignore *args, *kwargs and default arguments +int TypecheckVisitor::canCall(const types::FuncTypePtr &fn, + const std::vector &args) { + std::vector> reordered; + auto niGenerics = fn->ast->getNonInferrableGenerics(); + auto score = ctx->reorderNamedArgs( + fn.get(), args, + [&](int s, int k, const std::vector> &slots, bool _) { + for (int si = 0; si < slots.size(); si++) { + if (fn->ast->args[si].status == Param::Generic) { + if (slots[si].empty()) { + // is this "real" type? + if (in(niGenerics, fn->ast->args[si].name) && + !fn->ast->args[si].defaultValue) + return -1; reordered.push_back({nullptr, 0}); } else { reordered.push_back({args[slots[si][0]].value->type, slots[si][0]}); } - } - return 0; - }, - [](error::Error, const SrcInfo &, const std::string &) { return -1; }); - for (int ai = 0, mai = 0, gi = 0; score != -1 && ai < reordered.size(); ai++) { - auto expectTyp = method->ast->args[ai].status == Param::Normal - ? method->getArgTypes()[mai++] - : method->funcGenerics[gi++].type; - auto [argType, argTypeIdx] = reordered[ai]; - if (!argType) - continue; - if (method->ast->args[ai].status != Param::Normal) { - // Check if this is a good generic! - if (expectTyp && expectTyp->isStaticType()) { - if (!args[argTypeIdx].value->isStatic()) { - score = -1; - break; + } else if (si == s || si == k || slots[si].size() != 1) { + // Ignore *args, *kwargs and default arguments + reordered.push_back({nullptr, 0}); } else { - argType = Type::makeStatic(ctx->cache, args[argTypeIdx].value); + reordered.push_back({args[slots[si][0]].value->type, slots[si][0]}); } - } else { - /// TODO: check if these are real types or if traits are satisfied - continue; } - } - try { - ExprPtr dummy = std::make_shared(""); - dummy->type = argType; - dummy->setDone(); - wrapExpr(dummy, expectTyp, method); - types::Type::Unification undo; - if (dummy->type->unify(expectTyp.get(), &undo) >= 0) { - undo.undo(); - } else { + return 0; + }, + [](error::Error, const SrcInfo &, const std::string &) { return -1; }); + for (int ai = 0, mai = 0, gi = 0; score != -1 && ai < reordered.size(); ai++) { + auto expectTyp = fn->ast->args[ai].status == Param::Normal + ? fn->getArgTypes()[mai++] + : fn->funcGenerics[gi++].type; + auto [argType, argTypeIdx] = reordered[ai]; + if (!argType) + continue; + if (fn->ast->args[ai].status != Param::Normal) { + // Check if this is a good generic! + if (expectTyp && expectTyp->isStaticType()) { + if (!args[argTypeIdx].value->isStatic()) { score = -1; + break; + } else { + argType = Type::makeStatic(ctx->cache, args[argTypeIdx].value); } - } catch (const exc::ParserException &) { - // Ignore failed wraps + } else { + /// TODO: check if these are real types or if traits are satisfied + continue; + } + } + try { + ExprPtr dummy = std::make_shared(""); + dummy->type = argType; + dummy->setDone(); + wrapExpr(dummy, expectTyp, fn); + types::Type::Unification undo; + if (dummy->type->unify(expectTyp.get(), &undo) >= 0) { + undo.undo(); + } else { score = -1; } + } catch (const exc::ParserException &) { + // Ignore failed wraps + score = -1; } + } + return score; +} + +/// Select the best method among the provided methods given the list of arguments. +/// See @c reorderNamedArgs for details. +std::vector +TypecheckVisitor::findMatchingMethods(const types::ClassTypePtr &typ, + const std::vector &methods, + const std::vector &args) { + // Pick the last method that accepts the given arguments. + std::vector results; + for (const auto &mi : methods) { + if (!mi) + continue; // avoid overloads that have not been seen yet + auto method = ctx->instantiate(mi, typ)->getFunc(); + int score = canCall(method, args); if (score != -1) { results.push_back(mi); } @@ -307,16 +350,24 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType, const FuncTypePtr &callee, bool allowUnwrap) { auto expectedClass = expectedType->getClass(); auto exprClass = expr->getType()->getClass(); + auto doArgWrap = + !callee || !callee->ast->hasAttr("std.internal.attributes.no_argument_wrap"); + if (!doArgWrap) + return true; + auto doTypeWrap = + !callee || !callee->ast->hasAttr("std.internal.attributes.no_type_wrap"); if (callee && expr->isType()) { auto c = expr->type->getClass(); if (!c) return false; - if (c->getRecord()) - expr = transform(N(expr, N())); - else - expr = transform(N( - N("__internal__.class_ctr:0"), - std::vector{{"T", expr}, {"", N()}})); + if (doTypeWrap) { + if (c->getRecord()) + expr = transform(N(expr, N())); + else + expr = transform(N( + N("__internal__.class_ctr:0"), + std::vector{{"T", expr}, {"", N()}})); + } } std::unordered_set hints = {"Generator", "float", TYPE_OPTIONAL, @@ -419,4 +470,32 @@ ExprPtr TypecheckVisitor::castToSuperClass(ExprPtr expr, ClassTypePtr superTyp, dist, typExpr)); } +/// Unpack a Tuple or KwTuple expression into (name, type) vector. +/// Name is empty when handling Tuple; otherwise it matches names of KwTuple. +std::shared_ptr>> +TypecheckVisitor::unpackTupleTypes(ExprPtr expr) { + auto ret = std::make_shared>>(); + if (auto tup = expr->origExpr->getTuple()) { + for (auto &a : tup->items) { + transform(a); + if (!a->getType()->getClass()) + return nullptr; + ret->push_back({"", a->getType()}); + } + } else if (auto kw = expr->origExpr->getCall()) { // origExpr? + auto kwCls = in(ctx->cache->classes, expr->getType()->getClass()->name); + seqassert(kwCls, "cannot find {}", expr->getType()->getClass()->name); + for (size_t i = 0; i < kw->args.size(); i++) { + auto &a = kw->args[i].value; + transform(a); + if (!a->getType()->getClass()) + return nullptr; + ret->push_back({kwCls->fields[i].name, a->getType()}); + } + } else { + return nullptr; + } + return ret; +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 8b07c196..b89a0d04 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -139,11 +139,13 @@ class TypecheckVisitor : public CallbackASTVisitor { ExprPtr transformStaticLen(CallExpr *expr); ExprPtr transformHasAttr(CallExpr *expr); ExprPtr transformGetAttr(CallExpr *expr); + ExprPtr transformSetAttr(CallExpr *expr); ExprPtr transformCompileError(CallExpr *expr); ExprPtr transformTupleFn(CallExpr *expr); ExprPtr transformTypeFn(CallExpr *expr); ExprPtr transformRealizedFn(CallExpr *expr); ExprPtr transformStaticPrintFn(CallExpr *expr); + std::pair transformInternalStaticFn(CallExpr *expr); std::vector getSuperTypes(const types::ClassTypePtr &cls); void addFunctionGenerics(const types::FuncType *t); std::string generatePartialStub(const std::vector &mask, types::FuncType *fn); @@ -213,6 +215,7 @@ class TypecheckVisitor : public CallbackASTVisitor { types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ, const std::string &member, const std::vector> &args); + int canCall(const types::FuncTypePtr &, const std::vector &); std::vector findMatchingMethods(const types::ClassTypePtr &typ, const std::vector &methods, @@ -228,6 +231,13 @@ class TypecheckVisitor : public CallbackASTVisitor { friend class Cache; friend class types::CallableTrait; friend class types::UnionType; + +private: // Helpers + std::shared_ptr>> + unpackTupleTypes(ExprPtr); + std::pair>> + transformStaticLoopCall(const std::vector &, ExprPtr, + std::function(StmtPtr)>); }; } // namespace codon::ast diff --git a/codon/runtime/exc.cpp b/codon/runtime/exc.cpp index e5776c85..e75ade06 100644 --- a/codon/runtime/exc.cpp +++ b/codon/runtime/exc.cpp @@ -128,6 +128,7 @@ struct SeqExcHeader_t { seq_str_t file; seq_int_t line; seq_int_t col; + void *python_type; }; void seq_exc_init() { diff --git a/codon/runtime/lib.cpp b/codon/runtime/lib.cpp index f3d148ae..ba99a6cb 100644 --- a/codon/runtime/lib.cpp +++ b/codon/runtime/lib.cpp @@ -165,6 +165,22 @@ SEQ_FUNC void *seq_alloc_atomic(size_t n) { #endif } +SEQ_FUNC void *seq_alloc_uncollectable(size_t n) { +#if USE_STANDARD_MALLOC + return malloc(n); +#else + return GC_MALLOC_UNCOLLECTABLE(n); +#endif +} + +SEQ_FUNC void *seq_alloc_atomic_uncollectable(size_t n) { +#if USE_STANDARD_MALLOC + return malloc(n); +#else + return GC_MALLOC_ATOMIC_UNCOLLECTABLE(n); +#endif +} + SEQ_FUNC void *seq_calloc(size_t m, size_t n) { #if USE_STANDARD_MALLOC return calloc(m, n); diff --git a/codon/util/common.cpp b/codon/util/common.cpp index 7eb1c561..bf08dcf5 100644 --- a/codon/util/common.cpp +++ b/codon/util/common.cpp @@ -2,6 +2,7 @@ #include "common.h" +#include "llvm/Support/Path.h" #include #include #include diff --git a/codon/util/common.h b/codon/util/common.h index 2d3ce908..d656bbf8 100644 --- a/codon/util/common.h +++ b/codon/util/common.h @@ -2,7 +2,6 @@ #pragma once -#include "llvm/Support/Path.h" #include #include #include diff --git a/codon/util/jupyter.cpp b/codon/util/jupyter.cpp new file mode 100644 index 00000000..c9e6a4f9 --- /dev/null +++ b/codon/util/jupyter.cpp @@ -0,0 +1,15 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#include "codon/util/jupyter.h" +#include + +namespace codon { +int startJupyterKernel(const std::string &argv0, + const std::vector &plugins, + const std::string &configPath) { + fprintf(stderr, + "Jupyter support not included. Please install Codon Jupyter plugin.\n"); + return EXIT_FAILURE; +} + +} // namespace codon diff --git a/codon/util/jupyter.h b/codon/util/jupyter.h new file mode 100644 index 00000000..f000d420 --- /dev/null +++ b/codon/util/jupyter.h @@ -0,0 +1,12 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#pragma once + +#include +#include + +namespace codon { +int startJupyterKernel(const std::string &argv0, + const std::vector &plugins, + const std::string &configPath); +} // namespace codon diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 1ac2ca22..09a72b38 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -25,6 +25,7 @@ * [Python integration](interop/python.md) * [Python decorator](interop/decorator.md) +* [Python extensions](interop/pyext.md) * [C/C++ integration](interop/cpp.md) * [Jupyter integration](interop/jupyter.md) diff --git a/docs/advanced/build.md b/docs/advanced/build.md index f82d0369..0a92b362 100644 --- a/docs/advanced/build.md +++ b/docs/advanced/build.md @@ -15,34 +15,74 @@ cmake -S llvm-project/llvm -B llvm-project/build -G Ninja \ -DLLVM_ENABLE_TERMINFO=OFF \ -DLLVM_TARGETS_TO_BUILD=all cmake --build llvm-project/build -cmake --install llvm-project/build +cmake --install llvm-project/build --prefix=llvm-project/install ``` +You can also add `-DLLVM_ENABLE_PROJECTS=clang` if you do not have `clang` installed +on your system. We also recommend setting a local prefix during installation to +avoid clashes with the system LLVM. + # Build The following can generally be used to build Codon. The build process will automatically download and build several smaller dependencies. -``` bash +```bash cmake -S . -B build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR=$(llvm-config --cmakedir) \ -DCMAKE_C_COMPILER=clang \ -DCMAKE_CXX_COMPILER=clang++ cmake --build build --config Release +cmake --install build --prefix=install ``` -This will produce the `codon` executable in the `build` directory, as -well as `codon_test` which runs the test suite. Additionally, a number -of shared libraries are produced: +This will produce the `codon` executable in the `install/bin` directory, as +well as `codon_test` in the `build` directory which runs the test suite. +Additionally, a number of shared libraries are produced in `install/lib/codon`: -- `libcodonc`: The compiler library used by the `codon` compiler. +- `libcodonc`: The compiler library used by the `codon` command-line tool. - `libcodonrt`: The runtime library used during execution. - `libomp`: OpenMP runtime used to execute parallel code. -# Build options +{% hint style="warning" %} +Make sure the `llvm-config` being used corresponds to Codon's LLVM. You can also use +`-DLLVM_DIR=llvm-project/install/lib/cmake/llvm` on the first `cmake` command if you +followed the instructions above for compiling LLVM. +{% endhint %} + +# GPU support + +Add `-DCODON_GPU=ON` to the first `cmake` command above to enable GPU support. + +# Jupyter support + +To enable Jupyter support, you will need to build the Jupyter plugin: + +```bash +# Linux version: +cmake -S jupyter -B jupyter/build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DLLVM_DIR=$(llvm-config --cmakedir) \ + -DCODON_PATH=install \ + -DOPENSSL_ROOT_DIR=$(openssl version -d | cut -d' ' -f2 | tr -d '"') \ + -DOPENSSL_CRYPTO_LIBRARY=/usr/lib64/libssl.so \ + -DXEUS_USE_DYNAMIC_UUID=ON +# n.b. OPENSSL_CRYPTO_LIBRARY might differ on your system. -The following additional flags can be passed to CMake: +# On macOS, do this instead: +OPENSSL_ROOT_DIR=/usr/local/opt/openssl cmake -S jupyter -B jupyter/build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DLLVM_DIR=$(llvm-config --cmakedir) \ + -DCODON_PATH=install -- *CODON_JUPYTER* = `ON|OFF`: Enable or disable Jupyter support (default: `OFF`) -- *CODON_GPU* = `ON|OFF`: Enable or disable GPU support; requires CUDA (default: `OFF`) +# Then: +cmake --build jupyter/build +cmake --install jupyter/build +``` diff --git a/docs/interop/pyext.md b/docs/interop/pyext.md new file mode 100644 index 00000000..a62976ef --- /dev/null +++ b/docs/interop/pyext.md @@ -0,0 +1,189 @@ +Codon includes a build mode called `pyext` for generating +[Python extensions](https://docs.python.org/3/extending/extending.html) +(which are traditionally written in C, C++ or Cython): + +``` bash +codon build -pyext extension.codon # add -release to enable optimizations +``` + +`codon build -pyext` accepts the following options: + +- `-o `: Writes the compilation result to the specified file. +- `-module `: Specifies the generated Python module's name. + +{% hint style="warning" %} +It is recommended to use the `pyext` build mode with Python versions 3.9 +and up. +{% endhint %} + +# Functions + +Extension functions written in Codon should generally be fully typed: + +``` python +def foo(a: int, b: float, c: str): # return type will be deduced + return a * b + float(c) +``` + +The `pyext` build mode will automatically generate all the necessary wrappers +and hooks for converting a function written in Codon into a function that's +callable from Python. + +Function arguments that are not explicitly typed will be treated as generic +Python objects, and operated on through the CPython API. + +Function overloads are also possible in Codon: + +``` python +def bar(x: int): + return x + 2 + +@overload +def bar(x: str): + return x * 2 +``` + +This will result in a single Python function `bar()` that dispatches to the +correct Codon `bar()` at runtime based on the argument's type (or raise a +`TypeError` on an invalid input type). + +# Types + +Codon class definitions can also be converted to Python extension types via +the `@dataclass(python=True)` decorator: + +``` python +@dataclass(python=True) +class Vec: + x: float + y: float + + def __init__(self, x: float = 0.0, y: float = 0.0): + self.x = x + self.y = y + + def __add__(self, other: Vec): + return Vec(self.x + other.x, self.y + other.y) + + def __add__(self, other: float): + return Vec(self.x + other, self.y + other) + + def __repr__(self): + return f'Vec({self.x}, {self.y})' +``` + +Now in Python (assuming we compile to a module `vec`): + +``` python +from vec import Vec + +a = Vec(x=3.0, y=4.0) # Vec(3.0, 4.0) +b = a + Vec(1, 2) # Vec(4.0, 6.0) +c = b + 10.0 # Vec(14.0, 16.0) +``` + +# Building with `setuptools` + +Codon's `pyext` build mode can be used with `setuptools`. Here is a minimal example: + +``` python +# setup.py +import os +import sys +import shutil +from pathlib import Path +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + +# Find Codon +codon_path = os.environ.get('CODON_DIR') +if not codon_path: + c = shutil.which('codon') + if c: + codon_path = Path(c).parent / '..' +else: + codon_path = Path(codon_path) +for path in [ + os.path.expanduser('~') + '/.codon', + os.getcwd() + '/..', +]: + path = Path(path) + if not codon_path and path.exists(): + codon_path = path + break + +if ( + not codon_path + or not (codon_path / 'include' / 'codon').exists() + or not (codon_path / 'lib' / 'codon').exists() +): + print( + 'Cannot find Codon.', + 'Please either install Codon (https://github.com/exaloop/codon),', + 'or set CODON_DIR if Codon is not in PATH.', + file=sys.stderr, + ) + sys.exit(1) +codon_path = codon_path.resolve() +print('Found Codon:', str(codon_path)) + +# Build with Codon +class CodonExtension(Extension): + def __init__(self, name, source): + self.source = source + super().__init__(name, sources=[], language='c') + +class BuildCodonExt(build_ext): + def build_extensions(self): + pass + + def run(self): + inplace, self.inplace = self.inplace, False + super().run() + for ext in self.extensions: + self.build_codon(ext) + if inplace: + self.copy_extensions_to_source() + + def build_codon(self, ext): + extension_path = Path(self.get_ext_fullpath(ext.name)) + build_dir = Path(self.build_temp) + os.makedirs(build_dir, exist_ok=True) + os.makedirs(extension_path.parent.absolute(), exist_ok=True) + + codon_cmd = str(codon_path / 'bin' / 'codon') + optimization = '-debug' if self.debug else '-release' + self.spawn([codon_cmd, 'build', optimization, '--relocation-model=pic', '-pyext', + '-o', str(extension_path) + ".o", '-module', ext.name, ext.source]) + + ext.runtime_library_dirs = [str(codon_path / 'lib' / 'codon')] + self.compiler.link_shared_object( + [str(extension_path) + '.o'], + str(extension_path), + libraries=['codonrt'], + library_dirs=ext.runtime_library_dirs, + runtime_library_dirs=ext.runtime_library_dirs, + extra_preargs=['-Wl,-rpath,@loader_path'], + debug=self.debug, + build_temp=self.build_temp, + ) + self.distribution.codon_lib = extension_path + +setup( + name='mymodule', + version='0.1', + packages=['mymodule'], + ext_modules=[ + CodonExtension('mymodule', 'mymodule.codon'), + ], + cmdclass={'build_ext': BuildCodonExt} +) +``` + +Then, for example, we can build with: + +``` bash +python3 setup.py build_ext --inplace +``` + +Finally, we can `import mymodule` in Python and use the module. diff --git a/docs/interop/python.md b/docs/interop/python.md index 3994d201..25c192ba 100644 --- a/docs/interop/python.md +++ b/docs/interop/python.md @@ -39,6 +39,16 @@ print(multiply(3, 4)) # 12 (Be sure the `PYTHONPATH` environment variable includes the path of *mymodule.py*!) +`from python import` does not need to specify explicit types, in which case +Codon will operate directly on the Python objects, and convert Codon types +to Python types as necessary: + +``` python +from python import numpy as np # Codon will call NumPy through CPython's API +x = np.array([1, 2, 3, 4]) * 10 +print(x) # [10 20 30 40] +``` + # `@python` Codon programs can contain functions that will be executed by Python via @@ -62,3 +72,24 @@ def myrange(n: int) -> List[int]: print(myrange(5)) # [0, 1, 2, 3, 4] ``` + +# Data conversions + +Codon uses two new magic methods to transfer data to and from Python: + +- `__to_py__`: Produces a Python object (`PyObject*` in C) given a Codon object. +- `__from_py__`: Produces a Codon object given a Python object. + +``` python +import python # needed to initialize the Python runtime + +o = (42).__to_py__() # type of 'o' is 'cobj', equivalent to a pointer in C +print(o) # 0x100e00610 + +n = int.__from_py__(o) +print(n) # 42 +``` + +Codon stores the results of `__to_py__` calls by wrapping them in an instance of +a new class called `pyobj`, which correctly handles the underlying Python object's +reference count. All operations on `pyobj`s then go through CPython's API. diff --git a/docs/intro/faq.md b/docs/intro/faq.md index 5ac32284..111102b0 100644 --- a/docs/intro/faq.md +++ b/docs/intro/faq.md @@ -33,6 +33,10 @@ handle cases where specific Python libraries or dynamism are required. Codon differs in a few places in order to eliminate any dynamic runtime or virtual machine, and thereby attain much better performance. +- **Cython?** Like Cython, Codon has a [Python-extension build mode](../interop/pyext.md) that + compiles to Python extension modules, allowing Codon-compiled code to be imported and called + from plain Python. + - **C++?** Codon often generates the same code as an equivalent C or C++ program. Codon can sometimes generate *better* code than C/C++ compilers for a variety of reasons, such as better container implementations, the fact that Codon does not use object files and diff --git a/docs/intro/releases.md b/docs/intro/releases.md index 46413077..9a6d1fed 100644 --- a/docs/intro/releases.md +++ b/docs/intro/releases.md @@ -2,6 +2,40 @@ Below you can find release notes for each major Codon release, listing improvements, updates, optimizations and more for each new version. +These release notes generally do not include small bug fixes. See the +[closed issues](https://github.com/exaloop/codon/issues?q=is%3Aissue+is%3Aclosed)) +for more information. + +# v0.16 + +## Python extensions + +A new build mode is added to `codon` called `pyext` which compiles +to Python extension modules, allowing Codon code to be imported and +called directly from Python (similar to Cython). Please see the +[docs](../interop/pyext.md) for more information and usage examples. + +## Standard library updates + +- Various additions to the standard library, such as `math.fsum()` and + the built-in `pow()`. + +- Added `complex64`, which is a complex number with 32-bit float real and + imaginary components. + +- Better `Int[N]` and `UInt[N]` support: can now convert ints wider than + 64-bit to string; now supports more operators. + +## More Python-specific optimizations + +New optimizations for specific patterns including `any()`/`all()` and +multiple list concatenations. These patterns are now recognized and +optimized in Codon's IR. + +## Static expressions + +Codon now supports more compile-time static functions, such as `staticenumerate`. + # v0.15 ## Union types diff --git a/jupyter/CMakeLists.txt b/jupyter/CMakeLists.txt new file mode 100644 index 00000000..3b42fc94 --- /dev/null +++ b/jupyter/CMakeLists.txt @@ -0,0 +1,101 @@ +cmake_minimum_required(VERSION 3.14) +project( + CodonJupyter + VERSION "0.1" + HOMEPAGE_URL "https://github.com/exaloop/codon" + DESCRIPTION "Jupyter support for Codon") + +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + +if(NOT CODON_PATH) + set(CODON_PATH "$ENV{HOME}/.codon") +endif() +message(STATUS "Found Codon in ${CODON_PATH}") +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${CODON_PATH}/lib/codon/" CACHE PATH "Use the existing Codon installation" FORCE) +endif(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -pedantic -fvisibility-inlines-hidden -Wno-return-type-c-linkage -Wno-gnu-zero-variadic-macro-arguments -Wno-deprecated-declarations" + ) +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-type") +endif() +set(CMAKE_CXX_FLAGS_DEBUG "-g") +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-limit-debug-info") +endif() +set(CMAKE_CXX_FLAGS_RELEASE "-O3") +include_directories(.) + +find_package(LLVM REQUIRED CONFIG) +separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) +add_definitions(${LLVM_DEFINITIONS_LIST}) + +set(CPM_DOWNLOAD_VERSION 0.32.3) +set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION})) + message(STATUS "Downloading CPM.cmake...") + file(DOWNLOAD https://github.com/TheLartians/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake ${CPM_DOWNLOAD_LOCATION}) +endif() +include(${CPM_DOWNLOAD_LOCATION}) +CPMAddPackage( + NAME libzmq + VERSION 4.3.4 + URL https://github.com/zeromq/libzmq/releases/download/v4.3.4/zeromq-4.3.4.tar.gz + EXCLUDE_FROM_ALL YES + OPTIONS "WITH_PERF_TOOL OFF" + "ZMQ_BUILD_TESTS OFF" + "ENABLE_CPACK OFF" + "BUILD_SHARED ON" + "WITH_LIBSODIUM OFF" + "WITH_TLS OFF") +CPMAddPackage( + NAME cppzmq + URL https://github.com/zeromq/cppzmq/archive/refs/tags/v4.8.1.tar.gz + VERSION 4.8.1 + EXCLUDE_FROM_ALL YES + OPTIONS "CPPZMQ_BUILD_TESTS OFF") +CPMAddPackage( + NAME xtl + GITHUB_REPOSITORY "xtensor-stack/xtl" + VERSION 0.7.3 + GIT_TAG 0.7.3 + EXCLUDE_FROM_ALL YES + OPTIONS "BUILD_TESTS OFF") +CPMAddPackage( + NAME json + GITHUB_REPOSITORY "nlohmann/json" + VERSION 3.10.1) +CPMAddPackage( + NAME xeus + GITHUB_REPOSITORY "jupyter-xeus/xeus" + VERSION 2.2.0 + GIT_TAG 2.2.0 + EXCLUDE_FROM_ALL YES + PATCH_COMMAND patch -N -u CMakeLists.txt --ignore-whitespace -b ${CMAKE_SOURCE_DIR}/xeus.patch || true + OPTIONS "BUILD_EXAMPLES OFF" + "XEUS_BUILD_SHARED_LIBS OFF" + "XEUS_STATIC_DEPENDENCIES ON" + "CMAKE_POSITION_INDEPENDENT_CODE ON" + "XEUS_DISABLE_ARCH_NATIVE ON" + "XEUS_USE_DYNAMIC_UUID ${XEUS_USE_DYNAMIC_UUID}") +if (xeus_ADDED) + install(TARGETS nlohmann_json EXPORT xeus-targets) +endif() + +# Codon Jupyter library +set(CODON_JUPYTER_FILES jupyter.h jupyter.cpp) +add_library(codon_jupyter SHARED ${CODON_JUPYTER_FILES}) +target_include_directories(codon_jupyter PRIVATE "${CODON_PATH}/include" ${LLVM_INCLUDE_DIRS}) +add_dependencies(codon_jupyter xeus-static nlohmann_json) +target_link_directories(codon_jupyter PRIVATE "${CODON_PATH}/lib/codon") +target_link_libraries(codon_jupyter PRIVATE xeus-static codonc) + +install(TARGETS codon_jupyter DESTINATION .) diff --git a/extra/jupyter/jupyter.cpp b/jupyter/jupyter.cpp similarity index 98% rename from extra/jupyter/jupyter.cpp rename to jupyter/jupyter.cpp index 3a82d0f6..3c62243c 100644 --- a/extra/jupyter/jupyter.cpp +++ b/jupyter/jupyter.cpp @@ -1,8 +1,7 @@ -// Copyright (C) 2022 Exaloop Inc. +// Copyright (C) 2022-2023 Exaloop Inc. #include "jupyter.h" -#ifdef CODON_JUPYTER #include #include #include @@ -141,4 +140,3 @@ int startJupyterKernel(const std::string &argv0, } } // namespace codon -#endif diff --git a/extra/jupyter/jupyter.h b/jupyter/jupyter.h similarity index 93% rename from extra/jupyter/jupyter.h rename to jupyter/jupyter.h index 8959e171..b379e00f 100644 --- a/extra/jupyter/jupyter.h +++ b/jupyter/jupyter.h @@ -1,7 +1,6 @@ -// Copyright (C) 2022 Exaloop Inc. +// Copyright (C) 2022-2023 Exaloop Inc. #pragma once -#ifdef CODON_JUPYTER #include #include #include @@ -42,4 +41,3 @@ int startJupyterKernel(const std::string &argv0, const std::string &configPath); } // namespace codon -#endif diff --git a/jupyter/share/jupyter/kernels/codon/kernel.json b/jupyter/share/jupyter/kernels/codon/kernel.json new file mode 100644 index 00000000..0651d080 --- /dev/null +++ b/jupyter/share/jupyter/kernels/codon/kernel.json @@ -0,0 +1,9 @@ +{ + "display_name": "Codon", + "argv": [ + "/usr/local/bin/", + "jupyter", + "{connection_file}" + ], + "language": "python" +} diff --git a/extra/jupyter/share/jupyter/kernels/codon/kernel.json.in b/jupyter/share/jupyter/kernels/codon/kernel.json.in similarity index 100% rename from extra/jupyter/share/jupyter/kernels/codon/kernel.json.in rename to jupyter/share/jupyter/kernels/codon/kernel.json.in diff --git a/cmake/xeus.patch b/jupyter/xeus.patch similarity index 100% rename from cmake/xeus.patch rename to jupyter/xeus.patch diff --git a/scripts/Dockerfile.codon-jupyter b/scripts/Dockerfile.codon-jupyter new file mode 100644 index 00000000..fc93bd7d --- /dev/null +++ b/scripts/Dockerfile.codon-jupyter @@ -0,0 +1,52 @@ +FROM exaloop/codon-llvm:15.0.1 +ENV pass="codon-jupyter" + +# Install dependencies +RUN yum -y install openssl-devel libsodium-devel libuuid-devel + +# Build Codon core +RUN git clone -b pyext https://github.com/exaloop/codon /github/codon +RUN cmake3 -S /github/codon -B /github/codon/build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=/opt/llvm-codon/bin/clang \ + -DCMAKE_CXX_COMPILER=/opt/llvm-codon/bin/clang++ \ + -DLLVM_DIR=/opt/llvm-codon/lib/cmake/llvm \ + -DCMAKE_INSTALL_PREFIX=/opt/codon +RUN cmake3 --build /github/codon/build +RUN cmake3 --install /github/codon/build + +# Build Codon Jupyter support +RUN cmake3 -S /github/codon/jupyter -B /github/codon/jupyter/build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=/opt/llvm-codon/bin/clang \ + -DCMAKE_CXX_COMPILER=/opt/llvm-codon/bin/clang++ \ + -DLLVM_DIR=/opt/llvm-codon/lib/cmake/llvm \ + -DCODON_PATH=/opt/codon \ + -DOPENSSL_ROOT_DIR=$(openssl version -d | cut -d' ' -f2 | tr -d '"') \ + -DOPENSSL_CRYPTO_LIBRARY=/usr/lib64/libssl.so \ + -DXEUS_USE_DYNAMIC_UUID=ON +RUN cmake3 --build /github/codon/jupyter/build +RUN cmake3 --install /github/codon/jupyter/build + +# Build Seq (bioinformatics plugin) for Codon +RUN git clone -b develop https://github.com/exaloop/seq /github/seq +RUN cmake3 -S /github/seq -B /github/seq/build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=/opt/llvm-codon/bin/clang \ + -DCMAKE_CXX_COMPILER=/opt/llvm-codon/bin/clang++ \ + -DLLVM_DIR=/opt/llvm-codon/lib/cmake/llvm \ + -DCODON_PATH=/opt/codon +RUN cmake3 --build /github/seq/build +RUN cmake3 --install /github/seq/build + +# Set up Codon Jupyter kernel +RUN pip3 install ipywidgets==7.6.5 numpy matplotlib pandas scipy jupyter plotly +RUN mkdir -p /usr/share/jupyter/kernels/codon +RUN echo '{"display_name": "Codon", "argv": [ "/opt/codon/bin/codon", "jupyter", "-plugin", "seq", "{connection_file}" ], "language": "python"}' > /usr/share/jupyter/kernels/codon/kernel.json + +# Launch Jupyter +ENV CODON_PYTHON="/usr/lib64/libpython3.so" +CMD jupyter notebook --port=8888 --no-browser --ip=0.0.0.0 --allow-root --NotebookApp.token=${pass} diff --git a/scripts/Dockerfile.llvm-build b/scripts/Dockerfile.llvm-build index e2582216..8e145637 100644 --- a/scripts/Dockerfile.llvm-build +++ b/scripts/Dockerfile.llvm-build @@ -1,7 +1,10 @@ FROM quay.io/pypa/manylinux2014_x86_64 RUN yum -y update -RUN yum -y install devtoolset-7 ninja-build cmake3 +RUN yum -y install devtoolset-7 ninja-build cmake3 bzip2-devel +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --upgrade twine setuptools wheel + RUN mkdir -p /opt/llvm-codon RUN git clone --depth 1 -b codon https://github.com/exaloop/llvm-project /github/llvm-src RUN scl enable devtoolset-7 -- \ @@ -19,4 +22,5 @@ RUN scl enable devtoolset-7 -- \ RUN scl enable devtoolset-7 -- cmake3 --build /github/llvm-src/llvm/build RUN cmake3 --install /github/llvm-src/llvm/build RUN cd /github/llvm-src && tar cjvf /opt/llvm-codon-$(git rev-parse --short HEAD).tar.bz2 -C /opt llvm-codon/ -CMD cp /opt/llvm-codon-*.tar.bz2 /mnt/ +RUN cp /opt/llvm-codon-*.tar.bz2 /mnt/ +CMD echo "done" diff --git a/stdlib/internal/__init__.codon b/stdlib/internal/__init__.codon index fcf53831..21b15a76 100644 --- a/stdlib/internal/__init__.codon +++ b/stdlib/internal/__init__.codon @@ -3,6 +3,7 @@ # Core library from internal.attributes import * +from internal.static import static_print as __static_print__ from internal.types.ptr import * from internal.types.str import * from internal.types.int import * @@ -46,3 +47,5 @@ import internal.python if __py_numerics__: import internal.pynumerics +if __py_extension__: + internal.python.ensure_initialized() diff --git a/stdlib/internal/attributes.codon b/stdlib/internal/attributes.codon index cdc50612..33143109 100644 --- a/stdlib/internal/attributes.codon +++ b/stdlib/internal/attributes.codon @@ -32,6 +32,10 @@ def no_side_effect(): def nocapture(): pass +@__attribute__ +def pycapture(): + pass + @__attribute__ def derives(): pass @@ -59,3 +63,11 @@ def realize_without_self(): @__attribute__ def virtual(): pass + +@__attribute__ +def no_argument_wrap(): + pass + +@__attribute__ +def no_type_wrap(): + pass diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index be18e3cf..46854ecf 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -221,6 +221,11 @@ def enumerate(x, start: int = 0): yield (i, a) i += 1 +def staticenumerate(tup): + i = -1 + return tuple(((i := i + 1), t) for t in tup) + i + def echo(x): """ Print and return argument @@ -318,6 +323,29 @@ def oct(n): def hex(n): return _int_format(n.__index__(), 16, "0x") +def pow(base: float, exp: float): + return base ** exp + +@overload +def pow(base: int, exp: int, mod: Optional[int] = None): + if exp < 0: + raise ValueError("pow() negative int exponent not supported") + + if mod is not None: + if mod == 0: + raise ValueError("pow() 3rd argument cannot be 0") + base %= mod + + result = 1 + while exp > 0: + if exp & 1: + x = result * base + result = x % mod if mod is not None else x + y = base * base + base = y % mod if mod is not None else y + exp >>= 1 + return result % mod if mod is not None else result + @extend class int: def _from_str(s: str, base: int): diff --git a/stdlib/internal/c_stubs.codon b/stdlib/internal/c_stubs.codon index 0660c31c..6a292850 100644 --- a/stdlib/internal/c_stubs.codon +++ b/stdlib/internal/c_stubs.codon @@ -683,6 +683,11 @@ def gztell(a: cobj) -> int: def gzseek(a: cobj, b: int, c: i32) -> int: pass +@nocapture +@C +def gzflush(a: cobj, b: i32) -> i32: + pass + # @pure @C diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index d2927cc4..5cca77db 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -6,28 +6,33 @@ class __internal__: @tuple @__internal__ +@__notuple__ class bool: pass @tuple @__internal__ +@__notuple__ class byte: pass @tuple @__internal__ +@__notuple__ class int: MAX = 9223372036854775807 pass @tuple @__internal__ +@__notuple__ class float: MIN_10_EXP = -307 pass @tuple @__internal__ +@__notuple__ class float32: MIN_10_EXP = -37 pass @@ -44,6 +49,7 @@ class type: @tuple @__internal__ +@__notuple__ class Function[T, TR]: pass @@ -54,27 +60,32 @@ class Callable[T, TR]: @tuple @__internal__ +@__notuple__ class Ptr[T]: pass cobj = Ptr[byte] @tuple @__internal__ +@__notuple__ class Generator[T]: pass @tuple @__internal__ +@__notuple__ class Optional: T: type = NoneType @tuple @__internal__ +@__notuple__ class Int[N: Static[int]]: pass @tuple @__internal__ +@__notuple__ class UInt[N: Static[int]]: pass @@ -105,8 +116,9 @@ function = Function class Ref[T]: pass -@__internal__ @tuple +@__internal__ +@__notuple__ class Union[TU]: # compiler-generated def __new__(val): @@ -166,6 +178,9 @@ def hasattr(obj, attr: Static[str]): def getattr(obj, attr: Static[str]): pass +def setattr(obj, attr: Static[str], what): + pass + def tuple(iterable): pass @@ -181,6 +196,3 @@ def __realized__(fn, args): def statictuple(*args): return args - -def __static_print__(*args): - pass diff --git a/stdlib/internal/file.codon b/stdlib/internal/file.codon index b67d3611..4c0f1560 100644 --- a/stdlib/internal/file.codon +++ b/stdlib/internal/file.codon @@ -59,15 +59,18 @@ class File: return str(buf, ret) def tell(self) -> int: + self._ensure_open() ret = _C.ftell(self.fp) self._errcheck("error in tell") return ret def seek(self, offset: int, whence: int): + self._ensure_open() _C.fseek(self.fp, offset, i32(whence)) self._errcheck("error in seek") def flush(self): + self._ensure_open() _C.fflush(self.fp) def close(self): @@ -198,14 +201,22 @@ class gzFile: return str(buf, int(ret)) def tell(self) -> int: + self._ensure_open() ret = _C.gztell(self.fp) _gz_errcheck(self.fp) return ret def seek(self, offset: int, whence: int): + self._ensure_open() _C.gzseek(self.fp, offset, i32(whence)) _gz_errcheck(self.fp) + def flush(self): + Z_FINISH = 4 + self._ensure_open() + _C.gzflush(self.fp, i32(Z_FINISH)) + _gz_errcheck(self.fp) + def _iter(self) -> Generator[str]: self._ensure_open() while True: diff --git a/stdlib/internal/format.codon b/stdlib/internal/format.codon index 9c3b6251..b47acf5e 100644 --- a/stdlib/internal/format.codon +++ b/stdlib/internal/format.codon @@ -56,3 +56,79 @@ class Ptr: if format_spec and err: _format_error(ret) return ret + +def _divmod_10(dividend, N: Static[int]): + T = type(dividend) + zero, one = T(0), T(1) + neg = dividend < zero + dvd = dividend.__abs__() + + remainder = 0 + quotient = zero + + # Euclidean division + for bit_idx in range(N - 1, -1, -1): + mask = int((dvd & (one << T(bit_idx))) != zero) + remainder = (remainder << 1) + mask + if remainder >= 10: + quotient = (quotient << one) + one + remainder -= 10 + else: + quotient = quotient << one + + if neg: + quotient = -quotient + remainder = -remainder + + return quotient, remainder + +@extend +class Int: + def __str__(self) -> str: + if N <= 64: + return str(int(self)) + + if not self: + return '0' + + s = _strbuf() + d = self + + if d >= Int[N](0): + while True: + d, m = _divmod_10(d, N) + b = byte(48 + m) # 48 == ord('0') + s.append(str(__ptr__(b), 1)) + if not d: + break + else: + while True: + d, m = _divmod_10(d, N) + b = byte(48 - m) # 48 == ord('0') + s.append(str(__ptr__(b), 1)) + + if not d: + break + s.append('-') + + s.reverse() + return s.__str__() + +@extend +class UInt: + def __str__(self) -> str: + if N <= 64: + return self.__format__("") + + s = _strbuf() + d = self + + while True: + d, m = _divmod_10(d, N) + b = byte(48 + int(m)) # 48 == ord('0') + s.append(str(__ptr__(b), 1)) + if not d: + break + + s.reverse() + return s.__str__() diff --git a/stdlib/internal/gc.codon b/stdlib/internal/gc.codon index 21ea5ed8..47f57fbd 100644 --- a/stdlib/internal/gc.codon +++ b/stdlib/internal/gc.codon @@ -11,6 +11,16 @@ def seq_alloc(a: int) -> cobj: def seq_alloc_atomic(a: int) -> cobj: pass +@pure +@C +def seq_alloc_uncollectable(a: int) -> cobj: + pass + +@pure +@C +def seq_alloc_atomic_uncollectable(a: int) -> cobj: + pass + @nocapture @derives @C @@ -61,6 +71,16 @@ def alloc(sz: int): def alloc_atomic(sz: int): return seq_alloc_atomic(sz) +# Allocates a block of memory via GC that is scanned, +# but not collected itself. Should be free'd explicitly. +def alloc_uncollectable(sz: int): + return seq_alloc_uncollectable(sz) + +# Allocates a block of memory via GC that is scanned, +# but not collected itself. Should be free'd explicitly. +def alloc_atomic_uncollectable(sz: int): + return seq_alloc_atomic_uncollectable(sz) + def realloc(p: cobj, newsz: int, oldsz: int): return seq_realloc(p, newsz, oldsz) diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index 36fd0802..4388c600 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -1,34 +1,35 @@ # Copyright (C) 2022-2023 Exaloop Inc. -from internal.gc import free, register_finalizer, seq_alloc, seq_alloc_atomic, seq_gc_add_roots - -@pure -@C -def seq_check_errno() -> str: - pass - -from C import seq_print(str) -from C import exit(int) -from C import malloc(int) -> cobj as c_malloc +from internal.gc import ( + alloc, alloc_atomic, alloc_uncollectable, alloc_atomic_uncollectable, + free, atomic, sizeof, register_finalizer +) __vtable_size__ = 0 @extend class __internal__: + def yield_final(val): + pass + def yield_in_no_suspend(T: type) -> T: + pass + @pure @derives @llvm def class_raw(obj) -> Ptr[byte]: ret ptr %obj - @pure - @derives + def class_alloc(T: type) -> T: + """Allocates a new reference (class) type""" + sz = sizeof(tuple(T)) + p = alloc_atomic(sz) if T.__contents_atomic__ else alloc(sz) + register_finalizer(p) + return __internal__.to_class_ptr(p, T) + def class_new(T: type) -> T: """Create a new reference (class) type""" - sz = T.__tuplesize__() - p = seq_alloc_atomic(sz) if T.__atomic__ else seq_alloc(sz) - register_finalizer(p) - pf = __internal__.to_class_ptr(p, T) + pf = __internal__.class_alloc(T) __internal__.class_set_obj_vtable(pf) return pf @@ -54,8 +55,7 @@ class __internal__: def class_make_n_vtables(sz: int) -> Ptr[Ptr[cobj]]: """Create a global vtable.""" - p = Ptr[Ptr[cobj]](sz) - seq_gc_add_roots(p.as_byte(), (p + sz).as_byte()) + p = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj]))) __internal__.class_populate_vtables(p) return p @@ -86,6 +86,18 @@ class __internal__: """Calculates the byte distance of base class B and derived class D. Compiler generated.""" return 0 + def class_copy(obj: T, T: type) -> T: + p = __internal__.class_alloc(T) + str.memcpy(p.__raw__(), obj.__raw__(), sizeof(tuple(T))) + return p + + def class_super(obj: D, B: type, D: type) -> B: + pf = __internal__.to_class_ptr(obj.__raw__() + __internal__.class_base_derived_dist(B, D), B) + pn = __internal__.class_copy(pf) + # Replace vtables + __internal__.class_set_obj_vtable(pn) # replace vtables to point to its vtables! + return pn + # Unions @llvm @@ -153,6 +165,20 @@ class __internal__: t, __internal__.tuple_fix_index(idx, staticlen(t)), T, E ) + def tuple_add(t, i): + if isinstance(i, Tuple): + return (*t, *i) + else: + compile_error("can only concatenate tuple to tuple") + + def tuple_mul(t, i: Static[int]): + if i < 1: + return () + elif i == 1: + return t + else: + return (*(__internal__.tuple_mul(t, i - 1)), *t) + # ... @pure @@ -191,11 +217,17 @@ class __internal__: return AssertionError(s) def seq_assert_test(file: str, line: int, msg: str): + from C import seq_print(str) s = f": {msg}" if msg else "" s = f"\033[1;31mTEST FAILED:\033[0m {file} (line {line}){s}\n" seq_print(s) def check_errno(prefix: str): + @pure + @C + def seq_check_errno() -> str: + pass + msg = seq_check_errno() if msg: raise OSError(prefix + msg) @@ -448,3 +480,15 @@ class Function: __vtables__ = __internal__.class_init_vtables() def _____(): __vtables__ # make it global! + + +@tuple +class PyObject: + refcnt: int + pytype: Ptr[byte] + + +@tuple +class PyWrapper[T]: + head: PyObject + data: T diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 276a244a..d2f9a54c 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -11,8 +11,6 @@ Py_Initialize = Function[[], NoneType](cobj()) PyImport_AddModule = Function[[cobj], cobj](cobj()) PyImport_AddModuleObject = Function[[cobj], cobj](cobj()) PyImport_ImportModule = Function[[cobj], cobj](cobj()) -PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj()) -PyErr_NormalizeException = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj()) PyRun_SimpleString = Function[[cobj], NoneType](cobj()) PyEval_GetGlobals = Function[[], cobj](cobj()) PyEval_GetBuiltins = Function[[], cobj](cobj()) @@ -25,12 +23,15 @@ PyFloat_FromDouble = Function[[float], cobj](cobj()) PyBool_FromLong = Function[[int], cobj](cobj()) PyBytes_AsString = Function[[cobj], cobj](cobj()) PyList_New = Function[[int], cobj](cobj()) +PyList_Size = Function[[cobj], int](cobj()) PyList_GetItem = Function[[cobj, int], cobj](cobj()) PyList_SetItem = Function[[cobj, int, cobj], cobj](cobj()) PyDict_New = Function[[], cobj](cobj()) PyDict_Next = Function[[cobj, Ptr[int], Ptr[cobj], Ptr[cobj]], int](cobj()) PyDict_GetItem = Function[[cobj, cobj], cobj](cobj()) +PyDict_GetItemString = Function[[cobj, cobj], cobj](cobj()) PyDict_SetItem = Function[[cobj, cobj, cobj], cobj](cobj()) +PyDict_Size = Function[[cobj], int](cobj()) PySet_Add = Function[[cobj, cobj], cobj](cobj()) PySet_New = Function[[cobj], cobj](cobj()) PyTuple_New = Function[[int], cobj](cobj()) @@ -102,10 +103,17 @@ PyObject_DelItem = Function[[cobj, cobj], int](cobj()) PyObject_RichCompare = Function[[cobj, cobj, i32], cobj](cobj()) PyObject_IsInstance = Function[[cobj, cobj], i32](cobj()) +# error handling +PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj()) +PyErr_NormalizeException = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj()) +PyErr_SetString = Function[[cobj, cobj], NoneType](cobj()) + # constants Py_None = cobj() Py_True = cobj() Py_False = cobj() +Py_Ellipsis = cobj() +Py_NotImplemented = cobj() Py_LT = 0 Py_LE = 1 Py_EQ = 2 @@ -113,6 +121,39 @@ Py_NE = 3 Py_GT = 4 Py_GE = 5 +# types +PyLong_Type = cobj() +PyFloat_Type = cobj() +PyBool_Type = cobj() +PyUnicode_Type = cobj() +PyComplex_Type = cobj() +PyList_Type = cobj() +PyDict_Type = cobj() +PySet_Type = cobj() +PyTuple_Type = cobj() +PySlice_Type = cobj() + +# exceptions +PyExc_BaseException = cobj() +PyExc_Exception = cobj() +PyExc_NameError = cobj() +PyExc_OSError = cobj() +PyExc_IOError = cobj() +PyExc_ValueError = cobj() +PyExc_LookupError = cobj() +PyExc_IndexError = cobj() +PyExc_KeyError = cobj() +PyExc_TypeError = cobj() +PyExc_ArithmeticError = cobj() +PyExc_ZeroDivisionError = cobj() +PyExc_OverflowError = cobj() +PyExc_AttributeError = cobj() +PyExc_RuntimeError = cobj() +PyExc_NotImplementedError = cobj() +PyExc_StopIteration = cobj() +PyExc_AssertionError = cobj() +PyExc_SystemExit = cobj() + _PY_MODULE_CACHE = Dict[str, pyobj]() _PY_INIT = """ @@ -149,15 +190,13 @@ def __codon_repr__(fig): _PY_INITIALIZED = False -def init_dl_handles(py_handle: cobj): +def init_handles_dlopen(py_handle: cobj): global Py_DecRef global Py_IncRef global Py_Initialize global PyImport_AddModule global PyImport_AddModuleObject global PyImport_ImportModule - global PyErr_Fetch - global PyErr_NormalizeException global PyRun_SimpleString global PyEval_GetGlobals global PyEval_GetBuiltins @@ -168,12 +207,15 @@ def init_dl_handles(py_handle: cobj): global PyBool_FromLong global PyBytes_AsString global PyList_New + global PyList_Size global PyList_GetItem global PyList_SetItem global PyDict_New global PyDict_Next global PyDict_GetItem + global PyDict_GetItemString global PyDict_SetItem + global PyDict_Size global PySet_Add global PySet_New global PyTuple_New @@ -240,9 +282,43 @@ def init_dl_handles(py_handle: cobj): global PyObject_DelItem global PyObject_RichCompare global PyObject_IsInstance + global PyErr_Fetch + global PyErr_NormalizeException + global PyErr_SetString global Py_None global Py_True global Py_False + global Py_Ellipsis + global Py_NotImplemented + global PyLong_Type + global PyFloat_Type + global PyBool_Type + global PyUnicode_Type + global PyComplex_Type + global PyList_Type + global PyDict_Type + global PySet_Type + global PyTuple_Type + global PySlice_Type + global PyExc_BaseException + global PyExc_Exception + global PyExc_NameError + global PyExc_OSError + global PyExc_IOError + global PyExc_ValueError + global PyExc_LookupError + global PyExc_IndexError + global PyExc_KeyError + global PyExc_TypeError + global PyExc_ArithmeticError + global PyExc_ZeroDivisionError + global PyExc_OverflowError + global PyExc_AttributeError + global PyExc_RuntimeError + global PyExc_NotImplementedError + global PyExc_StopIteration + global PyExc_AssertionError + global PyExc_SystemExit Py_DecRef = dlsym(py_handle, "Py_DecRef") Py_IncRef = dlsym(py_handle, "Py_IncRef") @@ -250,8 +326,6 @@ def init_dl_handles(py_handle: cobj): PyImport_AddModule = dlsym(py_handle, "PyImport_AddModule") PyImport_AddModuleObject = dlsym(py_handle, "PyImport_AddModuleObject") PyImport_ImportModule = dlsym(py_handle, "PyImport_ImportModule") - PyErr_Fetch = dlsym(py_handle, "PyErr_Fetch") - PyErr_NormalizeException = dlsym(py_handle, "PyErr_NormalizeException") PyRun_SimpleString = dlsym(py_handle, "PyRun_SimpleString") PyEval_GetGlobals = dlsym(py_handle, "PyEval_GetGlobals") PyEval_GetBuiltins = dlsym(py_handle, "PyEval_GetBuiltins") @@ -262,12 +336,15 @@ def init_dl_handles(py_handle: cobj): PyBool_FromLong = dlsym(py_handle, "PyBool_FromLong") PyBytes_AsString = dlsym(py_handle, "PyBytes_AsString") PyList_New = dlsym(py_handle, "PyList_New") + PyList_Size = dlsym(py_handle, "PyList_Size") PyList_GetItem = dlsym(py_handle, "PyList_GetItem") PyList_SetItem = dlsym(py_handle, "PyList_SetItem") PyDict_New = dlsym(py_handle, "PyDict_New") PyDict_Next = dlsym(py_handle, "PyDict_Next") PyDict_GetItem = dlsym(py_handle, "PyDict_GetItem") + PyDict_GetItemString = dlsym(py_handle, "PyDict_GetItemString") PyDict_SetItem = dlsym(py_handle, "PyDict_SetItem") + PyDict_Size = dlsym(py_handle, "PyDict_Size") PySet_Add = dlsym(py_handle, "PySet_Add") PySet_New = dlsym(py_handle, "PySet_New") PyTuple_New = dlsym(py_handle, "PyTuple_New") @@ -334,9 +411,452 @@ def init_dl_handles(py_handle: cobj): PyObject_DelItem = dlsym(py_handle, "PyObject_DelItem") PyObject_RichCompare = dlsym(py_handle, "PyObject_RichCompare") PyObject_IsInstance = dlsym(py_handle, "PyObject_IsInstance") + PyErr_Fetch = dlsym(py_handle, "PyErr_Fetch") + PyErr_NormalizeException = dlsym(py_handle, "PyErr_NormalizeException") + PyErr_SetString = dlsym(py_handle, "PyErr_SetString") Py_None = dlsym(py_handle, "_Py_NoneStruct") Py_True = dlsym(py_handle, "_Py_TrueStruct") Py_False = dlsym(py_handle, "_Py_FalseStruct") + Py_Ellipsis = dlsym(py_handle, "_Py_EllipsisObject") + Py_NotImplemented = dlsym(py_handle, "_Py_NotImplementedStruct") + PyLong_Type = dlsym(py_handle, "PyLong_Type") + PyFloat_Type = dlsym(py_handle, "PyFloat_Type") + PyBool_Type = dlsym(py_handle, "PyBool_Type") + PyUnicode_Type = dlsym(py_handle, "PyUnicode_Type") + PyComplex_Type = dlsym(py_handle, "PyComplex_Type") + PyList_Type = dlsym(py_handle, "PyList_Type") + PyDict_Type = dlsym(py_handle, "PyDict_Type") + PySet_Type = dlsym(py_handle, "PySet_Type") + PyTuple_Type = dlsym(py_handle, "PyTuple_Type") + PySlice_Type = dlsym(py_handle, "PySlice_Type") + PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException"))[0] + PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception"))[0] + PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError"))[0] + PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError"))[0] + PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError"))[0] + PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError"))[0] + PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError"))[0] + PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError"))[0] + PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError"))[0] + PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError"))[0] + PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError"))[0] + PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError"))[0] + PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError"))[0] + PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError"))[0] + PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError"))[0] + PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError"))[0] + PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration"))[0] + PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError"))[0] + PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit"))[0] + +def init_handles_static(): + from C import Py_DecRef(cobj) as _Py_DecRef + from C import Py_IncRef(cobj) as _Py_IncRef + from C import Py_Initialize() as _Py_Initialize + from C import PyImport_AddModule(cobj) -> cobj as _PyImport_AddModule + from C import PyImport_AddModuleObject(cobj) -> cobj as _PyImport_AddModuleObject + from C import PyImport_ImportModule(cobj) -> cobj as _PyImport_ImportModule + from C import PyRun_SimpleString(cobj) as _PyRun_SimpleString + from C import PyEval_GetGlobals() -> cobj as _PyEval_GetGlobals + from C import PyEval_GetBuiltins() -> cobj as _PyEval_GetBuiltins + from C import PyLong_AsLong(cobj) -> int as _PyLong_AsLong + from C import PyLong_FromLong(int) -> cobj as _PyLong_FromLong + from C import PyFloat_AsDouble(cobj) -> float as _PyFloat_AsDouble + from C import PyFloat_FromDouble(float) -> cobj as _PyFloat_FromDouble + from C import PyBool_FromLong(int) -> cobj as _PyBool_FromLong + from C import PyBytes_AsString(cobj) -> cobj as _PyBytes_AsString + from C import PyList_New(int) -> cobj as _PyList_New + from C import PyList_Size(cobj) -> int as _PyList_Size + from C import PyList_GetItem(cobj, int) -> cobj as _PyList_GetItem + from C import PyList_SetItem(cobj, int, cobj) -> cobj as _PyList_SetItem + from C import PyDict_New() -> cobj as _PyDict_New + from C import PyDict_Next(cobj, Ptr[int], Ptr[cobj], Ptr[cobj]) -> int as _PyDict_Next + from C import PyDict_GetItem(cobj, cobj) -> cobj as _PyDict_GetItem + from C import PyDict_GetItemString(cobj, cobj) -> cobj as _PyDict_GetItemString + from C import PyDict_SetItem(cobj, cobj, cobj) -> cobj as _PyDict_SetItem + from C import PyDict_Size(cobj) -> int as _PyDict_Size + from C import PySet_Add(cobj, cobj) -> cobj as _PySet_Add + from C import PySet_New(cobj) -> cobj as _PySet_New + from C import PyTuple_New(int) -> cobj as _PyTuple_New + from C import PyTuple_Size(cobj) -> int as _PyTuple_Size + from C import PyTuple_GetItem(cobj, int) -> cobj as _PyTuple_GetItem + from C import PyTuple_SetItem(cobj, int, cobj) as _PyTuple_SetItem + from C import PyUnicode_AsEncodedString(cobj, cobj, cobj) -> cobj as _PyUnicode_AsEncodedString + from C import PyUnicode_DecodeFSDefaultAndSize(cobj, int) -> cobj as _PyUnicode_DecodeFSDefaultAndSize + from C import PyUnicode_FromString(cobj) -> cobj as _PyUnicode_FromString + from C import PyComplex_FromDoubles(float, float) -> cobj as _PyComplex_FromDoubles + from C import PyComplex_RealAsDouble(cobj) -> float as _PyComplex_RealAsDouble + from C import PyComplex_ImagAsDouble(cobj) -> float as _PyComplex_ImagAsDouble + from C import PyIter_Next(cobj) -> cobj as _PyIter_Next + from C import PySlice_New(cobj, cobj, cobj) -> cobj as _PySlice_New + from C import PySlice_Unpack(cobj, Ptr[int], Ptr[int], Ptr[int]) -> int as _PySlice_Unpack + from C import PyNumber_Add(cobj, cobj) -> cobj as _PyNumber_Add + from C import PyNumber_Subtract(cobj, cobj) -> cobj as _PyNumber_Subtract + from C import PyNumber_Multiply(cobj, cobj) -> cobj as _PyNumber_Multiply + from C import PyNumber_MatrixMultiply(cobj, cobj) -> cobj as _PyNumber_MatrixMultiply + from C import PyNumber_FloorDivide(cobj, cobj) -> cobj as _PyNumber_FloorDivide + from C import PyNumber_TrueDivide(cobj, cobj) -> cobj as _PyNumber_TrueDivide + from C import PyNumber_Remainder(cobj, cobj) -> cobj as _PyNumber_Remainder + from C import PyNumber_Divmod(cobj, cobj) -> cobj as _PyNumber_Divmod + from C import PyNumber_Power(cobj, cobj, cobj) -> cobj as _PyNumber_Power + from C import PyNumber_Negative(cobj) -> cobj as _PyNumber_Negative + from C import PyNumber_Positive(cobj) -> cobj as _PyNumber_Positive + from C import PyNumber_Absolute(cobj) -> cobj as _PyNumber_Absolute + from C import PyNumber_Invert(cobj) -> cobj as _PyNumber_Invert + from C import PyNumber_Lshift(cobj, cobj) -> cobj as _PyNumber_Lshift + from C import PyNumber_Rshift(cobj, cobj) -> cobj as _PyNumber_Rshift + from C import PyNumber_And(cobj, cobj) -> cobj as _PyNumber_And + from C import PyNumber_Xor(cobj, cobj) -> cobj as _PyNumber_Xor + from C import PyNumber_Or(cobj, cobj) -> cobj as _PyNumber_Or + from C import PyNumber_InPlaceAdd(cobj, cobj) -> cobj as _PyNumber_InPlaceAdd + from C import PyNumber_InPlaceSubtract(cobj, cobj) -> cobj as _PyNumber_InPlaceSubtract + from C import PyNumber_InPlaceMultiply(cobj, cobj) -> cobj as _PyNumber_InPlaceMultiply + from C import PyNumber_InPlaceMatrixMultiply(cobj, cobj) -> cobj as _PyNumber_InPlaceMatrixMultiply + from C import PyNumber_InPlaceFloorDivide(cobj, cobj) -> cobj as _PyNumber_InPlaceFloorDivide + from C import PyNumber_InPlaceTrueDivide(cobj, cobj) -> cobj as _PyNumber_InPlaceTrueDivide + from C import PyNumber_InPlaceRemainder(cobj, cobj) -> cobj as _PyNumber_InPlaceRemainder + from C import PyNumber_InPlacePower(cobj, cobj, cobj) -> cobj as _PyNumber_InPlacePower + from C import PyNumber_InPlaceLshift(cobj, cobj) -> cobj as _PyNumber_InPlaceLshift + from C import PyNumber_InPlaceRshift(cobj, cobj) -> cobj as _PyNumber_InPlaceRshift + from C import PyNumber_InPlaceAnd(cobj, cobj) -> cobj as _PyNumber_InPlaceAnd + from C import PyNumber_InPlaceXor(cobj, cobj) -> cobj as _PyNumber_InPlaceXor + from C import PyNumber_InPlaceOr(cobj, cobj) -> cobj as _PyNumber_InPlaceOr + from C import PyNumber_Long(cobj) -> cobj as _PyNumber_Long + from C import PyNumber_Float(cobj) -> cobj as _PyNumber_Float + from C import PyNumber_Index(cobj) -> cobj as _PyNumber_Index + from C import PyObject_Call(cobj, cobj, cobj) -> cobj as _PyObject_Call + from C import PyObject_GetAttr(cobj, cobj) -> cobj as _PyObject_GetAttr + from C import PyObject_GetAttrString(cobj, cobj) -> cobj as _PyObject_GetAttrString + from C import PyObject_GetIter(cobj) -> cobj as _PyObject_GetIter + from C import PyObject_HasAttrString(cobj, cobj) -> int as _PyObject_HasAttrString + from C import PyObject_IsTrue(cobj) -> int as _PyObject_IsTrue + from C import PyObject_Length(cobj) -> int as _PyObject_Length + from C import PyObject_LengthHint(cobj, int) -> int as _PyObject_LengthHint + from C import PyObject_SetAttrString(cobj, cobj, cobj) -> cobj as _PyObject_SetAttrString + from C import PyObject_Str(cobj) -> cobj as _PyObject_Str + from C import PyObject_Repr(cobj) -> cobj as _PyObject_Repr + from C import PyObject_Hash(cobj) -> int as _PyObject_Hash + from C import PyObject_GetItem(cobj, cobj) -> cobj as _PyObject_GetItem + from C import PyObject_SetItem(cobj, cobj, cobj) -> int as _PyObject_SetItem + from C import PyObject_DelItem(cobj, cobj) -> int as _PyObject_DelItem + from C import PyObject_RichCompare(cobj, cobj, i32) -> cobj as _PyObject_RichCompare + from C import PyObject_IsInstance(cobj, cobj) -> i32 as _PyObject_IsInstance + from C import PyErr_Fetch(Ptr[cobj], Ptr[cobj], Ptr[cobj]) as _PyErr_Fetch + from C import PyErr_NormalizeException(Ptr[cobj], Ptr[cobj], Ptr[cobj]) as _PyErr_NormalizeException + from C import PyErr_SetString(cobj, cobj) as _PyErr_SetString + from C import _Py_NoneStruct: cobj + from C import _Py_TrueStruct: cobj + from C import _Py_FalseStruct: cobj + from C import _Py_EllipsisObject: cobj + from C import _Py_NotImplementedStruct: cobj + from C import PyLong_Type: cobj as _PyLong_Type + from C import PyFloat_Type: cobj as _PyFloat_Type + from C import PyBool_Type: cobj as _PyBool_Type + from C import PyUnicode_Type: cobj as _PyUnicode_Type + from C import PyComplex_Type: cobj as _PyComplex_Type + from C import PyList_Type: cobj as _PyList_Type + from C import PyDict_Type: cobj as _PyDict_Type + from C import PySet_Type: cobj as _PySet_Type + from C import PyTuple_Type: cobj as _PyTuple_Type + from C import PySlice_Type: cobj as _PySlice_Type + from C import PyExc_BaseException: cobj as _PyExc_BaseException + from C import PyExc_Exception: cobj as _PyExc_Exception + from C import PyExc_NameError: cobj as _PyExc_NameError + from C import PyExc_OSError: cobj as _PyExc_OSError + from C import PyExc_IOError: cobj as _PyExc_IOError + from C import PyExc_ValueError: cobj as _PyExc_ValueError + from C import PyExc_LookupError: cobj as _PyExc_LookupError + from C import PyExc_IndexError: cobj as _PyExc_IndexError + from C import PyExc_KeyError: cobj as _PyExc_KeyError + from C import PyExc_TypeError: cobj as _PyExc_TypeError + from C import PyExc_ArithmeticError: cobj as _PyExc_ArithmeticError + from C import PyExc_ZeroDivisionError: cobj as _PyExc_ZeroDivisionError + from C import PyExc_OverflowError: cobj as _PyExc_OverflowError + from C import PyExc_AttributeError: cobj as _PyExc_AttributeError + from C import PyExc_RuntimeError: cobj as _PyExc_RuntimeError + from C import PyExc_NotImplementedError: cobj as _PyExc_NotImplementedError + from C import PyExc_StopIteration: cobj as _PyExc_StopIteration + from C import PyExc_AssertionError: cobj as _PyExc_AssertionError + from C import PyExc_SystemExit: cobj as _PyExc_SystemExit + + global Py_DecRef + global Py_IncRef + global Py_Initialize + global PyImport_AddModule + global PyImport_AddModuleObject + global PyImport_ImportModule + global PyRun_SimpleString + global PyEval_GetGlobals + global PyEval_GetBuiltins + global PyLong_AsLong + global PyLong_FromLong + global PyFloat_AsDouble + global PyFloat_FromDouble + global PyBool_FromLong + global PyBytes_AsString + global PyList_New + global PyList_Size + global PyList_GetItem + global PyList_SetItem + global PyDict_New + global PyDict_Next + global PyDict_GetItem + global PyDict_GetItemString + global PyDict_SetItem + global PyDict_Size + global PySet_Add + global PySet_New + global PyTuple_New + global PyTuple_Size + global PyTuple_GetItem + global PyTuple_SetItem + global PyUnicode_AsEncodedString + global PyUnicode_DecodeFSDefaultAndSize + global PyUnicode_FromString + global PyComplex_FromDoubles + global PyComplex_RealAsDouble + global PyComplex_ImagAsDouble + global PyIter_Next + global PySlice_New + global PySlice_Unpack + global PyNumber_Add + global PyNumber_Subtract + global PyNumber_Multiply + global PyNumber_MatrixMultiply + global PyNumber_FloorDivide + global PyNumber_TrueDivide + global PyNumber_Remainder + global PyNumber_Divmod + global PyNumber_Power + global PyNumber_Negative + global PyNumber_Positive + global PyNumber_Absolute + global PyNumber_Invert + global PyNumber_Lshift + global PyNumber_Rshift + global PyNumber_And + global PyNumber_Xor + global PyNumber_Or + global PyNumber_InPlaceAdd + global PyNumber_InPlaceSubtract + global PyNumber_InPlaceMultiply + global PyNumber_InPlaceMatrixMultiply + global PyNumber_InPlaceFloorDivide + global PyNumber_InPlaceTrueDivide + global PyNumber_InPlaceRemainder + global PyNumber_InPlacePower + global PyNumber_InPlaceLshift + global PyNumber_InPlaceRshift + global PyNumber_InPlaceAnd + global PyNumber_InPlaceXor + global PyNumber_InPlaceOr + global PyNumber_Long + global PyNumber_Float + global PyNumber_Index + global PyObject_Call + global PyObject_GetAttr + global PyObject_GetAttrString + global PyObject_GetIter + global PyObject_HasAttrString + global PyObject_IsTrue + global PyObject_Length + global PyObject_LengthHint + global PyObject_SetAttrString + global PyObject_Str + global PyObject_Repr + global PyObject_Hash + global PyObject_GetItem + global PyObject_SetItem + global PyObject_DelItem + global PyObject_RichCompare + global PyObject_IsInstance + global PyErr_Fetch + global PyErr_NormalizeException + global PyErr_SetString + global Py_None + global Py_True + global Py_False + global Py_Ellipsis + global Py_NotImplemented + global PyLong_Type + global PyFloat_Type + global PyBool_Type + global PyUnicode_Type + global PyComplex_Type + global PyList_Type + global PyDict_Type + global PySet_Type + global PyTuple_Type + global PySlice_Type + global PyExc_BaseException + global PyExc_Exception + global PyExc_NameError + global PyExc_OSError + global PyExc_IOError + global PyExc_ValueError + global PyExc_LookupError + global PyExc_IndexError + global PyExc_KeyError + global PyExc_TypeError + global PyExc_ArithmeticError + global PyExc_ZeroDivisionError + global PyExc_OverflowError + global PyExc_AttributeError + global PyExc_RuntimeError + global PyExc_NotImplementedError + global PyExc_StopIteration + global PyExc_AssertionError + global PyExc_SystemExit + + Py_DecRef = _Py_DecRef + Py_IncRef = _Py_IncRef + Py_Initialize = _Py_Initialize + PyImport_AddModule = _PyImport_AddModule + PyImport_AddModuleObject = _PyImport_AddModuleObject + PyImport_ImportModule = _PyImport_ImportModule + PyRun_SimpleString = _PyRun_SimpleString + PyEval_GetGlobals = _PyEval_GetGlobals + PyEval_GetBuiltins = _PyEval_GetBuiltins + PyLong_AsLong = _PyLong_AsLong + PyLong_FromLong = _PyLong_FromLong + PyFloat_AsDouble = _PyFloat_AsDouble + PyFloat_FromDouble = _PyFloat_FromDouble + PyBool_FromLong = _PyBool_FromLong + PyBytes_AsString = _PyBytes_AsString + PyList_New = _PyList_New + PyList_Size = _PyList_Size + PyList_GetItem = _PyList_GetItem + PyList_SetItem = _PyList_SetItem + PyDict_New = _PyDict_New + PyDict_Next = _PyDict_Next + PyDict_GetItem = _PyDict_GetItem + PyDict_GetItemString = _PyDict_GetItemString + PyDict_SetItem = _PyDict_SetItem + PyDict_Size = _PyDict_Size + PySet_Add = _PySet_Add + PySet_New = _PySet_New + PyTuple_New = _PyTuple_New + PyTuple_Size = _PyTuple_Size + PyTuple_GetItem = _PyTuple_GetItem + PyTuple_SetItem = _PyTuple_SetItem + PyUnicode_AsEncodedString = _PyUnicode_AsEncodedString + PyUnicode_DecodeFSDefaultAndSize = _PyUnicode_DecodeFSDefaultAndSize + PyUnicode_FromString = _PyUnicode_FromString + PyComplex_FromDoubles = _PyComplex_FromDoubles + PyComplex_RealAsDouble = _PyComplex_RealAsDouble + PyComplex_ImagAsDouble = _PyComplex_ImagAsDouble + PyIter_Next = _PyIter_Next + PySlice_New = _PySlice_New + PySlice_Unpack = _PySlice_Unpack + PyNumber_Add = _PyNumber_Add + PyNumber_Subtract = _PyNumber_Subtract + PyNumber_Multiply = _PyNumber_Multiply + PyNumber_MatrixMultiply = _PyNumber_MatrixMultiply + PyNumber_FloorDivide = _PyNumber_FloorDivide + PyNumber_TrueDivide = _PyNumber_TrueDivide + PyNumber_Remainder = _PyNumber_Remainder + PyNumber_Divmod = _PyNumber_Divmod + PyNumber_Power = _PyNumber_Power + PyNumber_Negative = _PyNumber_Negative + PyNumber_Positive = _PyNumber_Positive + PyNumber_Absolute = _PyNumber_Absolute + PyNumber_Invert = _PyNumber_Invert + PyNumber_Lshift = _PyNumber_Lshift + PyNumber_Rshift = _PyNumber_Rshift + PyNumber_And = _PyNumber_And + PyNumber_Xor = _PyNumber_Xor + PyNumber_Or = _PyNumber_Or + PyNumber_InPlaceAdd = _PyNumber_InPlaceAdd + PyNumber_InPlaceSubtract = _PyNumber_InPlaceSubtract + PyNumber_InPlaceMultiply = _PyNumber_InPlaceMultiply + PyNumber_InPlaceMatrixMultiply = _PyNumber_InPlaceMatrixMultiply + PyNumber_InPlaceFloorDivide = _PyNumber_InPlaceFloorDivide + PyNumber_InPlaceTrueDivide = _PyNumber_InPlaceTrueDivide + PyNumber_InPlaceRemainder = _PyNumber_InPlaceRemainder + PyNumber_InPlacePower = _PyNumber_InPlacePower + PyNumber_InPlaceLshift = _PyNumber_InPlaceLshift + PyNumber_InPlaceRshift = _PyNumber_InPlaceRshift + PyNumber_InPlaceAnd = _PyNumber_InPlaceAnd + PyNumber_InPlaceXor = _PyNumber_InPlaceXor + PyNumber_InPlaceOr = _PyNumber_InPlaceOr + PyNumber_Long = _PyNumber_Long + PyNumber_Float = _PyNumber_Float + PyNumber_Index = _PyNumber_Index + PyObject_Call = _PyObject_Call + PyObject_GetAttr = _PyObject_GetAttr + PyObject_GetAttrString = _PyObject_GetAttrString + PyObject_GetIter = _PyObject_GetIter + PyObject_HasAttrString = _PyObject_HasAttrString + PyObject_IsTrue = _PyObject_IsTrue + PyObject_Length = _PyObject_Length + PyObject_LengthHint = _PyObject_LengthHint + PyObject_SetAttrString = _PyObject_SetAttrString + PyObject_Str = _PyObject_Str + PyObject_Repr = _PyObject_Repr + PyObject_Hash = _PyObject_Hash + PyObject_GetItem = _PyObject_GetItem + PyObject_SetItem = _PyObject_SetItem + PyObject_DelItem = _PyObject_DelItem + PyObject_RichCompare = _PyObject_RichCompare + PyObject_IsInstance = _PyObject_IsInstance + PyErr_Fetch = _PyErr_Fetch + PyErr_NormalizeException = _PyErr_NormalizeException + PyErr_SetString = _PyErr_SetString + Py_None = __ptr__(_Py_NoneStruct).as_byte() + Py_True = __ptr__(_Py_TrueStruct).as_byte() + Py_False = __ptr__(_Py_FalseStruct).as_byte() + Py_Ellipsis = __ptr__(_Py_EllipsisObject).as_byte() + Py_NotImplemented = __ptr__(_Py_NotImplementedStruct).as_byte() + PyLong_Type = __ptr__(_PyLong_Type).as_byte() + PyFloat_Type = __ptr__(_PyFloat_Type).as_byte() + PyBool_Type = __ptr__(_PyBool_Type).as_byte() + PyUnicode_Type = __ptr__(_PyUnicode_Type).as_byte() + PyComplex_Type = __ptr__(_PyComplex_Type).as_byte() + PyList_Type = __ptr__(_PyList_Type).as_byte() + PyDict_Type = __ptr__(_PyDict_Type).as_byte() + PySet_Type = __ptr__(_PySet_Type).as_byte() + PyTuple_Type = __ptr__(_PyTuple_Type).as_byte() + PySlice_Type = __ptr__(_PySlice_Type).as_byte() + PyExc_BaseException = _PyExc_BaseException + PyExc_Exception = _PyExc_Exception + PyExc_NameError = _PyExc_NameError + PyExc_OSError = _PyExc_OSError + PyExc_IOError = _PyExc_IOError + PyExc_ValueError = _PyExc_ValueError + PyExc_LookupError = _PyExc_LookupError + PyExc_IndexError = _PyExc_IndexError + PyExc_KeyError = _PyExc_KeyError + PyExc_TypeError = _PyExc_TypeError + PyExc_ArithmeticError = _PyExc_ArithmeticError + PyExc_ZeroDivisionError = _PyExc_ZeroDivisionError + PyExc_OverflowError = _PyExc_OverflowError + PyExc_AttributeError = _PyExc_AttributeError + PyExc_RuntimeError = _PyExc_RuntimeError + PyExc_NotImplementedError = _PyExc_NotImplementedError + PyExc_StopIteration = _PyExc_StopIteration + PyExc_AssertionError = _PyExc_AssertionError + PyExc_SystemExit = _PyExc_SystemExit + +def init_error_py_types(): + BaseException._pytype = PyExc_BaseException + Exception._pytype = PyExc_Exception + NameError._pytype = PyExc_NameError + OSError._pytype = PyExc_OSError + IOError._pytype = PyExc_IOError + ValueError._pytype = PyExc_ValueError + LookupError._pytype = PyExc_LookupError + IndexError._pytype = PyExc_IndexError + KeyError._pytype = PyExc_KeyError + TypeError._pytype = PyExc_TypeError + ArithmeticError._pytype = PyExc_ArithmeticError + ZeroDivisionError._pytype = PyExc_ZeroDivisionError + OverflowError._pytype = PyExc_OverflowError + AttributeError._pytype = PyExc_AttributeError + RuntimeError._pytype = PyExc_RuntimeError + NotImplementedError._pytype = PyExc_NotImplementedError + StopIteration._pytype = PyExc_StopIteration + AssertionError._pytype = PyExc_AssertionError + SystemExit._pytype = PyExc_SystemExit def setup_python(python_loaded: bool): global _PY_INITIALIZED @@ -350,7 +870,8 @@ def setup_python(python_loaded: bool): LD = os.getenv("CODON_PYTHON", default="libpython." + dlext()) py_handle = dlopen(LD, RTLD_LOCAL | RTLD_NOW) - init_dl_handles(py_handle) + init_handles_dlopen(py_handle) + init_error_py_types() if not python_loaded: Py_Initialize() @@ -358,12 +879,34 @@ def setup_python(python_loaded: bool): _PY_INITIALIZED = True def ensure_initialized(python_loaded: bool = False): - setup_python(python_loaded) - PyRun_SimpleString(_PY_INIT.c_str()) + if __py_extension__: + init_handles_static() + init_error_py_types() + else: + setup_python(python_loaded) + PyRun_SimpleString(_PY_INIT.c_str()) def setup_decorator(): setup_python(True) +@tuple +class _PyArg_Parser: + initialized: i32 + format: cobj + keywords: Ptr[cobj] + fname: cobj + custom_msg: cobj + pos: i32 + min: i32 + max: i32 + kwtuple: cobj + next: cobj + + def __new__(fname: cobj, keywords: Ptr[cobj], format: cobj): + z = i32(0) + o = cobj() + return _PyArg_Parser(z, format, keywords, fname, o, z, z, z, o, o) + @extend class pyobj: @__internal__ @@ -587,7 +1130,7 @@ class pyobj: def __iter__(self) -> Generator[pyobj]: it = PyObject_GetIter(self.p) if not it: - raise ValueError("Python object is not iterable") + raise TypeError("Python object is not iterable") try: while i := PyIter_Next(it): yield pyobj(pyobj.exc_wrap(i), steal=True) @@ -709,6 +1252,19 @@ def _get_identifier(typ: str) -> pyobj: def _isinstance(what: pyobj, typ: pyobj) -> bool: return bool(pyobj.exc_wrap(PyObject_IsInstance(what.p, typ.p))) +@tuple +class _PyObject_Struct: + refcnt: int + pytype: cobj + +def _conversion_error(name: Static[str]): + raise PyError("conversion error: Python object did not have type '" + name + "'") + +def _ensure_type(o: cobj, t: cobj, name: Static[str]): + if Ptr[_PyObject_Struct](o)[0].pytype != t: + _conversion_error(name) + + # Type conversions @extend @@ -717,7 +1273,9 @@ class NoneType: Py_IncRef(Py_None) return Py_None - def __from_py__(i: cobj) -> None: + def __from_py__(x: cobj) -> None: + if x != Py_None: + _conversion_error("NoneType") return @extend @@ -726,7 +1284,8 @@ class int: return pyobj.exc_wrap(PyLong_FromLong(self)) def __from_py__(i: cobj) -> int: - return pyobj.exc_wrap(PyLong_AsLong(i)) + _ensure_type(i, PyLong_Type, "int") + return PyLong_AsLong(i) @extend class float: @@ -742,7 +1301,8 @@ class bool: return pyobj.exc_wrap(PyBool_FromLong(int(self))) def __from_py__(b: cobj) -> bool: - return pyobj.exc_wrap(PyObject_IsTrue(b)) != 0 + _ensure_type(b, PyBool_Type, "bool") + return PyObject_IsTrue(b) != 0 @extend class byte: @@ -766,8 +1326,9 @@ class complex: return pyobj.exc_wrap(PyComplex_FromDoubles(self.real, self.imag)) def __from_py__(c: cobj) -> complex: - real = pyobj.exc_wrap(PyComplex_RealAsDouble(c)) - imag = pyobj.exc_wrap(PyComplex_ImagAsDouble(c)) + _ensure_type(c, PyComplex_Type, "complex") + real = PyComplex_RealAsDouble(c) + imag = PyComplex_ImagAsDouble(c) return complex(real, imag) @extend @@ -783,11 +1344,11 @@ class List: return pylist def __from_py__(v: cobj) -> List[T]: - n = pyobj.exc_wrap(PyObject_Length(v)) + _ensure_type(v, PyList_Type, "list") + n = PyList_Size(v) t = List[T](n) for i in range(n): elem = PyList_GetItem(v, i) - pyobj.exc_check() t.append(T.__from_py__(elem)) return t @@ -802,12 +1363,12 @@ class Dict: return pydict def __from_py__(d: cobj) -> Dict[K, V]: + _ensure_type(d, PyDict_Type, "dict") b = dict[K, V]() pos = 0 k_ptr = cobj() v_ptr = cobj() while PyDict_Next(d, __ptr__(pos), __ptr__(k_ptr), __ptr__(v_ptr)): - pyobj.exc_check() k = K.__from_py__(k_ptr) v = V.__from_py__(v_ptr) b[k] = v @@ -824,8 +1385,9 @@ class Set: return pyset def __from_py__(s: cobj) -> Set[K]: + _ensure_type(s, PySet_Type, "set") b = set[K]() - s_iter = pyobj.exc_wrap(PyObject_GetIter(s)) + s_iter = PyObject_GetIter(s) while True: k_ptr = pyobj.exc_wrap(PyIter_Next(s_iter)) if not k_ptr: @@ -848,7 +1410,8 @@ class DynamicTuple: return pytup def __from_py__(t: cobj) -> DynamicTuple[T]: - n = pyobj.exc_wrap(PyTuple_Size(t)) + _ensure_type(t, PyTuple_Type, "tuple") + n = PyTuple_Size(t) p = Ptr[T](n) for i in range(n): p[i] = T.__from_py__(PyTuple_GetItem(t, i)) @@ -866,10 +1429,11 @@ class Slice: return PySlice_New(start_py, stop_py, step_py) def __from_py__(s: cobj) -> Slice: + _ensure_type(s, PySlice_Type, "slice") start = 0 stop = 0 step = 0 - pyobj.exc_wrap(PySlice_Unpack(s, __ptr__(start), __ptr__(stop), __ptr__(step))) + PySlice_Unpack(s, __ptr__(start), __ptr__(stop), __ptr__(step)) return Slice(Optional(start), Optional(stop), Optional(step)) @extend @@ -885,3 +1449,492 @@ class Optional: return Optional[T]() else: return Optional[T](T.__from_py__(o)) + + +__pyenv__: Optional[pyobj] = None +def _____(): __pyenv__ # make it global! + + +import internal.static as _S + + +class _PyWrapError(Static[PyError]): + def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)): + super().__init__("_PyWrapError", message) + self.pytype = pytype + + def __init__(self, e: PyError): + self.__init__("_PyWrapError", e.message, e.pytype) + + +class _PyWrap: + def _dispatch_error(F: Static[str]): + raise TypeError("could not find callable method '" + F + "' for given arguments") + + def _wrap(args, T: type, F: Static[str], map): + for fn in _S.fn_overloads(T, F): + a = _PyWrap._args_from_py(fn, args) + if a is None: + continue + if _S.fn_can_call(fn, *a): + try: + return map(fn, a) + except PyError as e: + pass + _PyWrap._dispatch_error(F) + + def _wrap_unary(obj: cobj, T: type, F: Static[str]) -> cobj: + return _PyWrap._wrap( + (obj,), T=T, F=F, + map=lambda f, a: f(*a).__to_py__() + ) + + def _args_from_py(fn, args): + def err(fail: Ptr[bool], T: type = NoneType) -> T: + fail[0] = True + # auto-return zero-initialized T + + def get_arg(F, p, k, fail: Ptr[bool], i: Static[int]): + if _S.fn_arg_has_type(F, i): + return _S.fn_arg_get_type(F, i).__from_py__(p[i]) if p[i] != cobj() else ( + _S.fn_get_default(F, i) if _S.fn_has_default(F, i) + else err(fail, _S.fn_arg_get_type(F, i)) + ) + else: + return pyobj(p[i], steal=False) if p[i] != cobj() else ( + _S.fn_get_default(F, i) if _S.fn_has_default(F, i) else err(fail) + ) + + fail = False + pargs = Ptr[cobj](__ptr__(args).as_byte()) + try: + ta = tuple(get_arg(fn, pargs, k, __ptr__(fail), i) for i, k in staticenumerate(_S.fn_args(fn))) + if fail: + return None + return _S.fn_wrap_call_args(fn, *ta) + except PyError: + return None + + def _reorder_args(fn, self: cobj, args: cobj, kwargs: cobj, M: Static[int] = 1): + nargs = PyTuple_Size(args) + nkwargs = PyDict_Size(kwargs) if kwargs != cobj() else 0 + + args_ordered = tuple(cobj() for _ in _S.fn_args(fn)) + pargs = Ptr[cobj](__ptr__(args_ordered).as_byte()) + + if nargs + nkwargs + M > len(args_ordered): + return None + + if M: + pargs[0] = self + + for i in range(nargs): + pargs[i + M] = PyTuple_GetItem(args, i) + + kwused = 0 + for i, k in staticenumerate(_S.fn_args(fn)): + if i < nargs + M: + continue + + p = PyDict_GetItemString(kwargs, k.ptr) if nkwargs else cobj() + if p != cobj(): + pargs[i] = p + kwused += 1 + + if kwused != nkwargs: + return None + + return _PyWrap._args_from_py(fn, args_ordered) + + def _reorder_args_fastcall( + fn, self: cobj, args: Ptr[cobj], nargs: int, + kwds: Ptr[str], nkw: int, M: Static[int] = 1 + ): + args_ordered = tuple(cobj() for _ in _S.fn_args(fn)) + pargs = Ptr[cobj](__ptr__(args_ordered).as_byte()) + + if nargs + M > len(args_ordered): + return None + + if M: + pargs[0] = self + + for i in range(nargs): + pargs[i + M] = args[i] + + for i in range(nargs, nargs + nkw): + kw = kwds[i - nargs] + o = args[i] + + found = False + j = M + for i, k in staticenumerate(_S.fn_args(fn)): + if M and i == 0: + continue + if kw == k: + if not pargs[j]: + pargs[j] = o + else: + return None + found = True + break + j += 1 + if not found: + return None + + return _PyWrap._args_from_py(fn, args_ordered) + + def wrap_magic_abs(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__abs__") + + def wrap_magic_pos(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__pos__") + + def wrap_magic_neg(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__neg__") + + def wrap_magic_invert(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__invert__") + + def wrap_magic_int(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__int__") + + def wrap_magic_float(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__float__") + + def wrap_magic_index(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__index__") + + def wrap_magic_repr(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__repr__") + + def wrap_magic_str(obj: cobj, T: type): + return _PyWrap._wrap_unary(obj, T, "__str__") + + def _wrap_binary(obj: cobj, obj2: cobj, T: type, F: Static[str]) -> cobj: + return _PyWrap._wrap( + (obj, obj2), T=T, F=F, + map=lambda f, a: f(*a).__to_py__() + ) + + def wrap_magic_add(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__add__") + + def wrap_magic_radd(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__radd__") + + def wrap_magic_iadd(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__iadd__") + + def wrap_magic_sub(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__sub__") + + def wrap_magic_rsub(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rsub__") + + def wrap_magic_isub(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__isub__") + + def wrap_magic_mul(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__mul__") + + def wrap_magic_rmul(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rmul__") + + def wrap_magic_imul(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__imul__") + + def wrap_magic_mod(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__mod__") + + def wrap_magic_rmod(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rmod__") + + def wrap_magic_imod(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__imod__") + + def wrap_magic_divmod(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__divmod__") + + def wrap_magic_rdivmod(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rdivmod__") + + def wrap_magic_lshift(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__lshift__") + + def wrap_magic_rlshift(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rlshift__") + + def wrap_magic_ilshift(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__ilshift__") + + def wrap_magic_rshift(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rshift__") + + def wrap_magic_rrshift(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rrshift__") + + def wrap_magic_irshift(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__irshift__") + + def wrap_magic_and(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__and__") + + def wrap_magic_rand(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rand__") + + def wrap_magic_iand(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__iand__") + + def wrap_magic_xor(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__xor__") + + def wrap_magic_rxor(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rxor__") + + def wrap_magic_ixor(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__ixor__") + + def wrap_magic_or(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__or__") + + def wrap_magic_ror(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__ror__") + + def wrap_magic_ior(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__ior__") + + def wrap_magic_floordiv(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__floordiv__") + + def wrap_magic_ifloordiv(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__ifloordiv__") + + def wrap_magic_truediv(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__truediv__") + + def wrap_magic_itruediv(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__itruediv__") + + def wrap_magic_matmul(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__matmul__") + + def wrap_magic_rmatmul(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rmatmul__") + + def wrap_magic_imatmul(obj: cobj, obj2: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__imatmul__") + + def wrap_magic_pow(obj: cobj, obj2: cobj, obj3: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__pow__") + + def wrap_magic_rpow(obj: cobj, obj2: cobj, obj3: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__rpow__") + + def wrap_magic_ipow(obj: cobj, obj2: cobj, obj3: cobj, T: type): + return _PyWrap._wrap_binary(obj, obj2, T, "__ipow__") + + def _wrap_hash(obj: cobj, T: type, F: Static[str]) -> i64: + return _PyWrap._wrap( + (obj,), T=T, F=F, + map=lambda f, a: f(*a) + ) + + def wrap_magic_len(obj: cobj, T: type): + return _PyWrap._wrap_hash(obj, T, "__len__") + + def wrap_magic_hash(obj: cobj, T: type): + return _PyWrap._wrap_hash(obj, T, "__hash__") + + def wrap_magic_bool(obj: cobj, T: type) -> i32: + return _PyWrap._wrap( + (obj,), T=T, F="__bool__", + map=lambda f, a: i32(1) if f(*a) else i32(0) + ) + + def wrap_magic_del(obj: cobj, T: type): + _PyWrap._wrap( + (obj,), T=T, F="__del__", + map=lambda f, a: f(*a) + ) + + def wrap_magic_contains(obj: cobj, arg: cobj, T: type) -> i32: + return _PyWrap._wrap( + (obj, arg,), T=T, F="__contains__", + map=lambda f, a: i32(1) if f(*a) else i32(0) + ) + + def wrap_magic_init(obj: cobj, args: cobj, kwargs: cobj, T: type) -> i32: + if isinstance(T, ByRef): + F: Static[str] = "__init__" + for fn in _S.fn_overloads(T, F): + a = _PyWrap._reorder_args(fn, obj, args, kwargs, M=1) + if a is not None and _S.fn_can_call(fn, *a): + fn(*a) + return i32(0) + _PyWrap._dispatch_error(F) + else: + F: Static[str] = "__new__" + for fn in _S.fn_overloads(T, F): + a = _PyWrap._reorder_args(fn, obj, args, kwargs, M=0) + if a is not None and _S.fn_can_call(fn, *a): + x = fn(*a) + p = Ptr[PyObject](obj) + 1 + Ptr[T](p.as_byte())[0] = x + return i32(0) + _PyWrap._dispatch_error(F) + + def wrap_magic_call(obj: cobj, args: cobj, kwargs: cobj, T: type) -> cobj: + F: Static[str] = "__call__" + for fn in _S.fn_overloads(T, F): + a = _PyWrap._reorder_args(fn, obj, args, kwargs, M=1) + if a is not None and _S.fn_can_call(fn, *a): + return fn(*a).__to_py__() + _PyWrap._dispatch_error(F) + + def _wrap_cmp(obj: cobj, other: cobj, T: type, F: Static[str]) -> cobj: + return _PyWrap._wrap( + (obj, other), T=T, F=F, + map=lambda f, a: f(*a).__to_py__() + ) + + def wrap_magic_lt(obj: cobj, other: cobj, T: type): + return _PyWrap._wrap_cmp(obj, other, T, "__lt__") + + def wrap_magic_le(obj: cobj, other: cobj, T: type): + return _PyWrap._wrap_cmp(obj, other, T, "__le__") + + def wrap_magic_eq(obj: cobj, other: cobj, T: type): + return _PyWrap._wrap_cmp(obj, other, T, "__eq__") + + def wrap_magic_ne(obj: cobj, other: cobj, T: type): + return _PyWrap._wrap_cmp(obj, other, T, "__ne__") + + def wrap_magic_gt(obj: cobj, other: cobj, T: type): + return _PyWrap._wrap_cmp(obj, other, T, "__gt__") + + def wrap_magic_ge(obj: cobj, other: cobj, T: type): + return _PyWrap._wrap_cmp(obj, other, T, "__ge__") + + def wrap_cmp(obj: cobj, other: cobj, op: i32, C: type) -> cobj: + if hasattr(C, "__lt__") and op == 0i32: + return _PyWrap.wrap_magic_lt(obj, other, C) + elif hasattr(C, "__le__") and op == 1i32: + return _PyWrap.wrap_magic_le(obj, other, C) + elif hasattr(C, "__eq__") and op == 2i32: + return _PyWrap.wrap_magic_eq(obj, other, C) + elif hasattr(C, "__ne__") and op == 3i32: + return _PyWrap.wrap_magic_ne(obj, other, C) + elif hasattr(C, "__gt__") and op == 4i32: + return _PyWrap.wrap_magic_gt(obj, other, C) + elif hasattr(C, "__ge__") and op == 5i32: + return _PyWrap.wrap_magic_ge(obj, other, C) + else: + Py_IncRef(Py_NotImplemented) + return Py_NotImplemented + + def wrap_magic_getitem(obj: cobj, idx: cobj, T: type): + return _PyWrap._wrap( + (obj, idx), T=T, F="__getitem__", + map=lambda f, a: f(*a).__to_py__() + ) + + def wrap_magic_setitem(obj: cobj, idx: cobj, val: cobj, T: type) -> i32: + if val == cobj(): + _PyWrap._wrap( + (obj, idx), T=T, F="__delitem__", + map=lambda f, a: f(*a) + ) + else: + _PyWrap._wrap( + (obj, idx, val), T=T, F="__setitem__", + map=lambda f, a: f(*a) + ) + return i32(0) + + class IterWrap: + _gen: cobj + T: type + + def _init(obj: cobj, T: type) -> cobj: + return _PyWrap.IterWrap(T.__from_py__(obj)).__to_py__() + + @realize_without_self + def __init__(self, it: T): + self._gen = it.__iter__().__raw__() + + def _iter(obj: cobj) -> cobj: + T # need separate fn for each instantiation + p = Ptr[PyObject](obj) + o = p[0] + p[0] = PyObject(o.refcnt + 1, o.pytype) + return obj + + def _iternext(self: cobj) -> cobj: + pt = _PyWrap.IterWrap[T].__from_py__(self) + if pt._gen == cobj(): + return cobj() + + gt = type(T().__iter__())(pt._gen) + if gt.done(): + pt._gen = cobj() + return cobj() + else: + return gt.next().__to_py__() + + def __to_py__(self): + return _PyWrap.wrap_to_py(self) + + def __from_py__(obj: cobj): + return _PyWrap.wrap_from_py(obj, _PyWrap.IterWrap[T]) + + def wrap_magic_iter(obj: cobj, T: type) -> cobj: + return _PyWrap.IterWrap._init(obj, T) + + def wrap_multiple( + obj: cobj, args: Ptr[cobj], nargs: int, _kwds: cobj, T: type, F: Static[str], + M: Static[int] = 1 + ): + kwds = Ptr[str]() + nkw = 0 + if _kwds: + nkw = PyTuple_Size(_kwds) + kwds = Ptr[str](nkw) + for i in range(nkw): + kwds[i] = str.__from_py__(PyTuple_GetItem(_kwds, i)) + + for fn in _S.fn_overloads(T, F): + a = _PyWrap._reorder_args_fastcall(fn, obj, args, nargs, kwds, nkw, M) + if a is not None and _S.fn_can_call(fn, *a): + return fn(*a).__to_py__() + + _PyWrap._dispatch_error(F) + + def wrap_get(obj: cobj, closure: cobj, T: type, S: Static[str]): + return getattr(T.__from_py__(obj), S).__to_py__() + + def wrap_set(obj: cobj, what: cobj, closure: cobj, T: type, S: Static[str]) -> i32: + t = T.__from_py__(obj) + val = type(getattr(t, S)).__from_py__(what) + setattr(t, S, val) + return i32(0) + + def py_type(T: type) -> cobj: + return cobj() + + def wrap_to_py(o) -> cobj: + O = type(o) + P = PyWrapper[O] + sz = sizeof(P) + pytype = _PyWrap.py_type(O) + mem = alloc_atomic_uncollectable(sz) if atomic(O) else alloc_uncollectable(sz) + obj = Ptr[P](mem.as_byte()) + obj[0] = PyWrapper(PyObject(1, pytype), o) + return obj.as_byte() + + def wrap_from_py(o: cobj, T: type) -> T: + obj = Ptr[PyWrapper[T]](o)[0] + pytype = _PyWrap.py_type(T) + if obj.head.pytype != pytype: + _conversion_error(T.__name__) + return obj.data diff --git a/stdlib/internal/static.codon b/stdlib/internal/static.codon new file mode 100644 index 00000000..25d70222 --- /dev/null +++ b/stdlib/internal/static.codon @@ -0,0 +1,35 @@ +# Copyright (C) 2022-2023 Exaloop Inc. + +# Methods for static reflection. Implemented within call.cpp and/or loops.cpp. +# !! Not intended for public use !! + +def fn_overloads(T: type, F: Static[str]): + pass + +def fn_args(F): # function: (i, name) + pass + +def fn_arg_has_type(F, i: Static[int]): + pass + +def fn_arg_get_type(F, i: Static[int]): + pass + +def fn_can_call(F, *args, **kwargs): + pass + +def fn_wrap_call_args(F, *args, **kwargs): + pass + +def fn_has_default(F, i: Static[int]): + pass + +def fn_get_default(F, i: Static[int]): + pass + +def class_args(T: type): + pass + +@no_type_wrap +def static_print(*args): + pass diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index d78d8ee1..ba4bbedd 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -1,5 +1,10 @@ # Copyright (C) 2022-2023 Exaloop Inc. +@tuple +class complex64: + real: float32 + imag: float32 + @tuple class complex: real: float @@ -284,3 +289,280 @@ class int: class float: def __suffix_j__(x: float) -> complex: return complex(0, x) + +f32 = float32 + +@extend +class complex64: + def __new__() -> complex64: + return (f32(0.0), f32(0.0)) + + def __new__(other): + return complex64(other.__complex__()) + + def __new__(other: complex) -> complex64: + return (f32(other.real), f32(other.imag)) + + def __new__(real, imag) -> complex64: + return (f32(float(real)), f32(float(imag))) + + def __complex__(self) -> complex: + return complex(float(self.real), float(self.imag)) + + def __bool__(self) -> bool: + return self.real != f32(0.0) and self.imag != f32(0.0) + + def __pos__(self) -> complex64: + return self + + def __neg__(self) -> complex64: + return complex64(-self.real, -self.imag) + + def __abs__(self) -> f32: + @pure + @C + def hypotf(a: f32, b: f32) -> f32: + pass + + return hypotf(self.real, self.imag) + + def __copy__(self) -> complex64: + return self + + def __hash__(self) -> int: + return self.real.__hash__() + self.imag.__hash__() * 1000003 + + def __add__(self, other) -> complex64: + return self + complex64(other) + + def __sub__(self, other) -> complex64: + return self - complex64(other) + + def __mul__(self, other) -> complex64: + return self * complex64(other) + + def __truediv__(self, other) -> complex64: + return self / complex64(other) + + def __eq__(self, other) -> bool: + return self == complex64(other) + + def __ne__(self, other) -> bool: + return self != complex64(other) + + def __pow__(self, other) -> complex64: + return self ** complex64(other) + + def __radd__(self, other) -> complex64: + return complex64(other) + self + + def __rsub__(self, other) -> complex64: + return complex64(other) - self + + def __rmul__(self, other) -> complex64: + return complex64(other) * self + + def __rtruediv__(self, other) -> complex64: + return complex64(other) / self + + def __rpow__(self, other) -> complex64: + return complex64(other) ** self + + def __add__(self, other: complex64) -> complex64: + return complex64(self.real + other.real, self.imag + other.imag) + + def __sub__(self, other: complex64) -> complex64: + return complex64(self.real - other.real, self.imag - other.imag) + + def __mul__(self, other: complex64) -> complex64: + a = (self.real * other.real) - (self.imag * other.imag) + b = (self.real * other.imag) + (self.imag * other.real) + return complex64(a, b) + + def __truediv__(self, other: complex64) -> complex64: + a = self + b = other + abs_breal = (-b.real) if b.real < f32(0) else b.real + abs_bimag = (-b.imag) if b.imag < f32(0) else b.imag + + if abs_breal >= abs_bimag: + # divide tops and bottom by b.real + if abs_breal == f32(0.0): + # errno = EDOM + return complex64(0.0, 0.0) + else: + ratio = b.imag / b.real + denom = b.real + b.imag * ratio + return complex64( + (a.real + a.imag * ratio) / denom, (a.imag - a.real * ratio) / denom + ) + elif abs_bimag >= abs_breal: + # divide tops and bottom by b.imag + ratio = b.real / b.imag + denom = b.real * ratio + b.imag + # assert b.imag != 0.0 + return complex64( + (a.real * ratio + a.imag) / denom, (a.imag * ratio - a.real) / denom + ) + else: + nan = 0.0 / 0.0 + return complex64(nan, nan) + + def __eq__(self, other: complex64) -> bool: + return self.real == other.real and self.imag == other.imag + + def __ne__(self, other: complex64) -> bool: + return not (self == other) + + def __pow__(self, other: int) -> complex64: + def powu(x: complex64, n: int) -> complex64: + mask = 1 + r = complex64(1.0, 0.0) + p = x + while mask > 0 and n >= mask: + if n & mask: + r = r * p + mask <<= 1 + p = p * p + return r + + if other > 0: + return powu(self, other) + else: + return complex64(1.0, 0.0) / powu(self, -other) + + def __pow__(self, other: complex64) -> complex64: + @pure + @C + def hypotf(a: f32, b: f32) -> f32: + pass + + @pure + @C + def atan2f(a: f32, b: f32) -> f32: + pass + + @pure + @llvm + def exp(x: f32) -> f32: + declare float @llvm.exp.f32(float) + %y = call float @llvm.exp.f32(float %x) + ret float %y + + @pure + @llvm + def pow(x: f32, y: f32) -> f32: + declare float @llvm.pow.f32(float, float) + %z = call float @llvm.pow.f32(float %x, float %y) + ret float %z + + @pure + @llvm + def log(x: f32) -> f32: + declare float @llvm.log.f32(float) + %y = call float @llvm.log.f32(float %x) + ret float %y + + @pure + @llvm + def sin(x: f32) -> f32: + declare float @llvm.sin.f32(float) + %y = call float @llvm.sin.f32(float %x) + ret float %y + + @pure + @llvm + def cos(x: f32) -> f32: + declare float @llvm.cos.f32(float) + %y = call float @llvm.cos.f32(float %x) + ret float %y + + if other.real == f32(0.0) and other.imag == f32(0.0): + return complex64(1.0, 0.0) + elif self.real == f32(0.0) and self.imag == f32(0.0): + # if other.imag != 0. or other.real < 0.: errno = EDOM + return complex64(0.0, 0.0) + else: + vabs = hypotf(self.real, self.imag) + len = pow(vabs, other.real) + at = atan2f(self.imag, self.real) + phase = at * other.real + if other.imag != f32(0.0): + len /= exp(at * other.imag) + phase += other.imag * log(vabs) + return complex64(len * cos(phase), len * sin(phase)) + + def __repr__(self) -> str: + @pure + @llvm + def copysign(x: f32, y: f32) -> f32: + declare float @llvm.copysign.f32(float, float) + %z = call float @llvm.copysign.f32(float %x, float %y) + ret float %z + + @pure + @llvm + def fabs(x: f32) -> f32: + declare float @llvm.fabs.f32(float) + %y = call float @llvm.fabs.f32(float %x) + ret float %y + + if self.real == f32(0.0) and copysign(f32(1.0), self.real) == f32(1.0): + return f"complex64({self.imag}j)" + else: + sign = "+" + if self.imag < f32(0.0) or ( + self.imag == f32(0.0) and copysign(f32(1.0), self.imag) == f32(-1.0) + ): + sign = "-" + return f"complex64({self.real}{sign}{fabs(self.imag)}j)" + + def conjugate(self) -> complex64: + return complex64(self.real, -self.imag) + + # helpers + def _phase(self) -> f32: + @pure + @C + def atan2f(a: f32, b: f32) -> f32: + pass + + return atan2f(self.imag, self.real) + + def _polar(self) -> Tuple[f32, f32]: + return (self.__abs__(), self._phase()) + + @pure + @llvm + def _exp(x: f32) -> f32: + declare float @llvm.exp.f32(float) + %y = call float @llvm.exp.f32(float %x) + ret float %y + + @pure + @llvm + def _sqrt(x: f32) -> f32: + declare float @llvm.sqrt.f32(float) + %y = call float @llvm.sqrt.f32(float %x) + ret float %y + + @pure + @llvm + def _cos(x: f32) -> f32: + declare float @llvm.cos.f32(float) + %y = call float @llvm.cos.f32(float %x) + ret float %y + + @pure + @llvm + def _sin(x: f32) -> f32: + declare float @llvm.sin.f32(float) + %y = call float @llvm.sin.f32(float %x) + ret float %y + + @pure + @llvm + def _log(x: f32) -> f32: + declare float @llvm.log.f32(float) + %y = call float @llvm.log.f32(float %x) + ret float %y diff --git a/stdlib/internal/types/error.codon b/stdlib/internal/types/error.codon index 400a00bc..a0f8ea90 100644 --- a/stdlib/internal/types/error.codon +++ b/stdlib/internal/types/error.codon @@ -3,12 +3,14 @@ # Warning(!): This type must be consistent with the exception # header type defined in runtime/exc.cpp. class BaseException: + _pytype: ClassVar[cobj] = cobj() typename: str message: str func: str file: str line: int col: int + python_type: cobj def __init__(self, typename: str, message: str = ""): self.typename = typename @@ -17,6 +19,7 @@ class BaseException: self.file = "" self.line = 0 self.col = 0 + self.python_type = BaseException._pytype def __str__(self): return self.message @@ -25,78 +28,138 @@ class BaseException: return f'{self.typename}({self.message.__repr__()})' class Exception(Static[BaseException]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, typename: str, msg: str = ""): super().__init__(typename, msg) + if (hasattr(self.__class__, "_pytype")): + self.python_type = self.__class__._pytype class NameError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("NameError", message) + self.python_type = self.__class__._pytype class OSError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("OSError", message) + self.python_type = self.__class__._pytype class IOError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("IOError", message) + self.python_type = self.__class__._pytype class ValueError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("ValueError", message) + self.python_type = self.__class__._pytype -class IndexError(Static[Exception]): +class LookupError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() + def __init__(self, typename: str, message: str = ""): + super().__init__(typename, message) + self.python_type = self.__class__._pytype + def __init__(self, msg: str = ""): + super().__init__("LookupError", msg) + self.python_type = self.__class__._pytype + +class IndexError(Static[LookupError]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("IndexError", message) + self.python_type = self.__class__._pytype -class KeyError(Static[Exception]): +class KeyError(Static[LookupError]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("KeyError", message) - -class OverflowError(Static[Exception]): - def __init__(self, message: str = ""): - super().__init__("OverflowError", message) + self.python_type = self.__class__._pytype class CError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("CError", message) + self.python_type = self.__class__._pytype class PyError(Static[Exception]): pytype: pyobj - def __init__(self, message: str, pytype: pyobj): + def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)): super().__init__("PyError", message) self.pytype = pytype class TypeError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("TypeError", message) + self.python_type = self.__class__._pytype + +class ArithmeticError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() + def __init__(self, msg: str = ""): + super().__init__("ArithmeticError", msg) + self.python_type = self.__class__._pytype -class ZeroDivisionError(Static[Exception]): +class ZeroDivisionError(Static[ArithmeticError]): + _pytype: ClassVar[cobj] = cobj() + def __init__(self, typename: str, message: str = ""): + super().__init__(typename, message) + self.python_type = self.__class__._pytype def __init__(self, message: str = ""): super().__init__("ZeroDivisionError", message) + self.python_type = self.__class__._pytype + +class OverflowError(Static[ArithmeticError]): + _pytype: ClassVar[cobj] = cobj() + def __init__(self, message: str = ""): + super().__init__("OverflowError", message) + self.python_type = self.__class__._pytype class AttributeError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("AttributeError", message) + self.python_type = self.__class__._pytype + +class RuntimeError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() + def __init__(self, typename: str, message: str = ""): + super().__init__(typename, message) + self.python_type = self.__class__._pytype + def __init__(self, message: str = ""): + super().__init__("RuntimeError", message) + self.python_type = self.__class__._pytype -class NotImplementedError(Static[Exception]): +class NotImplementedError(Static[RuntimeError]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("NotImplementedError", message) + self.python_type = self.__class__._pytype class StopIteration(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("StopIteration", message) + self.python_type = self.__class__._pytype class AssertionError(Static[Exception]): + _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): super().__init__("AssertionError", message) + self.python_type = self.__class__._pytype class SystemExit(Static[BaseException]): + _pytype: ClassVar[cobj] = cobj() _status: int def __init__(self, message: str = "", status: int = 0): super().__init__("SystemExit", message) self._status = status + self.python_type = self.__class__._pytype def __init__(self, status: int): self.__init__("", status) diff --git a/stdlib/internal/types/generator.codon b/stdlib/internal/types/generator.codon index 57cf13ae..80ca6a5e 100644 --- a/stdlib/internal/types/generator.codon +++ b/stdlib/internal/types/generator.codon @@ -24,6 +24,12 @@ class Generator: def __raw__(self) -> Ptr[byte]: ret ptr %self + @pure + @derives + @llvm + def __new__(ptr: cobj) -> Generator[T]: + ret ptr %ptr + @pure @llvm def __done__(self) -> bool: diff --git a/stdlib/internal/types/int.codon b/stdlib/internal/types/int.codon index 58dcb268..77ecdb8e 100644 --- a/stdlib/internal/types/int.codon +++ b/stdlib/internal/types/int.codon @@ -63,10 +63,9 @@ class int: @pure @llvm def __abs__(self) -> int: - %0 = icmp sgt i64 %self, 0 - %1 = sub i64 0, %self - %2 = select i1 %0, i64 %self, i64 %1 - ret i64 %2 + declare i64 @llvm.abs.i64(i64, i1) + %0 = call i64 @llvm.abs.i64(i64 %self, i1 false) + ret i64 %0 @pure @llvm diff --git a/stdlib/internal/types/intn.codon b/stdlib/internal/types/intn.codon index 2973f025..f1be1961 100644 --- a/stdlib/internal/types/intn.codon +++ b/stdlib/internal/types/intn.codon @@ -94,6 +94,13 @@ class Int: %0 = xor i{=N} %self, -1 ret i{=N} %0 + @pure + @llvm + def __abs__(self) -> Int[N]: + declare i{=N} @llvm.abs.i{=N}(i{=N}, i1) + %0 = call i{=N} @llvm.abs.i{=N}(i{=N} %self, i1 false) + ret i{=N} %0 + @pure @commutative @associative @@ -231,11 +238,24 @@ class Int: %0 = xor i{=N} %self, %other ret i{=N} %0 - def __repr__(self) -> str: - return f"Int[{N}]({int(self)})" + def __pow__(self, exp: Int[N]) -> Int[N]: + zero = Int[N](0) + one = Int[N](1) + + if exp < zero: + return zero + result = one + while True: + if exp & one: + result *= self + exp >>= one + if not exp: + break + self *= self + return result - def __str__(self) -> str: - return int(self).__repr__() + def __repr__(self) -> str: + return f"Int[{N}]({self.__str__()})" @pure @llvm @@ -327,6 +347,9 @@ class UInt: %0 = xor i{=N} %self, -1 ret i{=N} %0 + def __abs__(self) -> UInt[N]: + return self + @pure @commutative @associative @@ -459,12 +482,25 @@ class UInt: %0 = xor i{=N} %self, %other ret i{=N} %0 + def __pow__(self, exp: UInt[N]) -> UInt[N]: + zero = UInt[N](0) + one = UInt[N](1) + + if exp < zero: + return zero + result = one + while True: + if exp & one: + result *= self + exp >>= one + if not exp: + break + self *= self + return result + def __repr__(self) -> str: return f"UInt[{N}]({self.__str__()})" - def __str__(self) -> str: - return self.__format__("") - def popcnt(self) -> int: return int(Int[N](self)._popcnt()) diff --git a/stdlib/internal/types/strbuf.codon b/stdlib/internal/types/strbuf.codon index e143788c..67ec874c 100644 --- a/stdlib/internal/types/strbuf.codon +++ b/stdlib/internal/types/strbuf.codon @@ -23,5 +23,14 @@ class strbuf: str.memcpy(self.data + self.n, s.ptr, adding) self.n = needed + def reverse(self): + a = 0 + b = self.n - 1 + p = self.data + while a < b: + p[a], p[b] = p[b], p[a] + a += 1 + b -= 1 + def __str__(self): return str(self.data, self.n) diff --git a/test/core/arithmetic.codon b/test/core/arithmetic.codon index b19553e9..dc8c90be 100644 --- a/test/core/arithmetic.codon +++ b/test/core/arithmetic.codon @@ -105,3 +105,40 @@ def test_conversions(): assert bool(Int[80](42)) == True assert str(Int[80](42)) == '42' test_conversions() + +@test +def test_int_pow(): + @nonpure + def f(n): + return n + + assert f(3) ** f(2) == 9 + assert f(27) ** f(7) == 10460353203 + assert f(-27) ** f(7) == -10460353203 + assert f(-27) ** f(6) == 387420489 + assert f(1) ** f(0) == 1 + assert f(1) ** f(1000) == 1 + assert f(0) ** f(3) == 0 + assert f(0) ** f(0) == 1 + + T1 = Int[512] + assert f(T1(3)) ** f(T1(2)) == T1(9) + assert f(T1(27)) ** f(T1(7)) == T1(10460353203) + assert f(T1(-27)) ** f(T1(7)) == T1(-10460353203) + assert f(T1(-27)) ** f(T1(6)) == T1(387420489) + assert f(T1(1)) ** f(T1(0)) == T1(1) + assert f(T1(1)) ** f(T1(1000)) == T1(1) + assert f(T1(0)) ** f(T1(3)) == T1(0) + assert f(T1(0)) ** f(T1(0)) == T1(1) + assert str(f(T1(31)) ** f(T1(31))) == '17069174130723235958610643029059314756044734431' + assert str(f(T1(-31)) ** f(T1(31))) == '-17069174130723235958610643029059314756044734431' + + T2 = UInt[200] + assert f(T2(3)) ** f(T2(2)) == T2(9) + assert f(T2(27)) ** f(T2(7)) == T2(10460353203) + assert f(T2(1)) ** f(T2(0)) == T2(1) + assert f(T2(1)) ** f(T2(1000)) == T2(1) + assert f(T2(0)) ** f(T2(3)) == T2(0) + assert f(T2(0)) ** f(T2(0)) == T2(1) + assert str(f(T2(31)) ** f(T2(31))) == '17069174130723235958610643029059314756044734431' +test_int_pow() diff --git a/test/core/bltin.codon b/test/core/bltin.codon index ec334bf4..0a2df0a4 100644 --- a/test/core/bltin.codon +++ b/test/core/bltin.codon @@ -1,3 +1,5 @@ +# Python-specific + @test def test_min_max(): assert max(2, 1, 1, 1, 1) == 2 @@ -258,6 +260,34 @@ def test_divmod(): assert math.isclose(result[0], exp_result[0]) assert math.isclose(result[1], exp_result[1]) +@test +def test_pow(): + assert pow(3, 4) == 81 + assert pow(-3, 3) == -27 + assert pow(1, 0) == 1 + assert pow(-1, 0) == 1 + assert pow(0, 0) == 1 + assert pow(12, 12, 42) == 36 + assert pow(1234, 4321, 99) == 46 + assert pow(9999, 9999, 2) == 1 + assert pow(0, 0, 1) == 0 + + try: + pow(1, -1, 2) + assert False + except ValueError as e: + assert 'negative' in str(e) + + try: + pow(1, 1, 0) + assert False + except ValueError as e: + assert 'cannot be 0' in str(e) + + assert pow(1.5, 2) == 2.25 + assert pow(9, 0.5) == 3.0 + assert pow(2.0, -1.0) == 0.5 + @test def test_num_from_str(): assert int('0') == 0 @@ -319,9 +349,31 @@ def test_files(open_fn): with open_fn(path) as f: assert f.read(3) == 'hel' assert f.read() == 'lo\nworld\n' + f.seek(0, 0) + assert f.tell() == 0 + assert f.read() == 'hello\nworld\n' + + try: + f.tell() + assert False + except IOError: + pass + + try: + f.seek(0, 0) + assert False + except IOError: + pass + + try: + f.flush() + assert False + except IOError: + pass f = open_fn(path, 'a') f.write('goodbye') + f.flush() f.close() with open_fn(path) as f: @@ -337,7 +389,145 @@ test_gen_builtins() test_int_format() test_reversed() test_divmod() +test_pow() test_num_from_str() test_files(open) import gzip test_files(gzip.open) + + +# Codon-specific + +@pure +@llvm +def zext(x: int, T: type) -> T: + %0 = zext i64 %x to {=T} + ret {=T} %0 + +@test +def test_narrow_int_str(T: type): + z = T(0) + o = T(1) + a = T(42) + b = T(-9) + + assert str(z) == '0' + assert str(-z) == '0' + assert str(o) == '1' + assert str(-o) == '-1' + assert str(o + o + o) == '3' + assert str((o + o + o + o) * (o + o + o)) == '12' + assert str(a) == '42' + assert str(b) == '-9' + assert repr(a) == f'Int[{T.N}](42)' + +@test +def test_narrow_uint_str(T: type): + z = T(0) + o = T(1) + a = T(42) + + assert str(z) == '0' + assert str(-z) == '0' + assert str(o) == '1' + assert str(o + o + o) == '3' + assert str((o + o + o + o) * (o + o + o)) == '12' + assert str(a) == '42' + + if T.N == 32: + assert str(T(0xffffffff)) == '4294967295' + + if T.N == 64: + assert str(T(0xffffffffffffffff)) == '18446744073709551615' + + assert repr(a) == f'UInt[{T.N}](42)' + +@test +def test_wide_int_str(T: type): + z = T(0) + o = T(1) + a = T(0xf23ff2341234) + b = T(-77777) + + assert str(z) == '0' + assert str(-z) == '0' + assert str(o) == '1' + assert str(-o) == '-1' + assert str(o + o + o) == '3' + assert str((o + o + o + o) * (o + o + o)) == '12' + assert str(a) == '266356460360244' + assert str(b) == '-77777' + assert str(a * a) == '70945763975638233282255739536' + assert str(b * b) == '6049261729' + assert str(a * b) == '-20716406417438697588' + + n = zext(0x7fffffffffffffff, T) + m = zext(0xffffffffffffffff, T) + s = T(64) + assert str((n << s) | m) == '170141183460469231731687303715884105727' + if T.N == 128: + assert str(T(1) << T(127)) == '-170141183460469231731687303715884105728' + if T.N > 500: + assert str(a * a * a * a * a * a * a * a) == '25334123245849102734940743817373556303530349383588924760280652082676453679304226528003335153202090430651964934127616' + assert str(a * a * a * a * a * a * a * a * b) == '-1970412103692405663415486231883863088619679984007395801080348277034326537815244826668515398210598987424817876681643589632' + + assert repr(a) == f'Int[{T.N}](266356460360244)' + assert repr(a * b) == f'Int[{T.N}](-20716406417438697588)' + +@test +def test_wide_uint_str(T: type): + z = T(0) + o = T(1) + a = T(0xf23ff2341234) + + assert str(z) == '0' + assert str(-z) == '0' + assert str(o) == '1' + assert str(o + o + o) == '3' + assert str((o + o + o + o) * (o + o + o)) == '12' + assert str(a) == '266356460360244' + assert str(a * a) == '70945763975638233282255739536' + + n = zext(0xffffffffffffffff, T) + s = T(64) + assert str((n << s) | n) == '340282366920938463463374607431768211455' + assert str((n << s) | (n - T(1))) == '340282366920938463463374607431768211454' + if T.N > 500: + assert str(a * a * a * a * a * a * a * a) == '25334123245849102734940743817373556303530349383588924760280652082676453679304226528003335153202090430651964934127616' + + assert repr(a) == f'UInt[{T.N}](266356460360244)' + assert repr(a * a) == f'UInt[{T.N}](70945763975638233282255739536)' + +test_narrow_int_str(Int[7]) +test_narrow_int_str(Int[8]) +test_narrow_int_str(Int[10]) +test_narrow_int_str(Int[16]) +test_narrow_int_str(Int[32]) +test_narrow_int_str(Int[60]) +test_narrow_int_str(Int[63]) +test_narrow_int_str(Int[64]) + +test_narrow_uint_str(UInt[7]) +test_narrow_uint_str(UInt[8]) +test_narrow_uint_str(UInt[10]) +test_narrow_uint_str(UInt[16]) +test_narrow_uint_str(UInt[32]) +test_narrow_uint_str(UInt[60]) +test_narrow_uint_str(UInt[63]) +test_narrow_uint_str(UInt[64]) + +test_wide_int_str(Int[128]) +test_wide_int_str(Int[200]) +test_wide_int_str(Int[256]) +test_wide_int_str(Int[512]) +test_wide_int_str(Int[1024]) +test_wide_int_str(Int[2048]) +test_wide_int_str(Int[4096]) + +test_wide_uint_str(UInt[128]) +test_wide_uint_str(UInt[200]) +test_wide_uint_str(UInt[256]) +test_wide_uint_str(UInt[512]) +test_wide_uint_str(UInt[1024]) +test_wide_uint_str(UInt[2048]) +test_wide_uint_str(UInt[4096]) diff --git a/test/core/containers.codon b/test/core/containers.codon index add7a0a5..c5de81c4 100644 --- a/test/core/containers.codon +++ b/test/core/containers.codon @@ -159,6 +159,19 @@ def test_dyn_tuple(): assert t[3:0:-1] == D(3, 2) assert hash(D(1,2,3,4,5)) == hash((1,2,3,4,5)) + + assert (1, 2) + (3,) == (1, 2, 3) + assert (1,) + (2, 3) == (1, 2, 3) + assert (1, 2) + () == (1, 2) + assert () + () == () + assert () + (1, 2) == (1, 2) + assert (1,) + (2,) == (1, 2) + assert (1, 2) * 3 == (1, 2, 1, 2, 1, 2) + assert () * 99 == () + assert (1, 2, 3, 4) * 1 == (1, 2, 3, 4) + assert (1, 2) * 0 == () + assert (1, 2) * (-1) == () + assert () * -1 == () test_dyn_tuple() @test diff --git a/test/parser/llvm.codon b/test/parser/llvm.codon index 133dbf6e..f8c2e633 100644 --- a/test/parser/llvm.codon +++ b/test/parser/llvm.codon @@ -91,8 +91,8 @@ print a.__copy__() #: 15 print a.__hash__() #: 15 print a.__bool__() #: True print a.__pos__() #: 15 -print a.__neg__() #: 18446744073709551601 -print a.__invert__() #: 18446744073709551600 +print a.__neg__() #: 10633823966279326983230456482242756593 +print a.__invert__() #: 10633823966279326983230456482242756592 m = UInt[123](4) print a.__add__(m), a.__sub__(m), a.__mul__(m), a.__floordiv__(m), a.__truediv__(m) #: 19 11 60 3 3.75 print a.__mod__(m), a.__lshift__(m), a.__rshift__(m) #: 3 240 0 diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index 7723dd2b..2619765a 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -154,6 +154,9 @@ print s #: {0, 1, 2} d = {i: j for i in range(10) if i < 1 for j in range(10)} print d #: {0: 9} +x = {t: lambda x: x * t for t in range(5)} +print(x[3](10)) #: 30 + #%% comprehension_opt,barebones @extend class List: diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index 3d8a9541..eae20112 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -235,6 +235,19 @@ for i in [1]: else: print 'nope' +best = 4 +for s in [3, 4, 5]: + for i in [s]: + if s >= best: + print('b:', best) + break + else: + print('s:', s) + best = s +#: s: 3 +#: b: 3 +#: b: 3 + #%% match def foo(x): match x: @@ -471,6 +484,8 @@ bar() from a.b.rec1_err import bar #! cannot import name 'bar' from 'a.b.rec1_err' #! name 'bar' is not defined +# TODO: get rid of this! +#! no module named 'rec2_err' #%% import_err_1,barebones class Foo: @@ -590,7 +605,7 @@ def f(x): return g(x) print f(5) #: 6 -##% nested_generic_static,barebones +#%% nested_generic_static,barebones def foo(): N: Static[int] = 5 Z: Static[int] = 15 @@ -920,8 +935,8 @@ print(1 - a) #: [ 0 -1 -2] from python import re.split(str, str) -> List[str] as rs print rs(r'\W+', 'Words, words, words.') #: ['Words', 'words', 'words', ''] -#%% python_import_void -from python import os.system(str) -> None +#%% python_import_fn_2 +from python import os.system(str) -> int system("echo 'hello!'") #: hello! #%% python_pydef diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index 4b1da720..cd377433 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -191,7 +191,7 @@ foo(int, float) #! foo() takes 1 arguments (2 given) #%% instantiate_err_2,barebones def foo[N, T](): return N() -foo(int) #! cannot typecheck the program +foo(int) #! generic 'T' not provided #%% instantiate_err_3,barebones Ptr[int, float]() #! Ptr takes 1 generics (2 given) @@ -383,7 +383,7 @@ class A: a = A() #! argument 'ga' has recursive default value #%% call_err_4,barebones -seq_print(1, name="56", name=2) #! keyword argument repeated: name +seq_print_full(1, name="56", name=2) #! keyword argument repeated: name #%% call_partial,barebones def foo(i, j, k): @@ -824,6 +824,28 @@ print c.__class__.__name__, c.foo() #: AX[str] #: CX[str,bool] CX:[BX:[AX:a]:False]:1 +#%% super_vtable_2 +class Base: + def test(self): + print('base.test') +class A(Base): + def test(self): + super().test() + Base.test(self) + print('a.test') +a = A() +a.test() +def moo(x: Base): + x.test() +moo(a) +Base.test(a) +#: base.test +#: base.test +#: a.test +#: base.test +#: base.test +#: a.test +#: base.test #%% super_tuple,barebones @tuple diff --git a/test/parser/types.codon b/test/parser/types.codon index 371df2f4..dc7ee80b 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -1019,7 +1019,8 @@ print [a for a in ()] #: [] def foo(*args): return [a for a in args] args, result = ((), [()]) -print list(foo(*args)) == result #: False +print list(foo(*args)) #: [] +print result #: [()] #%% type_error_reporting @@ -1277,6 +1278,9 @@ for i in statictuple("x", 1, 3.3, 2): #: 3.3 #: 2 +print tuple(Int[i+10](i) for i in statictuple(1, 2, 3)).__class__.__name__ +#: Tuple[Int[11],Int[12],Int[13]] + for i in staticrange(0, 10): if i % 2 == 0: continue if i > 8: break @@ -1299,6 +1303,9 @@ print('whoa') #: xyz Int[7] #: whoa +print tuple(Int[i-10](i) for i in staticrange(30,33)).__class__.__name__ +#: Tuple[Int[20],Int[21],Int[22]] + for i in statictuple(0, 2, 4, 7, 11, 12, 13): if i % 2 == 0: continue if i > 8: break @@ -1312,6 +1319,23 @@ for i in staticrange(10): # TODO: large values are too slow! print('done') #: done +tt = (5, 'x', 3.14, False, [1, 2]) +for i, j in staticenumerate(tt): + print(foo(i * 2 + 1), j) +#: static 1 Int[1] +#: None 5 +#: static 3 Int[3] +#: None x +#: static 5 Int[5] +#: None 3.14 +#: static 7 Int[7] +#: None False +#: static 9 Int[9] +#: None [1, 2] + +print tuple((Int[i+1](i), j) for i, j in staticenumerate(tt)).__class__.__name__ +#: Tuple[Tuple[Int[1],int],Tuple[Int[2],str],Tuple[Int[3],float],Tuple[Int[4],bool],Tuple[Int[5],List[int]]] + #%% static_range_error,barebones for i in staticrange(1000, -2000, -2): pass @@ -1602,6 +1626,35 @@ B[Ho]() #: init B Ho +class Vehicle: + def drive(self): + return "I'm driving a vehicle" + +class Car(Vehicle): + def drive(self): + return "I'm driving a car" + +class Truck(Vehicle): + def drive(self): + return "I'm driving a truck" + +class SUV(Car, Truck): + def drive(self): + return "I'm driving an SUV" + +suv = SUV() +def moo(s): + print(s.drive()) +moo(suv) +moo(Truck()) +moo(Car()) +moo(Vehicle()) +#: I'm driving an SUV +#: I'm driving a truck +#: I'm driving a car +#: I'm driving a vehicle + + #%% polymorphism_error_1,barebones class M[T]: m: T @@ -1832,3 +1885,14 @@ a.test2(2) #: test:B 1 #: test2:B 2 + +#%% no_generic,barebones +def foo(a, b: Static[int]): + pass +foo(5) #! generic 'b' not provided + + +#%% no_generic_2,barebones +def f(a, b, T: type): + print(a, b) +f(1, 2) #! generic 'T' not provided diff --git a/test/python/myextension.codon b/test/python/myextension.codon new file mode 100644 index 00000000..48f7a45c --- /dev/null +++ b/test/python/myextension.codon @@ -0,0 +1,344 @@ +print('Hello from Codon') + +def f1(a: float = 1.11, b: float = 2.22): + return (a, b) + +def f2(): + return ({1: 'one'}, {2}, [3]) + +def f3(x): + return x * 2 + +def f4(x): + return x + +@overload +def f4(a: float = 1.11, b: float = 2.22): + return f1(a=a, b=b) + +@overload +def f4(): + return ['f4()'] + +@overload +def f4(x: str): + return x, x + +def f5(): + pass + +@dataclass(python=True) +class Vec: + a: float + b: float + tag: str + n: ClassVar[int] = 0 + d: ClassVar[int] = 0 + + def __init__(self, a: float, b: float, tag: str): + self.a = a + self.b = b + self.tag = tag + + def __init__(self, a: float = 0.0, b: float = 0.0): + self.__init__(a, b, 'v' + str(Vec.n)) + Vec.n += 1 + + def foo(self, a: float = 1.11, b: float = 2.22): + return (self.a, a, b) + + def bar(self): + return self + + def baz(a: float = 1.11, b: float = 2.22): + return (a, b) + + def nop(): + return 'nop' + + @property + def c(self): + return self.a + self.b + + def __str__(self): + return f'{self.tag}: <{self.a}, {self.b}>' + + def __repr__(self): + return f'Vec({self.a}, {self.b}, {repr(self.tag)})' + + def __add__(self, other: Vec): + return Vec(self.a + other.a, self.b + other.b, f'({self.tag}+{other.tag})') + + def __iadd__(self, other: Vec): + self.a += other.a + self.b += other.b + self.tag = f'({self.tag}+={other.tag})' + return self + + def __add__(self, x: float): + return Vec(self.a + x, self.b + x, f'({self.tag}+{x})') + + def __iadd__(self, x: float): + self.a += x + self.b += x + self.tag = f'({self.tag}+={x})' + return self + + def __add__(self, x: int): + return Vec(self.a + x, self.b + x, f'({self.tag}++{x})') + + def __sub__(self, other: Vec): + return Vec(self.a - other.a, self.b - other.b, f'({self.tag}-{other.tag})') + + def __isub__(self, other: Vec): + self.a -= other.a + self.b -= other.b + self.tag = f'({self.tag}-={other.tag})' + return self + + def __mul__(self, x: float): + return Vec(self.a * x, self.b * x, f'({self.tag}*{x})') + + def __imul__(self, x: float): + self.a *= x + self.b *= x + self.tag = f'({self.tag}*={x})' + return self + + def __mod__(self, x: float): + return Vec(self.a % x, self.b % x, f'({self.tag}%{x})') + + def __imod__(self, x: float): + self.a %= x + self.b %= x + self.tag = f'({self.tag}%={x})' + return self + + def __divmod__(self, x: float): + raise ArithmeticError('no divmod') + # return self.a / x, self.a % x + + def __pow__(self, x: float): + return Vec(self.a ** x, self.b ** x, f'({self.tag}**{x})') + + def __ipow__(self, x: float): + self.a **= x + self.b **= x + self.tag = f'({self.tag}**={x})' + return self + + def __neg__(self): + return Vec(-self.a, -self.b, f'(-{self.tag})') + + def __pos__(self): + return Vec(self.a, self.b, f'(+{self.tag})') + + def __abs__(self): + import math + return math.hypot(self.a, self.b) + + def __bool__(self): + return bool(self.a) and bool(self.b) + + def __invert__(self): + return Vec(-self.a - 1, -self.b - 1, f'(~{self.tag})') + + def __lshift__(self, x: int): + y = 1 << x + return Vec(self.a * y, self.b * y, f'({self.tag}<<{x})') + + def __ilshift__(self, x: int): + y = 1 << x + self.a *= y + self.b *= y + self.tag = f'({self.tag}<<={x})' + return self + + def __rshift__(self, x: int): + y = 1 << x + return Vec(self.a / y, self.b / y, f'({self.tag}>>{x})') + + def __irshift__(self, x: int): + y = 1 << x + self.a /= y + self.b /= y + self.tag = f'({self.tag}>>={x})' + return self + + def __and__(self, x: int): + return Vec(int(self.a) & x, int(self.b) & x, f'({self.tag}&{x})') + + def __iand__(self, x: int): + self.a = int(self.a) & x + self.b = int(self.b) & x + self.tag = f'({self.tag}&={x})' + return self + + def __or__(self, x: int): + return Vec(int(self.a) | x, int(self.b) | x, f'({self.tag}|{x})') + + def __ior__(self, x: int): + self.a = int(self.a) | x + self.b = int(self.b) | x + self.tag = f'({self.tag}|={x})' + return self + + def __xor__(self, x: int): + return Vec(int(self.a) ^ x, int(self.b) ^ x, f'({self.tag}^{x})') + + def __ixor__(self, x: int): + self.a = int(self.a) ^ x + self.b = int(self.b) ^ x + self.tag = f'({self.tag}^={x})' + return self + + #def __int__(self): + # return int(self.b) + + #def __float__(self): + # return self.b + + #def __index__(self): + # return int(self.a) + + def __floordiv__(self, x: float): + return Vec(self.a // x, self.b // x, f'({self.tag}//{x})') + + def __ifloordiv__(self, x: float): + self.a //= x + self.b //= x + self.tag = f'({self.tag}//={x})' + return self + + def __truediv__(self, x: float): + return Vec(self.a / x, self.b / x, f'({self.tag}/{x})') + + def __itruediv__(self, x: float): + self.a /= x + self.b /= x + self.tag = f'({self.tag}/={x})' + return self + + def __matmul__(self, other: Vec): + return (self.a * other.a) + (self.b * other.b) + + def __imatmul__(self, x: float): + self.a *= x + self.b *= x + self.tag = f'({self.tag}@={x})' + return self + + def __len__(self): + return len(self.tag) + + def __getitem__(self, idx: int): + if idx == 0: + return self.a + elif idx == 1: + return self.b + elif idx == 11: + return self.a + self.b + else: + raise KeyError('bad vec key ' + str(idx)) + + def __setitem__(self, idx: int, val: float): + if idx == 0: + self.a = val + elif idx == 1: + self.b = val + elif idx == 11: + self.a = val + self.b = val + else: + raise KeyError('bad vec key ' + str(idx) + ' with val ' + str(val)) + + def __delitem__(self, idx: int): + self[idx] = 0.0 + + def __contains__(self, val: float): + return self.a == val or self.b == val + + def __contains__(self, val: str): + return self.tag == val + + def __hash__(self): + return int(self.a + self.b) + + def __call__(self, a: float = 1.11, b: float = 2.22): + return self.foo(a=a, b=b) + + def __call__(self, x: str): + return (self.a, self.b, x) + + def __call__(self): + return Vec(self.a, self.b, f'({self.tag}())') + + def __iter__(self): + for c in self.tag: + yield c + + def __eq__(self, other: Vec): + return self.a == other.a and self.b == other.b + + def __eq__(self, x: float): + return self.a == x and self.b == x + + def __ne__(self, other: Vec): + return self.a != other.a or self.b != other.b + + def __lt__(self, other: Vec): + return abs(self) < abs(other) + + def __le__(self, other: Vec): + return abs(self) <= abs(other) + + def __gt__(self, other: Vec): + return abs(self) > abs(other) + + def __ge__(self, other: Vec): + return abs(self) >= abs(other) + + def __del__(self): + Vec.d += 1 + + def nd(): + return Vec.d + +def f6(x: float, t: str): + return Vec(x, x, t) + +def reset(): + Vec.n = 0 + +def par_sum(n: int): + m = 0 + @par(num_threads=4) + for i in range(n): + m += 3*i + 7 + return m + +@tuple +class Foo: + a: List[str] + x: Dict[str, int] + + def __new__(a: List[str]) -> Foo: + return (a, {s: i for i, s in enumerate(a)}) + + def __iter__(self): + return iter(self.a) + + def __repr__(self): + return f'Foo({self.a}, {self.x})' + + def hello(self): + return 'x' + + def __int__(self): + return 42 + + def __float__(self): + return 3.14 + + def __index__(self): + return 99 diff --git a/test/python/myextension2.codon b/test/python/myextension2.codon new file mode 100644 index 00000000..6e8dcb3f --- /dev/null +++ b/test/python/myextension2.codon @@ -0,0 +1,287 @@ +print('Hello from Codon 2') + +@tuple +class Vec: + a: float + b: float + tag: str + n: ClassVar[int] = 0 + d: ClassVar[int] = 0 + + def __new__(a: float, b: float, tag: str) -> Vec: + return (a, b, tag) + + def __new__(a: float = 0.0, b: float = 0.0): + v = Vec(a, b, 'v' + str(Vec.n)) + Vec.n += 1 + return v + + def foo(self, a: float = 1.11, b: float = 2.22): + return (self.a, a, b) + + def bar(self): + return self + + def baz(a: float = 1.11, b: float = 2.22): + return (a, b) + + def nop(): + return 'nop' + + @property + def c(self): + return self.a + self.b + + def __str__(self): + return f'{self.tag}: <{self.a}, {self.b}>' + + def __repr__(self): + return f'Vec({self.a}, {self.b}, {repr(self.tag)})' + + def __add__(self, other: Vec): + return Vec(self.a + other.a, self.b + other.b, f'({self.tag}+{other.tag})') + + def __iadd__(self, other: Vec): + a, b = self.a, self.b + a += other.a + b += other.b + tag = f'({self.tag}+={other.tag})' + return Vec(a, b, tag) + + def __add__(self, x: float): + return Vec(self.a + x, self.b + x, f'({self.tag}+{x})') + + def __iadd__(self, x: float): + a, b = self.a, self.b + a += x + b += x + tag = f'({self.tag}+={x})' + return Vec(a, b, tag) + + def __add__(self, x: int): + return Vec(self.a + x, self.b + x, f'({self.tag}++{x})') + + def __sub__(self, other: Vec): + return Vec(self.a - other.a, self.b - other.b, f'({self.tag}-{other.tag})') + + def __isub__(self, other: Vec): + a, b = self.a, self.b + a -= other.a + b -= other.b + tag = f'({self.tag}-={other.tag})' + return Vec(a, b, tag) + + def __mul__(self, x: float): + return Vec(self.a * x, self.b * x, f'({self.tag}*{x})') + + def __imul__(self, x: float): + a, b = self.a, self.b + a *= x + b *= x + tag = f'({self.tag}*={x})' + return Vec(a, b, tag) + + def __mod__(self, x: float): + return Vec(self.a % x, self.b % x, f'({self.tag}%{x})') + + def __imod__(self, x: float): + a, b = self.a, self.b + a %= x + b %= x + tag = f'({self.tag}%={x})' + return Vec(a, b, tag) + + def __divmod__(self, x: float): + raise ArithmeticError('no divmod') + # return self.a / x, self.a % x + + def __pow__(self, x: float): + return Vec(self.a ** x, self.b ** x, f'({self.tag}**{x})') + + def __ipow__(self, x: float): + a, b = self.a, self.b + a **= x + b **= x + tag = f'({self.tag}**={x})' + return Vec(a, b, tag) + + def __neg__(self): + return Vec(-self.a, -self.b, f'(-{self.tag})') + + def __pos__(self): + return Vec(self.a, self.b, f'(+{self.tag})') + + def __abs__(self): + import math + return math.hypot(self.a, self.b) + + def __bool__(self): + return bool(self.a) and bool(self.b) + + def __invert__(self): + return Vec(-self.a - 1, -self.b - 1, f'(~{self.tag})') + + def __lshift__(self, x: int): + y = 1 << x + return Vec(self.a * y, self.b * y, f'({self.tag}<<{x})') + + def __ilshift__(self, x: int): + a, b = self.a, self.b + y = 1 << x + a *= y + b *= y + tag = f'({self.tag}<<={x})' + return Vec(a, b, tag) + + def __rshift__(self, x: int): + y = 1 << x + return Vec(self.a / y, self.b / y, f'({self.tag}>>{x})') + + def __irshift__(self, x: int): + a, b = self.a, self.b + y = 1 << x + a /= y + b /= y + tag = f'({self.tag}>>={x})' + return Vec(a, b, tag) + + def __and__(self, x: int): + return Vec(int(self.a) & x, int(self.b) & x, f'({self.tag}&{x})') + + def __iand__(self, x: int): + a, b = self.a, self.b + a = int(self.a) & x + b = int(self.b) & x + tag = f'({self.tag}&={x})' + return Vec(a, b, tag) + + def __or__(self, x: int): + return Vec(int(self.a) | x, int(self.b) | x, f'({self.tag}|{x})') + + def __ior__(self, x: int): + a, b = self.a, self.b + a = int(self.a) | x + b = int(self.b) | x + tag = f'({self.tag}|={x})' + return Vec(a, b, tag) + + def __xor__(self, x: int): + return Vec(int(self.a) ^ x, int(self.b) ^ x, f'({self.tag}^{x})') + + def __ixor__(self, x: int): + a, b = self.a, self.b + a = int(self.a) ^ x + b = int(self.b) ^ x + tag = f'({self.tag}^={x})' + return Vec(a, b, tag) + + #def __int__(self): + # return int(self.b) + + #def __float__(self): + # return self.b + + #def __index__(self): + # return int(self.a) + + def __floordiv__(self, x: float): + return Vec(self.a // x, self.b // x, f'({self.tag}//{x})') + + def __ifloordiv__(self, x: float): + a, b = self.a, self.b + a //= x + b //= x + tag = f'({self.tag}//={x})' + return Vec(a, b, tag) + + def __truediv__(self, x: float): + return Vec(self.a / x, self.b / x, f'({self.tag}/{x})') + + def __itruediv__(self, x: float): + a, b = self.a, self.b + a /= x + b /= x + tag = f'({self.tag}/={x})' + return Vec(a, b, tag) + + def __matmul__(self, other: Vec): + return (self.a * other.a) + (self.b * other.b) + + def __imatmul__(self, x: float): + a, b = self.a, self.b + a *= x + b *= x + tag = f'({self.tag}@={x})' + return Vec(a, b, tag) + + def __len__(self): + return len(self.tag) + + def __getitem__(self, idx: int): + if idx == 0: + return self.a + elif idx == 1: + return self.b + elif idx == 11: + return self.a + self.b + else: + raise KeyError('bad vec key ' + str(idx)) + + def __contains__(self, val: float): + return self.a == val or self.b == val + + def __contains__(self, val: str): + return self.tag == val + + def __hash__(self): + return int(self.a + self.b) + + def __call__(self, a: float = 1.11, b: float = 2.22): + return self.foo(a=a, b=b) + + def __call__(self, x: str): + return (self.a, self.b, x) + + def __call__(self): + return Vec(self.a, self.b, f'({self.tag}())') + + def __iter__(self): + for c in self.tag: + yield c + + def __eq__(self, other: Vec): + return self.a == other.a and self.b == other.b + + def __eq__(self, x: float): + return self.a == x and self.b == x + + def __ne__(self, other: Vec): + return self.a != other.a or self.b != other.b + + def __lt__(self, other: Vec): + return abs(self) < abs(other) + + def __le__(self, other: Vec): + return abs(self) <= abs(other) + + def __gt__(self, other: Vec): + return abs(self) > abs(other) + + def __ge__(self, other: Vec): + return abs(self) >= abs(other) + + def __del__(self): + Vec.d += 1 + + def nd(): + return Vec.d + +def reset(): + Vec.n = 0 + +def par_sum(n: int): + m = 0 + @par(num_threads=4) + for i in range(n): + m += 3*i + 7 + return m diff --git a/test/python/pyext.py b/test/python/pyext.py new file mode 100644 index 00000000..b9f19599 --- /dev/null +++ b/test/python/pyext.py @@ -0,0 +1,417 @@ +import myext as m +import myext2 as m2 + +def equal(v, a, b, tag): + ok = (v.a == a and v.b == b and v.tag == tag) + if not ok: + print('GOT:', v.a, v.b, v.tag) + print('EXP:', a, b, tag) + return ok + +saw_fun = False +saw_set = False +saw_foo = False + +def test_codon_extensions(m): + m.reset() + + # functions # + ############# + + global saw_fun + + if hasattr(m, 'f1'): + assert m.f1(2.2, 3.3) == (2.2, 3.3) + assert m.f1(2.2, 3.3) == (2.2, 3.3) + assert m.f1(3.3) == (3.3, 2.22) + assert m.f1() == (1.11, 2.22) + assert m.f1(a=2.2, b=3.3) == (2.2, 3.3) + assert m.f1(2.2, b=3.3) == (2.2, 3.3) + assert m.f1(b=3.3, a=2.2) == (2.2, 3.3) + assert m.f1(a=2.2) == (2.2, 2.22) + assert m.f1(b=3.3) == (1.11, 3.3) + + try: + m.f1(1.1, 2.2, 3.3) + except: + pass + else: + assert False + + try: + m.f1(z=1) + except: + pass + else: + assert False + + assert m.f2() == ({1: 'one'}, {2}, [3]) + + try: + m.f2(1) + except: + pass + else: + assert False + + try: + m.f2(z=1, y=5) + except: + pass + else: + assert False + + assert m.f3(42) == 84 + assert m.f3(1.5) == 3.0 + assert m.f3('x') == 'xx' + + try: + m.f3(1, 2) + except: + pass + else: + assert False + + try: + m.f3(a=1, b=2) + except: + pass + else: + assert False + + assert m.f4() == ['f4()'] + assert m.f4(2.2, 3.3) == (2.2, 3.3) + assert m.f4(3.3) == (3.3, 2.22) + assert m.f4(a=2.2, b=3.3) == (2.2, 3.3) + assert m.f4(2.2, b=3.3) == (2.2, 3.3) + assert m.f4(b=3.3, a=2.2) == (2.2, 3.3) + assert m.f4(a=2.2) == (2.2, 2.22) + assert m.f4(b=3.3) == (1.11, 3.3) + assert m.f4('foo') == ('foo', 'foo') + assert m.f4({1}) == {1} + assert m.f5() is None + assert equal(m.f6(1.9, 't'), 1.9, 1.9, 't') + + saw_fun = True + + # constructors # + ################ + + x = m.Vec(3.14, 4.2, 'x') + y = m.Vec(100, 1000, tag='y') + z = m.Vec(b=2.2, a=1.1) + s = m.Vec(10) + t = m.Vec(b=11) + r = m.Vec(3, 4) + + assert equal(x, 3.14, 4.2, 'x') + assert equal(y, 100, 1000, 'y') + assert equal(z, 1.1, 2.2, 'v0') + assert equal(s, 10, 0.0, 'v1') + assert equal(t, 0.0, 11, 'v2') + + try: + m.Vec(tag=10, a=1, b=2) + except: + pass + else: + assert False + + # to-str # + ########## + + assert str(x) == 'x: <3.14, 4.2>' + assert repr(x) == "Vec(3.14, 4.2, 'x')" + + # methods # + ########### + + assert x.foo(2.2, 3.3) == (3.14, 2.2, 3.3) + assert y.foo(3.3) == (100, 3.3, 2.22) + assert z.foo() == (1.1, 1.11, 2.22) + assert x.foo(a=2.2, b=3.3) == (3.14, 2.2, 3.3) + assert x.foo(2.2, b=3.3) == (3.14, 2.2, 3.3) + assert x.foo(b=3.3, a=2.2) == (3.14, 2.2, 3.3) + assert x.foo(a=2.2) == (3.14, 2.2, 2.22) + assert x.foo(b=3.3) == (3.14, 1.11, 3.3) + + try: + x.foo(1, a=1) + except: + pass + else: + assert False + + try: + x.foo(1, 2, b=2) + except: + pass + else: + assert False + + try: + x.foo(1, z=2) + except: + pass + else: + assert False + + assert equal(x.bar(), 3.14, 4.2, 'x') + assert equal(y.bar(), 100, 1000, 'y') + assert equal(z.bar(), 1.1, 2.2, 'v0') + assert equal(s.bar(), 10, 0.0, 'v1') + assert equal(t.bar(), 0.0, 11, 'v2') + + try: + x.bar(1) + except: + pass + else: + assert False + + try: + x.bar(z=1) + except: + pass + else: + assert False + + assert m.Vec.baz(2.2, 3.3) == (2.2, 3.3) + assert x.baz(2.2, 3.3) == (2.2, 3.3) + assert m.Vec.baz(3.3) == (3.3, 2.22) + assert m.Vec.baz() == (1.11, 2.22) + assert m.Vec.baz(a=2.2, b=3.3) == (2.2, 3.3) + assert m.Vec.baz(2.2, b=3.3) == (2.2, 3.3) + assert m.Vec.baz(b=3.3, a=2.2) == (2.2, 3.3) + assert m.Vec.baz(a=2.2) == (2.2, 2.22) + assert m.Vec.baz(b=3.3) == (1.11, 3.3) + + try: + m.Vec.baz(1, a=1) + except: + pass + else: + assert False + + try: + m.Vec.baz(1, 2, b=2) + except: + pass + else: + assert False + + assert m.Vec.nop() == 'nop' + assert x.nop() == 'nop' + assert y.c == 1100 + + # fields # + ########## + + if hasattr(t, '__setitem__'): + t.a = 99 + assert equal(t, 99, 11, 'v2') + t.tag = 't' + assert equal(t, 99, 11, 't') + + # magics # + ########## + + assert equal(+y, 100, 1000, '(+y)') + assert equal(-y, -100, -1000, '(-y)') + assert equal(~y, -101, -1001, '(~y)') + assert abs(r) == 5.0 + assert bool(y) + assert not bool(m.Vec()) + assert len(x) == 1 + assert len(x + y) == 5 + assert hash(y) == 1100 + + assert equal(x + y, 103.14, 1004.2, '(x+y)') + try: + x + 'x' + except: + pass + else: + assert False + assert equal(x + y + y, 203.14, 2004.2, '((x+y)+y)') + assert equal(y + 50.5, 150.5, 1050.5, '(y+50.5)') + assert equal(y + 50, 150, 1050, '(y++50)') + # assert equal(50.5 + y, 150.5, 1050.5, '(y+50.5)') # support for r-magics? + assert equal(y - x, 96.86, 995.8, '(y-x)') + assert equal(y * 3.5, 350.0, 3500.0, '(y*3.5)') + assert equal(y // 3, 33, 333, '(y//3)') + assert equal(y / 2.5, 40.0, 400.0, '(y/2.5)') + try: + divmod(y, 1) + except ArithmeticError as e: + assert str(e) == 'no divmod' + else: + assert False + assert equal(y % 7, 2, 6, '(y%7)') + assert equal(y ** 2, 10000, 1000000, '(y**2)') + assert equal(y << 1, 200, 2000, '(y<<1)') + assert equal(y >> 2, 25, 250, '(y>>2)') + assert equal(y & 77, 68, 72, '(y&77)') + assert equal(y | 77, 109, 1005, '(y|77)') + assert equal(y ^ 77, 41, 933, '(y^77)') + assert y @ r == 4300 + + def dup(v): + return m.Vec(v.a, v.b, v.tag + '1') + + y1 = dup(y) + y1 += x + assert equal(y1, 103.14, 1004.2, '(y1+=x)') + + y1 = dup(y) + y1 += 1.5 + assert equal(y1, 101.5, 1001.5, '(y1+=1.5)') + + y1 = dup(y) + y1 -= x + assert equal(y1, 96.86, 995.8, '(y1-=x)') + + y1 = dup(y) + y1 *= 3.5 + assert equal(y1, 350.0, 3500.0, '(y1*=3.5)') + + y1 = dup(y) + y1 //= 3 + assert equal(y1, 33, 333, '(y1//=3)') + + y1 = dup(y) + y1 /= 2.5 + assert equal(y1, 40.0, 400.0, '(y1/=2.5)') + + y1 = dup(y) + y1 %= 7 + assert equal(y1, 2, 6, '(y1%=7)') + + y1 = dup(y) + y1 **= 2 + assert equal(y1, 10000, 1000000, '(y1**=2)') + + y1 = dup(y) + y1 <<= 1 + assert equal(y1, 200, 2000, '(y1<<=1)') + + y1 = dup(y) + y1 >>= 2 + assert equal(y1, 25, 250, '(y1>>=2)') + + y1 = dup(y) + y1 &= 77 + assert equal(y1, 68, 72, '(y1&=77)') + + y1 = dup(y) + y1 |= 77 + assert equal(y1, 109, 1005, '(y1|=77)') + + y1 = dup(y) + y1 ^= 77 + assert equal(y1, 41, 933, '(y1^=77)') + + y1 = dup(y) + y1 @= 3.5 + assert equal(y1, 350.0, 3500.0, '(y1@=3.5)') + + assert equal(y(), 100, 1000, '(y())') + assert x(2.2, 3.3) == (3.14, 2.2, 3.3) + assert y(3.3) == (100, 3.3, 2.22) + assert x(a=2.2, b=3.3) == (3.14, 2.2, 3.3) + assert x(2.2, b=3.3) == (3.14, 2.2, 3.3) + assert x(b=3.3, a=2.2) == (3.14, 2.2, 3.3) + assert x(a=2.2) == (3.14, 2.2, 2.22) + assert x(b=3.3) == (3.14, 1.11, 3.3) + assert y('foo') == (100.0, 1000.0, 'foo') + + assert x == x + assert x != y + assert r == m.Vec(3, 4, '?') + assert x < y + assert y > x + assert x <= y + assert y >= x + assert y <= y + assert x >= x + + assert list(iter(x)) == ['x'] + assert list(iter(x+y+y)) == list('((x+y)+y)') + + assert 100 in y + assert 1000 in y + assert 100.5 not in y + assert 'y' in y + assert 'x' not in y + + assert y[0] == 100 + assert y[1] == 1000 + assert y[11] == 1100 + try: + y[-1] + except KeyError as e: + assert str(e) == "'bad vec key -1'" + else: + assert False + + global saw_set + if hasattr(y, '__setitem__'): + y[0] = 99.9 + assert equal(y, 99.9, 1000, 'y') + y[1] = -42.6 + assert equal(y, 99.9, -42.6, 'y') + y[11] = 7.7 + assert equal(y, 7.7, 7.7, 'y') + try: + y[2] = 1.2 + except KeyError as e: + assert str(e) == "'bad vec key 2 with val 1.2'" + else: + assert False + + del y[1] + assert equal(y, 7.7, 0.0, 'y') + + saw_set = True + + assert m.Vec.nd() > 0 + + # tuple classes # + ################# + global saw_foo + if hasattr(m, 'Foo'): + x = m.Foo(list('hello')) + assert x.a == list('hello') + assert x.x == {s: i for i, s in enumerate('hello')} + assert x.hello() == 'x' + + try: + x.a = ['bye'] + except AttributeError: + pass + else: + assert False + + assert int(x) == 42 + assert float(x) == 3.14 + assert x.__index__() == 99 + saw_foo = True + + # Codon-specific # + ################## + def par_sum_check(n): + m = 0 + for i in range(n): + m += 3*i + 7 + return m + + for n in (0, 1, 10, 33, 999, 1237): + assert m.par_sum(n) == par_sum_check(n) + +for _ in range(3000): + test_codon_extensions(m) + test_codon_extensions(m2) + +assert saw_fun +assert saw_set +assert saw_foo diff --git a/test/python/setup.py b/test/python/setup.py new file mode 100644 index 00000000..ff98e155 --- /dev/null +++ b/test/python/setup.py @@ -0,0 +1,103 @@ +import os +import sys +import shutil + +from pathlib import Path +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + + +codon_path = os.environ.get("CODON_DIR") +if not codon_path: + c = shutil.which("codon") + if c: + codon_path = Path(c).parent / ".." +else: + codon_path = Path(codon_path) +for path in [ + os.path.expanduser("~") + "/.codon", + os.getcwd() + "/..", +]: + path = Path(path) + if not codon_path and path.exists(): + codon_path = path + break + +if ( + not codon_path + or not (codon_path / "include" / "codon").exists() + or not (codon_path / "lib" / "codon").exists() +): + print( + "Cannot find Codon.", + 'Please either install Codon (/bin/bash -c "$(curl -fsSL https://exaloop.io/install.sh)"),', + "or set CODON_DIR if Codon is not in PATH or installed in ~/.codon", + file=sys.stderr, + ) + sys.exit(1) +codon_path = codon_path.resolve() +print("Codon: " + str(codon_path)) + + +class CodonExtension(Extension): + def __init__(self, name, source): + self.source = source + super().__init__(name, sources=[], language='c') + +class BuildCodonExt(build_ext): + def build_extensions(self): + pass + + def run(self): + inplace, self.inplace = self.inplace, False + super().run() + for ext in self.extensions: + self.build_codon(ext) + if inplace: + self.copy_extensions_to_source() + + def build_codon(self, ext): + extension_path = Path(self.get_ext_fullpath(ext.name)) + build_dir = Path(self.build_temp) + os.makedirs(build_dir, exist_ok=True) + os.makedirs(extension_path.parent.absolute(), exist_ok=True) + + optimization = '-debug' if self.debug else '-release' + self.spawn([ + str(codon_path / "bin" / "codon"), 'build', optimization, "--relocation-model=pic", + '-pyext', '-o', str(extension_path) + ".o", '-module', ext.name, ext.source]) + + print('-->', extension_path) + ext.runtime_library_dirs = [str(codon_path / "lib" / "codon")] + self.compiler.link_shared_object( + [str(extension_path) + ".o"], + str(extension_path), + libraries=["codonrt"], + library_dirs=ext.runtime_library_dirs, + runtime_library_dirs=ext.runtime_library_dirs, + extra_preargs=['-Wl,-rpath,@loader_path'], + # export_symbols=self.get_export_symbols(ext), + debug=self.debug, + build_temp=self.build_temp, + ) + self.distribution.codon_lib = extension_path + +setup( + name='myext', + version='0.1', + packages=['myext'], + ext_modules=[ + CodonExtension('myext', 'myextension.codon'), + ], + cmdclass={'build_ext': BuildCodonExt} +) + +setup( + name='myext2', + version='0.1', + packages=['myext2'], + ext_modules=[ + CodonExtension('myext2', 'myextension2.codon'), + ], + cmdclass={'build_ext': BuildCodonExt} +) diff --git a/test/stdlib/cmath_test.codon b/test/stdlib/cmath_test.codon index 1d2da3bf..3a501d88 100644 --- a/test/stdlib/cmath_test.codon +++ b/test/stdlib/cmath_test.codon @@ -794,3 +794,36 @@ def test_cmath_testcases(): test_cmath_testcases() + + +@test +def test_complex64(): + c64 = complex64 + z = c64(.5 + .5j) + assert c64() == z * 0 + assert z + 1 == c64(1.5, .5) + assert bool(z) == True + assert bool(0 * z) == False + assert +z == z + assert -z == c64(-.5 - .5j) + assert abs(z) == float32(0.7071067811865476) + assert z + 1 == c64(1.5 + .5j) + assert 1j + z == c64(.5 + 1.5j) + assert z * 2 == c64(1 + 1j) + assert 2j * z == c64(-1 + 1j) + assert z / .5 == c64(1 + 1j) + assert 1j / z == c64(1 + 1j) + assert z ** 2 == c64(.5j) + y = 1j ** z + assert math.isclose(float(y.real), 0.32239694194483454) + assert math.isclose(float(y.imag), 0.32239694194483454) + assert z != -z + assert z != 0 + assert z.real == float32(.5) + assert (z + 1j).imag == float32(1.5) + assert z.conjugate() == c64(.5 - .5j) + assert z.__copy__() == z + assert hash(z) + assert c64(complex(z)) == z + +test_complex64()