Skip to content

Commit 3b9af5b

Browse files
authored
feat: Support singleton providers (#501)
* feat: Support singleton providers This change adds support for provider functions that are not reinvoked even if requested by multiple other providers. Instead, their value is cached and reused between invocations. To make this possible, we change how bindings are stored: instead of just a function reference, we now store a binding object which records whether the binding is a singleton, and records the resolved singleton value (if any). Resolves #500 * refac(bindings): hide singleton status Don't require callAnyFunction to be aware of whether a binding is a singleton or not.
1 parent 7f94c90 commit 3b9af5b

File tree

4 files changed

+140
-16
lines changed

4 files changed

+140
-16
lines changed

callbacks.go

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,59 @@ import (
66
"strings"
77
)
88

9+
// binding is a single binding registered with Kong.
10+
type binding struct {
11+
// fn is a function that returns a value of the target type.
12+
fn reflect.Value
13+
14+
// val is a value of the target type.
15+
// Must be set if done and singleton are true.
16+
val reflect.Value
17+
18+
// singleton indicates whether the binding is a singleton.
19+
// If true, the binding will be resolved once and cached.
20+
singleton bool
21+
22+
// done indicates whether a singleton binding has been resolved.
23+
// If singleton is false, this field is ignored.
24+
done bool
25+
}
26+
27+
// newValueBinding builds a binding with an already resolved value.
28+
func newValueBinding(v reflect.Value) *binding {
29+
return &binding{val: v, done: true, singleton: true}
30+
}
31+
32+
// newFunctionBinding builds a binding with a function
33+
// that will return a value of the target type.
34+
//
35+
// The function signature must be func(...) (T, error) or func(...) T
36+
// where parameters are recursively resolved.
37+
func newFunctionBinding(f reflect.Value, singleton bool) *binding {
38+
return &binding{fn: f, singleton: singleton}
39+
}
40+
41+
// Get returns the pre-resolved value for the binding,
42+
// or false if the binding is not resolved.
43+
func (b *binding) Get() (v reflect.Value, ok bool) {
44+
return b.val, b.done
45+
}
46+
47+
// Set sets the value of the binding to the given value,
48+
// marking it as resolved.
49+
//
50+
// If the binding is not a singleton, this method does nothing.
51+
func (b *binding) Set(v reflect.Value) {
52+
if b.singleton {
53+
b.val = v
54+
b.done = true
55+
}
56+
}
57+
958
// A map of type to function that returns a value of that type.
1059
//
1160
// The function should have the signature func(...) (T, error). Arguments are recursively resolved.
12-
type bindings map[reflect.Type]any
61+
type bindings map[reflect.Type]*binding
1362

1463
func (b bindings) String() string {
1564
out := []string{}
@@ -21,17 +70,18 @@ func (b bindings) String() string {
2170

2271
func (b bindings) add(values ...any) bindings {
2372
for _, v := range values {
24-
v := v
25-
b[reflect.TypeOf(v)] = func() (any, error) { return v, nil }
73+
val := reflect.ValueOf(v)
74+
b[val.Type()] = newValueBinding(val)
2675
}
2776
return b
2877
}
2978

3079
func (b bindings) addTo(impl, iface any) {
31-
b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil }
80+
val := reflect.ValueOf(impl)
81+
b[reflect.TypeOf(iface).Elem()] = newValueBinding(val)
3282
}
3383

34-
func (b bindings) addProvider(provider any) error {
84+
func (b bindings) addProvider(provider any, singleton bool) error {
3585
pv := reflect.ValueOf(provider)
3686
t := pv.Type()
3787
if t.Kind() != reflect.Func {
@@ -47,7 +97,7 @@ func (b bindings) addProvider(provider any) error {
4797
}
4898
}
4999
rt := pv.Type().Out(0)
50-
b[rt] = provider
100+
b[rt] = newFunctionBinding(pv, singleton)
51101
return nil
52102
}
53103

@@ -148,19 +198,29 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
148198
t := f.Type()
149199
for i := 0; i < t.NumIn(); i++ {
150200
pt := t.In(i)
151-
argf, ok := bindings[pt]
201+
binding, ok := bindings[pt]
152202
if !ok {
153203
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
154204
}
205+
206+
// Don't need to call the function if the value is already resolved.
207+
if val, ok := binding.Get(); ok {
208+
in = append(in, val)
209+
continue
210+
}
211+
155212
// Recursively resolve binding functions.
156-
argv, err := callAnyFunction(reflect.ValueOf(argf), bindings)
213+
argv, err := callAnyFunction(binding.fn, bindings)
157214
if err != nil {
158215
return nil, fmt.Errorf("%s: %w", pt, err)
159216
}
160217
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() {
161218
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
162219
}
163-
in = append(in, reflect.ValueOf(argv[0]))
220+
221+
val := reflect.ValueOf(argv[0])
222+
binding.Set(val)
223+
in = append(in, val)
164224
}
165225
outv := f.Call(in)
166226
out = make([]any, len(outv))

context.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,19 @@ func (c *Context) BindTo(impl, iface any) {
120120
// This is useful when the Run() function of different commands require different values that may
121121
// not all be initialisable from the main() function.
122122
//
123-
// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where
124-
// ... will be recursively injected with bound values.
123+
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
124+
// where ... will be recursively injected with bound values.
125125
func (c *Context) BindToProvider(provider any) error {
126-
return c.bindings.addProvider(provider)
126+
return c.bindings.addProvider(provider, false /* singleton */)
127+
}
128+
129+
// BindSingletonProvider allows binding of provider functions.
130+
// The provider will be called once and the result cached.
131+
//
132+
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
133+
// where ... will be recursively injected with bound values.
134+
func (c *Context) BindSingletonProvider(provider any) error {
135+
return c.bindings.addProvider(provider, true /* singleton */)
127136
}
128137

129138
// Value returns the value for a particular path element.
@@ -792,7 +801,7 @@ func (c *Context) RunNode(node *Node, binds ...any) (err error) {
792801
methodt := t.Method(i)
793802
if strings.HasPrefix(methodt.Name, "Provide") {
794803
method := p.Method(i)
795-
if err := methodBinds.addProvider(method.Interface()); err != nil {
804+
if err := methodBinds.addProvider(method.Interface(), false /* singleton */); err != nil {
796805
return fmt.Errorf("%s.%s: %w", t.Name(), methodt.Name, err)
797806
}
798807
}

options.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,33 @@ func BindTo(impl, iface any) Option {
210210

211211
// BindToProvider binds an injected value to a provider function.
212212
//
213-
// The provider function must have the signature:
213+
// The provider function must have one of the following signatures:
214+
//
215+
// func(...) (T, error)
216+
// func(...) T
214217
//
215-
// func() (any, error)
218+
// Where arguments to the function are injected by Kong.
216219
//
217220
// This is useful when the Run() function of different commands require different values that may
218221
// not all be initialisable from the main() function.
219222
func BindToProvider(provider any) Option {
220223
return OptionFunc(func(k *Kong) error {
221-
return k.bindings.addProvider(provider)
224+
return k.bindings.addProvider(provider, false /* singleton */)
225+
})
226+
}
227+
228+
// BindSingletonProvider binds an injected value to a provider function.
229+
// The provider function must have the signature:
230+
//
231+
// func(...) (T, error)
232+
// func(...) T
233+
//
234+
// Unlike [BindToProvider], the provider function will only be called
235+
// at most once, and the result will be cached and reused
236+
// across multiple recipients of the injected value.
237+
func BindSingletonProvider(provider any) Option {
238+
return OptionFunc(func(k *Kong) error {
239+
return k.bindings.addProvider(provider, true /* singleton */)
222240
})
223241
}
224242

options_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,43 @@ func TestBindToProvider(t *testing.T) {
119119
assert.True(t, cli.Called)
120120
}
121121

122+
func TestBindSingletonProvider(t *testing.T) {
123+
type (
124+
Connection struct{}
125+
ClientA struct{ conn *Connection }
126+
ClientB struct{ conn *Connection }
127+
)
128+
129+
var numConnections int
130+
newConnection := func() *Connection {
131+
numConnections++
132+
return &Connection{}
133+
}
134+
135+
var cli struct{}
136+
app, err := New(&cli,
137+
BindSingletonProvider(newConnection),
138+
BindToProvider(func(conn *Connection) *ClientA {
139+
return &ClientA{conn: conn}
140+
}),
141+
BindToProvider(func(conn *Connection) *ClientB {
142+
return &ClientB{conn: conn}
143+
}),
144+
)
145+
assert.NoError(t, err)
146+
147+
ctx, err := app.Parse([]string{})
148+
assert.NoError(t, err)
149+
150+
_, err = ctx.Call(func(a *ClientA, b *ClientB) {
151+
assert.NotZero(t, a.conn)
152+
assert.NotZero(t, b.conn)
153+
154+
assert.Equal(t, 1, numConnections, "expected newConnection to be called only once")
155+
})
156+
assert.NoError(t, err)
157+
}
158+
122159
func TestFlagNamer(t *testing.T) {
123160
var cli struct {
124161
SomeFlag string

0 commit comments

Comments
 (0)