Skip to content

Commit a9a3bb0

Browse files
committed
Start adding type inference
1 parent 517bcbb commit a9a3bb0

File tree

7 files changed

+326
-15
lines changed

7 files changed

+326
-15
lines changed

package.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ dependencies:
2525
- text
2626
- megaparsec >= 7.0.4 && < 7.1.0
2727
- parser-combinators
28+
- mtl
29+
- containers
30+
- bifunctors
2831

2932
library:
3033
source-dirs: src

src/Infer.hs

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
{-# LANGUAGE FlexibleInstances #-}
2+
{-# LANGUAGE TypeSynonymInstances #-}
3+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4+
5+
module Infer where
6+
7+
import Prelude hiding (foldr)
8+
9+
import Type
10+
import Syntax
11+
12+
import Control.Monad.State
13+
import Control.Monad.Except
14+
15+
import Data.Monoid
16+
import Data.List (nub)
17+
import Data.Foldable (foldr)
18+
import qualified Data.Map as Map
19+
import qualified Data.Set as Set
20+
21+
newtype TypeEnv = TypeEnv (Map.Map Var Scheme)
22+
deriving (Semigroup, Monoid, Show)
23+
24+
25+
data Unique = Unique { count :: Int }
26+
27+
type Infer = ExceptT TypeError (State Unique)
28+
type Subst = Map.Map TVar Type
29+
30+
data TypeError
31+
= UnificationFail Type Type
32+
| InfiniteType TVar Type
33+
| UnboundVariable String
34+
deriving (Show, Eq)
35+
36+
runInfer :: Infer (Subst, Type) -> Either TypeError Scheme
37+
runInfer m = case evalState (runExceptT m) initUnique of
38+
Left err -> Left err
39+
Right res -> Right $ closeOver res
40+
41+
closeOver :: (Map.Map TVar Type, Type) -> Scheme
42+
closeOver (sub, ty) = normalize sc
43+
where sc = generalize emptyTyenv (apply sub ty)
44+
45+
initUnique :: Unique
46+
initUnique = Unique { count = 0 }
47+
48+
extend :: TypeEnv -> (Var, Scheme) -> TypeEnv
49+
extend (TypeEnv env) (x, s) = TypeEnv $ Map.insert x s env
50+
51+
emptyTyenv :: TypeEnv
52+
emptyTyenv = TypeEnv Map.empty
53+
54+
typeof :: TypeEnv -> Var -> Maybe Type.Scheme
55+
typeof (TypeEnv env) name = Map.lookup name env
56+
57+
class Substitutable a where
58+
apply :: Subst -> a -> a
59+
ftv :: a -> Set.Set TVar
60+
61+
instance Substitutable Type where
62+
apply _ (TCon a) = TCon a
63+
apply s t@(TVar a) = Map.findWithDefault t a s
64+
apply s (t1 `TArr` t2) = apply s t1 `TArr` apply s t2
65+
66+
ftv TCon{} = Set.empty
67+
ftv (TVar a) = Set.singleton a
68+
ftv (t1 `TArr` t2) = ftv t1 `Set.union` ftv t2
69+
70+
instance Substitutable Scheme where
71+
apply s (Forall as t) = Forall as $ apply s' t
72+
where s' = foldr Map.delete s as
73+
ftv (Forall as t) = ftv t `Set.difference` Set.fromList as
74+
75+
instance Substitutable a => Substitutable [a] where
76+
apply = fmap . apply
77+
ftv = foldr (Set.union . ftv) Set.empty
78+
79+
instance Substitutable TypeEnv where
80+
apply s (TypeEnv env) = TypeEnv $ Map.map (apply s) env
81+
ftv (TypeEnv env) = ftv $ Map.elems env
82+
83+
84+
nullSubst :: Subst
85+
nullSubst = Map.empty
86+
87+
compose :: Subst -> Subst -> Subst
88+
s1 `compose` s2 = Map.map (apply s1) s2 `Map.union` s1
89+
90+
unify :: Type -> Type -> Infer Subst
91+
unify (l `TArr` r) (l' `TArr` r') = do
92+
s1 <- unify l l'
93+
s2 <- unify (apply s1 r) (apply s1 r')
94+
return (s2 `compose` s1)
95+
96+
unify (TVar a) t = bind a t
97+
unify t (TVar a) = bind a t
98+
unify (TCon a) (TCon b) | a == b = return nullSubst
99+
unify t1 t2 = throwError $ UnificationFail t1 t2
100+
101+
bind :: TVar -> Type -> Infer Subst
102+
bind a t
103+
| t == TVar a = return nullSubst
104+
| occursCheck a t = throwError $ InfiniteType a t
105+
| otherwise = return $ Map.singleton a t
106+
107+
occursCheck :: Substitutable a => TVar -> a -> Bool
108+
occursCheck a t = a `Set.member` ftv t
109+
110+
letters :: [String]
111+
letters = [1..] >>= flip replicateM ['a'..'z']
112+
113+
fresh :: Infer Type
114+
fresh = do
115+
s <- get
116+
put s{count = count s + 1}
117+
return $ TVar $ TV (letters !! count s)
118+
119+
instantiate :: Scheme -> Infer Type
120+
instantiate (Forall as t) = do
121+
as' <- mapM (const fresh) as
122+
let s = Map.fromList $ zip as as'
123+
return $ apply s t
124+
125+
generalize :: TypeEnv -> Type -> Scheme
126+
generalize env t = Forall as t
127+
where as = Set.toList $ ftv t `Set.difference` ftv env
128+
129+
ops :: Binop -> Type
130+
ops Add = typeInt `TArr` typeInt `TArr` typeInt
131+
ops Mul = typeInt `TArr` typeInt `TArr` typeInt
132+
ops Sub = typeInt `TArr` typeInt `TArr` typeInt
133+
ops Eql = typeInt `TArr` typeInt `TArr` typeBool
134+
135+
lookupEnv :: TypeEnv -> Var -> Infer (Subst, Type)
136+
lookupEnv (TypeEnv env) x =
137+
case Map.lookup x env of
138+
Nothing -> throwError $ UnboundVariable (show x)
139+
Just s -> do t <- instantiate s
140+
return (nullSubst, t)
141+
142+
extendDecl :: TypeEnv -> Decl -> Infer (Subst, TypeEnv)
143+
extendDecl env (name, e) = do
144+
(s, t) <- infer env e
145+
let env' = apply s env
146+
t' = generalize env' t
147+
pure $ (s, env' `extend` (name, t'))
148+
149+
extendDecls :: TypeEnv -> [Decl] -> Infer (Subst, TypeEnv)
150+
extendDecls env =
151+
foldM step (nullSubst, env)
152+
where
153+
step (s, e) decl = do
154+
(s1, e2) <- extendDecl e decl
155+
pure (s1 `compose` s, e2)
156+
157+
infer :: TypeEnv -> Expr -> Infer (Subst, Type)
158+
infer env ex = case ex of
159+
160+
Var x -> lookupEnv env x
161+
162+
Lam x e -> do
163+
tv <- fresh
164+
let env' = env `extend` (x, Forall [] tv)
165+
(s1, t1) <- infer env' e
166+
return (s1, apply s1 tv `TArr` t1)
167+
168+
App e1 e2 -> do
169+
tv <- fresh
170+
(s1, t1) <- infer env e1
171+
(s2, t2) <- infer (apply s1 env) e2
172+
s3 <- unify (apply s2 t1) (TArr t2 tv)
173+
return (s3 `compose` s2 `compose` s1, apply s3 tv)
174+
175+
Let decls e2 -> do
176+
(s1, env') <- extendDecls env decls
177+
(s2, t2) <- infer env' e2
178+
return (s2 `compose` s1, t2)
179+
180+
If cond tr fl -> do
181+
tv <- fresh
182+
inferPrim env [cond, tr, fl] (typeBool `TArr` tv `TArr` tv `TArr` tv)
183+
184+
Op op e1 e2 -> do
185+
inferPrim env [e1, e2] (ops op)
186+
187+
Lit (LInt _) -> return (nullSubst, typeInt)
188+
Lit (LBool _) -> return (nullSubst, typeBool)
189+
Lit (LFloat _) -> return (nullSubst, typeFloat)
190+
191+
inferPrim :: TypeEnv -> [Expr] -> Type -> Infer (Subst, Type)
192+
inferPrim env l t = do
193+
tv <- fresh
194+
(s1, tf) <- foldM inferStep (nullSubst, id) l
195+
s2 <- unify (apply s1 (tf tv)) t
196+
return (s2 `compose` s1, apply s2 tv)
197+
where
198+
inferStep (s, tf) exp = do
199+
(s', t) <- infer (apply s env) exp
200+
return (s' `compose` s, tf . (TArr t))
201+
202+
inferExpr :: TypeEnv -> Expr -> Either TypeError Scheme
203+
inferExpr env = runInfer . infer env
204+
205+
inferTop :: TypeEnv -> [Decl] -> Either TypeError TypeEnv
206+
inferTop env [] = Right env
207+
inferTop env ((name, ex):xs) = case inferExpr env ex of
208+
Left err -> Left err
209+
Right ty -> inferTop (extend env (name, ty)) xs
210+
211+
normalize :: Scheme -> Scheme
212+
normalize (Forall ts body) = Forall (fmap snd ord) (normtype body)
213+
where
214+
ord = zip (nub $ fv body) (fmap TV letters)
215+
216+
fv (TVar a) = [a]
217+
fv (TArr a b) = fv a ++ fv b
218+
fv (TCon _) = []
219+
220+
normtype (TArr a b) = TArr (normtype a) (normtype b)
221+
normtype (TCon a) = TCon a
222+
normtype (TVar a) =
223+
case lookup a ord of
224+
Just x -> TVar x
225+
Nothing -> error "type variable not in signature"

src/Lexer.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ symbol = L.symbol spaceConsumer
5555
integer :: Parser Integer
5656
integer = lexeme L.decimal
5757

58+
float :: Parser Float
59+
float = lexeme L.float
60+
5861
reserved :: String -> Parser ()
5962
reserved w = (lexeme . try) (string w *> notFollowedBy alphaNumChar)
6063

src/Parser.hs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ variable = do
2828
int :: Parser Expr
2929
int = do
3030
n <- L.integer
31-
return (Lit (LInt (fromIntegral n)))
31+
return (Lit (LInt n))
32+
33+
float :: Parser Expr
34+
float = do
35+
f <- L.float
36+
pure . Lit $ LFloat f
3237

3338
bool :: Parser Expr
3439
bool = (L.reserved "true" >> return (Lit (LBool True)))
@@ -70,14 +75,23 @@ ifthen = do
7075
fl <- expr
7176
return (If cond tr fl)
7277

78+
swizzle :: Parser Expr
79+
swizzle = do
80+
name <- L.identifier
81+
L.symbol "."
82+
op <- L.identifier
83+
pure $ Swizzle name op
84+
7385
aexp :: Parser Expr
7486
aexp =
7587
L.parens expr
88+
<|> try float
7689
<|> bool
7790
<|> int
7891
<|> ifthen
7992
<|> letin
8093
<|> lambda
94+
<|> try swizzle
8195
<|> variable
8296

8397
term :: Parser Expr
@@ -87,9 +101,7 @@ term = aexp >>= \x ->
87101

88102
table :: [[Operator Parser Expr]]
89103
table =
90-
[ [ InfixL (Op Swizzle <$ L.symbol ".")
91-
]
92-
, [ InfixL (Op Mul <$ L.symbol "*")
104+
[ [ InfixL (Op Mul <$ L.symbol "*")
93105
, InfixL (Op Div <$ L.symbol "/")
94106
]
95107
, [ InfixL (Op Add <$ L.symbol "+")
@@ -172,5 +184,5 @@ modl = do
172184
parseExpr :: String -> Either StringError Expr
173185
parseExpr input = parse expr "<stdin>" input
174186

175-
parseModule :: FilePath -> String -> Either StringError [(String, Expr)]
187+
parseModule :: FilePath -> String -> Either StringError [Decl]
176188
parseModule fname input = parse modl fname input

src/Syntax.hs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ data Expr
1212
| If Expr Expr Expr
1313
| Fix Expr
1414
| Op Binop Expr Expr
15-
| Ty TyLit
15+
| Swizzle Var Var
16+
| Ty GlslTypes
1617
deriving (Show, Eq, Ord)
1718

18-
data TyLit
19-
= Float
19+
data GlslTypes
20+
= Bool
21+
| Int
22+
| Float
2023
| Vec2
2124
| Vec3
2225
| Vec4
@@ -28,9 +31,10 @@ data TyLit
2831
data Lit
2932
= LInt Integer
3033
| LBool Bool
34+
| LFloat Float
3135
deriving (Show, Eq, Ord)
3236

33-
data Binop = Add | Sub | Mul | Eql | Div | Swizzle
37+
data Binop = Add | Sub | Mul | Eql | Div
3438
deriving (Eq, Ord, Show)
3539

3640
data Program = Program [Decl] Expr deriving (Show, Eq)

src/Type.hs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module Type where
2+
3+
import qualified Syntax as S
4+
5+
newtype TVar = TV String
6+
deriving (Show, Eq, Ord)
7+
8+
data Type
9+
= TVar TVar
10+
| TCon S.GlslTypes
11+
| TArr Type Type
12+
deriving (Show, Eq, Ord)
13+
14+
infixr `TArr`
15+
16+
data Scheme = Forall [TVar] Type
17+
deriving (Show, Eq, Ord)
18+
19+
typeInt :: Type
20+
typeInt = TCon S.Int
21+
22+
typeFloat :: Type
23+
typeFloat = TCon S.Float
24+
25+
typeBool :: Type
26+
typeBool = TCon S.Bool

0 commit comments

Comments
 (0)