Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaned up and renamed main solver loop #173

Merged
merged 2 commits into from
Mar 16, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Cleaned up and renamed main solver loop
  • Loading branch information
simonjwinwood committed Mar 15, 2022
commit 050763c0965441e33529bff28fe54b6e2556b0b7
112 changes: 68 additions & 44 deletions src/Reopt/TypeInference/Solver/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleContexts #-}

module Reopt.TypeInference.Solver.Solver
( unifyConstraints
( unifyConstraints
)
where

import Control.Lens ((.=), (<<+=), (<<.=))
import Control.Lens ((<<+=), (<<.=), Lens', (%=))
import Control.Monad (when)
import Control.Monad.State (MonadState (get))
import Data.Bifunctor (first)
Expand All @@ -30,15 +31,13 @@ import Reopt.TypeInference.Solver.Monad (Conditional (..),
addTyVarEq',
defineRowVar,
defineTyVar,
dequeueEqC,
dequeueEqRowC,
lookupRowExpr,
lookupTyVar,
traceUnification,
undefineRowVar,
undefineTyVar,
unsafeUnifyRowVars,
unsafeUnifyTyVars, Conditional', condEnabled, addEqC, addEqRowC)
unsafeUnifyTyVars, Conditional', condEnabled, addEqC, addEqRowC, ConstraintSolvingState, popField)
import Reopt.TypeInference.Solver.RowVariables (RowExpr (..),
emptyFieldMap,
rowExprShift,
Expand All @@ -55,7 +54,7 @@ import Reopt.TypeInference.Solver.Types (ITy (..), ITy',
-- FIXME: probably want to export the Eqv map somehow
unifyConstraints :: SolverM ConstraintSolution
unifyConstraints = do
processAtomicConstraints
solverLoop
finalizeTypeDefs

-- | @traceContext description ctx ctx'@ reports how the context changed via @trace@.
Expand All @@ -82,74 +81,99 @@ traceContext description action = do
trace (show msg) (return ())
pure r

-- | Process all atomic (i.e., non-disjunctive) constraints, updating the
-- context with each.
processAtomicConstraints :: SolverM ()
processAtomicConstraints = traceContext "processAtomicConstraints" $ do
dequeueEqC >>= \case
Just c -> solveEqC c >> processAtomicConstraints
Nothing -> dequeueEqRowC >>= \case
Just c -> solveEqRowC c >> processAtomicConstraints
Nothing -> condEqSolver >>= \case
True -> processAtomicConstraints
False -> pure ()
where
-- This solver will solve one at a time, which is important to get
-- around the +p++ case. We solve a single ptr add, then propagate
-- the new eq and roweq constraints.
condEqSolver = do
ceqs <- field @"ctxCondEqs" <<.= mempty -- get constraints and c
go [] ceqs
traceContext' :: PP.Pretty v => PP.Doc () -> v -> SolverM a -> SolverM a
traceContext' msg v = traceContext (msg <> ": " <> PP.pretty v)

restore :: [Conditional'] -> SolverM ()
restore cs = field @"ctxCondEqs" .= cs
--------------------------------------------------------------------------------
-- Solver loop

data Retain = Retain | Discard
deriving Eq

data Progress = Progress | NoProgress
deriving Eq

madeProgress :: Progress -> Bool
madeProgress Progress = True
madeProgress _ = False

solveHead :: Lens' ConstraintSolvingState [a] ->
(a -> SolverM ()) ->
SolverM Progress
solveHead fld doit = do
v <- popField fld
case v of
Nothing -> pure NoProgress
Just v' -> doit v' $> Progress

solveFirst :: Lens' ConstraintSolvingState [a] ->
(a -> SolverM (Retain, Progress)) ->
SolverM Progress
solveFirst fld solve = do
cstrs <- fld <<.= [] -- get constraints and c
go [] cstrs
where
restore cs = fld %= (++ cs)

-- FIXME: we might want to drop conditionals when they are never
-- FIXME: we might want to drop constraints when they are never
-- going to be satisfiable.
go acc [] = restore acc $> False -- finished here, we didn't so anything.
go acc [] = restore acc $> NoProgress -- finished here, we didn't so anything.
go acc (c : cs) = do
solved <- solveConditional c
if solved
-- Conditional fired, remove it and continue solving
then restore (cs ++ acc) $> True
-- Conditional couldn't be fired, try next conditionals
else go (c : acc) cs
(retain, progress) <- solve c
let acc' = if retain == Retain then c : acc else acc
if madeProgress progress
then restore (cs ++ acc') $> Progress
else go acc' cs

solverLoop :: SolverM ()
solverLoop = do
keepGoing <- foldr once (pure False) solvers
when keepGoing solverLoop
where
solvers = [ solveHead (field @"ctxEqCs") solveEqC
, solveHead (field @"ctxEqRowCs") solveEqRowC
, solveFirst (field @"ctxCondEqs") solveConditional
]

once m rest = do
progress <- m
if madeProgress progress then pure True else rest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use Control.Monad.Extra.ifM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already depend on extra, so we could even change the types and use ||^ which is really what is going on here


--------------------------------------------------------------------------------
-- Conditionals

solveConditional :: Conditional' -> SolverM Bool
solveConditional c = do
solveConditional :: Conditional' -> SolverM (Retain, Progress)
solveConditional c = traceContext' "solveConditional" c $ do
m_newEqs <- condEnabled c
case m_newEqs of
Just newEqs -> do
mapM_ addEqC eqcs
mapM_ addEqRowC (newEqs ++ eqrowcs)
pure True
Nothing -> pure False
pure (Discard, Progress)
Nothing -> pure (Retain, NoProgress)
where
(eqcs, eqrowcs) = cConstraints c

--------------------------------------------------------------------------------
-- Row unification

solveEqRowC :: EqRowC -> SolverM ()
solveEqRowC eqc = do
solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do
(le, m_lfm) <- lookupRowExpr (eqRowLHS eqc)
let lo = rowExprShift le
lv = rowExprVar le
lfm = fromMaybe emptyFieldMap m_lfm

(re, m_rfm) <- lookupRowExpr (eqRowRHS eqc)
let ro = rowExprShift re
rv = rowExprVar re
rfm = fromMaybe emptyFieldMap m_rfm

case () of
_ | (lo, lv) == (ro, rv) -> pure () -- trivial up to eqv.
| lv == rv -> trace "Recursive row var equation, ignoring" $ pure ()
| lo < ro -> unify (ro - lo) rv rfm lv lfm
| otherwise -> unify (lo - ro) lv lfm rv rfm
| lo < ro -> unify (ro - lo) rv rfm lv lfm
| otherwise -> unify (lo - ro) lv lfm rv rfm
where
unify delta lowv lowfm highv highfm = do
undefineRowVar highv
Expand All @@ -163,7 +187,7 @@ solveEqRowC eqc = do
-- Type unification

solveEqC :: EqC -> SolverM ()
solveEqC eqc = do
solveEqC eqc = traceContext' "solveEqC" eqc $ do
(lv, m_lty) <- lookupTyVar (eqLhs eqc)
(m_rv, m_rty) <- case eqRhs eqc of
VarTy tv -> first Just <$> lookupTyVar tv
Expand Down