Skip to content

Commit

Permalink
talos: about to refactor assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
simonjwinwood committed Aug 4, 2023
1 parent b8df82e commit 25a3cba
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 108 deletions.
191 changes: 99 additions & 92 deletions talos/src/Talos/Strategy/PathSymbolic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module Talos.Strategy.PathSymbolic (pathSymbolicStrat) where

import Control.Lens (_1, _2, each, over, itraverse, preview, traverseOf)
import Control.Monad (forM_, when,
zipWithM, (<=<), unless)
zipWithM, (<=<), unless, join)
import Control.Monad.Reader
import Control.Monad.Writer.CPS (censor, pass)
import Data.Bifunctor (second)
Expand Down Expand Up @@ -71,10 +71,12 @@ import Talos.Solver.SolverT (declareName,
declareSymbol,
liftSolver, reset,
scoped, contextSize)
import Talos.Lib (tByte)
import Talos.Lib (tByte, andMany)
import Control.Monad.Except (catchError)
import Data.List ((\\))

import qualified Talos.Strategy.PathSymbolic.Branching as B

-- ----------------------------------------------------------------------------------------
-- Backtracking random strats

Expand Down Expand Up @@ -228,7 +230,7 @@ stratSlice = go

primBindName n n' $ do
pe <- synthesiseExpr p
assertSExpr (MV.asAssertion pe)
assertSExpr (MV.toSExpr pe)

-- Once we have a model, we need to convert all the free
-- variables in the inverse function expression into values,
Expand Down Expand Up @@ -285,8 +287,7 @@ stratChoice sls _ = do
unzip . catMaybes <$>
itraverse (\i -> guardedChoice pv i . stratSlice) sls

when (null vs) unreachable -- all paths failed
v <- liftSemiSolverM (MV.mux vs Nothing)
v <- maybe unreachable (liftSemiSolverM . MV.mux) (B.branchingMaybe vs Nothing)

let feasibleIxs = map fst paths

Expand All @@ -312,7 +313,7 @@ stratLoop lclass =
-- No need to constrain the processing of the body.
(v, m) <- stratSlice b

let xs = MV.VSequence [] (vsm, [v])
let xs = MV.VSequence (B.branching [] (vsm, [v]))
v' = case sem of
SemNo -> MV.VUnit
SemYes -> xs
Expand Down Expand Up @@ -356,7 +357,7 @@ stratLoop lclass =
-- something.
(elv, m) <- censor (extendPath pathGuard) (stratSlice b)

let v = MV.VSequence [] (vsm, [elv])
let v = MV.VSequence (B.singleton (vsm, [elv]))
node = PathLoopGenerator ltag (Just lv) m

pure (v, SelectedLoop node)
Expand Down Expand Up @@ -395,7 +396,7 @@ stratLoop lclass =
, vsmMinLength = clb
, vsmIsBuilder = False
}
v = MV.VSequence [] (vsm, els)
v = MV.VSequence (B.singleton (vsm, els))
node = PathLoopUnrolled (Just lv) ms
pure (v, SelectedLoop node)

Expand Down Expand Up @@ -451,97 +452,108 @@ stratLoop lclass =

(vs, pbs) <- unzip <$> go se [] 0
let node = PathLoopUnrolled (Just lv) pbs
v <- liftSemiSolverM (MV.mux vs (Just se))
v <- liftSemiSolverM (MV.mux (B.branching vs se))
pure (v, SelectedLoop node)

SMorphismLoop (FoldMorphism n e lc b) -> do
ltag <- freshSymbolicLoopTag

se <- synthesiseExpr e
m_col <- preview #_VSequence <$> synthesiseExpr (lcCol lc)

let (vars, base) = case m_col of
let bvs = case m_col of
Nothing -> panic "UNIMPLEMENTED: non-list fold" []
Just r -> r

-- c.f. SRepeatLoop
let guards vsm
| Just lv <- vsmLoopCountVar vsm =
\i -> ( PS.loopCountGeqConstraint lv i
, PS.loopCountEqConstraint lv i
)
| otherwise = const (mempty, mempty)

