Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions cmd/dbos/cli_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ func getDatabaseURL(dbRole string) string {
return dsn.String()
}

// getDatabaseURLKeyValue returns the same connection string in libpq key-value format
// This is useful for testing that commands handle both URL and key-value formats
func getDatabaseURLKeyValue(dbRole string) string {
password := os.Getenv("PGPASSWORD")
if password == "" {
password = "dbos"
}
return fmt.Sprintf("user='%s' password='%s' database=dbos host=localhost port=5432 sslmode=disable", dbRole, password)
}

// TestCLIWorkflow provides comprehensive integration testing of the DBOS CLI
func TestCLIWorkflow(t *testing.T) {
defer goleak.VerifyNone(t,
Expand Down Expand Up @@ -147,6 +157,27 @@ func TestCLIWorkflow(t *testing.T) {
}
})

t.Run("ResetDatabaseWithKeyValueFormat", func(t *testing.T) {
// Test reset command with key-value format connection string
args := append([]string{"reset", "-y", "--db-url", getDatabaseURLKeyValue("postgres")}, config.args...)
cmd := exec.Command(cliPath, args...)

output, err := cmd.CombinedOutput()
require.NoError(t, err, "Reset database command with key-value format failed: %s", string(output))

assert.Contains(t, string(output), "System database has been reset successfully", "Output should confirm database reset")

// Verify the database was reset by checking schema doesn't exist
db, err := sql.Open("pgx", getDatabaseURL("postgres"))
require.NoError(t, err)
defer db.Close()

var exists bool
err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM information_schema.schemata WHERE schema_name = $1)", config.schemaName).Scan(&exists)
require.NoError(t, err)
assert.False(t, exists, fmt.Sprintf("Schema %s should not exist after reset", config.schemaName))
})

t.Run("ProjectInitialization", func(t *testing.T) {
testProjectInitialization(t, cliPath)
})
Expand Down Expand Up @@ -507,6 +538,22 @@ func testListWorkflows(t *testing.T, cliPath string, baseArgs []string, dbRole s
}
})
}

// Test list command with key-value format connection string
t.Run("ListWithKeyValueFormat", func(t *testing.T) {
args := append([]string{"workflow", "list", "--db-url", getDatabaseURLKeyValue(dbRole)}, baseArgs...)
fmt.Println(args)
cmd := exec.Command(cliPath, args...)

output, err := cmd.CombinedOutput()
require.NoError(t, err, "List command with key-value format failed: %s", string(output))

// Parse JSON output
var workflows []dbos.WorkflowStatus
err = json.Unmarshal(output, &workflows)
require.NoError(t, err, "JSON output should be valid")
assert.Greater(t, len(workflows), 0, "Should have workflows when using key-value format")
})
}

