Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build Delete() for delete a node and all its descendant #12

Merged
merged 11 commits into from
Dec 27, 2020
137 changes: 84 additions & 53 deletions nested_set.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package nestedset

import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"sync"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

// MoveDirection means where the node is going to be located
Expand All @@ -25,7 +25,7 @@ const (
MoveDirectionInner MoveDirection = 0
)

type nodeItem struct {
type nestedItem struct {
ID int64
ParentID sql.NullInt64
Depth int
Expand All @@ -36,31 +36,25 @@ type nodeItem struct {
DbNames map[string]string
}

// parseNode parse a gorm structure into an internal source structure
// for bring in all required data attribute like scope, left, righ etc.
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err error) {
tx = db
stmt := &gorm.Statement{
DB: tx,
ConnPool: tx.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}

err = stmt.Parse(source)
// parseNode parse a gorm struct into an internal nested item struct
// bring in all required data attribute like scope, left, righ etc.
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nestedItem, err error) {
scm, err := schema.Parse(source, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
err = fmt.Errorf("Invalid source, must be a valid Gorm Model instance, %v", source)
return
}

item = nodeItem{TableName: stmt.Table, DbNames: map[string]string{}}
tx = db.Table(scm.Table)

item = nestedItem{TableName: scm.Table, DbNames: map[string]string{}}
sourceValue := reflect.Indirect(reflect.ValueOf(source))
sourceType := sourceValue.Type()
for i := 0; i < sourceType.NumField(); i++ {
t := sourceType.Field(i)
v := sourceValue.Field(i)

schemaField := stmt.Schema.LookUpField(t.Name)
schemaField := scm.LookUpField(t.Name)
dbName := schemaField.DBName

switch t.Tag.Get("nestedset") {
Expand Down Expand Up @@ -98,27 +92,26 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err
return
}

// Create a new node by parent with Gorm original Create()
// ```nestedset.Create(db, &Category{...}, nil)```` will create a new category in root level
// Create a new node within its parent by Gorm original Create() method
// ```nestedset.Create(db, &Category{...}, nil)``` will create a new category in root level
// ```nestedset.Create(db, &Category{...}, &parent)``` will create a new category under parent node as its last child
func Create(db *gorm.DB, source, parent interface{}) error {
return db.Transaction(func(db *gorm.DB) (err error) {
tx, target, err := parseNode(db, source)
if err != nil {
return err
}
tx, target, err := parseNode(db, source)
if err != nil {
return err
}

// for totally blank table / scope default init root would be [1 - 2]
setDepth, setToLft, setToRgt := 0, 1, 2
tableName, dbNames := target.TableName, target.DbNames
// for totally blank table / scope default init root would be [1 - 2]
setToDepth, setToLft, setToRgt := 0, 1, 2
dbNames := target.DbNames

// put node into root level when parent is nil
return tx.Transaction(func(tx *gorm.DB) (err error) {
// create node in root level when parent is nil
if parent == nil {
lastNode := make(map[string]interface{})
orderSQL := formatSQL(":rgt desc", target)
rst := tx.Model(source).Select(dbNames["rgt"]).Order(orderSQL).First(&lastNode)
rst := tx.Select(dbNames["rgt"]).Order(formatSQL(":rgt DESC", target)).Take(&lastNode)
if rst.Error == nil {
setToLft = lastNode[dbNames["rgt"]].(int) + 1
setToLft = int(lastNode[dbNames["rgt"]].(int64) + 1)
setToRgt = setToLft + 1
}
} else {
Expand All @@ -129,36 +122,31 @@ func Create(db *gorm.DB, source, parent interface{}) error {

setToLft = targetParent.Rgt
setToRgt = targetParent.Rgt + 1
setDepth = targetParent.Depth + 1
setToDepth = targetParent.Depth + 1

// UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft;
err = tx.Table(tableName).
Where(formatSQL(":rgt >= ?", target), setToLft).
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).
Error
err = tx.Where(formatSQL(":rgt >= ?", target), setToLft).
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).Error
if err != nil {
return err
}

// UPDATE tree SET lft = lft + 2 WHERE lft > new_lft;
err = tx.Table(tableName).
Where(formatSQL(":lft > ?", target), setToLft).
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).
Error
err = tx.Where(formatSQL(":lft > ?", target), setToLft).
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).Error
if err != nil {
return err
}

// UPDATE tree SET children_count = children_count + 1 WHERE is = parent.id;
err = db.Model(parent).Update(
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target)),
).Error
// UPDATE tree SET children_count = children_count + 1 WHERE id = parent.id;
err = tx.Model(parent).Update(
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target))).Error
if err != nil {
return err
}
}

// Set Lft, Rgt, Depth dynamically by refect
// Set Lft, Rgt, Depth dynamically
v := reflect.Indirect(reflect.ValueOf(source))
t := v.Type()
for i := 0; i < t.NumField(); i++ {
Expand All @@ -174,7 +162,7 @@ func Create(db *gorm.DB, source, parent interface{}) error {
break
case "depth":
f := v.FieldByName(f.Name)
f.SetInt(int64(setDepth))
f.SetInt(int64(setToDepth))
break
}
}
Expand All @@ -184,6 +172,51 @@ func Create(db *gorm.DB, source, parent interface{}) error {
})
}

