Skip to content

Commit 8a44955

Browse files
author
AJ Roetker
committed
(bug) Nested bucket aggregations and Clone pattern
Fixes bug in nested bucket aggregations where metric values were duplicated due to duplicate field registration in SubAggregationFields(). Also fixes StartDoc/EndDoc lifecycle for bucket sub-aggregations and min/max comparison logic in optimized aggregations. Adds Clone() method to AggregationBuilder interface for proper deep copying of nested aggregation hierarchies. Adopts setter pattern for aggregation filters (SetPrefixFilter, SetRegexFilter).
1 parent 392fc7a commit 8a44955

File tree

7 files changed

+355
-70
lines changed

7 files changed

+355
-70
lines changed

bucket_aggregation_test.go

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,159 @@ func ExampleAggregationsRequest_filteredTerms() {
243243

244244
searchRequest.Aggregations["product_codes"] = productCodes
245245
}
246+
247+
// TestNestedBucketAggregations tests bucket aggregations nested within other bucket aggregations
248+
func TestNestedBucketAggregations(t *testing.T) {
249+
tmpIndexPath := createTmpIndexPath(t)
250+
defer cleanupTmpIndexPath(t, tmpIndexPath)
251+
252+
indexMapping := NewIndexMapping()
253+
index, err := New(tmpIndexPath, indexMapping)
254+
if err != nil {
255+
t.Fatal(err)
256+
}
257+
defer func() {
258+
err := index.Close()
259+
if err != nil {
260+
t.Fatal(err)
261+
}
262+
}()
263+
264+
// Index documents with region, category, and price
265+
docs := []struct {
266+
ID string
267+
Region string
268+
Category string
269+
Price float64
270+
}{
271+
{"doc1", "US", "Electronics", 999.00},
272+
{"doc2", "US", "Electronics", 799.00},
273+
{"doc3", "US", "Books", 29.99},
274+
{"doc4", "US", "Books", 19.99},
275+
{"doc5", "EU", "Electronics", 899.00},
276+
{"doc6", "EU", "Electronics", 699.00},
277+
{"doc7", "EU", "Books", 24.99},
278+
{"doc8", "APAC", "Electronics", 1099.00},
279+
{"doc9", "APAC", "Books", 34.99},
280+
}
281+
282+
batch := index.NewBatch()
283+
for _, doc := range docs {
284+
data := map[string]interface{}{
285+
"region": doc.Region,
286+
"category": doc.Category,
287+
"price": doc.Price,
288+
}
289+
err := batch.Index(doc.ID, data)
290+
if err != nil {
291+
t.Fatal(err)
292+
}
293+
}
294+
err = index.Batch(batch)
295+
if err != nil {
296+
t.Fatal(err)
297+
}
298+
299+
// Test nested bucket aggregation: Group by region, then by category within each region
300+
query := NewMatchAllQuery()
301+
searchRequest := NewSearchRequest(query)
302+
303+
// Create nested terms aggregation: region -> category -> avg price
304+
byCategory := NewTermsAggregation("category", 10)
305+
byCategory.AddSubAggregation("avg_price", NewAggregationRequest("avg", "price"))
306+
byCategory.AddSubAggregation("total_revenue", NewAggregationRequest("sum", "price"))
307+
308+
byRegion := NewTermsAggregation("region", 10)
309+
byRegion.AddSubAggregation("by_category", byCategory)
310+
311+
searchRequest.Aggregations = AggregationsRequest{
312+
"by_region": byRegion,
313+
}
314+
searchRequest.Size = 0 // Don't need hits
315+
316+
results, err := index.Search(searchRequest)
317+
if err != nil {
318+
t.Fatal(err)
319+
}
320+
321+
regionAgg, ok := results.Aggregations["by_region"]
322+
if !ok {
323+
t.Fatal("Expected by_region aggregation")
324+
}
325+
326+
if len(regionAgg.Buckets) != 3 {
327+
t.Fatalf("Expected 3 region buckets, got %d", len(regionAgg.Buckets))
328+
}
329+
330+
// Find US region bucket
331+
var usBucket *search.Bucket
332+
for _, bucket := range regionAgg.Buckets {
333+
if bucket.Key == "us" { // lowercase due to text analysis
334+
usBucket = bucket
335+
break
336+
}
337+
}
338+
339+
if usBucket == nil {
340+
t.Fatal("US region bucket not found")
341+
}
342+
343+
if usBucket.Count != 4 {
344+
t.Fatalf("Expected US count 4, got %d", usBucket.Count)
345+
}
346+
347+
// Check nested category aggregation within US region
348+
if usBucket.Aggregations == nil {
349+
t.Fatal("Expected sub-aggregations in US bucket")
350+
}
351+
352+
categoryAgg, ok := usBucket.Aggregations["by_category"]
353+
if !ok {
354+
t.Fatal("Expected by_category sub-aggregation in US bucket")
355+
}
356+
357+
if len(categoryAgg.Buckets) != 2 {
358+
t.Fatalf("Expected 2 category buckets in US region, got %d", len(categoryAgg.Buckets))
359+
}
360+
361+
// Find Electronics category in US region
362+
var electronicsCategory *search.Bucket
363+
for _, bucket := range categoryAgg.Buckets {
364+
if bucket.Key == "electronics" {
365+
electronicsCategory = bucket
366+
break
367+
}
368+
}
369+
370+
if electronicsCategory == nil {
371+
t.Fatal("Electronics category not found in US region")
372+
}
373+
374+
if electronicsCategory.Count != 2 {
375+
t.Fatalf("Expected 2 electronics items in US, got %d", electronicsCategory.Count)
376+
}
377+
378+
// Check metric sub-aggregations within category
379+
avgPrice := electronicsCategory.Aggregations["avg_price"]
380+
if avgPrice == nil {
381+
t.Fatal("Expected avg_price in electronics category")
382+
}
383+
384+
expectedAvg := 899.0 // (999 + 799) / 2
385+
actualAvg := avgPrice.Value.(float64)
386+
if actualAvg < expectedAvg-1 || actualAvg > expectedAvg+1 {
387+
t.Fatalf("Expected US electronics avg price around %f, got %f (note: if sum is doubled, count must also be doubled to get correct avg)", expectedAvg, actualAvg)
388+
}
389+
390+
totalRevenue := electronicsCategory.Aggregations["total_revenue"]
391+
if totalRevenue == nil {
392+
t.Fatal("Expected total_revenue in electronics category")
393+
}
394+
395+
// Verify total revenue
396+
expectedTotal := 1798.0 // 999 + 799
397+
actualTotal := totalRevenue.Value.(float64)
398+
if actualTotal != expectedTotal {
399+
t.Fatalf("Expected US electronics total %f, got %f", expectedTotal, actualTotal)
400+
}
401+
}