-- this allows short-circuiting if we try to unfold too many
-- times.
go se' acc i
| i == nloops = pure (reverse acc)
| otherwise = do
m_v_pb <- handleUnreachable (gtGuard i)
(primBindName n se' (stratSlice b))
case m_v_pb of
Nothing -> pure (reverse acc)
Just (v, pb) -> do
-- these should work (no imcompatible assumptions about lv)
go v (((eqGuard i, v), pb) : acc) (i + 1)
-- -- c.f. SRepeatLoop
-- let guards vsm
-- | Just lv <- vsmLoopCountVar vsm =
-- \i -> ( PS.loopCountGeqConstraint lv i
-- , PS.loopCountEqConstraint lv i
-- )
-- | otherwise = const (mempty, mempty)

-- -- this allows short-circuiting if we try to unfold too many
-- -- times.
-- go se' acc i
-- | i == nloops = pure (reverse acc)
-- | otherwise = do
-- m_v_pb <- handleUnreachable (gtGuard i)
-- (primBindName n se' (stratSlice b))
-- case m_v_pb of
-- Nothing -> pure (reverse acc)
-- Just (v, pb) -> do
-- -- these should work (no imcompatible assumptions about lv)
-- go v (((eqGuard i, v), pb) : acc) (i + 1)


-- TODO: prune early as for Repeat above
goOne _vsm (_se', acc) [] = pure (reverse acc)
goOne vsm (se', acc) ((i, el) : rest) = do
(v, pb) <- guardedLoopCollection vsm lc (primBindName n se' (stratSlice b)) i el
let (gtGuard, eqGuard) = guards vsm (i + 1)
-- -- TODO: prune early as for Repeat above
-- goOne _vsm (_se', acc) [] = pure (reverse acc)
-- goOne vsm (se', acc) ((i, el) : rest) = do
-- (v, pb) <- guardedLoopCollection vsm lc (primBindName n se' (stratSlice b)) i el
-- let (gtGuard, eqGuard) = guards vsm (i + 1)

v' <- hoistMaybe (MV.refine eqGuard v)
se'' <- hoistMaybe (MV.refine gtGuard v)
goOne vsm (se'', (v', pb) : acc) rest
-- v' <- hoistMaybe (MV.refine eqGuard v)
-- se'' <- hoistMaybe (MV.refine gtGuard v)
-- goOne vsm (se'', (v', pb) : acc) rest

go (vsm, els) = do
(svs, pbs) <- unzip <$> goOne vsm (se, []) (zip [0..] els)
-- go (vsm, els) = do
-- (svs, pbs) <- unzip <$> goOne vsm (se, []) (zip [0..] els)

let (_, eqGuard) = guards vsm 0
base <- hoistMaybe (MV.refine eqGuard se)
-- let (_, eqGuard) = guards vsm 0
-- base <- hoistMaybe (MV.refine eqGuard se)

pure (MV.unions (base :| svs), (g, vsm, pbs))
-- pure (MV.unions (base :| svs), (g, vsm, pbs))

-- (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col))
-- -- (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col))

vars' <- traverseOf (each . _2) go vars
base' <- go base
let go :: (VSequenceMeta, [MuxValue]) ->
SymbolicM (B.Branching MuxValue, [S.SExpr], (VSequenceMeta, [PathBuilder]))
go (vsm, els) = do
(svs, pbs) <- unzip <$> goOne vsm se [] (zip [0..] els)
let svs' = drop (vsmMinLength vsm) ((eqGuard vsm 0, se) : svs)

undefined

(bvs', assns, nodes) <- B.unzip3 <$> traverse go bvs
let assn = B.fold (\ps ss acc -> [S.ite (PS.toSExpr ps) (andMany ss) (andMany acc)]) assns
-- re-assert guarded assertions
assertSExpr (andMany assn)

v <- liftSemiSolverM (MV.mux (join bvs'))

v <- hoistMaybe (MV.unions' vs)
let node' = PathLoopMorphism ltag nodes

pure (v, SelectedLoop node')

SMorphismLoop (MapMorphism lc b) -> do
ltag <- freshSymbolicLoopTag
undefined

-- -- TODO: prune early as for Repeat above
-- col <- synthesiseExpr (lcCol lc)
-- let go (g, sv)
-- | Just (vsm, els) <- MV.gseToList sv = do
-- (els', pbs) <- unzip <$> zipWithM (guardedLoopCollection vsm lc (stratSlice b)) [0..] els
-- -- We just propagate the vsm for the collection value
-- -- to the users of the result of this sequence. We
-- -- also inherit the guard from the collection
-- let v = MV.singleton g (VSequence vsm els')
-- pure (v, (g, vsm, pbs))

-- -- pure (v, node)
-- | otherwise = panic "UNIMPLEMENTED: map over non-lists" []

-- TODO: prune early as for Repeat above
col <- synthesiseExpr (lcCol lc)
let go (g, sv)
| Just (vsm, els) <- MV.gseToList sv = do
(els', pbs) <- unzip <$> zipWithM (guardedLoopCollection vsm lc (stratSlice b)) [0..] els
-- We just propagate the vsm for the collection value
-- to the users of the result of this sequence. We
-- also inherit the guard from the collection
let v = MV.singleton g (VSequence vsm els')
pure (v, (g, vsm, pbs))

-- pure (v, node)
| otherwise = panic "UNIMPLEMENTED: map over non-lists" []
-- (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col))

(vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col))
-- v <- hoistMaybe (MV.unions' vs)
-- let node' = PathLoopMorphism ltag nodes

v <- hoistMaybe (MV.unions' vs)
let node' = PathLoopMorphism ltag nodes

pure (v, SelectedLoop node')
-- pure (v, SelectedLoop node')
where
execBnd :: (NonEmpty Int -> Int) -> Expr -> SymbolicM (MuxValue, Maybe Int)
execBnd g bnd = do
sbnd <- synthesiseExpr bnd
let m_cbnd = fmap (g . fmap (fromIntegral . V.valueToSize)) (MV.asValues sbnd)
let m_cbnd = g . fmap fromIntegral <$> MV.asIntegers sbnd
pure (sbnd, m_cbnd)

manyBoundsCheck slv slb m_sub = do
Expand All @@ -550,29 +562,24 @@ stratLoop lclass =
traverse_ (mkBound (slv `S.bvULeq`)) m_sub

-- Constructs a bounds check, and returns a concrete bound (if any)
mkBound f sbnd = do
ienv <- liftStrategy getIEnv
assertSExpr (f (MV.toSExpr (I.tEnv ienv) sizeType sbnd))
mkBound f sbnd = assertSExpr (f (MV.toSExpr sbnd))

-- Handles early failure as well.
guardedLoopCollection :: VSequenceMeta ->
LoopCollection' e ->
SymbolicM b ->
Int ->
MuxValue ->
SymbolicM b
SymbolicM (Maybe b)
guardedLoopCollection vsm lc m i el = do
el' <- hoistMaybe (MV.refine g el)
let kv = vUInt g 64 (fromIntegral i)
let kv = MV.vInteger sizeType (fromIntegral i)
bindK = maybe id (\kn -> primBindName kn kv) (lcKName lc)
bindE = primBindName (lcElName lc) el'
bindE = primBindName (lcElName lc) el

censor doCensor (bindK (bindE m))
handleUnreachable g (bindK (bindE m))
where
(g, doCensor)
| Just lcv <- vsmLoopCountVar vsm =
let g' = PS.insertLoopCount lcv (PS.LCCGt i) mempty
in ( g', extendPath (PS.toSExpr g') )
| otherwise = (mempty, id)
g | Just lcv <- vsmLoopCountVar vsm = PS.loopCountGtConstraint lcv i
| otherwise = PS.trivialPathSet

-- We represent disjunction by having multiple values; in this case,
-- if we have matching values for the case alternative v1, v2, v3,
Expand Down Expand Up @@ -608,8 +615,8 @@ guardedLoopCollection vsm lc m i el = do
stratCase :: Bool -> Case ExpSlice -> Maybe (Set SliceId) -> SymbolicM Result
stratCase _total cs m_sccs = do
v <- getName (caseVar cs)
pv <- freshPathVar
let (pred, missing) = MV.semiExecPatterns v pv (map fst (caseAlts cs))
pv <- freshPathVar (length (casePats cs))
let (pred, missing) = MV.semiExecPatterns v pv (map fst (casePats cs))
undefined

-- (alts, preds) <- liftSemiSolverM (MV.semiExecCase cs)
Expand Down Expand Up @@ -881,23 +888,23 @@ synthesiseByteSet bs b = go bs -- liftSymExecM $ SE.symExecByteSet b bs
go bs' =
case bs' of
SetAny -> pure (S.bool True)
SetSingle v -> S.eq b <$> symExecExpr v
SetRange l h -> S.and <$> (flip S.bvULeq b <$> symExecExpr l)
<*> (S.bvULeq b <$> symExecExpr h)
SetSingle v -> S.eq b <$> symE v
SetRange l h -> S.and <$> (flip S.bvULeq b <$> symE l)
<*> (S.bvULeq b <$> symE h)

SetComplement c -> S.not <$> go c
SetUnion l r -> S.or <$> go l <*> go r
SetIntersection l r -> S.and <$> go l <*> go r

SetLet n e bs'' -> do
(n', r) <- bindNameFreshIn n (go bs'')
mklet n' <$> symExecExpr e <*> pure r
e' <- synthesiseExpr e
primBindName n e' (go bs'')

SetCall f es -> S.fun (fnameToSMTName f) <$> ((++ [b]) <$> mapM symExecExpr es)
SetCase {} -> unimplemented
SetCall {} -> unexpected bs'
SetCase {} -> unexpected bs'

unimplemented = panic "SymExec (ByteSet): Unimplemented inside" [showPP bs]

unexpected bs' = panic "Unexpected constructor" [showPP bs']
symE = fmap MV.toSExpr . synthesiseExpr

-- traceGUIDChange :: String -> SymbolicM a -> SymbolicM a
-- traceGUIDChange msg m = do
Expand Down
Loading

0 comments on commit 25a3cba

Please sign in to comment.