Skip to content

Commit

Permalink
wire: give wire.Bind access to the arguments to the injector function (
Browse files Browse the repository at this point in the history
  • Loading branch information
vangent authored and zombiezen committed Nov 28, 2018
1 parent 67170e7 commit 6ea381b
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 89 deletions.
3 changes: 3 additions & 0 deletions cmd/wire/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
case pv.IsNil():
// This is an input.
inputVisited.Set(curr, -1)
case pv.IsArg():
// This is an injector argument.
inputVisited.Set(curr, -1)
case pv.IsProvider():
// Try to see if any args haven't been visited.
p := pv.Provider()
Expand Down
67 changes: 38 additions & 29 deletions internal/wire/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,37 +80,14 @@ type call struct {

// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) {
func solve(fset *token.FileSet, out types.Type, given *types.Tuple, set *ProviderSet) ([]call, []error) {
ec := new(errorCollector)
for i, g := range given {
for _, h := range given[:i] {
if types.Identical(g, h) {
ec.add(fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil)))
}
}
}

// Start building the mapping of type to local variable of the given type.
// The first len(given) local variables are the given types.
index := new(typeutil.Map)
for i, g := range given {
if pv := set.For(g); !pv.IsNil() {
switch {
case pv.IsProvider():
ec.add(fmt.Errorf("input of %s conflicts with provider %s at %s",
types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos)))
case pv.IsValue():
ec.add(fmt.Errorf("input of %s conflicts with value at %s",
types.TypeString(g, nil), fset.Position(pv.Value().Pos)))
default:
panic("unknown return value from ProviderSet.For")
}
} else {
index.Set(g, i)
}
}
if len(ec.errors) > 0 {
return nil, ec.errors
for i := 0; i < given.Len(); i++ {
index.Set(given.At(i).Type(), i)
}

// Topological sort of the directed graph defined by the providers
Expand Down Expand Up @@ -149,6 +126,19 @@ dfs:
ec.add(errors.New(sb.String()))
index.Set(curr.t, errAbort)
continue
case pv.IsArg():
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
// Interface binding.
i := index.At(concrete)
if i == nil {
stk = append(stk, curr, frame{t: concrete, from: curr.t, up: &curr})
continue
}
index.Set(curr.t, i)
}
continue
case pv.IsProvider():
p := pv.Provider()
src := set.srcMap.At(curr.t).(*providerSetSrc)
Expand Down Expand Up @@ -192,7 +182,7 @@ dfs:
}
args[i] = v.(int)
}
index.Set(curr.t, len(given)+len(calls))
index.Set(curr.t, given.Len()+len(calls))
kind := funcProviderCall
if p.IsStruct {
kind = structProvider
Expand Down Expand Up @@ -222,7 +212,7 @@ dfs:
}
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
index.Set(curr.t, len(given)+len(calls))
index.Set(curr.t, given.Len()+len(calls))
calls = append(calls, call{
kind: valueExpr,
out: curr.t,
Expand Down Expand Up @@ -308,8 +298,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
srcMap := new(typeutil.Map) // to *providerSetSrc
srcMap.SetHasher(hasher)

// Process imports first, verifying that there are no conflicts between sets.
ec := new(errorCollector)
// Process injector arguments.
if set.InjectorArgs != nil {
givens := set.InjectorArgs.Tuple
for i := 0; i < givens.Len(); i++ {
typ := givens.At(i).Type()
arg := &InjectorArg{Args: set.InjectorArgs, Index: i}
src := &providerSetSrc{InjectorArg: arg}
if prevSrc := srcMap.At(typ); prevSrc != nil {
ec.add(bindingConflictError(fset, typ, set, src, prevSrc.(*providerSetSrc)))
continue
}
providerMap.Set(typ, &ProvidedType{t: typ, a: arg})
srcMap.Set(typ, src)
}
}
// Process imports, verifying that there are no conflicts between sets.
for _, imp := range set.Imports {
src := &providerSetSrc{Import: imp}
imp.providerMap.Iterate(func(k types.Type, v interface{}) {
Expand Down Expand Up @@ -407,6 +412,10 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
// Leaf: values do not have dependencies.
continue
}
if pt.IsArg() {
// Injector arguments do not have dependencies.
continue
}
if !pt.IsProvider() {
panic("invalid provider map value")
}
Expand Down
136 changes: 89 additions & 47 deletions internal/wire/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ import (
// A providerSetSrc captures the source for a type provided by a ProviderSet.
// Exactly one of the fields will be set.
type providerSetSrc struct {
Provider *Provider
Binding *IfaceBinding
Value *Value
Import *ProviderSet
Provider *Provider
Binding *IfaceBinding
Value *Value
Import *ProviderSet
InjectorArg *InjectorArg
}

// description returns a string describing the source of p, including line numbers.
Expand All @@ -59,6 +60,9 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string
return fmt.Sprintf("wire.Value (%s)", fset.Position(p.Value.Pos))
case p.Import != nil:
return fmt.Sprintf("provider set %s(%s)", quoted(p.Import.VarName), fset.Position(p.Import.Pos))
case p.InjectorArg != nil:
args := p.InjectorArg.Args
return fmt.Sprintf("argument %s to injector function %s (%s)", args.Tuple.At(p.InjectorArg.Index).Name(), args.Name, fset.Position(args.Pos))
}
panic("providerSetSrc with no fields set")
}
Expand Down Expand Up @@ -93,6 +97,8 @@ type ProviderSet struct {
Bindings []*IfaceBinding
Values []*Value
Imports []*ProviderSet
// InjectorArgs is only filled in for wire.Build.
InjectorArgs *InjectorArgs

// providerMap maps from provided type to a *ProvidedType.
// It includes all of the imported types.
Expand Down Expand Up @@ -190,6 +196,24 @@ type Value struct {
info *types.Info
}

// InjectorArg describes a specific argument passed to an injector function.
type InjectorArg struct {
// Args is the full set of arguments.
Args *InjectorArgs
// Index is the index into Args.Tuple for this argument.
Index int
}

// InjectorArgs describes the arguments passed to an injector function.
type InjectorArgs struct {
// Name is the name of the injector function.
Name string
// Tuple represents the arguments.
Tuple *types.Tuple
// Pos is the source position of the injector function.
Pos token.Pos
}

// Load finds all the provider sets in the packages that match the given
// patterns, as well as the provider sets' transitive dependencies. It
// may return both errors and Info. The patterns are defined by the
Expand Down Expand Up @@ -252,11 +276,6 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf
if buildCall == nil {
continue
}
set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, "")
if len(errs) > 0 {
ec.add(notePositionAll(fset.Position(fn.Pos()), errs)...)
continue
}
sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature)
ins, out, err := injectorFuncSignature(sig)
if err != nil {
Expand All @@ -267,6 +286,16 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf
}
continue
}
injectorArgs := &InjectorArgs{
Name: fn.Name.Name,
Tuple: ins,
Pos: fn.Pos(),
}
set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "")
if len(errs) > 0 {
ec.add(notePositionAll(fset.Position(fn.Pos()), errs)...)
continue
}
_, errs = solve(fset, out.out, ins, set)
if len(errs) > 0 {
ec.add(mapErrors(errs, func(e error) error {
Expand Down Expand Up @@ -482,7 +511,7 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
}
switch fnObj.Name() {
case "NewSet":
pset, errs := oc.processNewSet(info, pkgPath, call, varName)
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "Bind":
b, err := processBind(oc.fset, info, call)
Expand Down Expand Up @@ -516,13 +545,14 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
}

func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, varName string) (*ProviderSet, []error) {
func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (*ProviderSet, []error) {
// Assumes that call.Fun is wire.NewSet or wire.Build.

pset := &ProviderSet{
Pos: call.Pos(),
PkgPath: pkgPath,
VarName: varName,
Pos: call.Pos(),
InjectorArgs: args,
PkgPath: pkgPath,
VarName: varName,
}
ec := new(errorCollector)
for _, arg := range call.Args {
Expand Down Expand Up @@ -626,17 +656,12 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
return provider, nil
}

func injectorFuncSignature(sig *types.Signature) ([]types.Type, outputSignature, error) {
func injectorFuncSignature(sig *types.Signature) (*types.Tuple, outputSignature, error) {
out, err := funcOutput(sig)
if err != nil {
return nil, outputSignature{}, err
}
params := sig.Params()
given := make([]types.Type, params.Len())
for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type()
}
return given, out, nil
return sig.Params(), out, nil
}

type outputSignature struct {
Expand Down Expand Up @@ -893,49 +918,66 @@ func isProviderSetType(t types.Type) bool {
return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet"
}

// ProvidedType is a pointer to a Provider or a Value. The zero value is
// a nil pointer. It also holds the concrete type that the Provider or Value
// provided.
// ProvidedType represents a type provided from a source. The source
// can be a *Provider (a provider function), a *Value (wire.Value), or an
// *InjectorArgs (arguments to the injector function). The zero value has
// none of the above, and returns true for IsNil.
type ProvidedType struct {
// t is the provided concrete type.
t types.Type
p *Provider
v *Value
a *InjectorArg
}

// IsNil reports whether pv is the zero value.
func (pv ProvidedType) IsNil() bool {
return pv.p == nil && pv.v == nil
// IsNil reports whether pt is the zero value.
func (pt ProvidedType) IsNil() bool {
return pt.p == nil && pt.v == nil && pt.a == nil
}

// ConcreteType returns the concrete type that was provided.
func (pv ProvidedType) ConcreteType() types.Type {
return pv.t
func (pt ProvidedType) ConcreteType() types.Type {
return pt.t
}

// IsProvider reports whether pv points to a Provider.
func (pv ProvidedType) IsProvider() bool {
return pv.p != nil
// IsProvider reports whether pt points to a Provider.
func (pt ProvidedType) IsProvider() bool {
return pt.p != nil
}

// IsValue reports whether pv points to a Value.
func (pv ProvidedType) IsValue() bool {
return pv.v != nil
// IsValue reports whether pt points to a Value.
func (pt ProvidedType) IsValue() bool {
return pt.v != nil
}

// IsArg reports whether pt points to an injector argument.
func (pt ProvidedType) IsArg() bool {
return pt.a != nil
}

// Provider returns pt as a Provider pointer. It panics if pt does not point
// to a Provider.
func (pt ProvidedType) Provider() *Provider {
if pt.p == nil {
panic("ProvidedType does not hold a Provider")
}
return pt.p
}

// Provider returns pv as a Provider pointer. It panics if pv points to a
// Value.
func (pv ProvidedType) Provider() *Provider {
if pv.v != nil {
panic("Value pointer converted to a Provider")
// Value returns pt as a Value pointer. It panics if pt does not point
// to a Value.
func (pt ProvidedType) Value() *Value {
if pt.v == nil {
panic("ProvidedType does not hold a Value")
}
return pv.p
return pt.v
}

// Value returns pv as a Value pointer. It panics if pv points to a
// Provider.
func (pv ProvidedType) Value() *Value {
if pv.p != nil {
panic("Provider pointer converted to a Value")
// Arg returns pt as an *InjectorArg representing an injector argument. It
// panics if pt does not point to an arg.
func (pt ProvidedType) Arg() *InjectorArg {
if pt.a == nil {
panic("ProvidedType does not hold an Arg")
}
return pv.v
return pt.a
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hello
1 change: 0 additions & 1 deletion internal/wire/testdata/BindInjectorArg/want/wire_errs.txt

This file was deleted.

13 changes: 13 additions & 0 deletions internal/wire/testdata/BindInjectorArg/want/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at example.com/foo/foo.go:x:y
example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo
current:
<- provider "provideFoo" (example.com/foo/foo.go:x:y)
<- provider set "Set" (example.com/foo/foo.go:x:y)
previous:
<- argument foo to injector function injectBar (example.com/foo/wire.go:x:y)
2 changes: 1 addition & 1 deletion internal/wire/testdata/InvalidInjector/foo/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

func injectFoo() Foo {
// This non-call statement makes this an invalid injector.
_ = 42
_ = 42
panic(wire.Build(provideFoo))
}

Expand Down
Loading

0 comments on commit 6ea381b

Please sign in to comment.