index_impl.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"io"
2121
"os"
2222
"path/filepath"
23+
"regexp"
2324
"strconv"
2425
"sync"
2526
"sync/atomic"
@@ -643,16 +644,26 @@ func buildAggregation(aggRequest *AggregationRequest) (search.AggregationBuilder
643644
if aggRequest.Size != nil {
644645
size = *aggRequest.Size
645646
}
646-
termsAgg, err := aggregation.NewTermsAggregation(
647+
termsAgg := aggregation.NewTermsAggregation(
647648
aggRequest.Field,
648649
size,
649-
aggRequest.TermPrefix,
650-
aggRequest.TermPattern,
651650
subAggBuilders,
652651
)
653-
if err != nil {
654-
return nil, fmt.Errorf("error creating terms aggregation: %w", err)
652+
653+
// Set prefix filter if provided
654+
if aggRequest.TermPrefix != "" {
655+
termsAgg.SetPrefixFilter(aggRequest.TermPrefix)
655656
}
657+
658+
// Compile and set regex filter if provided
659+
if aggRequest.TermPattern != "" {
660+
regex, err := regexp.Compile(aggRequest.TermPattern)
661+
if err != nil {
662+
return nil, fmt.Errorf("error compiling regex pattern for aggregation: %v", err)
663+
}
664+
termsAgg.SetRegexFilter(regex)
665+
}
666+
656667
return termsAgg, nil
657668

658669
case "range":

search.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package bleve
1717
import (
1818
"fmt"
1919
"reflect"
20+
"regexp"
2021
"sort"
2122
"strconv"
2223
"time"
@@ -331,6 +332,16 @@ func (ar *AggregationRequest) AddSubAggregation(name string, subAgg *Aggregation
331332
ar.Aggregations[name] = subAgg
332333
}
333334

335+
// SetPrefixFilter sets the prefix filter for terms aggregations.
336+
func (ar *AggregationRequest) SetPrefixFilter(prefix string) {
337+
ar.TermPrefix = prefix
338+
}
339+
340+
// SetRegexFilter sets the regex pattern filter for terms aggregations.
341+
func (ar *AggregationRequest) SetRegexFilter(pattern string) {
342+
ar.TermPattern = pattern
343+
}
344+
334345
// Validate validates the aggregation request
335346
func (ar *AggregationRequest) Validate() error {
336347
validTypes := map[string]bool{
@@ -347,6 +358,14 @@ func (ar *AggregationRequest) Validate() error {
347358
return fmt.Errorf("aggregation field cannot be empty")
348359
}
349360

361+
// Validate regex pattern if provided
362+
if ar.TermPattern != "" {
363+
_, err := regexp.Compile(ar.TermPattern)
364+
if err != nil {
365+
return fmt.Errorf("invalid term pattern: %v", err)
366+
}
367+
}
368+
350369
// Validate bucket-specific configuration
351370
if ar.Type == "terms" {
352371
if ar.Size != nil && *ar.Size < 0 {

0 commit comments

Comments
 (0)