Skip to content

Commit

Permalink
Add a BinariesPath configuration option, for prefetching Postgres bin…
Browse files Browse the repository at this point in the history
…aries before the test (#38)

* Refactor usage of paths throught embedded_postgres.go
* Separate binaryExtractLocation from runtimePath in prepare_database.go
* Add BinariesPath to the configuration; reuse unarchived binaries.
* Added readme information about BinariesPath.
* Fixed CI Lint comments.
* Move Mkdir to after the unarchive, to make the Alpine tests happy.
* Fixed typo
* Fixed tests that failed on Alpine.
  • Loading branch information
mishas authored Sep 16, 2021
1 parent 05b3cc2 commit c1bd5f3
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 60 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ This library aims to require as little configuration as possible, favouring over
| Version | 12.1.0 |
| RuntimePath | $USER_HOME/.embedded-postgres-go/extracted |
| DataPath | $USER_HOME/.embedded-postgres-go/extracted/data |
| BinariesPath | $USER_HOME/.embedded-postgres-go/extracted |
| Port | 5432 |
| StartTimeout | 15 Seconds |

Expand All @@ -54,6 +55,12 @@ If a persistent data location is required, set *DataPath* to a directory outside
If the *RuntimePath* directory is empty or already initialized but with an incompatible postgres version, it will be
removed and Postgres reinitialized.

Postgres binaries will be downloaded and placed in *BinaryPath* if `BinaryPath/bin` doesn't exist.
If the directory does exist, whatever binary version is placed there will be used (no version check
is done).
If your test need to run multiple different versions of Postgres for different tests, make sure
*BinaryPath* is a subdirectory of *RuntimePath*.

A single Postgres instance can be created, started and stopped as follows

```go
Expand Down
8 changes: 8 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Config struct {
password string
runtimePath string
dataPath string
binariesPath string
locale string
startTimeout time.Duration
logger io.Writer
Expand Down Expand Up @@ -84,6 +85,13 @@ func (c Config) DataPath(path string) Config {
return c
}

// BinariesPath sets the path of the pre-downloaded postgres binaries.
// If this option is left unset, the binaries will be downloaded.
func (c Config) BinariesPath(path string) Config {
c.binariesPath = path
return c
}

// Locale sets the default locale for initdb
func (c Config) Locale(locale string) Config {
c.locale = locale
Expand Down
87 changes: 43 additions & 44 deletions embedded_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,45 +68,62 @@ func (ep *EmbeddedPostgres) Start() error {
return err
}

cacheLocation, exists := ep.cacheLocator()
if !exists {
if err := ep.remoteFetchStrategy(); err != nil {
return err
}
cacheLocation, cacheExists := ep.cacheLocator()

if ep.config.runtimePath == "" {
ep.config.runtimePath = filepath.Join(filepath.Dir(cacheLocation), "extracted")
}

if ep.config.dataPath == "" {
ep.config.dataPath = filepath.Join(ep.config.runtimePath, "data")
}

binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation)
if err := os.RemoveAll(binaryExtractLocation); err != nil {
return fmt.Errorf("unable to clean up runtime directory %s with error: %s", binaryExtractLocation, err)
if err := os.RemoveAll(ep.config.runtimePath); err != nil {
return fmt.Errorf("unable to clean up runtime directory %s with error: %s", ep.config.runtimePath, err)
}

if err := archiver.NewTarXz().Unarchive(cacheLocation, binaryExtractLocation); err != nil {
return fmt.Errorf("unable to extract postgres archive %s to %s", cacheLocation, binaryExtractLocation)
if ep.config.binariesPath == "" {
ep.config.binariesPath = ep.config.runtimePath
}

_, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin"))
if os.IsNotExist(binDirErr) {
if !cacheExists {
if err := ep.remoteFetchStrategy(); err != nil {
return err
}
}

if err := archiver.NewTarXz().Unarchive(cacheLocation, ep.config.binariesPath); err != nil {
return fmt.Errorf("unable to extract postgres archive %s to %s", cacheLocation, ep.config.binariesPath)
}
}

dataLocation := userDataPathOrDefault(ep.config.dataPath, binaryExtractLocation)
if err := os.MkdirAll(ep.config.runtimePath, 0755); err != nil {
return fmt.Errorf("unable to create runtime directory %s with error: %s", ep.config.runtimePath, err)
}

reuseData := ep.config.dataPath != "" && dataDirIsValid(dataLocation, ep.config.version)
reuseData := dataDirIsValid(ep.config.dataPath, ep.config.version)

if !reuseData {
if err := os.RemoveAll(dataLocation); err != nil {
return fmt.Errorf("unable to clean up data directory %s with error: %s", dataLocation, err)
if err := os.RemoveAll(ep.config.dataPath); err != nil {
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)
}

if err := ep.initDatabase(binaryExtractLocation, dataLocation, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil {
if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil {
return err
}
}

if err := startPostgres(binaryExtractLocation, ep.config); err != nil {
if err := startPostgres(ep.config); err != nil {
return err
}

ep.started = true

if !reuseData {
if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil {
if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil {
if stopErr := stopPostgres(ep.config); stopErr != nil {
return fmt.Errorf("unable to stop database casused by error %s", err)
}

Expand All @@ -115,7 +132,7 @@ func (ep *EmbeddedPostgres) Start() error {
}

if err := healthCheckDatabaseOrTimeout(ep.config); err != nil {
if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil {
if stopErr := stopPostgres(ep.config); stopErr != nil {
return fmt.Errorf("unable to stop database casused by error %s", err)
}

Expand All @@ -127,13 +144,11 @@ func (ep *EmbeddedPostgres) Start() error {

// Stop will try to stop the Postgres process gracefully returning an error when there were any problems.
func (ep *EmbeddedPostgres) Stop() error {
cacheLocation, exists := ep.cacheLocator()
if !exists || !ep.started {
if !ep.started {
return errors.New("server has not been started")
}

binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation)
if err := stopPostgres(binaryExtractLocation, ep.config); err != nil {
if err := stopPostgres(ep.config); err != nil {
return err
}

Expand All @@ -142,10 +157,10 @@ func (ep *EmbeddedPostgres) Stop() error {
return nil
}

func startPostgres(binaryExtractLocation string, config Config) error {
postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl")
func startPostgres(config Config) error {
postgresBinary := filepath.Join(config.binariesPath, "bin/pg_ctl")
postgresProcess := exec.Command(postgresBinary, "start", "-w",
"-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation),
"-D", config.dataPath,
"-o", fmt.Sprintf(`"-p %d"`, config.port))
postgresProcess.Stderr = config.logger
postgresProcess.Stdout = config.logger
Expand All @@ -157,10 +172,10 @@ func startPostgres(binaryExtractLocation string, config Config) error {
return nil
}

func stopPostgres(binaryExtractLocation string, config Config) error {
postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl")
func stopPostgres(config Config) error {
postgresBinary := filepath.Join(config.binariesPath, "bin/pg_ctl")
postgresProcess := exec.Command(postgresBinary, "stop", "-w",
"-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation))
"-D", config.dataPath)
postgresProcess.Stderr = config.logger
postgresProcess.Stdout = config.logger

Expand All @@ -180,22 +195,6 @@ func ensurePortAvailable(port uint32) error {
return nil
}

func userRuntimePathOrDefault(userLocation, cacheLocation string) string {
if userLocation != "" {
return userLocation
}

return filepath.Join(filepath.Dir(cacheLocation), "extracted")
}

func userDataPathOrDefault(userLocation, runtimeLocation string) string {
if userLocation != "" {
return userLocation
}

return filepath.Join(runtimeLocation, "data")
}

func dataDirIsValid(dataDir string, version PostgresVersion) bool {
pgVersion := filepath.Join(dataDir, "PG_VERSION")

Expand Down
95 changes: 93 additions & 2 deletions embedded_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"github.com/mholt/archiver/v3"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -118,7 +119,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error {
database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger io.Writer) error {
return errors.New("ah it did not work")
}

Expand Down Expand Up @@ -221,7 +222,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error {
database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger io.Writer) error {
return nil
}

Expand Down Expand Up @@ -424,3 +425,93 @@ func Test_ReuseData(t *testing.T) {
shutdownDBAndFail(t, err, database)
}
}

func Test_CustomBinariesLocation(t *testing.T) {
tempDir, err := ioutil.TempDir("", "prepare_database_test")
if err != nil {
panic(err)
}

defer func() {
if err := os.RemoveAll(tempDir); err != nil {
panic(err)
}
}()

database := NewDatabase(DefaultConfig().
BinariesPath(tempDir))

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}

// Delete cache to make sure unarchive doesn't happen again.
cacheLocation, _ := database.cacheLocator()
if err := os.RemoveAll(cacheLocation); err != nil {
panic(err)
}

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}
}

func Test_PrefetchedBinaries(t *testing.T) {
binTempDir, err := ioutil.TempDir("", "prepare_database_test_bin")
if err != nil {
panic(err)
}

runtimeTempDir, err := ioutil.TempDir("", "prepare_database_test_runtime")
if err != nil {
panic(err)
}

defer func() {
if err := os.RemoveAll(binTempDir); err != nil {
panic(err)
}

if err := os.RemoveAll(runtimeTempDir); err != nil {
panic(err)
}
}()

database := NewDatabase(DefaultConfig().
BinariesPath(binTempDir).
RuntimePath(runtimeTempDir))

// Download and unarchive postgres into the bindir.
if err := database.remoteFetchStrategy(); err != nil {
panic(err)
}

cacheLocation, _ := database.cacheLocator()
if err := archiver.NewTarXz().Unarchive(cacheLocation, binTempDir); err != nil {
panic(err)
}

// Expect everything to work without cacheLocator and/or remoteFetch abilities.
database.cacheLocator = func() (string, bool) {
return "", false
}
database.remoteFetchStrategy = func() error {
return errors.New("did not work")
}

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}
}
10 changes: 5 additions & 5 deletions prepare_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ import (
"github.com/lib/pq"
)

type initDatabase func(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error
type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger io.Writer) error
type createDatabase func(port uint32, username, password, database string) error

func defaultInitDatabase(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error {
passwordFile, err := createPasswordFile(binaryExtractLocation, password)
func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger io.Writer) error {
passwordFile, err := createPasswordFile(runtimePath, password)
if err != nil {
return err
}
Expand Down Expand Up @@ -50,8 +50,8 @@ func defaultInitDatabase(binaryExtractLocation, pgDataDir, username, password, l
return nil
}

func createPasswordFile(binaryExtractLocation, password string) (string, error) {
passwordFileLocation := filepath.Join(binaryExtractLocation, "pwfile")
func createPasswordFile(runtimePath, password string) (string, error) {
passwordFileLocation := filepath.Join(runtimePath, "pwfile")
if err := ioutil.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil {
return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation)
}
Expand Down
Loading

0 comments on commit c1bd5f3

Please sign in to comment.