diff --git a/nested_set.go b/nested_set.go index 4377171..d4563c5 100644 --- a/nested_set.go +++ b/nested_set.go @@ -1,14 +1,14 @@ package nestedset import ( - "context" "database/sql" "fmt" "reflect" "strings" + "sync" "gorm.io/gorm" - "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) // MoveDirection means where the node is going to be located @@ -25,7 +25,7 @@ const ( MoveDirectionInner MoveDirection = 0 ) -type nodeItem struct { +type nestedItem struct { ID int64 ParentID sql.NullInt64 Depth int @@ -36,31 +36,25 @@ type nodeItem struct { DbNames map[string]string } -// parseNode parse a gorm structure into an internal source structure -// for bring in all required data attribute like scope, left, righ etc. -func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err error) { - tx = db - stmt := &gorm.Statement{ - DB: tx, - ConnPool: tx.ConnPool, - Context: context.Background(), - Clauses: map[string]clause.Clause{}, - } - - err = stmt.Parse(source) +// parseNode parse a gorm struct into an internal nested item struct +// bring in all required data attribute like scope, left, righ etc. +func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nestedItem, err error) { + scm, err := schema.Parse(source, &sync.Map{}, schema.NamingStrategy{}) if err != nil { err = fmt.Errorf("Invalid source, must be a valid Gorm Model instance, %v", source) return } - item = nodeItem{TableName: stmt.Table, DbNames: map[string]string{}} + tx = db.Table(scm.Table) + + item = nestedItem{TableName: scm.Table, DbNames: map[string]string{}} sourceValue := reflect.Indirect(reflect.ValueOf(source)) sourceType := sourceValue.Type() for i := 0; i < sourceType.NumField(); i++ { t := sourceType.Field(i) v := sourceValue.Field(i) - schemaField := stmt.Schema.LookUpField(t.Name) + schemaField := scm.LookUpField(t.Name) dbName := schemaField.DBName switch t.Tag.Get("nestedset") { @@ -98,27 +92,26 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err return } -// Create a new node by parent with Gorm original Create() -// ```nestedset.Create(db, &Category{...}, nil)```` will create a new category in root level +// Create a new node within its parent by Gorm original Create() method +// ```nestedset.Create(db, &Category{...}, nil)``` will create a new category in root level // ```nestedset.Create(db, &Category{...}, &parent)``` will create a new category under parent node as its last child func Create(db *gorm.DB, source, parent interface{}) error { - return db.Transaction(func(db *gorm.DB) (err error) { - tx, target, err := parseNode(db, source) - if err != nil { - return err - } + tx, target, err := parseNode(db, source) + if err != nil { + return err + } - // for totally blank table / scope default init root would be [1 - 2] - setDepth, setToLft, setToRgt := 0, 1, 2 - tableName, dbNames := target.TableName, target.DbNames + // for totally blank table / scope default init root would be [1 - 2] + setToDepth, setToLft, setToRgt := 0, 1, 2 + dbNames := target.DbNames - // put node into root level when parent is nil + return tx.Transaction(func(tx *gorm.DB) (err error) { + // create node in root level when parent is nil if parent == nil { lastNode := make(map[string]interface{}) - orderSQL := formatSQL(":rgt desc", target) - rst := tx.Model(source).Select(dbNames["rgt"]).Order(orderSQL).First(&lastNode) + rst := tx.Select(dbNames["rgt"]).Order(formatSQL(":rgt DESC", target)).Take(&lastNode) if rst.Error == nil { - setToLft = lastNode[dbNames["rgt"]].(int) + 1 + setToLft = int(lastNode[dbNames["rgt"]].(int64) + 1) setToRgt = setToLft + 1 } } else { @@ -129,36 +122,31 @@ func Create(db *gorm.DB, source, parent interface{}) error { setToLft = targetParent.Rgt setToRgt = targetParent.Rgt + 1 - setDepth = targetParent.Depth + 1 + setToDepth = targetParent.Depth + 1 // UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft; - err = tx.Table(tableName). - Where(formatSQL(":rgt >= ?", target), setToLft). - UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))). - Error + err = tx.Where(formatSQL(":rgt >= ?", target), setToLft). + UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).Error if err != nil { return err } // UPDATE tree SET lft = lft + 2 WHERE lft > new_lft; - err = tx.Table(tableName). - Where(formatSQL(":lft > ?", target), setToLft). - UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))). - Error + err = tx.Where(formatSQL(":lft > ?", target), setToLft). + UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).Error if err != nil { return err } - // UPDATE tree SET children_count = children_count + 1 WHERE is = parent.id; - err = db.Model(parent).Update( - dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target)), - ).Error + // UPDATE tree SET children_count = children_count + 1 WHERE id = parent.id; + err = tx.Model(parent).Update( + dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target))).Error if err != nil { return err } } - // Set Lft, Rgt, Depth dynamically by refect + // Set Lft, Rgt, Depth dynamically v := reflect.Indirect(reflect.ValueOf(source)) t := v.Type() for i := 0; i < t.NumField(); i++ { @@ -174,7 +162,7 @@ func Create(db *gorm.DB, source, parent interface{}) error { break case "depth": f := v.FieldByName(f.Name) - f.SetInt(int64(setDepth)) + f.SetInt(int64(setToDepth)) break } } @@ -184,6 +172,51 @@ func Create(db *gorm.DB, source, parent interface{}) error { }) } +// Delete a node from scoped list and its all descendent +// ```nestedset.Delete(db, &Category{...})``` +func Delete(db *gorm.DB, source interface{}) error { + tx, target, err := parseNode(db, source) + if err != nil { + return err + } + + // Batch Delete Method in GORM requires an instance of current source type without ID + // to avoid GORM style Delete interface, we hacked here by set source ID to 0 + dbNames := target.DbNames + v := reflect.Indirect(reflect.ValueOf(source)) + t := v.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Tag.Get("nestedset") == "id" { + f := v.FieldByName(f.Name) + f.SetInt(0) + break + } + } + + return tx.Transaction(func(tx *gorm.DB) (err error) { + err = tx.Where(formatSQL(":lft >= ? AND :rgt <= ?", target), target.Lft, target.Rgt). + Delete(source).Error + if err != nil { + return err + } + + // UPDATE tree SET rgt = rgt - width WHERE rgt > target_rgt; + // UPDATE tree SET lft = lft - width WHERE lft > target_rgt; + width := target.Rgt - target.Lft + 1 + for _, d := range []string{"rgt", "lft"} { + err = tx.Where(formatSQL(":"+d+" > ?", target), target.Rgt). + Update(dbNames[d], gorm.Expr(formatSQL(":"+d+" - ?", target), width)). + Error + if err != nil { + return err + } + } + + return nil + }) +} + // MoveTo move node to a position which is related a target node // ```nestedset.MoveTo(db, &node, &to, nestedset.MoveDirectionInner)``` will move [&node] to [&to] node's child_list as its first child func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error { @@ -197,8 +230,6 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error { return err } - tx = db.Table(targetNode.TableName) - var right, depthChange int var newParentID sql.NullInt64 if direction == MoveDirectionLeft || direction == MoveDirectionRight { @@ -218,7 +249,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error { return moveToRightOfPosition(tx, targetNode, right, depthChange, newParentID) } -func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChange int, newParentID sql.NullInt64) error { +func moveToRightOfPosition(tx *gorm.DB, targetNode nestedItem, position, depthChange int, newParentID sql.NullInt64) error { return tx.Transaction(func(tx *gorm.DB) (err error) { oldParentID := targetNode.ParentID targetRight := targetNode.Rgt @@ -261,7 +292,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan }) } -func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentID sql.NullInt64) (err error) { +func syncChildrenCount(tx *gorm.DB, targetNode nestedItem, oldParentID, newParentID sql.NullInt64) (err error) { var oldParentCount, newParentCount int64 if oldParentID.Valid { @@ -289,7 +320,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI return nil } -func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) { +func moveTarget(tx *gorm.DB, targetNode nestedItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) { dbNames := targetNode.DbNames if len(targetIds) > 0 { @@ -307,7 +338,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in return tx.Where(formatSQL(":id = ?", targetNode), targetID).Update(dbNames["parent_id"], newParentID).Error } -func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err error) { +func moveAffected(tx *gorm.DB, targetNode nestedItem, gte, lte, step int) (err error) { dbNames := targetNode.DbNames return tx.Where(formatSQL("(:lft BETWEEN ? AND ?) OR (:rgt BETWEEN ? AND ?)", targetNode), gte, lte, gte, lte). @@ -317,7 +348,7 @@ func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err err }).Error } -func formatSQL(placeHolderSQL string, node nodeItem) (out string) { +func formatSQL(placeHolderSQL string, node nestedItem) (out string) { out = placeHolderSQL out = strings.ReplaceAll(out, ":table_name", node.TableName) diff --git a/nested_set_test.go b/nested_set_test.go index bfe27af..3116266 100644 --- a/nested_set_test.go +++ b/nested_set_test.go @@ -130,6 +130,26 @@ func TestCreateSource(t *testing.T) { assert.Equal(t, c2.ChildrenCount, 1) } +func TestDeleteSource(t *testing.T) { + initData() + + c1 := Category{Title: "c1s"} + Create(db, &c1, nil) + + cp := Category{Title: "cp"} + Create(db, &cp, c1) + + c2 := Category{Title: "c2s"} + Create(db, &c2, nil) + + db.First(&c1) + Delete(db, &c1) + + db.Model(&c2).First(&c2) + assert.Equal(t, c2.Lft, 1) + assert.Equal(t, c2.Rgt, 2) +} + func TestMoveToRight(t *testing.T) { // case 1 initData()