Skip to content

Commit

Permalink
implemented new node logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cxiang03 committed Dec 24, 2020
1 parent 57a86d3 commit cba47c2
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 2 deletions.
95 changes: 93 additions & 2 deletions nested_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,97 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err
return
}

// Create a new node by parent
func Create(db *gorm.DB, source, parent interface{}) error {
return db.Transaction(func(db *gorm.DB) (err error) {
tx, targetNode, err := parseNode(db, source)
if err != nil {
return err
}

// for totally blank table / scope default init root would be [1 - 2]
setToLft := 1
setToRgt := 2
setDepth := 0

// put node into root level when parent is nil
if parent == nil {
lastOne := make(map[string]interface{})
result := tx.Model(source).Select(targetNode.DbNames["rgt"]).Order(formatSQL(":rgt desc", targetNode)).First(&lastOne)
if result.Error == nil {
setToLft = lastOne[targetNode.DbNames["rgt"]].(int) + 1
setToRgt = setToLft + 1
}
} else {
_, parentNode, err := parseNode(db, parent)
if err != nil {
return err
}

setToLft = parentNode.Rgt
setToRgt = parentNode.Rgt + 1
setDepth = parentNode.Depth + 1

dbNames := targetNode.DbNames

// UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft;
updateErr := tx.Table(targetNode.TableName).Where(
formatSQL(":rgt >= ?", targetNode), setToLft,
).UpdateColumn(
dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", targetNode)),
).Error

if updateErr != nil {
return updateErr
}

// UPDATE tree SET lft = lft + 2 WHERE lft > new_lft;
updateErr = tx.Table(targetNode.TableName).Where(
formatSQL(":lft > ?", targetNode), setToLft,
).UpdateColumn(
dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", targetNode)),
).Error
if updateErr != nil {
return updateErr
}

// UPDATE tree SET children_count = children_count + 1 WHERE is = parent.id;
updateErr = db.Model(parent).Update(
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", targetNode)),
).Error
if updateErr != nil {
return updateErr
}
}

sourceValue := reflect.Indirect(reflect.ValueOf(source))
sourceType := sourceValue.Type()
for i := 0; i < sourceType.NumField(); i++ {
t := sourceType.Field(i)
switch t.Tag.Get("nestedset") {
case "lft":
f := sourceValue.FieldByName(t.Name)
f.SetInt(int64(setToLft))
break
case "rgt":
f := sourceValue.FieldByName(t.Name)
f.SetInt(int64(setToRgt))
break
case "depth":
f := sourceValue.FieldByName(t.Name)
f.SetInt(int64(setDepth))
break
}
}

// skip the table & scope, since they should be all setup by caller
return db.Create(source).Error
})
}

// MoveTo move node to a position which is related a target node
func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
_, targetNode, err := parseNode(db, node)
tx, targetNode, err := parseNode(db, node)
if err != nil {
return err
}
Expand All @@ -108,7 +196,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
return err
}

tx := db.Table(targetNode.TableName)
tx = db.Table(targetNode.TableName)

var right, depthChange int
var newParentID sql.NullInt64
Expand Down Expand Up @@ -174,6 +262,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan

func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentID sql.NullInt64) (err error) {
var oldParentCount, newParentCount int64

if oldParentID.Valid {
err = tx.Where(formatSQL(":parent_id = ?", targetNode), oldParentID).Count(&oldParentCount).Error
if err != nil {
Expand All @@ -195,6 +284,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI
return
}
}

return nil
}

Expand All @@ -212,6 +302,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in
return
}
}

return tx.Where(formatSQL(":id = ?", targetNode), targetID).Update(dbNames["parent_id"], newParentID).Error
}

Expand Down
24 changes: 24 additions & 0 deletions nested_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,30 @@ func TestNewNodeItem(t *testing.T) {
assert.Equal(t, "item_id = ? AND left > right AND pid = ?, nodes_count = 1, depth1 = 0", formatSQL(":id = ? AND :lft > :rgt AND :parent_id = ?, :children_count = 1, :depth = 0", node))
}

func TestCreateSource(t *testing.T) {
initData()

c1 := Category{Title: "c1s"}
Create(db, &c1, nil)
assert.Equal(t, c1.Lft, 1)
assert.Equal(t, c1.Rgt, 2)

c2 := Category{Title: "c2s", UserType: "user"}
Create(db, &c2, nil)
assert.Equal(t, c2.Lft, 1)
assert.Equal(t, c2.Rgt, 2)

c3 := Category{Title: "c3s", UserType: "user"}
Create(db, &c3, nil)
assert.Equal(t, c3.Lft, 3)
assert.Equal(t, c3.Rgt, 4)

c4 := Category{Title: "c4s", UserType: "user"}
Create(db, &c4, &c2)
assert.Equal(t, c4.Lft, 2)
assert.Equal(t, c4.Rgt, 3)
}

func TestMoveToRight(t *testing.T) {
// case 1
initData()
Expand Down

0 comments on commit cba47c2

Please sign in to comment.