Skip to content

Commit 1cbc9c7

Browse files
efahkKwame Efah
authored andcommitted
Update exposure event tracking and add docs
1 parent 0f8c84e commit 1cbc9c7

File tree

7 files changed

+1655
-30
lines changed

7 files changed

+1655
-30
lines changed

demo/local_flags.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
logging.basicConfig(level=logging.INFO)
77

88
# Configure your project token, the feature flag to test, and user context to evaluate.
9-
PROJECT_TOKEN = ""
9+
PROJECT_TOKEN = "044781e247eabd8e9b73fad8a8c093d2"
1010
FLAG_KEY = "sample-flag"
1111
FLAG_FALLBACK_VARIANT = "control"
1212
USER_CONTEXT = { "distinct_id": "sample-distinct-id" }
@@ -18,14 +18,14 @@
1818
# Use the correct data residency endpoint for your project.
1919
API_HOST = "api-eu.mixpanel.com"
2020

21-
async def main():
21+
def main():
2222
local_config = mixpanel.LocalFlagsConfig(api_host=API_HOST, enable_polling=SHOULD_POLL_CONTINOUSLY, polling_interval_in_seconds=POLLING_INTERVAL_IN_SECONDS)
2323

2424
# Optionally use mixpanel client as a context manager, that will ensure shutdown of resources used by feature flagging
25-
async with mixpanel.Mixpanel(PROJECT_TOKEN, local_flags_config=local_config) as mp:
26-
await mp.local_flags.astart_polling_for_definitions()
25+
with mixpanel.Mixpanel(PROJECT_TOKEN, local_flags_config=local_config) as mp:
26+
mp.local_flags.start_polling_for_definitions()
2727
variant_value = mp.local_flags.get_variant_value(FLAG_KEY, FLAG_FALLBACK_VARIANT, USER_CONTEXT)
2828
print(f"Variant value: {variant_value}")
2929

3030
if __name__ == '__main__':
31-
asyncio.run(main())
31+
main()

mixpanel/flags/local_feature_flags.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import threading
66
from datetime import datetime, timedelta
77
from typing import Dict, Any, Callable, Optional
8-
from concurrent.futures import Future, ThreadPoolExecutor
98
from .types import ExperimentationFlag, ExperimentationFlags, SelectedVariant, LocalFlagsConfig, Rollout
109
from .utils import REQUEST_HEADERS, normalized_hash, prepare_common_query_params, EXPOSURE_EVENT
1110

@@ -16,13 +15,20 @@ class LocalFeatureFlagsProvider:
1615
FLAGS_DEFINITIONS_URL_PATH = "/flags/definitions"
1716

1817
def __init__(self, token: str, config: LocalFlagsConfig, version: str, tracker: Callable) -> None:
18+
"""
19+
Initializes the LocalFeatureFlagsProvider
20+
:param str token: your project's Mixpanel token
21+
:param LocalFlagsConfig config: configuration options for the local feature flags provider
22+
:param str version: the version of the Mixpanel library being used, just for tracking
23+
:param str tracker: A function used to track flags exposure events to mixpanel
24+
"""
1925
self._token: str = token
2026
self._config: LocalFlagsConfig = config
2127
self._version = version
2228
self._tracker: Callable = tracker
23-
self._executor: ThreadPoolExecutor = config.custom_executor or ThreadPoolExecutor(max_workers=5)
2429

2530
self._flag_definitions: Dict[str, ExperimentationFlag] = dict()
31+
self._are_flags_ready = False
2632

