Skip to content

[wip] LLVM calc 4 #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: trunk
Choose a base branch
from
7 changes: 6 additions & 1 deletion llvm-calc4/llvm-calc4.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ common shared
Calc.Parser.Function
Calc.Parser.Identifier
Calc.Parser.Module
Calc.Parser.Pattern
Calc.Parser.Primitives
Calc.Parser.Shared
Calc.Parser.Type
Calc.Parser.Types
Calc.Patterns.Flatten
Calc.PatternUtils
Calc.Repl
Calc.SourceSpan
Calc.Typecheck.Elaborate
Expand All @@ -77,9 +80,11 @@ common shared
Calc.Types.FunctionName
Calc.Types.Identifier
Calc.Types.Module
Calc.Types.Pattern
Calc.Types.Prim
Calc.Types.Type
Calc.TypeUtils
Calc.Utils

library
import: shared
Expand Down Expand Up @@ -110,8 +115,8 @@ test-suite llvm-calc4-tests
Test.Typecheck.TypecheckSpec

executable llvm-calc4
main-is: Main.hs
import: shared
main-is: Main.hs
hs-source-dirs: app
hs-source-dirs: src
ghc-options: -threaded -rtsopts -with-rtsopts=-N
Expand Down
4 changes: 4 additions & 0 deletions llvm-calc4/src/Calc/Compile/ToLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ lookupArg identifier = do
printFunction :: (LLVM.MonadModuleBuilder m) => Type ann -> m LLVM.Operand
printFunction (TPrim _ TInt) = LLVM.extern "printint" [LLVM.i32] LLVM.void
printFunction (TPrim _ TBool) = LLVM.extern "printbool" [LLVM.i1] LLVM.void
printFunction (TTuple {}) = error "printFunction TTuple"
printFunction (TFunction _ _ tyRet) = printFunction tyRet -- maybe this should be an error instead

-- | given our `Module` type, turn it into an LLVM module
Expand Down Expand Up @@ -152,6 +153,7 @@ functionNameToLLVM (FunctionName fnName) =
typeToLLVM :: Type ann -> LLVM.Type
typeToLLVM (TPrim _ TBool) = LLVM.i1
typeToLLVM (TPrim _ TInt) = LLVM.i32
typeToLLVM TTuple {} = error "typeToLLVM TTuple"
typeToLLVM (TFunction _ tyArgs tyRet) =
LLVM.FunctionType (typeToLLVM tyRet) (typeToLLVM <$> tyArgs) False

Expand Down Expand Up @@ -243,6 +245,8 @@ exprToLLVM (EPrim _ prim) =
pure $ primToLLVM prim
exprToLLVM (EVar _ var) =
lookupArg var
exprToLLVM (ETuple {}) = error "exprToLLVM ETuple"
exprToLLVM (EPatternMatch {}) = error "exprToLLVM EPatternMatch"
exprToLLVM (EApply _ fnName args) = do
irFunc <- lookupFunction fnName
irArgs <- traverse exprToLLVM args
Expand Down
8 changes: 8 additions & 0 deletions llvm-calc4/src/Calc/ExprUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Calc.ExprUtils
where

import Calc.Types
import Data.Bifunctor (second)

-- | get the annotation in the first leaf found in an `Expr`.
-- useful for getting the overall type of an expression
Expand All @@ -17,6 +18,8 @@ getOuterAnnotation (EPrim ann _) = ann
getOuterAnnotation (EIf ann _ _ _) = ann
getOuterAnnotation (EVar ann _) = ann
getOuterAnnotation (EApply ann _ _) = ann
getOuterAnnotation (ETuple ann _ _) = ann
getOuterAnnotation (EPatternMatch ann _ _) = ann

-- | modify the outer annotation of an expression
-- useful for adding line numbers during parsing
Expand All @@ -28,6 +31,8 @@ mapOuterExprAnnotation f expr' =
EIf ann a b c -> EIf (f ann) a b c
EVar ann a -> EVar (f ann) a
EApply ann a b -> EApply (f ann) a b
ETuple ann a b -> ETuple (f ann) a b
EPatternMatch ann a b -> EPatternMatch (f ann) a b

-- | Given a function that changes `Expr` values, apply it throughout
-- an AST tree
Expand All @@ -38,3 +43,6 @@ mapExpr _ (EVar ann a) = EVar ann a
mapExpr f (EApply ann fn args) = EApply ann fn (f <$> args)
mapExpr f (EIf ann predExpr thenExpr elseExpr) =
EIf ann (f predExpr) (f thenExpr) (f elseExpr)
mapExpr f (ETuple ann a as) = ETuple ann (f a) (f <$> as)
mapExpr f (EPatternMatch ann matchExpr patterns) =
EPatternMatch ann (f matchExpr) (fmap (second f) patterns)
53 changes: 50 additions & 3 deletions llvm-calc4/src/Calc/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Coerce
import qualified Data.List.NonEmpty as NE
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Monoid (First (..))

