Skip to content

Commit

Permalink
Trace removal re-implemented as a IR rewrite rule (#5907)
Browse files Browse the repository at this point in the history
* refactor: Internal module for RewriteRules, Monoid Instance

* Trace removal re-implemented as a IR rewrite rule

* Test case for an impure trace message, added note.
  • Loading branch information
Unisay authored Apr 17, 2024
1 parent 16a986f commit 7db7e37
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 75 deletions.
5 changes: 3 additions & 2 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ executable uplc
library plutus-ir
import: lang
visibility: public
hs-source-dirs: plutus-ir/src
exposed-modules:
PlutusIR
PlutusIR.Analysis.Builtins
Expand Down Expand Up @@ -537,14 +538,14 @@ library plutus-ir
PlutusIR.Transform.Rename
PlutusIR.Transform.RewriteRules
PlutusIR.Transform.RewriteRules.CommuteFnWithConst
PlutusIR.Transform.RewriteRules.RemoveTrace
PlutusIR.Transform.StrictifyBindings
PlutusIR.Transform.Substitute
PlutusIR.Transform.ThunkRecursions
PlutusIR.Transform.Unwrap
PlutusIR.TypeCheck
PlutusIR.TypeCheck.Internal

hs-source-dirs: plutus-ir/src
other-modules:
PlutusIR.Analysis.Definitions
PlutusIR.Analysis.Size
Expand All @@ -554,7 +555,7 @@ library plutus-ir
PlutusIR.Compiler.Recursion
PlutusIR.Normalize
PlutusIR.Transform.RewriteRules.Common
PlutusIR.Transform.RewriteRules.Rules
PlutusIR.Transform.RewriteRules.Internal
PlutusIR.Transform.RewriteRules.UnConstrConstrData

build-depends:
Expand Down
2 changes: 1 addition & 1 deletion plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import PlutusCore.Quote
import PlutusCore.StdLib.Type qualified as Types
import PlutusCore.TypeCheck.Internal qualified as PLC
import PlutusCore.Version qualified as PLC
import PlutusIR.Transform.RewriteRules.Rules
import PlutusIR.Transform.RewriteRules.Internal (RewriteRules)
import PlutusPrelude

import Control.Monad.Error.Lens (throwing)
Expand Down
9 changes: 5 additions & 4 deletions plutus-core/plutus-ir/src/PlutusIR/Transform/RewriteRules.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ module PlutusIR.Transform.RewriteRules
( rewriteWith
, rewritePass
, rewritePassSC
, RewriteRules (..)
, RewriteRules
, unRewriteRules
, defaultUniRewriteRules
) where

Expand All @@ -16,7 +17,7 @@ import PlutusCore.Name.Unique
import PlutusCore.Quote
import PlutusIR as PIR
import PlutusIR.Analysis.VarInfo
import PlutusIR.Transform.RewriteRules.Rules
import PlutusIR.Transform.RewriteRules.Internal

import Control.Lens
import PlutusIR.Pass
Expand Down Expand Up @@ -61,11 +62,11 @@ rewriteWith :: ( Monoid a, t ~ Term tyname Name uni fun a
=> RewriteRules uni fun
-> t
-> m t
rewriteWith (RewriteRules rules) t =
rewriteWith rules t =
-- We collect `VarsInfo` on the whole program term and pass it on as arg to each RewriteRule.
-- This has the limitation that any variables newly-introduced by the rules would
-- not be accounted in `VarsInfo`. This is currently fine, because we only rely on VarsInfo
-- for isPure; isPure is safe w.r.t "open" terms.
let vinfo = termVarInfo t
in transformMOf termSubterms (rules vinfo) t
in transformMOf termSubterms (unRewriteRules rules vinfo) t

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}

module PlutusIR.Transform.RewriteRules.Internal
( RewriteRules (..)
, defaultUniRewriteRules
) where

import PlutusCore.Default (DefaultFun, DefaultUni)
import PlutusCore.Name.Unique (Name)
import PlutusCore.Quote (MonadQuote)
import PlutusIR.Analysis.VarInfo (VarsInfo)
import PlutusIR.Core.Type qualified as PIR
import PlutusIR.Transform.RewriteRules.CommuteFnWithConst (commuteFnWithConst)
import PlutusIR.Transform.RewriteRules.UnConstrConstrData (unConstrConstrData)
import PlutusPrelude (Default (..), (>=>))

-- | A bundle of composed `RewriteRules`, to be passed at entrypoint of the compiler.
newtype RewriteRules uni fun where
RewriteRules
:: { unRewriteRules
:: forall tyname m a
. (MonadQuote m, Monoid a)
=> VarsInfo tyname Name uni a
-> PIR.Term tyname Name uni fun a
-> m (PIR.Term tyname Name uni fun a)
}
-> RewriteRules uni fun

-- | The rules for the Default Universe/Builtin.
defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules = RewriteRules $ \varsInfo ->
-- The rules are composed from left to right.
pure . commuteFnWithConst >=> unConstrConstrData def varsInfo

instance Default (RewriteRules DefaultUni DefaultFun) where
def = defaultUniRewriteRules

instance Semigroup (RewriteRules uni fun) where
RewriteRules r1 <> RewriteRules r2 = RewriteRules (\varsInfo -> r1 varsInfo >=> r2 varsInfo)

instance Monoid (RewriteRules uni fun) where
mempty = RewriteRules (const pure)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}

