Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,59 @@ import (
"strings"
)

// binding is a single binding registered with Kong.
type binding struct {
// fn is a function that returns a value of the target type.
fn reflect.Value

// val is a value of the target type.
// Must be set if done and singleton are true.
val reflect.Value

// singleton indicates whether the binding is a singleton.
// If true, the binding will be resolved once and cached.
singleton bool

// done indicates whether a singleton binding has been resolved.
// If singleton is false, this field is ignored.
done bool
}

// newValueBinding builds a binding with an already resolved value.
func newValueBinding(v reflect.Value) *binding {
return &binding{val: v, done: true, singleton: true}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This considers bindings with prior values to be singletons since there's no need to invoke a function.

}

// newFunctionBinding builds a binding with a function
// that will return a value of the target type.
//
// The function signature must be func(...) (T, error) or func(...) T
// where parameters are recursively resolved.
func newFunctionBinding(f reflect.Value, singleton bool) *binding {
return &binding{fn: f, singleton: singleton}
}

// Get returns the pre-resolved value for the binding,
// or false if the binding is not resolved.
func (b *binding) Get() (v reflect.Value, ok bool) {
return b.val, b.done
}

// Set sets the value of the binding to the given value,
// marking it as resolved.
//
// If the binding is not a singleton, this method does nothing.
func (b *binding) Set(v reflect.Value) {
if b.singleton {
b.val = v
b.done = true
}
}

// A map of type to function that returns a value of that type.
//
// The function should have the signature func(...) (T, error). Arguments are recursively resolved.
type bindings map[reflect.Type]any
type bindings map[reflect.Type]*binding

func (b bindings) String() string {
out := []string{}
Expand All @@ -21,17 +70,18 @@ func (b bindings) String() string {

func (b bindings) add(values ...any) bindings {
for _, v := range values {
v := v
b[reflect.TypeOf(v)] = func() (any, error) { return v, nil }
val := reflect.ValueOf(v)
b[val.Type()] = newValueBinding(val)
}
return b
}

func (b bindings) addTo(impl, iface any) {
b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil }
val := reflect.ValueOf(impl)
b[reflect.TypeOf(iface).Elem()] = newValueBinding(val)
}

func (b bindings) addProvider(provider any) error {
func (b bindings) addProvider(provider any, singleton bool) error {
pv := reflect.ValueOf(provider)
t := pv.Type()
if t.Kind() != reflect.Func {
Expand All @@ -47,7 +97,7 @@ func (b bindings) addProvider(provider any) error {
}
}
rt := pv.Type().Out(0)
b[rt] = provider
b[rt] = newFunctionBinding(pv, singleton)
return nil
}

Expand Down Expand Up @@ -148,19 +198,29 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
t := f.Type()
for i := 0; i < t.NumIn(); i++ {
pt := t.In(i)
argf, ok := bindings[pt]
binding, ok := bindings[pt]
if !ok {
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
}

// Don't need to call the function if the value is already resolved.
if val, ok := binding.Get(); ok {
in = append(in, val)
continue
}

// Recursively resolve binding functions.
argv, err := callAnyFunction(reflect.ValueOf(argf), bindings)
argv, err := callAnyFunction(binding.fn, bindings)
if err != nil {
return nil, fmt.Errorf("%s: %w", pt, err)
}
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() {
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
}
in = append(in, reflect.ValueOf(argv[0]))

val := reflect.ValueOf(argv[0])
binding.Set(val)
in = append(in, val)
}
outv := f.Call(in)
out = make([]any, len(outv))
Expand Down
17 changes: 13 additions & 4 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,19 @@ func (c *Context) BindTo(impl, iface any) {
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where
// ... will be recursively injected with bound values.
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
// where ... will be recursively injected with bound values.
func (c *Context) BindToProvider(provider any) error {
return c.bindings.addProvider(provider)
return c.bindings.addProvider(provider, false /* singleton */)
}

// BindSingletonProvider allows binding of provider functions.
// The provider will be called once and the result cached.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
// where ... will be recursively injected with bound values.
func (c *Context) BindSingletonProvider(provider any) error {
return c.bindings.addProvider(provider, true /* singleton */)
}

// Value returns the value for a particular path element.
Expand Down Expand Up @@ -792,7 +801,7 @@ func (c *Context) RunNode(node *Node, binds ...any) (err error) {
methodt := t.Method(i)
if strings.HasPrefix(methodt.Name, "Provide") {
method := p.Method(i)
if err := methodBinds.addProvider(method.Interface()); err != nil {
if err := methodBinds.addProvider(method.Interface(), false /* singleton */); err != nil {
return fmt.Errorf("%s.%s: %w", t.Name(), methodt.Name, err)
}
}
Expand Down
24 changes: 21 additions & 3 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,33 @@ func BindTo(impl, iface any) Option {

// BindToProvider binds an injected value to a provider function.
//
// The provider function must have the signature:
// The provider function must have one of the following signatures:
//
// func(...) (T, error)
// func(...) T
//
// func() (any, error)
// Where arguments to the function are injected by Kong.
//
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
func BindToProvider(provider any) Option {
return OptionFunc(func(k *Kong) error {
return k.bindings.addProvider(provider)
return k.bindings.addProvider(provider, false /* singleton */)
})
}

// BindSingletonProvider binds an injected value to a provider function.
// The provider function must have the signature:
//
// func(...) (T, error)
// func(...) T
//
// Unlike [BindToProvider], the provider function will only be called
// at most once, and the result will be cached and reused
// across multiple recipients of the injected value.
func BindSingletonProvider(provider any) Option {
return OptionFunc(func(k *Kong) error {
return k.bindings.addProvider(provider, true /* singleton */)
})
}

Expand Down
37 changes: 37 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,43 @@ func TestBindToProvider(t *testing.T) {
assert.True(t, cli.Called)
}

func TestBindSingletonProvider(t *testing.T) {
type (
Connection struct{}
ClientA struct{ conn *Connection }
ClientB struct{ conn *Connection }
)

var numConnections int
newConnection := func() *Connection {
numConnections++
return &Connection{}
}

var cli struct{}
app, err := New(&cli,
BindSingletonProvider(newConnection),
BindToProvider(func(conn *Connection) *ClientA {
return &ClientA{conn: conn}
}),
BindToProvider(func(conn *Connection) *ClientB {
return &ClientB{conn: conn}
}),
)
assert.NoError(t, err)

ctx, err := app.Parse([]string{})
assert.NoError(t, err)

_, err = ctx.Call(func(a *ClientA, b *ClientB) {
assert.NotZero(t, a.conn)
assert.NotZero(t, b.conn)

assert.Equal(t, 1, numConnections, "expected newConnection to be called only once")
})
assert.NoError(t, err)
}

func TestFlagNamer(t *testing.T) {
var cli struct {
SomeFlag string
Expand Down