Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change sqlite driver to github.com/glebarez/go-sqlite #91

Merged
merged 9 commits into from
Apr 18, 2023
Merged
Prev Previous commit
fix: add missing file
  • Loading branch information
devhaozi committed Apr 17, 2023
commit d21cd800459192f3aca6e7b5640aeb5b1c45fe0f
268 changes: 268 additions & 0 deletions database/console/driver/sqlite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
package sqlite

import (
"database/sql"
"fmt"
"io"
nurl "net/url"
"strconv"
"strings"

_ "github.com/glebarez/go-sqlite"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
"go.uber.org/atomic"
)

func init() {
database.Register("sqlite", &Sqlite{})
}

var DefaultMigrationsTable = "schema_migrations"
devhaozi marked this conversation as resolved.
Show resolved Hide resolved
var (
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
)

type Config struct {
MigrationsTable string
DatabaseName string
NoTxWrap bool
}

type Sqlite struct {
db *sql.DB
isLocked atomic.Bool

config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
return nil, err
}

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}

mx := &Sqlite{
db: instance,
config: config,
}
if err := mx.ensureVersionTable(); err != nil {
return nil, err
}
return mx, nil
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Sqlite type.
func (m *Sqlite) ensureVersionTable() (err error) {
if err = m.Lock(); err != nil {
return err
}

defer func() {
if e := m.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()

query := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
`, m.config.MigrationsTable, m.config.MigrationsTable)

if _, err := m.db.Exec(query); err != nil {
return err
}
return nil
}

func (m *Sqlite) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite://", "", 1)
db, err := sql.Open("sqlite", dbfile)
if err != nil {
return nil, err
}

qv := purl.Query()

migrationsTable := qv.Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}

noTxWrap := false
if v := qv.Get("x-no-tx-wrap"); v != "" {
noTxWrap, err = strconv.ParseBool(v)
if err != nil {
return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
}
}

mx, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
NoTxWrap: noTxWrap,
})
if err != nil {
return nil, err
}
return mx, nil
}

func (m *Sqlite) Close() error {
return m.db.Close()
}

func (m *Sqlite) Drop() (err error) {
query := `SELECT name FROM sqlite_master WHERE type = 'table';`
tables, err := m.db.Query(query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()

tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if len(tableNames) > 0 {
for _, t := range tableNames {
query := "DROP TABLE " + t
err = m.executeQuery(query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
query := "VACUUM"
_, err = m.db.Query(query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}

return nil
}

func (m *Sqlite) Lock() error {
if !m.isLocked.CompareAndSwap(false, true) {
return database.ErrLocked
}
return nil
}

func (m *Sqlite) Unlock() error {
if !m.isLocked.CompareAndSwap(true, false) {
return database.ErrNotLocked
}
return nil
}

func (m *Sqlite) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
query := string(migr[:])

if m.config.NoTxWrap {
return m.executeQueryNoTx(query)
}
return m.executeQuery(query)
}

func (m *Sqlite) executeQuery(query string) error {
tx, err := m.db.Begin()
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}

func (m *Sqlite) executeQueryNoTx(query string) error {
if _, err := m.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}

func (m *Sqlite) SetVersion(version int, dirty bool) error {
tx, err := m.db.Begin()
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := "DELETE FROM " + m.config.MigrationsTable
if _, err := tx.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}

if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}

return nil
}

func (m *Sqlite) Version() (version int, dirty bool, err error) {
query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
err = m.db.QueryRow(query).Scan(&version, &dirty)
if err != nil {
return database.NilVersion, false, nil
}
return version, dirty, nil
}