From 25a3cba9b55bf749f80e5fcc7b8fbe74c10daa3d Mon Sep 17 00:00:00 2001 From: Simon Winwood Date: Thu, 3 Aug 2023 19:17:30 -0700 Subject: [PATCH] talos: about to refactor assertions --- talos/src/Talos/Strategy/PathSymbolic.hs | 191 +++++++++--------- .../Talos/Strategy/PathSymbolic/Branching.hs | 40 ++++ .../src/Talos/Strategy/PathSymbolic/Monad.hs | 4 +- .../Talos/Strategy/PathSymbolic/MuxValue.hs | 26 ++- .../Strategy/PathSymbolic/PathBuilder.hs | 7 +- 5 files changed, 160 insertions(+), 108 deletions(-) diff --git a/talos/src/Talos/Strategy/PathSymbolic.hs b/talos/src/Talos/Strategy/PathSymbolic.hs index c7b8457a..5a3b515b 100644 --- a/talos/src/Talos/Strategy/PathSymbolic.hs +++ b/talos/src/Talos/Strategy/PathSymbolic.hs @@ -15,7 +15,7 @@ module Talos.Strategy.PathSymbolic (pathSymbolicStrat) where import Control.Lens (_1, _2, each, over, itraverse, preview, traverseOf) import Control.Monad (forM_, when, - zipWithM, (<=<), unless) + zipWithM, (<=<), unless, join) import Control.Monad.Reader import Control.Monad.Writer.CPS (censor, pass) import Data.Bifunctor (second) @@ -71,10 +71,12 @@ import Talos.Solver.SolverT (declareName, declareSymbol, liftSolver, reset, scoped, contextSize) -import Talos.Lib (tByte) +import Talos.Lib (tByte, andMany) import Control.Monad.Except (catchError) import Data.List ((\\)) +import qualified Talos.Strategy.PathSymbolic.Branching as B + -- ---------------------------------------------------------------------------------------- -- Backtracking random strats @@ -228,7 +230,7 @@ stratSlice = go primBindName n n' $ do pe <- synthesiseExpr p - assertSExpr (MV.asAssertion pe) + assertSExpr (MV.toSExpr pe) -- Once we have a model, we need to convert all the free -- variables in the inverse function expression into values, @@ -285,8 +287,7 @@ stratChoice sls _ = do unzip . catMaybes <$> itraverse (\i -> guardedChoice pv i . stratSlice) sls - when (null vs) unreachable -- all paths failed - v <- liftSemiSolverM (MV.mux vs Nothing) + v <- maybe unreachable (liftSemiSolverM . MV.mux) (B.branchingMaybe vs Nothing) let feasibleIxs = map fst paths @@ -312,7 +313,7 @@ stratLoop lclass = -- No need to constrain the processing of the body. (v, m) <- stratSlice b - let xs = MV.VSequence [] (vsm, [v]) + let xs = MV.VSequence (B.branching [] (vsm, [v])) v' = case sem of SemNo -> MV.VUnit SemYes -> xs @@ -356,7 +357,7 @@ stratLoop lclass = -- something. (elv, m) <- censor (extendPath pathGuard) (stratSlice b) - let v = MV.VSequence [] (vsm, [elv]) + let v = MV.VSequence (B.singleton (vsm, [elv])) node = PathLoopGenerator ltag (Just lv) m pure (v, SelectedLoop node) @@ -395,7 +396,7 @@ stratLoop lclass = , vsmMinLength = clb , vsmIsBuilder = False } - v = MV.VSequence [] (vsm, els) + v = MV.VSequence (B.singleton (vsm, els)) node = PathLoopUnrolled (Just lv) ms pure (v, SelectedLoop node) @@ -451,7 +452,7 @@ stratLoop lclass = (vs, pbs) <- unzip <$> go se [] 0 let node = PathLoopUnrolled (Just lv) pbs - v <- liftSemiSolverM (MV.mux vs (Just se)) + v <- liftSemiSolverM (MV.mux (B.branching vs se)) pure (v, SelectedLoop node) SMorphismLoop (FoldMorphism n e lc b) -> do @@ -459,89 +460,100 @@ stratLoop lclass = se <- synthesiseExpr e m_col <- preview #_VSequence <$> synthesiseExpr (lcCol lc) - - let (vars, base) = case m_col of + + let bvs = case m_col of Nothing -> panic "UNIMPLEMENTED: non-list fold" [] Just r -> r - -- c.f. SRepeatLoop - let guards vsm - | Just lv <- vsmLoopCountVar vsm = - \i -> ( PS.loopCountGeqConstraint lv i - , PS.loopCountEqConstraint lv i - ) - | otherwise = const (mempty, mempty) - - -- this allows short-circuiting if we try to unfold too many - -- times. - go se' acc i - | i == nloops = pure (reverse acc) - | otherwise = do - m_v_pb <- handleUnreachable (gtGuard i) - (primBindName n se' (stratSlice b)) - case m_v_pb of - Nothing -> pure (reverse acc) - Just (v, pb) -> do - -- these should work (no imcompatible assumptions about lv) - go v (((eqGuard i, v), pb) : acc) (i + 1) + -- -- c.f. SRepeatLoop + -- let guards vsm + -- | Just lv <- vsmLoopCountVar vsm = + -- \i -> ( PS.loopCountGeqConstraint lv i + -- , PS.loopCountEqConstraint lv i + -- ) + -- | otherwise = const (mempty, mempty) + + -- -- this allows short-circuiting if we try to unfold too many + -- -- times. + -- go se' acc i + -- | i == nloops = pure (reverse acc) + -- | otherwise = do + -- m_v_pb <- handleUnreachable (gtGuard i) + -- (primBindName n se' (stratSlice b)) + -- case m_v_pb of + -- Nothing -> pure (reverse acc) + -- Just (v, pb) -> do + -- -- these should work (no imcompatible assumptions about lv) + -- go v (((eqGuard i, v), pb) : acc) (i + 1) - -- TODO: prune early as for Repeat above - goOne _vsm (_se', acc) [] = pure (reverse acc) - goOne vsm (se', acc) ((i, el) : rest) = do - (v, pb) <- guardedLoopCollection vsm lc (primBindName n se' (stratSlice b)) i el - let (gtGuard, eqGuard) = guards vsm (i + 1) + -- -- TODO: prune early as for Repeat above + -- goOne _vsm (_se', acc) [] = pure (reverse acc) + -- goOne vsm (se', acc) ((i, el) : rest) = do + -- (v, pb) <- guardedLoopCollection vsm lc (primBindName n se' (stratSlice b)) i el + -- let (gtGuard, eqGuard) = guards vsm (i + 1) - v' <- hoistMaybe (MV.refine eqGuard v) - se'' <- hoistMaybe (MV.refine gtGuard v) - goOne vsm (se'', (v', pb) : acc) rest + -- v' <- hoistMaybe (MV.refine eqGuard v) + -- se'' <- hoistMaybe (MV.refine gtGuard v) + -- goOne vsm (se'', (v', pb) : acc) rest - go (vsm, els) = do - (svs, pbs) <- unzip <$> goOne vsm (se, []) (zip [0..] els) + -- go (vsm, els) = do + -- (svs, pbs) <- unzip <$> goOne vsm (se, []) (zip [0..] els) - let (_, eqGuard) = guards vsm 0 - base <- hoistMaybe (MV.refine eqGuard se) + -- let (_, eqGuard) = guards vsm 0 + -- base <- hoistMaybe (MV.refine eqGuard se) - pure (MV.unions (base :| svs), (g, vsm, pbs)) + -- pure (MV.unions (base :| svs), (g, vsm, pbs)) - -- (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col)) + -- -- (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col)) - vars' <- traverseOf (each . _2) go vars - base' <- go base + let go :: (VSequenceMeta, [MuxValue]) -> + SymbolicM (B.Branching MuxValue, [S.SExpr], (VSequenceMeta, [PathBuilder])) + go (vsm, els) = do + (svs, pbs) <- unzip <$> goOne vsm se [] (zip [0..] els) + let svs' = drop (vsmMinLength vsm) ((eqGuard vsm 0, se) : svs) + + undefined + + (bvs', assns, nodes) <- B.unzip3 <$> traverse go bvs + let assn = B.fold (\ps ss acc -> [S.ite (PS.toSExpr ps) (andMany ss) (andMany acc)]) assns + -- re-assert guarded assertions + assertSExpr (andMany assn) + + v <- liftSemiSolverM (MV.mux (join bvs')) - v <- hoistMaybe (MV.unions' vs) let node' = PathLoopMorphism ltag nodes - pure (v, SelectedLoop node') SMorphismLoop (MapMorphism lc b) -> do ltag <- freshSymbolicLoopTag + undefined + + -- -- TODO: prune early as for Repeat above + -- col <- synthesiseExpr (lcCol lc) + -- let go (g, sv) + -- | Just (vsm, els) <- MV.gseToList sv = do + -- (els', pbs) <- unzip <$> zipWithM (guardedLoopCollection vsm lc (stratSlice b)) [0..] els + -- -- We just propagate the vsm for the collection value + -- -- to the users of the result of this sequence. We + -- -- also inherit the guard from the collection + -- let v = MV.singleton g (VSequence vsm els') + -- pure (v, (g, vsm, pbs)) + + -- -- pure (v, node) + -- | otherwise = panic "UNIMPLEMENTED: map over non-lists" [] - -- TODO: prune early as for Repeat above - col <- synthesiseExpr (lcCol lc) - let go (g, sv) - | Just (vsm, els) <- MV.gseToList sv = do - (els', pbs) <- unzip <$> zipWithM (guardedLoopCollection vsm lc (stratSlice b)) [0..] els - -- We just propagate the vsm for the collection value - -- to the users of the result of this sequence. We - -- also inherit the guard from the collection - let v = MV.singleton g (VSequence vsm els') - pure (v, (g, vsm, pbs)) - - -- pure (v, node) - | otherwise = panic "UNIMPLEMENTED: map over non-lists" [] + -- (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col)) - (vs, nodes) <- unzip <$> collectMaybes (map go (MV.guardedValues col)) + -- v <- hoistMaybe (MV.unions' vs) + -- let node' = PathLoopMorphism ltag nodes - v <- hoistMaybe (MV.unions' vs) - let node' = PathLoopMorphism ltag nodes - - pure (v, SelectedLoop node') + -- pure (v, SelectedLoop node') where execBnd :: (NonEmpty Int -> Int) -> Expr -> SymbolicM (MuxValue, Maybe Int) execBnd g bnd = do sbnd <- synthesiseExpr bnd - let m_cbnd = fmap (g . fmap (fromIntegral . V.valueToSize)) (MV.asValues sbnd) + let m_cbnd = g . fmap fromIntegral <$> MV.asIntegers sbnd pure (sbnd, m_cbnd) manyBoundsCheck slv slb m_sub = do @@ -550,29 +562,24 @@ stratLoop lclass = traverse_ (mkBound (slv `S.bvULeq`)) m_sub -- Constructs a bounds check, and returns a concrete bound (if any) - mkBound f sbnd = do - ienv <- liftStrategy getIEnv - assertSExpr (f (MV.toSExpr (I.tEnv ienv) sizeType sbnd)) + mkBound f sbnd = assertSExpr (f (MV.toSExpr sbnd)) +-- Handles early failure as well. guardedLoopCollection :: VSequenceMeta -> LoopCollection' e -> SymbolicM b -> Int -> MuxValue -> - SymbolicM b + SymbolicM (Maybe b) guardedLoopCollection vsm lc m i el = do - el' <- hoistMaybe (MV.refine g el) - let kv = vUInt g 64 (fromIntegral i) + let kv = MV.vInteger sizeType (fromIntegral i) bindK = maybe id (\kn -> primBindName kn kv) (lcKName lc) - bindE = primBindName (lcElName lc) el' + bindE = primBindName (lcElName lc) el - censor doCensor (bindK (bindE m)) + handleUnreachable g (bindK (bindE m)) where - (g, doCensor) - | Just lcv <- vsmLoopCountVar vsm = - let g' = PS.insertLoopCount lcv (PS.LCCGt i) mempty - in ( g', extendPath (PS.toSExpr g') ) - | otherwise = (mempty, id) + g | Just lcv <- vsmLoopCountVar vsm = PS.loopCountGtConstraint lcv i + | otherwise = PS.trivialPathSet -- We represent disjunction by having multiple values; in this case, -- if we have matching values for the case alternative v1, v2, v3, @@ -608,8 +615,8 @@ guardedLoopCollection vsm lc m i el = do stratCase :: Bool -> Case ExpSlice -> Maybe (Set SliceId) -> SymbolicM Result stratCase _total cs m_sccs = do v <- getName (caseVar cs) - pv <- freshPathVar - let (pred, missing) = MV.semiExecPatterns v pv (map fst (caseAlts cs)) + pv <- freshPathVar (length (casePats cs)) + let (pred, missing) = MV.semiExecPatterns v pv (map fst (casePats cs)) undefined -- (alts, preds) <- liftSemiSolverM (MV.semiExecCase cs) @@ -881,23 +888,23 @@ synthesiseByteSet bs b = go bs -- liftSymExecM $ SE.symExecByteSet b bs go bs' = case bs' of SetAny -> pure (S.bool True) - SetSingle v -> S.eq b <$> symExecExpr v - SetRange l h -> S.and <$> (flip S.bvULeq b <$> symExecExpr l) - <*> (S.bvULeq b <$> symExecExpr h) + SetSingle v -> S.eq b <$> symE v + SetRange l h -> S.and <$> (flip S.bvULeq b <$> symE l) + <*> (S.bvULeq b <$> symE h) SetComplement c -> S.not <$> go c SetUnion l r -> S.or <$> go l <*> go r SetIntersection l r -> S.and <$> go l <*> go r SetLet n e bs'' -> do - (n', r) <- bindNameFreshIn n (go bs'') - mklet n' <$> symExecExpr e <*> pure r + e' <- synthesiseExpr e + primBindName n e' (go bs'') - SetCall f es -> S.fun (fnameToSMTName f) <$> ((++ [b]) <$> mapM symExecExpr es) - SetCase {} -> unimplemented + SetCall {} -> unexpected bs' + SetCase {} -> unexpected bs' - unimplemented = panic "SymExec (ByteSet): Unimplemented inside" [showPP bs] - + unexpected bs' = panic "Unexpected constructor" [showPP bs'] + symE = fmap MV.toSExpr . synthesiseExpr -- traceGUIDChange :: String -> SymbolicM a -> SymbolicM a -- traceGUIDChange msg m = do diff --git a/talos/src/Talos/Strategy/PathSymbolic/Branching.hs b/talos/src/Talos/Strategy/PathSymbolic/Branching.hs index 04f4a57d..3c7e707f 100644 --- a/talos/src/Talos/Strategy/PathSymbolic/Branching.hs +++ b/talos/src/Talos/Strategy/PathSymbolic/Branching.hs @@ -5,10 +5,17 @@ module Talos.Strategy.PathSymbolic.Branching ( Branching(..) + -- * Constructors + , singleton , branching , branchingMaybe + , branchingNE + + -- * Operations , fold , foldM + , unzip + , unzip3 , mapVariants , resolve @@ -16,6 +23,9 @@ module Talos.Strategy.PathSymbolic.Branching ) where +import Prelude hiding (unzip, unzip3) +import qualified Prelude + import GHC.Generics (Generic) import Talos.Strategy.PathSymbolic.PathSet ( PathSet ) @@ -28,6 +38,7 @@ import qualified Data.Map.Merge.Strict as Map import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Daedalus.Panic (panic) +import Data.List.NonEmpty (NonEmpty(..)) data Branching a = Branching { variants :: [ (PathSet, a) ] @@ -51,9 +62,16 @@ instance Monad Branching where vs = variants (base bs) ++ concatMap vsOne (variants bs) base' = base (base bs) +singleton :: a -> Branching a +singleton = pure + branching :: [(PathSet, a)] -> a -> Branching a branching = Branching + +branchingNE :: NonEmpty (PathSet, a) -> Branching a +branchingNE ((_, a) :| rest) = branching rest a + branchingMaybe :: [(PathSet, a)] -> Maybe a -> Maybe (Branching a) branchingMaybe vs m_b | (_, v) : vs' <- vs, Nothing <- m_b = Just (Branching vs' v) @@ -64,12 +82,34 @@ branchingMaybe' vs m_b = fromMaybe err (branchingMaybe vs m_b) where err = panic "Expecting non-empty branching" [] +-- Standard operations + fold :: (PathSet -> a -> a -> a) -> Branching a -> a fold f b = foldl' (\a' (ps, a) -> f ps a a') (base b) (variants b) foldM :: Monad m => (PathSet -> a -> a -> m a) -> Branching a -> m a foldM f b = foldlM (\a' (ps, a) -> f ps a a') (base b) (variants b) +-- FIXME: duplicates the pathsets +unzip :: Branching (a, b) -> (Branching a, Branching b) +unzip b = ( Branching { variants = zip pss vs1, base = fst (base b) } + , Branching { variants = zip pss vs2, base = snd (base b) } + ) + where + (pss, vs) = Prelude.unzip (variants b) + (vs1, vs2) = Prelude.unzip vs + +-- FIXME: duplicates pathsets +unzip3 :: Branching (a, b, c) -> (Branching a, Branching b, Branching c) +unzip3 b = ( Branching { variants = zip pss vs1, base = b1 } + , Branching { variants = zip pss vs2, base = b2 } + , Branching { variants = zip pss vs3, base = b3 } + ) + where + (pss, vs) = Prelude.unzip (variants b) + (vs1, vs2, vs3) = Prelude.unzip3 vs + (b1, b2, b3) = base b + mapVariants :: (PathSet -> a -> Maybe (PathSet, a)) -> Branching a -> Branching a mapVariants f bvs = bvs { variants = new } where diff --git a/talos/src/Talos/Strategy/PathSymbolic/Monad.hs b/talos/src/Talos/Strategy/PathSymbolic/Monad.hs index 586931d4..7050cc22 100644 --- a/talos/src/Talos/Strategy/PathSymbolic/Monad.hs +++ b/talos/src/Talos/Strategy/PathSymbolic/Monad.hs @@ -25,7 +25,7 @@ import GHC.Generics (Generic) import qualified SimpleSMT as SMT -- FIXME: use .CPS import Control.Monad.Except (ExceptT, runExceptT, - throwError) + throwError, MonadError) import Control.Monad.Writer (MonadWriter, WriterT, runWriterT, tell) import Data.List.NonEmpty (NonEmpty) @@ -191,7 +191,7 @@ instance Monoid SymbolicModel where newtype SymbolicM a = SymbolicM { getSymbolicM :: ExceptT () (WriterT SymbolicModel (ReaderT SymbolicEnv (SolverT StrategyM))) a } - deriving (Applicative, Functor, Monad, MonadIO + deriving (Applicative, Functor, Monad, MonadIO, MonadError () , MonadReader SymbolicEnv, MonadWriter SymbolicModel, MonadSolver, LiftTalosM) instance LiftStrategyM SymbolicM where diff --git a/talos/src/Talos/Strategy/PathSymbolic/MuxValue.hs b/talos/src/Talos/Strategy/PathSymbolic/MuxValue.hs index 55249832..83d7af75 100644 --- a/talos/src/Talos/Strategy/PathSymbolic/MuxValue.hs +++ b/talos/src/Talos/Strategy/PathSymbolic/MuxValue.hs @@ -29,10 +29,12 @@ module Talos.Strategy.PathSymbolic.MuxValue ( -- * Constructors , vSymbolicBool , vSymbolicInteger + , vInteger -- * Combinators , mux -- * Destructors - , asAssertion + , toSExpr + , asIntegers -- * Monad , SemiSolverM , runSemiSolverM @@ -95,6 +97,7 @@ import Talos.Strategy.PathSymbolic.SymExec (symExecTy) import qualified Talos.Strategy.PathSymbolic.Branching as B import Talos.Strategy.PathSymbolic.Branching (Branching) +import Data.List.NonEmpty (NonEmpty) -------------------------------------------------------------------------------- -- Logging and stats @@ -196,7 +199,6 @@ type MuxValue = MuxValueF SMTVar -- Used internally to avoid naming after every single operation. type MuxValueSExpr = MuxValueF SExpr - -- ---------------------------------------------------------------------------------------- -- ValueLens -- @@ -343,6 +345,11 @@ vNothing = VMaybe (singletonSumTypeMuxValueF Nothing VUnit) vJust :: MuxValueF b -> MuxValueF b vJust = VMaybe . singletonSumTypeMuxValueF (Just ()) +asIntegers :: MuxValueF s -> Maybe (NonEmpty Integer) +asIntegers (VIntegers (Typed _ bvs)) + | Nothing <- bvSymbolic bvs = NE.nonEmpty (Map.keys (bvConcrete bvs)) +asIntegers _ = Nothing + -- ---------------------------------------------------------------------------------------- -- Multiplexing values @@ -426,10 +433,11 @@ baseValueToSExpr ty vl bvs where sexps = map (first (vlToSExpr vl ty)) (Map.toList (bvConcrete bvs)) -asAssertion :: MuxValue -> SExpr -asAssertion (VBools bvs) = baseValueToSExpr TBool boolVL (S.const <$> bvs) -asAssertion _ = panic "Expecting a boolean value" [] - +toSExpr :: MuxValue -> SExpr +toSExpr (VBools bvs) = baseValueToSExpr TBool boolVL (S.const <$> bvs) +toSExpr (VIntegers (Typed ty bvs)) = baseValueToSExpr ty integerVL (S.const <$> bvs) +toSExpr _ = panic "Non-base value" [] + semiExecName :: (SemiCtxt m, HasCallStack) => Name -> SemiSolverM m MuxValueSExpr semiExecName n = asks (fromMaybe missing . Map.lookup n . localBoundNames) @@ -834,8 +842,8 @@ unreachable = throwError () -- Just fn -> pure fn -- Nothing -> panic "Missing pure function" [showPP f] --- -------------------------------------------------------------------------------- --- -- Exprs +-------------------------------------------------------------------------------- +-- Exprs semiExecExpr :: (SemiCtxt m, HasCallStack) => Expr -> SemiSolverM m MuxValue semiExecExpr e = nameSExprs =<< semiExecExpr' e @@ -883,7 +891,7 @@ semiExecOp0 op = FloatL {} -> unimplemented BoolL b -> vBool b ByteArrayL bs -> vFixedLenSequence (map (vInteger tByte . fromIntegral) (BS.unpack bs)) - NewBuilder _ty -> VSequence (pure (emptyVSequenceMeta {vsmIsBuilder = True}, [])) + NewBuilder _ty -> VSequence (B.singleton (emptyVSequenceMeta {vsmIsBuilder = True}, [])) MapEmpty _kty _vty -> VMap [] ENothing _ty -> vNothing where diff --git a/talos/src/Talos/Strategy/PathSymbolic/PathBuilder.hs b/talos/src/Talos/Strategy/PathSymbolic/PathBuilder.hs index 7f41ffb2..96c92706 100644 --- a/talos/src/Talos/Strategy/PathSymbolic/PathBuilder.hs +++ b/talos/src/Talos/Strategy/PathSymbolic/PathBuilder.hs @@ -86,6 +86,7 @@ import Talos.Path import qualified Talos.Solver.SolverT as Solv import Talos.Solver.SolverT (SMTVar, SolverT) import Talos.Lib (findM, andMany) +import qualified Talos.Strategy.PathSymbolic.Branching as B -- ---------------------------------------------------------------------------------------- -- Model parsing and enumeration. @@ -530,11 +531,7 @@ buildLoop (PathLoopGenerator ltag m_lv el) = -- We will fill this in when we have all the models pure SelectedHole -buildLoop (PathLoopMorphism _ltag variants base) = do - m_el <- findM (PS.fromModel . fst) variants - case m_el of - Nothing -> go base - Just (_, r) -> go r +buildLoop (PathLoopMorphism _ltag bvs) = go =<< B.resolve bvs where go (vsm, els) -- This is the unrolled case, so we just emit the path. We re-use the Unrolled case above.