Skip to content

Commit 937e03e

Browse files
committed
use nil for func args
1 parent 1727629 commit 937e03e

File tree

2 files changed

+66
-55
lines changed

2 files changed

+66
-55
lines changed

argsdef.go

+32-48
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,14 @@ func (def *ArgsDef) StringArgsFunc(commandFunc interface{}, resultsHandlers []Re
188188
return nil, err
189189
}
190190

191-
return func(ctx context.Context, callerArgs ...string) error {
191+
f := func(ctx context.Context, callerArgs ...string) error {
192192
argVals, err := def.argValsFromStringArgs(callerArgs)
193193
if err != nil {
194194
return err
195195
}
196-
if dispatcher.firstArgIsContext {
197-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
198-
}
199-
return dispatcher.callWithResultsHandlers(argVals, resultsHandlers)
200-
}, nil
196+
return dispatcher.callWithResultsHandlers(ctx, argVals, resultsHandlers)
197+
}
198+
return f, nil
201199
}
202200

203201
func (def *ArgsDef) StringMapArgsFunc(commandFunc interface{}, resultsHandlers []ResultsHandler) (StringMapArgsFunc, error) {
@@ -206,16 +204,14 @@ func (def *ArgsDef) StringMapArgsFunc(commandFunc interface{}, resultsHandlers [
206204
return nil, err
207205
}
208206

209-
return func(ctx context.Context, callerArgs map[string]string) (err error) {
207+
f := func(ctx context.Context, callerArgs map[string]string) (err error) {
210208
argVals, err := def.argValsFromStringMapArgs(callerArgs)
211209
if err != nil {
212210
return err
213211
}
214-
if dispatcher.firstArgIsContext {
215-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
216-
}
217-
return dispatcher.callWithResultsHandlers(argVals, resultsHandlers)
218-
}, nil
212+
return dispatcher.callWithResultsHandlers(ctx, argVals, resultsHandlers)
213+
}
214+
return f, nil
219215
}
220216

221217
func (def *ArgsDef) MapArgsFunc(commandFunc interface{}, resultsHandlers []ResultsHandler) (MapArgsFunc, error) {
@@ -224,16 +220,14 @@ func (def *ArgsDef) MapArgsFunc(commandFunc interface{}, resultsHandlers []Resul
224220
return nil, err
225221
}
226222

227-
return func(ctx context.Context, callerArgs map[string]interface{}) (err error) {
223+
f := func(ctx context.Context, callerArgs map[string]interface{}) (err error) {
228224
argVals, err := def.argValsFromMapArgs(callerArgs)
229225
if err != nil {
230226
return err
231227
}
232-
if dispatcher.firstArgIsContext {
233-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
234-
}
235-
return dispatcher.callWithResultsHandlers(argVals, resultsHandlers)
236-
}, nil
228+
return dispatcher.callWithResultsHandlers(ctx, argVals, resultsHandlers)
229+
}
230+
return f, nil
237231
}
238232

239233
func (def *ArgsDef) JSONArgsFunc(commandFunc interface{}, resultsHandlers []ResultsHandler) (JSONArgsFunc, error) {
@@ -242,16 +236,14 @@ func (def *ArgsDef) JSONArgsFunc(commandFunc interface{}, resultsHandlers []Resu
242236
return nil, err
243237
}
244238

245-
return func(ctx context.Context, callerArgs []byte) (err error) {
239+
f := func(ctx context.Context, callerArgs []byte) (err error) {
246240
argVals, err := def.argValsFromJSON(callerArgs)
247241
if err != nil {
248242
return err
249243
}
250-
if dispatcher.firstArgIsContext {
251-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
252-
}
253-
return dispatcher.callWithResultsHandlers(argVals, resultsHandlers)
254-
}, nil
244+
return dispatcher.callWithResultsHandlers(ctx, argVals, resultsHandlers)
245+
}
246+
return f, nil
255247
}
256248

257249
func (def *ArgsDef) StringArgsResultValuesFunc(commandFunc interface{}) (StringArgsResultValuesFunc, error) {
@@ -260,16 +252,14 @@ func (def *ArgsDef) StringArgsResultValuesFunc(commandFunc interface{}) (StringA
260252
return nil, err
261253
}
262254

263-
return func(ctx context.Context, args []string) ([]reflect.Value, error) {
255+
f := func(ctx context.Context, args []string) ([]reflect.Value, error) {
264256
argVals, err := def.argValsFromStringArgs(args)
265257
if err != nil {
266258
return nil, err
267259
}
268-
if dispatcher.firstArgIsContext {
269-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
270-
}
271-
return dispatcher.callAndReturnResults(argVals)
272-
}, nil
260+
return dispatcher.callAndReturnResults(ctx, argVals)
261+
}
262+
return f, nil
273263
}
274264

275265
func (def *ArgsDef) StringMapArgsResultValuesFunc(commandFunc interface{}) (StringMapArgsResultValuesFunc, error) {
@@ -278,16 +268,14 @@ func (def *ArgsDef) StringMapArgsResultValuesFunc(commandFunc interface{}) (Stri
278268
return nil, err
279269
}
280270

281-
return func(ctx context.Context, args map[string]string) ([]reflect.Value, error) {
271+
f := func(ctx context.Context, args map[string]string) ([]reflect.Value, error) {
282272
argVals, err := def.argValsFromStringMapArgs(args)
283273
if err != nil {
284274
return nil, err
285275
}
286-
if dispatcher.firstArgIsContext {
287-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
288-
}
289-
return dispatcher.callAndReturnResults(argVals)
290-
}, nil
276+
return dispatcher.callAndReturnResults(ctx, argVals)
277+
}
278+
return f, nil
291279
}
292280

293281
func (def *ArgsDef) MapArgsResultValuesFunc(commandFunc interface{}) (MapArgsResultValuesFunc, error) {
@@ -296,16 +284,14 @@ func (def *ArgsDef) MapArgsResultValuesFunc(commandFunc interface{}) (MapArgsRes
296284
return nil, err
297285
}
298286

299-
return func(ctx context.Context, args map[string]interface{}) ([]reflect.Value, error) {
287+
f := func(ctx context.Context, args map[string]interface{}) ([]reflect.Value, error) {
300288
argVals, err := def.argValsFromMapArgs(args)
301289
if err != nil {
302290
return nil, err
303291
}
304-
if dispatcher.firstArgIsContext {
305-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
306-
}
307-
return dispatcher.callAndReturnResults(argVals)
308-
}, nil
292+
return dispatcher.callAndReturnResults(ctx, argVals)
293+
}
294+
return f, nil
309295
}
310296

311297
func (def *ArgsDef) JSONArgsResultValuesFunc(commandFunc interface{}) (JSONArgsResultValuesFunc, error) {
@@ -314,14 +300,12 @@ func (def *ArgsDef) JSONArgsResultValuesFunc(commandFunc interface{}) (JSONArgsR
314300
return nil, err
315301
}
316302

317-
return func(ctx context.Context, argsJSON []byte) ([]reflect.Value, error) {
303+
f := func(ctx context.Context, argsJSON []byte) ([]reflect.Value, error) {
318304
argVals, err := def.argValsFromJSON(argsJSON)
319305
if err != nil {
320306
return nil, err
321307
}
322-
if dispatcher.firstArgIsContext {
323-
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
324-
}
325-
return dispatcher.callAndReturnResults(argVals)
326-
}, nil
308+
return dispatcher.callAndReturnResults(ctx, argVals)
309+
}
310+
return f, nil
327311
}

funcdispatcher.go

+34-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package command
22

33
import (
4+
"context"
45
"fmt"
56
"reflect"
67
)
@@ -19,12 +20,19 @@ import (
1920
// getReplacementVal getReplacementValFunc
2021
// }
2122

22-
func functionArgTypesWithoutReplaceables(funcType reflect.Type) (argTypes []reflect.Type) {
23+
// functionArgTypesWithoutReplaceables returns the function argument types except for
24+
// the first argument of type context.Context and callback function arguments.
25+
func functionArgTypesWithoutReplaceables(funcType reflect.Type) (argTypes []reflect.Type, firstArgIsContext bool, insertArgs []insertArg) {
2326
numArgs := funcType.NumIn()
2427
argTypes = make([]reflect.Type, 0, numArgs)
2528
for i := 0; i < numArgs; i++ {
2629
t := funcType.In(i)
2730
if i == 0 && t == typeOfContext {
31+
firstArgIsContext = true
32+
continue
33+
}
34+
if t.Kind() == reflect.Func {
35+
insertArgs = append(insertArgs, insertArg{index: i, value: reflect.Zero(t)})
2836
continue
2937
}
3038
// _, hasPlaceholder := ReplaceArgTypes[t]
@@ -33,7 +41,12 @@ func functionArgTypesWithoutReplaceables(funcType reflect.Type) (argTypes []refl
3341
// }
3442
argTypes = append(argTypes, t)
3543
}
36-
return argTypes
44+
return argTypes, firstArgIsContext, insertArgs
45+
}
46+
47+
type insertArg struct {
48+
index int
49+
value reflect.Value
3750
}
3851

3952
type funcDispatcher struct {
@@ -45,6 +58,7 @@ type funcDispatcher struct {
4558
// argReplacements []argReplacement
4659

4760
firstArgIsContext bool
61+
insertArgs []insertArg
4862
errorIndex int
4963
}
5064

@@ -58,8 +72,6 @@ func newFuncDispatcher(argsDef *ArgsDef, commandFunc interface{}) (disp *funcDis
5872
return nil, fmt.Errorf("expected a function or method, but got %s", disp.funcType)
5973
}
6074

61-
disp.firstArgIsContext = disp.funcType.NumIn() > 0 && disp.funcType.In(0) == typeOfContext
62-
6375
numResults := disp.funcType.NumOut()
6476
if numResults > 0 && disp.funcType.Out(numResults-1) == typeOfError {
6577
disp.errorIndex = numResults - 1
@@ -69,7 +81,8 @@ func newFuncDispatcher(argsDef *ArgsDef, commandFunc interface{}) (disp *funcDis
6981

7082
// disp.argReplacements = nil // TODO
7183

72-
funcArgTypes := functionArgTypesWithoutReplaceables(disp.funcType)
84+
var funcArgTypes []reflect.Type
85+
funcArgTypes, disp.firstArgIsContext, disp.insertArgs = functionArgTypesWithoutReplaceables(disp.funcType)
7386
numArgsDef := len(argsDef.argStructFields)
7487
if numArgsDef != len(funcArgTypes) {
7588
return nil, fmt.Errorf("number of fields in command.Args struct (%d) does not match number of function arguments (%d)", numArgsDef, len(funcArgTypes))
@@ -89,7 +102,14 @@ func newFuncDispatcher(argsDef *ArgsDef, commandFunc interface{}) (disp *funcDis
89102
return disp, nil
90103
}
91104

92-
func (disp *funcDispatcher) callWithResultsHandlers(argVals []reflect.Value, resultsHandlers []ResultsHandler) error {
105+
func (disp *funcDispatcher) callWithResultsHandlers(ctx context.Context, argVals []reflect.Value, resultsHandlers []ResultsHandler) error {
106+
if disp.firstArgIsContext {
107+
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
108+
}
109+
for _, insert := range disp.insertArgs {
110+
argVals = append(argVals[:insert.index], append([]reflect.Value{insert.value}, argVals[insert.index:]...)...)
111+
}
112+
93113
var resultVals []reflect.Value
94114
if disp.funcType.IsVariadic() {
95115
resultVals = disp.funcVal.CallSlice(argVals)
@@ -112,7 +132,14 @@ func (disp *funcDispatcher) callWithResultsHandlers(argVals []reflect.Value, res
112132
return resultErr
113133
}
114134

115-
func (disp *funcDispatcher) callAndReturnResults(argVals []reflect.Value) ([]reflect.Value, error) {
135+
func (disp *funcDispatcher) callAndReturnResults(ctx context.Context, argVals []reflect.Value) ([]reflect.Value, error) {
136+
if disp.firstArgIsContext {
137+
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
138+
}
139+
for _, insert := range disp.insertArgs {
140+
argVals = append(argVals[:insert.index], append([]reflect.Value{insert.value}, argVals[insert.index:]...)...)
141+
}
142+
116143
var resultVals []reflect.Value
117144
if disp.funcType.IsVariadic() {
118145
resultVals = disp.funcVal.CallSlice(argVals)

0 commit comments

Comments
 (0)