Skip to content

Add Features #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

# IDE
.idea/
.idea

# Mac
.DS_Store
*/.DS_Store
!sqldialects.xml

# Dependencies
vendor

# local files
/tmp

# Test binary, built with `go test -c`
*.test

# Output of the go coverage tool, specifically when used with LiteIDE
*.out
out/*
.editorconfig
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module github.com/bitcomplete/sqltestutil
module github.com/mohammad-ahmadi-de/sqltestutil

go 1.18

Expand Down
19 changes: 15 additions & 4 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ package sqltestutil

import (
"context"
"database/sql/driver"
"io/ioutil"
"path/filepath"
"sort"

"github.com/jmoiron/sqlx"
)

// RunMigrations reads all of the files matching *.up.sql in migrationDir and
Expand All @@ -19,18 +18,30 @@ import (
//
// Note that this function does not check whether the migration has already been
// run. Its primary purpose is to initialize a test database.
func RunMigrations(ctx context.Context, db sqlx.ExecerContext, migrationDir string) error {
func RunMigrations(ctx context.Context, db driver.ExecerContext, migrationDir string, files ...string) error {
filenames, err := filepath.Glob(filepath.Join(migrationDir, "*.up.sql"))
if err != nil {
return err
}
var filter map[string]struct{} = nil
if len(files) > 0 {
filter = make(map[string]struct{})
for i := range files {
filter[files[i]] = struct{}{}
}
}
sort.Strings(filenames)
for _, filename := range filenames {
if len(files) > 0 {
if _, exist := filter[filepath.Base(filename)]; !exist {
continue
}
}
data, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
_, err = db.ExecContext(ctx, string(data))
_, err = db.ExecContext(ctx, string(data), nil)
if err != nil {
return err
}
Expand Down
55 changes: 55 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package sqltestutil

import "fmt"

type Option func(*PostgresContainer)

func WithPassword(password string) Option {
return func(container *PostgresContainer) {
if len(password) == 0 {
panic("sqltestutil: password option can not be empty")
}
container.password = password
}
}
func WithUser(user string) Option {
return func(container *PostgresContainer) {
if len(user) == 0 {
panic("sqltestutil: user option can not be empty")
}
container.user = user
}
}
func WithPort(port uint16) Option {
return func(container *PostgresContainer) {
if port <= 1000 {
panic("sqltestutil: port option can not be less than 1000")
}
container.port = fmt.Sprint(port)
}
}

func WithVersion(version string) Option {
return func(container *PostgresContainer) {
if len(version) == 0 {
panic("sqltestutil: version option can not be empty")
}
container.version = version
}
}
func WithDBName(dbName string) Option {
return func(container *PostgresContainer) {
if len(dbName) == 0 {
panic("sqltestutil: dbName option can not be empty")
}
container.dbName = dbName
}
}
func WithContainerName(containerName string) Option {
return func(container *PostgresContainer) {
if len(containerName) == 0 {
panic("sqltestutil: containerName option can not be empty")
}
container.containerName = containerName
}
}
113 changes: 88 additions & 25 deletions postgres_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"github.com/docker/docker/api/types/filters"
"io"
"math/big"
"net"
Expand All @@ -19,9 +20,13 @@ import (
// PostgresContainer is a Docker container running Postgres. It can be used to
// cheaply start a throwaway Postgres instance for testing.
type PostgresContainer struct {
id string
password string
port string
id string
password string
user string
port string
dbName string
version string
containerName string
}

// StartPostgresContainer starts a new Postgres Docker container. The version
Expand Down Expand Up @@ -58,20 +63,65 @@ type PostgresContainer struct {
// }
//
// func TestExampleTestSuite(t *testing.T) {
// pg, _ := sqltestutil.StartPostgresContainer(context.Background(), "12")
// pg, _ := sqltestutil.StartPostgresContainer(context.Background(), WithVersion("12"))
// defer pg.Shutdown(ctx)
// suite.Run(t, &ExampleTestSuite{})
// }
//
// [1]: https://github.com/golang/go/issues/37206
// [2]: https://github.com/stretchr/testify
func StartPostgresContainer(ctx context.Context, version string) (*PostgresContainer, error) {
func StartPostgresContainer(ctx context.Context, options ...Option) (*PostgresContainer, error) {
cli, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
panic(err)
}
defer cli.Close()
image := "postgres:" + version

containerObj := &PostgresContainer{}
//
// apply options, if any.
//
for i := range options {
options[i](containerObj)
}
//
// set default values
//
if len(containerObj.password) == 0 {
password, err := randomPassword()
if err != nil {
return nil, err
}
containerObj.password = password
}
if len(containerObj.port) == 0 {
port, err := randomPort()
if err != nil {
return nil, err
}
containerObj.port = port
}
if len(containerObj.user) == 0 {
containerObj.user = "pgtest"
}
if len(containerObj.dbName) == 0 {
containerObj.dbName = "pgtest"
}
if len(containerObj.version) == 0 {
containerObj.version = "12"
}
if len(containerObj.containerName) == 0 {
containerObj.containerName = "sqltestutil"
}
//
// remove leaked containers
//
err = containerObj.fixContainerLeak(ctx)
if err != nil {
return nil, err
}

image := "postgres:" + containerObj.version
_, _, err = cli.ImageInspectWithRaw(ctx, image)
if err != nil {
_, notFound := err.(interface {
Expand All @@ -91,20 +141,12 @@ func StartPostgresContainer(ctx context.Context, version string) (*PostgresConta
}
}

password, err := randomPassword()
if err != nil {
return nil, err
}
port, err := randomPort()
if err != nil {
return nil, err
}
createResp, err := cli.ContainerCreate(ctx, &container.Config{
Image: image,
Env: []string{
"POSTGRES_DB=pgtest",
"POSTGRES_PASSWORD=" + password,
"POSTGRES_USER=pgtest",
"POSTGRES_DB=" + containerObj.dbName,
"POSTGRES_PASSWORD=" + containerObj.password,
"POSTGRES_USER=" + containerObj.user,
},
Healthcheck: &container.HealthConfig{
Test: []string{"CMD-SHELL", "pg_isready -U pgtest"},
Expand All @@ -115,10 +157,10 @@ func StartPostgresContainer(ctx context.Context, version string) (*PostgresConta
}, &container.HostConfig{
PortBindings: nat.PortMap{
"5432/tcp": []nat.PortBinding{
{HostPort: port},
{HostPort: containerObj.port},
},
},
}, nil, nil, "")
}, nil, nil, containerObj.containerName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -160,17 +202,38 @@ HealthCheck:
time.Sleep(500 * time.Millisecond)
}
}
return &PostgresContainer{
id: createResp.ID,
password: password,
port: port,
}, nil
containerObj.id = createResp.ID

return containerObj, nil
}
func (c *PostgresContainer) fixContainerLeak(ctx context.Context) error {
cli, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return err
}
defer cli.Close()

data, err := cli.ContainerList(ctx, types.ContainerListOptions{All: true, Filters: filters.NewArgs(filters.Arg("name", c.containerName))})
if err != nil {
return err
}
for i := range data {
err = cli.ContainerStop(ctx, data[i].ID, nil)
if err != nil {
return err
}
err = cli.ContainerRemove(ctx, data[i].ID, types.ContainerRemoveOptions{})
if err != nil {
return err
}
}
return nil
}

// ConnectionString returns a connection URL string that can be used to connect
// to the running Postgres container.
func (c *PostgresContainer) ConnectionString() string {
return fmt.Sprintf("postgres://pgtest:%s@127.0.0.1:%s/pgtest", c.password, c.port)
return fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s", c.user, c.password, c.port, c.dbName)
}

// Shutdown cleans up the Postgres container by stopping and removing it. This
Expand Down
15 changes: 15 additions & 0 deletions postgres_container_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package sqltestutil_test

import (
"context"
"github.com/mohammad-ahmadi-de/sqltestutil"
"testing"
)

func TestStartPostgresContainer(t *testing.T) {
c, err := sqltestutil.StartPostgresContainer(context.Background(), sqltestutil.WithPort(5321))
if err != nil {
t.Fatal(err)
}
c.Shutdown(context.Background())
}