module PlutusIR.Transform.RewriteRules.RemoveTrace
( rewriteRuleRemoveTrace
) where

import PlutusCore.Default (DefaultFun)
import PlutusCore.Default.Builtins qualified as Builtin
import PlutusIR.Transform.RewriteRules.Common (pattern A, pattern B, pattern I)
import PlutusIR.Transform.RewriteRules.Internal (RewriteRules (..))

{- Note [Impure trace messages]
Removing of traces could change behavior of those programs that use impure trace messages
e.g. `trace (error ()) foo`.
While it is possible to force evaluation of a trace message when removing a trace call
for the sake of a behavior preservation, this has a downside that pure messages remain
in the program and are not elimitated as a "dead" code.
This downside would defeat the purpose of removing traces, so we decided to not force.
-}

rewriteRuleRemoveTrace :: RewriteRules uni DefaultFun
rewriteRuleRemoveTrace = RewriteRules \_varsInfo -> \case
B Builtin.Trace `I` _argTy `A` _msg `A` arg -> pure arg
term -> pure term
32 changes: 0 additions & 32 deletions plutus-core/plutus-ir/src/PlutusIR/Transform/RewriteRules/Rules.hs

This file was deleted.

27 changes: 2 additions & 25 deletions plutus-tx-plugin/src/PlutusTx/Compiler/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import GHC.Types.TyThing qualified as GHC

import Language.Haskell.TH.Syntax qualified as TH

import Control.Monad.Reader (ask, asks)
import Control.Monad.Reader (asks)

import Data.ByteString qualified as BS
import Data.Foldable (for_)
Expand Down Expand Up @@ -301,8 +301,6 @@ defineBuiltinType name ty = do
-- | Add definitions for all the builtin terms to the environment.
defineBuiltinTerms :: CompilingDefault uni fun m ann => m ()
defineBuiltinTerms = do
CompileContext {ccOpts=compileOpts} <- ask

-- Error
-- See Note [Delaying error]
func <- delayedErrorFunc
Expand Down Expand Up @@ -380,28 +378,7 @@ defineBuiltinTerms = do
PLC.EqualsInteger -> defineBuiltinInl 'Builtins.equalsInteger

-- Tracing
-- When `remove-trace` is specified, we define `trace` as `\_ a -> a` instead of the
-- version.
PLC.Trace -> do
(traceTerm, ann) <-
if coRemoveTrace compileOpts
then liftQuote $ do
ta <- freshTyName "a"
t <- freshName "t"
a <- freshName "a"
pure
( PIR.tyAbs annMayInline ta (PLC.Type annMayInline) $
PIR.mkIterLamAbs
[ PIR.VarDecl annMayInline t $
PIR.mkTyBuiltin @_ @Text annMayInline
, PIR.VarDecl annMayInline a $
PLC.TyVar annMayInline ta
]
$ PIR.Var annMayInline a
, annMayInline
)
else pure (mkBuiltin PLC.Trace, annMayInline)
defineBuiltinTerm ann 'Builtins.trace traceTerm
PLC.Trace -> defineBuiltinInl 'Builtins.trace

