Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 27 additions & 46 deletions pkg/config/auto_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
package config

import (
"context"
"testing"

"github.com/stretchr/testify/assert"

"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
)

type mockEnvProvider struct {
envVars map[string]string
}

func (m *mockEnvProvider) Get(_ context.Context, name string) (string, bool) {
val, found := m.envVars[name]
return val, found
}

func TestAvailableProviders_NoGateway(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -96,7 +87,7 @@ func TestAvailableProviders_NoGateway(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

providers := AvailableProviders(t.Context(), "", &mockEnvProvider{envVars: tt.envVars})
providers := AvailableProviders(t.Context(), "", environment.NewMapEnvProvider(tt.envVars))

assert.NotEmpty(t, providers)
assert.Equal(t, tt.expectedProvider, providers[0])
Expand Down Expand Up @@ -152,7 +143,7 @@ func TestAvailableProviders_WithGateway(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

providers := AvailableProviders(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars})
providers := AvailableProviders(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars))

assert.Len(t, providers, 1)
assert.Equal(t, tt.expectedProvider, providers[0])
Expand Down Expand Up @@ -228,7 +219,7 @@ func TestAutoModelConfig(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}, nil)
modelConfig := AutoModelConfig(t.Context(), tt.gateway, environment.NewMapEnvProvider(tt.envVars), nil)

assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
Expand Down Expand Up @@ -328,7 +319,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
envVars["MISTRAL_API_KEY"] = "test-key"
}

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars}, nil)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(envVars), nil)

// Verify the returned model matches the DefaultModels entry
expectedModel := DefaultModels[provider]
Expand All @@ -341,7 +332,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
t.Run("dmr", func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, nil)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), nil)

assert.Equal(t, "dmr", modelConfig.Provider)
assert.Equal(t, DefaultModels["dmr"], modelConfig.Model)
Expand All @@ -353,51 +344,41 @@ func TestAvailableProviders_PrecedenceOrder(t *testing.T) {
t.Parallel()

// All keys present - anthropic should win
env := &mockEnvProvider{
envVars: map[string]string{
"ANTHROPIC_API_KEY": "test-key",
"OPENAI_API_KEY": "test-key",
"GOOGLE_API_KEY": "test-key",
"MISTRAL_API_KEY": "test-key",
},
}
var env environment.Provider = environment.NewMapEnvProvider(map[string]string{
"ANTHROPIC_API_KEY": "test-key",
"OPENAI_API_KEY": "test-key",
"GOOGLE_API_KEY": "test-key",
"MISTRAL_API_KEY": "test-key",
})
providers := AvailableProviders(t.Context(), "", env)
assert.Equal(t, "anthropic", providers[0])

// No anthropic - openai should win
env = &mockEnvProvider{
envVars: map[string]string{
"OPENAI_API_KEY": "test-key",
"GOOGLE_API_KEY": "test-key",
"MISTRAL_API_KEY": "test-key",
},
}
env = environment.NewMapEnvProvider(map[string]string{
"OPENAI_API_KEY": "test-key",
"GOOGLE_API_KEY": "test-key",
"MISTRAL_API_KEY": "test-key",
})
providers = AvailableProviders(t.Context(), "", env)
assert.Equal(t, "openai", providers[0])

// No anthropic or openai - google should win
env = &mockEnvProvider{
envVars: map[string]string{
"GOOGLE_API_KEY": "test-key",
"MISTRAL_API_KEY": "test-key",
},
}
env = environment.NewMapEnvProvider(map[string]string{
"GOOGLE_API_KEY": "test-key",
"MISTRAL_API_KEY": "test-key",
})
providers = AvailableProviders(t.Context(), "", env)
assert.Equal(t, "google", providers[0])

// No anthropic, openai, or google - mistral should win
env = &mockEnvProvider{
envVars: map[string]string{
"MISTRAL_API_KEY": "test-key",
},
}
env = environment.NewMapEnvProvider(map[string]string{
"MISTRAL_API_KEY": "test-key",
})
providers = AvailableProviders(t.Context(), "", env)
assert.Equal(t, "mistral", providers[0])

// No keys at all - dmr should be selected
env = &mockEnvProvider{
envVars: map[string]string{},
}
env = environment.NewNoEnvProvider()
providers = AvailableProviders(t.Context(), "", env)
assert.Equal(t, "dmr", providers[0])
}
Expand Down Expand Up @@ -467,7 +448,7 @@ func TestAutoModelConfig_UserDefaultModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: tt.envVars}, tt.defaultModel)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewMapEnvProvider(tt.envVars), tt.defaultModel)

assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
Expand All @@ -490,7 +471,7 @@ func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) {
ThinkingBudget: thinkingBudget,
}

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, defaultModel)
modelConfig := AutoModelConfig(t.Context(), "", environment.NewNoEnvProvider(), defaultModel)

assert.Equal(t, "anthropic", modelConfig.Provider)
assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model)
Expand Down
38 changes: 15 additions & 23 deletions pkg/config/sources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/environment"
)

func TestURLSource_Read(t *testing.T) {
Expand Down Expand Up @@ -308,11 +310,9 @@ func TestURLSource_Read_WithGitHubAuth(t *testing.T) {
t.Cleanup(server.Close)

// Create a mock env provider that returns a GitHub token
envProvider := &mockEnvProvider{
envVars: map[string]string{
"GITHUB_TOKEN": "test-token-123",
},
}
envProvider := environment.NewMapEnvProvider(map[string]string{
"GITHUB_TOKEN": "test-token-123",
})

// For non-GitHub URLs, auth should not be added even with token available
source := NewURLSource(server.URL, envProvider)
Expand Down Expand Up @@ -340,11 +340,9 @@ func TestURLSource_Read_WithGitHubAuth_GitHubURL(t *testing.T) {
}))
t.Cleanup(server.Close)

envProvider := &mockEnvProvider{
envVars: map[string]string{
"GITHUB_TOKEN": "test-token-456",
},
}
envProvider := environment.NewMapEnvProvider(map[string]string{
"GITHUB_TOKEN": "test-token-456",
})

