Skip to content

Commit

Permalink
add option to truncate tables with an identity/auto-incr reset
Browse files Browse the repository at this point in the history
  • Loading branch information
s-mang committed Feb 12, 2014
1 parent f2c24b2 commit 50e3857
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
16 changes: 15 additions & 1 deletion dbmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,24 @@ func (m *DbMap) TableForType(t reflect.Type) *TableMap {

// Truncate all tables in the DbMap
func (m *DbMap) TruncateTables() error {
return m.truncateTables(false)
}

// Truncate all tables in the DbMap and reset identity counter
func (m *DbMap) TruncateTablesIdentityRestart() error {
return m.truncateTables(true)
}

func (m *DbMap) truncateTables(restartIdentity bool) error {
var err error
var restartClause string
for i := range m.tables {
table := m.tables[i]
_, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuoteField(table.TableName)))
if restartIdentity {
restartClause = m.Dialect.RestartIdentityClause(table.TableName)
}

_, e := m.Exec(fmt.Sprintf("%s %s %s;", m.Dialect.TruncateClause(), m.Dialect.QuoteField(table.TableName), restartClause))
if e != nil {
err = e
}
Expand Down
15 changes: 15 additions & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type Dialect interface {
// string used to truncate tables
TruncateClause() string

// string used to reset identity counter when truncating tables
RestartIdentityClause(table string) string

// Get the driver name from a dialect
DriverName() string
}
Expand Down Expand Up @@ -139,6 +142,10 @@ func (d SqliteDialect) TruncateClause() string {
return "delete from"
}

func (d SqliteDialect) RestartIdentityClause(table string) string {
return ""
}

///////////////////////////////////////////////////////
// PostgreSQL //
////////////////
Expand Down Expand Up @@ -241,6 +248,10 @@ func (d PostgresDialect) TruncateClause() string {
return "truncate"
}

func (d PostgresDialect) RestartIdentityClause(table string) string {
return "restart identity"
}

///////////////////////////////////////////////////////
// MySQL //
///////////
Expand Down Expand Up @@ -344,3 +355,7 @@ func ReBind(query string, dialect Dialect) string {
func (m MySQLDialect) TruncateClause() string {
return "truncate"
}

func (d MySQLDialect) RestartIdentityClause(table string) string {
return "; alter table " + table + " AUTO_INCREMENT = 1"
}
38 changes: 38 additions & 0 deletions modl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,44 @@ func TestTruncateTables(t *testing.T) {
}
}

func TestTruncateTablesIdentityRestart(t *testing.T) {
dbmap := initDbMap()
defer dbmap.DropTables()
err := dbmap.CreateTablesIfNotExists()
if err != nil {
t.Error(err)
}

// Insert some data
p1 := &Person{0, 0, 0, "Bob", "Smith", 0}
dbmap.Insert(p1)
inv := &Invoice{0, 0, 1, "my invoice", 0, true}
dbmap.Insert(inv)

err = dbmap.TruncateTablesIdentityRestart()
if err != nil {
t.Error(err)
}

// Make sure all rows are deleted
people := []Person{}
invoices := []Invoice{}
dbmap.Select(&people, "SELECT * FROM person_test")
if len(people) != 0 {
t.Errorf("Expected 0 person rows, got %d", len(people))
}
dbmap.Select(&invoices, "SELECT * FROM invoice_test")
if len(invoices) != 0 {
t.Errorf("Expected 0 invoice rows, got %d", len(invoices))
}

p2 := &Person{0, 0, 0, "Other", "Person", 0}
dbmap.Insert(p2)
if p2.Id != int64(1) {
t.Errorf("Expected new person Id to be equal to 1, was %d", p2.Id)
}
}

func TestQuoteTableNames(t *testing.T) {
dbmap := initDbMap()
defer dbmap.DropTables()
Expand Down

0 comments on commit 50e3857

Please sign in to comment.