Skip to content

Commit

Permalink
refactor(depinject)!: require exported functions & types (cosmos#12797)
Browse files Browse the repository at this point in the history
* refactor(depinject)!: require exported functions

* unexport ProviderDescriptor

* WIP on tests

* fix tests and check for bound instance methods

* address merge issues

* WIP on checking valid types

* WIP on checking valid types

* WIP

* tests passing

* revert changes outside module

* docs

* docs

* docs

* add comment

* revert

* update depinject go.mod versions

* remove go.work

* add go.work back

* go mod tidy

* fix docs

Co-authored-by: Julien Robert <julien@rbrt.fr>
  • Loading branch information
aaronc and julienrbrt authored Aug 31, 2022
1 parent 372a8f1 commit 7728516
Show file tree
Hide file tree
Showing 16 changed files with 525 additions and 283 deletions.
30 changes: 20 additions & 10 deletions depinject/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,18 @@ func (s bindingSuite) TwoImplementationsMallardAndCanvasback() {
// we don't need to do anything because this is defined at the type level
}

func ProvideMallard() Mallard { return Mallard{} }
func ProvideCanvasback() Canvasback { return Canvasback{} }
func ProvideMarbled() Marbled { return Marbled{} }

func (s *bindingSuite) IsProvided(a string) {
switch a {
case "Mallard":
s.addConfig(depinject.Provide(func() Mallard { return Mallard{} }))
s.addConfig(depinject.Provide(ProvideMallard))
case "Canvasback":
s.addConfig(depinject.Provide(func() Canvasback { return Canvasback{} }))
s.addConfig(depinject.Provide(ProvideCanvasback))
case "Marbled":
s.addConfig(depinject.Provide(func() Marbled { return Marbled{} }))
s.addConfig(depinject.Provide(ProvideMarbled))
default:
s.Fatalf("unexpected duck type %s", a)
}
Expand All @@ -79,18 +83,22 @@ func (s *bindingSuite) addConfig(config depinject.Config) {
s.configs = append(s.configs, config)
}

func ProvideDuckWrapper(duck Duck) DuckWrapper {
return DuckWrapper{Module: "", Duck: duck}
}

func (s *bindingSuite) WeTryToResolveADuckInGlobalScope() {
s.addConfig(depinject.Provide(func(duck Duck) DuckWrapper {
return DuckWrapper{Module: "", Duck: duck}
}))
s.addConfig(depinject.Provide(ProvideDuckWrapper))
}

func ResolvePond(ducks []DuckWrapper) Pond { return Pond{Ducks: ducks} }

func (s *bindingSuite) resolvePond() *Pond {
if s.pond != nil {
return s.pond
}

s.addConfig(depinject.Provide(func(ducks []DuckWrapper) Pond { return Pond{Ducks: ducks} }))
s.addConfig(depinject.Provide(ResolvePond))
var pond Pond
s.err = depinject.Inject(depinject.Configs(s.configs...), &pond)
s.pond = &pond
Expand Down Expand Up @@ -131,10 +139,12 @@ func (s *bindingSuite) ThereIsABindingForAInModule(preferredType string, interfa
s.addConfig(depinject.BindInterfaceInModule(moduleName, fullTypeName(interfaceType), fullTypeName(preferredType)))
}

func ProvideModuleDuck(duck Duck, key depinject.OwnModuleKey) DuckWrapper {
return DuckWrapper{Module: depinject.ModuleKey(key).Name(), Duck: duck}
}

func (s *bindingSuite) ModuleWantsADuck(module string) {
s.addConfig(depinject.ProvideInModule(module, func(duck Duck) DuckWrapper {
return DuckWrapper{Module: module, Duck: duck}
}))
s.addConfig(depinject.ProvideInModule(module, ProvideModuleDuck))
}

func (s *bindingSuite) ModuleResolvesA(module string, duckType string) {
Expand Down
69 changes: 69 additions & 0 deletions depinject/check_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package depinject

import (
"reflect"
"strings"
"unicode"

"github.com/pkg/errors"
"golang.org/x/exp/slices"
)

// isExportedType checks if the type is exported and not in an internal
// package. NOTE: generic type parameters are not checked because this
// would involve complex parsing of type names (there is no reflect API for
// generic type parameters). Parsing of these parameters should be possible
// if someone chooses to do it in the future, but care should be taken to
// be exhaustive and cover all cases like pointers, map's, chan's, etc. which
// means you actually need a real parser and not just a regex.
func isExportedType(typ reflect.Type) error {
name := typ.Name()
pkgPath := typ.PkgPath()
if name != "" && pkgPath != "" {
if unicode.IsLower([]rune(name)[0]) {
return errors.Errorf("type must be exported: %s", typ)
}

pkgParts := strings.Split(pkgPath, "/")
if slices.Contains(pkgParts, "internal") {
return errors.Errorf("type must not come from an internal package: %s", typ)
}

return nil
}

switch typ.Kind() {
case reflect.Array, reflect.Slice, reflect.Chan, reflect.Pointer:
return isExportedType(typ.Elem())

case reflect.Func:
numIn := typ.NumIn()
for i := 0; i < numIn; i++ {
err := isExportedType(typ.In(i))
if err != nil {
return err
}
}

numOut := typ.NumOut()
for i := 0; i < numOut; i++ {
err := isExportedType(typ.Out(i))
if err != nil {
return err
}
}

return nil

case reflect.Map:
err := isExportedType(typ.Key())
if err != nil {
return err
}
return isExportedType(typ.Elem())

default:
// all the remaining types are builtin, non-composite types (like integers), so they are fine to use
return nil
}
}
59 changes: 59 additions & 0 deletions depinject/check_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package depinject

import (
"os"
"reflect"
"testing"

"gotest.tools/v3/assert"

"cosmossdk.io/depinject/internal/graphviz"
)

func TestCheckIsExportedType(t *testing.T) {
expectValidType(t, false)
expectValidType(t, uint(0))
expectValidType(t, uint8(0))
expectValidType(t, uint16(0))
expectValidType(t, uint32(0))
expectValidType(t, uint64(0))
expectValidType(t, int(0))
expectValidType(t, int8(0))
expectValidType(t, int16(0))
expectValidType(t, int32(0))
expectValidType(t, int64(0))
expectValidType(t, float32(0))
expectValidType(t, float64(0))
expectValidType(t, complex64(0))
expectValidType(t, complex128(0))
expectValidType(t, os.FileMode(0))
expectValidType(t, [1]int{0})
expectValidType(t, []int{})
expectValidType(t, "")
expectValidType(t, make(chan int))
expectValidType(t, make(<-chan int))
expectValidType(t, make(chan<- int))
expectValidType(t, func(int, string) (bool, error) { return false, nil })
expectValidType(t, func(int, ...string) (bool, error) { return false, nil })
expectValidType(t, In{})
expectValidType(t, map[string]In{})
expectValidType(t, &In{})
expectValidType(t, uintptr(0))
expectValidType(t, (*Location)(nil))

expectInvalidType(t, container{}, "must be exported")
expectInvalidType(t, &container{}, "must be exported")
expectInvalidType(t, graphviz.Attributes{}, "internal")
expectInvalidType(t, map[string]graphviz.Attributes{}, "internal")
expectInvalidType(t, []graphviz.Attributes{}, "internal")
}

func expectValidType(t *testing.T, v interface{}) {
t.Helper()
assert.NilError(t, isExportedType(reflect.TypeOf(v)))
}

func expectInvalidType(t *testing.T, v interface{}, errContains string) {
t.Helper()
assert.ErrorContains(t, isExportedType(reflect.TypeOf(v)), errContains)
}
22 changes: 18 additions & 4 deletions depinject/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ type Config interface {
// Provide defines a container configuration which registers the provided dependency
// injection providers. Each provider will be called at most once with the
// exception of module-scoped providers which are called at most once per module
// (see ModuleKey).
// (see ModuleKey). All provider functions must be declared, exported functions not
// internal packages and all of their input and output types must also be declared
// and exported and not in internal packages. Note that generic type parameters
// will not be checked, but they should also be exported so that codegen is possible.
func Provide(providers ...interface{}) Config {
return containerConfig(func(ctr *container) error {
return provide(ctr, nil, providers)
Expand All @@ -23,7 +26,10 @@ func Provide(providers ...interface{}) Config {

// ProvideInModule defines container configuration which registers the provided dependency
// injection providers that are to be run in the named module. Each provider
// will be called at most once.
// will be called at most once. All provider functions must be declared, exported functions not
// internal packages and all of their input and output types must also be declared
// and exported and not in internal packages. Note that generic type parameters
// will not be checked, but they should also be exported so that codegen is possible.
func ProvideInModule(moduleName string, providers ...interface{}) Config {
return containerConfig(func(ctr *container) error {
if moduleName == "" {
Expand All @@ -36,7 +42,7 @@ func ProvideInModule(moduleName string, providers ...interface{}) Config {

func provide(ctr *container, key *moduleKey, providers []interface{}) error {
for _, c := range providers {
rc, err := ExtractProviderDescriptor(c)
rc, err := extractProviderDescriptor(c)
if err != nil {
return errors.WithStack(err)
}
Expand All @@ -52,6 +58,10 @@ func provide(ctr *container, key *moduleKey, providers []interface{}) error {
// at the end of dependency graph configuration in the order in which it was defined. Invokers may not define output
// parameters, although they may return an error, and all of their input parameters will be marked as optional so that
// invokers impose no additional constraints on the dependency graph. Invoker functions should nil-check all inputs.
// All invoker functions must be declared, exported functions not
// internal packages and all of their input and output types must also be declared
// and exported and not in internal packages. Note that generic type parameters
// will not be checked, but they should also be exported so that codegen is possible.
func Invoke(invokers ...interface{}) Config {
return containerConfig(func(ctr *container) error {
return invoke(ctr, nil, invokers)
Expand All @@ -63,6 +73,10 @@ func Invoke(invokers ...interface{}) Config {
// at the end of dependency graph configuration in the order in which it was defined. Invokers may not define output
// parameters, although they may return an error, and all of their input parameters will be marked as optional so that
// invokers impose no additional constraints on the dependency graph. Invoker functions should nil-check all inputs.
// All invoker functions must be declared, exported functions not
// internal packages and all of their input and output types must also be declared
// and exported and not in internal packages. Note that generic type parameters
// will not be checked, but they should also be exported so that codegen is possible.
func InvokeInModule(moduleName string, invokers ...interface{}) Config {
return containerConfig(func(ctr *container) error {
if moduleName == "" {
Expand All @@ -75,7 +89,7 @@ func InvokeInModule(moduleName string, invokers ...interface{}) Config {

func invoke(ctr *container, key *moduleKey, invokers []interface{}) error {
for _, c := range invokers {
rc, err := ExtractInvokerDescriptor(c)
rc, err := extractInvokerDescriptor(c)
if err != nil {
return errors.WithStack(err)
}
Expand Down
23 changes: 14 additions & 9 deletions depinject/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type container struct {
}

type invoker struct {
fn *ProviderDescriptor
fn *providerDescriptor
modKey *moduleKey
}

Expand Down Expand Up @@ -55,7 +55,7 @@ func newContainer(cfg *debugConfig) *container {
}
}

func (c *container) call(provider *ProviderDescriptor, moduleKey *moduleKey) ([]reflect.Value, error) {
func (c *container) call(provider *providerDescriptor, moduleKey *moduleKey) ([]reflect.Value, error) {
loc := provider.Location
graphNode := c.locationGraphNode(loc, moduleKey)

Expand Down Expand Up @@ -205,7 +205,7 @@ func (c *container) getExplicitResolver(typ reflect.Type, key *moduleKey) (resol

var stringType = reflect.TypeOf("")

func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (interface{}, error) {
func (c *container) addNode(provider *providerDescriptor, key *moduleKey) (interface{}, error) {
providerGraphNode := c.locationGraphNode(provider.Location, key)
hasModuleKeyParam := false
hasOwnModuleKeyParam := false
Expand Down Expand Up @@ -359,7 +359,7 @@ func (c *container) supply(value reflect.Value, location Location) error {
return nil
}

func (c *container) addInvoker(provider *ProviderDescriptor, key *moduleKey) error {
func (c *container) addInvoker(provider *providerDescriptor, key *moduleKey) error {
// make sure there are no outputs
if len(provider.Outputs) > 0 {
return fmt.Errorf("invoker function %s should not return any outputs", provider.Location)
Expand All @@ -373,7 +373,7 @@ func (c *container) addInvoker(provider *ProviderDescriptor, key *moduleKey) err
return nil
}

func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Location) (reflect.Value, error) {
func (c *container) resolve(in providerInput, moduleKey *moduleKey, caller Location) (reflect.Value, error) {
c.resolveStack = append(c.resolveStack, resolveFrame{loc: caller, typ: in.Type})

typeGraphNode := c.typeGraphNode(in.Type)
Expand Down Expand Up @@ -426,17 +426,17 @@ func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Locat
}

func (c *container) build(loc Location, outputs ...interface{}) error {
var providerIn []ProviderInput
var providerIn []providerInput
for _, output := range outputs {
typ := reflect.TypeOf(output)
if typ.Kind() != reflect.Pointer {
return fmt.Errorf("output type must be a pointer, %s is invalid", typ)
}

providerIn = append(providerIn, ProviderInput{Type: typ.Elem()})
providerIn = append(providerIn, providerInput{Type: typ.Elem()})
}

desc := ProviderDescriptor{
desc := providerDescriptor{
Inputs: providerIn,
Outputs: nil,
Fn: func(values []reflect.Value) ([]reflect.Value, error) {
Expand Down Expand Up @@ -523,7 +523,12 @@ func fullyQualifiedTypeName(typ reflect.Type) string {
if typ.Kind() == reflect.Pointer || typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map || typ.Kind() == reflect.Array {
pkgType = typ.Elem()
}
return fmt.Sprintf("%s/%v", pkgType.PkgPath(), typ)
pkgPath := pkgType.PkgPath()
if pkgPath == "" {
return fmt.Sprintf("%v", typ)
}

return fmt.Sprintf("%s/%v", pkgPath, typ)
}

func bindingKeyFromTypeName(typeName string, key *moduleKey) string {
Expand Down
Loading

0 comments on commit 7728516

Please sign in to comment.