2733
httpx_client_parameters = {
2834
"base_url": f"https://{config.api_host}",
@@ -37,29 +43,41 @@ def __init__(self, token: str, config: LocalFlagsConfig, version: str, tracker:
3743
self._sync_client: httpx.Client = httpx.Client(**httpx_client_parameters)
3844

3945
self._async_polling_task: Optional[asyncio.Task] = None
40-
self._sync_polling_task: Optional[Future] = None
46+
self._sync_polling_task: Optional[threading.Thread] = None
4147

4248
self._sync_stop_event = threading.Event()
4349

4450
def start_polling_for_definitions(self):
51+
"""
52+
Fetches flag definitions for the current project.
53+
If configured by the caller, starts a background thread to poll for updates at regular intervals, if one does not already exist.
54+
"""
4555
self._fetch_flag_definitions()
4656

4757
if self._config.enable_polling:
4858
if not self._sync_polling_task and not self._async_polling_task:
4959
self._sync_stop_event.clear()
50-
self._sync_polling_task = self._executor.submit(self._start_continuous_polling)
60+
self._sync_polling_task = threading.Thread(target=self._start_continuous_polling, daemon=True)
61+
self._sync_polling_task.start()
5162
else:
52-
logging.error("A polling task is already running")
63+
logging.warning("A polling task is already running")
5364

5465
def stop_polling_for_definitions(self):
66+
"""
67+
If there exists a reference to a background thread polling for flag definition updates, signal it to stop and clear the reference.
68+
Once stopped, the polling thread cannot be restarted.
69+
"""
5570
if self._sync_polling_task:
5671
self._sync_stop_event.set()
57-
self._sync_polling_task.cancel()
5872
self._sync_polling_task = None
5973
else:
6074
logging.info("There is no polling task to cancel.")
6175

6276
async def astart_polling_for_definitions(self):
77+
"""
78+
Fetches flag definitions for the current project.
79+
If configured by the caller, starts an async task on the event loop to poll for updates at regular intervals, if one does not already exist.
80+
"""
6381
await self._afetch_flag_definitions()
6482

6583
if self._config.enable_polling:
@@ -69,6 +87,9 @@ async def astart_polling_for_definitions(self):
6987
logging.error("A polling task is already running")
7088

7189
async def astop_polling_for_definitions(self):
90+
"""
91+
If there exists an async task to poll for flag definition updates, cancel the task and clear the reference to it.
92+
"""
7293
if self._async_polling_task:
7394
self._async_polling_task.cancel()
7495
self._async_polling_task = None
@@ -94,20 +115,39 @@ def _start_continuous_polling(self):
94115

95116
def are_flags_ready(self) -> bool:
96117
"""
97-
Check if flag definitions have been loaded and are ready for use.
98-
:return: True if flag definitions are populated, False otherwise.
118+
Check if the call to fetch flag definitions has been made successfully.
99119
"""
100-
return bool(self._flag_definitions)
120+
return self._are_flags_ready
101121

102122
def get_variant_value(self, flag_key: str, fallback_value: Any, context: Dict[str, Any]) -> Any:
123+
"""
124+
Get the value of a feature flag variant.
125+
126+
:param str flag_key: The key of the feature flag to evaluate
127+
:param Any fallback_value: The default value to return if the flag is not found or evaluation fails
128+
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
129+
"""
103130
variant = self.get_variant(flag_key, SelectedVariant(variant_value=fallback_value), context)
104131
return variant.variant_value
105132

106133
def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
134+
"""
135+
Check if a feature flag is enabled for the given context.
136+
137+
:param str flag_key: The key of the feature flag to check
138+
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
139+
"""
107140
variant_value = self.get_variant_value(flag_key, False, context)
108141
return bool(variant_value)
109142

110143
def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]) -> SelectedVariant:
144+
"""
145+
Gets the selected variant for a feature flag
146+
147+
:param str flag_key: The key of the feature flag to evaluate
148+
:param SelectedVariant fallback_value: The default variant to return if evaluation fails
149+
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
150+
"""
111151
start_time = time.perf_counter()
112152
flag_definition = self._flag_definitions.get(flag_key)
113153

@@ -125,7 +165,7 @@ def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: D
125165
if rollout := self._get_assigned_rollout(flag_definition, context_value, context):
126166
variant = self._get_assigned_variant(flag_definition, context_value, flag_key, rollout)
127167
end_time = time.perf_counter()
128-
self.track_exposure(flag_key, variant, end_time - start_time, context)
168+
self._track_exposure(flag_key, variant, end_time - start_time, context)
129169
return variant
130170

