-
Notifications
You must be signed in to change notification settings - Fork 77
rewriteSharedTermConvertibility: rewrite terms without breaking type safety #2750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ module SAWCore.Rewriter | |
| -- * Term rewriting | ||
| , rewriteSharedTerm | ||
| , rewriteSharedTermTypeSafe | ||
| , rewriteSharedTermConvertibility | ||
| -- * Matching | ||
| , scMatch | ||
| -- * Miscellaneous | ||
|
|
@@ -106,15 +107,16 @@ data RewriteRule a | |
| , permutative :: Bool | ||
| , shallow :: Bool | ||
| , annotation :: Maybe a | ||
| , convertible :: Bool -- ^ flag is true if the rule's LHS and RHS are convertible in SAWcore type system | ||
| } | ||
| deriving (Show) | ||
| -- ^ Invariant: The set of loose variables in @lhs@ must be exactly | ||
| -- @[0 .. length ctxt - 1]@. The @rhs@ may contain a subset of these. | ||
|
|
||
| -- NB, exclude the annotation from equality tests | ||
| instance Eq (RewriteRule a) where | ||
| RewriteRule c1 l1 r1 p1 s1 _a1 == RewriteRule c2 l2 r2 p2 s2 _a2 = | ||
| c1 == c2 && l1 == l2 && r1 == r2 && p1 == p2 && s1 == s2 | ||
| RewriteRule c1 l1 r1 p1 s1 _a1 co1 == RewriteRule c2 l2 r2 p2 s2 _a2 co2 = | ||
| c1 == c2 && l1 == l2 && r1 == r2 && p1 == p2 && s1 == s2 && co1 == co2 | ||
|
|
||
| ctxtRewriteRule :: RewriteRule a -> [(VarName, Term)] | ||
| ctxtRewriteRule = ctxt | ||
|
|
@@ -354,7 +356,7 @@ ruleOfTerm :: Term -> Maybe a -> RewriteRule a | |
| ruleOfTerm t ann = | ||
| do let (vars, body) = R.asPiList t | ||
| case R.asGlobalApply eqIdent body of | ||
| Just [_, x, y] -> mkRewriteRule vars x y False ann | ||
| Just [_, x, y] -> mkRewriteRule vars x y False ann False | ||
| _ -> panic "ruleOfTerm" ["Illegal argument"] | ||
|
|
||
| -- Test whether a rewrite rule is permutative | ||
|
|
@@ -368,21 +370,22 @@ rulePermutes ctxt lhs rhs = | |
| Nothing -> False -- but here we have a looping rule, not good! | ||
| Just _ -> True | ||
|
|
||
| mkRewriteRule :: [(VarName, Term)] -> Term -> Term -> Bool -> Maybe a -> RewriteRule a | ||
| mkRewriteRule c l r shallow ann = | ||
| mkRewriteRule :: [(VarName, Term)] -> Term -> Term -> Bool -> Maybe a -> Bool -> RewriteRule a | ||
| mkRewriteRule c l r shallow ann convFlag = | ||
| RewriteRule | ||
| { ctxt = c | ||
| , lhs = l | ||
| , rhs = r | ||
| , permutative = rulePermutes c l r | ||
| , shallow = shallow | ||
| , annotation = ann | ||
| , convertible = convFlag | ||
| } | ||
|
|
||
| -- | Converts a universally quantified equality proposition between the | ||
| -- two given terms to a RewriteRule. | ||
| ruleOfTerms :: Term -> Term -> RewriteRule a | ||
| ruleOfTerms l r = mkRewriteRule [] l r False Nothing | ||
| ruleOfTerms l r = mkRewriteRule [] l r False Nothing False | ||
|
|
||
| -- | Converts a parameterized equality predicate to a RewriteRule, | ||
| -- returning 'Nothing' if the predicate is not an equation. | ||
|
|
@@ -425,7 +428,7 @@ ruleOfProp sc term ann = | |
| _ -> pure Nothing | ||
|
|
||
| where | ||
| eqRule x y = pure $ Just $ mkRewriteRule [] x y False ann | ||
| eqRule x y = pure $ Just $ mkRewriteRule [] x y False ann False | ||
|
|
||
| -- | Generate a rewrite rule from the type of an identifier, using 'ruleOfTerm' | ||
| scEqRewriteRule :: SharedContext -> Ident -> IO (RewriteRule a) | ||
|
|
@@ -442,19 +445,19 @@ scEqsRewriteRules sc = mapM (scEqRewriteRule sc) | |
| -- * If the rhs is a recursor, then split into a separate rule for each constructor. | ||
| -- * If the rhs is a record, then split into a separate rule for each accessor. | ||
| scExpandRewriteRule :: SharedContext -> RewriteRule a -> IO (Maybe [RewriteRule a]) | ||
| scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) = | ||
| scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann convFlag) = | ||
| case R.asLambda rhs of | ||
| Just (nm, tp, body) -> | ||
| do let ctxt' = ctxt ++ [(nm, tp)] | ||
| var0 <- scVariable sc nm tp | ||
| lhs' <- scApply sc lhs var0 | ||
| pure $ Just [mkRewriteRule ctxt' lhs' body shallow ann] | ||
| pure $ Just [mkRewriteRule ctxt' lhs' body shallow ann convFlag] | ||
| Nothing -> | ||
| case rhs of | ||
| (R.asRecordValue -> Just m) -> | ||
| do let mkRule (k, x) = | ||
| do l <- scRecordSelect sc lhs k | ||
| return (mkRewriteRule ctxt l x shallow ann) | ||
| return (mkRewriteRule ctxt l x shallow ann convFlag) | ||
| Just <$> traverse mkRule (Map.assocs m) | ||
| (R.asApplyAll -> | ||
| (R.asRecursorApp -> Just (r, crec), | ||
|
|
@@ -495,9 +498,9 @@ scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) = | |
| rhs2 <- scApplyAll sc rhs1 more' | ||
| rhs3 <- betaReduce rhs2 | ||
| -- re-fold recursive occurrences of the original rhs | ||
| let ss = addRule (mkRewriteRule ctxt rhs lhs shallow Nothing) emptySimpset | ||
| let ss = addRule (mkRewriteRule ctxt rhs lhs shallow Nothing convFlag) emptySimpset | ||
| (_,rhs') <- rewriteSharedTerm sc (ss :: Simpset ()) rhs3 | ||
| return (mkRewriteRule ctxt' lhs' rhs' shallow ann) | ||
| return (mkRewriteRule ctxt' lhs' rhs' shallow ann convFlag) | ||
| let d = recursorDataType crec | ||
| mm <- scGetModuleMap sc | ||
| dt <- | ||
|
|
@@ -557,7 +560,7 @@ scDefRewriteRules sc d = | |
| case defBody d of | ||
| Just rhs -> | ||
| do lhs <- scConst sc (defName d) | ||
| scExpandRewriteRules sc [mkRewriteRule [] lhs rhs False Nothing] | ||
| scExpandRewriteRules sc [mkRewriteRule [] lhs rhs False Nothing True] | ||
| Nothing -> | ||
| pure [] | ||
|
|
||
|
|
@@ -868,6 +871,131 @@ rewriteSharedTermTypeSafe sc ss t0 = | |
| Nothing -> apply rules t | ||
| Just tb -> rewriteAll =<< runTermBuilder tb (scGlobalDef sc) (scTermF sc) | ||
|
|
||
| data Convertibility = AllRules | ConvertibleRulesOnly | ||
|
|
||
| rewriteSharedTermConvertibility :: forall a. Ord a => SharedContext -> Simpset a -> Term -> IO (Set a, Term) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to have both |
||
| rewriteSharedTermConvertibility sc ss t0 = | ||
| do cache <- newCache | ||
| let ?cache = cache | ||
| setRef <- newIORef mempty | ||
| let ?annSet = setRef | ||
| t <- rewriteAll AllRules t0 | ||
| anns <- readIORef setRef | ||
| pure (anns, t) | ||
|
|
||
| where | ||
|
|
||
| rewriteAll :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Convertibility -> Term -> IO Term | ||
|
|
||
| rewriteAll convertibleFlag STApp{ stAppIndex = tidx, stAppTermF = tf } = | ||
| useCache ?cache tidx (rewriteTermF convertibleFlag tf >>= scTermF sc >>= rewriteTop convertibleFlag) | ||
|
|
||
| rewriteTermF :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => | ||
| Convertibility -> TermF Term -> IO (TermF Term) | ||
| rewriteTermF convertibleFlag tf = | ||
| case tf of | ||
| FTermF ftf -> FTermF <$> rewriteFTermF convertibleFlag ftf | ||
| App e1 e2 -> | ||
| do t1 <- scTypeOf sc e1 | ||
| case unwrapTermF t1 of | ||
| -- If type of e1 is not a dependent type, we can use any rule to rewrite e2 | ||
| -- otherwise, we only rewrite using convertible rules | ||
| -- This prevents rewriting e2 from changing type of @App e1 e2@. | ||
| Pi x _ t | ||
| | IntSet.notMember (vnIndex x) (freeVars t) -> | ||
| App <$> rewriteAll convertibleFlag e1 <*> rewriteAll convertibleFlag e2 | ||
| _ -> App <$> rewriteAll convertibleFlag e1 <*> rewriteAll ConvertibleRulesOnly e2 | ||
| -- could compute WHNF of t1 to see if it's a thing that doesn't match Pi but behaves like Pi | ||
| Lambda pat t e -> Lambda pat t <$> rewriteAll convertibleFlag e -- pat is x or varnames etc, t types (so don't rewrite), e is body | ||
| Constant{} -> return tf | ||
| Variable{} -> return tf | ||
| Pi x t1 t2 -> Pi x <$> rewriteAll convertibleFlag t1 <*> rewriteAll convertibleFlag t2 | ||
|
|
||
| rewriteFTermF :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => | ||
| Convertibility -> FlatTermF Term -> IO (FlatTermF Term) | ||
| rewriteFTermF convertibleFlag ftf = | ||
| case ftf of | ||
| UnitValue -> return ftf | ||
| UnitType -> return ftf | ||
| PairValue{} -> traverse (rewriteAll convertibleFlag) ftf | ||
| PairType{} -> traverse (rewriteAll convertibleFlag) ftf | ||
| PairLeft{} -> traverse (rewriteAll convertibleFlag) ftf | ||
| PairRight{} -> traverse (rewriteAll convertibleFlag) ftf | ||
|
|
||
| -- NOTE: we don't rewrite arguments of constructors, datatypes, or | ||
| -- recursors because of dependent types, as we could potentially cause | ||
| -- a term to become ill-typed | ||
jn80842 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Recursor{} -> return ftf -- just a function, atomic thus no subterms | ||
|
|
||
| RecordType{} -> traverse (rewriteAll convertibleFlag) ftf | ||
| RecordValue{} -> traverse (rewriteAll convertibleFlag) ftf | ||
| RecordProj{} -> traverse (rewriteAll convertibleFlag) ftf | ||
| Sort{} -> return ftf | ||
| NatLit{} -> return ftf | ||
| ArrayValue t es -> ArrayValue t <$> traverse (rewriteAll convertibleFlag) es -- specifically NOT rewriting type, only elts | ||
| StringLit{} -> return ftf | ||
|
|
||
| filterRules :: Convertibility -> Either (RewriteRule a) Conversion -> Bool | ||
| filterRules convertibleFlag (Left RewriteRule{convertible = ruleConvFlag}) = | ||
| case convertibleFlag of | ||
| ConvertibleRulesOnly -> ruleConvFlag | ||
| AllRules -> True | ||
| filterRules _ (Right _) = True | ||
|
|
||
| rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Convertibility -> Term -> IO Term | ||
| rewriteTop convertibleFlag t = | ||
| do mt <- reduceSharedTerm sc t | ||
| case mt of | ||
| Nothing -> apply convertibleFlag (filter (filterRules convertibleFlag) (Net.unify_term ss (termPat t))) t | ||
| Just t' -> rewriteAll convertibleFlag t' | ||
|
|
||
| recordAnn :: (?annSet :: IORef (Set a)) => Maybe a -> IO () | ||
| recordAnn Nothing = return () | ||
| recordAnn (Just a) = modifyIORef' ?annSet (Set.insert a) | ||
|
|
||
| apply :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => | ||
| Convertibility -> [Either (RewriteRule a) Conversion] -> Term -> IO Term | ||
| apply _ [] t = return t | ||
| apply convertibleFlag (Left (RewriteRule {ctxt, lhs, rhs, permutative, shallow, annotation}) : rules) t = do | ||
| -- if rewrite rule | ||
| result <- scMatch sc ctxt lhs t | ||
| case result of | ||
| Nothing -> apply convertibleFlag rules t | ||
| Just inst | ||
| | lhs == rhs -> | ||
| -- This should never happen because we avoid inserting | ||
| -- reflexive rules into simp sets in the first place. | ||
| do putStrLn $ "rewriteSharedTerm: skipping reflexive rule " ++ | ||
| "(THE IMPOSSIBLE HAPPENED!): " ++ scPrettyTerm PPS.defaultOpts lhs | ||
| apply convertibleFlag rules t | ||
| | IntMap.keysSet inst /= IntSet.fromList (map (vnIndex . fst) ctxt) -> | ||
| do putStrLn $ "rewriteSharedTerm: invalid lhs does not contain all variables: " | ||
| ++ scPrettyTerm PPS.defaultOpts lhs | ||
| apply convertibleFlag rules t | ||
| | permutative -> | ||
| do | ||
| t' <- scInstantiate sc inst rhs | ||
| case termWeightLt t' t of | ||
| True -> recordAnn annotation >> rewriteAll convertibleFlag t' -- keep the result only if it is "smaller" | ||
| False -> apply convertibleFlag rules t | ||
| | shallow -> | ||
| -- do not to further rewriting to the result of a "shallow" rule | ||
| do recordAnn annotation | ||
| scInstantiate sc inst rhs | ||
| | otherwise -> | ||
| do -- putStrLn "REWRITING:" | ||
| -- print lhs | ||
| recordAnn annotation | ||
| rewriteAll convertibleFlag =<< scInstantiate sc inst rhs | ||
| -- instead of syntactic rhs, has a bit of code that rewrites lhs (Term -> Maybe Term) | ||
| -- for now, assume all conversions are convertible, maybe check later | ||
| apply convertibleFlag (Right conv : rules) t = | ||
| do -- putStrLn "REWRITING:" | ||
| -- print (Net.toPat conv) | ||
| case runConversion conv t of | ||
| Nothing -> apply convertibleFlag rules t | ||
| Just tb -> rewriteAll convertibleFlag =<< runTermBuilder tb (scGlobalDef sc) (scTermF sc) | ||
|
|
||
|
|
||
| -- FIXME: is there some way to have sensable term replacement in the presence of loose variables | ||
| -- and/or under binders? | ||
|
|
@@ -924,7 +1052,7 @@ hoistIfs sc t = do | |
| ] | ||
| let ss :: Simpset () = addRules rules emptySimpset | ||
|
|
||
| (t', conds) <- doHoistIfs sc ss cache . snd =<< rewriteSharedTerm sc ss t | ||
| (t', conds) <- doHoistIfs sc ss cache . snd =<< rewriteSharedTermConvertibility sc ss t | ||
|
|
||
| -- remove duplicate conditions from the list, as muxing in SAW can result in | ||
| -- many copies of the same condition, which cause a performance issue | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.