-- Pairs
PLC.FstPair -> defineBuiltinInl 'Builtins.fst
Expand Down
16 changes: 15 additions & 1 deletion plutus-tx-plugin/src/PlutusTx/Plugin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ import Data.ByteString qualified as BS
import Data.ByteString.Unsafe qualified as BSUnsafe
import Data.Either.Validation
import Data.Map qualified as Map
import Data.Monoid.Extra (mwhen)
import Data.Set qualified as Set
import Data.Type.Bool qualified as PlutusTx.Bool
import GHC.Num.Integer qualified
import PlutusCore.Default (DefaultFun, DefaultUni)
import PlutusIR.Analysis.Builtins
import PlutusIR.Compiler.Provenance (noProvenance, original)
import PlutusIR.Compiler.Types qualified as PIR
import PlutusIR.Transform.RewriteRules
import PlutusIR.Transform.RewriteRules.RemoveTrace (rewriteRuleRemoveTrace)
import Prettyprinter qualified as PP
import System.IO (openTempFile)
import System.IO.Unsafe (unsafePerformIO)
Expand Down Expand Up @@ -423,7 +426,7 @@ compileMarkedExpr locStr codeTy origE = do
ccBuiltinsInfo = def,
ccBuiltinCostModel = def,
ccDebugTraceOn = _posDumpCompilationTrace opts,
ccRewriteRules = def
ccRewriteRules = makeRewriteRules opts
}
st = CompileState 0 mempty
-- See Note [Occurrence analysis]
Expand Down Expand Up @@ -482,6 +485,9 @@ runCompiler moduleName opts expr = do
PIR.DatatypeComponent PIR.Destructor _ -> True
_ ->
AlwaysInline `elem` fmap annInline (toList ann)

rewriteRules <- asks ccRewriteRules

-- Compilation configuration
-- pir's tc-config is based on plc tcconfig
let pirTcConfig = PIR.PirTCConfig plcTcConfig PIR.YesEscape
Expand Down Expand Up @@ -524,6 +530,7 @@ runCompiler moduleName opts expr = do
-- TODO: ensure the same as the one used in the plugin
& set PIR.ccBuiltinsInfo def
& set PIR.ccBuiltinCostModel def
& set PIR.ccRewriteRules rewriteRules
plcOpts = PLC.defaultCompilationOpts
& set (PLC.coSimplifyOpts . UPLC.soMaxSimplifierIterations)
(opts ^. posMaxSimplifierIterationsUPlc)
Expand Down Expand Up @@ -642,3 +649,10 @@ makePrimitiveNameInfo names = do
thing <- lift . lift $ GHC.lookupThing ghcName
pure (name, thing)
pure $ Map.fromList infos

makeRewriteRules :: PluginOptions -> RewriteRules DefaultUni DefaultFun
makeRewriteRules options =
fold
[ mwhen (options ^. posRemoveTrace) rewriteRuleRemoveTrace
, defaultUniRewriteRules
]
29 changes: 22 additions & 7 deletions plutus-tx-plugin/test/Plugin/NoTrace/Lib.hs
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-unused-foralls #-}

module Plugin.NoTrace.Lib where

import Control.Lens (universeOf, (^.))
import Data.Int (Int)
import Data.List (length)
import Prelude hiding (Show, show, (+))

import Control.Lens (universeOf, (&), (^.))
import GHC.Exts (noinline)
import PlutusCore.Builtin.Debug qualified as Builtin
import PlutusTx.Bool (Bool)
import PlutusTx.Builtins (BuiltinString, Integer, appendString)
import PlutusTx.Code (CompiledCode, getPlcNoAnn)
import PlutusCore.Default.Builtins qualified as Builtin
import PlutusCore.Evaluation.Machine.ExBudgetingDefaults (defaultCekParameters)
import PlutusTx.Builtins (BuiltinString, appendString, error)
import PlutusTx.Code (CompiledCode, getPlc, getPlcNoAnn)
import PlutusTx.Numeric ((+))
import PlutusTx.Show.TH (Show (show))
import PlutusTx.Trace (trace, traceError)
import UntypedPlutusCore qualified as UPLC
import UntypedPlutusCore.Evaluation.Machine.Cek (counting, noEmitter)
import UntypedPlutusCore.Evaluation.Machine.Cek.Internal (runCekDeBruijn)

