33import  json 
44import  urllib .parse 
55import  asyncio 
6- from  datetime  import  datetime   
6+ from  datetime  import  datetime 
77from  typing  import  Dict , Any , Callable 
88from  asgiref .sync  import  sync_to_async 
99
1010from  .types  import  RemoteFlagsConfig , SelectedVariant , RemoteFlagsResponse 
11- from  concurrent .futures  import  ThreadPoolExecutor 
1211from  .utils  import  REQUEST_HEADERS , EXPOSURE_EVENT , prepare_common_query_params 
1312
1413logger  =  logging .getLogger (__name__ )
1514logging .getLogger ("httpx" ).setLevel (logging .ERROR )
1615
16+ 
1717class  RemoteFeatureFlagsProvider :
1818    FLAGS_URL_PATH  =  "/flags" 
1919
20-     def  __init__ (self , token : str , config : RemoteFlagsConfig , version : str , tracker : Callable ) ->  None :
20+     def  __init__ (
21+         self , token : str , config : RemoteFlagsConfig , version : str , tracker : Callable 
22+     ) ->  None :
2123        self ._token : str  =  token 
2224        self ._config : RemoteFlagsConfig  =  config 
2325        self ._version : str  =  version 
2426        self ._tracker : Callable  =  tracker 
25-         self ._executor : ThreadPoolExecutor  =  config .custom_executor  or  ThreadPoolExecutor (max_workers = 5 )
2627
2728        httpx_client_parameters  =  {
2829            "base_url" : f"https://{ config .api_host }  " ,
@@ -31,103 +32,179 @@ def __init__(self, token: str, config: RemoteFlagsConfig, version: str, tracker:
3132            "timeout" : httpx .Timeout (config .request_timeout_in_seconds ),
3233        }
3334
34-         self ._async_client : httpx .AsyncClient  =  httpx .AsyncClient (** httpx_client_parameters )
35+         self ._async_client : httpx .AsyncClient  =  httpx .AsyncClient (
36+             ** httpx_client_parameters 
37+         )
3538        self ._sync_client : httpx .Client  =  httpx .Client (** httpx_client_parameters )
3639        self ._request_params_base  =  prepare_common_query_params (self ._token , version )
3740
38-     async  def  aget_variant_value (self , flag_key : str , fallback_value : Any , context : Dict [str , Any ]) ->  Any :
39-         variant  =  await  self .aget_variant (flag_key , SelectedVariant (variant_value = fallback_value ), context )
41+     async  def  aget_variant_value (
42+         self , flag_key : str , fallback_value : Any , context : Dict [str , Any ]
43+     ) ->  Any :
44+         """ 
45+         Gets the selected variant value of a feature flag variant for the current user context from remote server. 
46+ 
47+         :param str flag_key: The key of the feature flag to evaluate 
48+         :param Any fallback_value: The default value to return if the flag is not found or evaluation fails 
49+         :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context 
50+         """ 
51+         variant  =  await  self .aget_variant (
52+             flag_key , SelectedVariant (variant_value = fallback_value ), context 
53+         )
4054        return  variant .variant_value 
4155
42-     async  def  aget_variant (self , flag_key : str , fallback_value : SelectedVariant , context : Dict [str , Any ]) ->  SelectedVariant :
56+     async  def  aget_variant (
57+         self , flag_key : str , fallback_value : SelectedVariant , context : Dict [str , Any ]
58+     ) ->  SelectedVariant :
59+         """ 
60+         Asynchronously gets the selected variant  of a feature flag variant for the current user context from remote server. 
61+ 
62+         :param str flag_key: The key of the feature flag to evaluate 
63+         :param SelectedVariant fallback_value: The default variant to return if evaluation fails 
64+         :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context 
65+         """ 
4366        try :
4467            params  =  self ._prepare_query_params (flag_key , context )
4568            start_time  =  datetime .now ()
4669            response  =  await  self ._async_client .get (self .FLAGS_URL_PATH , params = params )
4770            end_time  =  datetime .now ()
4871            self ._instrument_call (start_time , end_time )
49-             selected_variant , is_fallback  =  self ._handle_response (flag_key , fallback_value , response )
72+             selected_variant , is_fallback  =  self ._handle_response (
73+                 flag_key , fallback_value , response 
74+             )
5075
5176            if  not  is_fallback  and  (distinct_id  :=  context .get ("distinct_id" )):
52-                 properties  =  self ._build_tracking_properties (flag_key , selected_variant , start_time , end_time )
77+                 properties  =  self ._build_tracking_properties (
78+                     flag_key , selected_variant , start_time , end_time 
79+                 )
5380                asyncio .create_task (
54-                     sync_to_async (self ._tracker , executor = self ._executor , thread_sensitive = False )(distinct_id , EXPOSURE_EVENT , properties ))
81+                     sync_to_async (self ._tracker , thread_sensitive = False )(
82+                         distinct_id , EXPOSURE_EVENT , properties 
83+                     )
84+                 )
5585
5686            return  selected_variant 
5787        except  Exception :
5888            logging .exception (f"Failed to get remote variant for flag '{ flag_key }  '" )
5989            return  fallback_value 
6090
6191    async  def  ais_enabled (self , flag_key : str , context : Dict [str , Any ]) ->  bool :
92+         """ 
93+         Asynchronously checks if a feature flag is enabled for the given context. 
94+ 
95+         :param str flag_key: The key of the feature flag to check 
96+         :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context 
97+         """ 
6298        variant_value  =  await  self .aget_variant_value (flag_key , False , context )
6399        return  bool (variant_value )
64100
65-     def  get_variant_value (self , flag_key : str , fallback_value : Any , context : Dict [str , Any ]) ->  Any :
66-         variant  =  self .get_variant (flag_key , SelectedVariant (variant_value = fallback_value ), context )
101+     def  get_variant_value (
102+         self , flag_key : str , fallback_value : Any , context : Dict [str , Any ]
103+     ) ->  Any :
104+         """ 
105+         Synchronously gets the value of a feature flag variant from remote server. 
106+ 
107+         :param str flag_key: The key of the feature flag to evaluate 
108+         :param Any fallback_value: The default value to return if the flag is not found or evaluation fails 
109+         :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context 
110+         """ 
111+         variant  =  self .get_variant (
112+             flag_key , SelectedVariant (variant_value = fallback_value ), context 
113+         )
67114        return  variant .variant_value 
68115
69-     def  get_variant (self , flag_key : str , fallback_value : SelectedVariant , context : Dict [str , Any ]) ->  SelectedVariant :
116+     def  get_variant (
117+         self , flag_key : str , fallback_value : SelectedVariant , context : Dict [str , Any ]
118+     ) ->  SelectedVariant :
119+         """ 
120+         Synchronously gets the selected variant for a feature flag from remote server. 
121+ 
122+         :param str flag_key: The key of the feature flag to evaluate 
123+         :param SelectedVariant fallback_value: The default variant to return if evaluation fails 
124+         :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context 
125+         """ 
70126        try :
71127            params  =  self ._prepare_query_params (flag_key , context )
72128            start_time  =  datetime .now ()
73129            response  =  self ._sync_client .get (self .FLAGS_URL_PATH , params = params )
74130            end_time  =  datetime .now ()
75131            self ._instrument_call (start_time , end_time )
76-             selected_variant , is_fallback  =  self ._handle_response (flag_key , fallback_value , response )
132+             selected_variant , is_fallback  =  self ._handle_response (
133+                 flag_key , fallback_value , response 
134+             )
77135
78136            if  not  is_fallback  and  (distinct_id  :=  context .get ("distinct_id" )):
79-                 properties  =  self ._build_tracking_properties (flag_key , selected_variant , start_time , end_time )
80-                 self ._executor .submit (self ._tracker , distinct_id , EXPOSURE_EVENT , properties )
137+                 properties  =  self ._build_tracking_properties (
138+                     flag_key , selected_variant , start_time , end_time 
139+                 )
140+                 self ._tracker (distinct_id , EXPOSURE_EVENT , properties )
81141
82142            return  selected_variant 
83143        except  Exception :
84144            logging .exception (f"Failed to get remote variant for flag '{ flag_key }  '" )
85145            return  fallback_value 
86146
87147    def  is_enabled (self , flag_key : str , context : Dict [str , Any ]) ->  bool :
148+         """ 
149+         Synchronously checks if a feature flag is enabled for the given context. 
150+ 
151+         :param str flag_key: The key of the feature flag to check 
152+         :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context 
153+         """ 
88154        variant_value  =  self .get_variant_value (flag_key , False , context )
89155        return  bool (variant_value )
90156
91-     def  _prepare_query_params (self , flag_key : str , context : Dict [str , Any ]) ->  Dict [str , str ]:
157+     def  _prepare_query_params (
158+         self , flag_key : str , context : Dict [str , Any ]
159+     ) ->  Dict [str , str ]:
92160        params  =  self ._request_params_base .copy ()
93-         context_json  =  json .dumps (context ).encode (' utf-8'  )
161+         context_json  =  json .dumps (context ).encode (" utf-8"  )
94162        url_encoded_context  =  urllib .parse .quote (context_json )
95-         params .update ({
96-             'flag_key' : flag_key ,
97-             'context' : url_encoded_context 
98-         })
163+         params .update ({"flag_key" : flag_key , "context" : url_encoded_context })
99164        return  params 
100165
101166    def  _instrument_call (self , start_time : datetime , end_time : datetime ) ->  None :
102167        request_duration  =  end_time  -  start_time 
103168        formatted_start_time  =  start_time .isoformat ()
104169        formatted_end_time  =  end_time .isoformat ()
105-         logging .info (f"Request started at '{ formatted_start_time }  ', completed at '{ formatted_end_time }  ', duration: '{ request_duration .total_seconds ():.3f}  s'" )
106- 
107-     def  _build_tracking_properties (self , flag_key : str , variant : SelectedVariant , start_time : datetime , end_time : datetime ) ->  Dict [str , Any ]:
170+         logging .info (
171+             f"Request started at '{ formatted_start_time }  ', completed at '{ formatted_end_time }  ', duration: '{ request_duration .total_seconds ():.3f}  s'" 
172+         )
173+ 
174+     def  _build_tracking_properties (
175+         self ,
176+         flag_key : str ,
177+         variant : SelectedVariant ,
178+         start_time : datetime ,
179+         end_time : datetime ,
180+     ) ->  Dict [str , Any ]:
108181        request_duration  =  end_time  -  start_time 
109182        formatted_start_time  =  start_time .isoformat ()
110183        formatted_end_time  =  end_time .isoformat ()
111184
112185        return  {
113-             ' Experiment name'  : flag_key ,
114-             ' Variant name'  : variant .variant_key ,
115-             ' $experiment_type' :  ' feature_flag'  ,
186+             " Experiment name"  : flag_key ,
187+             " Variant name"  : variant .variant_key ,
188+             " $experiment_type" :  " feature_flag"  ,
116189            "Flag evaluation mode" : "remote" ,
117190            "Variant fetch start time" : formatted_start_time ,
118191            "Variant fetch complete time" : formatted_end_time ,
119192            "Variant fetch latency (ms)" : request_duration .total_seconds () *  1000 ,
120193        }
121194
122-     def  _handle_response (self , flag_key : str , fallback_value : SelectedVariant , response : httpx .Response ) ->  tuple [SelectedVariant , bool ]:
195+     def  _handle_response (
196+         self , flag_key : str , fallback_value : SelectedVariant , response : httpx .Response 
197+     ) ->  tuple [SelectedVariant , bool ]:
123198        response .raise_for_status ()
124199
125200        flags_response  =  RemoteFlagsResponse .model_validate (response .json ())
126201
127202        if  flag_key  in  flags_response .flags :
128203            return  flags_response .flags [flag_key ], False 
129204        else :
130-             logging .warning (f"Flag '{ flag_key }  ' not found in remote response. Returning fallback, '{ fallback_value }  '" )
205+             logging .warning (
206+                 f"Flag '{ flag_key }  ' not found in remote response. Returning fallback, '{ fallback_value }  '" 
207+             )
131208            return  fallback_value , True 
132209
133210    def  __enter__ (self ):
0 commit comments