Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/migration/ImplementationDB.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func NewSQLiteDatabase(dsn string) (*SQLDatabase, error) {
return nil, fmt.Errorf("failed to ping sqlite database: %w", err)
}

return NewSQLDatabase(db, DriverSQLite), nil
return NewSQLDatabase(db, DriverSQLite)
}

// NewPostgresDatabase creates a PostgreSQL-backed database
Expand All @@ -42,7 +42,7 @@ func NewPostgresDatabase(dsn string) (*SQLDatabase, error) {
return nil, fmt.Errorf("failed to ping postgres database: %w", err)
}

return NewSQLDatabase(db, DriverPostgres), nil
return NewSQLDatabase(db, DriverPostgres)
}

// NewMySQLDatabase creates a MySQL-backed database
Expand All @@ -56,5 +56,5 @@ func NewMySQLDatabase(dsn string) (*SQLDatabase, error) {
return nil, fmt.Errorf("failed to ping mysql database: %w", err)
}

return NewSQLDatabase(db, DriverMySQL), nil
return NewSQLDatabase(db, DriverMySQL)
}
126 changes: 79 additions & 47 deletions lib/migration/SQLDatabase.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import (
// SQLDatabase implements the Database interface for SQL-based Etherpad stores
type SQLDatabase struct {
db *sql.DB
driver DriverType
tableName string
keyColumn string
valueColumn string
placeholder func(n int) string // Returns placeholder like $1 or ? depending on driver
}

type DriverType int
Expand All @@ -27,27 +27,65 @@ const (
DriverMySQL
)

// NewSQLDatabase creates a new SQLDatabase with the appropriate placeholder style
func NewSQLDatabase(db *sql.DB, driver DriverType) *SQLDatabase {
// validIdentifier ensures the identifier only contains safe characters
// This prevents SQL injection even if identifiers were somehow user-controlled
var validIdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)

func validateIdentifier(name string) error {
if !validIdentifierRegex.MatchString(name) {
return fmt.Errorf("invalid SQL identifier: %q", name)
}
return nil
}

// quoteIdentifier properly quotes an identifier based on the database driver
func (s *SQLDatabase) quoteIdentifier(name string) string {
switch s.driver {
case DriverMySQL:
// MySQL uses backticks, escape any backticks in the name
escaped := strings.ReplaceAll(name, "`", "``")
return "`" + escaped + "`"
case DriverPostgres, DriverSQLite:
// PostgreSQL and SQLite use double quotes, escape any double quotes
escaped := strings.ReplaceAll(name, `"`, `""`)
return `"` + escaped + `"`
default:
escaped := strings.ReplaceAll(name, `"`, `""`)
return `"` + escaped + `"`
}
}

// placeholder returns the appropriate placeholder for the driver
func (s *SQLDatabase) placeholder(n int) string {
switch s.driver {
case DriverPostgres:
return fmt.Sprintf("$%d", n)
default:
return "?"
}
}

// NewSQLDatabase creates a new SQLDatabase with the appropriate settings
func NewSQLDatabase(db *sql.DB, driver DriverType) (*SQLDatabase, error) {
s := &SQLDatabase{
db: db,
driver: driver,
tableName: "store",
keyColumn: "key",
valueColumn: "value",
}

switch driver {
case DriverPostgres:
s.placeholder = func(n int) string { return fmt.Sprintf("$%d", n) }
case DriverSQLite:
s.placeholder = func(n int) string { return "?" }
case DriverMySQL:
s.placeholder = func(n int) string { return "?" }
s.keyColumn = "`key`"
s.valueColumn = "`value`"
if err := validateIdentifier(s.tableName); err != nil {
return nil, fmt.Errorf("invalid table name: %w", err)
}
if err := validateIdentifier(s.keyColumn); err != nil {
return nil, fmt.Errorf("invalid key column: %w", err)
}
if err := validateIdentifier(s.valueColumn); err != nil {
return nil, fmt.Errorf("invalid value column: %w", err)
}

return s
return s, nil
}

func (s *SQLDatabase) Close() error {
Expand All @@ -61,7 +99,10 @@ func (s *SQLDatabase) Close() error {
func (s *SQLDatabase) getValue(key string) (string, error) {
query := fmt.Sprintf(
"SELECT %s FROM %s WHERE %s = %s",
s.valueColumn, s.tableName, s.keyColumn, s.placeholder(1),
s.quoteIdentifier(s.valueColumn),
s.quoteIdentifier(s.tableName),
s.quoteIdentifier(s.keyColumn),
s.placeholder(1),
)

var value string
Expand All @@ -76,19 +117,22 @@ func (s *SQLDatabase) getKeysByPrefix(prefix string, lastKey string, limit int)
var query string
var args []interface{}

quotedKey := s.quoteIdentifier(s.keyColumn)
quotedTable := s.quoteIdentifier(s.tableName)

if lastKey == "" {
query = fmt.Sprintf(
"SELECT %s FROM %s WHERE %s LIKE %s ORDER BY %s ASC LIMIT %s",
s.keyColumn, s.tableName, s.keyColumn,
s.placeholder(1), s.keyColumn, s.placeholder(2),
quotedKey, quotedTable, quotedKey,
s.placeholder(1), quotedKey, s.placeholder(2),
)
args = []interface{}{prefix + "%", limit}
} else {
query = fmt.Sprintf(
"SELECT %s FROM %s WHERE %s LIKE %s AND %s > %s ORDER BY %s ASC LIMIT %s",
s.keyColumn, s.tableName, s.keyColumn,
s.placeholder(1), s.keyColumn, s.placeholder(2),
s.keyColumn, s.placeholder(3),
quotedKey, quotedTable, quotedKey,
s.placeholder(1), quotedKey, s.placeholder(2),
quotedKey, s.placeholder(3),
)
args = []interface{}{prefix + "%", lastKey, limit}
}
Expand Down Expand Up @@ -119,19 +163,23 @@ func (s *SQLDatabase) getKeysAndValuesByPrefix(
var query string
var args []interface{}

quotedKey := s.quoteIdentifier(s.keyColumn)
quotedValue := s.quoteIdentifier(s.valueColumn)
quotedTable := s.quoteIdentifier(s.tableName)

if lastKey == "" {
query = fmt.Sprintf(
"SELECT * FROM %s WHERE %s LIKE %s ORDER BY %s ASC LIMIT %s",
s.tableName, s.keyColumn,
s.placeholder(1), s.keyColumn, s.placeholder(2),
"SELECT %s, %s FROM %s WHERE %s LIKE %s ORDER BY %s ASC LIMIT %s",
quotedKey, quotedValue, quotedTable, quotedKey,
s.placeholder(1), quotedKey, s.placeholder(2),
)
args = []interface{}{prefix + "%", limit}
} else {
query = fmt.Sprintf(
"SELECT * FROM %s WHERE %s LIKE %s AND %s > %s ORDER BY %s ASC LIMIT %s",
s.tableName, s.keyColumn,
s.placeholder(1), s.keyColumn, s.placeholder(2),
s.keyColumn, s.placeholder(3),
"SELECT %s, %s FROM %s WHERE %s LIKE %s AND %s > %s ORDER BY %s ASC LIMIT %s",
quotedKey, quotedValue, quotedTable, quotedKey,
s.placeholder(1), quotedKey, s.placeholder(2),
quotedKey, s.placeholder(3),
)
args = []interface{}{prefix + "%", lastKey, limit}
}
Expand Down Expand Up @@ -167,7 +215,7 @@ func (s *SQLDatabase) GetNextPads(lastPadId string, limit int) ([]Pad, error) {
lastKey = "pad:" + lastPadId
}

data, err := s.getKeysAndValuesByPrefix("pad:", lastKey, limit*10) // Get extra to filter
data, err := s.getKeysAndValuesByPrefix("pad:", lastKey, limit*10)
if err != nil {
return nil, err
}
Expand All @@ -176,7 +224,7 @@ func (s *SQLDatabase) GetNextPads(lastPadId string, limit int) ([]Pad, error) {
for key, value := range data {
matches := padKeyRegex.FindStringSubmatch(key)
if matches == nil {
continue // Skip revision keys like pad:xxx:revs:123
continue
}

padId := matches[1]
Expand All @@ -196,7 +244,6 @@ func (s *SQLDatabase) GetNextPads(lastPadId string, limit int) ([]Pad, error) {
}
}

// Sort by PadId for consistent ordering
sort.Slice(pads, func(i, j int) bool {
return pads[i].PadId < pads[j].PadId
})
Expand All @@ -212,22 +259,18 @@ func (s *SQLDatabase) GetNextPads(lastPadId string, limit int) ([]Pad, error) {
// Pad Revisions
// ============================================================================

// Key pattern: pad:<padId>:revs:<revNum>
func (s *SQLDatabase) GetPadRevisions(
padId string,
lastRev int,
limit int,
) ([]PadRevision, error) {
prefix := fmt.Sprintf("pad:%s:revs:", padId)

// For revisions, we need to handle numeric ordering
// Get all revision keys first, then filter and sort numerically
keys, err := s.getKeysByPrefix(prefix, "", 100000) // Get all revisions
keys, err := s.getKeysByPrefix(prefix, "", 100000)
if err != nil {
return nil, err
}

// Parse and filter revision numbers
type revKey struct {
num int
key string
Expand All @@ -245,17 +288,14 @@ func (s *SQLDatabase) GetPadRevisions(
}
}

// Sort by revision number
sort.Slice(revKeys, func(i, j int) bool {
return revKeys[i].num < revKeys[j].num
})

// Limit results
if len(revKeys) > limit {
revKeys = revKeys[:limit]
}

// Fetch values for selected keys
var revisions []PadRevision
for _, rk := range revKeys {
value, err := s.getValue(rk.key)
Expand All @@ -279,7 +319,6 @@ func (s *SQLDatabase) GetPadRevisions(
// Authors
// ============================================================================

// Key pattern: globalAuthor:<authorId>
func (s *SQLDatabase) GetNextAuthors(lastAuthorId string, limit int) ([]Author, error) {
lastKey := ""
if lastAuthorId != "" {
Expand All @@ -303,7 +342,6 @@ func (s *SQLDatabase) GetNextAuthors(lastAuthorId string, limit int) ([]Author,
authors = append(authors, author)
}

// Sort for consistent ordering
sort.Slice(authors, func(i, j int) bool {
return authors[i].Id < authors[j].Id
})
Expand All @@ -315,7 +353,6 @@ func (s *SQLDatabase) GetNextAuthors(lastAuthorId string, limit int) ([]Author,
// Readonly Mappings
// ============================================================================

// GetNextReadonly2Pad Key pattern: readonly2pad:<readonlyId>
func (s *SQLDatabase) GetNextReadonly2Pad(
lastReadonlyId string,
limit int,
Expand Down Expand Up @@ -352,7 +389,6 @@ func (s *SQLDatabase) GetNextReadonly2Pad(
return mappings, nil
}

// Key pattern: pad2readonly:<padId>
func (s *SQLDatabase) GetNextPad2Readonly(lastPadId string, limit int) ([]Pad2Readonly, error) {
lastKey := ""
if lastPadId != "" {
Expand Down Expand Up @@ -390,7 +426,6 @@ func (s *SQLDatabase) GetNextPad2Readonly(lastPadId string, limit int) ([]Pad2Re
// Token to Author
// ============================================================================

// Key pattern: token2author:<token>
func (s *SQLDatabase) GetNextToken2Author(lastToken string, limit int) ([]Token2Author, error) {
lastKey := ""
if lastToken != "" {
Expand Down Expand Up @@ -428,7 +463,6 @@ func (s *SQLDatabase) GetNextToken2Author(lastToken string, limit int) ([]Token2
// Chat Messages
// ============================================================================

// Key pattern: pad:<padId>:chat:<chatNum>
func (s *SQLDatabase) GetPadChatMessages(
padId string,
lastChatNum int,
Expand Down Expand Up @@ -489,7 +523,6 @@ func (s *SQLDatabase) GetPadChatMessages(
// Groups
// ============================================================================

// Key pattern: group:<groupId>
func (s *SQLDatabase) GetNextGroups(lastGroupId string, limit int) ([]Group, error) {
lastKey := ""
if lastGroupId != "" {
Expand Down Expand Up @@ -520,7 +553,6 @@ func (s *SQLDatabase) GetNextGroups(lastGroupId string, limit int) ([]Group, err
return groups, nil
}

// Key pattern: group2sessions:<groupId>
func (s *SQLDatabase) GetNextGroup2Sessions(
lastGroupId string,
limit int,
Expand Down Expand Up @@ -557,7 +589,6 @@ func (s *SQLDatabase) GetNextGroup2Sessions(
return mappings, nil
}

// Key pattern: author2sessions:<authorId>
func (s *SQLDatabase) GetNextAuthor2Sessions(
lastAuthorId string,
limit int,
Expand Down Expand Up @@ -594,7 +625,6 @@ func (s *SQLDatabase) GetNextAuthor2Sessions(
return mappings, nil
}

// Key pattern: session:<sessionId>
func (s *SQLDatabase) GetNextSessions(lastSessionId string, limit int) ([]Session, error) {
lastKey := ""
if lastSessionId != "" {
Expand Down Expand Up @@ -624,3 +654,5 @@ func (s *SQLDatabase) GetNextSessions(lastSessionId string, limit int) ([]Sessio

return sessions, nil
}

var _ Database = (*SQLDatabase)(nil)
20 changes: 20 additions & 0 deletions lib/migration/migrator_mysql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package migration

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestMigrator_MySQL_To_SQLite(t *testing.T) {
sqlDB, cleanup := startMySQL(t)
defer cleanup()

source, err := NewSQLDatabase(sqlDB, DriverMySQL)
assert.NoError(t, err)

insertData(t, sqlDB, insertKV)

target := setupSQLiteTarget(t)
startMigratorPipeline(t, source, target)
}
20 changes: 20 additions & 0 deletions lib/migration/migrator_postgres_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package migration

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestMigrator_Postgres_To_SQLite(t *testing.T) {
pgDB, cleanup := startPostgres(t)
defer cleanup()

source, err := NewSQLDatabase(pgDB, DriverPostgres)
assert.NoError(t, err)

insertData(t, pgDB, insertKVPostgres)

target := setupSQLiteTarget(t)
startMigratorPipeline(t, source, target)
}
Loading
Loading