Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SOURCE ?= file go_bindata github github_ee bitbucket aws_s3 google_cloud_storage godoc_vfs gitlab
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb sqlserver firebird neo4j pgx
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb sqlserver firebird neo4j pgx snowflake
DATABASE_TEST ?= $(DATABASE) sqlite sqlite3 sqlcipher
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
TEST_FLAGS ?=
Expand Down
26 changes: 13 additions & 13 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ var drivers = make(map[string]Driver)
// Driver is the interface every database driver must implement.
//
// How to implement a database driver?
// 1. Implement this interface.
// 2. Optionally, add a function named `WithInstance`.
// This function should accept an existing DB instance and a Config{} struct
// and return a driver instance.
// 3. Add a test that calls database/testing.go:Test()
// 4. Add own tests for Open(), WithInstance() (when provided) and Close().
// All other functions are tested by tests in database/testing.
// Saves you some time and makes sure all database drivers behave the same way.
// 5. Call Register in init().
// 6. Create a internal/cli/build_<driver-name>.go file
// 7. Add driver name in 'DATABASE' variable in Makefile
// 1. Implement this interface.
// 2. Optionally, add a function named `WithInstance`.
// This function should accept an existing DB instance and a Config{} struct
// and return a driver instance.
// 3. Add a test that calls database/testing.go:Test()
// 4. Add own tests for Open(), WithInstance() (when provided) and Close().
// All other functions are tested by tests in database/testing.
// Saves you some time and makes sure all database drivers behave the same way.
// 5. Call Register in init().
// 6. Create a internal/cli/build_<driver-name>.go file
// 7. Add driver name in 'DATABASE' variable in Makefile
//
// Guidelines:
// * Don't try to correct user input. Don't assume things.
// - Don't try to correct user input. Don't assume things.
// When in doubt, return an error and explain the situation to the user.
// * All configuration input must come from the URL string in func Open()
// - All configuration input must come from the URL string in func Open()
// or the Config{} struct in WithInstance. Don't os.Getenv().
type Driver interface {
// Open returns a new driver instance configured with parameters
Expand Down
7 changes: 6 additions & 1 deletion database/snowflake/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-warehouse`| n/a | Name of the warehouse to use when connecting |
| `x-role` | n/a | Name of the role to use when connecting |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multiple statements to be run in a single migration. Defaults to `false` |
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-timeout` | n/a | Request timeout. Defaults to 5 minutes |
| `x-connect-timeout` | `ConnectTimeout` | Initial connection timeout to the cluster. Defaults to 30 seconds |

Snowflake is PostgreSQL compatible but has some specific features (or lack thereof) that require slightly different behavior.

## Status
This driver is not officially supported as there are no tests for it.
This driver is not officially supported.
173 changes: 92 additions & 81 deletions database/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,40 @@ import (
nurl "net/url"
"strconv"
"strings"

"go.uber.org/atomic"
"time"

"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
"github.com/lib/pq"
sf "github.com/snowflakedb/gosnowflake"
"go.uber.org/atomic"
)

func init() {
db := Snowflake{}
database.Register("snowflake", &db)
}

var DefaultMigrationsTable = "schema_migrations"
const (
DefaultMigrationsTable = "schema_migrations"
DefaultRequestTimeout = 5 * time.Minute
DefaultConnectTimeout = 30 * time.Second
)

var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoPassword = fmt.Errorf("no password")
ErrNoSchema = fmt.Errorf("no schema")
ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name")
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoPassword = fmt.Errorf("no password")
ErrNoSchema = fmt.Errorf("no schema")
ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name")
ErrInvalidParameterFormat = fmt.Errorf("invalid parameter format")
)

type Config struct {
MigrationsTable string
DatabaseName string
MigrationsTable string
DatabaseName string
MultiStatementEnabled bool
ConnectTimeout time.Duration
dsn string
}

