From ba16de261411794f52409934734ebace68b58642 Mon Sep 17 00:00:00 2001 From: Sung Yoon Whang Date: Tue, 8 Feb 2022 09:51:33 -0800 Subject: [PATCH] Add fx.Module (#830) This adds fx.Module Option which is a first-class object for supporting scoped operations on dependencies. A Module can consist of zero or more fx.Options. By default, Provides to a Module is provided to the entire App, but there is a room for adding an option to scope that to a Module. Module can wrap Options such asSupply/Extract, Provide, and Invoke but there are some Options that don't make sense to put under Module. For example, StartTimeout, StopTimeout, WithLogger explicitly errors out when supplied to a Module. Implementation-wise, a Module corresponds to dig.Scope which was added in uber-go/dig#305. Extra bookkeeping is done by the module struct which contains the provides and invokes to a Scope. Co-authored-by: Abhinav Gupta Co-authored-by: Abhinav Gupta Co-authored-by: Abhinav Gupta --- app.go | 176 +++++++++++++--------- app_internal_test.go | 8 +- app_test.go | 27 +++- module.go | 103 +++++++++++++ module_test.go | 340 +++++++++++++++++++++++++++++++++++++++++++ supply.go | 4 +- supply_test.go | 23 ++- 7 files changed, 603 insertions(+), 78 deletions(-) create mode 100644 module.go create mode 100644 module_test.go diff --git a/app.go b/app.go index da3ca2307..a56ae3c4b 100644 --- a/app.go +++ b/app.go @@ -53,7 +53,7 @@ const DefaultTimeout = 15 * time.Second type Option interface { fmt.Stringer - apply(*App) + apply(*module) } // Provide registers any number of constructor functions, teaching the @@ -98,9 +98,9 @@ type provideOption struct { Stack fxreflect.Stack } -func (o provideOption) apply(app *App) { +func (o provideOption) apply(mod *module) { for _, target := range o.Targets { - app.provides = append(app.provides, provide{ + mod.provides = append(mod.provides, provide{ Target: target, Stack: o.Stack, }) @@ -145,9 +145,9 @@ type invokeOption struct { Stack fxreflect.Stack } -func (o invokeOption) apply(app *App) { +func (o invokeOption) apply(mod *module) { for _, target := range o.Targets { - app.invokes = append(app.invokes, invoke{ + mod.invokes = append(mod.invokes, invoke{ Target: target, Stack: o.Stack, }) @@ -174,8 +174,8 @@ func Error(errs ...error) Option { type errorOption []error -func (errs errorOption) apply(app *App) { - app.err = multierr.Append(app.err, multierr.Combine(errs...)) +func (errs errorOption) apply(mod *module) { + mod.app.err = multierr.Append(mod.app.err, multierr.Combine(errs...)) } func (errs errorOption) String() string { @@ -219,9 +219,9 @@ func Options(opts ...Option) Option { type optionGroup []Option -func (og optionGroup) apply(app *App) { +func (og optionGroup) apply(mod *module) { for _, opt := range og { - opt.apply(app) + opt.apply(mod) } } @@ -240,8 +240,13 @@ func StartTimeout(v time.Duration) Option { type startTimeoutOption time.Duration -func (t startTimeoutOption) apply(app *App) { - app.startTimeout = time.Duration(t) +func (t startTimeoutOption) apply(m *module) { + if m.parent != nil { + m.app.err = fmt.Errorf("fx.StartTimeout Option should be passed to top-level App, " + + "not to fx.Module") + } else { + m.app.startTimeout = time.Duration(t) + } } func (t startTimeoutOption) String() string { @@ -255,8 +260,13 @@ func StopTimeout(v time.Duration) Option { type stopTimeoutOption time.Duration -func (t stopTimeoutOption) apply(app *App) { - app.stopTimeout = time.Duration(t) +func (t stopTimeoutOption) apply(m *module) { + if m.parent != nil { + m.app.err = fmt.Errorf("fx.StopTimeout Option should be passed to top-level App, " + + "not to fx.Module") + } else { + m.app.stopTimeout = time.Duration(t) + } } func (t stopTimeoutOption) String() string { @@ -288,10 +298,16 @@ type withLoggerOption struct { Stack fxreflect.Stack } -func (l withLoggerOption) apply(app *App) { - app.logConstructor = &provide{ - Target: l.constructor, - Stack: l.Stack, +func (l withLoggerOption) apply(m *module) { + if m.parent != nil { + // loggers shouldn't differ based on Module. + m.app.err = fmt.Errorf("fx.WithLogger Option should be passed to top-level App, " + + "not to fx.Module") + } else { + m.app.logConstructor = &provide{ + Target: l.constructor, + Stack: l.Stack, + } } } @@ -316,9 +332,14 @@ func Logger(p Printer) Option { type loggerOption struct{ p Printer } -func (l loggerOption) apply(app *App) { - np := writerFromPrinter(l.p) - app.log = fxlog.DefaultLogger(np) // assuming np is thread-safe. +func (l loggerOption) apply(m *module) { + if m.parent != nil { + m.app.err = fmt.Errorf("fx.StartTimeout Option should be passed to top-level App, " + + "not to fx.Module") + } else { + np := writerFromPrinter(l.p) + m.app.log = fxlog.DefaultLogger(np) // assuming np is thread-safe. + } } func (l loggerOption) String() string { @@ -366,11 +387,12 @@ var NopLogger = WithLogger(func() fxevent.Logger { return fxevent.NopLogger }) type App struct { err error clock fxclock.Clock - container *dig.Container lifecycle *lifecycleWrapper - // Constructors and its dependencies. - provides []provide - invokes []invoke + + container *dig.Container + root *module + modules []*module + // Used to setup logging within fx. log fxevent.Logger logConstructor *provide // set only if fx.WithLogger was used @@ -424,8 +446,8 @@ func ErrorHook(funcs ...ErrorHandler) Option { type errorHookOption []ErrorHandler -func (eho errorHookOption) apply(app *App) { - app.errorHooks = append(app.errorHooks, eho...) +func (eho errorHookOption) apply(m *module) { + m.app.errorHooks = append(m.app.errorHooks, eho...) } func (eho errorHookOption) String() string { @@ -455,8 +477,13 @@ type validateOption struct { validate bool } -func (o validateOption) apply(app *App) { - app.validate = o.validate +func (o validateOption) apply(m *module) { + if m.parent != nil { + m.app.err = fmt.Errorf("fx.validate Option should be passed to top-level App, " + + "not to fx.Module") + } else { + m.app.validate = o.validate + } } func (o validateOption) String() string { @@ -521,9 +548,11 @@ func New(opts ...Option) *App { startTimeout: DefaultTimeout, stopTimeout: DefaultTimeout, } + app.root = &module{app: app} + app.modules = append(app.modules, app.root) for _, opt := range opts { - opt.apply(app) + opt.apply(app.root) } // There are a few levels of wrapping on the lifecycle here. To quickly @@ -562,17 +591,21 @@ func New(opts ...Option) *App { dig.DryRun(app.validate), ) - for _, p := range app.provides { - app.provide(p) + for _, m := range app.modules { + m.build(app, app.container) + } + + for _, m := range app.modules { + m.provideAll() } frames := fxreflect.CallerStack(0, 0) // include New in the stack for default Provides - app.provide(provide{ + app.root.provide(provide{ Target: func() Lifecycle { return app.lifecycle }, Stack: frames, }) - app.provide(provide{Target: app.shutdowner, Stack: frames}) - app.provide(provide{Target: app.dotGraph, Stack: frames}) + app.root.provide(provide{Target: app.shutdowner, Stack: frames}) + app.root.provide(provide{Target: app.dotGraph, Stack: frames}) // If you are thinking about returning here after provides: do not (just yet)! // If a custom logger was being used, we're still buffering messages. @@ -597,7 +630,7 @@ func New(opts ...Option) *App { return app } - if err := app.executeInvokes(); err != nil { + if err := app.root.executeInvokes(); err != nil { app.err = err if dig.CanVisualizeError(err) { @@ -827,14 +860,14 @@ func (app *App) dotGraph() (DotGraph, error) { return DotGraph(b.String()), err } -func (app *App) provide(p provide) { - if app.err != nil { +func (m *module) provide(p provide) { + if m.app.err != nil { return } constructor := p.Target if _, ok := constructor.(Option); ok { - app.err = fmt.Errorf("fx.Option should be passed to fx.New directly, "+ + m.app.err = fmt.Errorf("fx.Option should be passed to fx.New directly, "+ "not to fx.Provide: fx.Provide received %v from:\n%+v", constructor, p.Stack) return @@ -843,6 +876,7 @@ func (app *App) provide(p provide) { var info dig.ProvideInfo opts := []dig.ProvideOption{ dig.FillProvideInfo(&info), + dig.Export(true), } defer func() { var ev fxevent.Event @@ -851,7 +885,7 @@ func (app *App) provide(p provide) { case p.IsSupply: ev = &fxevent.Supplied{ TypeName: p.SupplyType.String(), - Err: app.err, + Err: m.app.err, } default: @@ -861,39 +895,40 @@ func (app *App) provide(p provide) { } ev = &fxevent.Provided{ - ConstructorName: fxreflect.FuncName(constructor), + ConstructorName: fxreflect.FuncName(p.Target), OutputTypeNames: outputNames, - Err: app.err, + Err: m.app.err, } } - - app.log.LogEvent(ev) + m.app.log.LogEvent(ev) }() + c := m.scope switch constructor := constructor.(type) { case annotationError: // fx.Annotate failed. Turn it into an Fx error. - app.err = fmt.Errorf( + m.app.err = fmt.Errorf( "encountered error while applying annotation using fx.Annotate to %s: %+v", fxreflect.FuncName(constructor.target), constructor.err) return case annotated: - c, err := constructor.Build() + ctor, err := constructor.Build() if err != nil { - app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", constructor, p.Stack, err) + m.app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", constructor, p.Stack, err) return } - if err := app.container.Provide(c, opts...); err != nil { - app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", constructor, p.Stack, err) + if err := c.Provide(ctor, opts...); err != nil { + m.app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", constructor, p.Stack, err) + return } case Annotated: ann := constructor switch { case len(ann.Group) > 0 && len(ann.Name) > 0: - app.err = fmt.Errorf( + m.app.err = fmt.Errorf( "fx.Annotated may specify only one of Name or Group: received %v from:\n%+v", ann, p.Stack) return @@ -903,8 +938,9 @@ func (app *App) provide(p provide) { opts = append(opts, dig.Group(ann.Group)) } - if err := app.container.Provide(ann.Target, opts...); err != nil { - app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", ann, p.Stack, err) + if err := c.Provide(ann.Target, opts...); err != nil { + m.app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", ann, p.Stack, err) + return } default: @@ -915,7 +951,7 @@ func (app *App) provide(p provide) { t := ft.Out(i) if t == reflect.TypeOf(Annotated{}) { - app.err = fmt.Errorf( + m.app.err = fmt.Errorf( "fx.Annotated should be passed to fx.Provide directly, "+ "it should not be returned by the constructor: "+ "fx.Provide received %v from:\n%+v", @@ -925,40 +961,42 @@ func (app *App) provide(p provide) { } } - if err := app.container.Provide(constructor, opts...); err != nil { - app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", fxreflect.FuncName(constructor), p.Stack, err) + if err := c.Provide(constructor, opts...); err != nil { + m.app.err = fmt.Errorf("fx.Provide(%v) from:\n%+vFailed: %v", fxreflect.FuncName(constructor), p.Stack, err) + return } } - } -// Execute invokes in order supplied to New, returning the first error -// encountered. -func (app *App) executeInvokes() error { - // TODO: consider taking a context to limit the time spent running invocations. - - for _, i := range app.invokes { - if err := app.executeInvoke(i); err != nil { +func (m *module) executeInvokes() error { + for _, invoke := range m.invokes { + if err := m.executeInvoke(invoke); err != nil { return err } } + for _, m := range m.modules { + if err := m.executeInvokes(); err != nil { + return err + } + } return nil } -func (app *App) executeInvoke(i invoke) (err error) { +func (m *module) executeInvoke(i invoke) (err error) { fn := i.Target - fnName := fxreflect.FuncName(fn) + fnName := fxreflect.FuncName(i.Target) - app.log.LogEvent(&fxevent.Invoking{FunctionName: fnName}) + m.app.log.LogEvent(&fxevent.Invoking{FunctionName: fnName}) defer func() { - app.log.LogEvent(&fxevent.Invoked{ + m.app.log.LogEvent(&fxevent.Invoked{ FunctionName: fnName, Err: err, Trace: fmt.Sprintf("%+v", i.Stack), // format stack trace as multi-line }) }() + c := m.scope switch fn := fn.(type) { case Option: return fmt.Errorf("fx.Option should be passed to fx.New directly, "+ @@ -966,14 +1004,14 @@ func (app *App) executeInvoke(i invoke) (err error) { fn, i.Stack) case annotated: - c, err := fn.Build() + af, err := fn.Build() if err != nil { return err } - return app.container.Invoke(c) + return c.Invoke(af) default: - return app.container.Invoke(fn) + return c.Invoke(fn) } } diff --git a/app_internal_test.go b/app_internal_test.go index 197e4f73c..82de85e2c 100644 --- a/app_internal_test.go +++ b/app_internal_test.go @@ -85,8 +85,8 @@ func (o withExitOption) String() string { return fmt.Sprintf("WithExit(%v)", fxreflect.FuncName(o)) } -func (o withExitOption) apply(app *App) { - app.osExit = o +func (o withExitOption) apply(m *module) { + m.app.osExit = o } // WithClock specifies how Fx accesses time operations. @@ -98,8 +98,8 @@ func WithClock(clock fxclock.Clock) Option { type withClockOption struct{ clock fxclock.Clock } -func (o withClockOption) apply(app *App) { - app.clock = o.clock +func (o withClockOption) apply(m *module) { + m.app.clock = o.clock } func (o withClockOption) String() string { diff --git a/app_test.go b/app_test.go index 97615c0f5..128c53542 100644 --- a/app_test.go +++ b/app_test.go @@ -130,8 +130,8 @@ func TestNewApp(t *testing.T) { errMsg := err.Error() assert.Contains(t, errMsg, "cycle detected in dependency graph") - assert.Contains(t, errMsg, "depends on func(fx_test.A) fx_test.B") assert.Contains(t, errMsg, "depends on func(fx_test.B) fx_test.A") + assert.Contains(t, errMsg, "depends on func(fx_test.A) fx_test.B") }) t.Run("ProvidesDotGraph", func(t *testing.T) { @@ -1580,6 +1580,31 @@ func TestErrorHook(t *testing.T) { assert.Contains(t, graphStr, `"fx_test.B" [color=red];`) assert.Contains(t, graphStr, `"fx_test.A" [color=orange];`) }) + + t.Run("GraphWithErrorInModule", func(t *testing.T) { + t.Parallel() + + type A struct{} + type B struct{} + + var errStr, graphStr string + h := errHandlerFunc(func(err error) { + errStr = err.Error() + graphStr, _ = VisualizeError(err) + }) + NewForTest(t, + Module("module", + Provide(func() (B, error) { return B{}, fmt.Errorf("great sadness") }), + Provide(func(B) A { return A{} }), + Invoke(func(A) {}), + ErrorHook(&h), + ), + ) + assert.Contains(t, errStr, "great sadness") + assert.Contains(t, graphStr, `"fx_test.B" [color=red];`) + assert.Contains(t, graphStr, `"fx_test.A" [color=orange];`) + }) + } func TestOptionString(t *testing.T) { diff --git a/module.go b/module.go new file mode 100644 index 000000000..d73f483ed --- /dev/null +++ b/module.go @@ -0,0 +1,103 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx + +import ( + "fmt" + + "go.uber.org/dig" +) + +// Module is a named group of zero or more fx.Options. +func Module(name string, opts ...Option) Option { + mo := moduleOption{ + name: name, + options: opts, + } + return mo +} + +type moduleOption struct { + name string + options []Option +} + +func (o moduleOption) String() string { + return fmt.Sprintf("fx.Module(%q, %v)", o.name, o.options) +} + +func (o moduleOption) apply(mod *module) { + // This get called on any submodules' that are declared + // as part of another module. + + // 1. Create a new module with the parent being the specified + // module. + // 2. Apply child Options on the new module. + // 3. Append it to the parent module. + newModule := &module{ + name: o.name, + parent: mod, + app: mod.app, + } + for _, opt := range o.options { + opt.apply(newModule) + } + mod.modules = append(mod.modules, newModule) +} + +type module struct { + parent *module + name string + scope *dig.Scope + provides []provide + invokes []invoke + modules []*module + app *App +} + +// builds the Scopes using the App's Container. Note that this happens +// after applyModules' are called because the App's Container needs to +// be built for any Scopes to be initialized, and applys' should be called +// before the Container can get initialized. +func (m *module) build(app *App, root *dig.Container) { + if m.parent == nil { + m.scope = root.Scope(m.name) + // TODO: Once fx.Decorate is in-place, + // use the root container instead of subscope. + } else { + parentScope := m.parent.scope + m.scope = parentScope.Scope(m.name) + } + + for _, mod := range m.modules { + mod.build(app, root) + } +} + +func (m *module) provideAll() { + for _, p := range m.provides { + m.provide(p) + } + + for _, m := range m.modules { + m.provideAll() + } +} diff --git a/module_test.go b/module_test.go new file mode 100644 index 000000000..66606dac5 --- /dev/null +++ b/module_test.go @@ -0,0 +1,340 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx_test + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/fx" + "go.uber.org/fx/fxevent" + "go.uber.org/fx/fxtest" + "go.uber.org/fx/internal/fxlog" +) + +func TestModuleSuccess(t *testing.T) { + t.Parallel() + + type Logger struct { + Name string + } + + t.Run("provide a dependency from a submodule", func(t *testing.T) { + t.Parallel() + + redis := fx.Module("redis", + fx.Provide(func() *Logger { + return &Logger{Name: "redis"} + }), + ) + + app := fxtest.New(t, + redis, + fx.Invoke(func(l *Logger) { + assert.Equal(t, "redis", l.Name) + }), + ) + + defer app.RequireStart().RequireStop() + }) + + t.Run("provide a dependency from nested modules", func(t *testing.T) { + t.Parallel() + app := fxtest.New(t, + fx.Module("child", + fx.Module("grandchild", + fx.Provide(func() *Logger { + return &Logger{Name: "redis"} + }), + ), + ), + fx.Invoke(func(l *Logger) { + assert.Equal(t, "redis", l.Name) + }), + ) + defer app.RequireStart().RequireStop() + }) + + t.Run("invoke from nested module", func(t *testing.T) { + t.Parallel() + invokeRan := false + app := fxtest.New(t, + fx.Provide(func() *Logger { + return &Logger{ + Name: "redis", + } + }), + fx.Module("child", + fx.Module("grandchild", + fx.Invoke(func(l *Logger) { + assert.Equal(t, "redis", l.Name) + invokeRan = true + }), + ), + ), + ) + require.True(t, invokeRan) + require.NoError(t, app.Err()) + defer app.RequireStart().RequireStop() + }) + + t.Run("invoke in module with dep from top module", func(t *testing.T) { + t.Parallel() + child := fx.Module("child", + fx.Invoke(func(l *Logger) { + assert.Equal(t, "my logger", l.Name) + }), + ) + app := fxtest.New(t, + child, + fx.Provide(func() *Logger { + return &Logger{Name: "my logger"} + }), + ) + defer app.RequireStart().RequireStop() + }) + + t.Run("provide in module with annotate", func(t *testing.T) { + t.Parallel() + child := fx.Module("child", + fx.Provide(fx.Annotate(func() *Logger { + return &Logger{Name: "good logger"} + }, fx.ResultTags(`name:"goodLogger"`))), + ) + app := fxtest.New(t, + child, + fx.Invoke(fx.Annotate(func(l *Logger) { + assert.Equal(t, "good logger", l.Name) + }, fx.ParamTags(`name:"goodLogger"`))), + ) + defer app.RequireStart().RequireStop() + }) + + t.Run("invoke in module with annotate", func(t *testing.T) { + t.Parallel() + ranInvoke := false + child := fx.Module("child", + // use something provided by the root module. + fx.Invoke(fx.Annotate(func(l *Logger) { + assert.Equal(t, "good logger", l.Name) + ranInvoke = true + })), + ) + app := fxtest.New(t, + child, + fx.Provide(fx.Annotate(func() *Logger { + return &Logger{Name: "good logger"} + })), + ) + defer app.RequireStart().RequireStop() + assert.True(t, ranInvoke) + }) + + t.Run("use Options in Module", func(t *testing.T) { + t.Parallel() + + opts := fx.Options( + fx.Provide(fx.Annotate(func() string { + return "dog" + }, fx.ResultTags(`group:"pets"`))), + fx.Provide(fx.Annotate(func() string { + return "cat" + }, fx.ResultTags(`group:"pets"`))), + ) + + petModule := fx.Module("pets", opts) + + app := fxtest.New(t, + petModule, + fx.Invoke(fx.Annotate(func(pets []string) { + assert.ElementsMatch(t, []string{"dog", "cat"}, pets) + }, fx.ParamTags(`group:"pets"`))), + ) + + defer app.RequireStart().RequireStop() + }) +} + +func TestModuleFailures(t *testing.T) { + t.Parallel() + + t.Run("invoke from submodule failed", func(t *testing.T) { + t.Parallel() + + type A struct{} + type B struct{} + + sub := fx.Module("sub", + fx.Provide(func() *A { return &A{} }), + fx.Invoke(func(*A, *B) { // missing dependency. + require.Fail(t, "this should not be called") + }), + ) + + app := NewForTest(t, + sub, + fx.Invoke(func(a *A) { + assert.NotNil(t, a) + }), + ) + + err := app.Err() + require.Error(t, err) + assert.Contains(t, err.Error(), "missing type: *fx_test.B") + }) + + t.Run("provide the same dependency from multiple modules", func(t *testing.T) { + t.Parallel() + + type A struct{} + + app := NewForTest(t, + fx.Module("mod1", fx.Provide(func() A { return A{} })), + fx.Module("mod2", fx.Provide(func() A { return A{} })), + fx.Invoke(func(a A) {}), + ) + + err := app.Err() + require.Error(t, err) + assert.Contains(t, err.Error(), "already provided by ") + }) + + t.Run("providing Modules should fail", func(t *testing.T) { + t.Parallel() + app := NewForTest(t, + fx.Module("module", + fx.Provide( + fx.Module("module"), + ), + ), + ) + err := app.Err() + require.Error(t, err) + assert.Contains(t, err.Error(), "fx.Option should be passed to fx.New directly, not to fx.Provide") + }) + + t.Run("invoking Modules should fail", func(t *testing.T) { + t.Parallel() + app := NewForTest(t, + fx.Module("module", + fx.Invoke( + fx.Invoke("module"), + ), + ), + ) + err := app.Err() + require.Error(t, err) + assert.Contains(t, err.Error(), "fx.Option should be passed to fx.New directly, not to fx.Invoke") + }) + + t.Run("annotate failure in Module", func(t *testing.T) { + t.Parallel() + + type A struct{} + newA := func() A { + return A{} + } + + app := NewForTest(t, + fx.Module("module", + fx.Provide(fx.Annotate(newA, + fx.ParamTags(`"name:"A1"`), + fx.ParamTags(`"name:"A2"`), + )), + ), + ) + err := app.Err() + require.Error(t, err) + + assert.Contains(t, err.Error(), "encountered error while applying annotation") + assert.Contains(t, err.Error(), "cannot apply more than one line of ParamTags") + }) + + t.Run("provider in Module fails", func(t *testing.T) { + t.Parallel() + + type A struct{} + type B struct{} + + newA := func() (A, error) { + return A{}, nil + } + newB := func() (B, error) { + return B{}, errors.New("minor sadness") + } + + app := NewForTest(t, + fx.Module("module", + fx.Provide( + newA, + newB, + ), + ), + fx.Invoke(func(A, B) { + assert.Fail(t, "this should never run") + }), + ) + + err := app.Err() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to build fx_test.B") + assert.Contains(t, err.Error(), "minor sadness") + }) + + t.Run("invalid Options in Module", func(t *testing.T) { + t.Parallel() + + tests := []struct { + desc string + opt fx.Option + }{ + { + desc: "StartTimeout Option", + opt: fx.StartTimeout(time.Second), + }, + { + desc: "StopTimeout Option", + opt: fx.StopTimeout(time.Second), + }, + { + desc: "WithLogger Option", + opt: fx.WithLogger(func() fxevent.Logger { return new(fxlog.Spy) }), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + app := NewForTest(t, + fx.Module("module", + tt.opt, + ), + ) + require.Error(t, app.Err()) + }) + } + }) +} diff --git a/supply.go b/supply.go index 232a49357..5813833cf 100644 --- a/supply.go +++ b/supply.go @@ -86,9 +86,9 @@ type supplyOption struct { Stack fxreflect.Stack } -func (o supplyOption) apply(app *App) { +func (o supplyOption) apply(m *module) { for i, target := range o.Targets { - app.provides = append(app.provides, provide{ + m.provides = append(m.provides, provide{ Target: target, Stack: o.Stack, IsSupply: true, diff --git a/supply_test.go b/supply_test.go index d3dd622c6..3c0bf7cca 100644 --- a/supply_test.go +++ b/supply_test.go @@ -64,6 +64,25 @@ func TestSupply(t *testing.T) { require.Same(t, bIn, bOut) }) + t.Run("SupplyInModule", func(t *testing.T) { + t.Parallel() + + aIn, bIn := &A{}, &B{} + var aOut *A + var bOut *B + + app := fxtest.New( + t, + fx.Module("module", + fx.Supply(aIn, bIn), + ), + fx.Populate(&aOut, &bOut), + ) + defer app.RequireStart().RequireStop() + require.Same(t, aIn, aOut) + require.Same(t, bIn, bOut) + }) + t.Run("AnnotateIsSupplied", func(t *testing.T) { t.Parallel() @@ -104,8 +123,7 @@ func TestSupply(t *testing.T) { require.NotPanicsf( t, func() { fx.Supply(A{}, (*B)(nil)) }, - "a wrapped nil should not panic", - ) + "a wrapped nil should not panic") require.PanicsWithValuef( t, @@ -134,4 +152,5 @@ func TestSupply(t *testing.T) { require.NoError(t, supplied[0].(*fxevent.Supplied).Err) require.Error(t, supplied[1].(*fxevent.Supplied).Err) }) + }