From 7eb5ce940c8145ef57920ef90b52857e9716ffc9 Mon Sep 17 00:00:00 2001 From: Matthew Heon Date: Wed, 29 Nov 2017 14:37:35 -0500 Subject: [PATCH] Add schema validation to DB This ensures we don't open a DB with an earlier schema or a config that differs from ours Signed-off-by: Matthew Heon Closes: #86 Approved by: rhatdan --- libpod/errors.go | 3 + libpod/sql_state.go | 9 +++ libpod/sql_state_internal.go | 131 +++++++++++++++++++++++++++++++++++ libpod/sql_state_test.go | 7 +- 4 files changed, 149 insertions(+), 1 deletion(-) diff --git a/libpod/errors.go b/libpod/errors.go index 180ca51dbfda..782104cf011c 100644 --- a/libpod/errors.go +++ b/libpod/errors.go @@ -59,6 +59,9 @@ var ( // ErrDBClosed indicates that the connection to the state database has // already been closed ErrDBClosed = errors.New("database connection already closed") + // ErrDBBadConfig indicates that the database has a different schema or + // was created by a libpod with a different config + ErrDBBadConfig = errors.New("database configuration mismatch") // ErrNotImplemented indicates that the requested functionality is not // yet present diff --git a/libpod/sql_state.go b/libpod/sql_state.go index 8b18a89590a0..7c2061fcac85 100644 --- a/libpod/sql_state.go +++ b/libpod/sql_state.go @@ -14,6 +14,10 @@ import ( _ "github.com/mattn/go-sqlite3" ) +// DBSchema is the current DB schema version +// Increments every time a change is made to the database's tables +const DBSchema = 1 + // SQLState is a state implementation backed by a persistent SQLite3 database type SQLState struct { db *sql.DB @@ -69,6 +73,11 @@ func NewSQLState(dbPath, lockPath, specsDir string, runtime *Runtime) (State, er return nil, err } + // Ensure that the database matches our config + if err := checkDB(db, runtime); err != nil { + return nil, err + } + state.db = db state.valid = true diff --git a/libpod/sql_state_internal.go b/libpod/sql_state_internal.go index 58a6daa58a31..6e0142b9b887 100644 --- a/libpod/sql_state_internal.go +++ b/libpod/sql_state_internal.go @@ -15,6 +15,137 @@ import ( _ "github.com/mattn/go-sqlite3" ) +// Checks that the DB configuration matches the runtime's configuration +func checkDB(db *sql.DB, r *Runtime) (err error) { + // Create a table to hold runtime information + // TODO: Include UID/GID mappings + const runtimeTable = ` + CREATE TABLE runtime( + Id INTEGER NOT NULL PRIMARY KEY, + SchemaVersion INTEGER NOT NULL, + StaticDir TEXT NOT NULL, + TmpDir TEXT NOT NULL, + RunRoot TEXT NOT NULL, + GraphRoot TEXT NOT NULL, + GraphDriverName TEXT NOT NULL, + CHECK (Id=0) + ); + ` + const fillRuntimeTable = `INSERT INTO runtime VALUES ( + ?, ?, ?, ?, ?, ?, ? + );` + + const selectRuntimeTable = `SELECT SchemaVersion, + StaticDir, + TmpDir, + RunRoot, + GraphRoot, + GraphDriverName + FROM runtime WHERE id=0;` + + const checkRuntimeExists = "SELECT name FROM sqlite_master WHERE type='table' AND name='runtime';" + + tx, err := db.Begin() + if err != nil { + return errors.Wrapf(err, "error beginning database transaction") + } + defer func() { + if err != nil { + if err2 := tx.Rollback(); err2 != nil { + logrus.Errorf("Error rolling back transaction to check runtime table: %v", err2) + } + } + + }() + + row := tx.QueryRow(checkRuntimeExists) + var table string + if err := row.Scan(&table); err != nil { + // There is no runtime table + // Create and populate the runtime table + if err == sql.ErrNoRows { + if _, err := tx.Exec(runtimeTable); err != nil { + return errors.Wrapf(err, "error creating runtime table in database") + } + + _, err := tx.Exec(fillRuntimeTable, + 0, + DBSchema, + r.config.StaticDir, + r.config.TmpDir, + r.config.StorageConfig.RunRoot, + r.config.StorageConfig.GraphRoot, + r.config.StorageConfig.GraphDriverName) + if err != nil { + return errors.Wrapf(err, "error populating runtime table in database") + } + + if err := tx.Commit(); err != nil { + return errors.Wrapf(err, "error committing runtime table transaction in database") + } + + return nil + } + + return errors.Wrapf(err, "error checking for presence of runtime table in database") + } + + // There is a runtime table + // Retrieve its contents + var ( + schemaVersion int + staticDir string + tmpDir string + runRoot string + graphRoot string + graphDriverName string + ) + + row = tx.QueryRow(selectRuntimeTable) + err = row.Scan( + &schemaVersion, + &staticDir, + &tmpDir, + &runRoot, + &graphRoot, + &graphDriverName) + if err != nil { + return errors.Wrapf(err, "error retrieving runtime information from database") + } + + // Compare the information in the database against our runtime config + if schemaVersion != DBSchema { + return errors.Wrapf(ErrDBBadConfig, "database schema version %d does not match our schema version %d", + schemaVersion, DBSchema) + } + if staticDir != r.config.StaticDir { + return errors.Wrapf(ErrDBBadConfig, "database static directory %s does not match our static directory %s", + staticDir, r.config.StaticDir) + } + if tmpDir != r.config.TmpDir { + return errors.Wrapf(ErrDBBadConfig, "database temp directory %s does not match our temp directory %s", + tmpDir, r.config.TmpDir) + } + if runRoot != r.config.StorageConfig.RunRoot { + return errors.Wrapf(ErrDBBadConfig, "database runroot directory %s does not match our runroot directory %s", + runRoot, r.config.StorageConfig.RunRoot) + } + if graphRoot != r.config.StorageConfig.GraphRoot { + return errors.Wrapf(ErrDBBadConfig, "database graph root directory %s does not match our graph root directory %s", + graphRoot, r.config.StorageConfig.GraphRoot) + } + if graphDriverName != r.config.StorageConfig.GraphDriverName { + return errors.Wrapf(ErrDBBadConfig, "database runroot directory %s does not match our runroot directory %s", + graphDriverName, r.config.StorageConfig.GraphDriverName) + } + + if err := tx.Commit(); err != nil { + return errors.Wrapf(err, "error committing runtime table transaction in database") + } + + return nil +} + // Performs database setup including by not limited to initializing tables in // the database func prepareDB(db *sql.DB) (err error) { diff --git a/libpod/sql_state_test.go b/libpod/sql_state_test.go index 9f6b5d0783f4..124959544487 100644 --- a/libpod/sql_state_test.go +++ b/libpod/sql_state_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/containers/storage" "github.com/opencontainers/runtime-tools/generate" "github.com/stretchr/testify/assert" ) @@ -102,7 +103,11 @@ func getEmptyState() (s State, p string, err error) { dbPath := filepath.Join(tmpDir, "db.sql") lockPath := filepath.Join(tmpDir, "db.lck") - state, err := NewSQLState(dbPath, lockPath, tmpDir, nil) + runtime := new(Runtime) + runtime.config = new(RuntimeConfig) + runtime.config.StorageConfig = storage.StoreOptions{} + + state, err := NewSQLState(dbPath, lockPath, tmpDir, runtime) if err != nil { return nil, "", err }