Skip to content

Commit

Permalink
Merge pull request #12 from calebx/main
Browse files Browse the repository at this point in the history
build Delete() for delete a node and all its descendant
  • Loading branch information
huacnlee authored Dec 27, 2020
2 parents 43484f0 + 9e32fb7 commit 1033f63
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 53 deletions.
137 changes: 84 additions & 53 deletions nested_set.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package nestedset

import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"sync"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

// MoveDirection means where the node is going to be located
Expand All @@ -25,7 +25,7 @@ const (
MoveDirectionInner MoveDirection = 0
)

type nodeItem struct {
type nestedItem struct {
ID int64
ParentID sql.NullInt64
Depth int
Expand All @@ -36,31 +36,25 @@ type nodeItem struct {
DbNames map[string]string
}

// parseNode parse a gorm structure into an internal source structure
// for bring in all required data attribute like scope, left, righ etc.
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err error) {
tx = db
stmt := &gorm.Statement{
DB: tx,
ConnPool: tx.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}

err = stmt.Parse(source)
// parseNode parse a gorm struct into an internal nested item struct
// bring in all required data attribute like scope, left, righ etc.
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nestedItem, err error) {
scm, err := schema.Parse(source, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
err = fmt.Errorf("Invalid source, must be a valid Gorm Model instance, %v", source)
return
}

item = nodeItem{TableName: stmt.Table, DbNames: map[string]string{}}
tx = db.Table(scm.Table)

item = nestedItem{TableName: scm.Table, DbNames: map[string]string{}}
sourceValue := reflect.Indirect(reflect.ValueOf(source))
sourceType := sourceValue.Type()
for i := 0; i < sourceType.NumField(); i++ {
t := sourceType.Field(i)
v := sourceValue.Field(i)

schemaField := stmt.Schema.LookUpField(t.Name)
schemaField := scm.LookUpField(t.Name)
dbName := schemaField.DBName

switch t.Tag.Get("nestedset") {
Expand Down Expand Up @@ -98,27 +92,26 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err
return
}

// Create a new node by parent with Gorm original Create()
// ```nestedset.Create(db, &Category{...}, nil)```` will create a new category in root level
// Create a new node within its parent by Gorm original Create() method
// ```nestedset.Create(db, &Category{...}, nil)``` will create a new category in root level
// ```nestedset.Create(db, &Category{...}, &parent)``` will create a new category under parent node as its last child
func Create(db *gorm.DB, source, parent interface{}) error {
return db.Transaction(func(db *gorm.DB) (err error) {
tx, target, err := parseNode(db, source)
if err != nil {
return err
}
tx, target, err := parseNode(db, source)
if err != nil {
return err
}

// for totally blank table / scope default init root would be [1 - 2]
setDepth, setToLft, setToRgt := 0, 1, 2
tableName, dbNames := target.TableName, target.DbNames
// for totally blank table / scope default init root would be [1 - 2]
setToDepth, setToLft, setToRgt := 0, 1, 2
dbNames := target.DbNames

// put node into root level when parent is nil
return tx.Transaction(func(tx *gorm.DB) (err error) {
// create node in root level when parent is nil
if parent == nil {
lastNode := make(map[string]interface{})
orderSQL := formatSQL(":rgt desc", target)
rst := tx.Model(source).Select(dbNames["rgt"]).Order(orderSQL).First(&lastNode)
rst := tx.Select(dbNames["rgt"]).Order(formatSQL(":rgt DESC", target)).Take(&lastNode)
if rst.Error == nil {
setToLft = lastNode[dbNames["rgt"]].(int) + 1
setToLft = int(lastNode[dbNames["rgt"]].(int64) + 1)
setToRgt = setToLft + 1
}
} else {
Expand All @@ -129,36 +122,31 @@ func Create(db *gorm.DB, source, parent interface{}) error {

setToLft = targetParent.Rgt
setToRgt = targetParent.Rgt + 1
setDepth = targetParent.Depth + 1
setToDepth = targetParent.Depth + 1

// UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft;
err = tx.Table(tableName).
Where(formatSQL(":rgt >= ?", target), setToLft).
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).
Error
err = tx.Where(formatSQL(":rgt >= ?", target), setToLft).
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).Error
if err != nil {
return err
}

// UPDATE tree SET lft = lft + 2 WHERE lft > new_lft;
err = tx.Table(tableName).
Where(formatSQL(":lft > ?", target), setToLft).
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).
Error
err = tx.Where(formatSQL(":lft > ?", target), setToLft).
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).Error
if err != nil {
return err
}

// UPDATE tree SET children_count = children_count + 1 WHERE is = parent.id;
err = db.Model(parent).Update(
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target)),
).Error
// UPDATE tree SET children_count = children_count + 1 WHERE id = parent.id;
err = tx.Model(parent).Update(
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target))).Error
if err != nil {
return err
}
}

// Set Lft, Rgt, Depth dynamically by refect
// Set Lft, Rgt, Depth dynamically
v := reflect.Indirect(reflect.ValueOf(source))
t := v.Type()
for i := 0; i < t.NumField(); i++ {
Expand All @@ -174,7 +162,7 @@ func Create(db *gorm.DB, source, parent interface{}) error {
break
case "depth":
f := v.FieldByName(f.Name)
f.SetInt(int64(setDepth))
f.SetInt(int64(setToDepth))
break
}
}
Expand All @@ -184,6 +172,51 @@ func Create(db *gorm.DB, source, parent interface{}) error {
})
}

