diff --git a/build.go b/build.go index 10a5440..0a002ad 100644 --- a/build.go +++ b/build.go @@ -521,6 +521,9 @@ func (b *builder) processNode(root node) (q query, err error) { return } q = &groupQuery{Input: q} + // fix https://github.com/antchfx/xpath/issues/76 + q = &cacheQuery{Input: q} + b.firstInput = q } return } diff --git a/func.go b/func.go index afe5988..1c1c93b 100644 --- a/func.go +++ b/func.go @@ -53,6 +53,14 @@ func positionFunc(q query, t iterator) interface{} { // lastFunc is a XPath Node Set functions last(). func lastFunc(q query, t iterator) interface{} { + // + type Counter interface { + count() int + } + if p, ok := q.(Counter); ok { + return float64(p.count()) + } + var ( count = 0 node = t.Current().Copy() @@ -158,7 +166,8 @@ func nameFunc(arg query) func(query, iterator) interface{} { if arg == nil { v = t.Current() } else { - v = arg.Clone().Select(t) + arg.Reset() + v = arg.Select(t) if v == nil { return "" } @@ -178,7 +187,8 @@ func localNameFunc(arg query) func(query, iterator) interface{} { if arg == nil { v = t.Current() } else { - v = arg.Clone().Select(t) + arg.Reset() + v = arg.Select(t) if v == nil { return "" } @@ -195,7 +205,8 @@ func namespaceFunc(arg query) func(query, iterator) interface{} { v = t.Current() } else { // Get the first node in the node-set if specified. - v = arg.Clone().Select(t) + arg.Reset() + v = arg.Select(t) if v == nil { return "" } @@ -592,7 +603,8 @@ func functionArgs(q query) query { if _, ok := q.(*functionQuery); ok { return q } - return q.Clone() + q.Reset() + return q } func reverseFunc(q query, t iterator) func() NodeNavigator { diff --git a/func_test.go b/func_test.go index 2ee13fe..1e05a11 100644 --- a/func_test.go +++ b/func_test.go @@ -8,8 +8,7 @@ func (t testQuery) Select(_ iterator) NodeNavigator { panic("implement me") } -func (t testQuery) Clone() query { - return t +func (t testQuery) Reset() { } func (t testQuery) Evaluate(_ iterator) interface{} { diff --git a/query.go b/query.go index 6ceb033..08509ea 100644 --- a/query.go +++ b/query.go @@ -19,7 +19,7 @@ type query interface { // Evaluate evaluates query and returns values of the current query. Evaluate(iterator) interface{} - Clone() query + Reset() } // nopQuery is an empty query that always return nil for any query. @@ -31,7 +31,7 @@ func (nopQuery) Select(iterator) NodeNavigator { return nil } func (nopQuery) Evaluate(iterator) interface{} { return nil } -func (nopQuery) Clone() query { return nopQuery{} } +func (nopQuery) Reset() {} // contextQuery is returns current node on the iterator object query. type contextQuery struct { @@ -55,8 +55,8 @@ func (c *contextQuery) Evaluate(iterator) interface{} { return c } -func (c *contextQuery) Clone() query { - return &contextQuery{count: 0, Root: c.Root} +func (c *contextQuery) Reset() { + c.count = 0 } // ancestorQuery is an XPath ancestor node query.(ancestor::*|ancestor-self::*) @@ -111,8 +111,9 @@ func (a *ancestorQuery) Test(n NodeNavigator) bool { return a.Predicate(n) } -func (a *ancestorQuery) Clone() query { - return &ancestorQuery{Self: a.Self, Input: a.Input.Clone(), Predicate: a.Predicate} +func (a *ancestorQuery) Reset() { + a.iterator = nil + a.Input.Reset() } // attributeQuery is an XPath attribute node query.(@*) @@ -161,8 +162,9 @@ func (a *attributeQuery) Test(n NodeNavigator) bool { return a.Predicate(n) } -func (a *attributeQuery) Clone() query { - return &attributeQuery{Input: a.Input.Clone(), Predicate: a.Predicate} +func (a *attributeQuery) Reset() { + a.Input.Reset() + a.iterator = nil } // childQuery is an XPath child node query.(child::*) @@ -215,8 +217,10 @@ func (c *childQuery) Test(n NodeNavigator) bool { return c.Predicate(n) } -func (c *childQuery) Clone() query { - return &childQuery{Input: c.Input.Clone(), Predicate: c.Predicate} +func (c *childQuery) Reset() { + c.posit = 0 + c.iterator = nil + c.Input.Reset() } // position returns a position of current NodeNavigator. @@ -308,8 +312,11 @@ func (d *descendantQuery) depth() int { return d.level } -func (d *descendantQuery) Clone() query { - return &descendantQuery{Self: d.Self, Input: d.Input.Clone(), Predicate: d.Predicate} +func (d *descendantQuery) Reset() { + d.posit = 0 + d.iterator = nil + d.level = 0 + d.Input.Reset() } // followingQuery is an XPath following node query.(following::*|following-sibling::*) @@ -386,8 +393,10 @@ func (f *followingQuery) Test(n NodeNavigator) bool { return f.Predicate(n) } -func (f *followingQuery) Clone() query { - return &followingQuery{Input: f.Input.Clone(), Sibling: f.Sibling, Predicate: f.Predicate} +func (f *followingQuery) Reset() { + f.posit = 0 + f.iterator = nil + f.Input.Reset() } func (f *followingQuery) position() int { @@ -467,8 +476,10 @@ func (p *precedingQuery) Test(n NodeNavigator) bool { return p.Predicate(n) } -func (p *precedingQuery) Clone() query { - return &precedingQuery{Input: p.Input.Clone(), Sibling: p.Sibling, Predicate: p.Predicate} +func (p *precedingQuery) Reset() { + p.posit = 0 + p.iterator = nil + p.Input.Reset() } func (p *precedingQuery) position() int { @@ -499,8 +510,8 @@ func (p *parentQuery) Evaluate(t iterator) interface{} { return p } -func (p *parentQuery) Clone() query { - return &parentQuery{Input: p.Input.Clone(), Predicate: p.Predicate} +func (p *parentQuery) Reset() { + p.Input.Reset() } func (p *parentQuery) Test(n NodeNavigator) bool { @@ -535,8 +546,8 @@ func (s *selfQuery) Test(n NodeNavigator) bool { return s.Predicate(n) } -func (s *selfQuery) Clone() query { - return &selfQuery{Input: s.Input.Clone(), Predicate: s.Predicate} +func (s *selfQuery) Reset() { + s.Input.Reset() } // filterQuery is an XPath query for predicate filter. @@ -558,8 +569,8 @@ func (f *filterQuery) do(t iterator) bool { pt := getNodePosition(f.Input) return int(val.Float()) == pt default: - if q, ok := f.Predicate.(query); ok { - return q.Select(t) != nil + if f.Predicate != nil { + return f.Predicate.Select(t) != nil } } return false @@ -577,7 +588,7 @@ func (f *filterQuery) Select(t iterator) NodeNavigator { node := f.Input.Select(t) if node == nil { - return node + return nil } node = node.Copy() @@ -598,8 +609,11 @@ func (f *filterQuery) Evaluate(t iterator) interface{} { return f } -func (f *filterQuery) Clone() query { - return &filterQuery{Input: f.Input.Clone(), Predicate: f.Predicate.Clone()} +func (f *filterQuery) Reset() { + f.posit = 0 + f.positmap = nil + f.Input.Reset() + f.Predicate.Reset() } // functionQuery is an XPath function that returns a computed value for @@ -620,8 +634,8 @@ func (f *functionQuery) Evaluate(t iterator) interface{} { return f.Func(f.Input, t) } -func (f *functionQuery) Clone() query { - return &functionQuery{Input: f.Input.Clone(), Func: f.Func} +func (f *functionQuery) Reset() { + f.Input.Reset() } // transformFunctionQuery diffs from functionQuery where the latter computes a scalar @@ -648,8 +662,9 @@ func (f *transformFunctionQuery) Evaluate(t iterator) interface{} { return f } -func (f *transformFunctionQuery) Clone() query { - return &transformFunctionQuery{Input: f.Input.Clone(), Func: f.Func} +func (f *transformFunctionQuery) Reset() { + f.Input.Reset() + f.iterator = nil } // constantQuery is an XPath constant operand. @@ -665,8 +680,7 @@ func (c *constantQuery) Evaluate(t iterator) interface{} { return c.Val } -func (c *constantQuery) Clone() query { - return c +func (c *constantQuery) Reset() { } type groupQuery struct { @@ -676,22 +690,21 @@ type groupQuery struct { } func (g *groupQuery) Select(t iterator) NodeNavigator { - for { - node := g.Input.Select(t) - if node == nil { - return nil - } - g.posit++ - return node.Copy() + node := g.Input.Select(t) + if node == nil { + return nil } + g.posit++ + return node } func (g *groupQuery) Evaluate(t iterator) interface{} { return g.Input.Evaluate(t) } -func (g *groupQuery) Clone() query { - return &groupQuery{Input: g.Input} +func (g *groupQuery) Reset() { + g.posit = 0 + g.Input.Reset() } func (g *groupQuery) position() int { @@ -724,8 +737,9 @@ func (l *logicalQuery) Evaluate(t iterator) interface{} { return l.Do(t, m, n) } -func (l *logicalQuery) Clone() query { - return &logicalQuery{Left: l.Left.Clone(), Right: l.Right.Clone(), Do: l.Do} +func (l *logicalQuery) Reset() { + l.Left.Reset() + l.Right.Reset() } // numericQuery is an XPath numeric operator expression. @@ -745,8 +759,9 @@ func (n *numericQuery) Evaluate(t iterator) interface{} { return n.Do(m, k) } -func (n *numericQuery) Clone() query { - return &numericQuery{Left: n.Left.Clone(), Right: n.Right.Clone(), Do: n.Do} +func (n *numericQuery) Reset() { + n.Left.Reset() + n.Right.Reset() } type booleanQuery struct { @@ -835,8 +850,10 @@ func (b *booleanQuery) Evaluate(t iterator) interface{} { return asBool(t, m) } -func (b *booleanQuery) Clone() query { - return &booleanQuery{IsOr: b.IsOr, Left: b.Left.Clone(), Right: b.Right.Clone()} +func (b *booleanQuery) Reset() { + b.iterator = nil + b.Left.Reset() + b.Right.Reset() } type unionQuery struct { @@ -892,8 +909,60 @@ func (u *unionQuery) Evaluate(t iterator) interface{} { return u } -func (u *unionQuery) Clone() query { - return &unionQuery{Left: u.Left.Clone(), Right: u.Right.Clone()} +func (u *unionQuery) Reset() { + u.Left.Reset() + u.Right.Reset() + u.iterator = nil +} + +type cacheQuery struct { + posit int + buffer []NodeNavigator + iterator func() NodeNavigator + + Input query +} + +func (c *cacheQuery) Select(t iterator) NodeNavigator { + if c.iterator == nil { + for { + node := c.Input.Select(t) + if node == nil { + break + } + c.buffer = append(c.buffer, node.Copy()) + } + c.iterator = func() NodeNavigator { + if c.posit >= len(c.buffer) { + return nil + } + node := c.buffer[c.posit] + c.posit++ + return node + } + } + return c.iterator() +} + +func (c *cacheQuery) Evaluate(t iterator) interface{} { + c.posit = 0 + c.buffer = nil + return c.Input.Evaluate(t) +} + +func (c *cacheQuery) Reset() { + c.buffer = nil + c.posit = 0 + c.iterator = nil + c.Input.Reset() +} + +func (c *cacheQuery) position() int { + return c.posit +} + +func (c *cacheQuery) count() int { + return len(c.buffer) } func getHashCode(n NodeNavigator) uint64 { diff --git a/xpath.go b/xpath.go index 5f6aa89..3ef6f49 100644 --- a/xpath.go +++ b/xpath.go @@ -121,14 +121,16 @@ func (expr *Expr) Evaluate(root NodeNavigator) interface{} { val := expr.q.Evaluate(iteratorFunc(func() NodeNavigator { return root })) switch val.(type) { case query: - return &NodeIterator{query: expr.q.Clone(), node: root} + expr.q.Reset() + return &NodeIterator{query: expr.q, node: root} } return val } // Select selects a node set using the specified XPath expression. func (expr *Expr) Select(root NodeNavigator) *NodeIterator { - return &NodeIterator{query: expr.q.Clone(), node: root} + expr.q.Reset() + return &NodeIterator{query: expr.q, node: root} } // String returns XPath expression string. diff --git a/xpath_test.go b/xpath_test.go index c997753..5e13b4b 100644 --- a/xpath_test.go +++ b/xpath_test.go @@ -46,6 +46,27 @@ func TestMustCompile(t *testing.T) { } } +func TestSubQuery(t *testing.T) { + testXPath2(t, html, "(//li)", 4) + testXPath4(t, html, "(//li)[2]", `about`) + testXPath4(t, html, "(//li)[last()]", ``) + testXPath4(t, html, "(//li/a[@id])[last()]", `login`) + testXPath4(t, html, "(//li/a[@id])[last()]/@id", `login`) // This test case shoud fix. Skip. + // test cached + expr := MustCompile("(//li/a)[last()]") + for i := 0; i < 10; i++ { + iter := expr.Select(createNavigator(html)) + if iter.MoveNext() { + node := iter.Current().(*TNodeNavigator) + if e, g := "login", node.Value(); e != g { + t.Fatalf("expected %s, but got %s", e, g) + } + } else { + t.Fatal("expected one but got nil.") + } + } +} + func TestSelf(t *testing.T) { testXPath(t, html, ".", "html") testXPath(t, html.FirstChild, ".", "head") @@ -403,6 +424,16 @@ func testXPath3(t *testing.T, root *TNode, expr string, expected *TNode) { } } +func testXPath4(t *testing.T, root *TNode, expr string, expected string) { + node := selectNode(root, expr) + if node == nil { + t.Fatalf("`%s` returns node is nil", expr) + } + if got := node.Value(); got != expected { + t.Fatalf("`%s` expected \n%s,but got\n%s", expr, expected, got) + } +} + func iterateNavs(t *NodeIterator) []*TNodeNavigator { var nodes []*TNodeNavigator for t.MoveNext() {