Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions internal/postgres/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

embeddedpostgres "github.com/fergusstrange/embedded-postgres"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pgschema/pgschema/cmd/util"
)

// PostgresVersion is an alias for the embedded-postgres version type.
Expand Down Expand Up @@ -45,23 +46,23 @@ type EmbeddedPostgresConfig struct {
// DetectPostgresVersionFromDB connects to a database and detects its version
// This is a convenience function that opens a connection, detects the version, and closes it
func DetectPostgresVersionFromDB(host string, port int, database, user, password string) (PostgresVersion, error) {
// Build connection string
dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=prefer",
user, password, host, port, database)
// Build connection config
config := &util.ConnectionConfig{
Host: host,
Port: port,
Database: database,
User: user,
Password: password,
SSLMode: "prefer",
}

// Connect to database
db, err := sql.Open("pgx", dsn)
db, err := util.Connect(config)
if err != nil {
return "", fmt.Errorf("failed to connect to database: %w", err)
}
defer db.Close()

// Test the connection
ctx := context.Background()
if err := db.PingContext(ctx); err != nil {
return "", fmt.Errorf("failed to ping database: %w", err)
}

// Detect version
return detectPostgresVersion(db)
}
Expand Down Expand Up @@ -102,28 +103,25 @@ func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, e
return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err)
}

// Build connection string
// Build connection config
host := "localhost"
dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable",
config.Username, config.Password, host, port, config.Database)
connConfig := &util.ConnectionConfig{
Host: host,
Port: port,
Database: config.Database,
User: config.Username,
Password: config.Password,
SSLMode: "disable",
}

// Connect to database
db, err := sql.Open("pgx", dsn)
db, err := util.Connect(connConfig)
if err != nil {
instance.Stop()
os.RemoveAll(runtimePath)
return nil, fmt.Errorf("failed to connect to embedded PostgreSQL: %w", err)
}

// Test the connection
ctx := context.Background()
if err := db.PingContext(ctx); err != nil {
db.Close()
instance.Stop()
os.RemoveAll(runtimePath)
return nil, fmt.Errorf("failed to ping embedded PostgreSQL: %w", err)
}

return &EmbeddedPostgres{
instance: instance,
db: db,
Expand Down
22 changes: 11 additions & 11 deletions internal/postgres/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"

_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pgschema/pgschema/cmd/util"
)

// ExternalDatabase manages an external PostgreSQL database for desired state validation.
Expand Down Expand Up @@ -35,23 +36,22 @@ type ExternalDatabaseConfig struct {
// NewExternalDatabase creates a new external database connection for desired state validation.
// It validates the connection, checks version compatibility, and generates a temporary schema name.
func NewExternalDatabase(config *ExternalDatabaseConfig) (*ExternalDatabase, error) {
// Build connection string
dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=prefer",
config.Username, config.Password, config.Host, config.Port, config.Database)
// Build connection config
connConfig := &util.ConnectionConfig{
Host: config.Host,
Port: config.Port,
Database: config.Database,
User: config.Username,
Password: config.Password,
SSLMode: "prefer",
}

// Connect to database
db, err := sql.Open("pgx", dsn)
db, err := util.Connect(connConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect to external database: %w", err)
}

// Test the connection
ctx := context.Background()
if err := db.PingContext(ctx); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping external database: %w", err)
}

// Detect version and validate compatibility
majorVersion, err := detectMajorVersion(db)
if err != nil {
Expand Down