Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions otherTests/saw-core/Tests/Rewriter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ prelude_bveq_sameL_test =
`pureApp` n
`mkApp` (mkGlobalDef "Prelude.bvNat" `pureApp` n `mkApp` mkNatLit 0)
`pureApp` z
(_, lhs_term) <- rewriteSharedTerm sc ss =<< scMkTerm sc lhs
(_, rhs_term) <- rewriteSharedTerm sc ss =<< scMkTerm sc rhs
(_, lhs_term) <- rewriteSharedTermConvertibility sc ss =<< scMkTerm sc lhs
(_, rhs_term) <- rewriteSharedTermConvertibility sc ss =<< scMkTerm sc rhs
assertEqual "Incorrect conversion\n" lhs_term rhs_term
2 changes: 1 addition & 1 deletion saw-central/src/SAWCentral/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ add_cryptol_defs = add_core_defs "Cryptol"
rewritePrim :: SV.SAWSimpset -> TypedTerm -> TopLevel TypedTerm
rewritePrim ss (TypedTerm schema t) = do
sc <- getSharedContext
(_,t') <- io $ rewriteSharedTerm sc ss t
(_,t') <- io $ rewriteSharedTermConvertibility sc ss t
return (TypedTerm schema t')

unfold_term :: [Text] -> TypedTerm -> TopLevel TypedTerm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import SAWCoreWhat4.ReturnTrip
import SAWCentral.Crucible.Common

import SAWCentral.Proof (TheoremNonce)
import SAWCore.Rewriter (Simpset, rewriteSharedTerm)
import SAWCore.Rewriter (Simpset, rewriteSharedTermConvertibility)
import qualified CryptolSAWCore.Simpset as Cryptol
import SAWCoreWhat4.What4(w4EvalAny, valueToSymExpr)

Expand Down Expand Up @@ -84,7 +84,7 @@ resolveTerm sym unint bt rr tm =
| rrWhat4Eval rr ->
do -- Try to use rewrites to simplify the term
cryptol_ss <- Cryptol.mkCryptolSimpset @TheoremNonce sc
tm'' <- snd <$> rewriteSharedTerm sc cryptol_ss tm'
tm'' <- snd <$> rewriteSharedTermConvertibility sc cryptol_ss tm'
tm''' <- basicRewrite sc tm''
if all isPreludeName (Map.elems (getConstantSet tm''')) then
do
Expand All @@ -105,7 +105,7 @@ resolveTerm sym unint bt rr tm =
basicRewrite sc =
case rrBasicSS rr of
Nothing -> pure
Just ss -> \t -> snd <$> rewriteSharedTerm sc ss t
Just ss -> \t -> snd <$> rewriteSharedTermConvertibility sc ss t

isPreludeName nm =
case nm of
Expand Down
2 changes: 1 addition & 1 deletion saw-central/src/SAWCentral/Proof.hs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ unfoldFixOnceProp sc unints (Prop tm) =
-- | Rewrite the proposition using the provided Simpset
simplifyProp :: Ord a => SharedContext -> Simpset a -> Prop -> IO (Set a, Prop)
simplifyProp sc ss (Prop tm) =
do (a, tm') <- rewriteSharedTerm sc ss tm
do (a, tm') <- rewriteSharedTermConvertibility sc ss tm
return (a, Prop tm')

-- | Rewrite the propositions using the provided Simpset
Expand Down
156 changes: 142 additions & 14 deletions saw-core/src/SAWCore/Rewriter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ module SAWCore.Rewriter
-- * Term rewriting
, rewriteSharedTerm
, rewriteSharedTermTypeSafe
, rewriteSharedTermConvertibility
-- * Matching
, scMatch
-- * Miscellaneous
Expand Down Expand Up @@ -106,15 +107,16 @@ data RewriteRule a
, permutative :: Bool
, shallow :: Bool
, annotation :: Maybe a
, convertible :: Bool -- ^ flag is true if the rule's LHS and RHS are convertible in SAWcore type system
}
deriving (Show)
-- ^ Invariant: The set of loose variables in @lhs@ must be exactly
-- @[0 .. length ctxt - 1]@. The @rhs@ may contain a subset of these.

-- NB, exclude the annotation from equality tests
instance Eq (RewriteRule a) where
RewriteRule c1 l1 r1 p1 s1 _a1 == RewriteRule c2 l2 r2 p2 s2 _a2 =
c1 == c2 && l1 == l2 && r1 == r2 && p1 == p2 && s1 == s2
RewriteRule c1 l1 r1 p1 s1 _a1 co1 == RewriteRule c2 l2 r2 p2 s2 _a2 co2 =
c1 == c2 && l1 == l2 && r1 == r2 && p1 == p2 && s1 == s2 && co1 == co2

ctxtRewriteRule :: RewriteRule a -> [(VarName, Term)]
ctxtRewriteRule = ctxt
Expand Down Expand Up @@ -354,7 +356,7 @@ ruleOfTerm :: Term -> Maybe a -> RewriteRule a
ruleOfTerm t ann =
do let (vars, body) = R.asPiList t
case R.asGlobalApply eqIdent body of
Just [_, x, y] -> mkRewriteRule vars x y False ann
Just [_, x, y] -> mkRewriteRule vars x y False ann False
_ -> panic "ruleOfTerm" ["Illegal argument"]

-- Test whether a rewrite rule is permutative
Expand All @@ -368,21 +370,22 @@ rulePermutes ctxt lhs rhs =
Nothing -> False -- but here we have a looping rule, not good!
Just _ -> True

mkRewriteRule :: [(VarName, Term)] -> Term -> Term -> Bool -> Maybe a -> RewriteRule a
mkRewriteRule c l r shallow ann =
mkRewriteRule :: [(VarName, Term)] -> Term -> Term -> Bool -> Maybe a -> Bool -> RewriteRule a
mkRewriteRule c l r shallow ann convFlag =
RewriteRule
{ ctxt = c
, lhs = l
, rhs = r
, permutative = rulePermutes c l r
, shallow = shallow
, annotation = ann
, convertible = convFlag
}

-- | Converts a universally quantified equality proposition between the
-- two given terms to a RewriteRule.
ruleOfTerms :: Term -> Term -> RewriteRule a
ruleOfTerms l r = mkRewriteRule [] l r False Nothing
ruleOfTerms l r = mkRewriteRule [] l r False Nothing False

-- | Converts a parameterized equality predicate to a RewriteRule,
-- returning 'Nothing' if the predicate is not an equation.
Expand Down Expand Up @@ -425,7 +428,7 @@ ruleOfProp sc term ann =
_ -> pure Nothing

where
eqRule x y = pure $ Just $ mkRewriteRule [] x y False ann
eqRule x y = pure $ Just $ mkRewriteRule [] x y False ann False

-- | Generate a rewrite rule from the type of an identifier, using 'ruleOfTerm'
scEqRewriteRule :: SharedContext -> Ident -> IO (RewriteRule a)
Expand All @@ -442,19 +445,19 @@ scEqsRewriteRules sc = mapM (scEqRewriteRule sc)
-- * If the rhs is a recursor, then split into a separate rule for each constructor.
-- * If the rhs is a record, then split into a separate rule for each accessor.
scExpandRewriteRule :: SharedContext -> RewriteRule a -> IO (Maybe [RewriteRule a])
scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) =
scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann convFlag) =
case R.asLambda rhs of
Just (nm, tp, body) ->
do let ctxt' = ctxt ++ [(nm, tp)]
var0 <- scVariable sc nm tp
lhs' <- scApply sc lhs var0
pure $ Just [mkRewriteRule ctxt' lhs' body shallow ann]
pure $ Just [mkRewriteRule ctxt' lhs' body shallow ann convFlag]
Nothing ->
case rhs of
(R.asRecordValue -> Just m) ->
do let mkRule (k, x) =
do l <- scRecordSelect sc lhs k
return (mkRewriteRule ctxt l x shallow ann)
return (mkRewriteRule ctxt l x shallow ann convFlag)
Just <$> traverse mkRule (Map.assocs m)
(R.asApplyAll ->
(R.asRecursorApp -> Just (r, crec),
Expand Down Expand Up @@ -495,9 +498,9 @@ scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) =
rhs2 <- scApplyAll sc rhs1 more'
rhs3 <- betaReduce rhs2
-- re-fold recursive occurrences of the original rhs
let ss = addRule (mkRewriteRule ctxt rhs lhs shallow Nothing) emptySimpset
let ss = addRule (mkRewriteRule ctxt rhs lhs shallow Nothing convFlag) emptySimpset
(_,rhs') <- rewriteSharedTerm sc (ss :: Simpset ()) rhs3
return (mkRewriteRule ctxt' lhs' rhs' shallow ann)
return (mkRewriteRule ctxt' lhs' rhs' shallow ann convFlag)
let d = recursorDataType crec
mm <- scGetModuleMap sc
dt <-
Expand Down Expand Up @@ -557,7 +560,7 @@ scDefRewriteRules sc d =
case defBody d of
Just rhs ->
do lhs <- scConst sc (defName d)
scExpandRewriteRules sc [mkRewriteRule [] lhs rhs False Nothing]
scExpandRewriteRules sc [mkRewriteRule [] lhs rhs False Nothing True]
Nothing ->
pure []

Expand Down Expand Up @@ -868,6 +871,131 @@ rewriteSharedTermTypeSafe sc ss t0 =
Nothing -> apply rules t
Just tb -> rewriteAll =<< runTermBuilder tb (scGlobalDef sc) (scTermF sc)

data Convertibility = AllRules | ConvertibleRulesOnly

rewriteSharedTermConvertibility :: forall a. Ord a => SharedContext -> Simpset a -> Term -> IO (Set a, Term)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to have both rewriteSharedTermTypeSafe and rewriteSharedTermConvertibility; your new version here should just replace rewriteSharedTermTypeSafe, as it's intended to be an improvement to it. Removing the original function will also make the diff easier to read, as it will highlight what you've changed vs what you've copied from the original.

rewriteSharedTermConvertibility sc ss t0 =
do cache <- newCache
let ?cache = cache
setRef <- newIORef mempty
let ?annSet = setRef
t <- rewriteAll AllRules t0
anns <- readIORef setRef
pure (anns, t)

where

rewriteAll :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Convertibility -> Term -> IO Term

rewriteAll convertibleFlag STApp{ stAppIndex = tidx, stAppTermF = tf } =
useCache ?cache tidx (rewriteTermF convertibleFlag tf >>= scTermF sc >>= rewriteTop convertibleFlag)

rewriteTermF :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Convertibility -> TermF Term -> IO (TermF Term)
rewriteTermF convertibleFlag tf =
case tf of
FTermF ftf -> FTermF <$> rewriteFTermF convertibleFlag ftf
App e1 e2 ->
do t1 <- scTypeOf sc e1
case unwrapTermF t1 of
-- If type of e1 is not a dependent type, we can use any rule to rewrite e2
-- otherwise, we only rewrite using convertible rules
-- This prevents rewriting e2 from changing type of @App e1 e2@.
Pi x _ t
| IntSet.notMember (vnIndex x) (freeVars t) ->
App <$> rewriteAll convertibleFlag e1 <*> rewriteAll convertibleFlag e2
_ -> App <$> rewriteAll convertibleFlag e1 <*> rewriteAll ConvertibleRulesOnly e2
-- could compute WHNF of t1 to see if it's a thing that doesn't match Pi but behaves like Pi
Lambda pat t e -> Lambda pat t <$> rewriteAll convertibleFlag e -- pat is x or varnames etc, t types (so don't rewrite), e is body
Constant{} -> return tf
Variable{} -> return tf
Pi x t1 t2 -> Pi x <$> rewriteAll convertibleFlag t1 <*> rewriteAll convertibleFlag t2

rewriteFTermF :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Convertibility -> FlatTermF Term -> IO (FlatTermF Term)
rewriteFTermF convertibleFlag ftf =
case ftf of
UnitValue -> return ftf
UnitType -> return ftf
PairValue{} -> traverse (rewriteAll convertibleFlag) ftf
PairType{} -> traverse (rewriteAll convertibleFlag) ftf
PairLeft{} -> traverse (rewriteAll convertibleFlag) ftf
PairRight{} -> traverse (rewriteAll convertibleFlag) ftf

-- NOTE: we don't rewrite arguments of constructors, datatypes, or
-- recursors because of dependent types, as we could potentially cause
-- a term to become ill-typed
Recursor{} -> return ftf -- just a function, atomic thus no subterms

RecordType{} -> traverse (rewriteAll convertibleFlag) ftf
RecordValue{} -> traverse (rewriteAll convertibleFlag) ftf
RecordProj{} -> traverse (rewriteAll convertibleFlag) ftf
Sort{} -> return ftf
NatLit{} -> return ftf
ArrayValue t es -> ArrayValue t <$> traverse (rewriteAll convertibleFlag) es -- specifically NOT rewriting type, only elts
StringLit{} -> return ftf

filterRules :: Convertibility -> Either (RewriteRule a) Conversion -> Bool
filterRules convertibleFlag (Left RewriteRule{convertible = ruleConvFlag}) =
case convertibleFlag of
ConvertibleRulesOnly -> ruleConvFlag
AllRules -> True
filterRules _ (Right _) = True

rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Convertibility -> Term -> IO Term
rewriteTop convertibleFlag t =
do mt <- reduceSharedTerm sc t
case mt of
Nothing -> apply convertibleFlag (filter (filterRules convertibleFlag) (Net.unify_term ss (termPat t))) t
Just t' -> rewriteAll convertibleFlag t'

recordAnn :: (?annSet :: IORef (Set a)) => Maybe a -> IO ()
recordAnn Nothing = return ()
recordAnn (Just a) = modifyIORef' ?annSet (Set.insert a)

apply :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Convertibility -> [Either (RewriteRule a) Conversion] -> Term -> IO Term
apply _ [] t = return t
apply convertibleFlag (Left (RewriteRule {ctxt, lhs, rhs, permutative, shallow, annotation}) : rules) t = do
-- if rewrite rule
result <- scMatch sc ctxt lhs t
case result of
Nothing -> apply convertibleFlag rules t
Just inst
| lhs == rhs ->
-- This should never happen because we avoid inserting
-- reflexive rules into simp sets in the first place.
do putStrLn $ "rewriteSharedTerm: skipping reflexive rule " ++
"(THE IMPOSSIBLE HAPPENED!): " ++ scPrettyTerm PPS.defaultOpts lhs
apply convertibleFlag rules t
| IntMap.keysSet inst /= IntSet.fromList (map (vnIndex . fst) ctxt) ->
do putStrLn $ "rewriteSharedTerm: invalid lhs does not contain all variables: "
++ scPrettyTerm PPS.defaultOpts lhs
apply convertibleFlag rules t
| permutative ->
do
t' <- scInstantiate sc inst rhs
case termWeightLt t' t of
True -> recordAnn annotation >> rewriteAll convertibleFlag t' -- keep the result only if it is "smaller"
False -> apply convertibleFlag rules t
| shallow ->
-- do not to further rewriting to the result of a "shallow" rule
do recordAnn annotation
scInstantiate sc inst rhs
| otherwise ->
do -- putStrLn "REWRITING:"
-- print lhs
recordAnn annotation
rewriteAll convertibleFlag =<< scInstantiate sc inst rhs
-- instead of syntactic rhs, has a bit of code that rewrites lhs (Term -> Maybe Term)
-- for now, assume all conversions are convertible, maybe check later
apply convertibleFlag (Right conv : rules) t =
do -- putStrLn "REWRITING:"
-- print (Net.toPat conv)
case runConversion conv t of
Nothing -> apply convertibleFlag rules t
Just tb -> rewriteAll convertibleFlag =<< runTermBuilder tb (scGlobalDef sc) (scTermF sc)


-- FIXME: is there some way to have sensable term replacement in the presence of loose variables
-- and/or under binders?
Expand Down Expand Up @@ -924,7 +1052,7 @@ hoistIfs sc t = do
]
let ss :: Simpset () = addRules rules emptySimpset

(t', conds) <- doHoistIfs sc ss cache . snd =<< rewriteSharedTerm sc ss t
(t', conds) <- doHoistIfs sc ss cache . snd =<< rewriteSharedTermConvertibility sc ss t

-- remove duplicate conditions from the list, as muxing in SAW can result in
-- many copies of the same condition, which cause a performance issue
Expand Down
2 changes: 1 addition & 1 deletion saw-core/src/SAWCore/Term/Certified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ rawCtx (Term ctx _ _) = ctx
-- the SAWCore type system.
scTypeCheckWHNF :: SharedContext -> Raw.Term -> IO Raw.Term
scTypeCheckWHNF sc t =
do (_, t') <- rewriteSharedTerm sc (addConvs natConversions emptySimpset :: Simpset ()) t
do (_, t') <- rewriteSharedTermConvertibility sc (addConvs natConversions emptySimpset :: Simpset ()) t
Raw.scWhnf sc t'

-- | Check if two terms are "convertible for type-checking", meaning that they
Expand Down
Loading