diff --git a/CHANGELOG.md b/CHANGELOG.md index e5967871..97c71423 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added +- `rivertest.WorkContext`, a test function that can be used to initialize a context to test a `JobArgs.Work` implementation that will have a client set to context for use with `river.ClientFromContext`. [PR #526](https://github.com/riverqueue/river/pull/526). - A new `river migrate-list` command is available which lists available migrations and which version a target database is migrated to. [PR #534](https://github.com/riverqueue/river/pull/534). - `river version` or `river --version` now prints River version information. [PR #537](https://github.com/riverqueue/river/pull/537). diff --git a/context.go b/context.go index c1d24d16..61786d29 100644 --- a/context.go +++ b/context.go @@ -3,18 +3,14 @@ package river import ( "context" "errors" -) - -type ctxKey int -const ( - ctxKeyClient ctxKey = iota + "github.com/riverqueue/river/internal/rivercommon" ) var errClientNotInContext = errors.New("river: client not found in context, can only be used in a Worker") func withClient[TTx any](ctx context.Context, client *Client[TTx]) context.Context { - return context.WithValue(ctx, ctxKeyClient, client) + return context.WithValue(ctx, rivercommon.ContextKeyClient{}, client) } // ClientFromContext returns the Client from the context. This function can @@ -23,6 +19,9 @@ func withClient[TTx any](ctx context.Context, client *Client[TTx]) context.Conte // // It panics if the context does not contain a Client, which will never happen // from the context provided to a Worker's Work() method. +// +// When testing JobArgs.Work implementations, it might be useful to use +// rivertest.WorkContext to initialize a context that has an available client. func ClientFromContext[TTx any](ctx context.Context) *Client[TTx] { client, err := ClientFromContextSafely[TTx](ctx) if err != nil { @@ -37,8 +36,11 @@ func ClientFromContext[TTx any](ctx context.Context) *Client[TTx] { // // It returns an error if the context does not contain a Client, which will // never happen from the context provided to a Worker's Work() method. +// +// When testing JobArgs.Work implementations, it might be useful to use +// rivertest.WorkContext to initialize a context that has an available client. func ClientFromContextSafely[TTx any](ctx context.Context) (*Client[TTx], error) { - client, exists := ctx.Value(ctxKeyClient).(*Client[TTx]) + client, exists := ctx.Value(rivercommon.ContextKeyClient{}).(*Client[TTx]) if !exists || client == nil { return nil, errClientNotInContext } diff --git a/internal/rivercommon/river_common.go b/internal/rivercommon/river_common.go index 6e8ead1f..80c91336 100644 --- a/internal/rivercommon/river_common.go +++ b/internal/rivercommon/river_common.go @@ -16,6 +16,8 @@ const ( QueueDefault = "default" ) +type ContextKeyClient struct{} + // ErrShutdown is a special error injected by the client into its fetch and work // CancelCauseFuncs when it's stopping. It may be used by components for such // cases like avoiding logging an error during a normal shutdown procedure. This diff --git a/rivertest/rivertest.go b/rivertest/rivertest.go index 5c51975a..7ab025ff 100644 --- a/rivertest/rivertest.go +++ b/rivertest/rivertest.go @@ -12,6 +12,7 @@ import ( "time" "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/rivercommon" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" @@ -535,3 +536,12 @@ func failure(t testingT, format string, a ...any) { func failureString(format string, a ...any) string { return "\n River assertion failure:\n " + fmt.Sprintf(format, a...) + "\n" } + +// WorkContext returns a realistic context that can be used to test JobArgs.Work +// implementations. +// +// In particual, adds a client to the context so that river.ClientFromContext is +// usable in the test suite. +func WorkContext[TTx any](ctx context.Context, client *river.Client[TTx]) context.Context { + return context.WithValue(ctx, rivercommon.ContextKeyClient{}, client) +} diff --git a/rivertest/rivertest_test.go b/rivertest/rivertest_test.go index a79619ef..c51a59db 100644 --- a/rivertest/rivertest_test.go +++ b/rivertest/rivertest_test.go @@ -1059,6 +1059,32 @@ func TestRequireManyInsertedTx(t *testing.T) { }) } +func TestWorkContext(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct{} + + setup := func(ctx context.Context, t *testing.T) (context.Context, *testBundle) { + t.Helper() + + client, err := river.NewClient(riverpgxv5.New(nil), &river.Config{}) + require.NoError(t, err) + + return WorkContext(ctx, client), &testBundle{} + } + + t.Run("ClientFromContext", func(t *testing.T) { + t.Parallel() + + ctx, _ := setup(ctx, t) + + client := river.ClientFromContext[pgx.Tx](ctx) + require.NotNil(t, client) + }) +} + // MockT mocks testingT (or *testing.T). It's used to let us verify our test // helpers. type MockT struct {