data Arg = MkArg

Expand All @@ -32,6 +35,15 @@ countTraces code =
, subterm@(UPLC.Builtin _ Builtin.Trace) <- universeOf UPLC.termSubterms term
]

evaluatesToError :: CompiledCode a -> Bool
evaluatesToError = not . evaluatesWithoutError

evaluatesWithoutError :: CompiledCode a -> Bool
evaluatesWithoutError code =
runCekDeBruijn defaultCekParameters counting noEmitter (getPlc code ^. UPLC.progTerm) & \case
(Left _exception, _counter, _logs) -> False
(Right _result, _counter, _logs) -> True

----------------------------------------------------------------------------------------------------
-- Functions that contain traces -------------------------------------------------------------------

Expand Down Expand Up @@ -62,3 +74,6 @@ traceRepeatedly =
i2 = trace "Making my second int" (2 :: Integer)
i3 = trace "Adding them up" (i1 + i2)
in i3

traceImpure :: ()
traceImpure = trace ("Message: " `appendString` PlutusTx.Builtins.error ()) ()
15 changes: 12 additions & 3 deletions plutus-tx-plugin/test/Plugin/NoTrace/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ module Plugin.NoTrace.Spec where
import Prelude

import Plugin.NoTrace.Lib (countTraces)
import Plugin.NoTrace.Lib qualified as Lib
import Plugin.NoTrace.WithoutTraces qualified as WithoutTraces
import Plugin.NoTrace.WithTraces qualified as WithTraces
import Test.Tasty (testGroup)
import Test.Tasty.Extras (TestNested)
import Test.Tasty.HUnit (testCase, (@=?))
import Test.Tasty.HUnit (assertBool, testCase, (@=?))

noTrace :: TestNested
noTrace = pure do
testGroup
"remove-trace"
[ testGroup
"Trace calls are present"
"Trace calls are preserved"
[ testCase "trace-argument" $
1 @=? countTraces WithTraces.traceArgument
, testCase "trace-show" $
Expand All @@ -33,9 +34,13 @@ noTrace = pure do
1 @=? countTraces WithTraces.traceNonConstant
, testCase "trace-repeatedly" $
3 @=? countTraces WithTraces.traceRepeatedly
, testCase "trace-impure" $
1 @=? countTraces WithTraces.traceImpure
, testCase "trace-impure with effect" $ -- See note [Impure trace messages]
assertBool "Effect is missing" (Lib.evaluatesToError WithTraces.traceImpure)
]
, testGroup
"Trace calls are absent"
"Trace calls are removed"
[ testCase "trace-argument" $
0 @=? countTraces WithoutTraces.traceArgument
, testCase "trace-show" $
Expand All @@ -48,5 +53,9 @@ noTrace = pure do
0 @=? countTraces WithoutTraces.traceNonConstant
, testCase "trace-repeatedly" $
0 @=? countTraces WithoutTraces.traceRepeatedly
, testCase "trace-impure" $
0 @=? countTraces WithoutTraces.traceImpure
, testCase "trace-impure without effect" $ -- See note [Impure trace messages]
assertBool "Effect wasn't erased" (Lib.evaluatesWithoutError WithoutTraces.traceImpure)
]
]
3 changes: 3 additions & 0 deletions plutus-tx-plugin/test/Plugin/NoTrace/WithTraces.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ traceComplex = plc (Proxy @"traceComplex") Lib.traceComplex

traceRepeatedly :: CompiledCode Integer
traceRepeatedly = plc (Proxy @"traceRepeatedly") Lib.traceRepeatedly

traceImpure :: CompiledCode ()
traceImpure = plc (Proxy @"traceImpure") Lib.traceImpure
Loading

0 comments on commit 7db7e37

Please sign in to comment.