Skip to content

Commit f5419ec

Browse files
committed
Teach the PIR beta pass to handle multiple arguments at once
This has bothered me for a while. We rely on this to turn function applications into let-bindings so that we can inline them, but because of the way it was written it could only ever do this for at most one argument at a time! So we could only fully-inline a function of arity N with N rounds of the simplifier! Terrible! Now it handles any number at once. I wrote all the functions for this while working on #4365, since there we *really* need it, since constructors and destructors appear like multiple function arguments rather than multiple nested immediately-applied-lambdas. But it applies well here too, so I thought I'd just do this improvement on the side.
1 parent 3619837 commit f5419ec

File tree

8 files changed

+133
-48
lines changed

8 files changed

+133
-48
lines changed

plutus-core/plutus-ir/src/PlutusIR/Core/Type.hs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ module PlutusIR.Core.Type (
1818
Binding (..),
1919
Term (..),
2020
Program (..),
21-
applyProgram
21+
applyProgram,
22+
termAnn
2223
) where
2324

2425
import PlutusPrelude
@@ -158,3 +159,17 @@ applyProgram
158159
-> Program tyname name uni fun a
159160
-> Program tyname name uni fun a
160161
applyProgram (Program a1 t1) (Program a2 t2) = Program (a1 <> a2) (Apply mempty t1 t2)
162+
163+
termAnn :: Term tynam name uni fun a -> a
164+
termAnn t = case t of
165+
Let a _ _ _ -> a
166+
Var a _ -> a
167+
TyAbs a _ _ _ -> a
168+
LamAbs a _ _ _ -> a
169+
Apply a _ _ -> a
170+
Constant a _ -> a
171+
Builtin a _ -> a
172+
TyInst a _ _ -> a
173+
Error a _ -> a
174+
IWrap a _ _ _ -> a
175+
Unwrap a _ -> a
Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,84 @@
1-
{-# LANGUAGE LambdaCase #-}
1+
{-# LANGUAGE LambdaCase #-}
2+
{-# LANGUAGE ViewPatterns #-}
23
{-|
34
A simple beta-reduction pass.
45
-}
56
module PlutusIR.Transform.Beta (
67
beta
78
) where
89

9-
import PlutusPrelude
10-
1110
import PlutusIR
11+
import PlutusIR.Core.Type
1212

13-
import Control.Lens (transformOf)
13+
import Control.Lens.Setter ((%~))
14+
import Data.Function ((&))
15+
import Data.List.NonEmpty qualified as NE
1416

15-
{-|
16-
A single non-recursive application of the beta rule.
17+
{- Note [Multi-beta]
18+
Consider two examples where applying beta should be helpful.
19+
20+
1: [(\x . [(\y . t) b]) a]
21+
2: [[(\x . (\y . t)) a] b]
22+
23+
(1) is the typical "let-binding" pattern: each binding corresponds to an immediately-applied lambda.
24+
(2) is the typical "function application" pattern: a multi-argument function applied to multiple arguments.
25+
26+
In both cases we would like to produce something like
27+
28+
let
29+
x = a
30+
y = b
31+
in t
32+
33+
However, if we naively do a bottom-up pattern-matching transformation on the AST
34+
to look for immediately-applied lambda abstractions then we will get the following:
35+
36+
1:
37+
[(\x . [(\y . t) b]) a]
38+
-->
39+
[(\x . let y = b in t) a]
40+
->
41+
let x = a in let y = b in t
42+
43+
2:
44+
[[(\x . (\y . t)) a] b]
45+
-->
46+
[(let x = a in (\y . t)) b]
47+
48+
Now, if we later lift the let out, then we will be able to see that we can transform (2) further.
49+
But that means that a) we'd have to do the expensive let-floating pass in every iteration of the simplifier, and
50+
b) we can only inline one function argument per iteration of the simplifier, so for a function of
51+
arity N we *must* do at least N passes.
52+
53+
This isn't great, so the solution is to recognize case (2) properly and handle all the arguments in one go.
54+
That will also match cases like (1) just fine, since it's just made up of unary function applications.
55+
56+
That does mean that we need to do a manual traversal rather than doing standard bottom-up processing.
1757
-}
18-
betaStep
19-
:: Term tyname name uni fun a
20-
-> Term tyname name uni fun a
21-
betaStep = \case
22-
Apply a (LamAbs _ name typ body) arg ->
23-
let varDecl = VarDecl a name typ
24-
binding = TermBind a Strict varDecl arg
25-
bindings = binding :| []
26-
in
27-
Let a NonRec bindings body
28-
-- This case is disabled as it introduces a lot of type inlining (determined from profiling)
29-
-- and is currently unsound https://input-output.atlassian.net/browse/SCP-2570.
30-
-- TyInst a (TyAbs _ tyname kind body) typ ->
31-
-- let tyVarDecl = TyVarDecl a tyname kind
32-
-- tyBinding = TypeBind a tyVarDecl typ
33-
-- bindings = tyBinding :| []
34-
-- in
35-
-- Let a NonRec bindings body
36-
t -> t
58+
59+
{-| Extract the list of bindings from a term, a bit like a "multi-beta" reduction.
60+
61+
Some examples will help:
62+
63+
[(\x . t) a] -> Just ([x |-> a], t)
64+
65+
[[[(\x . (\y . (\z . t))) a] b] c] -> Just ([x |-> a, y |-> b, z |-> c]) t)
66+
67+
[[(\x . t) a] b] -> Nothing
68+
69+
When we decide that we want to do beta for types, we will need to extend this to handle type instantiations too.
70+
-}
71+
extractBindings :: Term tyname name uni fun a -> Maybe (NE.NonEmpty (Binding tyname name uni fun a), Term tyname name uni fun a)
72+
extractBindings = collectArgs []
73+
where
74+
collectArgs argStack (Apply _ f arg) = collectArgs (arg:argStack) f
75+
collectArgs argStack t = matchArgs argStack [] t
76+
matchArgs (arg:rest) acc (LamAbs a n ty body) = matchArgs rest (TermBind a Strict (VarDecl a n ty) arg:acc) body
77+
matchArgs [] acc t =
78+
case NE.nonEmpty (reverse acc) of
79+
Nothing -> Nothing
80+
Just acc' -> Just (acc', t)
81+
matchArgs (_:_) _ _ = Nothing
3782

