Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ proxy/ui_dist/placeholder.txt:
touch $@

test: proxy/ui_dist/placeholder.txt
go test -short -v -count=1 ./proxy
go test -short -count=1 ./proxy/...

# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt
go test -v -count=1 ./proxy
go test -count=1 ./proxy/...

ui/node_modules:
cd ui && npm install
Expand Down Expand Up @@ -81,4 +82,4 @@ release:
git tag "$$new_tag";

# Phony targets
.PHONY: all clean ui mac linux windows simple-responder
.PHONY: all clean ui mac linux windows simple-responder test test-all
13 changes: 7 additions & 6 deletions llama-swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy"
"github.com/mostlygeek/llama-swap/proxy/config"
)

var (
Expand All @@ -38,13 +39,13 @@ func main() {
os.Exit(0)
}

config, err := proxy.LoadConfig(*configPath)
conf, err := config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}

if len(config.Profiles) > 0 {
if len(conf.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
}

Expand All @@ -67,15 +68,15 @@ func main() {
// Support for watching config and reloading when it changes
reloadProxyManager := func() {
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
config, err = proxy.LoadConfig(*configPath)
conf, err = config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Warning, unable to reload configuration: %v\n", err)
return
}

fmt.Println("Configuration Changed")
currentPM.Shutdown()
srv.Handler = proxy.New(config)
srv.Handler = proxy.New(conf)
fmt.Println("Configuration Reloaded")

// wait a few seconds and tell any UI to reload
Expand All @@ -85,12 +86,12 @@ func main() {
})
})
} else {
config, err = proxy.LoadConfig(*configPath)
conf, err = config.LoadConfig(*configPath)
if err != nil {
fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1)
}
srv.Handler = proxy.New(config)
srv.Handler = proxy.New(conf)
}
}

Expand Down
2 changes: 1 addition & 1 deletion proxy/config.go → proxy/config/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package proxy
package config

import (
"fmt"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//go:build !windows

package proxy
package config

import (
"os"
Expand Down
2 changes: 1 addition & 1 deletion proxy/config_test.go → proxy/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package proxy
package config

import (
"slices"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//go:build windows

package proxy
package config

import (
"os"
Expand Down
7 changes: 4 additions & 3 deletions proxy/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy/config"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -65,18 +66,18 @@ func getTestPort() int {
return port
}

func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
}

func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
// Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d"
`, simpleResponderPath, port, expectedMessage, port)

var cfg ModelConfig
var cfg config.ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
}
Expand Down
3 changes: 2 additions & 1 deletion proxy/metrics_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
)

// TokenMetrics represents parsed token statistics from llama-server logs
Expand Down Expand Up @@ -38,7 +39,7 @@ type MetricsMonitor struct {
nextID int
}

func NewMetricsMonitor(config *Config) *MetricsMonitor {
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback
Expand Down
7 changes: 4 additions & 3 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
)

type ProcessState string
Expand All @@ -39,7 +40,7 @@ const (

type Process struct {
ID string
config ModelConfig
config config.ModelConfig
cmd *exec.Cmd

// PR #155 called to cancel the upstream process
Expand Down Expand Up @@ -74,7 +75,7 @@ type Process struct {
failedStartCount int
}

func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
Expand Down Expand Up @@ -539,7 +540,7 @@ func (p *Process) cmdStopUpstreamProcess() error {

if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return err
Expand Down
23 changes: 12 additions & 11 deletions proxy/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
config := ModelConfig{
config := config.ModelConfig{
Cmd: "nonexistent-command",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
Expand Down Expand Up @@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {

// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
config := config.ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
Expand Down Expand Up @@ -402,15 +403,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
binaryPath := getSimpleResponderPath()
port := getTestPort()

config := ModelConfig{
conf := config.ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}

process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger)
process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
defer process.Stop()

// reduce to make testing go faster
Expand Down Expand Up @@ -450,15 +451,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
}

func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
conf := getTestSimpleResponderConfig("test_stop_cmd")

if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
conf.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
conf.CmdStop = "kill -TERM ${PID}"
}

process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
defer process.Stop()

err := process.start()
Expand All @@ -470,15 +471,15 @@ func TestProcess_StopCmd(t *testing.T) {

func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
expectedMessage := "test_env_not_emptied"
config := getTestSimpleResponderConfig(expectedMessage)
conf := getTestSimpleResponderConfig(expectedMessage)

// ensure that the the default config does not blank out the inherited environment
configWEnv := config
configWEnv := conf

// ensure the additiona variables are appended to the process' environment
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")

process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger)
process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)

process1.start()
Expand Down
6 changes: 4 additions & 2 deletions proxy/processgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"net/http"
"slices"
"sync"

"github.com/mostlygeek/llama-swap/proxy/config"
)

type ProcessGroup struct {
sync.Mutex

config Config
config config.Config
id string
swap bool
exclusive bool
Expand All @@ -24,7 +26,7 @@ type ProcessGroup struct {
lastUsedProcess string
}

func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id]
if !ok {
panic("Unable to find configuration for group id: " + id)
Expand Down
15 changes: 8 additions & 7 deletions proxy/processgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ import (
"sync"
"testing"

"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
)

var processGroupTestConfig = AddDefaultGroupToConfig(Config{
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"),
},
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Exclusive: true,
Expand All @@ -34,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
})

func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5"))
}

Expand All @@ -48,9 +49,9 @@ func TestProcessGroup_HasMember(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
// use the same listening so if a model is already running, it will fail
// this is a way to test that swap isolation is working
// properly when there are parallel requests made at the
Expand All @@ -61,7 +62,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
"model4": getTestSimpleResponderConfigPort("model4", 9832),
"model5": getTestSimpleResponderConfigPort("model5", 9832),
},
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Members: []string{"model1", "model2", "model3", "model4", "model5"},
Expand Down
5 changes: 3 additions & 2 deletions proxy/proxymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
Expand All @@ -27,7 +28,7 @@ const (
type ProxyManager struct {
sync.Mutex

config Config
config config.Config
ginEngine *gin.Engine

// logging
Expand All @@ -44,7 +45,7 @@ type ProxyManager struct {
shutdownCancel context.CancelFunc
}

func New(config Config) *ProxyManager {
func New(config config.Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
Expand Down
Loading
Loading