From 82500d87090d4d530951c52d0a149626d7a3525e Mon Sep 17 00:00:00 2001 From: ahf <41027529+ahfriedman@users.noreply.github.com> Date: Thu, 2 May 2024 07:41:36 -0400 Subject: [PATCH] Added generic programs, made functions that return Unit not have to say return, and added test cases --- programs/cancelable/BranchCancel1.bismuth | 2 +- programs/cursed/README.md | 26 +++ programs/cursed/cursed-no-ret.bismuth | 222 +++++++++++++++++++++ programs/cursed/cursed.bismuth | 227 ++++++++++++++++++++++ programs/generics/GenericProg.bismuth | 13 ++ src/bismuthc.cpp | 1 + src/cli/Compile.cpp | 21 +- src/cli/include/Compile.h | 3 +- src/codegen/CodegenVisitor.cpp | 3 - src/codegen/DynArrayVisitor.cpp | 12 ++ src/codegen/include/DynArrayVisitor.h | 29 +++ src/semantic/SemanticVisitor.cpp | 61 ++++-- src/semantic/include/SemanticVisitor.h | 26 +-- src/symbol/Protocol.cpp | 83 ++++++++ src/symbol/Type.cpp | 55 +++++- src/symbol/include/Protocol.h | 21 +- src/symbol/include/Symbol.h | 3 - src/symbol/include/Type.h | 12 +- test/codegen/codegen_tests.cpp | 22 +++ test/semantic/conditional_tests.cpp | 4 +- test/semantic/program_tests.cpp | 110 +++++------ test/semantic/semantic_tests.cpp | 2 +- 22 files changed, 814 insertions(+), 144 deletions(-) create mode 100644 programs/cursed/README.md create mode 100644 programs/cursed/cursed-no-ret.bismuth create mode 100644 programs/cursed/cursed.bismuth create mode 100644 programs/generics/GenericProg.bismuth create mode 100644 src/codegen/DynArrayVisitor.cpp create mode 100644 src/codegen/include/DynArrayVisitor.h diff --git a/programs/cancelable/BranchCancel1.bismuth b/programs/cancelable/BranchCancel1.bismuth index 24a8303..b8d35ec 100644 --- a/programs/cancelable/BranchCancel1.bismuth +++ b/programs/cancelable/BranchCancel1.bismuth @@ -61,7 +61,7 @@ define c2 :: c : Channel< c.send(2); var b := c.recv(); - cancel(c); # FIXME: when missing, errors are misleading! + cancel(c); # FIXME: when missing, errors are misleading! -> Part of it is that we can visit branches in an ext choice out of order relative to how they are written in code (so if you get an error on branch 2 only, you may think branch 1 is fine) # match a # | Unit u => printf("WRONG! Got unit!\n"); diff --git a/programs/cursed/README.md b/programs/cursed/README.md new file mode 100644 index 0000000..b1aea57 --- /dev/null +++ b/programs/cursed/README.md @@ -0,0 +1,26 @@ +# "Cursed" Tests + +These programs test a few things via cursed implementations of a basic program that, given an integer i, prints: +* Never if i is zero, +* Once if i is one, +* Twice if i is two, +* Thrice if i is three, +* or ith otherwise. + +These programs can be thought of testing a few things, however, specifically, they introduce tests for **binary operators** that were previously untested. + +They also test that (aside from the name of the file/module), `return` can be removed from functions that return a `Unit` without any impact on the generated code or semantics (TODO: test that this is the case even with branches where we may rely on a branch ending in return to help perform type checking). + +Finally, they test that the following function type checks. +```bismuth +define func sil_sel_array(int i) : int { + var arr := [0, 1, 2, 3]; + select { + i < 4: match arr[i] + | int s => return s; + | Unit u => return -1; + true: return i; + } +} +``` +This ensures against an error that used to be raised that would report `(Unit + Var)` is not able to act as an `int` for the match statement. \ No newline at end of file diff --git a/programs/cursed/cursed-no-ret.bismuth b/programs/cursed/cursed-no-ret.bismuth new file mode 100644 index 0000000..396d515 --- /dev/null +++ b/programs/cursed/cursed-no-ret.bismuth @@ -0,0 +1,222 @@ +extern func printf(str s, ...) : int; # Import printf + +define func more_cursed(int i) : Unit { + var r0 := i & ~3; #0b11; // Get the number w/o the last 2 bits + var r1 := r0 | (r0 >> 1); #// Group each 2 bits + var r2 := r1 | (r1 >> 2); #// Groups of 4 bits + var r3 := r2 | (r2 >> 4); #// Groups of 8 bits + var r4 := r3 | (r3 >> 8); #// Groups of 16 bits + var r5 := r4 | (r4 >> 16); #// Groups of 32 bits + var r6 := (r5 & 0b1) << 2; #// Isolate the last bit, and shift it such that the number is 4 + + # // r5 will be all 1s IFF i > 3. So ~r5 & i ensures we combine r6 with i such that: + # // r7 = i > 3 ? 4 : i; + var r7 := r6 | (~r5 & i); + + var arr := [ + (int i) : Unit { printf("Never"); }, + (int i) : Unit { printf("Once"); }, + (int i) : Unit { printf("Twice"); }, + (int i) : Unit { printf("Thrice"); }, + (int i) : Unit { printf("%uth", i);} + ]; + + match arr[r7] + | Unit u => printf("IMPOSSIBLE"); + | (int -> Unit) fn => fn(i); +} + +define func less_cursed(int i) : Unit { + var idx := i; + if i > 3 { + idx := 4 + } + + var arr := [ + (int i) : Unit { printf("Never"); }, + (int i) : Unit { printf("Once"); }, + (int i) : Unit { printf("Twice"); }, + (int i) : Unit { printf("Thrice"); }, + (int i) : Unit { printf("%uth", i);} + ]; + + match arr[idx] + | Unit u => printf("IMPOSSIBLE"); + | (int -> Unit) fn => fn(i); +} + +define func sel_array(int i) : Unit { + var arr := ["Never", "Once", "Twice", "Thrice"]; + select { + i < 4: match arr[i] + | str s => printf(s); + | Unit u => printf("IMPOSSIBLE"); + true: printf("%uth", i); + } +} + +# Define the main program in the system. +# - c is the name of the channel used by the program +# - -int indicates that we have to send an int over the channel. +define program :: c : Channel<-int> { + printf("Hello, World!\n"); + c.send(0) + + for(int i := 0; i < 500000; i := i + 1) { + /* + 500000 + real = [1.496s, 1.245s, 1.235s, 1.243s] + user = [0.319s, 0.265s, 0.190s, 0.210s] + sys = [0.859s, 0.668s, 0.777s, 0.720s] + + 500000 % 5 + real = [1.057s, 1.071s, 1.116s, 1.022s] + user = [0.212s, 0.229s, 0.172s, 0.191s] + sys = [0.685s, 0.668s, 0.730s, 0.735s] + */ + + # less_cursed(i % 5); + + /* + 500000 + real = [1.200s, 1.201s, 1.258s, 1.312s] + user = [0.233s, 0.231s, 0.229s, 0.260s] + sys = [0.768s, 0.727s, 0.778s, 0.676s] + + 500000 % 5 + real = [1.076s, 1.054s, 1.075s, 1.126s] + user = [0.212s, 0.221s, 0.191s, 0.265s] + sys = [0.711s, 0.702s, 0.760s, 0.675s] + */ + + # more_cursed(i % 5); + + /* + 500000 + real = [0.042s, 0.044s, 0.016s, 0.041s] + user = [0.034s, 0.035s, 0.016s, 0.037s] + sys = [0.008s, 0.008s, 0.000s, 0.004s] + + 500000 % 5 + + real = [0.035s, 0.047s, 0.049s, 0.053s] + user = [0.035s, 0.046s, 0.041s, 0.052s] + sys = [0.000s, 0.001s, 0.009s, 0.001s] + */ + + # sil_more_cursed(i % 5); + + + /* + 500000 + real = [1.389s, 1.100s, 1.209s, 1.532s] + user = [0.184s, 0.211s, 0.223s, 0.270s] + sys = [0.803s, 0.711s, 0.687s, 0.928s] + + 500000 % 5 + real = [1.061s, 1.099s, 1.014s, 1.033s] + user = [0.186s, 0.175s, 0.184s, 0.219s] + sys = [0.694s, 0.729s, 0.689s, 0.702s] + */ + + # sel_array(i) + + /* + 500000 + real = [0.006s, 0.020s, 0.017s, 0.018s] + user = [0.005s, 0.015s, 0.013s, 0.008s] + sys = [0.000s, 0.004s, 0.004s, 0.010s] + + 500000 % 5 + real = [0.033s, 0.033s, 0.014s, 0.012s] + user = [0.025s, 0.028s, 0.014s, 0.012s] + sys = [0.009s, 0.005s, 0.000s, 0.001s] + */ + + # sil_sel_array(i % 5); + # printf("\n"); + + /* + 500000 + real = [0.016s, 0.020s, 0.011s, 0.017s] + user = [0.012s, 0.016s, 0.004s, 0.013s] + sys = [0.004s, 0.005s, 0.008s, 0.004s] + + 500000 % 5 + real = [0.009s, 0.020s, 0.004s, 0.018s] + user = [0.009s, 0.014s, 0.000s, 0.018s] + sys = [0.001s, 0.006s, 0.004s, 0.001s] + */ + + + sil_sel_each(i % 5); + } + + printf("%d\n", ~0); +} + + +define func sil_more_cursed(int i) : int { + var r0 := i & ~3; #0b11; // Get the number w/o the last 2 bits + var r1 := r0 | (r0 >> 1); #// Group each 2 bits + var r2 := r1 | (r1 >> 2); #// Groups of 4 bits + var r3 := r2 | (r2 >> 4); #// Groups of 8 bits + var r4 := r3 | (r3 >> 8); #// Groups of 16 bits + var r5 := r4 | (r4 >> 16); #// Groups of 32 bits + var r6 := (r5 & 0b1) << 2; #// Isolate the last bit, and shift it such that the number is 4 + + # // r5 will be all 1s IFF i > 3. So ~r5 & i ensures we combine r6 with i such that: + # // r7 = i > 3 ? 4 : i; + var r7 := r6 | (~r5 & i); + + var arr := [ + (int i) : int { return 0; }, + (int i) : int { return 1; }, + (int i) : int { return 2; }, + (int i) : int { return 3; }, + (int i) : int { return i; } + ]; + + match arr[r7] + | Unit u => return -1; + | (int -> int) fn => return fn(i); +} + +define func sil_less_cursed(int i) : int { + var idx := i; + if i > 3 { + idx := 4 + } + + var arr := [ + (int i) : int { return 0; }, + (int i) : int { return 1; }, + (int i) : int { return 2; }, + (int i) : int { return 3; }, + (int i) : int { return i; } + ]; + + match arr[idx] + | Unit u => return -1; + | (int -> int) fn => return fn(i); +} + +define func sil_sel_array(int i) : int { + var arr := [0, 1, 2, 3]; + select { + i < 4: match arr[i] + | int s => return s; + | Unit u => return -1; + true: return i; + } +} + +define func sil_sel_each(int i) : int { + select { + i == 0: return 0; + i == 1: return 1; + i == 2: return 2; + i == 3: return 3; + true: return i; + } +} \ No newline at end of file diff --git a/programs/cursed/cursed.bismuth b/programs/cursed/cursed.bismuth new file mode 100644 index 0000000..ab9f5fa --- /dev/null +++ b/programs/cursed/cursed.bismuth @@ -0,0 +1,227 @@ +extern func printf(str s, ...) : int; # Import printf + +define func more_cursed(int i) : Unit { + var r0 := i & ~3; #0b11; // Get the number w/o the last 2 bits + var r1 := r0 | (r0 >> 1); #// Group each 2 bits + var r2 := r1 | (r1 >> 2); #// Groups of 4 bits + var r3 := r2 | (r2 >> 4); #// Groups of 8 bits + var r4 := r3 | (r3 >> 8); #// Groups of 16 bits + var r5 := r4 | (r4 >> 16); #// Groups of 32 bits + var r6 := (r5 & 0b1) << 2; #// Isolate the last bit, and shift it such that the number is 4 + + # // r5 will be all 1s IFF i > 3. So ~r5 & i ensures we combine r6 with i such that: + # // r7 = i > 3 ? 4 : i; + var r7 := r6 | (~r5 & i); + + var arr := [ + (int i) : Unit { printf("Never"); return; }, + (int i) : Unit { printf("Once"); return; }, + (int i) : Unit { printf("Twice"); return; }, + (int i) : Unit { printf("Thrice"); return; }, + (int i) : Unit { printf("%uth", i); return; } + ]; + + match arr[r7] + | Unit u => printf("IMPOSSIBLE"); + | (int -> Unit) fn => fn(i); + return; +} + +define func less_cursed(int i) : Unit { + var idx := i; + if i > 3 { + idx := 4 + } + + var arr := [ + (int i) : Unit { printf("Never"); return; }, + (int i) : Unit { printf("Once"); return; }, + (int i) : Unit { printf("Twice"); return; }, + (int i) : Unit { printf("Thrice"); return; }, + (int i) : Unit { printf("%uth", i); return; } + ]; + + match arr[idx] + | Unit u => printf("IMPOSSIBLE"); + | (int -> Unit) fn => fn(i); + return; +} + +# TODO: ALLOW ["Never", "Once", "Twice", "Thrice"][i] +define func sel_array(int i) : Unit { + var arr := ["Never", "Once", "Twice", "Thrice"]; + select { + i < 4: match arr[i] + | str s => printf(s); + | Unit u => printf("IMPOSSIBLE"); + true: printf("%uth", i); + } + + return; +} + +# Define the main program in the system. +# - c is the name of the channel used by the program +# - -int indicates that we have to send an int over the channel. +define program :: c : Channel<-int> { + printf("Hello, World!\n"); + c.send(0) + + for(int i := 0; i < 500000; i := i + 1) { + /* + 500000 + real = [1.496s, 1.245s, 1.235s, 1.243s] + user = [0.319s, 0.265s, 0.190s, 0.210s] + sys = [0.859s, 0.668s, 0.777s, 0.720s] + + 500000 % 5 + real = [1.057s, 1.071s, 1.116s, 1.022s] + user = [0.212s, 0.229s, 0.172s, 0.191s] + sys = [0.685s, 0.668s, 0.730s, 0.735s] + */ + + # less_cursed(i % 5); + + /* + 500000 + real = [1.200s, 1.201s, 1.258s, 1.312s] + user = [0.233s, 0.231s, 0.229s, 0.260s] + sys = [0.768s, 0.727s, 0.778s, 0.676s] + + 500000 % 5 + real = [1.076s, 1.054s, 1.075s, 1.126s] + user = [0.212s, 0.221s, 0.191s, 0.265s] + sys = [0.711s, 0.702s, 0.760s, 0.675s] + */ + + # more_cursed(i % 5); + + /* + 500000 + real = [0.042s, 0.044s, 0.016s, 0.041s] + user = [0.034s, 0.035s, 0.016s, 0.037s] + sys = [0.008s, 0.008s, 0.000s, 0.004s] + + 500000 % 5 + + real = [0.035s, 0.047s, 0.049s, 0.053s] + user = [0.035s, 0.046s, 0.041s, 0.052s] + sys = [0.000s, 0.001s, 0.009s, 0.001s] + */ + + # sil_more_cursed(i % 5); + + + /* + 500000 + real = [1.389s, 1.100s, 1.209s, 1.532s] + user = [0.184s, 0.211s, 0.223s, 0.270s] + sys = [0.803s, 0.711s, 0.687s, 0.928s] + + 500000 % 5 + real = [1.061s, 1.099s, 1.014s, 1.033s] + user = [0.186s, 0.175s, 0.184s, 0.219s] + sys = [0.694s, 0.729s, 0.689s, 0.702s] + */ + + # sel_array(i) + + /* + 500000 + real = [0.006s, 0.020s, 0.017s, 0.018s] + user = [0.005s, 0.015s, 0.013s, 0.008s] + sys = [0.000s, 0.004s, 0.004s, 0.010s] + + 500000 % 5 + real = [0.033s, 0.033s, 0.014s, 0.012s] + user = [0.025s, 0.028s, 0.014s, 0.012s] + sys = [0.009s, 0.005s, 0.000s, 0.001s] + */ + + # sil_sel_array(i % 5); + # printf("\n"); + + /* + 500000 + real = [0.016s, 0.020s, 0.011s, 0.017s] + user = [0.012s, 0.016s, 0.004s, 0.013s] + sys = [0.004s, 0.005s, 0.008s, 0.004s] + + 500000 % 5 + real = [0.009s, 0.020s, 0.004s, 0.018s] + user = [0.009s, 0.014s, 0.000s, 0.018s] + sys = [0.001s, 0.006s, 0.004s, 0.001s] + */ + + + sil_sel_each(i % 5); + } + + printf("%d\n", ~0); +} + + +define func sil_more_cursed(int i) : int { + var r0 := i & ~3; #0b11; // Get the number w/o the last 2 bits + var r1 := r0 | (r0 >> 1); #// Group each 2 bits + var r2 := r1 | (r1 >> 2); #// Groups of 4 bits + var r3 := r2 | (r2 >> 4); #// Groups of 8 bits + var r4 := r3 | (r3 >> 8); #// Groups of 16 bits + var r5 := r4 | (r4 >> 16); #// Groups of 32 bits + var r6 := (r5 & 0b1) << 2; #// Isolate the last bit, and shift it such that the number is 4 + + # // r5 will be all 1s IFF i > 3. So ~r5 & i ensures we combine r6 with i such that: + # // r7 = i > 3 ? 4 : i; + var r7 := r6 | (~r5 & i); + + var arr := [ + (int i) : int { return 0; }, + (int i) : int { return 1; }, + (int i) : int { return 2; }, + (int i) : int { return 3; }, + (int i) : int { return i; } + ]; + + match arr[r7] + | Unit u => return -1; + | (int -> int) fn => return fn(i); +} + +define func sil_less_cursed(int i) : int { + var idx := i; + if i > 3 { + idx := 4 + } + + var arr := [ + (int i) : int { return 0; }, + (int i) : int { return 1; }, + (int i) : int { return 2; }, + (int i) : int { return 3; }, + (int i) : int { return i; } + ]; + + match arr[idx] + | Unit u => return -1; + | (int -> int) fn => return fn(i); +} + +define func sil_sel_array(int i) : int { + var arr := [0, 1, 2, 3]; + select { + i < 4: match arr[i] + | int s => return s; + | Unit u => return -1; + true: return i; + } +} + +define func sil_sel_each(int i) : int { + select { + i == 0: return 0; + i == 1: return 1; + i == 2: return 2; + i == 3: return 3; + true: return i; + } +} \ No newline at end of file diff --git a/programs/generics/GenericProg.bismuth b/programs/generics/GenericProg.bismuth new file mode 100644 index 0000000..d1d2cfa --- /dev/null +++ b/programs/generics/GenericProg.bismuth @@ -0,0 +1,13 @@ +extern func printf(str s, ...); + +define foo :: c : Channel<+T;-T> { + T t := c.recv(); + c.send(t); +} + +define program :: c : Channel<-int> { + var c1 := exec foo; + c1.send(5); + printf("%u\n", c1.recv()); # should be 5 + c.send(0); +} \ No newline at end of file diff --git a/src/bismuthc.cpp b/src/bismuthc.cpp index 3e29fb9..a89ee33 100755 --- a/src/bismuthc.cpp +++ b/src/bismuthc.cpp @@ -111,6 +111,7 @@ Version: Pre-Alpha 1.3.4 @ )"""" << GIT_COMMIT_HASH ChangeLog ========= 1.3.4 - XXX: + - Return statements can be omitted in functions that return Unit. - Added Logical & Arithmetic Right Bit Shift, Left Bit Shift, Bit XOR/AND/OR - Added imports, basic name mangling, and generics/templates - Added u32, i64, and u64 diff --git a/src/cli/Compile.cpp b/src/cli/Compile.cpp index ddee9cb..2dce267 100644 --- a/src/cli/Compile.cpp +++ b/src/cli/Compile.cpp @@ -27,10 +27,8 @@ std::vector pathToIdentifierSteps(std::filesystem::path& relPath)// for(auto it : relPath) { - std::cout << "::" << it; parts.push_back(it); } - std::cout << std::endl; return parts; } @@ -372,7 +370,6 @@ void Stage_CodeGen(std::vectorgetIROut(); if (std::error_code *ec = std::get_if(&irOutOpt)) { @@ -471,15 +468,15 @@ int compile( CompileType compileWith) { - std::cout << "517 compile w/ " << std::endl; - std::cout << "\t inputs : " << std::endl; - std::cout << "\t outputFileName : " << outputFileName << std::endl; - std::cout << "\t demoMode : " << demoMode << std::endl; - std::cout << "\t isVerbose : " << isVerbose << std::endl; - std::cout << "\t toStringMode : " << toStringMode << std::endl; - std::cout << "\t printOutput : " << printOutput << std::endl; - std::cout << "\t noCode : " << noCode << std::endl; - std::cout << "\t compileWith : " << compileWith << std::endl; + // std::cout << "517 compile w/ " << std::endl; + // std::cout << "\t inputs : " << std::endl; + // std::cout << "\t outputFileName : " << outputFileName << std::endl; + // std::cout << "\t demoMode : " << demoMode << std::endl; + // std::cout << "\t isVerbose : " << isVerbose << std::endl; + // std::cout << "\t toStringMode : " << toStringMode << std::endl; + // std::cout << "\t printOutput : " << printOutput << std::endl; + // std::cout << "\t noCode : " << noCode << std::endl; + // std::cout << "\t compileWith : " << compileWith << std::endl; /****************************************************************** * Now that we have the input, we can perform the first stage: diff --git a/src/cli/include/Compile.h b/src/cli/include/Compile.h index 6e0c947..ed1595d 100644 --- a/src/cli/include/Compile.h +++ b/src/cli/include/Compile.h @@ -76,8 +76,7 @@ class CompilerInput { for(auto s : pathSteps) { ans += "::" + s; - } - std::cout << "80 PS " << ans << std::endl; + } return pathSteps; } virtual std::string getSourceName() { return inputStream->getSourceName(); } diff --git a/src/codegen/CodegenVisitor.cpp b/src/codegen/CodegenVisitor.cpp index 5470139..ddeeaa3 100644 --- a/src/codegen/CodegenVisitor.cpp +++ b/src/codegen/CodegenVisitor.cpp @@ -2295,14 +2295,11 @@ std::optional CodegenVisitor::visit(TProgramDefNode & n) if(!fn) fn = Function::Create( fnType, - // GlobalValue::PrivateLinkage, getLinkageType(n.getVisibility()), funcFullName, module ); - // prog->setName(fn.getName().str());// Note: NOT ALWAYS NEEDED -> Probably not needed - // Create basic block BasicBlock *bBlk = BasicBlock::Create(module->getContext(), "entry", fn); builder->SetInsertPoint(bBlk); diff --git a/src/codegen/DynArrayVisitor.cpp b/src/codegen/DynArrayVisitor.cpp new file mode 100644 index 0000000..955eda7 --- /dev/null +++ b/src/codegen/DynArrayVisitor.cpp @@ -0,0 +1,12 @@ +/** + * @file DynArrayVisitor.cpp + * @author Alex Friedman (ahf.dev) + * @brief + * @version 0.1 + * @date 2023-11-19 + * + * @copyright Copyright (c) 2023 + * + */ + +#include "DynArrayVisitor.h" \ No newline at end of file diff --git a/src/codegen/include/DynArrayVisitor.h b/src/codegen/include/DynArrayVisitor.h new file mode 100644 index 0000000..6bc8967 --- /dev/null +++ b/src/codegen/include/DynArrayVisitor.h @@ -0,0 +1,29 @@ +/** + * @file DynArrayVisitor.h + * @author Alex Friedman (ahf.dev) + * @brief + * @version 0.1 + * @date 2023-11-19 + * + * @copyright Copyright (c) 2023 + * + */ +#pragma once + +#include "CodegenUtils.h" +#include "BismuthErrorHandler.h" + + +class DynArrayVisitor : public CodegenModule +{ +// private: + // BismuthErrorHandler errorHandler; + + +public: + DynArrayVisitor(Module *m, DisplayMode mode, int f, BismuthErrorHandler e) : CodegenModule(m, mode, f, e) + { + // errorHandler = e; // TODO: REMOVE FROM DEEPCOPYVISITOR OR NOT? + } + +}; \ No newline at end of file diff --git a/src/semantic/SemanticVisitor.cpp b/src/semantic/SemanticVisitor.cpp index d62429a..7811a2e 100644 --- a/src/semantic/SemanticVisitor.cpp +++ b/src/semantic/SemanticVisitor.cpp @@ -117,14 +117,14 @@ std::variant, ErrorChain *> SemanticVisitor::visit // Note: re-applying template symbols happens in each visitor for now! if (BismuthParser::DefineProgramContext * progCtx = dynamic_cast(e)) { - std::variant progOpt = visitCtx(progCtx); + std::variant progOpt = visitCtx(progCtx); if (ErrorChain **e = std::get_if(&progOpt)) { return (*e)->addError(ctx->getStart(), "Failed to type check program"); } - defs.push_back(std::get(progOpt)); + defs.push_back(std::get(progOpt)); } else if (BismuthParser::DefineFunctionContext * fnCtx = dynamic_cast(e)) { @@ -223,7 +223,7 @@ std::optional SemanticVisitor::postCUVisitChecks(BismuthParser::Co // FIXME: DO SUBTYPING BETTER! if (!(TypeChannel(inv->getProtocol())).isSubtype(new TypeChannel(new ProtocolSequence(false, {new ProtocolSend(false, Types::DYN_INT)})))) { - errorHandler.addError(ctx->getStart(), "Program must recognize a channel of protocol -int, not " + inv->toString(toStringMode)); + errorHandler.addError(ctx->getStart(), "In demo mode, 'program' must recognize a channel of protocol -int, not " + inv->getProtocol()->toString(toStringMode)); } else { @@ -232,7 +232,7 @@ std::optional SemanticVisitor::postCUVisitChecks(BismuthParser::Co } else if(demoMode) { - errorHandler.addError(ctx->getStart(), "When compiling in demo, 'program :: * : Channel<-int>' (the entry point) is required"); + errorHandler.addError(ctx->getStart(), "When compiling in demo mode identifier 'program' must be defined as 'program :: * : Channel<-int>' (the entry point)"); } } else if(demoMode) @@ -414,7 +414,6 @@ SemanticVisitor::phasedVisit(BismuthParser::CompilationUnitContext *ctx, std::ve std::variant SemanticVisitor::visitCtx(BismuthParser::CallExprContext *ctx) { - std::cout << "253 " << ctx->getText() << std::endl; // Need RValue std::variant typeOpt = anyOpt2VarError(errorHandler, ctx->expr->accept(this)); if (ErrorChain **e = std::get_if(&typeOpt)) @@ -511,7 +510,6 @@ std::variant SemanticVisitor::visitCtx(BismuthP std::variant SemanticVisitor::visitCtx(BismuthParser::ExpressionStatementContext *ctx) { - std::cout << "350 " << ctx->getText() << std::endl; if(!dynamic_cast(ctx->expression())) return errorHandler.addError(ctx->getStart(), "Using an expression as statement in a manner that results in dead code."); @@ -532,7 +530,7 @@ std::variant SemanticVisitor::visitCtx(BismuthParser: * @param ctx The parser rule context * @return TProgramDefNode * if successful, ErrorChain * if error */ -std::variant SemanticVisitor::visitCtx(BismuthParser::DefineProgramContext *ctx) +std::variant SemanticVisitor::visitCtx(BismuthParser::DefineProgramContext *ctx) { std::variant symOpt = defineAndGetSymbolFor(ctx); @@ -543,12 +541,7 @@ std::variant SemanticVisitor::visitCtx(BismuthP DefinitionSymbol * defSym = std::get(symOpt); - // Symbol * sym = symScope.first; - // Scope * innerScope = symScope.second; - - - if (const TypeProgram *progType = dynamic_cast(defSym->getType())) - { + auto generateProgram = [this, ctx, defSym](const TypeProgram * progType) -> std::variant { std::string funcId = ctx->name->getText(); // Lookup the function in the current scope and prevent re-declarations @@ -565,7 +558,6 @@ std::variant SemanticVisitor::visitCtx(BismuthP std::variant blkOpt = this->safeVisitBlock(ctx->block(), false); if (ErrorChain **e = std::get_if(&blkOpt)) { - std::cout << "497" << std::endl; return (*e)->addError(ctx->getStart(), "Failed to safe visit block"); } @@ -580,11 +572,35 @@ std::variant SemanticVisitor::visitCtx(BismuthP stmgr->enterScope(orig); return new TProgramDefNode(defSym, channelSymbol, std::get(blkOpt), progType, ctx->getStart()); + }; + + if(const TypeTemplate * templateTy = dynamic_cast(defSym->getType())) + { + if(!templateTy->getValueType()) + return errorHandler.addCompilerError(ctx->getStart(), "template type does not have value type to template"); + + if(const TypeProgram * progTy = dynamic_cast(templateTy->getValueType().value())) + { + auto progNodeOpt = generateProgram(progTy); + if (ErrorChain **e = std::get_if(&progNodeOpt)) + { + return (*e); //->addError(ctx->getStart(), "Failed to safe visit block"); + } + + return new TDefineTemplateNode( + defSym, + templateTy, + std::get(progNodeOpt), + ctx->getStart() + ); + } } - else + else if (const TypeProgram *progType = dynamic_cast(defSym->getType())) { - return errorHandler.addError(ctx->getStart(), "Cannot execute " + defSym->toString()); + return generateProgram(progType); } + + return errorHandler.addError(ctx->getStart(), "Cannot execute " + defSym->toString()); } std::variant SemanticVisitor::visitCtx(BismuthParser::DefineFunctionContext *ctx) @@ -622,7 +638,6 @@ std::variant SemanticVisitor::visitCtx(BismuthPa ); return templateNode; - } return lam; @@ -1645,6 +1660,7 @@ std::variant SemanticVisitor::visitCtx(BismuthPars } // Note: This automatically performs checks to prevent issues with setting VAR = VAR + std::cout << "1663 " << exprType->toString(C_STYLE) << " <: " << newAssignType->toString(C_STYLE) << std::endl; if (e->a && exprType->isNotSubtype(newAssignType)) { return errorHandler.addError(e->getStart(), "Expression of type " + exprType->toString(toStringMode) + " cannot be assigned to " + newAssignType->toString(toStringMode)); @@ -2138,7 +2154,15 @@ std::variant SemanticVisitor::visitCtx(Bismuth // If we have a return type, make sure that we return as the last statement in the FUNC. The type of the return is managed when we visited it. if (!TypedAST::endsInReturn(*blk)) { - errorHandler.addError(ctx->getStart(), "Lambda must end in return statement"); + if(retType->isNotSubtype(Types::UNIT)) + { + errorHandler.addError(ctx->getStart(), "Expected function to return type of " + retType->toString(toStringMode) + "; however, no return instruction was provided."); + } + else + { + // One of the first bits of syntactic sugar in bismuth! + blk->exprs.push_back(new TReturnNode(nullptr, std::nullopt)); + } } safeExitScope(ctx); stmgr->enterScope(origScope); @@ -2405,7 +2429,6 @@ SemanticVisitor::visitPathType(BismuthParser::PathContext *ctx) if(!opt) { - std::cout << stmgr->toString() << std::endl; return errorHandler.addError(pCtx->getStart(), "Could not find " + stepId + " in " + lookupScope->getIdentifier()->getFullyQualifiedName()); } diff --git a/src/semantic/include/SemanticVisitor.h b/src/semantic/include/SemanticVisitor.h index e714a72..d1f8814 100644 --- a/src/semantic/include/SemanticVisitor.h +++ b/src/semantic/include/SemanticVisitor.h @@ -140,8 +140,8 @@ class SemanticVisitor : public BismuthBaseVisitor std::any visitTypeDef(BismuthParser::TypeDefContext *ctx) override { return ctx->defineType()->accept(this); } - std::variant visitCtx(BismuthParser::DefineProgramContext *ctx); - std::any visitDefineProgram(BismuthParser::DefineProgramContext *ctx) override { return TNVariantCast(visitCtx(ctx)); } + std::variant visitCtx(BismuthParser::DefineProgramContext *ctx); + std::any visitDefineProgram(BismuthParser::DefineProgramContext *ctx) override { return TNVariantCast(visitCtx(ctx)); } std::variant visitCtx(BismuthParser::DefineFunctionContext *ctx); std::any visitDefineFunction(BismuthParser::DefineFunctionContext *ctx) override { return TNVariantCast(visitCtx(ctx)); } @@ -324,26 +324,22 @@ std::variant< std::variant visitCondition(BismuthParser::ExpressionContext *ex) { - std::cout << "338" << std::endl; auto a = ex->accept(this); - std::cout << "340-pre" << std::endl; - std::cout << a.type().name() << std::endl; std::variant condOpt = anyOpt2VarError(errorHandler, a); -std::cout << "340" << std::endl; + if (ErrorChain **e = std::get_if(&condOpt)) { - std::cout << "346" << std::endl; return (*e)->addError(ex->getStart(), "Unable to type check condition expression"); } -std::cout << "345" << std::endl; + TypedNode *cond = std::get(condOpt); const Type *conditionType = cond->getType(); -std::cout << "348" << std::endl; + if (conditionType->isNotSubtype(Types::DYN_BOOL)) { return errorHandler.addError(ex->getStart(), "Condition expected boolean, but was given " + conditionType->toString(toStringMode)); } -std::cout << "353" << std::endl; + return cond; } @@ -369,16 +365,14 @@ std::cout << "353" << std::endl; bool foundReturn = false; for (auto e : ctx->stmts) { - std::cout << "383 " << e->getText() << std::endl; // Visit all the statements in the block std::variant tnOpt = anyOpt2VarError(errorHandler, e->accept(this)); -std::cout << "386" << std::endl; + if (ErrorChain **e = std::get_if(&tnOpt)) { - std::cout << "389" << std::endl; return (*e)->addError(ctx->getStart(), "Failed to type check statement in block"); } -std::cout << "391" << std::endl; + nodes.push_back(std::get(tnOpt)); // If we found a return, then this is dead code, and we can break out of the loop. if (foundReturn) @@ -386,7 +380,7 @@ std::cout << "391" << std::endl; errorHandler.addError(ctx->getStart(), "Dead code"); break; } -std::cout << "399" << std::endl; + // If the current statement is a return, set foundReturn = true if (dynamic_cast(e)) foundReturn = true; @@ -553,9 +547,7 @@ std::cout << "399" << std::endl; // Should always be inferrable if(const TypeInfer * inf = dynamic_cast(sym->getType())) { - std::cout << "PRE Unify " << sym->toString() << std::endl; inf->unify(); - std::cout << "POST Unify " << sym->toString() << std::endl; } } diff --git a/src/symbol/Protocol.cpp b/src/symbol/Protocol.cpp index 49d23f4..2155a19 100644 --- a/src/symbol/Protocol.cpp +++ b/src/symbol/Protocol.cpp @@ -43,6 +43,14 @@ bool ProtocolRecv::isSupertypeFor(const Protocol *other) const return false; } + +const ProtocolRecv * ProtocolRecv::getCopySubst(std::map & existing) const +{ + return new ProtocolRecv( + this->isInCloseable(), + recvType->getCopySubst(existing) + ); +} /********************************************* * * ProtocolSend @@ -77,6 +85,14 @@ bool ProtocolSend::isSupertypeFor(const Protocol *other) const return false; } + +const ProtocolSend * ProtocolSend::getCopySubst(std::map & existing) const +{ + return new ProtocolSend( + this->isInCloseable(), + sendType->getCopySubst(existing) + ); +} /********************************************* * * ProtocolWN @@ -110,6 +126,14 @@ bool ProtocolWN::isSupertypeFor(const Protocol *other) const return false; } +const ProtocolWN * ProtocolWN::getCopySubst(std::map & existing) const +{ + return new ProtocolWN( + this->isInCloseable(), + proto->getCopySubst(existing) + ); +} + /********************************************* * * ProtocolOC @@ -142,6 +166,14 @@ bool ProtocolOC::isSupertypeFor(const Protocol *other) const return false; } +const ProtocolOC * ProtocolOC::getCopySubst(std::map & existing) const +{ + return new ProtocolOC( + this->isInCloseable(), + proto->getCopySubst(existing) + ); +} + /********************************************* * * ProtocolIChoice @@ -229,6 +261,20 @@ bool ProtocolIChoice::isSupertypeFor(const Protocol *other) const return false; } +const ProtocolIChoice * ProtocolIChoice::getCopySubst(std::map & existing) const +{ + std::set opts; + + for (auto p : this->opts) + { + opts.insert(p->getCopySubst(existing)); + } + + auto ans = new ProtocolIChoice(this->inCloseable, opts); + ans->guardCount = this->guardCount; + return ans; +} + /********************************************* * * ProtocolEChoice @@ -348,6 +394,20 @@ bool ProtocolEChoice::isSupertypeFor(const Protocol *other) const return false; } + +const ProtocolEChoice * ProtocolEChoice::getCopySubst(std::map & existing) const +{ + std::set opts; + + for (auto p : this->opts) + { + opts.insert(p->getCopySubst(existing)); + } + + auto ans = new ProtocolEChoice(this->inCloseable, opts); + ans->guardCount = this->guardCount; + return ans; +} /********************************************* * * ProtocolSequence @@ -963,6 +1023,20 @@ bool ProtocolSequence::isSupertypeFor(const Protocol *other) const return false; } +const ProtocolSequence * ProtocolSequence::getCopySubst(std::map & existing) const +{ + std::vector substSteps; + + for(auto step : this->steps) + { + substSteps.push_back( + step->getCopySubst(existing) + ); + } + + return new ProtocolSequence(this->isInCloseable(), substSteps); +} + /********************************************* * * ProtocolClose @@ -1043,4 +1117,13 @@ bool ProtocolClose::isSupertypeFor(const Protocol *other) const } return false; +} + +const ProtocolClose * ProtocolClose::getCopySubst(std::map & existing) const +{ + return new ProtocolClose( + this->isInCloseable(), + proto->getCopySubst(existing), + closeNumber + ); } \ No newline at end of file diff --git a/src/symbol/Type.cpp b/src/symbol/Type.cpp index b5f6767..995e2ee 100644 --- a/src/symbol/Type.cpp +++ b/src/symbol/Type.cpp @@ -2,7 +2,6 @@ bool Type::isSubtype(const Type *other, InferenceMode mode) const { - std::cout << "5 " << this->toString(C_STYLE) << " <: " << other->toString(C_STYLE) << std::endl; if (const TypeInfer *inf = dynamic_cast(this)) { // return false; @@ -572,6 +571,10 @@ const Type * TypeProgram::getCopySubst(std::map exis existing.insert({this, ans}); + ans->setProtocol( + protocol->getCopySubst(existing) + ); + // FIXME: NEED TO IMPL THIS!!! -> but it seems to work? though we don't have generics for programs // TODO: use ->define() func! @@ -785,7 +788,6 @@ const TypeInfer * TypeInfer::getCopy() const { return this; }; bool TypeInfer::setValue(const Type *other, InferenceMode mode) const { - std::cout << "844 " << this->toString(C_STYLE) << " " << (mode == InferenceMode::QUERY) << std::endl; // Prevent us from being sent another TypeInfer. There's no reason for this to happen // as it should have been added as a dependency (and doing this would break things) if (dynamic_cast(other)) @@ -825,7 +827,7 @@ bool TypeInfer::setValue(const Type *other, InferenceMode mode) const valid: if(mode != InferenceMode::SET) return true; -std::cout << "883" << std::endl; + // Set our valueType to be the provided type to see if anything breaks... TypeInfer *u_this = const_cast(this); *u_this->valueType = other; @@ -852,12 +854,20 @@ bool TypeInfer::isSupertypeFor(const Type *other) const bool TypeInfer::isSupertypeFor(const Type *other, InferenceMode mode) const { - std::cout << "858 " << this->toString(C_STYLE) << " VS " << other->toString(C_STYLE) << " " << (mode == InferenceMode::QUERY) << std::endl; + std::cout << "857 " << this->toString(C_STYLE) << ".isSupertypeFor " << other->toString(C_STYLE) << std::endl; + if(possibleTypes.size()) + { + std::cout << "POSSIBLE TYPES: "; + for(auto t : possibleTypes) + { + std::cout << t->toString(C_STYLE) << ", "; + } + std::cout << endl; + } // If we already have an inferred type, we can simply // check if that type is a subtype of other. if (valueType->has_value()) { - std::cout << "881" << std::endl; return valueType->value()->isSubtype(other); } /* @@ -882,6 +892,13 @@ bool TypeInfer::isSupertypeFor(const Type *other, InferenceMode mode) const TypeInfer *moth = const_cast(oInf); moth->infTypes.push_back(this); + + // TODO: handle this better so that way we can compare and unify across + // the two when both are non-empty (DO AN INTERSECT!) + if(moth->possibleTypes.empty() && !this->possibleTypes.empty()) + { + moth->possibleTypes = this->possibleTypes; + } return true; } @@ -893,17 +910,14 @@ bool TypeInfer::isSupertypeFor(const Type *other, InferenceMode mode) const // is going to error anyways. We won't want to show this as a var. if(!ans && mode == InferenceMode::SET) unify(); - std::cout << "910!" << this->toString(C_STYLE) << std::endl; return ans; } bool TypeInfer::unify() const { - std::cout << "909 " << this->hasBeenInferred() << " " << possibleTypes.size() << std::endl; if(this->hasBeenInferred()) return true; if(possibleTypes.size() != 0) { - std::cout << "912 " << (*(possibleTypes.begin()))->toString(C_STYLE) << std::endl; return this->isSupertypeFor(*(possibleTypes.begin())); } @@ -945,7 +959,16 @@ bool TypeSum::contains(const Type *ty) const { if(const TypeInfer * infType = dynamic_cast(ty)) { - if(!infType->hasBeenInferred()) return false; + if(!infType->hasBeenInferred()) + return false; + // { + // for(const Type * t : this->cases) + // { + // if(t->isSubtype(infType)) + // return true; + // } + // return false; + // } return this->contains(infType->getValueType().value()); } @@ -1517,4 +1540,18 @@ const Type * TypeModule::getCopySubst(std::map exist bool TypeModule::isSupertypeFor(const Type * other) const { return this == other; +} + + +bool TypeCompare::operator()(const Type *a, const Type *b) const +{ + // Only needed b/c of int types giving + // type infer due to trying to allow for + // inference of specific int type... + // if(const TypeInfer * infA = dynamic_cast(a)) + // { + // infA->isSubtype(b); + // } + if(dynamic_cast(a)) b->isSubtype(a, InferenceMode::QUERY); + return a->toString(C_STYLE) < b->toString(C_STYLE); } \ No newline at end of file diff --git a/src/symbol/include/Protocol.h b/src/symbol/include/Protocol.h index d676966..d8126b0 100644 --- a/src/symbol/include/Protocol.h +++ b/src/symbol/include/Protocol.h @@ -95,6 +95,8 @@ class Protocol } bool isSubtype(const Protocol * other) const; + + virtual const Protocol * getCopySubst(std::map & existing) const = 0; protected: virtual bool isSupertypeFor(const Protocol *other) const = 0; }; @@ -193,6 +195,7 @@ class ProtocolSequence : public Protocol void guard() const override; bool unguard() const override; + const ProtocolSequence * getCopySubst(std::map & existing) const override; protected: virtual bool isSupertypeFor(const Protocol *other) const override; @@ -237,6 +240,7 @@ class ProtocolRecv : public Protocol const Type* getRecvType() const { return recvType; } + const ProtocolRecv * getCopySubst(std::map & existing) const override; protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; @@ -263,6 +267,7 @@ class ProtocolSend : public Protocol const Type *getSendType() const { return sendType; } + const ProtocolSend * getCopySubst(std::map & existing) const override; protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; @@ -289,6 +294,8 @@ class ProtocolWN : public Protocol const ProtocolSequence *getInnerProtocol() const { return proto; } + const ProtocolWN * getCopySubst(std::map & existing) const override; + protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; @@ -315,6 +322,8 @@ class ProtocolOC : public Protocol const ProtocolSequence *getInnerProtocol() const { return proto; } + const ProtocolOC * getCopySubst(std::map & existing) const override; + protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; @@ -366,6 +375,13 @@ struct ProtocolBranchOption { seq->getCopy() ); } + + const ProtocolBranchOption * getCopySubst(std::map & existing) const { + return new ProtocolBranchOption( + label, + seq->getCopySubst(existing) + ); + } }; @@ -412,6 +428,7 @@ class ProtocolEChoice : public Protocol std::set getOptions() const { return opts; } std::optional lookup(std::variant opt) const; + const ProtocolEChoice * getCopySubst(std::map & existing) const override; protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; @@ -437,7 +454,7 @@ class ProtocolIChoice : public Protocol const Protocol *getCopy() const override; std::set getOptions() const { return opts; } - + const ProtocolIChoice * getCopySubst(std::map & existing) const override; protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; @@ -471,6 +488,8 @@ class ProtocolClose : public Protocol void guard() const override; bool unguard() const override; + const ProtocolClose * getCopySubst(std::map & existing) const override; + protected: virtual bool isSupertypeFor(const Protocol *other) const override; }; \ No newline at end of file diff --git a/src/symbol/include/Symbol.h b/src/symbol/include/Symbol.h index cd33e35..193d1b4 100644 --- a/src/symbol/include/Symbol.h +++ b/src/symbol/include/Symbol.h @@ -134,8 +134,6 @@ class DefinitionSymbol : public LocatableSymbol VisibilityModifier visibility; }; - -// FIXME: finish this and impl paths! class AliasSymbol : public LocatableSymbol { public: @@ -155,5 +153,4 @@ class AliasSymbol : public LocatableSymbol private: Identifier * orig; - // Symbol * orig; }; \ No newline at end of file diff --git a/src/symbol/include/Type.h b/src/symbol/include/Type.h index 6913835..5434b77 100644 --- a/src/symbol/include/Type.h +++ b/src/symbol/include/Type.h @@ -419,16 +419,7 @@ class TypeAbsurd : public Type struct TypeCompare { - bool operator()(const Type *a, const Type *b) const - { - // Only needed b/c of int types giving - // type infer due to trying to allow for - // inference of specific int type... - // if(const TypeInfer * infA = dynamic_cast(a)) - // { - // } - return a->toString(C_STYLE) < b->toString(C_STYLE); - } + bool operator()(const Type *a, const Type *b) const; }; namespace Types @@ -875,6 +866,7 @@ class TypeInfer : public Type * @return false */ bool hasBeenInferred() const; + bool hasPossibleTypes() const { return !possibleTypes.empty(); } std::optional getValueType() const; diff --git a/test/codegen/codegen_tests.cpp b/test/codegen/codegen_tests.cpp index 94c2994..d60d599 100644 --- a/test/codegen/codegen_tests.cpp +++ b/test/codegen/codegen_tests.cpp @@ -883,6 +883,14 @@ TEST_CASE("programs/generics/ReferenceGenericS - Co-dependent structs", "[codege "950b9c5cf40bd673a8250726ae932b20089332f552fae683d85a402f471e054c"); } +TEST_CASE("programs/generics/GenericProg - Generic Program", "[codegen][generic]") +{ + auto stream = std::fstream("/home/shared/programs/generics/GenericProg.bismuth"); + EnsureCompilesTo( + antlr4::ANTLRInputStream(stream), + "4d1e7ff027a665ed171e1e338f9a1d2e3fe0c972b854219acf505a7b5670c86c"); +} + TEST_CASE("programs/inferint - Infer the type of a number", "[codegen][infer integers]") { @@ -894,6 +902,20 @@ TEST_CASE("programs/inferint - Infer the type of a number", "[codegen][infer int } +TEST_CASE("programs/cursed - Binary Operators, Functions with inferred returns, and inference of ints through array matching", "[codegen]") +{ + std::string hash = "0df6327622a38a2a07439f1ed3caee1dc9cdc35f5887d231b5393fde3aacf37a"; + auto stream = std::fstream("/home/shared/programs/cursed/cursed.bismuth"); + EnsureCompilesTo( + antlr4::ANTLRInputStream(stream), + hash); + + auto stream2 = std::fstream("/home/shared/programs/cursed/cursed-no-ret.bismuth"); + EnsureCompilesTo( + antlr4::ANTLRInputStream(stream2), + hash); +} + /************************************ * Example C-Level Tests ************************************/ diff --git a/test/semantic/conditional_tests.cpp b/test/semantic/conditional_tests.cpp index 618abc7..c3e5a4d 100644 --- a/test/semantic/conditional_tests.cpp +++ b/test/semantic/conditional_tests.cpp @@ -14,9 +14,9 @@ using Catch::Matchers::ContainsSubstring; -void EnsureErrorsWithMessage(antlr4::ANTLRInputStream input, std::string message); +void EnsureErrorsWithMessage(antlr4::ANTLRInputStream input, std::string message, int flags=0); -void EnsureErrorsWithMessage(std::string program, std::string message); +void EnsureErrorsWithMessage(std::string program, std::string message, int flags=0); TEST_CASE("Inference If Errors - 1", "[semantic]") diff --git a/test/semantic/program_tests.cpp b/test/semantic/program_tests.cpp index 04cc82a..83db839 100644 --- a/test/semantic/program_tests.cpp +++ b/test/semantic/program_tests.cpp @@ -12,7 +12,7 @@ using Catch::Matchers::ContainsSubstring; -void EnsureErrorsWithMessage(antlr4::ANTLRInputStream input, std::string message) +void EnsureErrorsWithMessage(antlr4::ANTLRInputStream input, std::string message, int flags=0) { BismuthLexer lexer(&input); antlr4::CommonTokenStream tokens(&lexer); @@ -22,17 +22,17 @@ void EnsureErrorsWithMessage(antlr4::ANTLRInputStream input, std::string message REQUIRE_NOTHROW(tree = parser.compilationUnit()); REQUIRE(tree != NULL); STManager stm = STManager(); - SemanticVisitor sv = SemanticVisitor(&stm, DisplayMode::C_STYLE, 0); + SemanticVisitor sv = SemanticVisitor(&stm, DisplayMode::C_STYLE, flags); auto cuOpt = sv.visitCtx(tree); REQUIRE(sv.hasErrors(0)); REQUIRE_THAT(sv.getErrors(), ContainsSubstring(message)); } -void EnsureErrorsWithMessage(std::string program, std::string message) +void EnsureErrorsWithMessage(std::string program, std::string message, int flags=0) { antlr4::ANTLRInputStream input(program); - EnsureErrorsWithMessage(input, message); + EnsureErrorsWithMessage(input, message, flags); } // TODO: does this use excess memory bc we dont free news? @@ -194,70 +194,41 @@ TEST_CASE("programs/test16f - var loop", "[semantic]") REQUIRE_FALSE(sv.hasErrors(0)); } -// FIXME: REENABLE AGAIN! -/* -TEST_CASE("Test program() should return int warning", "[semantic][conditional]") +TEST_CASE("Demo Mode: Program is required", "[semantic][conditional]") { - antlr4::ANTLRInputStream input( - R""""( - define func program () { - return; # FIXME: DO THESE HAVE TO END IN RETURN? - } - )""""); - BismuthLexer lexer(&input); - // lexer.removeErrorListeners(); - // auto lListener = TestErrorListener(); - // lexer.addErrorListener(&lListener); - antlr4::CommonTokenStream tokens(&lexer); - BismuthParser parser(&tokens); - parser.removeErrorListeners(); - auto pListener = TestErrorListener(); - parser.addErrorListener(&pListener); - - BismuthParser::CompilationUnitContext *tree = NULL; - REQUIRE_NOTHROW(tree = parser.compilationUnit()); - REQUIRE(tree != NULL); - REQUIRE(tree->getText() != ""); - - STManager stmgr = STManager(); - SemanticVisitor sv = SemanticVisitor(&stmgr); - - sv.visitCompilationUnit(tree); - CHECK_FALSE(sv.hasErrors(ERROR)); - CHECK(sv.hasErrors(CRITICAL_WARNING)); + EnsureErrorsWithMessage( + R""""( + )"""", + "When compiling in demo mode, 'program :: * : Channel<-int>' (the entry point) must be defined", + DEMO_MODE + ); } -TEST_CASE("Test program() should not have parameters warning", "[semantic][conditional]") +TEST_CASE("Demo Mode: Program is wrong type", "[semantic][conditional]") { - antlr4::ANTLRInputStream input( - R""""( - int func program (int a) { - return 0; - } - )""""); - BismuthLexer lexer(&input); - // lexer.removeErrorListeners(); - // auto lListener = TestErrorListener(); - // lexer.addErrorListener(&lListener); - antlr4::CommonTokenStream tokens(&lexer); - BismuthParser parser(&tokens); - parser.removeErrorListeners(); - auto pListener = TestErrorListener(); - parser.addErrorListener(&pListener); - - BismuthParser::CompilationUnitContext *tree = NULL; - REQUIRE_NOTHROW(tree = parser.compilationUnit()); - REQUIRE(tree != NULL); - REQUIRE(tree->getText() != ""); - - STManager stmgr = STManager(); - SemanticVisitor sv = SemanticVisitor(&stmgr); + EnsureErrorsWithMessage( + R""""( +define func program () { + return; +} + )"""", + "When compiling in demo mode identifier 'program' must be defined as 'program :: * : Channel<-int>' (the entry point)", + DEMO_MODE + ); +} - sv.visitCompilationUnit(tree); - CHECK_FALSE(sv.hasErrors(ERROR)); - CHECK(sv.hasErrors(CRITICAL_WARNING)); +TEST_CASE("Demo Mode: Program follows wrong protocol", "[semantic][conditional]") +{ + EnsureErrorsWithMessage( + R""""( +define program :: c : Channel<-boolean> { + c.send(false) +} + )"""", + "In demo mode, 'program' must recognize a channel of protocol -int, not -boolean", + DEMO_MODE + ); } -*/ TEST_CASE("Dead code in program block", "[semantic][program]") { @@ -671,8 +642,6 @@ define foo :: c : Channel<+int> = { "Unsupported redeclaration of foo"); } -// FIXME: TEST EXTERN PROGRAMS? - TEST_CASE("Forward Decl with wrong num args", "[semantic][program][function][forward-decl]") { EnsureErrorsWithMessage( @@ -2233,6 +2202,19 @@ define program :: c : Channel<-int> { "Evaluation of expression would result in introducing a linear resource that is impossible to use"); } + +TEST_CASE("Error message during inference of number reports proper type mismatch", "[semantic][program]") +{ + EnsureErrorsWithMessage( + R""""( +define program :: c : Channel<-int> { + str[4] arr := [0, 1, 2, 3]; + c.send(0); +} + )"""", + "Expression of type int[4] cannot be assigned to str[4]"); +} + /********************************* * B-Level Example tests *********************************/ diff --git a/test/semantic/semantic_tests.cpp b/test/semantic/semantic_tests.cpp index 140083c..d172da3 100644 --- a/test/semantic/semantic_tests.cpp +++ b/test/semantic/semantic_tests.cpp @@ -78,7 +78,7 @@ TEST_CASE("Test Type Equality - Subtypes", "[semantic]") SECTION("Bot Type Tests") { REQUIRE(BotTy->isSubtype(TopTy)); - REQUIRE_FALSE(BotTy->isSubtype(IntTy)); //FIXME: THESE SEEM WRONG.... MAYBE? + REQUIRE_FALSE(BotTy->isSubtype(IntTy)); REQUIRE_FALSE(BotTy->isSubtype(StrTy)); REQUIRE_FALSE(BotTy->isSubtype(BoolTy)); REQUIRE_FALSE(BotTy->isSubtype(BotTy));