// testGetWorkflow tests retrieving individual workflow details
Expand Down Expand Up @@ -560,6 +607,21 @@ func testGetWorkflow(t *testing.T, cliPath string, baseArgs []string, dbRole str
assert.NotEmpty(t, status2.Status, "Should have workflow status")
assert.NotEmpty(t, status2.Name, "Should have workflow name")

// Test with key-value format connection string (libpq format)
argsKeyValue := append([]string{"workflow", "get", workflowID, "--db-url", getDatabaseURLKeyValue(dbRole)}, baseArgs...)
cmdKeyValue := exec.Command(cliPath, argsKeyValue...)

outputKeyValue, errKeyValue := cmdKeyValue.CombinedOutput()
require.NoError(t, errKeyValue, "Get workflow JSON command with key-value format failed: %s", string(outputKeyValue))

// Verify valid JSON
var statusKeyValue dbos.WorkflowStatus
err = json.Unmarshal(outputKeyValue, &statusKeyValue)
require.NoError(t, err, "JSON output should be valid")
assert.Equal(t, workflowID, statusKeyValue.ID, "JSON should contain correct workflow ID")
assert.NotEmpty(t, statusKeyValue.Status, "Should have workflow status")
assert.NotEmpty(t, statusKeyValue.Name, "Should have workflow name")

// Test with config file containing environment variable
configPath := "dbos-config.yaml"

Expand Down
47 changes: 24 additions & 23 deletions cmd/dbos/reset.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package main

import (
"database/sql"
"context"
"fmt"
"net/url"

_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -39,42 +39,43 @@ func runReset(cmd *cobra.Command, args []string) error {
return err
}

// Parse the URL to get database name
parsedURL, err := url.Parse(dbURL)
if err != nil {
return fmt.Errorf("invalid database URL: %w", err)
}
ctx := context.Background()

// Extract database name from path
dbName := parsedURL.Path
if len(dbName) > 0 && dbName[0] == '/' {
dbName = dbName[1:] // Remove leading slash
// Parse the connection string using pgxpool.ParseConfig which handles both URL and key-value formats
config, err := pgxpool.ParseConfig(dbURL)
if err != nil {
return fmt.Errorf("failed to parse database URL: %w", err)
}

// Get the database name from the config
dbName := config.ConnConfig.Database
if dbName == "" {
return fmt.Errorf("database name is required in URL")
return fmt.Errorf("database name not found in connection string")
}

// Connect to postgres database to drop and recreate the system database
parsedURL.Path = "/postgres"
postgresURL := parsedURL.String()
// Create a connection configuration pointing to the postgres database
postgresConfig := config.ConnConfig.Copy()
postgresConfig.Database = "postgres"

db, err := sql.Open("pgx", postgresURL)
// Connect to the postgres database
conn, err := pgx.ConnectConfig(ctx, postgresConfig)
if err != nil {
return fmt.Errorf("failed to connect to postgres database: %w", err)
return fmt.Errorf("failed to connect to PostgreSQL server: %w", err)
}
defer db.Close()
defer conn.Close(ctx)

// Drop the system database if it exists
logger.Info("Resetting system database", "database", dbName)
dropQuery := fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)
if _, err := db.Exec(dropQuery); err != nil {
dropSQL := fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", pgx.Identifier{dbName}.Sanitize())
_, err = conn.Exec(ctx, dropSQL)
if err != nil {
return fmt.Errorf("failed to drop system database: %w", err)
}

// Create the database
createQuery := fmt.Sprintf("CREATE DATABASE %s", dbName)
if _, err := db.Exec(createQuery); err != nil {
createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{dbName}.Sanitize())
_, err = conn.Exec(ctx, createSQL)
if err != nil {
return fmt.Errorf("failed to create system database: %w", err)
}

Expand Down
84 changes: 63 additions & 21 deletions cmd/dbos/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,76 @@ import (
"log/slog"
"net/url"
"os"
"strings"

"github.com/dbos-inc/dbos-transact-golang/dbos"
"github.com/spf13/viper"
)

// maskPassword replaces the password in a database URL with asterisks
func maskPassword(dbURL string) string {
func maskPassword(dbURL string) (string, error) {
parsedURL, err := url.Parse(dbURL)
if err != nil {
// If we can't parse it, return the original (shouldn't happen with valid URLs)
logger.Warn("Failed to parse database URL", "error", err)
return dbURL
}
if err == nil && parsedURL.Scheme != "" {

// Check if there is user info with a password
if parsedURL.User != nil {
username := parsedURL.User.Username()
_, hasPassword := parsedURL.User.Password()
if hasPassword {
// Manually construct the URL with masked password to avoid encoding
maskedURL := parsedURL.Scheme + "://" + username + ":********@" + parsedURL.Host + parsedURL.Path
if parsedURL.RawQuery != "" {
maskedURL += "?" + parsedURL.RawQuery
}
if parsedURL.Fragment != "" {
maskedURL += "#" + parsedURL.Fragment
// Check if there is user info with a password
if parsedURL.User != nil {
username := parsedURL.User.Username()
_, hasPassword := parsedURL.User.Password()
if hasPassword {
// Manually construct the URL with masked password to avoid encoding
maskedURL := parsedURL.Scheme + "://" + username + ":***@" + parsedURL.Host + parsedURL.Path
if parsedURL.RawQuery != "" {
maskedURL += "?" + parsedURL.RawQuery
}
if parsedURL.Fragment != "" {
maskedURL += "#" + parsedURL.Fragment
}
return maskedURL, nil
}
return maskedURL
}

return parsedURL.String(), nil
}

return parsedURL.String()
// If URL parsing failed or no scheme, try key-value format (libpq connection string)
return maskPasswordInKeyValueFormat(dbURL), nil
}

// maskPasswordInKeyValueFormat masks password in libpq-style key-value connection strings
// Format: "user=foo password=bar database=db host=localhost"
// Supports all spacing variations: password=value, password =value, password= value, password = value
func maskPasswordInKeyValueFormat(connStr string) string {
// Find "password" key (case insensitive)
lowerStr := strings.ToLower(connStr)
passwordKey := "password"
passwordIdx := strings.Index(lowerStr, passwordKey)
if passwordIdx == -1 {
return connStr // No password found
}

// Find the = sign after "password" (skip optional spaces before =)
afterKey := passwordIdx + len(passwordKey)
for afterKey < len(connStr) && connStr[afterKey] == ' ' {
afterKey++
}
if afterKey >= len(connStr) || connStr[afterKey] != '=' {
return connStr // No = sign found
}

// Find the start of the password value (skip = and optional spaces after =)
valueStart := afterKey + 1
for valueStart < len(connStr) && connStr[valueStart] == ' ' {
valueStart++
}

// Find the end of the password value (next space or end of string)
valueEnd := valueStart
for valueEnd < len(connStr) && connStr[valueEnd] != ' ' {
valueEnd++
}

// Replace password value with ***
return connStr[:valueStart] + "***" + connStr[valueEnd:]
}

// getDBURL resolves the database URL from flag, config, or environment variable
Expand All @@ -63,7 +101,11 @@ func getDBURL() (string, error) {
}

// Log the database URL in verbose mode with masked password
maskedURL := maskPassword(resolvedURL)
maskedURL, err := maskPassword(resolvedURL)
if err != nil {
logger.Debug("Failed to mask database URL", "error", err)
maskedURL = resolvedURL
}
logger.Debug("Using database URL", "source", source, "url", maskedURL)

return resolvedURL, nil
Expand Down
94 changes: 94 additions & 0 deletions dbos/dbos_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -278,6 +279,99 @@ func TestConfig(t *testing.T) {

require.NotNil(t, ctx2)
})

t.Run("KeyValueFormatConnectionString", func(t *testing.T) {
t.Setenv("DBOS__APPVERSION", "v1.0.0")
t.Setenv("DBOS__APPID", "test-keyvalue-format")
t.Setenv("DBOS__VMID", "test-executor-id")

// Get base connection parameters
originalURL := databaseURL
parsedURL, err := pgxpool.ParseConfig(originalURL)
require.NoError(t, err)

user := parsedURL.ConnConfig.User
database := parsedURL.ConnConfig.Database
host := parsedURL.ConnConfig.Host
port := parsedURL.ConnConfig.Port

// Use a unique test password that won't match other connection parameters
testPassword := "TEST_PASSWORD_UNIQUE_12345!@#$%"

// Test password masking with various spacing formats
maskingTestCases := []struct {
name string
connStr string
}{
{"NoSpaces", fmt.Sprintf("user=%s password=%s database=%s host=%s", user, testPassword, database, host)},
{"SpaceBeforeEquals", fmt.Sprintf("user=%s password =%s database=%s host=%s", user, testPassword, database, host)},
{"SpaceAfterEquals", fmt.Sprintf("user=%s password= %s database=%s host=%s", user, testPassword, database, host)},
{"SpacesBothSides", fmt.Sprintf("user=%s password = %s database=%s host=%s", user, testPassword, database, host)},
{"UppercaseKey", fmt.Sprintf("user=%s PASSWORD=%s database=%s host=%s", user, testPassword, database, host)},
{"MixedCaseKey", fmt.Sprintf("user=%s Password=%s database=%s host=%s", user, testPassword, database, host)},
}

// Add port and sslmode if needed
portSSL := ""
if port != 0 {
portSSL += fmt.Sprintf(" port=%d", port)
}
if strings.Contains(originalURL, "sslmode=disable") {
portSSL += " sslmode=disable"
}
for i := range maskingTestCases {
maskingTestCases[i].connStr += portSSL
}

for _, tc := range maskingTestCases {
t.Run("Masking_"+tc.name, func(t *testing.T) {
masked, err := maskPassword(tc.connStr)
require.NoError(t, err)
assert.Contains(t, masked, "***", "password should be masked")
passwordPattern := fmt.Sprintf("password=%s", testPassword)
assert.NotContains(t, strings.ToLower(masked), strings.ToLower(passwordPattern), "password should not appear in plaintext")
})
}

// Integration test: verify DBOS context works with key-value format
t.Run("DBOSContextCreation", func(t *testing.T) {
// Use the actual password from config for integration test
actualPassword := parsedURL.ConnConfig.Password
keyValueConnStr := fmt.Sprintf("user='%s' password='%s' database=%s host=%s%s", user, actualPassword, database, host, portSSL)

ctx, err := NewDBOSContext(context.Background(), Config{
DatabaseURL: keyValueConnStr,
AppName: "test-keyvalue-format",
})
require.NoError(t, err)
defer func() {
if ctx != nil {
Shutdown(ctx, 1*time.Minute)
}
}()

require.NotNil(t, ctx)

// Verify system DB is functional
dbosCtx, ok := ctx.(*dbosContext)
require.True(t, ok)
sysDB, ok := dbosCtx.systemDB.(*sysDB)
require.True(t, ok)

var exists bool
err = sysDB.pool.QueryRow(context.Background(), "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'workflow_status')").Scan(&exists)
require.NoError(t, err)
assert.True(t, exists)

// Verify masking works
poolConnStr := sysDB.pool.Config().ConnString()
maskedConnStr, err := maskPassword(poolConnStr)
require.NoError(t, err)
assert.Contains(t, maskedConnStr, "password=***")
assert.NotContains(t, maskedConnStr, fmt.Sprintf("password=%s", actualPassword))
})
})

}

func TestCustomSystemDBSchema(t *testing.T) {
Expand Down
Loading