// Delete a node from scoped list and its all descendent
// ```nestedset.Delete(db, &Category{...})```
func Delete(db *gorm.DB, source interface{}) error {
tx, target, err := parseNode(db, source)
if err != nil {
return err
}

// Batch Delete Method in GORM requires an instance of current source type without ID
// to avoid GORM style Delete interface, we hacked here by set source ID to 0
dbNames := target.DbNames
v := reflect.Indirect(reflect.ValueOf(source))
t := v.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("nestedset") == "id" {
f := v.FieldByName(f.Name)
f.SetInt(0)
break
}
}

return tx.Transaction(func(tx *gorm.DB) (err error) {
err = tx.Where(formatSQL(":lft >= ? AND :rgt <= ?", target), target.Lft, target.Rgt).
Delete(source).Error
if err != nil {
return err
}

// UPDATE tree SET rgt = rgt - width WHERE rgt > target_rgt;
// UPDATE tree SET lft = lft - width WHERE lft > target_rgt;
width := target.Rgt - target.Lft + 1
for _, d := range []string{"rgt", "lft"} {
err = tx.Where(formatSQL(":"+d+" > ?", target), target.Rgt).
Update(dbNames[d], gorm.Expr(formatSQL(":"+d+" - ?", target), width)).
Error
if err != nil {
return err
}
}

return nil
})
}

// MoveTo move node to a position which is related a target node
// ```nestedset.MoveTo(db, &node, &to, nestedset.MoveDirectionInner)``` will move [&node] to [&to] node's child_list as its first child
func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
Expand All @@ -197,8 +230,6 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
return err
}

tx = db.Table(targetNode.TableName)

var right, depthChange int
var newParentID sql.NullInt64
if direction == MoveDirectionLeft || direction == MoveDirectionRight {
Expand All @@ -218,7 +249,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
return moveToRightOfPosition(tx, targetNode, right, depthChange, newParentID)
}

func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChange int, newParentID sql.NullInt64) error {
func moveToRightOfPosition(tx *gorm.DB, targetNode nestedItem, position, depthChange int, newParentID sql.NullInt64) error {
return tx.Transaction(func(tx *gorm.DB) (err error) {
oldParentID := targetNode.ParentID
targetRight := targetNode.Rgt
Expand Down Expand Up @@ -261,7 +292,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan
})
}

func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentID sql.NullInt64) (err error) {
func syncChildrenCount(tx *gorm.DB, targetNode nestedItem, oldParentID, newParentID sql.NullInt64) (err error) {
var oldParentCount, newParentCount int64

if oldParentID.Valid {
Expand Down Expand Up @@ -289,7 +320,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI
return nil
}

func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
func moveTarget(tx *gorm.DB, targetNode nestedItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
dbNames := targetNode.DbNames

if len(targetIds) > 0 {
Expand All @@ -307,7 +338,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in
return tx.Where(formatSQL(":id = ?", targetNode), targetID).Update(dbNames["parent_id"], newParentID).Error
}

func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err error) {
func moveAffected(tx *gorm.DB, targetNode nestedItem, gte, lte, step int) (err error) {
dbNames := targetNode.DbNames

return tx.Where(formatSQL("(:lft BETWEEN ? AND ?) OR (:rgt BETWEEN ? AND ?)", targetNode), gte, lte, gte, lte).
Expand All @@ -317,7 +348,7 @@ func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err err
}).Error
}

func formatSQL(placeHolderSQL string, node nodeItem) (out string) {
func formatSQL(placeHolderSQL string, node nestedItem) (out string) {
out = placeHolderSQL

out = strings.ReplaceAll(out, ":table_name", node.TableName)
Expand Down
20 changes: 20 additions & 0 deletions nested_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,26 @@ func TestCreateSource(t *testing.T) {
assert.Equal(t, c2.ChildrenCount, 1)
}

func TestDeleteSource(t *testing.T) {
initData()

c1 := Category{Title: "c1s"}
Create(db, &c1, nil)

cp := Category{Title: "cp"}
Create(db, &cp, c1)

c2 := Category{Title: "c2s"}
Create(db, &c2, nil)

db.First(&c1)
Delete(db, &c1)

db.Model(&c2).First(&c2)
assert.Equal(t, c2.Lft, 1)
assert.Equal(t, c2.Rgt, 2)
}

func TestMoveToRight(t *testing.T) {
// case 1
initData()
Expand Down

0 comments on commit 1033f63

Please sign in to comment.