// Delete a node from scoped list and its all descendent
// ```nestedset.Delete(db, &Category{...})```
func Delete(db *gorm.DB, source interface{}) error {
tx, target, err := parseNode(db, source)
if err != nil {
return err
}

// Batch Delete Method in GORM requires an instance of current source type without ID
// to avoid GORM style Delete interface, we hacked here by set source ID to 0
dbNames := target.DbNames
v := reflect.Indirect(reflect.ValueOf(source))
t := v.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("nestedset") == "id" {
f := v.FieldByName(f.Name)
f.SetInt(0)
break
}
}

return tx.Transaction(func(tx *gorm.DB) (err error) {
err = tx.Where(formatSQL(":lft >= ? AND :rgt <= ?", target), target.Lft, target.Rgt).
Delete(source).Error
if err != nil {
return err
}

// UPDATE tree SET rgt = rgt - width WHERE rgt > target_rgt;
// UPDATE tree SET lft = lft - width WHERE lft > target_rgt;
width := target.Rgt - target.Lft + 1
for _, d := range []string{"rgt", "lft"} {
err = tx.Where(formatSQL(":"+d+" > ?", target), target.Rgt).
Update(dbNames[d], gorm.Expr(formatSQL(":"+d+" - ?", target), width)).
Error
if err != nil {
return err
}
}

return nil
})
}

// MoveTo move node to a position which is related a target node
// ```nestedset.MoveTo(db, &node, &to, nestedset.MoveDirectionInner)``` will move [&node] to [&to] node's child_list as its first child
func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
Expand All @@ -197,8 +230,6 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
return err
}

tx = db.Table(targetNode.TableName)

var right, depthChange int
var newParentID sql.NullInt64
if direction == MoveDirectionLeft || direction == MoveDirectionRight {
Expand All @@ -218,7 +249,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
return moveToRightOfPosition(tx, targetNode, right, depthChange, newParentID)
}

func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChange int, newParentID sql.NullInt64) error {
func moveToRightOfPosition(tx *gorm.DB, targetNode nestedItem, position, depthChange int, newParentID sql.NullInt64) error {
return tx.Transaction(func(tx *gorm.DB) (err error) {
oldParentID := targetNode.ParentID
targetRight := targetNode.Rgt
Expand Down Expand Up @@ -261,7 +292,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan
})
}

func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentID sql.NullInt64) (err error) {
func syncChildrenCount(tx *gorm.DB, targetNode nestedItem, oldParentID, newParentID sql.NullInt64) (err error) {
var oldParentCount, newParentCount int64

if oldParentID.Valid {
Expand Down Expand Up @@ -289,7 +320,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI
return nil
}

func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
func moveTarget(tx *gorm.DB, targetNode nestedItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
dbNames := targetNode.DbNames

if len(targetIds) > 0 {
Expand All @@ -307,7 +338,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in
return tx.Where(formatSQL(":id = ?", targetNode), targetID).Update(dbNames["parent_id"], newParentID).Error
}

func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err error) {
func moveAffected(tx *gorm.DB, targetNode nestedItem, gte, lte, step int) (err error) {
dbNames := targetNode.DbNames

return tx.Where(formatSQL("(:lft BETWEEN ? AND ?) OR (:rgt BETWEEN ? AND ?)", targetNode), gte, lte, gte, lte).
Expand All @@ -317,7 +348,7 @@ func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err err
}).Error
}

func formatSQL(placeHolderSQL string, node nodeItem) (out string) {
func formatSQL(placeHolderSQL string, node nestedItem) (out string) {
out = placeHolderSQL

out = strings.ReplaceAll(out, ":table_name", node.TableName)
Expand Down
20 changes: 20 additions & 0 deletions nested_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,26 @@ func TestCreateSource(t *testing.T) {
assert.Equal(t, c2.ChildrenCount, 1)
}

func TestDeleteSource(t *testing.T) {
initData()

c1 := Category{Title: "c1s"}
Create(db, &c1, nil)

cp := Category{Title: "cp"}
Create(db, &cp, c1)

c2 := Category{Title: "c2s"}
Create(db, &c2, nil)

db.First(&c1)
Delete(db, &c1)

db.Model(&c2).First(&c2)
assert.Equal(t, c2.Lft, 1)
assert.Equal(t, c2.Rgt, 2)
}

func TestMoveToRight(t *testing.T) {
// case 1
initData()
Expand Down