Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,30 @@ const (
// DefaultListenIPv4 is the default interface used by the HTTP server.
DefaultListenIPv4 = "127.0.0.1"
// DefaultListenPort is the default port used by the HTTP server.
DefaultListenPort = "3000"
defaultListenAddr = DefaultListenIPv4 + ":" + DefaultListenPort
defaultRoutedMode = false
defaultUnifiedMode = false
defaultEnvFile = ""
defaultEnableDIFC = false
defaultLogDir = "/tmp/gh-aw/mcp-logs"
DefaultListenPort = "3000"
defaultListenAddr = DefaultListenIPv4 + ":" + DefaultListenPort
defaultRoutedMode = false
defaultUnifiedMode = false
defaultEnvFile = ""
defaultEnableDIFC = false
defaultLogDir = "/tmp/gh-aw/mcp-logs"
defaultParallelLaunch = true
)

var (
configFile string
configStdin bool
listenAddr string
routedMode bool
unifiedMode bool
envFile string
enableDIFC bool
logDir string
validateEnv bool
verbosity int // Verbosity level: 0 (default), 1 (-v info), 2 (-vv debug), 3 (-vvv trace)
debugLog = logger.New("cmd:root")
version = "dev" // Default version, overridden by SetVersion
configFile string
configStdin bool
listenAddr string
routedMode bool
unifiedMode bool
envFile string
enableDIFC bool
logDir string
validateEnv bool
parallelLaunch bool
verbosity int // Verbosity level: 0 (default), 1 (-v info), 2 (-vv debug), 3 (-vvv trace)
debugLog = logger.New("cmd:root")
version = "dev" // Default version, overridden by SetVersion
)

