Skip to content

Commit

Permalink
rename some variables, replace schema loading by schema
Browse files Browse the repository at this point in the history
  • Loading branch information
calebx committed Dec 27, 2020
1 parent 442690d commit 9189b56
Showing 1 changed file with 29 additions and 40 deletions.
69 changes: 29 additions & 40 deletions nested_set.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package nestedset

import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"sync"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

// MoveDirection means where the node is going to be located
Expand All @@ -25,7 +25,7 @@ const (
MoveDirectionInner MoveDirection = 0
)

type nodeItem struct {
type nestedItem struct {
ID int64
ParentID sql.NullInt64
Depth int
Expand All @@ -36,31 +36,25 @@ type nodeItem struct {
DbNames map[string]string
}

// parseNode parse a gorm structure into an internal source structure
// for bring in all required data attribute like scope, left, righ etc.
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err error) {
tx = db
stmt := &gorm.Statement{
DB: tx,
ConnPool: tx.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}

err = stmt.Parse(source)
// parseNode parse a gorm struct into an internal nested item struct
// bring in all required data attribute like scope, left, righ etc.
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nestedItem, err error) {
scm, err := schema.Parse(source, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
err = fmt.Errorf("Invalid source, must be a valid Gorm Model instance, %v", source)
return
}

item = nodeItem{TableName: stmt.Table, DbNames: map[string]string{}}
tx = db.Table(scm.Table)

item = nestedItem{TableName: scm.Table, DbNames: map[string]string{}}
sourceValue := reflect.Indirect(reflect.ValueOf(source))
sourceType := sourceValue.Type()
for i := 0; i < sourceType.NumField(); i++ {
t := sourceType.Field(i)
v := sourceValue.Field(i)

schemaField := stmt.Schema.LookUpField(t.Name)
schemaField := scm.LookUpField(t.Name)
dbName := schemaField.DBName

switch t.Tag.Get("nestedset") {
Expand Down Expand Up @@ -98,7 +92,7 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err
return
}

// Create a new node by parent with Gorm original Create()
// Create a new node within its parent by Gorm original Create() method
// ```nestedset.Create(db, &Category{...}, nil)``` will create a new category in root level
// ```nestedset.Create(db, &Category{...}, &parent)``` will create a new category under parent node as its last child
func Create(db *gorm.DB, source, parent interface{}) error {
Expand All @@ -108,17 +102,16 @@ func Create(db *gorm.DB, source, parent interface{}) error {
}

// for totally blank table / scope default init root would be [1 - 2]
setDepth, setToLft, setToRgt := 0, 1, 2
tableName, dbNames := target.TableName, target.DbNames
var setToDepth, setToLft, setToRgt = 0, 1, 2
dbNames := target.DbNames

return tx.Transaction(func(tx *gorm.DB) (err error) {
// put node into root level when parent is nil
// create node in root level when parent is nil
if parent == nil {
lastNode := make(map[string]interface{})
orderSQL := formatSQL(":rgt desc", target)
rst := tx.Model(source).Select(dbNames["rgt"]).Order(orderSQL).First(&lastNode)
rst := tx.Select(dbNames["rgt"]).Order(formatSQL(":rgt desc", target)).Take(&lastNode)
if rst.Error == nil {
setToLft = lastNode[dbNames["rgt"]].(int) + 1
setToLft = int(lastNode[dbNames["rgt"]].(int64) + 1)
setToRgt = setToLft + 1
}
} else {
Expand All @@ -129,22 +122,18 @@ func Create(db *gorm.DB, source, parent interface{}) error {

setToLft = targetParent.Rgt
setToRgt = targetParent.Rgt + 1
setDepth = targetParent.Depth + 1
setToDepth = targetParent.Depth + 1

// UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft;
err = tx.Table(tableName).
Where(formatSQL(":rgt >= ?", target), setToLft).
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).
Error
err = tx.Where(formatSQL(":rgt >= ?", target), setToLft).
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).Error
if err != nil {
return err
}

// UPDATE tree SET lft = lft + 2 WHERE lft > new_lft;
err = tx.Table(tableName).
Where(formatSQL(":lft > ?", target), setToLft).
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).
Error
err = tx.Where(formatSQL(":lft > ?", target), setToLft).
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).Error
if err != nil {
return err
}
Expand All @@ -158,7 +147,7 @@ func Create(db *gorm.DB, source, parent interface{}) error {
}
}

// Set Lft, Rgt, Depth dynamically by refect
// Set Lft, Rgt, Depth dynamically
v := reflect.Indirect(reflect.ValueOf(source))
t := v.Type()
for i := 0; i < t.NumField(); i++ {
Expand All @@ -174,7 +163,7 @@ func Create(db *gorm.DB, source, parent interface{}) error {
break
case "depth":
f := v.FieldByName(f.Name)
f.SetInt(int64(setDepth))
f.SetInt(int64(setToDepth))
break
}
}
Expand Down Expand Up @@ -265,7 +254,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
return moveToRightOfPosition(tx, targetNode, right, depthChange, newParentID)
}

func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChange int, newParentID sql.NullInt64) error {
func moveToRightOfPosition(tx *gorm.DB, targetNode nestedItem, position, depthChange int, newParentID sql.NullInt64) error {
return tx.Transaction(func(tx *gorm.DB) (err error) {
oldParentID := targetNode.ParentID
targetRight := targetNode.Rgt
Expand Down Expand Up @@ -308,7 +297,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan
})
}

func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentID sql.NullInt64) (err error) {
func syncChildrenCount(tx *gorm.DB, targetNode nestedItem, oldParentID, newParentID sql.NullInt64) (err error) {
var oldParentCount, newParentCount int64

if oldParentID.Valid {
Expand Down Expand Up @@ -336,7 +325,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI
return nil
}

func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
func moveTarget(tx *gorm.DB, targetNode nestedItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
dbNames := targetNode.DbNames

if len(targetIds) > 0 {
Expand All @@ -354,7 +343,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in
return tx.Where(formatSQL(":id = ?", targetNode), targetID).Update(dbNames["parent_id"], newParentID).Error
}

func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err error) {
func moveAffected(tx *gorm.DB, targetNode nestedItem, gte, lte, step int) (err error) {
dbNames := targetNode.DbNames

return tx.Where(formatSQL("(:lft BETWEEN ? AND ?) OR (:rgt BETWEEN ? AND ?)", targetNode), gte, lte, gte, lte).
Expand All @@ -364,7 +353,7 @@ func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err err
}).Error
}

func formatSQL(placeHolderSQL string, node nodeItem) (out string) {
func formatSQL(placeHolderSQL string, node nestedItem) (out string) {
out = placeHolderSQL

out = strings.ReplaceAll(out, ":table_name", node.TableName)
Expand Down

0 comments on commit 9189b56

Please sign in to comment.