Skip to content

Commit

Permalink
Add option to use inaccessible accumulator var (#1097)
Browse files Browse the repository at this point in the history
Add an option to use '@Result' as the accumulator variable for builtin
comprehensions. The current default "__result__" is accesible in the CEL
Source, allowing for expressions to type check but lead to unexpected or
incorrect results. '@Result' isn't a normally accessible identifier in the
source expression so a bit safer as a default.
  • Loading branch information
jnthntatum authored Jan 7, 2025
1 parent 7c5909e commit fa6eb32
Show file tree
Hide file tree
Showing 11 changed files with 331 additions and 38 deletions.
15 changes: 3 additions & 12 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,18 +777,9 @@ func TestMacroInterop(t *testing.T) {
}

func TestMacroModern(t *testing.T) {
existsOneMacro := ReceiverMacro("exists_one", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeExistsOne(mef, iterRange, args)
})
transformMacro := ReceiverMacro("transform", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeMap(mef, iterRange, args)
})
filterMacro := ReceiverMacro("filter", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return parser.MakeFilter(mef, iterRange, args)
})
existsOneMacro := ReceiverMacro("exists_one", 2, parser.MakeExistsOne)
transformMacro := ReceiverMacro("transform", 2, parser.MakeMap)
filterMacro := ReceiverMacro("filter", 2, parser.MakeFilter)
pairMacro := GlobalMacro("pair", 2,
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
return mef.NewMap(mef.NewMapEntry(args[0], args[1], false)), nil
Expand Down
9 changes: 9 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,15 @@ func ParserExpressionSizeLimit(limit int) EnvOption {
}
}

// EnableHiddenAccumulatorName sets the parser to use the identifier '@result' for accumulators
// which is not normally accessible from CEL source.
func EnableHiddenAccumulatorName(enabled bool) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.EnableHiddenAccumulatorName(enabled))
return e, nil
}
}

