1616# under the License.
1717
1818import collections .abc
19+ from copy import deepcopy
20+ from typing import (
21+ TYPE_CHECKING ,
22+ Any ,
23+ ClassVar ,
24+ Dict ,
25+ Iterable ,
26+ MutableMapping ,
27+ Optional ,
28+ Union ,
29+ cast ,
30+ )
1931
2032from .response .aggs import AggResponse , BucketData , FieldBucketData , TopHitsData
21- from .utils import DslBase
33+ from .utils import AttrDict , DslBase , JSONType
2234
35+ if TYPE_CHECKING :
36+ from .query import Query
37+ from .search_base import SearchBase
2338
24- def A (name_or_agg , filter = None , ** params ):
39+
40+ def A (
41+ name_or_agg : Union [MutableMapping [str , Any ], "Agg" , str ],
42+ filter : Optional [Union [str , "Query" ]] = None ,
43+ ** params : Any ,
44+ ) -> "Agg" :
2545 if filter is not None :
2646 if name_or_agg != "filter" :
2747 raise ValueError (
@@ -31,11 +51,11 @@ def A(name_or_agg, filter=None, **params):
3151 params ["filter" ] = filter
3252
3353 # {"terms": {"field": "tags"}, "aggs": {...}}
34- if isinstance (name_or_agg , collections .abc .Mapping ):
54+ if isinstance (name_or_agg , collections .abc .MutableMapping ):
3555 if params :
3656 raise ValueError ("A() cannot accept parameters when passing in a dict." )
3757 # copy to avoid modifying in-place
38- agg = name_or_agg . copy ( )
58+ agg = deepcopy ( name_or_agg )
3959 # pop out nested aggs
4060 aggs = agg .pop ("aggs" , None )
4161 # pop out meta data
@@ -70,48 +90,57 @@ def A(name_or_agg, filter=None, **params):
7090class Agg (DslBase ):
7191 _type_name = "agg"
7292 _type_shortcut = staticmethod (A )
73- name = None
93+ name = ""
7494
75- def __contains__ (self , key ) :
95+ def __contains__ (self , key : str ) -> bool :
7696 return False
7797
78- def to_dict (self ):
98+ def to_dict (self ) -> Dict [ str , JSONType ] :
7999 d = super ().to_dict ()
80- if "meta" in d [self .name ]:
81- d ["meta" ] = d [self .name ].pop ("meta" )
100+ if isinstance (d [self .name ], dict ):
101+ n = cast (Dict [str , JSONType ], d [self .name ])
102+ if "meta" in n :
103+ d ["meta" ] = n .pop ("meta" )
82104 return d
83105
84- def result (self , search , data ) :
106+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
85107 return AggResponse (self , search , data )
86108
87109
88110class AggBase :
89- _param_defs = {
111+ aggs : Dict [str , Agg ]
112+ _base : Agg
113+ _params : Dict [str , Any ]
114+ _param_defs : ClassVar [Dict [str , Any ]] = {
90115 "aggs" : {"type" : "agg" , "hash" : True },
91116 }
92117
93- def __contains__ (self , key ) :
118+ def __contains__ (self , key : str ) -> bool :
94119 return key in self ._params .get ("aggs" , {})
95120
96- def __getitem__ (self , agg_name ):
97- agg = self ._params .setdefault ("aggs" , {})[agg_name ] # propagate KeyError
121+ def __getitem__ (self , agg_name : str ) -> Agg :
122+ agg = cast (
123+ Agg , self ._params .setdefault ("aggs" , {})[agg_name ]
124+ ) # propagate KeyError
98125
99126 # make sure we're not mutating a shared state - whenever accessing a
100127 # bucket, return a shallow copy of it to be safe
101128 if isinstance (agg , Bucket ):
102- agg = A (agg .name , ** agg ._params )
129+ agg = A (agg .name , filter = None , ** agg ._params )
103130 # be sure to store the copy so any modifications to it will affect us
104131 self ._params ["aggs" ][agg_name ] = agg
105132
106133 return agg
107134
108- def __setitem__ (self , agg_name , agg ) :
135+ def __setitem__ (self , agg_name : str , agg : Agg ) -> None :
109136 self .aggs [agg_name ] = A (agg )
110137
111- def __iter__ (self ):
138+ def __iter__ (self ) -> Iterable [ str ] :
112139 return iter (self .aggs )
113140
114- def _agg (self , bucket , name , agg_type , * args , ** params ):
141+ def _agg (
142+ self , bucket : bool , name : str , agg_type : str , * args : Any , ** params : Any
143+ ) -> Agg :
115144 agg = self [name ] = A (agg_type , * args , ** params )
116145
117146 # For chaining - when creating new buckets return them...
@@ -121,29 +150,31 @@ def _agg(self, bucket, name, agg_type, *args, **params):
121150 else :
122151 return self ._base
123152
124- def metric (self , name , agg_type , * args , ** params ) :
153+ def metric (self , name : str , agg_type : str , * args : Any , ** params : Any ) -> Agg :
125154 return self ._agg (False , name , agg_type , * args , ** params )
126155
127- def bucket (self , name , agg_type , * args , ** params ) :
156+ def bucket (self , name : str , agg_type : str , * args : Any , ** params : Any ) -> Agg :
128157 return self ._agg (True , name , agg_type , * args , ** params )
129158
130- def pipeline (self , name , agg_type , * args , ** params ) :
159+ def pipeline (self , name : str , agg_type : str , * args : Any , ** params : Any ) -> Agg :
131160 return self ._agg (False , name , agg_type , * args , ** params )
132161
133- def result (self , search , data ) :
162+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
134163 return BucketData (self , search , data )
135164
136165
137166class Bucket (AggBase , Agg ):
138- def __init__ (self , ** params ):
167+ def __init__ (self , ** params : Any ):
139168 super ().__init__ (** params )
140169 # remember self for chaining
141170 self ._base = self
142171
143- def to_dict (self ):
172+ def to_dict (self ) -> Dict [ str , JSONType ] :
144173 d = super (AggBase , self ).to_dict ()
145- if "aggs" in d [self .name ]:
146- d ["aggs" ] = d [self .name ].pop ("aggs" )
174+ if isinstance (d [self .name ], dict ):
175+ n = cast (AttrDict [str , Any ], d [self .name ])
176+ if "aggs" in n :
177+ d ["aggs" ] = n .pop ("aggs" )
147178 return d
148179
149180
@@ -154,14 +185,16 @@ class Filter(Bucket):
154185 "aggs" : {"type" : "agg" , "hash" : True },
155186 }
156187
157- def __init__ (self , filter = None , ** params ):
188+ def __init__ (self , filter : Optional [ Union [ str , "Query" ]] = None , ** params : Any ):
158189 if filter is not None :
159190 params ["filter" ] = filter
160191 super ().__init__ (** params )
161192
162- def to_dict (self ):
193+ def to_dict (self ) -> Dict [ str , JSONType ] :
163194 d = super ().to_dict ()
164- d [self .name ].update (d [self .name ].pop ("filter" , {}))
195+ if isinstance (d [self .name ], dict ):
196+ n = cast (AttrDict [str , Any ], d [self .name ])
197+ n .update (n .pop ("filter" , {}))
165198 return d
166199
167200
@@ -189,7 +222,7 @@ class Parent(Bucket):
189222class DateHistogram (Bucket ):
190223 name = "date_histogram"
191224
192- def result (self , search , data ) :
225+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
193226 return FieldBucketData (self , search , data )
194227
195228
@@ -232,7 +265,7 @@ class Global(Bucket):
232265class Histogram (Bucket ):
233266 name = "histogram"
234267
235- def result (self , search , data ) :
268+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
236269 return FieldBucketData (self , search , data )
237270
238271
@@ -259,7 +292,7 @@ class Range(Bucket):
259292class RareTerms (Bucket ):
260293 name = "rare_terms"
261294
262- def result (self , search , data ) :
295+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
263296 return FieldBucketData (self , search , data )
264297
265298
@@ -278,7 +311,7 @@ class SignificantText(Bucket):
278311class Terms (Bucket ):
279312 name = "terms"
280313
281- def result (self , search , data ) :
314+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
282315 return FieldBucketData (self , search , data )
283316
284317
@@ -305,7 +338,7 @@ class Composite(Bucket):
305338class VariableWidthHistogram (Bucket ):
306339 name = "variable_width_histogram"
307340
308- def result (self , search , data ) :
341+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
309342 return FieldBucketData (self , search , data )
310343
311344
@@ -321,7 +354,7 @@ class CategorizeText(Bucket):
321354class TopHits (Agg ):
322355 name = "top_hits"
323356
324- def result (self , search , data ) :
357+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
325358 return TopHitsData (self , search , data )
326359
327360
0 commit comments