Skip to content

Commit

Permalink
Make rewriteSharedTerm ensure that rule instantiations are well-typed.
Browse files Browse the repository at this point in the history
Fixes #1312.
  • Loading branch information
Brian Huffman committed Jun 22, 2021
1 parent 6f52028 commit 2e3ea56
Showing 1 changed file with 43 additions and 14 deletions.
57 changes: 43 additions & 14 deletions saw-core/src/Verifier/SAW/Rewriter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ reduceSharedTerm _ _ = Nothing
-- and returned in the result set.
rewriteSharedTerm :: forall a. Ord a => SharedContext -> Simpset a -> Term -> IO (Set a, Term)
rewriteSharedTerm sc ss t0 =
do cache <- newCache
do let ?env = []
cache <- newCache
let ?cache = cache
setRef <- newIORef mempty
let ?annSet = setRef
Expand All @@ -618,17 +619,39 @@ rewriteSharedTerm sc ss t0 =
pure (anns, t)

where
rewriteAll :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term
rewriteAll ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Term -> IO Term
rewriteAll (Unshared tf) =
traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop
rewriteSubterms tf >>= scTermF sc >>= rewriteTop
rewriteAll STApp{ stAppIndex = tidx, stAppTermF = tf } =
useCache ?cache tidx (traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop)

traverseTF :: forall b. (b -> IO b) -> TermF b -> IO (TermF b)
traverseTF _ tf@(Constant {}) = pure tf
traverseTF f tf = traverse f tf

rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term
useCache ?cache tidx (rewriteSubterms tf >>= scTermF sc >>= rewriteTop)

rewriteSubterms ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
TermF Term -> IO (TermF Term)
rewriteSubterms tf =
case tf of
FTermF ftf -> FTermF <$> traverse rewriteAll ftf
App t1 t2 -> App <$> rewriteAll t1 <*> rewriteAll t2
Lambda x t1 t2 ->
do t1' <- rewriteAll t1
localCache <- newCache
let localEnv = t1' : ?env
t2' <- let ?cache = localCache; ?env = localEnv in rewriteAll t2
pure (Lambda x t1' t2')
Pi x t1 t2 ->
do t1' <- rewriteAll t1
localCache <- newCache
let localEnv = t1' : ?env
t2' <- let ?cache = localCache; ?env = localEnv in rewriteAll t2
pure (Pi x t1' t2')
LocalVar{} -> pure tf
Constant{} -> pure tf

rewriteTop ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Term -> IO Term
rewriteTop t =
case reduceSharedTerm sc t of
Nothing -> apply (Net.unify_term ss t) t
Expand All @@ -638,8 +661,9 @@ rewriteSharedTerm sc ss t0 =
recordAnn Nothing = return ()
recordAnn (Just a) = modifyIORef' ?annSet (Set.insert a)

apply :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
[Either (RewriteRule a) Conversion] -> Term -> IO Term
apply ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
[Either (RewriteRule a) Conversion] -> Term -> IO Term
apply [] t = return t
apply (Left (RewriteRule {ctxt, lhs, rhs, permutative, annotation}) : rules) t = do
result <- scMatch sc lhs t
Expand All @@ -665,8 +689,13 @@ rewriteSharedTerm sc ss t0 =
| otherwise ->
do -- putStrLn "REWRITING:"
-- print lhs
recordAnn annotation
rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) rhs
tys <- traverse (scTypeOf' sc ?env) (Map.elems inst)
if tys /= ctxt
then
do apply rules t
else
do recordAnn annotation
rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) rhs
apply (Right conv : rules) t =
do -- putStrLn "REWRITING:"
-- print (Net.toPat conv)
Expand Down

0 comments on commit 2e3ea56

Please sign in to comment.