Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More typechecking error message improvements #1308

Merged
merged 17 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactoring/improvements
- Rename `expect` to `unify`
- Add a `Syntax` argument to all `decomposeXXX` functions
- Add `Source` and `Join` machinery to properly keep track of which
  type is "expected" and which is "actual"
  • Loading branch information
byorgey committed Jun 6, 2023
commit 833f4a3df38044d3212f1efb34d59786086a50ee
8 changes: 4 additions & 4 deletions src/Swarm/Language/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,17 @@ appliedTermPrec _ = 10
instance PrettyPrec TypeErr where
prettyPrec _ (UnifyErr ty1 ty2) =
"Can't unify" <+> ppr ty1 <+> "and" <+> ppr ty2
prettyPrec _ (Mismatch Nothing ty1 ty2) =
prettyPrec _ (Mismatch Nothing (getJoin -> (ty1,ty2))) =
"Type mismatch: expected" <+> ppr ty1 <> ", but got" <+> ppr ty2
prettyPrec _ (Mismatch (Just t) ty1 ty2) =
prettyPrec _ (Mismatch (Just t) (getJoin -> (ty1,ty2))) =
nest 2 . vcat $
[ "Type mismatch:"
, "From context, expected" <+> bquote (ppr t) <+> "to have type" <+> bquote (ppr ty1) <> ","
, "but it actually has type" <+> bquote (ppr ty2)
]
prettyPrec _ (LambdaArgMismatch ty1 ty2) =
prettyPrec _ (LambdaArgMismatch (getJoin -> (ty1,ty2))) =
"Lambda argument has type annotation" <+> ppr ty2 <> ", but expected argument type" <+> ppr ty1
prettyPrec _ (FieldsMismatch expFs actFs) = fieldMismatchMsg expFs actFs
prettyPrec _ (FieldsMismatch (getJoin -> (expFs,actFs))) = fieldMismatchMsg expFs actFs
prettyPrec _ (EscapedSkolem x) =
"Skolem variable" <+> pretty x <+> "would escape its scope"
prettyPrec _ (UnboundVar x) =
Expand Down
140 changes: 88 additions & 52 deletions src/Swarm/Language/Typecheck.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- For 'Ord IntVar' instance
Expand All @@ -17,6 +18,9 @@ module Swarm.Language.Typecheck (
TypeErr (..),
InvalidAtomicReason (..),

-- * Type provenance
Source(..), Join, getJoin,

-- * Typechecking stack

TCFrame(..), TCStack, withFrame, getTCStack,
Expand All @@ -28,8 +32,7 @@ module Swarm.Language.Typecheck (

-- * Unification
substU,
expect,
(=:=),
unify,
HasBindings (..),
instantiate,
skolemize,
Expand All @@ -41,8 +44,6 @@ module Swarm.Language.Typecheck (
infer,
inferConst,
check,
decomposeCmdTy,
decomposeFunTy,
isSimpleUType,
) where

Expand All @@ -52,7 +53,7 @@ import Control.Lens ((^.))
import Control.Lens.Indexed (itraverse)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification hiding (applyBindings, (=:=), unify)
import Control.Unification qualified as U
import Control.Unification.IntVar
import Data.Data (Data, gmapM)
Expand Down Expand Up @@ -83,6 +84,39 @@ data TCFrame where

type TCStack = [(SrcLoc, TCFrame)]

------------------------------------------------------------
-- Type source

-- | The source of a type during typechecking.
data Source
= Expected -- ^ An expected type that was "pushed down" from the context.
| Actual -- ^ An actual/inferred type that was "pulled up" from a term.
deriving (Show, Eq, Ord, Bounded, Enum)

-- | A value along with its source (expected vs actual).
type Sourced a = (Source, a)

-- | A "join" where an expected thing meets an actual thing.
data Join a = Join (Source -> a)

instance Show a => Show (Join a) where
show (getJoin -> (e,a)) = "(expected: " <> show e <> ", actual: " <> show a <> ")"

type TypeJoin = Join UType

-- | Create a 'Join' from an expected thing and an actual thing (in that order).
joined :: a -> a -> Join a
joined expect actual = Join (\case {Expected -> expect; Actual -> actual})

-- | Create a 'Join' from a 'Sourced' thing together with another
-- thing (which is assumed to have the opposite 'Source').
mkJoin :: Sourced a -> a -> Join a
mkJoin (src, a1) a2 = Join $ \s -> if s == src then a1 else a2

-- | Convert a 'Join' into a pair of (expected, actual).
getJoin :: Join a -> (a,a)
getJoin (Join j) = (j Expected, j Actual)

------------------------------------------------------------
-- Type checking monad

Expand Down Expand Up @@ -209,25 +243,27 @@ noSkolems l (Forall xs upty) = do

infix 4 =:=

-- | @expect t expTy actTy@ expects that term @t@ has type @expTy@,
-- where it actually has type @actTy@. Ensure those types are the
-- same.
expect :: Maybe Syntax -> UType -> UType -> TC UType
expect ms expected actual = case unifyCheck expected actual of
Apart -> throwTypeErr NoLoc $ Mismatch ms expected actual
-- | @unify t expTy actTy@ ensures that the given two types are equal.
-- The first type is the expected type (passed down from the
-- context) whereas the second is the actual type (extracted from a
-- term). If we know the actual term @t@ which is supposed to have
-- these types, we can use it to generate better error
-- messages.
--
-- We first do a quick-and-dirty check to see whether we know for
-- sure the types either are or cannot be equal, generating an
-- equality constraint for the unifier as a last resort.
unify :: Maybe Syntax -> TypeJoin -> TC UType
unify ms j = case unifyCheck expected actual of
Apart -> throwTypeErr NoLoc $ Mismatch ms j
Equal -> return expected
MightUnify -> lift . lift $ expected U.=:= actual
where
(expected, actual) = getJoin j

-- | Constrain two types to be equal, first with a quick-and-dirty
-- check to see whether we know for sure they either are or cannot
-- be equal, generating an equality constraint for the unifier as a
-- last resort.
--
-- Important: the first given type should be the "expected" type
-- from context, and the second should be the "actual" type (in any
-- situation when this distinction makes sense).
-- | Ensure two types are the same.
(=:=) :: UType -> UType -> TC UType
(=:=) = expect Nothing
ty1 =:= ty2 = unify Nothing (joined ty1 ty2)

-- | @unification-fd@ provides a function 'U.applyBindings' which
-- fully substitutes for any bound unification variables (for
Expand Down Expand Up @@ -344,13 +380,13 @@ data TypeErr
| -- | Type mismatch caught by 'unifyCheck'. The given term was
-- expected to have a certain type, but has a different type
-- instead.
Mismatch (Maybe Syntax) UType UType -- expected, actual
Mismatch (Maybe Syntax) TypeJoin
| -- | Lambda argument type mismatch.
LambdaArgMismatch UType UType -- expected, actual
LambdaArgMismatch TypeJoin
| -- | Record field mismatch, i.e. based on the expected type we
-- were expecting a record with certain fields, but found one with
-- a different field set.
FieldsMismatch (Set Var) (Set Var) -- expected fields, actual
FieldsMismatch (Join (Set Var))
| -- | A definition was encountered not at the top level.
DefNotTopLevel Term
| -- | A term was encountered which we cannot infer the type of.
Expand Down Expand Up @@ -387,37 +423,37 @@ instance Fallible TypeF IntVar ContextualTypeErr where
-- Type decomposition

-- | Decompose a type that is supposed to be a delay type.
decomposeDelayTy :: UType -> TC UType
decomposeDelayTy (UTyDelay a) = return a
decomposeDelayTy ty = do
decomposeDelayTy :: Syntax -> Sourced UType -> TC UType
decomposeDelayTy _ (_, UTyDelay a) = return a
decomposeDelayTy t ty = do
a <- fresh
_ <- UTyDelay a =:= ty
_ <- unify (Just t) (mkJoin ty (UTyDelay a))
return a

-- | Decompose a type that is supposed to be a command type.
decomposeCmdTy :: UType -> TC UType
decomposeCmdTy (UTyCmd a) = return a
decomposeCmdTy ty = do
decomposeCmdTy :: Syntax -> Sourced UType -> TC UType
decomposeCmdTy _ (_, UTyCmd a) = return a
decomposeCmdTy t ty = do
a <- fresh
_ <- UTyCmd a =:= ty
_ <- unify (Just t) (mkJoin ty (UTyCmd a))
return a

-- | Decompose a type that is supposed to be a function type.
decomposeFunTy :: UType -> TC (UType, UType)
decomposeFunTy (UTyFun ty1 ty2) = return (ty1, ty2)
decomposeFunTy ty = do
decomposeFunTy :: Syntax -> Sourced UType -> TC (UType, UType)
decomposeFunTy _ (_, UTyFun ty1 ty2) = return (ty1, ty2)
decomposeFunTy t ty = do
ty1 <- fresh
ty2 <- fresh
_ <- UTyFun ty1 ty2 =:= ty
_ <- unify (Just t) (mkJoin ty (UTyFun ty1 ty2))
return (ty1, ty2)

-- | Decompose a type that is supposed to be a product type.
decomposeProdTy :: UType -> TC (UType, UType)
decomposeProdTy (UTyProd ty1 ty2) = return (ty1, ty2)
decomposeProdTy ty = do
decomposeProdTy :: Syntax -> Sourced UType -> TC (UType, UType)
decomposeProdTy _ (_, UTyProd ty1 ty2) = return (ty1, ty2)
decomposeProdTy t ty = do
ty1 <- fresh
ty2 <- fresh
_ <- UTyProd ty1 ty2 =:= ty
_ <- unify (Just t) (mkJoin ty (UTyProd ty1 ty2))
return (ty1, ty2)

------------------------------------------------------------
Expand All @@ -440,7 +476,7 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of
SDef r x Nothing t1 -> do
xTy <- fresh
t1' <- withBinding (lvVar x) (Forall [] xTy) $ infer t1
_ <- xTy =:= t1' ^. sType
_ <- unify (Just t1) (joined xTy (t1' ^. sType))
pty <- generalize (t1' ^. sType)
return $ Module (Syntax' l (SDef r x Nothing t1') (UTyCmd UTyUnit)) (singleton (lvVar x) pty)

Expand All @@ -458,7 +494,7 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of
SBind mx c1 c2 -> do
-- First, infer the left side.
Module c1' ctx1 <- inferModule c1
a <- decomposeCmdTy (c1' ^. sType)
a <- decomposeCmdTy c1 (Actual, c1' ^. sType)

-- Now infer the right side under an extended context: things in
-- scope on the right-hand side include both any definitions
Expand All @@ -475,7 +511,7 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of
-- going to return the entire type, but it's important to
-- ensure it's a command type anyway. Otherwise something
-- like 'move; 3' would be accepted with type int.
_ <- decomposeCmdTy (c2' ^. sType)
_ <- decomposeCmdTy c2 (Actual, c2' ^. sType)

-- Ctx.union is right-biased, so ctx1 `union` ctx2 means later
-- definitions will shadow previous ones. Include the binder
Expand Down Expand Up @@ -556,7 +592,7 @@ infer s@(Syntax l t) = addLocToTypeErr l $ case t of
SApp f x -> do
-- Infer the type of the left-hand side and make sure it has a function type.
f' <- infer f
(argTy, resTy) <- decomposeFunTy (f' ^. sType)
(argTy, resTy) <- decomposeFunTy f (Actual, f' ^. sType)

-- Then check that the argument has the right type.
x' <- check x argTy
Expand All @@ -566,9 +602,9 @@ infer s@(Syntax l t) = addLocToTypeErr l $ case t of
-- application.
SBind mx c1 c2 -> do
c1' <- infer c1
a <- decomposeCmdTy (c1' ^. sType)
a <- decomposeCmdTy c1 (Actual, c1' ^. sType)
c2' <- maybe id ((`withBinding` Forall [] a) . lvVar) mx $ infer c2
_ <- decomposeCmdTy (c2' ^. sType)
_ <- decomposeCmdTy c2 (Actual, c2' ^. sType)
return $ Syntax' l (SBind mx c1' c2') (c2' ^. sType)

-- Handle record projection in inference mode. Knowing the expected
Expand Down Expand Up @@ -732,28 +768,28 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of
-- dynamically at runtime when evaluating recursive let or def expressions,
-- so we don't have to worry about typechecking them here.
SDelay d s1 -> do
ty1 <- decomposeDelayTy expected
ty1 <- decomposeDelayTy s (Expected, expected)
s1' <- check s1 ty1
return $ Syntax' l (SDelay d s1') (UTyDelay ty1)

-- To check the type of a pair, make sure the expected type is a
-- product type, and push the two types down into the left and right.
SPair s1 s2 -> do
(ty1, ty2) <- decomposeProdTy expected
(ty1, ty2) <- decomposeProdTy s (Expected, expected)
s1' <- check s1 ty1
s2' <- check s2 ty2
return $ Syntax' l (SPair s1' s2') (UTyProd ty1 ty2)

-- To check a lambda, make sure the expected type is a function type.
SLam x mxTy body -> do
(argTy, resTy) <- decomposeFunTy expected
(argTy, resTy) <- decomposeFunTy s (Expected, expected)
case toU mxTy of
Just xTy -> case unifyCheck argTy xTy of
-- Generate a special error when the explicit type annotation
-- on a lambda doesn't match the expected type,
-- e.g. (\x:int. x + 2) : text -> int, since the usual
-- "expected/but got" language would probably be confusing.
Apart -> throwTypeErr l $ LambdaArgMismatch argTy xTy
Apart -> throwTypeErr l $ LambdaArgMismatch (joined argTy xTy)
-- Otherwise, make sure to unify the annotation with the
-- expected argument type.
_ -> void $ argTy =:= xTy
Expand All @@ -767,7 +803,7 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of

TConst c :$: at
| c `elem` [Atomic, Instant] -> do
argTy <- decomposeCmdTy expected
argTy <- decomposeCmdTy s (Expected, expected)
at' <- check at (UTyCmd argTy)
atomic' <- infer (Syntax l (TConst c))
-- It's important that we typecheck the subterm @at@ *before* we
Expand Down Expand Up @@ -829,15 +865,15 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of
actualFields = M.keysSet fields
when (actualFields /= expectedFields) $
throwTypeErr l $
FieldsMismatch expectedFields actualFields
FieldsMismatch (joined expectedFields actualFields)
m' <- itraverse (\x ms -> check (fromMaybe (STerm (TVar x)) ms) (tyMap ! x)) fields
return $ Syntax' l (SRcd (Just <$> m')) expected

-- Fallback: switch into inference mode, and check that the type we
-- get is what we expected.
_ -> do
Syntax' l' t' actual <- infer s
Syntax' l' t' <$> expect (Just s) expected actual
Syntax' l' t' <$> unify (Just s) (joined expected actual)

-- ~~~~ Note [Checking and inference for record literals]
--
Expand Down
4 changes: 2 additions & 2 deletions src/Swarm/Language/Typecheck/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ instance Monoid UnifyStatus where
-- 'Equal'. In case (1), we can generate a much better error
-- message at the instant the two types come together than we could
-- if we threw a constraint into the unifier. In case (2), we don't
-- have to bother with generating a constraint. If we don't know for
-- sure whether they will unify, return 'MightUnify'.
-- have to bother with generating a trivial constraint. If we don't
-- know for sure whether they will unify, return 'MightUnify'.
unifyCheck :: UType -> UType -> UnifyStatus
unifyCheck ty1 ty2 = case (ty1, ty2) of
(UVar x, UVar y)
Expand Down