Skip to content

Commit

Permalink
#96, allows node-set numeric operator on +, -, *, MOD(), DIV()
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengchun committed Apr 9, 2024
1 parent 5116a24 commit 4b4638b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion build.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ func (b *builder) processOperator(root *operatorNode, props *builderProp) (query
var qyOutput query
switch root.Op {
case "+", "-", "*", "div", "mod": // Numeric operator
var exprFunc func(interface{}, interface{}) interface{}
var exprFunc func(iterator, interface{}, interface{}) interface{}
switch root.Op {
case "+":
exprFunc = plusFunc
Expand Down
30 changes: 14 additions & 16 deletions operator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package xpath

import (
"reflect"
"strconv"
)

Expand Down Expand Up @@ -247,44 +246,43 @@ var orFunc = func(t iterator, m, n interface{}) interface{} {
return logicalFuncs[t1][t2](t, "or", m, n)
}

func numericExpr(m, n interface{}, cb func(float64, float64) float64) float64 {
typ := reflect.TypeOf(float64(0))
a := reflect.ValueOf(m).Convert(typ)
b := reflect.ValueOf(n).Convert(typ)
return cb(a.Float(), b.Float())
func numericExpr(t iterator, m, n interface{}, cb func(float64, float64) float64) float64 {
a := asNumber(t, m)
b := asNumber(t, n)
return cb(a, b)
}

// plusFunc is an `+` operator.
var plusFunc = func(m, n interface{}) interface{} {
return numericExpr(m, n, func(a, b float64) float64 {
var plusFunc = func(t iterator, m, n interface{}) interface{} {
return numericExpr(t, m, n, func(a, b float64) float64 {
return a + b
})
}

// minusFunc is an `-` operator.
var minusFunc = func(m, n interface{}) interface{} {
return numericExpr(m, n, func(a, b float64) float64 {
var minusFunc = func(t iterator, m, n interface{}) interface{} {
return numericExpr(t, m, n, func(a, b float64) float64 {
return a - b
})
}

// mulFunc is an `*` operator.
var mulFunc = func(m, n interface{}) interface{} {
return numericExpr(m, n, func(a, b float64) float64 {
var mulFunc = func(t iterator, m, n interface{}) interface{} {
return numericExpr(t, m, n, func(a, b float64) float64 {
return a * b
})
}

// divFunc is an `DIV` operator.
var divFunc = func(m, n interface{}) interface{} {
return numericExpr(m, n, func(a, b float64) float64 {
var divFunc = func(t iterator, m, n interface{}) interface{} {
return numericExpr(t, m, n, func(a, b float64) float64 {
return a / b
})
}

// modFunc is an 'MOD' operator.
var modFunc = func(m, n interface{}) interface{} {
return numericExpr(m, n, func(a, b float64) float64 {
var modFunc = func(t iterator, m, n interface{}) interface{} {
return numericExpr(t, m, n, func(a, b float64) float64 {
return float64(int(a) % int(b))
})
}
4 changes: 2 additions & 2 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ func (l *logicalQuery) Properties() queryProp {
type numericQuery struct {
Left, Right query

Do func(interface{}, interface{}) interface{}
Do func(iterator, interface{}, interface{}) interface{}
}

func (n *numericQuery) Select(t iterator) NodeNavigator {
Expand All @@ -1009,7 +1009,7 @@ func (n *numericQuery) Select(t iterator) NodeNavigator {
func (n *numericQuery) Evaluate(t iterator) interface{} {
m := n.Left.Evaluate(t)
k := n.Right.Evaluate(t)
return n.Do(m, k)
return n.Do(t, m, k)
}

func (n *numericQuery) Clone() query {
Expand Down
23 changes: 23 additions & 0 deletions xpath_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package xpath
import (
"bytes"
"fmt"
"math"
"sort"
"strings"
"testing"
Expand Down Expand Up @@ -235,6 +236,28 @@ func TestMustCompile(t *testing.T) {
}
}

func Test_plusFunc(t *testing.T) {
// 1+1
assertEqual(t, float64(2), plusFunc(nil, float64(1), float64(1)))
// string +
assertEqual(t, float64(2), plusFunc(nil, "1", "1"))
// invalid string
v := plusFunc(nil, "a", 1)
assertTrue(t, math.IsNaN(v.(float64)))
// Nodeset
// TODO
}

func Test_minusFunc(t *testing.T) {
// 1 - 1
assertEqual(t, float64(0), minusFunc(nil, float64(1), float64(1)))
// string
assertEqual(t, float64(0), minusFunc(nil, "1", "1"))
// invalid string
v := minusFunc(nil, "a", 1)
assertTrue(t, math.IsNaN(v.(float64)))
}

func TestNodeType(t *testing.T) {
tests := []struct {
expr string
Expand Down

0 comments on commit 4b4638b

Please sign in to comment.