From d7229df5d1c22f4597c29a12ba9356305e2b06fd Mon Sep 17 00:00:00 2001 From: Georgy Lukyanov Date: Mon, 10 Jul 2023 15:21:31 +0200 Subject: [PATCH] Don't trace rewrites and equation applications if not requested (#225) Fixes #224 and #126 * Refactor `EquatioM` to `EquationT` - make it a monad transformer - Move immutable data to a `ReaderT` - Add `MonadLoggerIO` constraint on interface functions * Refactor `RewriteM` to `RewriteT` - make it a monad transformer - Make the reader config a proper product type, rather then a tuple - Add `MonadLoggerIO` constraint on interface functions * Thread tracing flag down to rewrites and equation applications. Only accumulate traces if the flag is set. Currently, the `doTracing` flag is less granular than the RPC request tracing options: we will collect all traces if any of the four RPC options is set. We should probably address this in a follow-up PR. --- library/Booster/JsonRpc.hs | 26 ++- library/Booster/Pattern/ApplyEquations.hs | 186 ++++++++++-------- library/Booster/Pattern/Rewrite.hs | 70 ++++--- .../Test/Booster/Pattern/ApplyEquations.hs | 8 +- unit-tests/Test/Booster/Pattern/Rewrite.hs | 35 ++-- 5 files changed, 194 insertions(+), 131 deletions(-) diff --git a/library/Booster/JsonRpc.hs b/library/Booster/JsonRpc.hs index fb91881db..87079385a 100644 --- a/library/Booster/JsonRpc.hs +++ b/library/Booster/JsonRpc.hs @@ -29,7 +29,6 @@ import Data.Text qualified as Text import Data.Text.Encoding qualified as Text import GHC.Records import Numeric.Natural -import Prettyprinter import Booster.Definition.Attributes.Base (getUniqueId, uniqueId) import Booster.Definition.Base (KoreDefinition (..)) @@ -44,7 +43,6 @@ import Booster.Pattern.Rewrite ( performRewrite, ) import Booster.Pattern.Util (sortOfTerm) -import Booster.Prettyprinter (renderDefault) import Booster.Syntax.Json (KoreJson (..), addHeader, sortOfJson) import Booster.Syntax.Json.Externalise import Booster.Syntax.Json.Internalise (internalisePattern, internaliseTermOrPredicate) @@ -83,7 +81,15 @@ respond stateVar = let cutPoints = fromMaybe [] req.cutPointRules terminals = fromMaybe [] req.terminalRules mbDepth = fmap getNat req.maxDepth - execResponse req <$> performRewrite def mLlvmLibrary mbDepth cutPoints terminals pat + doTracing = + any + (fromMaybe False) + [ req.logSuccessfulRewrites + , req.logFailedRewrites + , req.logSuccessfulSimplifications + , req.logFailedSimplifications + ] + execResponse req <$> performRewrite doTracing def mLlvmLibrary mbDepth cutPoints terminals pat AddModule req -> do -- block other request executions while modifying the server state state <- liftIO $ takeMVar stateVar @@ -124,8 +130,12 @@ respond stateVar = Just . mapMaybe (mkLogEquationTrace (req.logSuccessfulSimplifications, req.logFailedSimplifications)) . toList - logTraces = - mapM_ (Log.logOther (Log.LevelOther "Simplify") . pack . renderDefault . pretty) + doTracing = + any + (fromMaybe False) + [ req.logSuccessfulSimplifications + , req.logFailedSimplifications + ] case internalised of Left patternErrors -> do Log.logError $ "Error internalising cterm: " <> Text.pack (show patternErrors) @@ -133,9 +143,8 @@ respond stateVar = -- term and predicate (pattern) Right (TermAndPredicate Pattern{term, constraints}) -> do Log.logInfoNS "booster" "Simplifying term of a pattern" - case ApplyEquations.evaluateTerm ApplyEquations.TopDown def mLlvmLibrary term of + ApplyEquations.evaluateTerm doTracing ApplyEquations.TopDown def mLlvmLibrary term >>= \case Right (newTerm, traces) -> do - logTraces $ filter (not . ApplyEquations.isMatchFailure) traces let (t, p) = externalisePattern Pattern{constraints, term = newTerm} tSort = externaliseSort (sortOfTerm newTerm) result = maybe t (KoreJson.KJAnd tSort t) p @@ -151,9 +160,8 @@ respond stateVar = -- predicate only Right (APredicate predicate) -> do Log.logInfoNS "booster" "Simplifying a predicate" - case ApplyEquations.traceSimplifyConstraint def mLlvmLibrary predicate of + ApplyEquations.simplifyConstraint doTracing def mLlvmLibrary predicate >>= \case Right (newPred, traces) -> do - logTraces $ filter (not . ApplyEquations.isMatchFailure) traces let predicateSort = fromMaybe (error "not a predicate") $ sortOfJson req.state.term diff --git a/library/Booster/Pattern/ApplyEquations.hs b/library/Booster/Pattern/ApplyEquations.hs index 54dc5cc9f..8f88d54e7 100644 --- a/library/Booster/Pattern/ApplyEquations.hs +++ b/library/Booster/Pattern/ApplyEquations.hs @@ -14,14 +14,21 @@ module Booster.Pattern.ApplyEquations ( isMatchFailure, isSuccess, simplifyConstraint, - traceSimplifyConstraint, ) where import Control.Monad import Control.Monad.Extra +import Control.Monad.IO.Class (MonadIO (..)) +import Control.Monad.Logger.CallStack ( + LogLevel (..), + MonadLogger, + MonadLoggerIO, + logOther, + ) import Control.Monad.Trans.Class import Control.Monad.Trans.Except import Control.Monad.Trans.Maybe +import Control.Monad.Trans.Reader (ReaderT (..), ask) import Control.Monad.Trans.State import Data.Foldable (toList) import Data.Functor.Foldable @@ -31,7 +38,7 @@ import Data.Map qualified as Map import Data.Maybe (catMaybes, fromJust, fromMaybe, isJust) import Data.Sequence (Seq (..)) import Data.Set qualified as Set -import Data.Text (Text) +import Data.Text (Text, pack) import Data.Text qualified as Text import Prettyprinter @@ -44,12 +51,14 @@ import Booster.Pattern.Index import Booster.Pattern.Match import Booster.Pattern.Simplify import Booster.Pattern.Util +import Booster.Prettyprinter (renderDefault) -newtype EquationM a = EquationM (StateT EquationState (Except EquationFailure) a) - deriving newtype (Functor, Applicative, Monad) +newtype EquationT io a + = EquationT (StateT EquationState (ReaderT EquationConfig (ExceptT EquationFailure io)) a) + deriving newtype (Functor, Applicative, Monad, MonadIO, MonadLogger, MonadLoggerIO) -throw :: EquationFailure -> EquationM a -throw = EquationM . lift . throwE +throw :: MonadLoggerIO io => EquationFailure -> EquationT io a +throw = EquationT . lift . lift . throwE data EquationFailure = IndexIsNone Term @@ -58,10 +67,14 @@ data EquationFailure | InternalError Text deriving stock (Eq, Show) -data EquationState = EquationState +data EquationConfig = EquationConfig { definition :: KoreDefinition , llvmApi :: Maybe LLVM.API - , termStack :: [Term] + , doTracing :: Bool + } + +data EquationState = EquationState + { termStack :: [Term] , changed :: Bool , trace :: Seq EquationTrace } @@ -130,27 +143,30 @@ isMatchFailure _ = False isSuccess EquationTrace{result = Success{}} = True isSuccess _ = False -startState :: KoreDefinition -> Maybe LLVM.API -> EquationState -startState definition llvmApi = - EquationState{definition, llvmApi, termStack = [], changed = False, trace = mempty} +startState :: EquationState +startState = + EquationState{termStack = [], changed = False, trace = mempty} + +getState :: MonadLoggerIO io => EquationT io EquationState +getState = EquationT get -getState :: EquationM EquationState -getState = EquationM get +getConfig :: MonadLoggerIO io => EquationT io EquationConfig +getConfig = EquationT (lift ask) -countSteps :: EquationM Int +countSteps :: MonadLoggerIO io => EquationT io Int countSteps = length . (.termStack) <$> getState -pushTerm :: Term -> EquationM () -pushTerm t = EquationM . modify $ \s -> s{termStack = t : s.termStack} +pushTerm :: MonadLoggerIO io => Term -> EquationT io () +pushTerm t = EquationT . modify $ \s -> s{termStack = t : s.termStack} -setChanged, resetChanged :: EquationM () -setChanged = EquationM . modify $ \s -> s{changed = True} -resetChanged = EquationM . modify $ \s -> s{changed = False} +setChanged, resetChanged :: MonadLoggerIO io => EquationT io () +setChanged = EquationT . modify $ \s -> s{changed = True} +resetChanged = EquationT . modify $ \s -> s{changed = False} -getChanged :: EquationM Bool -getChanged = EquationM $ gets (.changed) +getChanged :: MonadLoggerIO io => EquationT io Bool +getChanged = EquationT $ gets (.changed) -checkForLoop :: Term -> EquationM () +checkForLoop :: MonadLoggerIO io => Term -> EquationT io () checkForLoop t = do EquationState{termStack, trace} <- getState whenJust (elemIndex t termStack) $ \i -> do @@ -162,24 +178,32 @@ data Direction = TopDown | BottomUp data EquationPreference = PreferFunctions | PreferSimplifications deriving stock (Eq, Show) -runEquationM :: +runEquationT :: + MonadLoggerIO io => + Bool -> KoreDefinition -> Maybe LLVM.API -> - EquationM a -> - Either EquationFailure (a, [EquationTrace]) -runEquationM definition llvmApi (EquationM m) = - fmap (fmap $ toList . trace) <$> runExcept $ runStateT m $ startState definition llvmApi + EquationT io a -> + io (Either EquationFailure (a, [EquationTrace])) +runEquationT doTracing definition llvmApi (EquationT m) = do + endState <- + runExceptT + . flip runReaderT EquationConfig{definition, llvmApi, doTracing} + . runStateT m + $ startState + pure (fmap (toList . trace) <$> endState) iterateEquations :: + MonadLoggerIO io => Int -> Direction -> EquationPreference -> Term -> - EquationM Term + EquationT io Term iterateEquations maxIterations direction preference startTerm = go startTerm where - go :: Term -> EquationM Term + go :: MonadLoggerIO io => Term -> EquationT io Term go currentTerm | (getAttributes currentTerm).isEvaluated = pure currentTerm | otherwise = do @@ -198,13 +222,15 @@ iterateEquations maxIterations direction preference startTerm = ---------------------------------------- -- Interface function evaluateTerm :: + MonadLoggerIO io => + Bool -> Direction -> KoreDefinition -> Maybe LLVM.API -> Term -> - Either EquationFailure (Term, [EquationTrace]) -evaluateTerm direction def llvmApi = - runEquationM def llvmApi + io (Either EquationFailure (Term, [EquationTrace])) +evaluateTerm doTracing direction def llvmApi = + runEquationT doTracing def llvmApi . iterateEquations 100 direction PreferFunctions ---------------------------------------- @@ -216,10 +242,11 @@ evaluateTerm direction def llvmApi = one equation will be applied per level (if any). -} applyTerm :: + MonadLoggerIO io => Direction -> EquationPreference -> Term -> - EquationM Term + EquationT io Term applyTerm BottomUp pref = cataA $ \case DomainValueF s val -> @@ -239,11 +266,11 @@ applyTerm TopDown pref = \t@(Term attributes _) -> if attributes.isEvaluated then pure t else do - s <- getState + config <- getConfig -- All fully concrete values go to the LLVM backend (top-down only) - if isConcrete t && isJust s.llvmApi + if isConcrete t && isJust config.llvmApi then do - let result = simplifyTerm (fromJust s.llvmApi) s.definition t (sortOfTerm t) + let result = simplifyTerm (fromJust config.llvmApi) config.definition t (sortOfTerm t) when (result /= t) setChanged pure result else apply t @@ -282,11 +309,12 @@ applyTerm TopDown pref = \t@(Term attributes _) -> top-level term, in priority order and per group. -} applyAtTop :: + MonadLoggerIO io => EquationPreference -> Term -> - EquationM Term + EquationT io Term applyAtTop pref term = do - def <- (.definition) <$> getState + def <- (.definition) <$> getConfig case pref of PreferFunctions -> do -- when applying equations, we want to catch DoesNotPreserveDefinedness/incosistentmatch/etc @@ -311,17 +339,17 @@ data ApplyEquationResult | MatchConstraintViolated Constrained VarName deriving stock (Eq, Show) -type ResultHandler = +type ResultHandler io = -- | action on successful equation application - (Term -> EquationM Term) -> + (Term -> EquationT io Term) -> -- | action on failed match - EquationM Term -> + EquationT io Term -> -- | action on aborted equation application - EquationM Term -> + EquationT io Term -> ApplyEquationResult -> - EquationM Term + EquationT io Term -handleFunctionEquation :: ResultHandler +handleFunctionEquation :: ResultHandler io handleFunctionEquation success continue abort = \case Success rewritten -> success rewritten FailedMatch _ -> continue @@ -331,7 +359,7 @@ handleFunctionEquation success continue abort = \case RuleNotPreservingDefinedness -> abort MatchConstraintViolated{} -> continue -handleSimplificationEquation :: ResultHandler +handleSimplificationEquation :: ResultHandler io handleSimplificationEquation success continue _abort = \case Success rewritten -> success rewritten FailedMatch _ -> continue @@ -342,11 +370,12 @@ handleSimplificationEquation success continue _abort = \case MatchConstraintViolated{} -> continue applyEquations :: - forall tag. + forall io tag. + MonadLoggerIO io => Theory (RewriteRule tag) -> - ResultHandler -> + ResultHandler io -> Term -> - EquationM Term + EquationT io Term applyEquations theory handler term = do let index = termTopIndex term when (index == None) $ @@ -374,7 +403,7 @@ applyEquations theory handler term = do -- process one equation at a time, until something has happened processEquations :: [RewriteRule tag] -> - EquationM Term + EquationT io Term processEquations [] = pure term -- nothing to do, term stays the same processEquations (eq : rest) = do @@ -383,21 +412,27 @@ applyEquations theory handler term = do handler (\t -> setChanged >> pure t) (processEquations rest) (pure term) res traceRuleApplication :: + MonadLoggerIO io => Term -> Maybe Location -> Maybe Label -> Maybe UniqueId -> ApplyEquationResult -> - EquationM () -traceRuleApplication t loc lbl uid res = - EquationM . modify $ - \s -> s{trace = s.trace :|> EquationTrace t loc lbl uid res} + EquationT io () +traceRuleApplication t loc lbl uid res = do + let newTraceItem = EquationTrace t loc lbl uid res + logOther (LevelOther "Simplify") (pack . renderDefault . pretty $ newTraceItem) + config <- getConfig + when (config.doTracing) $ + EquationT . modify $ + \s -> s{trace = s.trace :|> newTraceItem} applyEquation :: - forall tag. + forall io tag. + MonadLoggerIO io => Term -> RewriteRule tag -> - EquationM ApplyEquationResult + EquationT io ApplyEquationResult applyEquation term rule = fmap (either id Success) $ runExceptT $ do -- ensured by internalisation: no existentials in equations unless (null rule.existentials) $ @@ -410,7 +445,7 @@ applyEquation term rule = fmap (either id Success) $ runExceptT $ do when (allMustBeConcrete rule.attributes.concreteness && not (Set.null (freeVariables term))) $ throwE (MatchConstraintViolated Concrete "* (term has variables)") -- match lhs - koreDef <- (.definition) <$> lift getState + koreDef <- (.definition) <$> lift getConfig case matchTerm koreDef rule.lhs term of MatchFailed failReason -> throwE $ FailedMatch failReason MatchIndeterminate _pat _subj -> throwE IndeterminateMatch @@ -444,9 +479,9 @@ applyEquation term rule = fmap (either id Success) $ runExceptT $ do -- is Bottom. checkConstraint :: Predicate -> - MaybeT (ExceptT ApplyEquationResult EquationM) (Maybe Predicate) + MaybeT (ExceptT ApplyEquationResult (EquationT io)) (Maybe Predicate) checkConstraint p = do - mApi <- (.llvmApi) <$> lift (lift getState) + mApi <- (.llvmApi) <$> lift (lift getConfig) case simplifyPredicate mApi p of Bottom -> fail "side condition was false" Top -> pure Nothing @@ -458,7 +493,7 @@ applyEquation term rule = fmap (either id Success) $ runExceptT $ do checkConcreteness :: Concreteness -> Map Variable Term -> - ExceptT ApplyEquationResult EquationM () + ExceptT ApplyEquationResult (EquationT io) () checkConcreteness Unconstrained _ = pure () checkConcreteness (AllConstrained constrained) subst = mapM_ (\(var, t) -> mkCheck (toPair var) constrained t) $ Map.assocs subst @@ -474,7 +509,7 @@ applyEquation term rule = fmap (either id Success) $ runExceptT $ do (VarName, SortName) -> Constrained -> Term -> - ExceptT ApplyEquationResult EquationM () + ExceptT ApplyEquationResult (EquationT io) () mkCheck (varName, _) constrained (Term attributes _) | not test = throwE $ MatchConstraintViolated constrained varName | otherwise = pure () @@ -486,8 +521,8 @@ applyEquation term rule = fmap (either id Success) $ runExceptT $ do verifyVar :: Map Variable Term -> (VarName, SortName) -> - (Term -> ExceptT ApplyEquationResult EquationM ()) -> - ExceptT ApplyEquationResult EquationM () + (Term -> ExceptT ApplyEquationResult (EquationT io) ()) -> + ExceptT ApplyEquationResult (EquationT io) () verifyVar subst (variableName, sortName) check = maybe ( lift . throw . InternalError . Text.pack $ @@ -512,35 +547,28 @@ pattern FalseBool = DomainValue SortBool "false" ensured conditions). If and as soon as this function is used inside equation - application, it needs to run within the same 'EquationM' context + application, it needs to run within the same 'EquationT' context so we can detect simplification loops and avoid monad nesting. -} simplifyConstraint :: + MonadLoggerIO io => + Bool -> KoreDefinition -> Maybe LLVM.API -> Predicate -> - Predicate -simplifyConstraint def mbApi p = - either (const p) fst $ traceSimplifyConstraint def mbApi p - --- | Constraint simplification that collects a simplification trace -traceSimplifyConstraint :: - KoreDefinition -> - Maybe LLVM.API -> - Predicate -> - Either EquationFailure (Predicate, [EquationTrace]) -traceSimplifyConstraint def mbApi p = - runEquationM def mbApi $ simplifyConstraint' p + io (Either EquationFailure (Predicate, [EquationTrace])) +simplifyConstraint doTracing def mbApi p = + runEquationT doTracing def mbApi $ simplifyConstraint' p -- version for internal nested evaluation -simplifyConstraint' :: Predicate -> EquationM Predicate +simplifyConstraint' :: MonadLoggerIO io => Predicate -> EquationT io Predicate -- We are assuming all predicates are of the form 'P ==Bool true' and -- evaluating them using simplifyBool if they are concrete. -- Non-concrete \equals predicates are simplified using evaluateTerm. simplifyConstraint' = \case EqualsTerm t TrueBool | isConcrete t -> do - mbApi <- (.llvmApi) <$> getState + mbApi <- (.llvmApi) <$> getConfig case mbApi of Just api -> if simplifyBool api t @@ -562,9 +590,9 @@ simplifyConstraint' = \case FalseBool -> Bottom other -> EqualsTerm other TrueBool - evalBool :: Term -> EquationM Term + evalBool :: MonadLoggerIO io => Term -> EquationT io Term evalBool t = do prior <- getState -- save prior state so we can revert result <- iterateEquations 100 TopDown PreferFunctions t - EquationM $ put prior + EquationT $ put prior pure result diff --git a/library/Booster/Pattern/Rewrite.hs b/library/Booster/Pattern/Rewrite.hs index 592742426..97081103f 100644 --- a/library/Booster/Pattern/Rewrite.hs +++ b/library/Booster/Pattern/Rewrite.hs @@ -10,11 +10,12 @@ module Booster.Pattern.Rewrite ( RewriteFailed (..), RewriteResult (..), RewriteTrace (..), - runRewriteM, + runRewriteT, ) where import Control.Applicative ((<|>)) import Control.Monad +import Control.Monad.IO.Class (MonadIO (..)) import Control.Monad.Logger.CallStack import Control.Monad.Trans.Class import Control.Monad.Trans.Except @@ -51,17 +52,26 @@ import Booster.Pattern.Unify import Booster.Pattern.Util import Booster.Prettyprinter -newtype RewriteM err a = RewriteM {unRewriteM :: ReaderT (KoreDefinition, Maybe LLVM.API) (Except err) a} - deriving newtype (Functor, Applicative, Monad) +newtype RewriteT io err a = RewriteT {unRewriteT :: ReaderT RewriteConfig (ExceptT err io) a} + deriving newtype (Functor, Applicative, Monad, MonadLogger, MonadIO, MonadLoggerIO) -runRewriteM :: KoreDefinition -> Maybe LLVM.API -> RewriteM err a -> Either err a -runRewriteM def mLlvmLibrary = runExcept . flip runReaderT (def, mLlvmLibrary) . unRewriteM +data RewriteConfig = RewriteConfig + { definition :: KoreDefinition + , llvmApi :: Maybe LLVM.API + , doTracing :: Bool + } -throw :: err -> RewriteM err a -throw = RewriteM . lift . throwE +runRewriteT :: Bool -> KoreDefinition -> Maybe LLVM.API -> RewriteT io err a -> io (Either err a) +runRewriteT doTracing def mLlvmLibrary = + runExceptT + . flip runReaderT RewriteConfig{definition = def, llvmApi = mLlvmLibrary, doTracing} + . unRewriteT -getDefinition :: RewriteM err KoreDefinition -getDefinition = RewriteM $ fst <$> ask +throw :: MonadLoggerIO io => err -> RewriteT io err a +throw = RewriteT . lift . throwE + +getDefinition :: MonadLoggerIO io => RewriteT io err KoreDefinition +getDefinition = RewriteT $ definition <$> ask {- | Performs a rewrite step (using suitable rewrite rules from the definition). @@ -71,7 +81,11 @@ getDefinition = RewriteM $ fst <$> ask additional constraints. -} rewriteStep :: - [Text] -> [Text] -> Pattern -> RewriteM (RewriteFailed "Rewrite") (RewriteResult Pattern) + MonadLoggerIO io => + [Text] -> + [Text] -> + Pattern -> + RewriteT io (RewriteFailed "Rewrite") (RewriteResult Pattern) rewriteStep cutLabels terminalLabels pat = do let termIdx = kCellTermIndex pat.term when (termIdx == None) $ throw (TermIndexIsNone pat.term) @@ -90,7 +104,8 @@ rewriteStep cutLabels terminalLabels pat = do -- until a result is obtained or the entire rewrite fails. processGroups rules where - processGroups :: [[RewriteRule k]] -> RewriteM (RewriteFailed k) (RewriteResult Pattern) + processGroups :: + MonadLoggerIO io => [[RewriteRule k]] -> RewriteT io (RewriteFailed k) (RewriteResult Pattern) processGroups [] = throw (NoApplicableRules pat) processGroups (rules : rest) = do @@ -139,10 +154,11 @@ exception is thrown which indicates the exact reason why (this will abort the entire rewrite). -} applyRule :: - forall k. + forall io k. + MonadLoggerIO io => Pattern -> RewriteRule k -> - RewriteM (RewriteFailed k) (Maybe (RewriteRule k, Pattern)) + RewriteT io (RewriteFailed k) (Maybe (RewriteRule k, Pattern)) applyRule pat rule = runMaybeT $ do def <- lift getDefinition -- unify terms @@ -209,13 +225,15 @@ applyRule pat rule = runMaybeT $ do checkConstraint :: (Predicate -> a) -> Predicate -> - MaybeT (RewriteM (RewriteFailed k)) (Maybe a) + MaybeT (RewriteT io (RewriteFailed k)) (Maybe a) checkConstraint onUnclear p = do - (def, mApi) <- lift $ RewriteM ask - case simplifyConstraint def mApi p of - Bottom -> fail "Rule condition was False" - Top -> pure Nothing - other -> pure $ Just $ onUnclear other + RewriteConfig{definition, llvmApi, doTracing} <- lift $ RewriteT ask + simplified <- simplifyConstraint doTracing definition llvmApi p + case simplified of + Right (Bottom, _) -> fail "Rule condition was False" + Right (Top, _) -> pure Nothing + Right (other, _) -> pure $ Just $ onUnclear other + Left _ -> pure $ Just $ onUnclear p {- | Reason why a rewrite did not produce a result. Contains additional information for logging what happened during the rewrite. @@ -399,6 +417,8 @@ showPattern title pat = hang 4 $ vsep [title, pretty pat.term] performRewrite :: forall io. MonadLoggerIO io => + -- | whether to accumulate rewrite traces + Bool -> KoreDefinition -> Maybe LLVM.API -> -- | maximum depth @@ -409,7 +429,7 @@ performRewrite :: [Text] -> Pattern -> io (Natural, Seq (RewriteTrace Pattern), RewriteResult Pattern) -performRewrite def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pat = do +performRewrite doTracing def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pat = do (rr, (counter, traces)) <- flip runStateT (0, mempty) $ doSteps False pat pure (counter, traces, rr) where @@ -425,13 +445,13 @@ performRewrite def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pat = do rewriteTrace t = do logRewrite $ pack $ renderDefault $ pretty t - modify $ \(counter, traces) -> (counter, traces |> t) + when doTracing $ modify $ \(counter, traces) -> (counter, traces |> t) incrementCounter = modify $ \(counter, traces) -> (counter + 1, traces) simplifyP :: Pattern -> StateT (Natural, Seq (RewriteTrace Pattern)) io Pattern simplifyP p = do - let result = evaluateTerm TopDown def mLlvmLibrary p.term - case result of + let result = evaluateTerm doTracing TopDown def mLlvmLibrary p.term + result >>= \case Left r@(TooManyIterations n _ t) -> do logWarn $ "Simplification unable to finish in " <> prettyText n <> " steps." -- could output term before and after at debug or custom log level @@ -469,9 +489,9 @@ performRewrite def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pat = do (if wasSimplified then pure else mapM simplifyP) $ RewriteFinished Nothing Nothing pat' else do let res = - runRewriteM def mLlvmLibrary $ + runRewriteT doTracing def mLlvmLibrary $ rewriteStep cutLabels terminalLabels pat' - case res of + res >>= \case Right (RewriteFinished mlbl uniqueId single) -> do case mlbl of Just lbl -> rewriteTrace $ RewriteSingleStep lbl uniqueId pat' single diff --git a/unit-tests/Test/Booster/Pattern/ApplyEquations.hs b/unit-tests/Test/Booster/Pattern/ApplyEquations.hs index 989ff08af..74eaa0d63 100644 --- a/unit-tests/Test/Booster/Pattern/ApplyEquations.hs +++ b/unit-tests/Test/Booster/Pattern/ApplyEquations.hs @@ -10,9 +10,11 @@ module Test.Booster.Pattern.ApplyEquations ( test_errors, ) where +import Control.Monad.Logger (runNoLoggingT) import Data.Map (Map) import Data.Map qualified as Map import Data.Text (Text) +import GHC.IO.Unsafe (unsafePerformIO) import Test.Tasty import Test.Tasty.HUnit @@ -76,7 +78,7 @@ test_evaluateFunction = -- eval BottomUp subj @?= Right result ] where - eval direction = fmap fst . evaluateTerm direction funDef Nothing + eval direction = fmap fst . unsafePerformIO . runNoLoggingT . evaluateTerm False direction funDef Nothing a = var "A" someSort d = dv someSort "hey" apply f = app f . (: []) @@ -107,7 +109,7 @@ test_simplify = simpl BottomUp subj @?= Right result ] where - simpl direction = fmap fst . evaluateTerm direction simplDef Nothing + simpl direction = fmap fst . unsafePerformIO . runNoLoggingT . evaluateTerm False direction simplDef Nothing a = var "A" someSort test_errors :: TestTree @@ -120,7 +122,7 @@ test_errors = subj = f $ app con1 [a] loopTerms = [f $ app con1 [a], f $ app con2 [a], f $ app con3 [a, a], f $ app con1 [a]] - isLoop loopTerms $ evaluateTerm TopDown loopDef Nothing subj + isLoop loopTerms . unsafePerformIO . runNoLoggingT $ evaluateTerm False TopDown loopDef Nothing subj ] where isLoop ts (Left (EquationLoop _ ts')) = ts @?= ts' diff --git a/unit-tests/Test/Booster/Pattern/Rewrite.hs b/unit-tests/Test/Booster/Pattern/Rewrite.hs index 9388955af..f76cf45b8 100644 --- a/unit-tests/Test/Booster/Pattern/Rewrite.hs +++ b/unit-tests/Test/Booster/Pattern/Rewrite.hs @@ -15,6 +15,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map (Map) import Data.Map qualified as Map import Data.Text (Text) +import GHC.IO.Unsafe (unsafePerformIO) import Numeric.Natural import Test.Tasty import Test.Tasty.HUnit @@ -193,18 +194,18 @@ errorCases = [trm| kCell{}( kseq{}( inj{SomeSort{}, SortKItem{}}( con2{}( \dv{SomeSort{}}("thing") ) ), Thing:SortK{}) ) |] , testCase "Index is None" $ do let t = - [trm| - kCell{}( - kseq{}( - inj{SomeSort{}, SortKItem{}}( - \and{SomeSort{}}( - con1{}( \dv{SomeSort{}}("thing") ), + [trm| + kCell{}( + kseq{}( + inj{SomeSort{}, SortKItem{}}( + \and{SomeSort{}}( + con1{}( \dv{SomeSort{}}("thing") ), con2{}( \dv{SomeSort{}}("thing") ) ) - ), + ), Thing:SortK{} ) - ) + ) |] t `failsWith` TermIndexIsNone t ] @@ -252,25 +253,26 @@ rulePriority = rewritesTo :: Term -> (Text, Term) -> IO () t1 `rewritesTo` (lbl, t2) = - runRewriteM def Nothing (rewriteStep [] [] $ Pattern t1 []) + unsafePerformIO (runNoLoggingT $ runRewriteT False def Nothing (rewriteStep [] [] $ Pattern t1 [])) @?= Right (RewriteFinished (Just lbl) Nothing $ Pattern t2 []) branchesTo :: Term -> [(Text, Term)] -> IO () t `branchesTo` ts = - runRewriteM def Nothing (rewriteStep [] [] $ Pattern t []) + unsafePerformIO (runNoLoggingT $ runRewriteT False def Nothing (rewriteStep [] [] $ Pattern t [])) @?= Right (RewriteBranch (Pattern t []) $ NE.fromList $ map (\(lbl, t') -> (lbl, Nothing, Pattern t' [])) ts) failsWith :: Term -> RewriteFailed "Rewrite" -> IO () failsWith t err = - runRewriteM def Nothing (rewriteStep [] [] $ Pattern t []) @?= Left err + unsafePerformIO (runNoLoggingT $ runRewriteT False def Nothing (rewriteStep [] [] $ Pattern t [])) + @?= Left err ---------------------------------------- -- tests for performRewrite (iterated rewrite in IO with logging) runRewrite :: Term -> IO (Natural, RewriteResult Term) runRewrite t = do - (counter, _, res) <- runNoLoggingT $ performRewrite def Nothing Nothing [] [] $ Pattern t [] + (counter, _, res) <- runNoLoggingT $ performRewrite False def Nothing Nothing [] [] $ Pattern t [] pure (counter, fmap (.term) res) aborts :: Term -> IO () @@ -398,7 +400,8 @@ supportsDepthControl = where rewritesToDepth :: MaxDepth -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO () rewritesToDepth (MaxDepth depth) (Steps n) t t' f = do - (counter, _, res) <- runNoLoggingT $ performRewrite def Nothing (Just depth) [] [] $ Pattern t [] + (counter, _, res) <- + runNoLoggingT $ performRewrite False def Nothing (Just depth) [] [] $ Pattern t [] (counter, fmap (.term) res) @?= (n, f t') supportsCutPoints :: TestTree @@ -445,7 +448,8 @@ supportsCutPoints = where rewritesToCutPoint :: Text -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO () rewritesToCutPoint lbl (Steps n) t t' f = do - (counter, _, res) <- runNoLoggingT $ performRewrite def Nothing Nothing [lbl] [] $ Pattern t [] + (counter, _, res) <- + runNoLoggingT $ performRewrite False def Nothing Nothing [lbl] [] $ Pattern t [] (counter, fmap (.term) res) @?= (n, f t') supportsTerminalRules :: TestTree @@ -470,5 +474,6 @@ supportsTerminalRules = where rewritesToTerminal :: Text -> Steps -> Term -> t -> (t -> RewriteResult Term) -> IO () rewritesToTerminal lbl (Steps n) t t' f = do - (counter, _, res) <- runNoLoggingT $ performRewrite def Nothing Nothing [] [lbl] $ Pattern t [] + (counter, _, res) <- + runNoLoggingT $ performRewrite False def Nothing Nothing [] [lbl] $ Pattern t [] (counter, fmap (.term) res) @?= (n, f t')