131171
logger.info(f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}")
@@ -237,10 +277,11 @@ def _handle_response(self, response: httpx.Response, start_time: datetime, end_t
237277
logger.exception("Failed to parse flag definitions")
238278

239279
self._flag_definitions = flags
280+
self._are_flags_ready = True
240281
logger.info(f"Successfully fetched {len(self._flag_definitions)} flag definitions")
241282

242283

243-
def track_exposure(self, flag_key: str, variant: SelectedVariant, latency_in_seconds: float, context: Dict[str, Any]):
284+
def _track_exposure(self, flag_key: str, variant: SelectedVariant, latency_in_seconds: float, context: Dict[str, Any]):
244285
if distinct_id := context.get("distinct_id"):
245286
properties = {
246287
'Experiment name': flag_key,
@@ -249,7 +290,8 @@ def track_exposure(self, flag_key: str, variant: SelectedVariant, latency_in_sec
249290
"Flag evaluation mode": "local",
250291
"Variant fetch latency (ms)": latency_in_seconds * 1000
251292
}
252-
self._executor.submit(self._tracker, distinct_id, EXPOSURE_EVENT, properties)
293+
294+
self._tracker(distinct_id, EXPOSURE_EVENT, properties)
253295
else:
254296
logging.error("Cannot track exposure event without a distinct_id in the context")
255297

mixpanel/flags/remote_feature_flags.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from asgiref.sync import sync_to_async
99

1010
from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse
11-
from concurrent.futures import ThreadPoolExecutor
1211
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params
1312

1413
logger = logging.getLogger(__name__)
@@ -22,7 +21,6 @@ def __init__(self, token: str, config: RemoteFlagsConfig, version: str, tracker:
2221
self._config: RemoteFlagsConfig = config
2322
self._version: str = version
2423
self._tracker: Callable = tracker
25-
self._executor: ThreadPoolExecutor = config.custom_executor or ThreadPoolExecutor(max_workers=5)
2624

2725
httpx_client_parameters = {
2826
"base_url": f"https://{config.api_host}",
@@ -36,10 +34,24 @@ def __init__(self, token: str, config: RemoteFlagsConfig, version: str, tracker:
3634
self._request_params_base = prepare_common_query_params(self._token, version)
3735

3836
async def aget_variant_value(self, flag_key: str, fallback_value: Any, context: Dict[str, Any]) -> Any:
37+
"""
38+
Gets the selected variant value of a feature flag variant for the current user context from remote server.
39+
40+
:param str flag_key: The key of the feature flag to evaluate
41+
:param Any fallback_value: The default value to return if the flag is not found or evaluation fails
42+
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
43+
"""
3944
variant = await self.aget_variant(flag_key, SelectedVariant(variant_value=fallback_value), context)
4045
return variant.variant_value
4146

4247
async def aget_variant(self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]) -> SelectedVariant:
48+
"""
49+
Asynchronously gets the selected variant of a feature flag variant for the current user context from remote server.
50+
51+
:param str flag_key: The key of the feature flag to evaluate
52+
:param SelectedVariant fallback_value: The default variant to return if evaluation fails
53+
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
54+
"""
4355
try:
4456
params = self._prepare_query_params(flag_key, context)
4557
start_time = datetime.now()
@@ -51,22 +63,42 @@ async def aget_variant(self, flag_key: str, fallback_value: SelectedVariant, con
5163
if not is_fallback and (distinct_id := context.get("distinct_id")):
5264
properties = self._build_tracking_properties(flag_key, selected_variant, start_time, end_time)
5365
asyncio.create_task(
54-
sync_to_async(self._tracker, executor=self._executor, thread_sensitive=False)(distinct_id, EXPOSURE_EVENT, properties))
66+
sync_to_async(self._tracker, thread_sensitive=False)(distinct_id, EXPOSURE_EVENT, properties))
5567

5668
return selected_variant
5769
except Exception:
5870
logging.exception(f"Failed to get remote variant for flag '{flag_key}'")
5971
return fallback_value
6072

6173
async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
74+
"""
75+
Asynchronously checks if a feature flag is enabled for the given context.
76+
77+
:param str flag_key: The key of the feature flag to check
78+
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
79+
"""
6280
variant_value = await self.aget_variant_value(flag_key, False, context)
6381
return bool(variant_value)
6482

6583
def get_variant_value(self, flag_key: str, fallback_value: Any, context: Dict[str, Any]) -> Any:
84+
"""
85+
Synchronously gets the value of a feature flag variant from remote server.
86+
87+
:param str flag_key: The key of the feature flag to evaluate
88+
:param Any fallback_value: The default value to return if the flag is not found or evaluation fails
89+
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
90+
"""
6691
variant = self.get_variant(flag_key, SelectedVariant(variant_value=fallback_value), context)
6792
return variant.variant_value
6893

6994
def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]) -> SelectedVariant:
95+
"""
96+
Synchronously getsthe selected variant for a feature flag from remote server.
97+
98+
:param str flag_key: The key of the feature flag to evaluate
99+
:param SelectedVariant fallback_value: The default variant to return if evaluation fails
100+
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
101+
"""
70102
try:
71103
params = self._prepare_query_params(flag_key, context)
72104
start_time = datetime.now()
@@ -77,14 +109,20 @@ def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: D
77109

78110
if not is_fallback and (distinct_id := context.get("distinct_id")):
79111
properties = self._build_tracking_properties(flag_key, selected_variant, start_time, end_time)
80-
self._executor.submit(self._tracker, distinct_id, EXPOSURE_EVENT, properties)
112+
self._tracker(distinct_id, EXPOSURE_EVENT, properties)
81113

82114
return selected_variant
83115
except Exception:
84116
logging.exception(f"Failed to get remote variant for flag '{flag_key}'")
85117
return fallback_value
86118

87119
def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
120+
"""
121+
Synchronously checks if a feature flag is enabled for the given context.
122+
123+
:param str flag_key: The key of the feature flag to check
124+
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
125+
"""
88126
variant_value = self.get_variant_value(flag_key, False, context)
89127
return bool(variant_value)
90128

mixpanel/flags/test_local_feature_flags.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,22 +227,19 @@ async def test_get_variant_value_tracks_exposure_when_variant_selected(self):
227227
with patch('mixpanel.flags.utils.normalized_hash') as mock_hash:
228228
mock_hash.return_value = 0.5
229229
_ = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
230-
flags._executor.shutdown()
231230
flags._tracker.assert_called_once()
232231

233232
@respx.mock
234233
async def test_get_variant_value_does_not_track_exposure_on_fallback(self):
235234
flags = await self.setup_flags([])
236235
_ = flags.get_variant_value("nonexistent_flag", "fallback", {"distinct_id": "user123"})
237-
flags._executor.shutdown()
238236
flags._tracker.assert_not_called()
239237

240238
@respx.mock
241239
async def test_get_variant_value_does_not_track_exposure_without_distinct_id(self):
242240
flag = create_test_flag(context="company")
243241
flags = await self.setup_flags([flag])
244242
_ = flags.get_variant_value("nonexistent_flag", "fallback", {"company_id": "company123"})
245-
flags._executor.shutdown()
246243
flags._tracker.assert_not_called()
247244

248245
@respx.mock

mixpanel/flags/test_remote_feature_flags.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,12 @@ async def test_get_variant_value_tracks_exposure_event_if_variant_selected(self)
6464
if pending:
6565
await asyncio.gather(*pending, return_exceptions=True)
6666

67-
self._flags._executor.shutdown()
6867
self.mock_tracker.assert_called_once()
6968

7069
@respx.mock
7170
async def test_get_variant_value_does_not_track_exposure_event_if_fallback(self):
7271
respx.get(ENDPOINT).mock(side_effect=httpx.RequestError("Network error"))
7372
await self._flags.aget_variant_value("test_flag", "control", {"distinct_id": "user123"})
74-
self._flags._executor.shutdown()
7573
self.mock_tracker.assert_not_called()
7674

7775
@respx.mock
@@ -135,14 +133,12 @@ def test_get_variant_value_tracks_exposure_event_if_variant_selected(self):
135133
return_value=create_success_response({"test_flag": SelectedVariant(variant_key="treatment", variant_value="treatment")}))
136134

137135
self._flags.get_variant_value("test_flag", "control", {"distinct_id": "user123"})
138-
self._flags._executor.shutdown()
139136
self.mock_tracker.assert_called_once()
140137

141138
@respx.mock
142139
def test_get_variant_value_does_not_track_exposure_event_if_fallback(self):
143140
respx.get(ENDPOINT).mock(side_effect=httpx.RequestError("Network error"))
144141
self._flags.get_variant_value("test_flag", "control", {"distinct_id": "user123"})
145-
self._flags._executor.shutdown()
146142
self.mock_tracker.assert_not_called()
147143

148144
@respx.mock

mixpanel/flags/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Optional, List, Dict, Any
2-
from concurrent.futures import ThreadPoolExecutor
32
from pydantic import BaseModel, ConfigDict
43

54
MIXPANEL_DEFAULT_API_ENDPOINT = "api.mixpanel.com"
@@ -9,7 +8,6 @@ class FlagsConfig(BaseModel):
98

109
api_host: str = "api.mixpanel.com"
1110
request_timeout_in_seconds: int = 10
12-
custom_executor: Optional[ThreadPoolExecutor] = None
1311

1412
class LocalFlagsConfig(FlagsConfig):
1513
enable_polling: bool = True

0 commit comments

Comments
 (0)