3883
{-|
3984
Recursively apply the beta transformation on the code, both for the terms
@@ -57,4 +102,17 @@ and types
57102
beta
58103
:: Term tyname name uni fun a
59104
-> Term tyname name uni fun a
60-
beta = transformOf termSubterms betaStep
105+
beta = \case
106+
-- See Note [Multi-beta]
107+
-- This maybe isn't the best annotation for this term, but it will do.
108+
(extractBindings -> Just (bs, t)) -> Let (termAnn t) NonRec bs (beta t)
109+
-- This case is disabled as it introduces a lot of type inlining (determined from profiling)
110+
-- and is currently unsound https://input-output.atlassian.net/browse/SCP-2570.
111+
-- TyInst a (TyAbs _ tyname kind body) typ ->
112+
-- let tyVarDecl = TyVarDecl a tyname kind
113+
-- tyBinding = TypeBind a tyVarDecl typ
114+
-- bindings = tyBinding :| []
115+
-- in
116+
-- Let a NonRec bindings body
117+
118+
t -> t & termSubterms %~ beta

plutus-core/plutus-ir/test/TransformSpec.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ beta =
146146
$ map (goldenPir (Beta.beta . runQuote . PLC.rename) $ term @PLC.DefaultUni @PLC.DefaultFun)
147147
[ "lamapp"
148148
, "absapp"
149+
, "multiapp"
150+
, "multilet"
149151
]
150152

151153
unwrapCancel :: TestNested

plutus-core/plutus-ir/test/recursion/stupidZero.golden

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,29 +84,17 @@
8484
x_i0
8585
Nat_i6
8686
[
87-
[
88-
(lam
89-
stupidZero_i0
90-
(fun Nat_i7 Nat_i7)
87+
(lam
88+
stupidZero_i0
89+
(fun Nat_i7 Nat_i7)
90+
[
91+
[ { [ match_Nat_i4 x_i2 ] Nat_i7 } Zero_i6 ]
9192
(lam
92-
n_i0
93-
Nat_i8
94-
[
95-
[
96-
{ [ match_Nat_i5 n_i1 ] Nat_i8 }
97-
Zero_i7
98-
]
99-
(lam
100-
pred_i0
101-
Nat_i9
102-
[ stupidZero_i3 pred_i1 ]
103-
)
104-
]
93+
pred_i0 Nat_i8 [ stupidZero_i2 pred_i1 ]
10594
)
106-
)
107-
[ (unwrap s_i2) s_i2 ]
108-
]
109-
x_i1
95+
]
96+
)
97+
[ (unwrap s_i2) s_i2 ]
11098
]
11199
)
112100
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[ (lam x (con integer) (lam y (con integer) (lam z (con integer) [ y x z ]))) (con integer 1) (con integer 2) (con integer 3) ]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(let
2+
(nonrec)
3+
(termbind (strict) (vardecl x (con integer)) (con integer 1))
4+
(termbind (strict) (vardecl y (con integer)) (con integer 2))
5+
(termbind (strict) (vardecl z (con integer)) (con integer 3))
6+
[ [ y x ] z ]
7+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[ (lam x (con integer) [ (lam y (con integer) [ (lam z (con integer) [ y x z ]) (con integer 3) ] ) (con integer 2) ]) (con integer 1)]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
(let
2+
(nonrec)
3+
(termbind (strict) (vardecl x (con integer)) (con integer 1))
4+
(let
5+
(nonrec)
6+
(termbind (strict) (vardecl y (con integer)) (con integer 2))
7+
(let
8+
(nonrec)
9+
(termbind (strict) (vardecl z (con integer)) (con integer 3))
10+
[ [ y x ] z ]
11+
)
12+
)
13+
)

0 commit comments

Comments
 (0)