Skip to content

Commit

Permalink
In some cases (Error not checked, missed data) one can perform very h…
Browse files Browse the repository at this point in the history
…armful operation - global update or delete (all records)

This is to prevent it.
  • Loading branch information
slockij committed Nov 4, 2016
1 parent d5d3e3a commit e26cb8d
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 2 deletions.
9 changes: 8 additions & 1 deletion callback_delete.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package gorm

import "fmt"
import (
"errors"
"fmt"
)

// Define callbacks for deleting
func init() {
Expand All @@ -13,6 +16,10 @@ func init() {

// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
return
}
if !scope.HasError() {
scope.CallMethod("BeforeDelete")
}
Expand Down
5 changes: 5 additions & 0 deletions callback_update.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"errors"
"fmt"
"strings"
)
Expand Down Expand Up @@ -31,6 +32,10 @@ func assignUpdatingAttributesCallback(scope *Scope) {

// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
Expand Down
24 changes: 23 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type DB struct {
source string
values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler
blockGlobalUpdate bool
}

// Open initialize a new db connection, need to import driver first, e.g:
Expand Down Expand Up @@ -142,6 +143,18 @@ func (s *DB) LogMode(enable bool) *DB {
return s
}

// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
// This is to prevent eventual error with empty objects updates/deletions
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
s.blockGlobalUpdate = enable
return s
}

// HasBlockGlobalUpdate return state of block
func (s *DB) HasBlockGlobalUpdate() bool {
return s.blockGlobalUpdate
}

// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap()
Expand Down Expand Up @@ -682,7 +695,16 @@ func (s *DB) GetErrors() (errors []error) {
////////////////////////////////////////////////////////////////////////////////

func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
db := DB{
db: s.db,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
values: map[string]interface{}{},
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
}

for key, value := range s.values {
db.values[key] = value
Expand Down
38 changes: 38 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,44 @@ func TestOpenWithOneParameter(t *testing.T) {
}
}

func TestBlockGlobalUpdate(t *testing.T) {
db := DB.New()
db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})

err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
if err != nil {
t.Error("Unexpected error on global update")
}

err = db.Delete(&Toy{}).Error
if err != nil {
t.Error("Unexpected error on global delete")
}

db.BlockGlobalUpdate(true)

db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})

err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
if err == nil {
t.Error("Expected error on global update")
}

err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
if err != nil {
t.Error("Unxpected error on conditional update")
}

err = db.Delete(&Toy{}).Error
if err == nil {
t.Error("Expected error on global delete")
}
err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
if err != nil {
t.Error("Unexpected error on conditional delete")
}
}

func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {
Expand Down
7 changes: 7 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,3 +1280,10 @@ func (scope *Scope) getColumnAsScope(column string) *Scope {
}
return nil
}

func (scope *Scope) hasConditions() bool {
return !scope.PrimaryKeyZero() ||
len(scope.Search.whereConditions) > 0 ||
len(scope.Search.orConditions) > 0 ||
len(scope.Search.notConditions) > 0
}

0 comments on commit e26cb8d

Please sign in to comment.