Skip to content

Commit

Permalink
feat: automatically translate failed requests to localhost to docker.…
Browse files Browse the repository at this point in the history
…host.internal (#224)

Co-authored-by: Dustin Deus <deusdustin@gmail.com>
  • Loading branch information
fiam and StarpTech authored Nov 7, 2023
1 parent d7fe7da commit 936006d
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 65 deletions.
1 change: 1 addition & 0 deletions router/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func Main() {
core.WithMetrics(metricsConfig(&cfg.Telemetry)),
core.WithEngineExecutionConfig(cfg.EngineExecutionConfiguration),
core.WithAccessController(core.NewAccessController(authenticators, cfg.Authorization.RequireAuthentication)),
core.WithLocalhostFallbackInsideDocker(cfg.LocalhostFallbackInsideDocker),
)

if err != nil {
Expand Down
31 changes: 16 additions & 15 deletions router/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,22 @@ type Config struct {
Headers HeaderRules `yaml:"headers"`
TrafficShaping TrafficShapingRules `yaml:"traffic_shaping"`

ListenAddr string `yaml:"listen_addr" default:"localhost:3002" validate:"hostname_port" envconfig:"LISTEN_ADDR"`
ControlplaneURL string `yaml:"controlplane_url" default:"https://cosmo-cp.wundergraph.com" envconfig:"CONTROLPLANE_URL" validate:"required,uri"`
PlaygroundEnabled bool `yaml:"playground_enabled" default:"true" envconfig:"PLAYGROUND_ENABLED"`
IntrospectionEnabled bool `yaml:"introspection_enabled" default:"true" envconfig:"INTROSPECTION_ENABLED"`
LogLevel string `yaml:"log_level" default:"info" envconfig:"LOG_LEVEL" validate:"oneof=debug info warning error fatal panic"`
JSONLog bool `yaml:"json_log" default:"true" envconfig:"JSON_LOG"`
ShutdownDelay time.Duration `yaml:"shutdown_delay" default:"60s" validate:"required,min=15s" envconfig:"SHUTDOWN_DELAY"`
GracePeriod time.Duration `yaml:"grace_period" default:"20s" validate:"required" envconfig:"GRACE_PERIOD"`
PollInterval time.Duration `yaml:"poll_interval" default:"10s" validate:"required,min=5s" envconfig:"POLL_INTERVAL"`
HealthCheckPath string `yaml:"health_check_path" default:"/health" envconfig:"HEALTH_CHECK_PATH" validate:"uri"`
ReadinessCheckPath string `yaml:"readiness_check_path" default:"/health/ready" envconfig:"READINESS_CHECK_PATH" validate:"uri"`
LivenessCheckPath string `yaml:"liveness_check_path" default:"/health/live" envconfig:"LIVENESS_CHECK_PATH" validate:"uri"`
GraphQLPath string `yaml:"graphql_path" default:"/graphql" envconfig:"GRAPHQL_PATH"`
Authentication AuthenticationConfiguration `yaml:"authentication"`
Authorization AuthorizationConfiguration `yaml:"authorization"`
ListenAddr string `yaml:"listen_addr" default:"localhost:3002" validate:"hostname_port" envconfig:"LISTEN_ADDR"`
ControlplaneURL string `yaml:"controlplane_url" default:"https://cosmo-cp.wundergraph.com" envconfig:"CONTROLPLANE_URL" validate:"required,uri"`
PlaygroundEnabled bool `yaml:"playground_enabled" default:"true" envconfig:"PLAYGROUND_ENABLED"`
IntrospectionEnabled bool `yaml:"introspection_enabled" default:"true" envconfig:"INTROSPECTION_ENABLED"`
LogLevel string `yaml:"log_level" default:"info" envconfig:"LOG_LEVEL" validate:"oneof=debug info warning error fatal panic"`
JSONLog bool `yaml:"json_log" default:"true" envconfig:"JSON_LOG"`
ShutdownDelay time.Duration `yaml:"shutdown_delay" default:"60s" validate:"required,min=15s" envconfig:"SHUTDOWN_DELAY"`
GracePeriod time.Duration `yaml:"grace_period" default:"20s" validate:"required" envconfig:"GRACE_PERIOD"`
PollInterval time.Duration `yaml:"poll_interval" default:"10s" validate:"required,min=5s" envconfig:"POLL_INTERVAL"`
HealthCheckPath string `yaml:"health_check_path" default:"/health" envconfig:"HEALTH_CHECK_PATH" validate:"uri"`
ReadinessCheckPath string `yaml:"readiness_check_path" default:"/health/ready" envconfig:"READINESS_CHECK_PATH" validate:"uri"`
LivenessCheckPath string `yaml:"liveness_check_path" default:"/health/live" envconfig:"LIVENESS_CHECK_PATH" validate:"uri"`
GraphQLPath string `yaml:"graphql_path" default:"/graphql" envconfig:"GRAPHQL_PATH"`
Authentication AuthenticationConfiguration `yaml:"authentication"`
Authorization AuthorizationConfiguration `yaml:"authorization"`
LocalhostFallbackInsideDocker bool `yaml:"localhost_fallback_inside_docker" default:"true" envconfig:"LOCALHOST_FALLBACK_INSIDE_DOCKER"`

ConfigPath string `envconfig:"CONFIG_PATH" validate:"omitempty,filepath"`
RouterConfigPath string `yaml:"router_config_path" envconfig:"ROUTER_CONFIG_PATH" validate:"omitempty,filepath"`
Expand Down
24 changes: 19 additions & 5 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"connectrpc.com/connect"
"github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/graphqlmetrics/v1/graphqlmetricsv1connect"
"github.com/wundergraph/cosmo/router/internal/docker"
"github.com/wundergraph/cosmo/router/internal/graphqlmetrics"
brotli "go.withmatt.com/connect-brotli"

Expand Down Expand Up @@ -105,6 +106,8 @@ type (
routerTrafficConfig *config.RouterTrafficConfiguration
accessController *AccessController
retryOptions retrytransport.RetryOptions
// If connecting to localhost inside Docker fails, fallback to the docker internal address for the host
localhostFallbackInsideDocker bool

engineExecutionConfiguration config.EngineExecutionConfiguration

Expand Down Expand Up @@ -621,17 +624,21 @@ func (r *Router) newServer(ctx context.Context, routerConfig *nodev1.RouterConfi
return nil, fmt.Errorf("failed to create planner cache: %w", err)
}

if r.localhostFallbackInsideDocker && docker.Inside() {
r.logger.Info("localhost fallback enabled, connections that fail to connect to localhost will be retried using host.docker.internal")
}

ecb := &ExecutorConfigurationBuilder{
introspection: r.introspection,
baseURL: r.baseURL,
transport: r.transport,
logger: r.logger,
includeInfo: r.graphqlMetricsConfig.Enabled,
transportOptions: &TransportOptions{
requestTimeout: r.subgraphTransportOptions.RequestTimeout,
preHandlers: r.preOriginHandlers,
postHandlers: r.postOriginHandlers,
retryOptions: retrytransport.RetryOptions{
RequestTimeout: r.subgraphTransportOptions.RequestTimeout,
PreHandlers: r.preOriginHandlers,
PostHandlers: r.postOriginHandlers,
RetryOptions: retrytransport.RetryOptions{
Enabled: r.retryOptions.Enabled,
MaxRetryCount: r.retryOptions.MaxRetryCount,
MaxDuration: r.retryOptions.MaxDuration,
Expand All @@ -640,7 +647,8 @@ func (r *Router) newServer(ctx context.Context, routerConfig *nodev1.RouterConfi
return retrytransport.IsRetryableError(err, resp) && !isMutationRequest(req.Context())
},
},
logger: r.logger,
LocalhostFallbackInsideDocker: r.localhostFallbackInsideDocker,
Logger: r.logger,
},
}

Expand Down Expand Up @@ -1051,6 +1059,12 @@ func WithAccessController(controller *AccessController) Option {
}
}

func WithLocalhostFallbackInsideDocker(fallback bool) Option {
return func(r *Router) {
r.localhostFallbackInsideDocker = fallback
}
}

func DefaultRouterTrafficConfig() *config.RouterTrafficConfiguration {
return &config.RouterTrafficConfiguration{
MaxRequestBodyBytes: 1000 * 1000 * 5, // 5 MB
Expand Down
98 changes: 53 additions & 45 deletions router/core/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strconv"
"time"

"github.com/wundergraph/cosmo/router/internal/docker"
"github.com/wundergraph/cosmo/router/internal/otel"
"github.com/wundergraph/cosmo/router/internal/retrytransport"
"github.com/wundergraph/cosmo/router/internal/trace"
Expand Down Expand Up @@ -85,67 +86,74 @@ func (ct *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error)
}

type TransportFactory struct {
preHandlers []TransportPreHandler
postHandlers []TransportPostHandler
retryOptions retrytransport.RetryOptions
requestTimeout time.Duration
logger *zap.Logger
preHandlers []TransportPreHandler
postHandlers []TransportPostHandler
retryOptions retrytransport.RetryOptions
requestTimeout time.Duration
localhostFallbackInsideDocker bool
logger *zap.Logger
}

var _ ApiTransportFactory = TransportFactory{}

type TransportOptions struct {
preHandlers []TransportPreHandler
postHandlers []TransportPostHandler
retryOptions retrytransport.RetryOptions
requestTimeout time.Duration
logger *zap.Logger
PreHandlers []TransportPreHandler
PostHandlers []TransportPostHandler
RetryOptions retrytransport.RetryOptions
RequestTimeout time.Duration
LocalhostFallbackInsideDocker bool
Logger *zap.Logger
}

func NewTransport(opts *TransportOptions) *TransportFactory {
return &TransportFactory{
preHandlers: opts.preHandlers,
postHandlers: opts.postHandlers,
logger: opts.logger,
retryOptions: opts.retryOptions,
requestTimeout: opts.requestTimeout,
preHandlers: opts.PreHandlers,
postHandlers: opts.PostHandlers,
retryOptions: opts.RetryOptions,
requestTimeout: opts.RequestTimeout,
localhostFallbackInsideDocker: opts.LocalhostFallbackInsideDocker,
logger: opts.Logger,
}
}

func (t TransportFactory) RoundTripper(transport http.RoundTripper, enableStreamingMode bool) http.RoundTripper {
tp := NewCustomTransport(
t.logger,
trace.NewTransport(
transport,
[]otelhttp.Option{
otelhttp.WithSpanNameFormatter(SpanNameFormatter),
otelhttp.WithSpanOptions(otrace.WithAttributes(otel.EngineTransportAttribute)),
},
trace.WithPreHandler(func(r *http.Request) {
span := otrace.SpanFromContext(r.Context())
reqContext := getRequestContext(r.Context())
operation := reqContext.operation

if operation != nil {
if operation.name != "" {
span.SetAttributes(otel.WgOperationName.String(operation.name))
}
if operation.opType != "" {
span.SetAttributes(otel.WgOperationType.String(operation.opType))
}
if operation.hash != 0 {
span.SetAttributes(otel.WgOperationHash.String(strconv.FormatUint(operation.hash, 10)))
}
if t.localhostFallbackInsideDocker && docker.Inside() {
transport = docker.NewLocalhostFallbackRoundTripper(transport)
}
traceTransport := trace.NewTransport(
transport,
[]otelhttp.Option{
otelhttp.WithSpanNameFormatter(SpanNameFormatter),
otelhttp.WithSpanOptions(otrace.WithAttributes(otel.EngineTransportAttribute)),
},
trace.WithPreHandler(func(r *http.Request) {
span := otrace.SpanFromContext(r.Context())
reqContext := getRequestContext(r.Context())
operation := reqContext.operation

if operation != nil {
if operation.name != "" {
span.SetAttributes(otel.WgOperationName.String(operation.name))
}

subgraph := reqContext.ActiveSubgraph(r)
if subgraph != nil {
span.SetAttributes(otel.WgSubgraphID.String(subgraph.Id))
span.SetAttributes(otel.WgSubgraphName.String(subgraph.Name))
if operation.opType != "" {
span.SetAttributes(otel.WgOperationType.String(operation.opType))
}
if operation.hash != 0 {
span.SetAttributes(otel.WgOperationHash.String(strconv.FormatUint(operation.hash, 10)))
}
}

}),
),
subgraph := reqContext.ActiveSubgraph(r)
if subgraph != nil {
span.SetAttributes(otel.WgSubgraphID.String(subgraph.Id))
span.SetAttributes(otel.WgSubgraphName.String(subgraph.Name))
}

}),
)
tp := NewCustomTransport(
t.logger,
traceTransport,
t.retryOptions,
)

Expand Down
23 changes: 23 additions & 0 deletions router/internal/docker/docker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Package docker implements helper functions we use while running under Docker.
// This should only be used for development purposes.
package docker

import (
"os"
)

const (
// dockerInternalHost is the hostname used by docker to access the host machine
// with bridge networking. We use it for automatic fallbacks when requests to localhost fail.
dockerInternalHost = "host.docker.internal"
)

func Inside() bool {
// Check if we are running inside docker by
// testing by checking if /.dockerenv exists
//
// This is not documented by Docker themselves, but it's the only
// method that has been working reliably for several years.
st, err := os.Stat("/.dockerenv")
return err == nil && !st.IsDir()
}
110 changes: 110 additions & 0 deletions router/internal/docker/roundtripper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package docker

import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func findLocalNonLocalhostInterface() (net.IP, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("could not list network interfaces: %w", err)
}
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
switch x := addr.(type) {
case *net.IPNet:
if x.IP.IsPrivate() && !x.IP.IsLoopback() {
return x.IP, nil
}
}
}
}
return nil, errors.New("could not find a suitable IP address")
}

func TestLocalhostFallbackRoundTripper(t *testing.T) {
t.Parallel()

localIP, err := findLocalNonLocalhostInterface()
if err != nil {
// If we can't find a suitable address to run the test, skip it
t.Skip(err)
}
t.Log("using local IP", localIP)
// Find a random free TCP port
l, err := net.Listen("tcp", fmt.Sprintf("[%s]:0", localIP.String()))
require.NoError(t, err)
port := l.Addr().(*net.TCPAddr).Port

server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
require.NoError(t, err)
response := map[string]any{
"method": r.Method,
"host": r.Host,
"path": r.URL.Path,
"body": string(data),
}
resp, err := json.Marshal(response)
require.NoError(t, err)
w.Write(resp)
}))
server.Listener = l
server.Start()
t.Cleanup(server.Close)

transport := &localhostFallbackRoundTripper{
transport: http.DefaultTransport,
targetHost: localIP.String(),
}
client := http.Client{
Transport: transport,
}

t.Run("GET", func(t *testing.T) {
t.Parallel()
resp, err := client.Get(fmt.Sprintf("http://localhost:%d/hello", port))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var response map[string]any
err = json.Unmarshal(data, &response)
require.NoError(t, err)
assert.Equal(t, "GET", response["method"])
assert.Equal(t, fmt.Sprintf("%s:%d", localIP.String(), port), response["host"])
assert.Equal(t, "", response["body"])
assert.Equal(t, "/hello", response["path"])
})

t.Run("POST", func(t *testing.T) {
t.Parallel()
const hello = "hello world"
resp, err := client.Post(fmt.Sprintf("http://localhost:%d", port), "text/plain", strings.NewReader(hello))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var response map[string]any
err = json.Unmarshal(data, &response)
require.NoError(t, err)
assert.Equal(t, "POST", response["method"])
assert.Equal(t, hello, response["body"])
assert.Equal(t, "/", response["path"])
})
}
Loading

0 comments on commit 936006d

Please sign in to comment.