Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/robfig/cron/v3"
)

Expand All @@ -37,6 +38,7 @@ type Config struct {
ApplicationVersion string // Application version (optional, overridden by DBOS__APPVERSION env var)
ExecutorID string // Executor ID (optional, overridden by DBOS__VMID env var)
Context context.Context // User Context
SystemDBPool *pgxpool.Pool // Custom System Database Pool
}

func processConfig(inputConfig *Config) (*Config, error) {
Expand All @@ -61,6 +63,7 @@ func processConfig(inputConfig *Config) (*Config, error) {
ConductorAPIKey: inputConfig.ConductorAPIKey,
ApplicationVersion: inputConfig.ApplicationVersion,
ExecutorID: inputConfig.ExecutorID,
SystemDBPool: inputConfig.SystemDBPool,
}

// Load defaults
Expand Down Expand Up @@ -321,8 +324,14 @@ func NewDBOSContext(ctx context.Context, inputConfig Config) (DBOSContext, error

initExecutor.applicationID = os.Getenv("DBOS__APPID")

newSystemDatabaseInputs := newSystemDatabaseInput{
databaseURL: config.DatabaseURL,
customPool: config.SystemDBPool,
logger: initExecutor.logger,
}

// Create the system database
systemDB, err := newSystemDatabase(initExecutor, config.DatabaseURL, initExecutor.logger)
systemDB, err := newSystemDatabase(initExecutor, newSystemDatabaseInputs)
if err != nil {
return nil, newInitializationError(fmt.Sprintf("failed to create system database: %v", err))
}
Expand Down
57 changes: 57 additions & 0 deletions dbos/dbos_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package dbos

import (
"bytes"
"context"
"log/slog"
"testing"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -58,6 +61,60 @@ func TestConfig(t *testing.T) {
assert.Equal(t, expectedMsg, dbosErr.Message)
})

t.Run("NewSystemDatabaseWithCustomPool", func(t *testing.T) {

// Logger
var buf bytes.Buffer
slogLogger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))

slogLogger = slogLogger.With("service", "dbos-test", "environment", "test")

// Custom Pool
poolConfig, err := pgxpool.ParseConfig(databaseURL)
require.NoError(t, err)

poolConfig.MaxConns = 10
poolConfig.MinConns = 5
poolConfig.MaxConnLifetime = 2 * time.Hour
poolConfig.MaxConnIdleTime = time.Minute * 2

poolConfig.ConnConfig.ConnectTimeout = 10 * time.Second

pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
require.NoError(t, err)

config := Config{
DatabaseURL: databaseURL,
AppName: "test-custom-pool",
Logger: slogLogger,
SystemDBPool: pool,
}

customdbosContext, err := NewDBOSContext(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, customdbosContext)

dbosCtx, ok := customdbosContext.(*dbosContext)
defer dbosCtx.Shutdown(10 * time.Second)
require.True(t, ok)

sysDB, ok := dbosCtx.systemDB.(*sysDB)
require.True(t, ok)
assert.Same(t, pool, sysDB.pool, "The pool in dbosContext should be the same as the custom pool provided")

stats := sysDB.pool.Stat()
assert.Equal(t, int32(10), stats.MaxConns(), "MaxConns should match custom pool config")

sysdbConfig := sysDB.pool.Config()
assert.Equal(t, int32(10), sysdbConfig.MaxConns)
assert.Equal(t, int32(5), sysdbConfig.MinConns)
assert.Equal(t, 2*time.Hour, sysdbConfig.MaxConnLifetime)
assert.Equal(t, 2*time.Minute, sysdbConfig.MaxConnIdleTime)
assert.Equal(t, 10*time.Second, sysdbConfig.ConnConfig.ConnectTimeout)
})

t.Run("FailsWithoutDatabaseURL", func(t *testing.T) {
config := Config{
AppName: "test-app",
Expand Down
58 changes: 41 additions & 17 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,19 @@ func runMigrations(databaseURL string) error {
return nil
}

type newSystemDatabaseInput struct {
databaseURL string
customPool *pgxpool.Pool
logger *slog.Logger
}

// New creates a new SystemDatabase instance and runs migrations
func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Logger) (systemDatabase, error) {
func newSystemDatabase(ctx context.Context, inputs newSystemDatabaseInput) (systemDatabase, error) {
// Dereference fields from inputs
databaseURL := inputs.databaseURL
customPool := inputs.customPool
logger := inputs.logger

// Create the database if it doesn't exist
if err := createDatabaseIfNotExists(ctx, databaseURL, logger); err != nil {
return nil, fmt.Errorf("failed to create database: %v", err)
Expand All @@ -241,24 +252,37 @@ func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Log
return nil, fmt.Errorf("failed to run migrations: %v", err)
}

// Parse the connection string to get a config
config, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse database URL: %v", err)
}
// Set pool configuration
config.MaxConns = 20
config.MinConns = 0
config.MaxConnLifetime = time.Hour
config.MaxConnIdleTime = time.Minute * 5
// pool
var pool *pgxpool.Pool

// Add acquire timeout to prevent indefinite blocking
config.ConnConfig.ConnectTimeout = 10 * time.Second
if customPool != nil {

pool = customPool

} else {

// Parse the connection string to get a config
config, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse database URL: %v", err)
}

// Set pool configuration
config.MaxConns = 20
config.MinConns = 0
config.MaxConnLifetime = time.Hour
config.MaxConnIdleTime = time.Minute * 5

// Add acquire timeout to prevent indefinite blocking
config.ConnConfig.ConnectTimeout = 10 * time.Second

// Create pool with configuration
newPool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %v", err)
}
pool = newPool

// Create pool with configuration
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %v", err)
}

// Test the connection
Expand Down
Loading