From 7728516abfab950dc7a9120caad4870f1f962df5 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Wed, 31 Aug 2022 13:37:01 -0400 Subject: [PATCH] refactor(depinject)!: require exported functions & types (#12797) * refactor(depinject)!: require exported functions * unexport ProviderDescriptor * WIP on tests * fix tests and check for bound instance methods * address merge issues * WIP on checking valid types * WIP on checking valid types * WIP * tests passing * revert changes outside module * docs * docs * docs * add comment * revert * update depinject go.mod versions * remove go.work * add go.work back * go mod tidy * fix docs Co-authored-by: Julien Robert --- depinject/binding_test.go | 30 ++-- depinject/check_type.go | 69 ++++++++ depinject/check_type_test.go | 59 +++++++ depinject/config.go | 22 ++- depinject/container.go | 23 +-- depinject/container_test.go | 234 +++++++++++---------------- depinject/go.sum | 4 +- depinject/invoke_test.go | 55 ++++--- depinject/module_dep.go | 2 +- depinject/provider_desc.go | 123 +++++++++----- depinject/provider_desc_test.go | 138 +++++++++++----- depinject/simple.go | 2 +- depinject/struct_args.go | 26 +-- depinject/testdata/example.dot | 7 + depinject/testdata/example_error.dot | 12 +- orm/go.sum | 2 - 16 files changed, 525 insertions(+), 283 deletions(-) create mode 100644 depinject/check_type.go create mode 100644 depinject/check_type_test.go diff --git a/depinject/binding_test.go b/depinject/binding_test.go index be07b9ab5265..5cfacc628b3c 100644 --- a/depinject/binding_test.go +++ b/depinject/binding_test.go @@ -62,14 +62,18 @@ func (s bindingSuite) TwoImplementationsMallardAndCanvasback() { // we don't need to do anything because this is defined at the type level } +func ProvideMallard() Mallard { return Mallard{} } +func ProvideCanvasback() Canvasback { return Canvasback{} } +func ProvideMarbled() Marbled { return Marbled{} } + func (s *bindingSuite) IsProvided(a string) { switch a { case "Mallard": - s.addConfig(depinject.Provide(func() Mallard { return Mallard{} })) + s.addConfig(depinject.Provide(ProvideMallard)) case "Canvasback": - s.addConfig(depinject.Provide(func() Canvasback { return Canvasback{} })) + s.addConfig(depinject.Provide(ProvideCanvasback)) case "Marbled": - s.addConfig(depinject.Provide(func() Marbled { return Marbled{} })) + s.addConfig(depinject.Provide(ProvideMarbled)) default: s.Fatalf("unexpected duck type %s", a) } @@ -79,18 +83,22 @@ func (s *bindingSuite) addConfig(config depinject.Config) { s.configs = append(s.configs, config) } +func ProvideDuckWrapper(duck Duck) DuckWrapper { + return DuckWrapper{Module: "", Duck: duck} +} + func (s *bindingSuite) WeTryToResolveADuckInGlobalScope() { - s.addConfig(depinject.Provide(func(duck Duck) DuckWrapper { - return DuckWrapper{Module: "", Duck: duck} - })) + s.addConfig(depinject.Provide(ProvideDuckWrapper)) } +func ResolvePond(ducks []DuckWrapper) Pond { return Pond{Ducks: ducks} } + func (s *bindingSuite) resolvePond() *Pond { if s.pond != nil { return s.pond } - s.addConfig(depinject.Provide(func(ducks []DuckWrapper) Pond { return Pond{Ducks: ducks} })) + s.addConfig(depinject.Provide(ResolvePond)) var pond Pond s.err = depinject.Inject(depinject.Configs(s.configs...), &pond) s.pond = &pond @@ -131,10 +139,12 @@ func (s *bindingSuite) ThereIsABindingForAInModule(preferredType string, interfa s.addConfig(depinject.BindInterfaceInModule(moduleName, fullTypeName(interfaceType), fullTypeName(preferredType))) } +func ProvideModuleDuck(duck Duck, key depinject.OwnModuleKey) DuckWrapper { + return DuckWrapper{Module: depinject.ModuleKey(key).Name(), Duck: duck} +} + func (s *bindingSuite) ModuleWantsADuck(module string) { - s.addConfig(depinject.ProvideInModule(module, func(duck Duck) DuckWrapper { - return DuckWrapper{Module: module, Duck: duck} - })) + s.addConfig(depinject.ProvideInModule(module, ProvideModuleDuck)) } func (s *bindingSuite) ModuleResolvesA(module string, duckType string) { diff --git a/depinject/check_type.go b/depinject/check_type.go new file mode 100644 index 000000000000..6ac4de4a60b4 --- /dev/null +++ b/depinject/check_type.go @@ -0,0 +1,69 @@ +package depinject + +import ( + "reflect" + "strings" + "unicode" + + "github.com/pkg/errors" + "golang.org/x/exp/slices" +) + +// isExportedType checks if the type is exported and not in an internal +// package. NOTE: generic type parameters are not checked because this +// would involve complex parsing of type names (there is no reflect API for +// generic type parameters). Parsing of these parameters should be possible +// if someone chooses to do it in the future, but care should be taken to +// be exhaustive and cover all cases like pointers, map's, chan's, etc. which +// means you actually need a real parser and not just a regex. +func isExportedType(typ reflect.Type) error { + name := typ.Name() + pkgPath := typ.PkgPath() + if name != "" && pkgPath != "" { + if unicode.IsLower([]rune(name)[0]) { + return errors.Errorf("type must be exported: %s", typ) + } + + pkgParts := strings.Split(pkgPath, "/") + if slices.Contains(pkgParts, "internal") { + return errors.Errorf("type must not come from an internal package: %s", typ) + } + + return nil + } + + switch typ.Kind() { + case reflect.Array, reflect.Slice, reflect.Chan, reflect.Pointer: + return isExportedType(typ.Elem()) + + case reflect.Func: + numIn := typ.NumIn() + for i := 0; i < numIn; i++ { + err := isExportedType(typ.In(i)) + if err != nil { + return err + } + } + + numOut := typ.NumOut() + for i := 0; i < numOut; i++ { + err := isExportedType(typ.Out(i)) + if err != nil { + return err + } + } + + return nil + + case reflect.Map: + err := isExportedType(typ.Key()) + if err != nil { + return err + } + return isExportedType(typ.Elem()) + + default: + // all the remaining types are builtin, non-composite types (like integers), so they are fine to use + return nil + } +} diff --git a/depinject/check_type_test.go b/depinject/check_type_test.go new file mode 100644 index 000000000000..f01a24000a83 --- /dev/null +++ b/depinject/check_type_test.go @@ -0,0 +1,59 @@ +package depinject + +import ( + "os" + "reflect" + "testing" + + "gotest.tools/v3/assert" + + "cosmossdk.io/depinject/internal/graphviz" +) + +func TestCheckIsExportedType(t *testing.T) { + expectValidType(t, false) + expectValidType(t, uint(0)) + expectValidType(t, uint8(0)) + expectValidType(t, uint16(0)) + expectValidType(t, uint32(0)) + expectValidType(t, uint64(0)) + expectValidType(t, int(0)) + expectValidType(t, int8(0)) + expectValidType(t, int16(0)) + expectValidType(t, int32(0)) + expectValidType(t, int64(0)) + expectValidType(t, float32(0)) + expectValidType(t, float64(0)) + expectValidType(t, complex64(0)) + expectValidType(t, complex128(0)) + expectValidType(t, os.FileMode(0)) + expectValidType(t, [1]int{0}) + expectValidType(t, []int{}) + expectValidType(t, "") + expectValidType(t, make(chan int)) + expectValidType(t, make(<-chan int)) + expectValidType(t, make(chan<- int)) + expectValidType(t, func(int, string) (bool, error) { return false, nil }) + expectValidType(t, func(int, ...string) (bool, error) { return false, nil }) + expectValidType(t, In{}) + expectValidType(t, map[string]In{}) + expectValidType(t, &In{}) + expectValidType(t, uintptr(0)) + expectValidType(t, (*Location)(nil)) + + expectInvalidType(t, container{}, "must be exported") + expectInvalidType(t, &container{}, "must be exported") + expectInvalidType(t, graphviz.Attributes{}, "internal") + expectInvalidType(t, map[string]graphviz.Attributes{}, "internal") + expectInvalidType(t, []graphviz.Attributes{}, "internal") +} + +func expectValidType(t *testing.T, v interface{}) { + t.Helper() + assert.NilError(t, isExportedType(reflect.TypeOf(v))) +} + +func expectInvalidType(t *testing.T, v interface{}, errContains string) { + t.Helper() + assert.ErrorContains(t, isExportedType(reflect.TypeOf(v)), errContains) +} diff --git a/depinject/config.go b/depinject/config.go index 156cb7103240..ff947bc627bd 100644 --- a/depinject/config.go +++ b/depinject/config.go @@ -14,7 +14,10 @@ type Config interface { // Provide defines a container configuration which registers the provided dependency // injection providers. Each provider will be called at most once with the // exception of module-scoped providers which are called at most once per module -// (see ModuleKey). +// (see ModuleKey). All provider functions must be declared, exported functions not +// internal packages and all of their input and output types must also be declared +// and exported and not in internal packages. Note that generic type parameters +// will not be checked, but they should also be exported so that codegen is possible. func Provide(providers ...interface{}) Config { return containerConfig(func(ctr *container) error { return provide(ctr, nil, providers) @@ -23,7 +26,10 @@ func Provide(providers ...interface{}) Config { // ProvideInModule defines container configuration which registers the provided dependency // injection providers that are to be run in the named module. Each provider -// will be called at most once. +// will be called at most once. All provider functions must be declared, exported functions not +// internal packages and all of their input and output types must also be declared +// and exported and not in internal packages. Note that generic type parameters +// will not be checked, but they should also be exported so that codegen is possible. func ProvideInModule(moduleName string, providers ...interface{}) Config { return containerConfig(func(ctr *container) error { if moduleName == "" { @@ -36,7 +42,7 @@ func ProvideInModule(moduleName string, providers ...interface{}) Config { func provide(ctr *container, key *moduleKey, providers []interface{}) error { for _, c := range providers { - rc, err := ExtractProviderDescriptor(c) + rc, err := extractProviderDescriptor(c) if err != nil { return errors.WithStack(err) } @@ -52,6 +58,10 @@ func provide(ctr *container, key *moduleKey, providers []interface{}) error { // at the end of dependency graph configuration in the order in which it was defined. Invokers may not define output // parameters, although they may return an error, and all of their input parameters will be marked as optional so that // invokers impose no additional constraints on the dependency graph. Invoker functions should nil-check all inputs. +// All invoker functions must be declared, exported functions not +// internal packages and all of their input and output types must also be declared +// and exported and not in internal packages. Note that generic type parameters +// will not be checked, but they should also be exported so that codegen is possible. func Invoke(invokers ...interface{}) Config { return containerConfig(func(ctr *container) error { return invoke(ctr, nil, invokers) @@ -63,6 +73,10 @@ func Invoke(invokers ...interface{}) Config { // at the end of dependency graph configuration in the order in which it was defined. Invokers may not define output // parameters, although they may return an error, and all of their input parameters will be marked as optional so that // invokers impose no additional constraints on the dependency graph. Invoker functions should nil-check all inputs. +// All invoker functions must be declared, exported functions not +// internal packages and all of their input and output types must also be declared +// and exported and not in internal packages. Note that generic type parameters +// will not be checked, but they should also be exported so that codegen is possible. func InvokeInModule(moduleName string, invokers ...interface{}) Config { return containerConfig(func(ctr *container) error { if moduleName == "" { @@ -75,7 +89,7 @@ func InvokeInModule(moduleName string, invokers ...interface{}) Config { func invoke(ctr *container, key *moduleKey, invokers []interface{}) error { for _, c := range invokers { - rc, err := ExtractInvokerDescriptor(c) + rc, err := extractInvokerDescriptor(c) if err != nil { return errors.WithStack(err) } diff --git a/depinject/container.go b/depinject/container.go index 13965e9239d7..d2998f585828 100644 --- a/depinject/container.go +++ b/depinject/container.go @@ -25,7 +25,7 @@ type container struct { } type invoker struct { - fn *ProviderDescriptor + fn *providerDescriptor modKey *moduleKey } @@ -55,7 +55,7 @@ func newContainer(cfg *debugConfig) *container { } } -func (c *container) call(provider *ProviderDescriptor, moduleKey *moduleKey) ([]reflect.Value, error) { +func (c *container) call(provider *providerDescriptor, moduleKey *moduleKey) ([]reflect.Value, error) { loc := provider.Location graphNode := c.locationGraphNode(loc, moduleKey) @@ -205,7 +205,7 @@ func (c *container) getExplicitResolver(typ reflect.Type, key *moduleKey) (resol var stringType = reflect.TypeOf("") -func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (interface{}, error) { +func (c *container) addNode(provider *providerDescriptor, key *moduleKey) (interface{}, error) { providerGraphNode := c.locationGraphNode(provider.Location, key) hasModuleKeyParam := false hasOwnModuleKeyParam := false @@ -359,7 +359,7 @@ func (c *container) supply(value reflect.Value, location Location) error { return nil } -func (c *container) addInvoker(provider *ProviderDescriptor, key *moduleKey) error { +func (c *container) addInvoker(provider *providerDescriptor, key *moduleKey) error { // make sure there are no outputs if len(provider.Outputs) > 0 { return fmt.Errorf("invoker function %s should not return any outputs", provider.Location) @@ -373,7 +373,7 @@ func (c *container) addInvoker(provider *ProviderDescriptor, key *moduleKey) err return nil } -func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Location) (reflect.Value, error) { +func (c *container) resolve(in providerInput, moduleKey *moduleKey, caller Location) (reflect.Value, error) { c.resolveStack = append(c.resolveStack, resolveFrame{loc: caller, typ: in.Type}) typeGraphNode := c.typeGraphNode(in.Type) @@ -426,17 +426,17 @@ func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Locat } func (c *container) build(loc Location, outputs ...interface{}) error { - var providerIn []ProviderInput + var providerIn []providerInput for _, output := range outputs { typ := reflect.TypeOf(output) if typ.Kind() != reflect.Pointer { return fmt.Errorf("output type must be a pointer, %s is invalid", typ) } - providerIn = append(providerIn, ProviderInput{Type: typ.Elem()}) + providerIn = append(providerIn, providerInput{Type: typ.Elem()}) } - desc := ProviderDescriptor{ + desc := providerDescriptor{ Inputs: providerIn, Outputs: nil, Fn: func(values []reflect.Value) ([]reflect.Value, error) { @@ -523,7 +523,12 @@ func fullyQualifiedTypeName(typ reflect.Type) string { if typ.Kind() == reflect.Pointer || typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map || typ.Kind() == reflect.Array { pkgType = typ.Elem() } - return fmt.Sprintf("%s/%v", pkgType.PkgPath(), typ) + pkgPath := pkgType.PkgPath() + if pkgPath == "" { + return fmt.Sprintf("%v", typ) + } + + return fmt.Sprintf("%s/%v", pkgPath, typ) } func bindingKeyFromTypeName(typeName string, key *moduleKey) string { diff --git a/depinject/container_test.go b/depinject/container_test.go index 87a37f65123b..cf9394a56b11 100644 --- a/depinject/container_test.go +++ b/depinject/container_test.go @@ -3,7 +3,6 @@ package depinject_test import ( "fmt" "os" - "reflect" "testing" "github.com/stretchr/testify/require" @@ -179,23 +178,26 @@ func TestUnexportedField(t *testing.T) { scenarioConfigProvides = depinject.Configs( depinject.Provide(ProvideMsgClientA), depinject.ProvideInModule("runtime", ProvideKVStoreKey), - depinject.ProvideInModule("a", wrapMethod0(ModuleA{})), - depinject.ProvideInModule("c", wrapMethod0(ModuleUnexportedProvides{})), + depinject.ProvideInModule("a", ModuleA.Provide), + depinject.ProvideInModule("c", ModuleUnexportedProvides.Provide), + depinject.Supply(ModuleA{}, ModuleUnexportedProvides{}), ) scenarioConfigDependency = depinject.Configs( depinject.Provide(ProvideMsgClientA), depinject.ProvideInModule("runtime", ProvideKVStoreKey), - depinject.ProvideInModule("a", wrapMethod0(ModuleA{})), - depinject.ProvideInModule("c", wrapMethod0(ModuleUnexportedDependency{})), + depinject.ProvideInModule("a", ModuleA.Provide), + depinject.ProvideInModule("c", ModuleUnexportedDependency.Provide), + depinject.Supply(ModuleA{}, ModuleUnexportedDependency{}), ) scenarioConfigProvidesDependency = depinject.Configs( depinject.Provide(ProvideMsgClientA), depinject.ProvideInModule("runtime", ProvideKVStoreKey), - depinject.ProvideInModule("a", wrapMethod0(ModuleA{})), - depinject.ProvideInModule("c", wrapMethod0(ModuleUnexportedProvides{})), - depinject.ProvideInModule("d", wrapMethod0(ModuleD{})), + depinject.ProvideInModule("a", ModuleA.Provide), + depinject.ProvideInModule("c", ModuleUnexportedProvides.Provide), + depinject.ProvideInModule("d", ModuleD.Provide), + depinject.Supply(ModuleA{}, ModuleUnexportedProvides{}, ModuleD{}), ) ) @@ -237,8 +239,9 @@ func TestUnexportedField(t *testing.T) { var scenarioConfig = depinject.Configs( depinject.Provide(ProvideMsgClientA), depinject.ProvideInModule("runtime", ProvideKVStoreKey), - depinject.ProvideInModule("a", wrapMethod0(ModuleA{})), - depinject.ProvideInModule("b", wrapMethod0(ModuleB{})), + depinject.ProvideInModule("a", ModuleA.Provide), + depinject.ProvideInModule("b", ModuleB.Provide), + depinject.Supply(ModuleA{}, ModuleB{}), ) func TestScenario(t *testing.T) { @@ -273,21 +276,6 @@ func TestScenario(t *testing.T) { }, b) } -func wrapMethod0(module interface{}) interface{} { - methodFn := reflect.TypeOf(module).Method(0).Func.Interface() - ctrInfo, err := depinject.ExtractProviderDescriptor(methodFn) - if err != nil { - panic(err) - } - - ctrInfo.Inputs = ctrInfo.Inputs[1:] - fn := ctrInfo.Fn - ctrInfo.Fn = func(values []reflect.Value) ([]reflect.Value, error) { - return fn(append([]reflect.Value{reflect.ValueOf(module)}, values...)) - } - return ctrInfo -} - func TestResolveError(t *testing.T) { var x string require.Error(t, depinject.Inject( @@ -316,64 +304,41 @@ func TestErrorOption(t *testing.T) { require.Error(t, err) } -func TestBadCtr(t *testing.T) { - _, err := depinject.ExtractProviderDescriptor(KeeperA{}) - require.Error(t, err) -} - func TestTrivial(t *testing.T) { require.NoError(t, depinject.Inject(depinject.Configs())) } -func TestErrorFunc(t *testing.T) { - _, err := depinject.ExtractProviderDescriptor( - func() (error, int) { return nil, 0 }, - ) - require.Error(t, err) - - _, err = depinject.ExtractProviderDescriptor( - func() (int, error) { return 0, nil }, - ) - require.NoError(t, err) - - var x int - require.Error(t, - depinject.Inject( - depinject.Provide(func() (int, error) { - return 0, fmt.Errorf("the error") - }), - &x, - )) -} +func Provide0() int { return 0 } +func Provide1() int { return 1 } func TestSimple(t *testing.T) { var x int require.NoError(t, depinject.Inject( - depinject.Provide( - func() int { return 1 }, - ), + depinject.Provide(Provide1), &x, ), ) require.Error(t, depinject.Inject( - depinject.Provide( - func() int { return 0 }, - func() int { return 1 }, - ), + depinject.Provide(Provide0, Provide1), &x, ), ) } +func ProvideModuleScoped0(depinject.ModuleKey) int { return 0 } +func ProvideModuleScoped1(depinject.ModuleKey) int { return 1 } +func ProvideFloat64FromInt(x int) float64 { return float64(x) } +func ProvideFloat32FromInt(x int) float32 { return float32(x) } + func TestModuleScoped(t *testing.T) { var x int require.Error(t, depinject.Inject( depinject.Provide( - func(depinject.ModuleKey) int { return 0 }, + ProvideModuleScoped0, ), &x, ), @@ -384,12 +349,10 @@ func TestModuleScoped(t *testing.T) { depinject.Inject( depinject.Configs( depinject.Provide( - func(depinject.ModuleKey) int { return 0 }, - func() int { return 1 }, - ), - depinject.ProvideInModule("a", - func(x int) float64 { return float64(x) }, + ProvideModuleScoped0, + Provide1, ), + depinject.ProvideInModule("a", ProvideFloat64FromInt), ), &y, ), @@ -399,12 +362,10 @@ func TestModuleScoped(t *testing.T) { depinject.Inject( depinject.Configs( depinject.Provide( - func() int { return 0 }, - func(depinject.ModuleKey) int { return 1 }, - ), - depinject.ProvideInModule("a", - func(x int) float64 { return float64(x) }, + Provide0, + ProvideModuleScoped0, ), + depinject.ProvideInModule("a", ProvideFloat64FromInt), ), &y, ), @@ -414,12 +375,10 @@ func TestModuleScoped(t *testing.T) { depinject.Inject( depinject.Configs( depinject.Provide( - func(depinject.ModuleKey) int { return 0 }, - func(depinject.ModuleKey) int { return 1 }, - ), - depinject.ProvideInModule("a", - func(x int) float64 { return float64(x) }, + ProvideModuleScoped0, + ProvideModuleScoped1, ), + depinject.ProvideInModule("a", ProvideFloat64FromInt), ), &y, ), @@ -428,12 +387,8 @@ func TestModuleScoped(t *testing.T) { require.NoError(t, depinject.Inject( depinject.Configs( - depinject.Provide( - func(depinject.ModuleKey) int { return 0 }, - ), - depinject.ProvideInModule("a", - func(x int) float64 { return float64(x) }, - ), + depinject.Provide(ProvideModuleScoped0), + depinject.ProvideInModule("a", ProvideFloat64FromInt), ), &y, ), @@ -442,12 +397,8 @@ func TestModuleScoped(t *testing.T) { require.Error(t, depinject.Inject( depinject.Configs( - depinject.Provide( - func(depinject.ModuleKey) int { return 0 }, - ), - depinject.ProvideInModule("", - func(x int) float64 { return float64(x) }, - ), + depinject.Provide(ProvideModuleScoped0), + depinject.ProvideInModule("", ProvideFloat64FromInt), ), &y, ), @@ -457,12 +408,10 @@ func TestModuleScoped(t *testing.T) { require.NoError(t, depinject.Inject( depinject.Configs( - depinject.Provide( - func(depinject.ModuleKey) int { return 0 }, - ), + depinject.Provide(ProvideModuleScoped0), depinject.ProvideInModule("a", - func(x int) float64 { return float64(x) }, - func(x int) float32 { return float32(x) }, + ProvideFloat64FromInt, + ProvideFloat32FromInt, ), ), &y, &z, @@ -475,6 +424,18 @@ type OnePerModuleInt int func (OnePerModuleInt) IsOnePerModuleType() {} +func OnePerModuleInt3() OnePerModuleInt { return 3 } +func OnePerModuleInt4() OnePerModuleInt { return 4 } +func CollectOnePerModuleInts(x map[string]OnePerModuleInt) string { + sum := 0 + for _, v := range x { + sum += int(v) + } + return fmt.Sprintf("%d", sum) +} + +func ReturnOnePerModuleMap() map[string]OnePerModuleInt { return nil } + func TestOnePerModule(t *testing.T) { var x OnePerModuleInt require.Error(t, @@ -487,19 +448,9 @@ func TestOnePerModule(t *testing.T) { require.NoError(t, depinject.Inject( depinject.Configs( - depinject.ProvideInModule("a", - func() OnePerModuleInt { return 3 }, - ), - depinject.ProvideInModule("b", - func() OnePerModuleInt { return 4 }, - ), - depinject.Provide(func(x map[string]OnePerModuleInt) string { - sum := 0 - for _, v := range x { - sum += int(v) - } - return fmt.Sprintf("%d", sum) - }), + depinject.ProvideInModule("a", OnePerModuleInt3), + depinject.ProvideInModule("b", OnePerModuleInt4), + depinject.Provide(CollectOnePerModuleInts), ), &y, &z, @@ -516,8 +467,8 @@ func TestOnePerModule(t *testing.T) { require.Error(t, depinject.Inject( depinject.ProvideInModule("a", - func() OnePerModuleInt { return 0 }, - func() OnePerModuleInt { return 0 }, + OnePerModuleInt3, + OnePerModuleInt3, ), &m, ), @@ -527,7 +478,7 @@ func TestOnePerModule(t *testing.T) { require.Error(t, depinject.Inject( depinject.Provide( - func() OnePerModuleInt { return 0 }, + OnePerModuleInt3, ), &m, ), @@ -536,19 +487,14 @@ func TestOnePerModule(t *testing.T) { require.Error(t, depinject.Inject( - depinject.Provide( - func() map[string]OnePerModuleInt { return nil }, - ), + depinject.Provide(ReturnOnePerModuleMap), &m, ), "bad return type", ) require.NoError(t, - depinject.Inject( - depinject.Configs(), - &m, - ), + depinject.Inject(depinject.Configs(), &m), "no providers", ) } @@ -557,21 +503,24 @@ type ManyPerContainerInt int func (ManyPerContainerInt) IsManyPerContainerType() {} +func ManyPerContainerInt4() ManyPerContainerInt { return 4 } +func ManyPerContainerInt9() ManyPerContainerInt { return 9 } +func CollectManyPerContainerInts(xs []ManyPerContainerInt) string { + sum := 0 + for _, x := range xs { + sum += int(x) + } + return fmt.Sprintf("%d", sum) +} + func TestManyPerContainer(t *testing.T) { var xs []ManyPerContainerInt var sum string require.NoError(t, depinject.Inject( depinject.Provide( - func() ManyPerContainerInt { return 4 }, - func() ManyPerContainerInt { return 9 }, - func(xs []ManyPerContainerInt) string { - sum := 0 - for _, x := range xs { - sum += int(x) - } - return fmt.Sprintf("%d", sum) - }, + ManyPerContainerInt4, ManyPerContainerInt9, + CollectManyPerContainerInts, ), &xs, &sum, @@ -584,12 +533,7 @@ func TestManyPerContainer(t *testing.T) { var z ManyPerContainerInt require.Error(t, - depinject.Inject( - depinject.Provide( - func() ManyPerContainerInt { return 0 }, - ), - &z, - ), + depinject.Inject(depinject.Provide(ManyPerContainerInt4), &z), "bad input type", ) @@ -657,6 +601,14 @@ type TestOutput struct { Y int64 } +func ProvideTestOutput() (TestOutput, error) { + return TestOutput{X: "A", Y: -10}, nil +} + +func ProvideTestOutputErr() (TestOutput, error) { + return TestOutput{}, fmt.Errorf("error") +} + func TestStructArgs(t *testing.T) { var input TestInput require.Error(t, depinject.Inject(depinject.Configs(), &input)) @@ -678,18 +630,14 @@ func TestStructArgs(t *testing.T) { var x string var y int64 require.NoError(t, depinject.Inject( - depinject.Provide(func() (TestOutput, error) { - return TestOutput{X: "A", Y: -10}, nil - }), + depinject.Provide(ProvideTestOutput), &x, &y, )) require.Equal(t, "A", x) require.Equal(t, int64(-10), y) require.Error(t, depinject.Inject( - depinject.Provide(func() (TestOutput, error) { - return TestOutput{}, fmt.Errorf("error") - }), + depinject.Provide(ProvideTestOutputErr), &x, )) } @@ -703,11 +651,21 @@ func TestDebugOptions(t *testing.T) { stdout := os.Stdout os.Stdout = outfile defer func() { os.Stdout = stdout }() - defer os.Remove(outfile.Name()) + defer func() { + err := os.Remove(outfile.Name()) + if err != nil { + panic(err) + } + }() graphfile, err := os.CreateTemp("", "graph") require.NoError(t, err) - defer os.Remove(graphfile.Name()) + defer func() { + err := os.Remove(graphfile.Name()) + if err != nil { + panic(err) + } + }() require.NoError(t, depinject.InjectDebug( depinject.DebugOptions( @@ -748,8 +706,8 @@ func TestGraphAndLogOutput(t *testing.T) { badConfig := depinject.Configs( depinject.ProvideInModule("runtime", ProvideKVStoreKey), - depinject.ProvideInModule("a", wrapMethod0(ModuleA{})), - depinject.ProvideInModule("b", wrapMethod0(ModuleB{})), + depinject.ProvideInModule("a", ModuleA.Provide), + depinject.ProvideInModule("b", ModuleB.Provide), ) require.Error(t, depinject.InjectDebug(debugOpts, badConfig, &b)) golden.Assert(t, graphOut, "example_error.dot") diff --git a/depinject/go.sum b/depinject/go.sum index b77f07b8e003..7ef1565d8e60 100644 --- a/depinject/go.sum +++ b/depinject/go.sum @@ -83,5 +83,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.3.0 h1:MfDY1b1/0xN1CyMlQDac0ziEy9zJQd9CXBRRDHw2jJo= gotest.tools/v3 v3.3.0/go.mod h1:Mcr9QNxkg0uMvy/YElmo4SpXgJKWgQvYrT7Kw5RzJ1A= -pgregory.net/rapid v0.4.8 h1:d+5SGZWUbJPbl3ss6tmPFqnNeQR6VDOFly+eTjwPiEw= -pgregory.net/rapid v0.4.8/go.mod h1:Z5PbWqjvWR1I3UGjvboUuan4fe4ZYEYNLNQLExzCoUs= +pgregory.net/rapid v0.5.2 h1:zC+jmuzcz5yJvG/igG06aLx8kcGmZY435NcuyhblKjY= +pgregory.net/rapid v0.5.2/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= diff --git a/depinject/invoke_test.go b/depinject/invoke_test.go index 69dec08c57f0..3f2a7b023c53 100644 --- a/depinject/invoke_test.go +++ b/depinject/invoke_test.go @@ -10,62 +10,79 @@ import ( ) func TestInvoke(t *testing.T) { - gocuke.NewRunner(t, &invokeSuite{}). + gocuke.NewRunner(t, &InvokeSuite{}). Path("features/invoke.feature"). + Step("an int provider returning 5", (*InvokeSuite).AnIntProviderReturning5). + Step(`a string pointer provider pointing to "foo"`, (*InvokeSuite).AStringPointerProviderPointingToFoo). Run() } -type invokeSuite struct { +type InvokeSuite struct { gocuke.TestingT configs []depinject.Config i int sp *string } -func (s *invokeSuite) AnInvokerRequestingAnIntAndStringPointer() { - s.configs = append(s.configs, depinject.Invoke(s.intStringPointerInvoker)) +func (s *InvokeSuite) AnInvokerRequestingAnIntAndStringPointer() { + s.configs = append(s.configs, + depinject.Supply(s), + depinject.Invoke((*InvokeSuite).IntStringPointerInvoker), + ) } -func (s *invokeSuite) intStringPointerInvoker(i int, sp *string) { +func (s *InvokeSuite) IntStringPointerInvoker(i int, sp *string) { s.i = i s.sp = sp } -func (s *invokeSuite) TheContainerIsBuilt() { +func (s *InvokeSuite) TheContainerIsBuilt() { assert.NilError(s, depinject.Inject(depinject.Configs(s.configs...))) } -func (s *invokeSuite) TheInvokerWillGetTheIntParameterSetTo(a int64) { +func (s *InvokeSuite) TheInvokerWillGetTheIntParameterSetTo(a int64) { assert.Equal(s, int(a), s.i) } -func (s *invokeSuite) TheInvokerWillGetTheStringPointerParameterSetToNil() { +func (s *InvokeSuite) TheInvokerWillGetTheStringPointerParameterSetToNil() { if s.sp != nil { s.Fatalf("expected a nil string pointer, got %s", *s.sp) } } -func (s *invokeSuite) AnIntProviderReturning(a int64) { - s.configs = append(s.configs, depinject.Provide(func() int { return int(a) })) +func IntProvider5() int { return 5 } + +func (s *InvokeSuite) AnIntProviderReturning5() { + s.configs = append(s.configs, depinject.Provide(IntProvider5)) +} + +func StringPtrProviderFoo() *string { + x := "foo" + return &x } -func (s *invokeSuite) AStringPointerProviderPointingTo(a string) { - s.configs = append(s.configs, depinject.Provide(func() *string { return &a })) +func (s *InvokeSuite) AStringPointerProviderPointingToFoo() { + s.configs = append(s.configs, depinject.Provide(StringPtrProviderFoo)) } -func (s *invokeSuite) TheInvokerWillGetTheStringPointerParameterSetTo(a string) { +func (s *InvokeSuite) TheInvokerWillGetTheStringPointerParameterSetTo(a string) { if s.sp == nil { s.Fatalf("expected a non-nil string pointer") } assert.Equal(s, a, *s.sp) } -func (s *invokeSuite) AnInvokerRequestingAnIntAndStringPointerRunInModule(a string) { - s.configs = append(s.configs, depinject.InvokeInModule(a, s.intStringPointerInvoker)) +func (s *InvokeSuite) AnInvokerRequestingAnIntAndStringPointerRunInModule(a string) { + s.configs = append(s.configs, + depinject.Supply(s), + depinject.InvokeInModule(a, (*InvokeSuite).IntStringPointerInvoker), + ) +} + +func ProvideLenModuleKey(key depinject.ModuleKey) int { + return len(key.Name()) } -func (s *invokeSuite) AModulescopedIntProviderWhichReturnsTheLengthOfTheModuleName() { - s.configs = append(s.configs, depinject.Provide(func(key depinject.ModuleKey) int { - return len(key.Name()) - })) +func (s *InvokeSuite) AModulescopedIntProviderWhichReturnsTheLengthOfTheModuleName() { + s.configs = append(s.configs, depinject.Provide(ProvideLenModuleKey)) } diff --git a/depinject/module_dep.go b/depinject/module_dep.go index 08cf0f81bf96..02683cbf2700 100644 --- a/depinject/module_dep.go +++ b/depinject/module_dep.go @@ -7,7 +7,7 @@ import ( ) type moduleDepProvider struct { - provider *ProviderDescriptor + provider *providerDescriptor calledForModule map[*moduleKey]bool valueMap map[*moduleKey][]reflect.Value } diff --git a/depinject/provider_desc.go b/depinject/provider_desc.go index 90d1dfd624be..b9fd3b14daf0 100644 --- a/depinject/provider_desc.go +++ b/depinject/provider_desc.go @@ -2,21 +2,24 @@ package depinject import ( "reflect" + "strings" + "unicode" "github.com/pkg/errors" + "golang.org/x/exp/slices" ) -// ProviderDescriptor defines a special provider type that is defined by +// providerDescriptor defines a special provider type that is defined by // reflection. It should be passed as a value to the Provide function. // Ex: // -// option.Provide(ProviderDescriptor{ ... }) -type ProviderDescriptor struct { +// option.Provide(providerDescriptor{ ... }) +type providerDescriptor struct { // Inputs defines the in parameter types to Fn. - Inputs []ProviderInput + Inputs []providerInput // Outputs defines the out parameter types to Fn. - Outputs []ProviderOutput + Outputs []providerOutput // Fn defines the provider function. Fn func([]reflect.Value) ([]reflect.Value, error) @@ -26,85 +29,94 @@ type ProviderDescriptor struct { Location Location } -type ProviderInput struct { +type providerInput struct { Type reflect.Type Optional bool } -type ProviderOutput struct { +type providerOutput struct { Type reflect.Type } -func ExtractProviderDescriptor(provider interface{}) (ProviderDescriptor, error) { - rctr, ok := provider.(ProviderDescriptor) - if !ok { - var err error - rctr, err = doExtractProviderDescriptor(provider) - if err != nil { - return ProviderDescriptor{}, err - } +func extractProviderDescriptor(provider interface{}) (providerDescriptor, error) { + rctr, err := doExtractProviderDescriptor(provider) + if err != nil { + return providerDescriptor{}, err } - - return expandStructArgsProvider(rctr) + return postProcessProvider(rctr) } -func ExtractInvokerDescriptor(provider interface{}) (ProviderDescriptor, error) { - rctr, ok := provider.(ProviderDescriptor) - if !ok { - var err error - rctr, err = doExtractProviderDescriptor(provider) - - // mark all inputs as optional - for i, input := range rctr.Inputs { - input.Optional = true - rctr.Inputs[i] = input - } +func extractInvokerDescriptor(provider interface{}) (providerDescriptor, error) { + rctr, err := doExtractProviderDescriptor(provider) + if err != nil { + return providerDescriptor{}, err + } - if err != nil { - return ProviderDescriptor{}, err - } + // mark all inputs as optional + for i, input := range rctr.Inputs { + input.Optional = true + rctr.Inputs[i] = input } - return expandStructArgsProvider(rctr) + return postProcessProvider(rctr) } -func doExtractProviderDescriptor(ctr interface{}) (ProviderDescriptor, error) { +func doExtractProviderDescriptor(ctr interface{}) (providerDescriptor, error) { val := reflect.ValueOf(ctr) typ := val.Type() if typ.Kind() != reflect.Func { - return ProviderDescriptor{}, errors.Errorf("expected a Func type, got %v", typ) + return providerDescriptor{}, errors.Errorf("expected a Func type, got %v", typ) } - loc := LocationFromPC(val.Pointer()) + loc := LocationFromPC(val.Pointer()).(*location) + nameParts := strings.Split(loc.name, ".") + if len(nameParts) == 0 { + return providerDescriptor{}, errors.Errorf("missing function name %s", loc) + } + + lastNamePart := nameParts[len(nameParts)-1] + + if unicode.IsLower([]rune(lastNamePart)[0]) { + return providerDescriptor{}, errors.Errorf("function must be exported: %s", loc) + } + + if strings.Contains(lastNamePart, "-") { + return providerDescriptor{}, errors.Errorf("function can't be used as a provider (it might be a bound instance method): %s", loc) + } + + pkgParts := strings.Split(loc.pkg, "/") + if slices.Contains(pkgParts, "internal") { + return providerDescriptor{}, errors.Errorf("function must not be in an internal package: %s", loc) + } if typ.IsVariadic() { - return ProviderDescriptor{}, errors.Errorf("variadic function can't be used as a provider: %s", loc) + return providerDescriptor{}, errors.Errorf("variadic function can't be used as a provider: %s", loc) } numIn := typ.NumIn() - in := make([]ProviderInput, numIn) + in := make([]providerInput, numIn) for i := 0; i < numIn; i++ { - in[i] = ProviderInput{ + in[i] = providerInput{ Type: typ.In(i), } } errIdx := -1 numOut := typ.NumOut() - var out []ProviderOutput + var out []providerOutput for i := 0; i < numOut; i++ { t := typ.Out(i) if t == errType { if i != numOut-1 { - return ProviderDescriptor{}, errors.Errorf("output error parameter is not last parameter in function %s", loc) + return providerDescriptor{}, errors.Errorf("output error parameter is not last parameter in function %s", loc) } errIdx = i } else { - out = append(out, ProviderOutput{Type: t}) + out = append(out, providerOutput{Type: t}) } } - return ProviderDescriptor{ + return providerDescriptor{ Inputs: in, Outputs: out, Fn: func(values []reflect.Value) ([]reflect.Value, error) { @@ -123,3 +135,30 @@ func doExtractProviderDescriptor(ctr interface{}) (ProviderDescriptor, error) { } var errType = reflect.TypeOf((*error)(nil)).Elem() + +func postProcessProvider(descriptor providerDescriptor) (providerDescriptor, error) { + descriptor, err := expandStructArgsProvider(descriptor) + if err != nil { + return providerDescriptor{}, err + } + err = checkInputAndOutputTypes(descriptor) + return descriptor, err +} + +func checkInputAndOutputTypes(descriptor providerDescriptor) error { + for _, input := range descriptor.Inputs { + err := isExportedType(input.Type) + if err != nil { + return err + } + } + + for _, output := range descriptor.Outputs { + err := isExportedType(output.Type) + if err != nil { + return err + } + } + + return nil +} diff --git a/depinject/provider_desc_test.go b/depinject/provider_desc_test.go index a768e8787623..b7d6d349377b 100644 --- a/depinject/provider_desc_test.go +++ b/depinject/provider_desc_test.go @@ -1,29 +1,56 @@ -package depinject_test +package depinject import ( "reflect" "testing" - "cosmossdk.io/depinject" + "gotest.tools/v3/assert" + + "cosmossdk.io/depinject/internal/codegen" + "cosmossdk.io/depinject/internal/graphviz" ) type StructIn struct { - depinject.In + In X int Y float64 `optional:"true"` } type BadOptional struct { - depinject.In + In X int `optional:"foo"` } type StructOut struct { - depinject.Out + Out X string Y []byte } +func privateProvider(int, float64) (string, []byte) { return "", nil } + +func PrivateInAndOut(containerConfig) *container { return nil } + +func InternalInAndOut(graphviz.Attributes) *codegen.FileGen { return nil } + +type SomeStruct struct{} + +func (SomeStruct) privateMethod() int { return 0 } + +func SimpleArgs(int, float64) (string, []byte) { return "", nil } + +func SimpleArgsWithError(int, float64) (string, []byte, error) { return "", nil, nil } + +func StructInAndOut(_ float32, _ StructIn, _ byte) (int16, StructOut, int32, error) { + return int16(0), StructOut{}, int32(0), nil +} + +func BadErrorPosition() (error, int) { return nil, 0 } + +func BadOptionalFn(_ BadOptional) int { return 0 } + +func Variadic(...float64) int { return 0 } + func TestExtractProviderDescriptor(t *testing.T) { var ( intType = reflect.TypeOf(0) @@ -39,67 +66,102 @@ func TestExtractProviderDescriptor(t *testing.T) { tests := []struct { name string ctr interface{} - wantIn []depinject.ProviderInput - wantOut []depinject.ProviderOutput - wantErr bool + wantIn []providerInput + wantOut []providerOutput + wantErr string }{ + { + "private", + privateProvider, + nil, + nil, + "function must be exported", + }, + { + "private method", + SomeStruct.privateMethod, + nil, + nil, + "function must be exported", + }, + { + "private in and out", + PrivateInAndOut, + nil, + nil, + "type must be exported", + }, + { + "internal in and out", + InternalInAndOut, + nil, + nil, + "internal", + }, + { + "struct", + SomeStruct{}, + nil, + nil, + "expected a Func type", + }, { "simple args", - func(x int, y float64) (string, []byte) { return "", nil }, - []depinject.ProviderInput{{Type: intType}, {Type: float64Type}}, - []depinject.ProviderOutput{{Type: stringType}, {Type: bytesTyp}}, - false, + SimpleArgs, + []providerInput{{Type: intType}, {Type: float64Type}}, + []providerOutput{{Type: stringType}, {Type: bytesTyp}}, + "", }, { "simple args with error", - func(x int, y float64) (string, []byte, error) { return "", nil, nil }, - []depinject.ProviderInput{{Type: intType}, {Type: float64Type}}, - []depinject.ProviderOutput{{Type: stringType}, {Type: bytesTyp}}, - false, + SimpleArgsWithError, + []providerInput{{Type: intType}, {Type: float64Type}}, + []providerOutput{{Type: stringType}, {Type: bytesTyp}}, + "", }, { "struct in and out", - func(_ float32, _ StructIn, _ byte) (int16, StructOut, int32, error) { - return int16(0), StructOut{}, int32(0), nil - }, - []depinject.ProviderInput{{Type: float32Type}, {Type: intType}, {Type: float64Type, Optional: true}, {Type: byteTyp}}, - []depinject.ProviderOutput{{Type: int16Type}, {Type: stringType}, {Type: bytesTyp}, {Type: int32Type}}, - false, + StructInAndOut, + []providerInput{{Type: float32Type}, {Type: intType}, {Type: float64Type, Optional: true}, {Type: byteTyp}}, + []providerOutput{{Type: int16Type}, {Type: stringType}, {Type: bytesTyp}, {Type: int32Type}}, + "", }, { "error bad position", - func() (error, int) { return nil, 0 }, + BadErrorPosition, nil, nil, - true, + "error parameter is not last parameter", }, { "bad optional", - func(_ BadOptional) int { return 0 }, + BadOptionalFn, nil, nil, - true, + "bad optional tag", }, { "variadic", - func(...float64) int { return 0 }, + Variadic, nil, nil, - true, + "variadic function can't be used", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := depinject.ExtractProviderDescriptor(tt.ctr) - if (err != nil) != tt.wantErr { - t.Errorf("ExtractProviderDescriptor() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got.Inputs, tt.wantIn) { - t.Errorf("ExtractProviderDescriptor() got = %v, want %v", got.Inputs, tt.wantIn) - } - if !reflect.DeepEqual(got.Outputs, tt.wantOut) { - t.Errorf("ExtractProviderDescriptor() got = %v, want %v", got.Outputs, tt.wantOut) + got, err := extractProviderDescriptor(tt.ctr) + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + } else { + assert.NilError(t, err) + + if !reflect.DeepEqual(got.Inputs, tt.wantIn) { + t.Errorf("extractProviderDescriptor() got = %v, want %v", got.Inputs, tt.wantIn) + } + if !reflect.DeepEqual(got.Outputs, tt.wantOut) { + t.Errorf("extractProviderDescriptor() got = %v, want %v", got.Outputs, tt.wantOut) + } } }) } diff --git a/depinject/simple.go b/depinject/simple.go index 9b8caff8c549..ce2b8314e917 100644 --- a/depinject/simple.go +++ b/depinject/simple.go @@ -7,7 +7,7 @@ import ( ) type simpleProvider struct { - provider *ProviderDescriptor + provider *providerDescriptor called bool values []reflect.Value moduleKey *moduleKey diff --git a/depinject/struct_args.go b/depinject/struct_args.go index d82dd3c2f856..a62bc46f3326 100644 --- a/depinject/struct_args.go +++ b/depinject/struct_args.go @@ -37,15 +37,15 @@ type isOut interface{ isOut() } var isOutType = reflect.TypeOf((*isOut)(nil)).Elem() -func expandStructArgsProvider(provider ProviderDescriptor) (ProviderDescriptor, error) { +func expandStructArgsProvider(provider providerDescriptor) (providerDescriptor, error) { var structArgsInInput bool - var newIn []ProviderInput + var newIn []providerInput for _, in := range provider.Inputs { if in.Type.AssignableTo(isInType) { structArgsInInput = true inTypes, err := structArgsInTypes(in.Type) if err != nil { - return ProviderDescriptor{}, err + return providerDescriptor{}, err } newIn = append(newIn, inTypes...) } else { @@ -56,7 +56,7 @@ func expandStructArgsProvider(provider ProviderDescriptor) (ProviderDescriptor, newOut, structArgsInOutput := expandStructArgsOutTypes(provider.Outputs) if structArgsInInput || structArgsInOutput { - return ProviderDescriptor{ + return providerDescriptor{ Inputs: newIn, Outputs: newOut, Fn: expandStructArgsFn(provider), @@ -67,7 +67,7 @@ func expandStructArgsProvider(provider ProviderDescriptor) (ProviderDescriptor, return provider, nil } -func expandStructArgsFn(provider ProviderDescriptor) func(inputs []reflect.Value) ([]reflect.Value, error) { +func expandStructArgsFn(provider providerDescriptor) func(inputs []reflect.Value) ([]reflect.Value, error) { fn := provider.Fn inParams := provider.Inputs outParams := provider.Outputs @@ -106,9 +106,9 @@ func expandStructArgsFn(provider ProviderDescriptor) func(inputs []reflect.Value } } -func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) { +func structArgsInTypes(typ reflect.Type) ([]providerInput, error) { n := typ.NumField() - var res []ProviderInput + var res []providerInput for i := 0; i < n; i++ { f := typ.Field(i) if f.Type.AssignableTo(isInType) { @@ -125,7 +125,7 @@ func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) { } } - res = append(res, ProviderInput{ + res = append(res, providerInput{ Type: f.Type, Optional: optional, }) @@ -133,9 +133,9 @@ func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) { return res, nil } -func expandStructArgsOutTypes(outputs []ProviderOutput) ([]ProviderOutput, bool) { +func expandStructArgsOutTypes(outputs []providerOutput) ([]providerOutput, bool) { foundStructArgs := false - var newOut []ProviderOutput + var newOut []providerOutput for _, out := range outputs { if out.Type.AssignableTo(isOutType) { foundStructArgs = true @@ -147,16 +147,16 @@ func expandStructArgsOutTypes(outputs []ProviderOutput) ([]ProviderOutput, bool) return newOut, foundStructArgs } -func structArgsOutTypes(typ reflect.Type) []ProviderOutput { +func structArgsOutTypes(typ reflect.Type) []providerOutput { n := typ.NumField() - var res []ProviderOutput + var res []providerOutput for i := 0; i < n; i++ { f := typ.Field(i) if f.Type.AssignableTo(isOutType) { continue } - res = append(res, ProviderOutput{ + res = append(res, providerOutput{ Type: f.Type, }) } diff --git a/depinject/testdata/example.dot b/depinject/testdata/example.dot index ea61e3559e4d..76a636abfb4a 100644 --- a/depinject/testdata/example.dot +++ b/depinject/testdata/example.dot @@ -20,24 +20,31 @@ digraph "" { "cosmossdk.io/depinject_test.KVStoreKey"[color="black", fontcolor="black", penwidth="1.5"]; "cosmossdk.io/depinject_test.KeeperA"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject_test.KeeperB"[color="black", fontcolor="black", penwidth="1.5"]; + "cosmossdk.io/depinject_test.ModuleA"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; + "cosmossdk.io/depinject_test.ModuleB"[color="black", fontcolor="black", penwidth="1.5"]; "cosmossdk.io/depinject_test.MsgClientA"[color="black", fontcolor="black", penwidth="1.5"]; "cosmossdk.io/depinject_test.ProvideMsgClientA"[color="black", fontcolor="black", penwidth="1.5", shape="box"]; "cosmossdk.io/depinject_test.TestGraphAndLogOutput"[color="black", fontcolor="black", penwidth="1.5", shape="hexagon"]; + "cosmossdk.io/depinject_test.init"[color="black", fontcolor="black", penwidth="1.5", shape="box"]; "map[string]cosmossdk.io/depinject_test.Handler"[color="lightgrey", comment="one-per-module", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject.ModuleKey" -> "cosmossdk.io/depinject_test.ProvideMsgClientA"; "cosmossdk.io/depinject_test.ProvideMsgClientA" -> "cosmossdk.io/depinject_test.MsgClientA"; "cosmossdk.io/depinject.ModuleKey" -> "cosmossdk.io/depinject_test.ProvideKVStoreKey"; "cosmossdk.io/depinject_test.ProvideKVStoreKey" -> "cosmossdk.io/depinject_test.KVStoreKey"; + "cosmossdk.io/depinject_test.ModuleA" -> "cosmossdk.io/depinject_test.ModuleA.Provide"; "cosmossdk.io/depinject_test.KVStoreKey" -> "cosmossdk.io/depinject_test.ModuleA.Provide"; "cosmossdk.io/depinject.OwnModuleKey" -> "cosmossdk.io/depinject_test.ModuleA.Provide"; "cosmossdk.io/depinject_test.ModuleA.Provide" -> "cosmossdk.io/depinject_test.KeeperA"; "cosmossdk.io/depinject_test.ModuleA.Provide" -> "map[string]cosmossdk.io/depinject_test.Handler"; "cosmossdk.io/depinject_test.ModuleA.Provide" -> "[]cosmossdk.io/depinject_test.Command"; + "cosmossdk.io/depinject_test.ModuleB" -> "cosmossdk.io/depinject_test.ModuleB.Provide"; "cosmossdk.io/depinject_test.KVStoreKey" -> "cosmossdk.io/depinject_test.ModuleB.Provide"; "cosmossdk.io/depinject_test.MsgClientA" -> "cosmossdk.io/depinject_test.ModuleB.Provide"; "cosmossdk.io/depinject_test.ModuleB.Provide" -> "cosmossdk.io/depinject_test.KeeperB"; "cosmossdk.io/depinject_test.ModuleB.Provide" -> "[]cosmossdk.io/depinject_test.Command"; "cosmossdk.io/depinject_test.ModuleB.Provide" -> "map[string]cosmossdk.io/depinject_test.Handler"; + "cosmossdk.io/depinject_test.init" -> "cosmossdk.io/depinject_test.ModuleA"; + "cosmossdk.io/depinject_test.init" -> "cosmossdk.io/depinject_test.ModuleB"; "cosmossdk.io/depinject_test.KeeperB" -> "cosmossdk.io/depinject_test.TestGraphAndLogOutput"; } diff --git a/depinject/testdata/example_error.dot b/depinject/testdata/example_error.dot index a47386cdd4d2..296de2b4a2a3 100644 --- a/depinject/testdata/example_error.dot +++ b/depinject/testdata/example_error.dot @@ -11,25 +11,29 @@ digraph "" { subgraph "cluster_runtime" { graph [fontsize="12.0", label="Module: runtime", penwidth="0.5", style="rounded"]; - "cosmossdk.io/depinject_test.ProvideKVStoreKey"[color="black", fontcolor="black", penwidth="1.5", shape="box"]; + "cosmossdk.io/depinject_test.ProvideKVStoreKey"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5", shape="box"]; } "[]cosmossdk.io/depinject_test.Command"[color="lightgrey", comment="many-per-container", fontcolor="dimgrey", penwidth="0.5"]; - "cosmossdk.io/depinject.ModuleKey"[color="black", fontcolor="black", penwidth="1.5"]; + "cosmossdk.io/depinject.ModuleKey"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject.OwnModuleKey"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; - "cosmossdk.io/depinject_test.KVStoreKey"[color="black", fontcolor="black", penwidth="1.5"]; + "cosmossdk.io/depinject_test.KVStoreKey"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject_test.KeeperA"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject_test.KeeperB"[color="red", fontcolor="red", penwidth="0.5"]; - "cosmossdk.io/depinject_test.MsgClientA"[color="red", fontcolor="red", penwidth="0.5"]; + "cosmossdk.io/depinject_test.ModuleA"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; + "cosmossdk.io/depinject_test.ModuleB"[color="red", fontcolor="red", penwidth="0.5"]; + "cosmossdk.io/depinject_test.MsgClientA"[color="lightgrey", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject_test.TestGraphAndLogOutput"[color="red", fontcolor="red", penwidth="0.5", shape="hexagon"]; "map[string]cosmossdk.io/depinject_test.Handler"[color="lightgrey", comment="one-per-module", fontcolor="dimgrey", penwidth="0.5"]; "cosmossdk.io/depinject.ModuleKey" -> "cosmossdk.io/depinject_test.ProvideKVStoreKey"; "cosmossdk.io/depinject_test.ProvideKVStoreKey" -> "cosmossdk.io/depinject_test.KVStoreKey"; + "cosmossdk.io/depinject_test.ModuleA" -> "cosmossdk.io/depinject_test.ModuleA.Provide"; "cosmossdk.io/depinject_test.KVStoreKey" -> "cosmossdk.io/depinject_test.ModuleA.Provide"; "cosmossdk.io/depinject.OwnModuleKey" -> "cosmossdk.io/depinject_test.ModuleA.Provide"; "cosmossdk.io/depinject_test.ModuleA.Provide" -> "cosmossdk.io/depinject_test.KeeperA"; "cosmossdk.io/depinject_test.ModuleA.Provide" -> "map[string]cosmossdk.io/depinject_test.Handler"; "cosmossdk.io/depinject_test.ModuleA.Provide" -> "[]cosmossdk.io/depinject_test.Command"; + "cosmossdk.io/depinject_test.ModuleB" -> "cosmossdk.io/depinject_test.ModuleB.Provide"; "cosmossdk.io/depinject_test.KVStoreKey" -> "cosmossdk.io/depinject_test.ModuleB.Provide"; "cosmossdk.io/depinject_test.MsgClientA" -> "cosmossdk.io/depinject_test.ModuleB.Provide"; "cosmossdk.io/depinject_test.ModuleB.Provide" -> "cosmossdk.io/depinject_test.KeeperB"; diff --git a/orm/go.sum b/orm/go.sum index 71a2f98a1dfd..0800d8dcf81a 100644 --- a/orm/go.sum +++ b/orm/go.sum @@ -257,7 +257,5 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.3.0 h1:MfDY1b1/0xN1CyMlQDac0ziEy9zJQd9CXBRRDHw2jJo= gotest.tools/v3 v3.3.0/go.mod h1:Mcr9QNxkg0uMvy/YElmo4SpXgJKWgQvYrT7Kw5RzJ1A= pgregory.net/rapid v0.4.7/go.mod h1:UYpPVyjFHzYBGHIxLFoupi8vwk6rXNzRY9OMvVxFIOU= -pgregory.net/rapid v0.5.1 h1:U7LVKOJavGH81G5buGiyztKCmpQLfepzitHEKDLQ8ug= -pgregory.net/rapid v0.5.1/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= pgregory.net/rapid v0.5.2 h1:zC+jmuzcz5yJvG/igG06aLx8kcGmZY435NcuyhblKjY= pgregory.net/rapid v0.5.2/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=