var rootCmd = &cobra.Command{
Expand All @@ -75,6 +77,7 @@ func init() {
rootCmd.Flags().BoolVar(&enableDIFC, "enable-difc", defaultEnableDIFC, "Enable DIFC enforcement and session requirement (requires sys___init call before tool access)")
rootCmd.Flags().StringVar(&logDir, "log-dir", getDefaultLogDir(), "Directory for log files (falls back to stdout if directory cannot be created)")
rootCmd.Flags().BoolVar(&validateEnv, "validate-env", false, "Validate execution environment (Docker, env vars) before starting")
rootCmd.Flags().BoolVar(&parallelLaunch, "parallel-launch", defaultParallelLaunch, "Launch MCP servers in parallel during startup (enabled by default)")
rootCmd.Flags().CountVarP(&verbosity, "verbose", "v", "Increase verbosity level (use -v for info, -vv for debug, -vvv for trace)")

// Mark mutually exclusive flags
Expand Down Expand Up @@ -238,12 +241,19 @@ func run(cmd *cobra.Command, args []string) error {

// Apply command-line flags to config
cfg.EnableDIFC = enableDIFC
cfg.ParallelLaunch = parallelLaunch
if enableDIFC {
log.Println("DIFC enforcement and session requirement enabled")
} else {
log.Println("DIFC enforcement disabled (sessions auto-created for standard MCP client compatibility)")
}

if parallelLaunch {
log.Println("Parallel server launching enabled")
} else {
log.Println("Sequential server launching enabled")
}

// Determine mode (default to unified if neither flag is set)
mode := "unified"
if routedMode {
Expand Down
7 changes: 4 additions & 3 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ const (

// Config represents the MCPG configuration
type Config struct {
Servers map[string]*ServerConfig `toml:"servers"`
EnableDIFC bool `toml:"enable_difc"` // When true, enables DIFC enforcement and requires sys___init call before tool access. Default is false for standard MCP client compatibility.
Gateway *GatewayConfig `toml:"gateway"` // Gateway configuration (port, API key, etc.)
Servers map[string]*ServerConfig `toml:"servers"`
EnableDIFC bool `toml:"enable_difc"` // When true, enables DIFC enforcement and requires sys___init call before tool access. Default is false for standard MCP client compatibility.
ParallelLaunch bool `toml:"parallel_launch"` // When true (default), launches MCP servers in parallel during startup.
Gateway *GatewayConfig `toml:"gateway"` // Gateway configuration (port, API key, etc.)
}

// GatewayConfig represents gateway-level configuration
Expand Down
104 changes: 88 additions & 16 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ type ToolInfo struct {

// UnifiedServer implements a unified MCP server that aggregates multiple backend servers
type UnifiedServer struct {
launcher *launcher.Launcher
sysServer *sys.SysServer
ctx context.Context
server *sdk.Server
sessions map[string]*Session // mcp-session-id -> Session
sessionMu sync.RWMutex
tools map[string]*ToolInfo // prefixed tool name -> tool info
toolsMu sync.RWMutex
launcher *launcher.Launcher
sysServer *sys.SysServer
ctx context.Context
server *sdk.Server
sessions map[string]*Session // mcp-session-id -> Session
sessionMu sync.RWMutex
tools map[string]*ToolInfo // prefixed tool name -> tool info
toolsMu sync.RWMutex
parallelLaunch bool // When true (default), launches MCP servers in parallel during startup

// DIFC components
guardRegistry *guard.Registry
Expand All @@ -101,15 +102,16 @@ type UnifiedServer struct {

// NewUnified creates a new unified MCP server
func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) {
logUnified.Printf("Creating new unified server: enableDIFC=%v, servers=%d", cfg.EnableDIFC, len(cfg.Servers))
logUnified.Printf("Creating new unified server: enableDIFC=%v, parallelLaunch=%v, servers=%d", cfg.EnableDIFC, cfg.ParallelLaunch, len(cfg.Servers))
l := launcher.New(ctx, cfg)

us := &UnifiedServer{
launcher: l,
sysServer: sys.NewSysServer(l.ServerIDs()),
ctx: ctx,
sessions: make(map[string]*Session),
tools: make(map[string]*ToolInfo),
launcher: l,
sysServer: sys.NewSysServer(l.ServerIDs()),
ctx: ctx,
sessions: make(map[string]*Session),
tools: make(map[string]*ToolInfo),
parallelLaunch: cfg.ParallelLaunch,

// Initialize DIFC components
guardRegistry: guard.NewRegistry(),
Expand Down Expand Up @@ -141,6 +143,13 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error)
return us, nil
}

// launchResult stores the result of a backend server launch
type launchResult struct {
serverID string
err error
duration time.Duration
}

// registerAllTools fetches and registers tools from all backend servers
func (us *UnifiedServer) registerAllTools() error {
log.Println("Registering tools from all backends...")
Expand All @@ -157,8 +166,22 @@ func (us *UnifiedServer) registerAllTools() error {
log.Println("DIFC disabled: skipping sys tools registration")
}

// Register tools from each backend server
for _, serverID := range us.launcher.ServerIDs() {
serverIDs := us.launcher.ServerIDs()

if us.parallelLaunch {
// Launch servers in parallel
return us.registerAllToolsParallel(serverIDs)
} else {
// Launch servers sequentially (original behavior)
return us.registerAllToolsSequential(serverIDs)
}
}

// registerAllToolsSequential registers tools from backend servers sequentially
func (us *UnifiedServer) registerAllToolsSequential(serverIDs []string) error {
logUnified.Printf("Registering tools sequentially from %d backends", len(serverIDs))

for _, serverID := range serverIDs {
logUnified.Printf("Registering tools from backend: %s", serverID)
if err := us.registerToolsFromBackend(serverID); err != nil {
log.Printf("Warning: failed to register tools from %s: %v", serverID, err)
Expand All @@ -170,6 +193,55 @@ func (us *UnifiedServer) registerAllTools() error {
return nil
}

// registerAllToolsParallel registers tools from backend servers in parallel
func (us *UnifiedServer) registerAllToolsParallel(serverIDs []string) error {
logUnified.Printf("Registering tools in parallel from %d backends", len(serverIDs))

var wg sync.WaitGroup
results := make(chan launchResult, len(serverIDs))

// Launch each server in its own goroutine
for _, serverID := range serverIDs {
wg.Add(1)
go func(sid string) {
defer wg.Done()

startTime := time.Now()
err := us.registerToolsFromBackend(sid)
duration := time.Since(startTime)

results <- launchResult{
serverID: sid,
err: err,
duration: duration,
}
}(serverID)
}

// Wait for all goroutines to complete
wg.Wait()
close(results)

// Collect and log results
successCount := 0
failureCount := 0
for result := range results {
if result.err != nil {
log.Printf("Warning: failed to register tools from %s (took %v): %v", result.serverID, result.duration, result.err)
logger.LogWarn("backend", "Failed to register tools from %s (took %v): %v", result.serverID, result.duration, result.err)
failureCount++
} else {
logUnified.Printf("Successfully registered tools from %s (took %v)", result.serverID, result.duration)
logger.LogInfo("backend", "Successfully registered tools from %s (took %v)", result.serverID, result.duration)
successCount++
}
}

log.Printf("Parallel tool registration complete: %d succeeded, %d failed, total tools=%d", successCount, failureCount, len(us.tools))
logUnified.Printf("Tool registration complete: %d succeeded, %d failed, total tools=%d", successCount, failureCount, len(us.tools))
return nil
}

// registerToolsFromBackend registers tools from a specific backend with <server>___<tool> naming
func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
log.Printf("Registering tools from backend: %s", serverID)
Expand Down
28 changes: 28 additions & 0 deletions internal/server/unified_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,31 @@ func TestRequireSession_EdgeCases(t *testing.T) {
})
}
}

func TestUnifiedServer_ParallelLaunch_Enabled(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
ParallelLaunch: true,
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
require.NoError(t, err, "NewUnified() failed")
defer us.Close()

assert.True(t, us.parallelLaunch, "ParallelLaunch should be enabled when configured")
}

func TestUnifiedServer_ParallelLaunch_Disabled(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{},
ParallelLaunch: false,
}

ctx := context.Background()
us, err := NewUnified(ctx, cfg)
require.NoError(t, err, "NewUnified() failed")
defer us.Close()

assert.False(t, us.parallelLaunch, "ParallelLaunch should be disabled when configured")
}