Skip to content

Commit

Permalink
Add schema validation to DB
Browse files Browse the repository at this point in the history
This ensures we don't open a DB with an earlier schema or a
config that differs from ours

Signed-off-by: Matthew Heon <matthew.heon@gmail.com>

Closes: #86
Approved by: rhatdan
  • Loading branch information
mheon authored and rh-atomic-bot committed Nov 30, 2017
1 parent ed5d686 commit 7eb5ce9
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 1 deletion.
3 changes: 3 additions & 0 deletions libpod/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ var (
// ErrDBClosed indicates that the connection to the state database has
// already been closed
ErrDBClosed = errors.New("database connection already closed")
// ErrDBBadConfig indicates that the database has a different schema or
// was created by a libpod with a different config
ErrDBBadConfig = errors.New("database configuration mismatch")

// ErrNotImplemented indicates that the requested functionality is not
// yet present
Expand Down
9 changes: 9 additions & 0 deletions libpod/sql_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
_ "github.com/mattn/go-sqlite3"
)

// DBSchema is the current DB schema version
// Increments every time a change is made to the database's tables
const DBSchema = 1

// SQLState is a state implementation backed by a persistent SQLite3 database
type SQLState struct {
db *sql.DB
Expand Down Expand Up @@ -69,6 +73,11 @@ func NewSQLState(dbPath, lockPath, specsDir string, runtime *Runtime) (State, er
return nil, err
}

// Ensure that the database matches our config
if err := checkDB(db, runtime); err != nil {
return nil, err
}

state.db = db

state.valid = true
Expand Down
131 changes: 131 additions & 0 deletions libpod/sql_state_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,137 @@ import (
_ "github.com/mattn/go-sqlite3"
)

// Checks that the DB configuration matches the runtime's configuration
func checkDB(db *sql.DB, r *Runtime) (err error) {
// Create a table to hold runtime information
// TODO: Include UID/GID mappings
const runtimeTable = `
CREATE TABLE runtime(
Id INTEGER NOT NULL PRIMARY KEY,
SchemaVersion INTEGER NOT NULL,
StaticDir TEXT NOT NULL,
TmpDir TEXT NOT NULL,
RunRoot TEXT NOT NULL,
GraphRoot TEXT NOT NULL,
GraphDriverName TEXT NOT NULL,
CHECK (Id=0)
);
`
const fillRuntimeTable = `INSERT INTO runtime VALUES (
?, ?, ?, ?, ?, ?, ?
);`

const selectRuntimeTable = `SELECT SchemaVersion,
StaticDir,
TmpDir,
RunRoot,
GraphRoot,
GraphDriverName
FROM runtime WHERE id=0;`

const checkRuntimeExists = "SELECT name FROM sqlite_master WHERE type='table' AND name='runtime';"

tx, err := db.Begin()
if err != nil {
return errors.Wrapf(err, "error beginning database transaction")
}
defer func() {
if err != nil {
if err2 := tx.Rollback(); err2 != nil {
logrus.Errorf("Error rolling back transaction to check runtime table: %v", err2)
}
}

}()

row := tx.QueryRow(checkRuntimeExists)
var table string
if err := row.Scan(&table); err != nil {
// There is no runtime table
// Create and populate the runtime table
if err == sql.ErrNoRows {
if _, err := tx.Exec(runtimeTable); err != nil {
return errors.Wrapf(err, "error creating runtime table in database")
}

_, err := tx.Exec(fillRuntimeTable,
0,
DBSchema,
r.config.StaticDir,
r.config.TmpDir,
r.config.StorageConfig.RunRoot,
r.config.StorageConfig.GraphRoot,
r.config.StorageConfig.GraphDriverName)
if err != nil {
return errors.Wrapf(err, "error populating runtime table in database")
}

if err := tx.Commit(); err != nil {
return errors.Wrapf(err, "error committing runtime table transaction in database")
}

return nil
}

return errors.Wrapf(err, "error checking for presence of runtime table in database")
}

// There is a runtime table
// Retrieve its contents
var (
schemaVersion int
staticDir string
tmpDir string
runRoot string
graphRoot string
graphDriverName string
)

row = tx.QueryRow(selectRuntimeTable)
err = row.Scan(
&schemaVersion,
&staticDir,
&tmpDir,
&runRoot,
&graphRoot,
&graphDriverName)
if err != nil {
return errors.Wrapf(err, "error retrieving runtime information from database")
}

// Compare the information in the database against our runtime config
if schemaVersion != DBSchema {
return errors.Wrapf(ErrDBBadConfig, "database schema version %d does not match our schema version %d",
schemaVersion, DBSchema)
}
if staticDir != r.config.StaticDir {
return errors.Wrapf(ErrDBBadConfig, "database static directory %s does not match our static directory %s",
staticDir, r.config.StaticDir)
}
if tmpDir != r.config.TmpDir {
return errors.Wrapf(ErrDBBadConfig, "database temp directory %s does not match our temp directory %s",
tmpDir, r.config.TmpDir)
}
if runRoot != r.config.StorageConfig.RunRoot {
return errors.Wrapf(ErrDBBadConfig, "database runroot directory %s does not match our runroot directory %s",
runRoot, r.config.StorageConfig.RunRoot)
}
if graphRoot != r.config.StorageConfig.GraphRoot {
return errors.Wrapf(ErrDBBadConfig, "database graph root directory %s does not match our graph root directory %s",
graphRoot, r.config.StorageConfig.GraphRoot)
}
if graphDriverName != r.config.StorageConfig.GraphDriverName {
return errors.Wrapf(ErrDBBadConfig, "database runroot directory %s does not match our runroot directory %s",
graphDriverName, r.config.StorageConfig.GraphDriverName)
}

if err := tx.Commit(); err != nil {
return errors.Wrapf(err, "error committing runtime table transaction in database")
}

return nil
}

// Performs database setup including by not limited to initializing tables in
// the database
func prepareDB(db *sql.DB) (err error) {
Expand Down
7 changes: 6 additions & 1 deletion libpod/sql_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/containers/storage"
"github.com/opencontainers/runtime-tools/generate"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -102,7 +103,11 @@ func getEmptyState() (s State, p string, err error) {
dbPath := filepath.Join(tmpDir, "db.sql")
lockPath := filepath.Join(tmpDir, "db.lck")

state, err := NewSQLState(dbPath, lockPath, tmpDir, nil)
runtime := new(Runtime)
runtime.config = new(RuntimeConfig)
runtime.config.StorageConfig = storage.StoreOptions{}

state, err := NewSQLState(dbPath, lockPath, tmpDir, runtime)
if err != nil {
return nil, "", err
}
Expand Down

0 comments on commit 7eb5ce9

Please sign in to comment.