-- | type for interpreter state
newtype InterpreterState ann = InterpreterState
Expand All @@ -31,6 +33,7 @@ data InterpreterError ann
= NonBooleanPredicate ann (Expr ann)
| FunctionNotFound FunctionName [FunctionName]
| VarNotFound Identifier [Identifier]
| NoPatternsMatched (Expr ann) (NE.NonEmpty (Pattern ann))
deriving stock (Eq, Ord, Show)

-- | type of Reader env for interpreter state
Expand Down Expand Up @@ -62,12 +65,12 @@ runInterpreter = flip evalStateT initialState . flip runReaderT initialEnv . run
-- we use the Reader env here because the vars disappear after we use them,
-- say, in a function
withVars ::
[(ArgumentName, b)] ->
[ArgumentName] ->
[Expr ann] ->
InterpretM ann a ->
InterpretM ann a
withVars fnArgs inputs =
let newVars = M.fromList $ zip (coerce . fst <$> fnArgs) inputs
let newVars = M.fromList $ zip (coerce <$> fnArgs) inputs
in local
( \(InterpreterEnv ieVars) ->
InterpreterEnv $ ieVars <> newVars
Expand Down Expand Up @@ -109,7 +112,7 @@ interpretApply fnName args = do
fn <- gets (M.lookup fnName . isFunctions)
case fn of
Just (Function {fnArgs, fnBody}) ->
withVars fnArgs args (interpret fnBody)
withVars (fst <$> fnArgs) args (interpret fnBody)
Nothing -> do
allFnNames <- gets (M.keys . isFunctions)
throwError (FunctionNotFound fnName allFnNames)
Expand All @@ -126,13 +129,57 @@ interpret (EApply _ fnName args) =
interpretApply fnName args
interpret (EInfix ann op a b) =
interpretInfix ann op a b
interpret (ETuple ann a as) = do
aA <- interpret a
asA <- traverse interpret as
pure (ETuple ann aA asA)
interpret (EPatternMatch _ expr pats) = do
exprA <- interpret expr
interpretPatternMatch exprA pats
interpret (EIf ann predExpr thenExpr elseExpr) = do
predA <- interpret predExpr
case predA of
(EPrim _ (PBool True)) -> interpret thenExpr
(EPrim _ (PBool False)) -> interpret elseExpr
other -> throwError (NonBooleanPredicate ann other)

interpretPatternMatch ::
Expr ann ->
NE.NonEmpty (Pattern ann, Expr ann) ->
InterpretM ann (Expr ann)
interpretPatternMatch expr' patterns = do
-- interpret match expression
intExpr <- interpret expr'
let foldF (pat, patExpr) = case patternMatches pat intExpr of
Just bindings -> First (Just (patExpr, bindings))
_ -> First Nothing

-- get first matching pattern
case getFirst (foldMap foldF patterns) of
Just (patExpr, bindings) ->
let vars = fmap (coerce . fst) bindings
exprs = fmap snd bindings
in withVars vars exprs (interpret patExpr)
_ -> throwError (NoPatternsMatched expr' (fst <$> patterns))

-- pull vars out of expr to match patterns
patternMatches ::
Pattern ann ->
Expr ann ->
Maybe [(Identifier, Expr ann)]
patternMatches (PWildcard _) _ = pure []
patternMatches (PVar _ name) expr = pure [(name, expr)]
patternMatches (PTuple _ pA pAs) (ETuple _ a as) = do
matchA <- patternMatches pA a
matchAs <-
traverse
(uncurry patternMatches)
(zip (NE.toList pAs) (NE.toList as))
pure $ matchA <> mconcat matchAs
patternMatches (PLiteral _ pB) (EPrim _ b)
| pB == b = pure mempty
patternMatches _ _ = Nothing

interpretModule ::
Module ann ->
InterpretM ann (Expr ann)
Expand Down
51 changes: 49 additions & 2 deletions llvm-calc4/src/Calc/Parser/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
module Calc.Parser.Expr (exprParser) where

import Calc.Parser.Identifier
import Calc.Parser.Pattern
import Calc.Parser.Primitives
import Calc.Parser.Shared
import Calc.Parser.Types
import Calc.Types.Annotation
import Calc.Types.Expr
import Control.Monad.Combinators.Expr
import qualified Data.List.NonEmpty as NE
import Data.Text
import Text.Megaparsec

Expand All @@ -17,8 +19,10 @@ exprParser = addLocation (makeExprParser exprPart table) <?> "expression"

exprPart :: Parser (Expr Annotation)
exprPart =
inBrackets (addLocation exprParser)
<|> primParser
try tupleParser
<|> inBrackets (addLocation exprParser)
<|> patternMatchParser
<|> primExprParser
<|> ifParser
<|> try applyParser
<|> varParser
Expand Down Expand Up @@ -55,3 +59,46 @@ applyParser = addLocation $ do
args <- sepBy exprParser (stringLiteral ",")
stringLiteral ")"
pure (EApply mempty fnName args)

tupleParser :: Parser (Expr Annotation)
tupleParser = label "tuple" $
addLocation $ do
_ <- stringLiteral "("
neArgs <- NE.fromList <$> sepBy1 exprParser (stringLiteral ",")
neTail <- case NE.nonEmpty (NE.tail neArgs) of
Just ne -> pure ne
_ -> fail "Expected at least two items in a tuple"
_ <- stringLiteral ")"
pure (ETuple mempty (NE.head neArgs) neTail)

-----

patternMatchParser :: Parser ParserExpr
patternMatchParser = addLocation $ do
matchExpr <- matchExprWithParser
patterns <-
try patternMatchesParser
<|> pure <$> patternCaseParser
case NE.nonEmpty patterns of
(Just nePatterns) -> pure $ EPatternMatch mempty matchExpr nePatterns
_ -> error "need at least one pattern"

matchExprWithParser :: Parser ParserExpr
matchExprWithParser = do
stringLiteral "case"
sumExpr <- exprParser
stringLiteral "of"
pure sumExpr

patternMatchesParser :: Parser [(ParserPattern, ParserExpr)]
patternMatchesParser =
sepBy
patternCaseParser
(stringLiteral "|")

patternCaseParser :: Parser (ParserPattern, ParserExpr)
patternCaseParser = do
pat <- orInBrackets patternParser
stringLiteral "->"
patExpr <- exprParser
pure (pat, patExpr)
63 changes: 63 additions & 0 deletions llvm-calc4/src/Calc/Parser/Pattern.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{-# LANGUAGE OverloadedStrings #-}

module Calc.Parser.Pattern
( patternParser,
ParserPattern,
)
where

import Calc.Parser.Identifier
import Calc.Parser.Primitives
import Calc.Parser.Shared
import Calc.Parser.Types
import Calc.Types
import qualified Data.List.NonEmpty as NE
import Text.Megaparsec
import Text.Megaparsec.Char

type ParserPattern = Pattern Annotation

patternParser :: Parser ParserPattern
patternParser =
label
"pattern match"
( orInBrackets
( try patTupleParser
<|> try patWildcardParser
<|> try patVariableParser
<|> patLitParser
)
)

----

patWildcardParser :: Parser ParserPattern
patWildcardParser =
myLexeme $
withLocation
(\loc _ -> PWildcard loc)
(string "_")

----

patVariableParser :: Parser ParserPattern
patVariableParser =
myLexeme $ withLocation PVar identifierParser

----

patTupleParser :: Parser ParserPattern
patTupleParser = label "tuple" $
withLocation (\loc (pHead, pTail) -> PTuple loc pHead pTail) $ do
_ <- stringLiteral "("
neArgs <- NE.fromList <$> sepBy1 patternParser (stringLiteral ",")
neTail <- case NE.nonEmpty (NE.tail neArgs) of
Just ne -> pure ne
_ -> fail "Expected at least two items in a tuple"
_ <- stringLiteral ")"
pure (NE.head neArgs, neTail)

----

patLitParser :: Parser ParserPattern
patLitParser = myLexeme $ withLocation PLiteral primParser
15 changes: 12 additions & 3 deletions llvm-calc4/src/Calc/Parser/Primitives.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{-# LANGUAGE OverloadedStrings #-}

module Calc.Parser.Primitives
( primParser,
( primExprParser,
primParser,
intParser,
)
where
Expand Down Expand Up @@ -37,10 +38,18 @@ falseParser = stringLiteral "False" $> False

---

primParser :: Parser ParserExpr
primParser =
primExprParser :: Parser ParserExpr
primExprParser =
myLexeme $
addLocation $
EPrim mempty . PInt <$> intParser
<|> EPrim mempty <$> truePrimParser
<|> EPrim mempty <$> falsePrimParser

----

primParser :: Parser Prim
primParser =
PInt <$> intParser
<|> truePrimParser
<|> falsePrimParser
6 changes: 5 additions & 1 deletion llvm-calc4/src/Calc/Parser/Shared.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{-# LANGUAGE OverloadedStrings #-}

module Calc.Parser.Shared
( inBrackets,
( orInBrackets,
inBrackets,
myLexeme,
withLocation,
stringLiteral,
Expand Down Expand Up @@ -47,6 +48,9 @@ addTypeLocation = withLocation (mapOuterTypeAnnotation . const)
inBrackets :: Parser a -> Parser a
inBrackets = between2 '(' ')'

orInBrackets :: Parser a -> Parser a
orInBrackets parser = try parser <|> try (inBrackets parser)

myLexeme :: Parser a -> Parser a
myLexeme = L.lexeme (L.space space1 empty empty)

Expand Down
Loading