type Snowflake struct {
Expand All @@ -50,8 +57,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
ctx, cancel := context.WithTimeout(context.Background(), config.ConnectTimeout)
defer cancel()

if err := instance.Ping(); err != nil {
if err := instance.PingContext(ctx); err != nil {
return nil, err
}

Expand All @@ -73,7 +82,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())
conn, err := instance.Conn(ctx)

if err != nil {
return nil, err
Expand All @@ -92,7 +101,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return px, nil
}

func (p *Snowflake) Open(url string) (database.Driver, error) {
func configForURL(url string) (*Config, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
Expand All @@ -119,29 +128,74 @@ func (p *Snowflake) Open(url string) (database.Driver, error) {
}

cfg := &sf.Config{
Account: purl.Host,
User: purl.User.Username(),
Password: password,
Database: database,
Schema: schema,
Account: purl.Host,
User: purl.User.Username(),
Password: password,
Database: database,
Schema: schema,
RequestTimeout: DefaultRequestTimeout,
}

if warehouse := purl.Query().Get("x-warehouse"); len(warehouse) > 0 {
cfg.Warehouse = warehouse
}
if role := purl.Query().Get("x-role"); len(role) > 0 {
cfg.Role = role
}
if timeout := purl.Query().Get("x-timeout"); len(timeout) > 0 {
timeoutSeconds, err := strconv.ParseInt(timeout, 10, 64)
if err != nil {
return nil, ErrInvalidParameterFormat
}
cfg.RequestTimeout = time.Duration(timeoutSeconds) * time.Second
}

dsn, err := sf.DSN(cfg)
if err != nil {
return nil, err
}

db, err := sql.Open("snowflake", dsn)
migrationsTable := purl.Query().Get("x-migrations-table")

multiStatement := false
if multi := purl.Query().Get("x-multi-statement"); len(multi) > 0 {
multiStatement, err = strconv.ParseBool(multi)
if err != nil {
return nil, ErrInvalidParameterFormat
}
}

config := &Config{
DatabaseName: database,
MigrationsTable: migrationsTable,
MultiStatementEnabled: multiStatement,
ConnectTimeout: DefaultConnectTimeout,
dsn: dsn,
}

if connectTimeout := purl.Query().Get("x-connect-timeout"); len(connectTimeout) > 0 {
connectTimeoutSeconds, err := strconv.ParseInt(connectTimeout, 10, 64)
if err != nil {
return nil, ErrInvalidParameterFormat
}
config.ConnectTimeout = time.Duration(connectTimeoutSeconds) * time.Second
}

return config, nil
}

func (p *Snowflake) Open(url string) (database.Driver, error) {
cfg, err := configForURL(url)
if err != nil {
return nil, err
}

migrationsTable := purl.Query().Get("x-migrations-table")
db, err := sql.Open("snowflake", cfg.dsn)
if err != nil {
return nil, err
}

px, err := WithInstance(db, &Config{
DatabaseName: database,
MigrationsTable: migrationsTable,
})
px, err := WithInstance(db, cfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -178,68 +232,25 @@ func (p *Snowflake) Run(migration io.Reader) error {
return err
}

ctx := context.Background()
if p.config.MultiStatementEnabled {
// allow variable number of statements in the request by setting MULTI_STATEMENT_COUNT to 0
// https://docs.snowflake.com/en/developer-guide/sql-api/submitting-multiple-statements.html#specifying-multiple-sql-statements-in-the-request
if ctx, err = sf.WithMultiStatement(ctx, 0); err != nil {
return err
}
}

// run migration
query := string(migr[:])
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
// gosnowflake.SnowflakeError does not return a line number, so there's no need to parse it
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}

return nil
}

func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}

const newLine = '\n'

func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}

func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}

func (p *Snowflake) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
Expand Down Expand Up @@ -284,8 +295,8 @@ func (p *Snowflake) Version() (version int, dirty bool, err error) {
return database.NilVersion, false, nil

case err != nil:
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
if e, ok := err.(*sf.SnowflakeError); ok {
if e.Number == sf.ErrObjectNotExistOrAuthorized {
return database.NilVersion, false, nil
}
}
Expand Down
Loading