// URL with GitHub host in path (not hostname) should NOT receive auth
// This prevents token leakage to attacker-controlled domains
Expand All @@ -369,9 +367,7 @@ func TestURLSource_Read_WithGitHubAuth_NoToken(t *testing.T) {
t.Cleanup(server.Close)

// Create a mock env provider without a GitHub token
envProvider := &mockEnvProvider{
envVars: map[string]string{},
}
envProvider := environment.NewNoEnvProvider()

source := NewURLSource(server.URL, envProvider)
_, err := source.Read(t.Context())
Expand Down Expand Up @@ -436,11 +432,9 @@ func TestIsGitHubURL(t *testing.T) {
func TestResolve_URLReference_WithEnvProvider(t *testing.T) {
t.Parallel()

envProvider := &mockEnvProvider{
envVars: map[string]string{
"GITHUB_TOKEN": "test-token",
},
}
envProvider := environment.NewMapEnvProvider(map[string]string{
"GITHUB_TOKEN": "test-token",
})

source, err := Resolve("https://github.com/owner/repo/raw/main/agent.yaml", envProvider)
require.NoError(t, err)
Expand All @@ -455,11 +449,9 @@ func TestResolve_URLReference_WithEnvProvider(t *testing.T) {
func TestResolveSources_URLReference_WithEnvProvider(t *testing.T) {
t.Parallel()

envProvider := &mockEnvProvider{
envVars: map[string]string{
"GITHUB_TOKEN": "test-token",
},
}
envProvider := environment.NewMapEnvProvider(map[string]string{
"GITHUB_TOKEN": "test-token",
})

url := "https://github.com/owner/repo/raw/main/agent.yaml"
sources, err := ResolveSources(url, envProvider)
Expand Down
25 changes: 25 additions & 0 deletions pkg/environment/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ func (p *EnvListProvider) Get(_ context.Context, name string) (string, bool) {
return "", false
}

// MapEnvProvider provides access to a static map of environment variables.
type MapEnvProvider struct {
vars map[string]string
}

func NewMapEnvProvider(vars map[string]string) *MapEnvProvider {
return &MapEnvProvider{vars: vars}
}

func (p *MapEnvProvider) Get(_ context.Context, name string) (string, bool) {
v, ok := p.vars[name]
return v, ok
}

// NoEnvProvider is a provider that never finds any variable.
type NoEnvProvider struct{}

func NewNoEnvProvider() *NoEnvProvider {
return &NoEnvProvider{}
}

func (p *NoEnvProvider) Get(context.Context, string) (string, bool) {
return "", false
}

// EnvFilesProvider provides access env files.
type EnvFilesProvider struct {
values []KeyValuePair
Expand Down
47 changes: 13 additions & 34 deletions pkg/model/provider/bedrock/client_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package bedrock

import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -342,24 +341,10 @@ func TestConvertImageURL_ValidImage(t *testing.T) {

// NewClient validation tests

type mockEnvProvider struct {
values map[string]string
}

func (m *mockEnvProvider) Get(_ context.Context, key string) (string, bool) {
if m.values == nil {
return "", false
}
v, ok := m.values[key]
return v, ok
}

var _ environment.Provider = (*mockEnvProvider)(nil)

func TestNewClient_NilConfig(t *testing.T) {
t.Parallel()

_, err := NewClient(t.Context(), nil, &mockEnvProvider{})
_, err := NewClient(t.Context(), nil, environment.NewNoEnvProvider())
require.Error(t, err)
assert.Contains(t, err.Error(), "model configuration is required")
}
Expand All @@ -371,7 +356,7 @@ func TestNewClient_WrongProvider(t *testing.T) {
Provider: "openai",
Model: "gpt-4",
}
_, err := NewClient(t.Context(), cfg, &mockEnvProvider{})
_, err := NewClient(t.Context(), cfg, environment.NewNoEnvProvider())
require.Error(t, err)
assert.Contains(t, err.Error(), "model type must be 'amazon-bedrock'")
}
Expand Down Expand Up @@ -422,7 +407,7 @@ func TestBuildAWSConfig_DefaultRegion(t *testing.T) {
ProviderOpts: map[string]any{},
}

env := &mockEnvProvider{values: map[string]string{}}
env := environment.NewNoEnvProvider()

awsCfg, err := buildAWSConfig(t.Context(), cfg, env)
require.NoError(t, err)
Expand All @@ -442,7 +427,7 @@ func TestBuildAWSConfig_RegionFromProviderOpts(t *testing.T) {
},
}

env := &mockEnvProvider{values: map[string]string{}}
env := environment.NewNoEnvProvider()

awsCfg, err := buildAWSConfig(t.Context(), cfg, env)
require.NoError(t, err)
Expand All @@ -459,9 +444,9 @@ func TestBuildAWSConfig_RegionFromEnv(t *testing.T) {
ProviderOpts: map[string]any{},
}

env := &mockEnvProvider{values: map[string]string{
env := environment.NewMapEnvProvider(map[string]string{
"AWS_REGION": "ap-northeast-1",
}}
})

awsCfg, err := buildAWSConfig(t.Context(), cfg, env)
require.NoError(t, err)
Expand All @@ -480,9 +465,9 @@ func TestBuildAWSConfig_ProviderOptsOverridesEnv(t *testing.T) {
},
}

env := &mockEnvProvider{values: map[string]string{
env := environment.NewMapEnvProvider(map[string]string{
"AWS_REGION": "us-west-2",
}}
})

awsCfg, err := buildAWSConfig(t.Context(), cfg, env)
require.NoError(t, err)
Expand All @@ -504,9 +489,7 @@ func TestNewClient_ValidConfig(t *testing.T) {
},
}

env := &mockEnvProvider{values: map[string]string{}}

client, err := NewClient(t.Context(), cfg, env)
client, err := NewClient(t.Context(), cfg, environment.NewNoEnvProvider())
require.NoError(t, err)
require.NotNil(t, client)

Expand All @@ -527,11 +510,9 @@ func TestNewClient_WithBearerToken(t *testing.T) {
},
}

env := &mockEnvProvider{values: map[string]string{
client, err := NewClient(t.Context(), cfg, environment.NewMapEnvProvider(map[string]string{
"MY_BEDROCK_TOKEN": "test-bearer-token",
}}

client, err := NewClient(t.Context(), cfg, env)
}))
require.NoError(t, err)
require.NotNil(t, client)
}
Expand All @@ -547,11 +528,9 @@ func TestNewClient_WithBearerTokenFromEnv(t *testing.T) {
},
}

env := &mockEnvProvider{values: map[string]string{
client, err := NewClient(t.Context(), cfg, environment.NewMapEnvProvider(map[string]string{
"AWS_BEARER_TOKEN_BEDROCK": "env-bearer-token",
}}

client, err := NewClient(t.Context(), cfg, env)
}))
require.NoError(t, err)
require.NotNil(t, client)
}
Expand Down
Loading