Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,41 @@ func TestContextEval(t *testing.T) {
}
}

func TestContextEvalPropagation(t *testing.T) {
env, err := NewEnv(Function("test",
Overload("test_int", []*Type{}, IntType,
FunctionBindingContext(func(ctx context.Context, _ ...ref.Val) ref.Val {
md := ctx.Value("metadata")
if md == nil {
return types.NewErr("cannot find metadata value")
}
return types.Int(md.(int))
}),
),
))
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile("test()")
if iss.Err() != nil {
t.Fatalf("env.Compile(expr) failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}

expected := 10
ctx := context.WithValue(context.Background(), "metadata", expected)
out, _, err := prg.ContextEval(ctx, map[string]interface{}{})
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
if out != types.Int(expected) {
t.Errorf("prg.ContextEval() got %v, but wanted %d", out, expected)
}
}

func BenchmarkContextEval(b *testing.B) {
env := testEnv(b,
Variable("items", ListType(IntType)),
Expand Down Expand Up @@ -1428,7 +1463,7 @@ func TestCustomInterpreterDecorator(t *testing.T) {
if !lhsIsConst || !rhsIsConst {
return i, nil
}
val := call.Eval(interpreter.EmptyActivation())
val := call.Eval(context.Background(), interpreter.EmptyActivation())
if types.IsError(val) {
return nil, val.(*types.Err)
}
Expand Down
18 changes: 18 additions & 0 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return decls.FunctionBinding(binding)
}

// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt {
return decls.UnaryBindingContext(binding)
}

// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt {
return decls.BinaryBindingContext(binding)
}

// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt {
return decls.FunctionBindingContext(binding)
}

// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
Expand Down
5 changes: 3 additions & 2 deletions cel/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cel

import (
"context"
"fmt"
"math"
"reflect"
Expand Down Expand Up @@ -673,7 +674,7 @@ func TestExprDeclToDeclaration(t *testing.T) {
}
prg, err := e.Program(ast, Functions(&functions.Overload{
Operator: overloads.SizeString,
Unary: func(arg ref.Val) ref.Val {
Unary: func(ctx context.Context, arg ref.Val) ref.Val {
str, ok := arg.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
Expand All @@ -682,7 +683,7 @@ func TestExprDeclToDeclaration(t *testing.T) {
},
}, &functions.Overload{
Operator: overloads.SizeStringInst,
Unary: func(arg ref.Val) ref.Val {
Unary: func(ctx context.Context, arg ref.Val) ref.Val {
str, ok := arg.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
Expand Down
13 changes: 7 additions & 6 deletions cel/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cel

import (
"context"
"math"
"strconv"
"strings"
Expand Down Expand Up @@ -494,17 +495,17 @@ func (opt *evalOptionalOr) ID() int64 {

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
func (opt *evalOptionalOr) Eval(ctx context.Context, vars interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optLHS := opt.lhs.Eval(ctx, vars)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal
}
return opt.rhs.Eval(ctx)
return opt.rhs.Eval(ctx, vars)
}

// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
Expand All @@ -522,17 +523,17 @@ func (opt *evalOptionalOrValue) ID() int64 {

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
func (opt *evalOptionalOrValue) Eval(ctx context.Context, vars interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optLHS := opt.lhs.Eval(ctx, vars)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal.GetValue()
}
return opt.rhs.Eval(ctx)
return opt.rhs.Eval(ctx, vars)
}

type timeUTCLibrary struct{}
Expand Down
16 changes: 13 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorat

// Eval implements the Program interface method.
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
return p.eval(context.Background(), input)
}

// Eval implements the Program interface method.
func (p *prog) eval(ctx context.Context, input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
Expand Down Expand Up @@ -291,7 +296,7 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
}
v = p.interpretable.Eval(vars)
v = p.interpretable.Eval(ctx, vars)
// The output of an internal Eval may have a value (`v`) that is a types.Err. This step
// translates the CEL value to a Go error response. This interface does not quite match the
// RPC signature which allows for multiple errors to be returned, but should be sufficient.
Expand Down Expand Up @@ -321,7 +326,7 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
return p.Eval(vars)
return p.eval(ctx, vars)
}

// progFactory is a helper alias for marking a program creation factory function.
Expand Down Expand Up @@ -349,6 +354,11 @@ func newProgGen(factory progFactory) (Program, error) {

// Eval implements the Program interface method.
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
return gen.eval(context.Background(), input)
}

// Eval implements the Program interface method.
func (gen *progGen) eval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
Expand All @@ -368,7 +378,7 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
}

// Evaluate the input, returning the result and the 'state' within EvalDetails.
v, _, err := p.Eval(input)
v, _, err := p.ContextEval(ctx, input)
if err != nil {
return v, det, err
}
Expand Down
83 changes: 57 additions & 26 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package decls

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -242,23 +243,23 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
// All of the defined overloads are wrapped into a top-level function which
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
funcDispatch := func(ctx context.Context, args ...ref.Val) ref.Val {
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
// During dynamic dispatch over multiple functions, signature agreement checks
// are preserved in order to assist with the function resolution step.
switch len(args) {
case 1:
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.unaryOp(args[0])
return o.unaryOp(ctx, args[0])
}
case 2:
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.binaryOp(args[0], args[1])
return o.binaryOp(ctx, args[0], args[1])
}
}
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.functionOp(args...)
return o.functionOp(ctx, args...)
}
// eventually this will fall through to the noSuchOverload below.
}
Expand Down Expand Up @@ -333,8 +334,10 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Unary: fn,
Operator: f.Name(),
Unary: func(ctx context.Context, val ref.Val) ref.Val {
return fn(val)
},
OperandTrait: trait,
}
return f, nil
Expand All @@ -355,8 +358,10 @@ func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Binary: fn,
Operator: f.Name(),
Binary: func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val {
return fn(lhs, rhs)
},
OperandTrait: trait,
}
return f, nil
Expand All @@ -377,8 +382,10 @@ func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOp
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Function: fn,
Operator: f.Name(),
Function: func(ctx context.Context, values ...ref.Val) ref.Val {
return fn(values...)
},
OperandTrait: trait,
}
return f, nil
Expand Down Expand Up @@ -460,11 +467,11 @@ type OverloadDecl struct {

// Function implementation options. Optional, but encouraged.
// unaryOp is a function binding that takes a single argument.
unaryOp functions.UnaryOp
unaryOp functions.UnaryContextOp
// binaryOp is a function binding that takes two arguments.
binaryOp functions.BinaryOp
binaryOp functions.BinaryContextOp
// functionOp is a catch-all for zero-arity and three-plus arity functions.
functionOp functions.FunctionOp
functionOp functions.FunctionContextOp
}

// ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker
Expand Down Expand Up @@ -580,41 +587,41 @@ func (o *OverloadDecl) hasBinding() bool {
}

// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryOp {
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryContextOp {
if o.unaryOp == nil {
return nil
}
return func(arg ref.Val) ref.Val {
return func(ctx context.Context, arg ref.Val) ref.Val {
if !o.matchesRuntimeUnarySignature(disableTypeGuards, arg) {
return MaybeNoSuchOverload(funcName, arg)
}
return o.unaryOp(arg)
return o.unaryOp(ctx, arg)
}
}

// guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined.
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryOp {
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryContextOp {
if o.binaryOp == nil {
return nil
}
return func(arg1, arg2 ref.Val) ref.Val {
return func(ctx context.Context, arg1, arg2 ref.Val) ref.Val {
if !o.matchesRuntimeBinarySignature(disableTypeGuards, arg1, arg2) {
return MaybeNoSuchOverload(funcName, arg1, arg2)
}
return o.binaryOp(arg1, arg2)
return o.binaryOp(ctx, arg1, arg2)
}
}

// guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided.
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionOp {
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionContextOp {
if o.functionOp == nil {
return nil
}
return func(args ...ref.Val) ref.Val {
return func(ctx context.Context, args ...ref.Val) ref.Val {
if !o.matchesRuntimeSignature(disableTypeGuards, args...) {
return MaybeNoSuchOverload(funcName, args...)
}
return o.functionOp(args...)
return o.functionOp(ctx, args...)
}
}

Expand Down Expand Up @@ -667,6 +674,30 @@ type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error)
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
return UnaryBindingContext(func(ctx context.Context, val ref.Val) ref.Val {
return binding(val)
})
}

// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
return BinaryBindingContext(func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val {
return binding(lhs, rhs)
})
}

// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return FunctionBindingContext(func(ctx context.Context, values ...ref.Val) ref.Val {
return binding(values...)
})
}

// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
Expand All @@ -679,9 +710,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
}
}

// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
Expand All @@ -694,9 +725,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
}
}

// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
Expand Down
Loading