From dd0b5dceda525c681732a639abf10c2367d96ff9 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Fri, 2 Feb 2024 11:20:32 +0000 Subject: [PATCH] Make plugin-specific env take precedence over sys env (#25128) * Make plugin-specific env take precedence over sys env * Expand the existing plugin env integration test --------- Co-authored-by: Austin Gebauer <34121980+austingebauer@users.noreply.github.com> --- changelog/25128.txt | 6 ++ sdk/helper/pluginutil/env.go | 6 ++ sdk/helper/pluginutil/run_config.go | 44 ++++++-- sdk/helper/pluginutil/run_config_test.go | 54 ++++++---- sdk/plugin/mock/backend.go | 1 + sdk/plugin/mock/path_env.go | 38 +++++++ vault/external_tests/plugin/plugin_test.go | 103 ++++++++----------- vault/plugincatalog/plugin_catalog.go | 111 +++++++++++++------- vault/plugincatalog/plugin_catalog_test.go | 114 +++++++++++++++++++++ 9 files changed, 349 insertions(+), 128 deletions(-) create mode 100644 changelog/25128.txt create mode 100644 sdk/plugin/mock/path_env.go diff --git a/changelog/25128.txt b/changelog/25128.txt new file mode 100644 index 000000000000..c9003ffea88d --- /dev/null +++ b/changelog/25128.txt @@ -0,0 +1,6 @@ +```release-note:change +plugins: By default, environment variables provided during plugin registration will now take precedence over system environment variables. +Use the environment variable `VAULT_PLUGIN_USE_LEGACY_ENV_LAYERING=true` to opt out and keep higher preference for system environment +variables. When this flag is set, Vault will check during unseal for conflicts and print warnings for any plugins with environment +variables that conflict with system environment variables. +``` diff --git a/sdk/helper/pluginutil/env.go b/sdk/helper/pluginutil/env.go index 1b45ef32dca2..515baa1f7619 100644 --- a/sdk/helper/pluginutil/env.go +++ b/sdk/helper/pluginutil/env.go @@ -38,6 +38,12 @@ const ( // PluginMultiplexingOptOut is an ENV name used to define a comma separated list of plugin names // opted-out of the multiplexing feature; for emergencies if multiplexing ever causes issues PluginMultiplexingOptOut = "VAULT_PLUGIN_MULTIPLEXING_OPT_OUT" + + // PluginUseLegacyEnvLayering opts out of new environment variable precedence. + // If set to true, Vault process environment variables take precedence over any + // colliding plugin-specific environment variables. Otherwise, plugin-specific + // environment variables take precedence over Vault process environment variables. + PluginUseLegacyEnvLayering = "VAULT_PLUGIN_USE_LEGACY_ENV_LAYERING" ) // OptionallyEnableMlock determines if mlock should be called, and if so enables diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index 9b44e9c4f8e2..1af71d09b75c 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -65,26 +65,26 @@ func (rc runConfig) mlockEnabled() bool { func (rc runConfig) generateCmd(ctx context.Context) (cmd *exec.Cmd, clientTLSConfig *tls.Config, err error) { cmd = exec.Command(rc.command, rc.args...) - cmd.Env = append(cmd.Env, rc.env...) + env := rc.env // Add the mlock setting to the ENV of the plugin if rc.mlockEnabled() { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) + env = append(env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } version, err := rc.Wrapper.VaultVersion(ctx) if err != nil { return nil, nil, err } - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version)) + env = append(env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version)) if rc.IsMetadataMode { rc.Logger = rc.Logger.With("metadata", "true") } metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode) - cmd.Env = append(cmd.Env, metadataEnv) + env = append(env, metadataEnv) automtlsEnv := fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, rc.AutoMTLS) - cmd.Env = append(cmd.Env, automtlsEnv) + env = append(env, automtlsEnv) if !rc.AutoMTLS && !rc.IsMetadataMode { // Get a CA TLS Certificate @@ -107,7 +107,35 @@ func (rc runConfig) generateCmd(ctx context.Context) (cmd *exec.Cmd, clientTLSCo } // Add the response wrap token to the ENV of the plugin - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + env = append(env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + } + + if rc.image == "" { + // go-plugin has always overridden user-provided env vars with the OS + // (Vault process) env vars, but we want plugins to be able to override + // the Vault process env. We don't want to make a breaking change in + // go-plugin so always set SkipHostEnv and replicate the legacy behavior + // ourselves if user opts in. + if legacy, _ := strconv.ParseBool(os.Getenv(PluginUseLegacyEnvLayering)); legacy { + // Env vars are layered as follows, with later entries overriding + // earlier entries if there are duplicate keys: + // 1. Env specified at plugin registration + // 2. Env from Vault SDK + // 3. Env from Vault process (OS) + // 4. Env from go-plugin + cmd.Env = append(env, os.Environ()...) + } else { + // Env vars are layered as follows, with later entries overriding + // earlier entries if there are duplicate keys: + // 1. Env from Vault process (OS) + // 2. Env specified at plugin registration + // 3. Env from Vault SDK + // 4. Env from go-plugin + cmd.Env = append(os.Environ(), env...) + } + } else { + // Containerized plugins do not inherit any env vars from Vault. + cmd.Env = env } return cmd, clientTLSConfig, nil @@ -128,7 +156,8 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - AutoMTLS: rc.AutoMTLS, + AutoMTLS: rc.AutoMTLS, + SkipHostEnv: true, } if rc.image == "" { clientConfig.Cmd = cmd @@ -141,7 +170,6 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error if err != nil { return nil, err } - clientConfig.SkipHostEnv = true clientConfig.RunnerFunc = containerCfg.NewContainerRunner clientConfig.UnixSocketConfig = &plugin.UnixSocketConfig{ Group: strconv.Itoa(containerCfg.GroupAdd), diff --git a/sdk/helper/pluginutil/run_config_test.go b/sdk/helper/pluginutil/run_config_test.go index dcd59dfa7b5f..6bb840f462d9 100644 --- a/sdk/helper/pluginutil/run_config_test.go +++ b/sdk/helper/pluginutil/run_config_test.go @@ -34,10 +34,11 @@ func TestMakeConfig(t *testing.T) { mlockEnabled bool mlockEnabledTimes int - expectedConfig *plugin.ClientConfig - expectTLSConfig bool - expectRunnerFunc bool - skipSecureConfig bool + expectedConfig *plugin.ClientConfig + expectTLSConfig bool + expectRunnerFunc bool + skipSecureConfig bool + useLegacyEnvLayering bool } tests := map[string]testCase{ @@ -66,8 +67,9 @@ func TestMakeConfig(t *testing.T) { responseWrapInfoTimes: 0, - mlockEnabled: false, - mlockEnabledTimes: 1, + mlockEnabled: false, + mlockEnabledTimes: 1, + useLegacyEnvLayering: true, expectedConfig: &plugin.ClientConfig{ HandshakeConfig: plugin.HandshakeConfig{ @@ -83,12 +85,12 @@ func TestMakeConfig(t *testing.T) { Cmd: commandWithEnv( "echo", []string{"foo", "bar"}, - []string{ + append(append([]string{ "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), - }, + }, os.Environ()...), PluginUseLegacyEnvLayering+"=true"), ), SecureConfig: &plugin.SecureConfig{ Checksum: []byte("some_sha256"), @@ -98,8 +100,9 @@ func TestMakeConfig(t *testing.T) { plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - Logger: hclog.NewNullLogger(), - AutoMTLS: false, + Logger: hclog.NewNullLogger(), + AutoMTLS: false, + SkipHostEnv: true, }, expectTLSConfig: false, }, @@ -148,14 +151,14 @@ func TestMakeConfig(t *testing.T) { Cmd: commandWithEnv( "echo", []string{"foo", "bar"}, - []string{ + append(os.Environ(), []string{ "initial=true", fmt.Sprintf("%s=%t", PluginMlockEnabled, true), fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"), - }, + }...), ), SecureConfig: &plugin.SecureConfig{ Checksum: []byte("some_sha256"), @@ -165,8 +168,9 @@ func TestMakeConfig(t *testing.T) { plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - Logger: hclog.NewNullLogger(), - AutoMTLS: false, + Logger: hclog.NewNullLogger(), + AutoMTLS: false, + SkipHostEnv: true, }, expectTLSConfig: true, }, @@ -212,12 +216,12 @@ func TestMakeConfig(t *testing.T) { Cmd: commandWithEnv( "echo", []string{"foo", "bar"}, - []string{ + append(os.Environ(), []string{ "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), - }, + }...), ), SecureConfig: &plugin.SecureConfig{ Checksum: []byte("some_sha256"), @@ -227,8 +231,9 @@ func TestMakeConfig(t *testing.T) { plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - Logger: hclog.NewNullLogger(), - AutoMTLS: true, + Logger: hclog.NewNullLogger(), + AutoMTLS: true, + SkipHostEnv: true, }, expectTLSConfig: false, }, @@ -274,12 +279,12 @@ func TestMakeConfig(t *testing.T) { Cmd: commandWithEnv( "echo", []string{"foo", "bar"}, - []string{ + append(os.Environ(), []string{ "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), - }, + }...), ), SecureConfig: &plugin.SecureConfig{ Checksum: []byte("some_sha256"), @@ -289,8 +294,9 @@ func TestMakeConfig(t *testing.T) { plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - Logger: hclog.NewNullLogger(), - AutoMTLS: true, + Logger: hclog.NewNullLogger(), + AutoMTLS: true, + SkipHostEnv: true, }, expectTLSConfig: false, }, @@ -369,6 +375,10 @@ func TestMakeConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + if test.useLegacyEnvLayering { + t.Setenv(PluginUseLegacyEnvLayering, "true") + } + config, err := test.rc.makeConfig(ctx) if err != nil { t.Fatalf("no error expected, got: %s", err) diff --git a/sdk/plugin/mock/backend.go b/sdk/plugin/mock/backend.go index 6ca6421830fb..b34191b938a3 100644 --- a/sdk/plugin/mock/backend.go +++ b/sdk/plugin/mock/backend.go @@ -59,6 +59,7 @@ func Backend() *backend { pathInternal(&b), pathSpecial(&b), pathRaw(&b), + pathEnv(&b), }, ), PathsSpecial: &logical.Paths{ diff --git a/sdk/plugin/mock/path_env.go b/sdk/plugin/mock/path_env.go new file mode 100644 index 000000000000..18b4b71ccc32 --- /dev/null +++ b/sdk/plugin/mock/path_env.go @@ -0,0 +1,38 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package mock + +import ( + "context" + "os" + + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" +) + +// pathEnv is used to interrogate plugin env vars. +func pathEnv(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "env/" + framework.GenericNameRegex("key"), + Fields: map[string]*framework.FieldSchema{ + "key": { + Type: framework.TypeString, + Required: true, + Description: "The name of the environment variable to read.", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathEnvRead, + }, + } +} + +func (b *backend) pathEnvRead(_ context.Context, _ *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Return the secret + return &logical.Response{ + Data: map[string]interface{}{ + "key": os.Getenv(data.Get("key").(string)), + }, + }, nil +} diff --git a/vault/external_tests/plugin/plugin_test.go b/vault/external_tests/plugin/plugin_test.go index a35dbe2ea525..e38853fbcc81 100644 --- a/vault/external_tests/plugin/plugin_test.go +++ b/vault/external_tests/plugin/plugin_test.go @@ -23,11 +23,6 @@ import ( "github.com/hashicorp/vault/vault" ) -const ( - expectedEnvKey = "FOO" - expectedEnvValue = "BAR" -) - // logicalVersionMap is a map of version to test plugin var logicalVersionMap = map[string]string{ "v4": "TestBackend_PluginMain_V4_Logical", @@ -699,50 +694,69 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo return cluster } +// TestSystemBackend_Plugin_Env ensures we use env vars specified during plugin +// registration, and get the priority between OS and plugin env vars correct. func TestSystemBackend_Plugin_Env(t *testing.T) { - kvPair := fmt.Sprintf("%s=%s", expectedEnvKey, expectedEnvValue) - cluster := testSystemBackend_SingleCluster_Env(t, []string{kvPair}) - defer cluster.Cleanup() -} - -// testSystemBackend_SingleCluster_Env is a helper func that returns a single -// cluster and a single mounted plugin logical backend. -func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.TestCluster { pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ - LogicalBackends: map[string]logical.Factory{ - "test": plugin.Factory, - }, PluginDirectory: pluginDir, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ - HandlerFunc: vaulthttp.Handler, - KeepStandbysSealed: true, - NumCores: 1, - TempDir: pluginDir, + HandlerFunc: vaulthttp.Handler, + NumCores: 1, + TempDir: pluginDir, }) cluster.Start() + t.Cleanup(cluster.Cleanup) core := cluster.Cores[0] vault.TestWaitActive(t, core.Core) client := core.Client - env = append([]string{pluginutil.PluginCACertPEMEnv + "=" + cluster.CACertPEMFile}, env...) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestBackend_PluginMainEnv", env) - options := map[string]interface{}{ - "type": "mock-plugin", + key := t.Name() + "_FOO" + osValue := "bar" + pluginValue := "baz" + t.Setenv(key, osValue) + env := []string{ + fmt.Sprintf("%s=%s", key, pluginValue), + pluginutil.PluginCACertPEMEnv + "=" + cluster.CACertPEMFile, } + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestBackend_PluginMainLogical", env) - resp, err := client.Logical().Write("sys/mounts/mock", options) + err := client.Sys().Mount("mock", &api.MountInput{ + Type: "mock-plugin", + }) if err != nil { t.Fatalf("err: %v", err) } - if resp != nil { - t.Fatalf("bad: %v", resp) + + // Plugin env should take precedence by default. + resp, err := client.Logical().Read("mock/env/" + key) + if err != nil { + t.Fatal(err) + } + if resp == nil || resp.Data["key"] != pluginValue { + t.Fatal(resp) } - return cluster + // Now set the flag that reverts to legacy behavior and reload the plugin. + t.Setenv(pluginutil.PluginUseLegacyEnvLayering, "true") + _, err = client.Sys().RootReloadPlugin(context.Background(), &api.RootReloadPluginInput{ + Plugin: "mock-plugin", + }) + if err != nil { + t.Fatal(err) + } + + // Now the OS value should take precedence. + resp, err = client.Logical().Read("mock/env/" + key) + if err != nil { + t.Fatal(err) + } + if resp == nil || resp.Data["key"] != osValue { + t.Fatal(resp) + } } func TestBackend_PluginMain_V4_Logical(t *testing.T) { @@ -922,36 +936,3 @@ func TestBackend_PluginMainCredentials(t *testing.T) { t.Fatal(err) } } - -// TestBackend_PluginMainEnv is a mock plugin that simply checks for the existence of FOO env var. -func TestBackend_PluginMainEnv(t *testing.T) { - args := []string{} - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" { - return - } - - // Check on actual vs expected env var - actual := os.Getenv(expectedEnvKey) - if actual != expectedEnvValue { - t.Fatalf("expected: %q, got: %q", expectedEnvValue, actual) - } - - caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) - if caPEM == "" { - t.Fatal("CA cert not passed in") - } - args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) - - apiClientMeta := &api.PluginAPIClientMeta{} - flags := apiClientMeta.FlagSet() - flags.Parse(args) - - factoryFunc := mock.FactoryType(logical.TypeLogical) - - err := lplugin.Serve(&lplugin.ServeOpts{ - BackendFactoryFunc: factoryFunc, - }) - if err != nil { - t.Fatal(err) - } -} diff --git a/vault/plugincatalog/plugin_catalog.go b/vault/plugincatalog/plugin_catalog.go index 03c958b68cb5..f622fb17b1f6 100644 --- a/vault/plugincatalog/plugin_catalog.go +++ b/vault/plugincatalog/plugin_catalog.go @@ -9,9 +9,11 @@ import ( "encoding/json" "errors" "fmt" + "os" "path" "path/filepath" "sort" + "strconv" "strings" "sync" @@ -169,6 +171,31 @@ func SetupPluginCatalog(ctx context.Context, in *PluginCatalogInput) (*PluginCat return nil, err } + if legacy, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUseLegacyEnvLayering)); legacy { + conflicts := false + osKeys := envKeys(os.Environ()) + plugins, err := catalog.collectAllPlugins(ctx) + if err != nil { + return nil, err + } + for _, plugin := range plugins { + pluginKeys := envKeys(plugin.Env) + for k := range pluginKeys { + if _, ok := osKeys[k]; ok { + conflicts = true + logger.Warn("conflict between system and plugin environment variable", "type", plugin.Type, "name", plugin.Name, "version", plugin.Version, "variable", k) + } + } + } + if conflicts { + logger.Warn("conflicts found between system and plugin environment variables, "+ + "system environment variables will take precedence until flag is disabled", + pluginutil.PluginUseLegacyEnvLayering, os.Getenv(pluginutil.PluginUseLegacyEnvLayering)) + } else { + logger.Info("no conflicts found between system and plugin environment variables") + } + } + logger.Info("successfully setup plugin catalog", "plugin-directory", catalog.directory) if catalog.tmpdir != "" { logger.Debug("plugin temporary directory configured", "tmpdir", catalog.tmpdir) @@ -177,6 +204,18 @@ func SetupPluginCatalog(ctx context.Context, in *PluginCatalogInput) (*PluginCat return catalog, nil } +func envKeys(env []string) map[string]struct{} { + keys := make(map[string]struct{}, len(env)) + for _, env := range env { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 0 { + continue + } + keys[parts[0]] = struct{}{} + } + return keys +} + type pluginClientConn struct { *grpc.ClientConn id string @@ -1071,28 +1110,13 @@ func (c *PluginCatalog) List(ctx context.Context, pluginType consts.PluginType) // ListPluginsWithRuntime lists the plugins that are registered with a given runtime func (c *PluginCatalog) ListPluginsWithRuntime(ctx context.Context, runtime string) ([]string, error) { - // Collect keys for external plugins in the barrier. - keys, err := logical.CollectKeys(ctx, c.catalogView) + plugins, err := c.collectAllPlugins(ctx) if err != nil { return nil, err } var ret []string - for _, key := range keys { - // Skip: pinned version entry. - if strings.HasPrefix(key, pinnedVersionStoragePrefix) { - continue - } - entry, err := c.catalogView.Get(ctx, key) - if err != nil || entry == nil { - continue - } - - plugin := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(entry.Value, plugin); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %w", err) - } - + for _, plugin := range plugins { if plugin.Runtime == runtime { ret = append(ret, plugin.Name) } @@ -1113,31 +1137,14 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug var result []pluginutil.VersionedPlugin - // Collect keys for external plugins in the barrier. - keys, err := logical.CollectKeys(ctx, c.catalogView) + plugins, err := c.collectAllPlugins(ctx) if err != nil { return nil, err } unversionedPlugins := make(map[string]struct{}) - for _, key := range keys { - // Skip: pinned version entry. - if strings.HasPrefix(key, pinnedVersionStoragePrefix) { - continue - } - + for _, plugin := range plugins { var semanticVersion *semver.Version - - entry, err := c.catalogView.Get(ctx, key) - if err != nil || entry == nil { - continue - } - - plugin := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(entry.Value, plugin); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %w", err) - } - if plugin.Version == "" { semanticVersion, err = semver.NewVersion("0.0.0") if err != nil { @@ -1150,7 +1157,7 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug semanticVersion, err = semver.NewVersion(plugin.Version) if err != nil { - return nil, fmt.Errorf("unexpected error parsing version from plugin catalog entry %q: %w", key, err) + return nil, fmt.Errorf("unexpected error parsing %s %s plugin version: %w", plugin.Type.String(), plugin.Name, err) } } @@ -1208,6 +1215,36 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug return result, nil } +func (c *PluginCatalog) collectAllPlugins(ctx context.Context) ([]*pluginutil.PluginRunner, error) { + // Collect keys for external plugins in the barrier. + keys, err := logical.CollectKeys(ctx, c.catalogView) + if err != nil { + return nil, err + } + + var plugins []*pluginutil.PluginRunner + for _, key := range keys { + // Skip: pinned version entry. + if strings.HasPrefix(key, pinnedVersionStoragePrefix) { + continue + } + + entry, err := c.catalogView.Get(ctx, key) + if err != nil || entry == nil { + continue + } + + plugin := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(entry.Value, plugin); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %w", err) + } + + plugins = append(plugins, plugin) + } + + return plugins, nil +} + func sortVersionedPlugins(versionedPlugins []pluginutil.VersionedPlugin) { sort.SliceStable(versionedPlugins, func(i, j int) bool { left, right := versionedPlugins[i], versionedPlugins[j] diff --git a/vault/plugincatalog/plugin_catalog_test.go b/vault/plugincatalog/plugin_catalog_test.go index ac889f7c105e..ff829cefc1ca 100644 --- a/vault/plugincatalog/plugin_catalog_test.go +++ b/vault/plugincatalog/plugin_catalog_test.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "os" "os/exec" + "path" "path/filepath" "reflect" "runtime" @@ -69,6 +70,119 @@ func testPluginCatalog(t *testing.T) *PluginCatalog { return pluginCatalog } +type warningCountLogger struct { + log.Logger + warnings int +} + +func (l *warningCountLogger) Warn(msg string, args ...interface{}) { + l.warnings++ + l.Logger.Warn(msg, args...) +} + +func (l *warningCountLogger) reset() { + l.warnings = 0 +} + +// TestPluginCatalog_SetupPluginCatalog_WarningsWithLegacyEnvSetting ensures we +// log the correct number of warnings during plugin catalog setup (which is run +// during unseal) if users have set the flag to keep old behavior. This is to +// help users migrate safely to the new default behavior. +func TestPluginCatalog_SetupPluginCatalog_WarningsWithLegacyEnvSetting(t *testing.T) { + logger := &warningCountLogger{ + Logger: log.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }), + } + storage, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + logicalStorage := logical.NewLogicalStorage(storage) + + // prefix to avoid collisions with other tests. + const prefix = "TEST_PLUGIN_CATALOG_ENV_" + plugin := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Type: consts.PluginTypeDatabase, + Version: "1.0.0", + Env: []string{ + prefix + "A=1", + prefix + "VALUE_WITH_EQUALS=1=2", + prefix + "EMPTY_VALUE=", + }, + } + + // Insert a plugin into storage before the catalog is setup. + buf, err := json.Marshal(plugin) + if err != nil { + t.Fatal(err) + } + logicalEntry := logical.StorageEntry{ + Key: path.Join(plugin.Type.String(), plugin.Name, plugin.Version), + Value: buf, + } + if err := logicalStorage.Put(context.Background(), &logicalEntry); err != nil { + t.Fatal(err) + } + + for name, tc := range map[string]struct { + sysEnv map[string]string + expectedWarnings int + }{ + "no env": {}, + "colliding env, no flag": { + sysEnv: map[string]string{ + prefix + "A": "10", + }, + expectedWarnings: 0, + }, + "colliding env, with flag": { + sysEnv: map[string]string{ + pluginutil.PluginUseLegacyEnvLayering: "true", + prefix + "A": "10", + }, + expectedWarnings: 2, + }, + "all colliding env, with flag": { + sysEnv: map[string]string{ + pluginutil.PluginUseLegacyEnvLayering: "true", + prefix + "A": "10", + prefix + "VALUE_WITH_EQUALS": "1=2", + prefix + "EMPTY_VALUE": "", + }, + expectedWarnings: 4, + }, + } { + t.Run(name, func(t *testing.T) { + logger.reset() + for k, v := range tc.sysEnv { + t.Setenv(k, v) + } + + _, err := SetupPluginCatalog( + context.Background(), + &PluginCatalogInput{ + Logger: logger, + BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), + CatalogView: logicalStorage, + PluginDirectory: "", + Tmpdir: "", + EnableMlock: false, + PluginRuntimeCatalog: nil, + }, + ) + if err != nil { + t.Fatal(err) + } + + if tc.expectedWarnings != logger.warnings { + t.Fatalf("expected %d warnings, got %d", tc.expectedWarnings, logger.warnings) + } + }) + } +} + func TestPluginCatalog_CRUD(t *testing.T) { const pluginName = "mysql-database-plugin"