func maybeInteropProvider(provider any) (types.Provider, error) {
switch p := provider.(type) {
case types.Provider:
Expand Down
6 changes: 5 additions & 1 deletion checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,13 @@ func (c *coster) addPath(e ast.Expr, path []string) {
c.exprPath[e.ID()] = path
}

func isAccumulatorVar(name string) bool {
return name == parser.AccumulatorName || name == parser.HiddenAccumulatorName
}

func (c *coster) newAstNode(e ast.Expr) *astNode {
path := c.getPath(e)
if len(path) > 0 && path[0] == parser.AccumulatorName {
if len(path) > 0 && isAccumulatorVar(path[0]) {
// only provide paths to root vars; omit accumulator vars
path = nil
}
Expand Down
25 changes: 22 additions & 3 deletions common/ast/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type ExprFactory interface {
//comprehension.
NewAccuIdent(id int64) Expr

// AccuIdentName reports the name of the accumulator variable to be used within a comprehension.
AccuIdentName() string

// NewLiteral creates an Expr value representing a literal value, such as a string or integer.
NewLiteral(id int64, value ref.Val) Expr

Expand Down Expand Up @@ -78,11 +81,23 @@ type ExprFactory interface {
isExprFactory()
}

type baseExprFactory struct{}
type baseExprFactory struct {
accumulatorName string
}

// NewExprFactory creates an ExprFactory instance.
func NewExprFactory() ExprFactory {
return &baseExprFactory{}
return &baseExprFactory{
"__result__",
}
}

// NewExprFactoryWithAccumulator creates an ExprFactory instance with a custom
// accumulator identifier name.
func NewExprFactoryWithAccumulator(id string) ExprFactory {
return &baseExprFactory{
id,
}
}

func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr {
Expand Down Expand Up @@ -138,7 +153,11 @@ func (fac *baseExprFactory) NewIdent(id int64, name string) Expr {
}

func (fac *baseExprFactory) NewAccuIdent(id int64) Expr {
return fac.NewIdent(id, "__result__")
return fac.NewIdent(id, fac.AccuIdentName())
}

func (fac *baseExprFactory) AccuIdentName() string {
return fac.accumulatorName
}

func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr {
Expand Down
28 changes: 14 additions & 14 deletions ext/comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewLiteral(types.True),
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewAccuIdent()),
/*step=*/ mef.NewCall(operators.LogicalAnd, mef.NewAccuIdent(), args[2]),
Expand All @@ -267,7 +267,7 @@ func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewLiteral(types.False),
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewCall(operators.LogicalNot, mef.NewAccuIdent())),
/*step=*/ mef.NewCall(operators.LogicalOr, mef.NewAccuIdent(), args[2]),
Expand All @@ -285,7 +285,7 @@ func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.E
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewLiteral(types.Int(0)),
/*condition=*/ mef.NewLiteral(types.True),
/*step=*/ mef.NewCall(operators.Conditional, args[2],
Expand All @@ -311,18 +311,18 @@ func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
transform = args[2]
}

// __result__ = __result__ + [transform]
// accumulator = accumulator + [transform]
step := mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewList(transform))
if filter != nil {
// __result__ = (filter) ? __result__ + [transform] : __result__
// accumulator = (filter) ? accumulator + [transform] : accumulator
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewList(),
/*condition=*/ mef.NewLiteral(types.True),
step,
Expand All @@ -346,17 +346,17 @@ func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (a
transform = args[2]
}

// __result__ = cel.@mapInsert(__result__, iterVar1, transform)
// accumulator = cel.@mapInsert(accumulator, iterVar1, transform)
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), mef.NewIdent(iterVar1), transform)
if filter != nil {
// __result__ = (filter) ? cel.@mapInsert(__result__, iterVar1, transform) : __result__
// accumulator = (filter) ? cel.@mapInsert(accumulator, iterVar1, transform) : accumulator
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewMap(),
/*condition=*/ mef.NewLiteral(types.True),
step,
Expand All @@ -380,17 +380,17 @@ func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Exp
transform = args[2]
}

// __result__ = cel.@mapInsert(__result__, transform)
// accumulator = cel.@mapInsert(accumulator, transform)
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), transform)
if filter != nil {
// __result__ = (filter) ? cel.@mapInsert(__result__, transform) : __result__
// accumulator = (filter) ? cel.@mapInsert(accumulator, transform) : accumulator
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
mef.AccuIdentName(),
/*accuInit=*/ mef.NewMap(),
/*condition=*/ mef.NewLiteral(types.True),
step,
Expand All @@ -410,10 +410,10 @@ func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, str
if iterVar1 == iterVar2 {
return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1))
}
if iterVar1 == parser.AccumulatorName {
if iterVar1 == mef.AccuIdentName() || iterVar1 == parser.AccumulatorName {
return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable")
}
if iterVar2 == parser.AccumulatorName {
if iterVar2 == mef.AccuIdentName() || iterVar2 == parser.AccumulatorName {
return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable")
}
return iterVar1, iterVar2, nil
Expand Down
4 changes: 3 additions & 1 deletion ext/comprehensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
Lists(),
Strings(),
cel.OptionalTypes(),
cel.EnableMacroCallTracking()}
cel.EnableMacroCallTracking(),
cel.EnableHiddenAccumulatorName(true),
}
env, err := cel.NewEnv(append(baseOpts, opts...)...)
if err != nil {
t.Fatalf("cel.NewEnv(TwoVarComprehensions()) failed: %v", err)
Expand Down
5 changes: 5 additions & 0 deletions parser/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ func (e *exprHelper) NewAccuIdent() ast.Expr {
return e.exprFactory.NewAccuIdent(e.nextMacroID())
}

// AccuIdentName implements the ExprHelper interface method.
func (e *exprHelper) AccuIdentName() string {
return e.exprFactory.AccuIdentName()
}

// NewGlobalCall implements the ExprHelper interface method.
func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr {
return e.exprFactory.NewCall(e.nextMacroID(), function, args...)
Expand Down
23 changes: 17 additions & 6 deletions parser/macro.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ type ExprHelper interface {
// NewAccuIdent returns an accumulator identifier for use with comprehension results.
NewAccuIdent() ast.Expr

// AccuIdentName returns the name of the accumulator identifier.
AccuIdentName() string

// NewCall creates a function call Expr value for a global (free) function.
NewCall(function string, args ...ast.Expr) ast.Expr

Expand Down Expand Up @@ -298,6 +301,11 @@ var (
// AccumulatorName is the traditional variable name assigned to the fold accumulator variable.
const AccumulatorName = "__result__"

// HiddenAccumulatorName is a proposed update to the default fold accumlator variable.
// @result is not normally accessible from source, preventing accidental or intentional collisions
// in user expressions.
const HiddenAccumulatorName = "@result"

type quantifierKind int

const (
Expand Down Expand Up @@ -342,7 +350,8 @@ func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common
if !found {
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
}
if v == AccumulatorName {
accu := eh.AccuIdentName()
if v == accu || v == AccumulatorName {
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
}

Expand All @@ -364,7 +373,7 @@ func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common
if filter != nil {
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
}
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
return eh.NewComprehension(target, v, accu, init, condition, step, eh.NewAccuIdent()), nil
}

// MakeFilter expands the input call arguments into a comprehension which produces a list which contains
Expand All @@ -375,7 +384,8 @@ func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *com
if !found {
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
}
if v == AccumulatorName {
accu := eh.AccuIdentName()
if v == accu || v == AccumulatorName {
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
}

Expand All @@ -384,7 +394,7 @@ func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *com
condition := eh.NewLiteral(types.True)
step := eh.NewCall(operators.Add, eh.NewAccuIdent(), eh.NewList(args[0]))
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
return eh.NewComprehension(target, v, accu, init, condition, step, eh.NewAccuIdent()), nil
}

// MakeHas expands the input call arguments into a presence test, e.g. has(<operand>.field)
Expand All @@ -401,7 +411,8 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []
if !found {
return nil, eh.NewError(args[0].ID(), "argument must be a simple name")
}
if v == AccumulatorName {
accu := eh.AccuIdentName()
if v == accu || v == AccumulatorName {
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
}

Expand Down Expand Up @@ -431,7 +442,7 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []
default:
return nil, eh.NewError(args[0].ID(), fmt.Sprintf("unrecognized quantifier '%v'", kind))
}
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, result), nil
return eh.NewComprehension(target, v, accu, init, condition, step, result), nil
}

func extractIdent(e ast.Expr) (string, bool) {
Expand Down
13 changes: 13 additions & 0 deletions parser/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type options struct {
enableOptionalSyntax bool
enableVariadicOperatorASTs bool
enableIdentEscapeSyntax bool
enableHiddenAccumulatorName bool
}

// Option configures the behavior of the parser.
Expand Down Expand Up @@ -137,6 +138,18 @@ func EnableIdentEscapeSyntax(enableIdentEscapeSyntax bool) Option {
}
}

// EnableHiddenAccumulatorName uses an accumulator variable name that is not a
// normally accessible identifier in source for comprehension macros. Compatibility notes:
// with this option enabled, a parsed AST would be semantically the same as if disabled, but would
// have different internal identifiers in any of the built-in comprehension sub-expressions. When
// disabled, it is possible but almost certainly a logic error to access the accumulator variable.
func EnableHiddenAccumulatorName(enabled bool) Option {
return func(opts *options) error {
opts.enableHiddenAccumulatorName = enabled
return nil
}
}

// EnableVariadicOperatorASTs enables a compact representation of chained like-kind commutative
// operators. e.g. `a || b || c || d` -> `call(op='||', args=[a, b, c, d])`
//
Expand Down
6 changes: 5 additions & 1 deletion parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func mustNewParser(opts ...Option) *Parser {
// Parse parses the expression represented by source and returns the result.
func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) {
errs := common.NewErrors(source)
fac := ast.NewExprFactory()
accu := AccumulatorName
if p.enableHiddenAccumulatorName {
accu = HiddenAccumulatorName
}
fac := ast.NewExprFactoryWithAccumulator(accu)
impl := parser{
errors: &parseErrors{errs},
exprFactory: fac,
Expand Down
Loading

0 comments on commit fa6eb32

Please sign in to comment.