Skip to content

Commit f274982

Browse files
committed
Fix metrics timer stats deadlock
1 parent c03159e commit f274982

File tree

2 files changed

+166
-138
lines changed

2 files changed

+166
-138
lines changed

src/core/services/metrics_service.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@
1414
_counters: dict[str, int] = defaultdict(int)
1515
_timers: dict[str, list[float]] = defaultdict(list)
1616

17+
def _calculate_timer_stats(durations: list[float]) -> dict[str, Any]:
18+
"""Calculate statistics for the provided list of durations."""
19+
if not durations:
20+
return {
21+
"count": 0,
22+
"total": 0.0,
23+
"average": 0.0,
24+
"min": 0.0,
25+
"max": 0.0,
26+
}
27+
28+
total_duration = sum(durations)
29+
count = len(durations)
30+
return {
31+
"count": count,
32+
"total": total_duration,
33+
"average": total_duration / count,
34+
"min": min(durations),
35+
"max": max(durations),
36+
}
37+
38+
1739

1840
def inc(name: str, by: int = 1) -> None:
1941
"""Increment a counter metric by the specified amount.
@@ -90,23 +112,9 @@ def get_timer_stats(name: str) -> dict[str, Any]:
90112
A dictionary containing count, total, average, min, and max durations
91113
"""
92114
with _lock:
93-
durations = _timers.get(name, [])
94-
if not durations:
95-
return {
96-
"count": 0,
97-
"total": 0.0,
98-
"average": 0.0,
99-
"min": 0.0,
100-
"max": 0.0,
101-
}
115+
durations = list(_timers.get(name, []))
102116

103-
return {
104-
"count": len(durations),
105-
"total": sum(durations),
106-
"average": sum(durations) / len(durations),
107-
"min": min(durations),
108-
"max": max(durations),
109-
}
117+
return _calculate_timer_stats(durations)
110118

111119

112120
def get_all_timer_stats() -> dict[str, dict[str, Any]]:
@@ -116,7 +124,9 @@ def get_all_timer_stats() -> dict[str, dict[str, Any]]:
116124
A dictionary mapping timer names to their statistics
117125
"""
118126
with _lock:
119-
return {name: get_timer_stats(name) for name in _timers}
127+
timers_snapshot = {name: list(durations) for name, durations in _timers.items()}
128+
129+
return {name: _calculate_timer_stats(durations) for name, durations in timers_snapshot.items()}
120130

121131

122132
def log_performance_stats() -> None:
Lines changed: 139 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,139 @@
1-
"""
2-
Unit tests for the metrics service.
3-
"""
4-
5-
from __future__ import annotations
6-
7-
import time
8-
9-
from src.core.services import metrics_service
10-
11-
12-
class TestMetricsService:
13-
"""Test the metrics service functionality."""
14-
15-
def setup_method(self):
16-
"""Reset metrics before each test."""
17-
# Clear counters and timers
18-
with metrics_service._lock:
19-
metrics_service._counters.clear()
20-
metrics_service._timers.clear()
21-
22-
def test_counter_increment(self):
23-
"""Test basic counter increment functionality."""
24-
metrics_service.inc("test.counter")
25-
assert metrics_service.get("test.counter") == 1
26-
27-
metrics_service.inc("test.counter", by=5)
28-
assert metrics_service.get("test.counter") == 6
29-
30-
def test_counter_get_nonexistent(self):
31-
"""Test getting a counter that doesn't exist returns 0."""
32-
assert metrics_service.get("nonexistent.counter") == 0
33-
34-
def test_counter_snapshot(self):
35-
"""Test getting a snapshot of all counters."""
36-
metrics_service.inc("counter1")
37-
metrics_service.inc("counter2", by=3)
38-
metrics_service.inc("counter3", by=10)
39-
40-
snapshot = metrics_service.snapshot()
41-
assert snapshot["counter1"] == 1
42-
assert snapshot["counter2"] == 3
43-
assert snapshot["counter3"] == 10
44-
45-
def test_record_duration(self):
46-
"""Test recording duration measurements."""
47-
metrics_service.record_duration("test.timer", 0.5)
48-
metrics_service.record_duration("test.timer", 1.0)
49-
metrics_service.record_duration("test.timer", 0.75)
50-
51-
stats = metrics_service.get_timer_stats("test.timer")
52-
assert stats["count"] == 3
53-
assert stats["total"] == 2.25
54-
assert stats["average"] == 0.75
55-
assert stats["min"] == 0.5
56-
assert stats["max"] == 1.0
57-
58-
def test_timer_context_manager(self):
59-
"""Test the timer context manager."""
60-
with metrics_service.timer("test.operation"):
61-
time.sleep(0.01) # Sleep for 10ms
62-
63-
stats = metrics_service.get_timer_stats("test.operation")
64-
assert stats["count"] == 1
65-
assert stats["total"] >= 0.01 # Should be at least 10ms
66-
assert stats["average"] >= 0.01
67-
68-
def test_timer_stats_empty(self):
69-
"""Test getting stats for a timer with no measurements."""
70-
stats = metrics_service.get_timer_stats("nonexistent.timer")
71-
assert stats["count"] == 0
72-
assert stats["total"] == 0.0
73-
assert stats["average"] == 0.0
74-
assert stats["min"] == 0.0
75-
assert stats["max"] == 0.0
76-
77-
def test_get_all_timer_stats(self):
78-
"""Test getting stats for all timers."""
79-
metrics_service.record_duration("timer1", 0.5)
80-
metrics_service.record_duration("timer2", 1.0)
81-
82-
all_stats = metrics_service.get_all_timer_stats()
83-
assert "timer1" in all_stats
84-
assert "timer2" in all_stats
85-
assert all_stats["timer1"]["count"] == 1
86-
assert all_stats["timer2"]["count"] == 1
87-
88-
def test_tool_call_processing_metrics(self):
89-
"""Test metrics specific to tool call processing."""
90-
# Simulate processing and skipping messages
91-
metrics_service.inc("tool_call.messages.processed", by=5)
92-
metrics_service.inc("tool_call.messages.skipped", by=45)
93-
94-
assert metrics_service.get("tool_call.messages.processed") == 5
95-
assert metrics_service.get("tool_call.messages.skipped") == 45
96-
97-
# Calculate skip rate
98-
total = 5 + 45
99-
skip_rate = (45 / total) * 100
100-
assert skip_rate == 90.0
101-
102-
def test_log_performance_stats_with_data(self, caplog):
103-
"""Test logging performance statistics with data."""
104-
metrics_service.inc("tool_call.messages.processed", by=10)
105-
metrics_service.inc("tool_call.messages.skipped", by=90)
106-
metrics_service.record_duration("tool_call.processing.duration", 0.05)
107-
metrics_service.record_duration("tool_call.processing.duration", 0.03)
108-
109-
metrics_service.log_performance_stats()
110-
111-
# Check that log messages were generated
112-
assert any("processed=10" in record.message for record in caplog.records)
113-
assert any("skipped=90" in record.message for record in caplog.records)
114-
assert any("skip_rate=90.0%" in record.message for record in caplog.records)
115-
116-
def test_log_performance_stats_no_data(self, caplog):
117-
"""Test logging performance statistics with no data."""
118-
metrics_service.log_performance_stats()
119-
120-
# Should not log anything when there's no data
121-
assert len(caplog.records) == 0
1+
"""
2+
Unit tests for the metrics service.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import threading
8+
import time
9+
10+
from src.core.services import metrics_service
11+
12+
13+
class TestMetricsService:
14+
"""Test the metrics service functionality."""
15+
16+
def setup_method(self):
17+
"""Reset metrics before each test."""
18+
# Clear counters and timers
19+
with metrics_service._lock:
20+
metrics_service._counters.clear()
21+
metrics_service._timers.clear()
22+
23+
def test_counter_increment(self):
24+
"""Test basic counter increment functionality."""
25+
metrics_service.inc("test.counter")
26+
assert metrics_service.get("test.counter") == 1
27+
28+
metrics_service.inc("test.counter", by=5)
29+
assert metrics_service.get("test.counter") == 6
30+
31+
def test_counter_get_nonexistent(self):
32+
"""Test getting a counter that doesn't exist returns 0."""
33+
assert metrics_service.get("nonexistent.counter") == 0
34+
35+
def test_counter_snapshot(self):
36+
"""Test getting a snapshot of all counters."""
37+
metrics_service.inc("counter1")
38+
metrics_service.inc("counter2", by=3)
39+
metrics_service.inc("counter3", by=10)
40+
41+
snapshot = metrics_service.snapshot()
42+
assert snapshot["counter1"] == 1
43+
assert snapshot["counter2"] == 3
44+
assert snapshot["counter3"] == 10
45+
46+
def test_record_duration(self):
47+
"""Test recording duration measurements."""
48+
metrics_service.record_duration("test.timer", 0.5)
49+
metrics_service.record_duration("test.timer", 1.0)
50+
metrics_service.record_duration("test.timer", 0.75)
51+
52+
stats = metrics_service.get_timer_stats("test.timer")
53+
assert stats["count"] == 3
54+
assert stats["total"] == 2.25
55+
assert stats["average"] == 0.75
56+
assert stats["min"] == 0.5
57+
assert stats["max"] == 1.0
58+
59+
def test_timer_context_manager(self):
60+
"""Test the timer context manager."""
61+
with metrics_service.timer("test.operation"):
62+
time.sleep(0.01) # Sleep for 10ms
63+
64+
stats = metrics_service.get_timer_stats("test.operation")
65+
assert stats["count"] == 1
66+
assert stats["total"] >= 0.01 # Should be at least 10ms
67+
assert stats["average"] >= 0.01
68+
69+
def test_timer_stats_empty(self):
70+
"""Test getting stats for a timer with no measurements."""
71+
stats = metrics_service.get_timer_stats("nonexistent.timer")
72+
assert stats["count"] == 0
73+
assert stats["total"] == 0.0
74+
assert stats["average"] == 0.0
75+
assert stats["min"] == 0.0
76+
assert stats["max"] == 0.0
77+
78+
def test_get_all_timer_stats(self):
79+
"""Test getting stats for all timers."""
80+
metrics_service.record_duration("timer1", 0.5)
81+
metrics_service.record_duration("timer2", 1.0)
82+
83+
all_stats = metrics_service.get_all_timer_stats()
84+
assert "timer1" in all_stats
85+
assert "timer2" in all_stats
86+
assert all_stats["timer1"]["count"] == 1
87+
assert all_stats["timer2"]["count"] == 1
88+
89+
def test_get_all_timer_stats_thread_safe(self):
90+
"""Ensure get_all_timer_stats does not deadlock when called from another thread."""
91+
metrics_service.record_duration("timer1", 0.1)
92+
metrics_service.record_duration("timer2", 0.2)
93+
94+
result: dict[str, dict[str, float]] = {}
95+
96+
def target() -> None:
97+
result.update(metrics_service.get_all_timer_stats())
98+
99+
worker = threading.Thread(target=target)
100+
worker.start()
101+
worker.join(timeout=1)
102+
103+
assert not worker.is_alive(), "get_all_timer_stats deadlocked when called from another thread"
104+
assert result, "Expected timer stats to be populated after thread execution"
105+
106+
def test_tool_call_processing_metrics(self):
107+
"""Test metrics specific to tool call processing."""
108+
# Simulate processing and skipping messages
109+
metrics_service.inc("tool_call.messages.processed", by=5)
110+
metrics_service.inc("tool_call.messages.skipped", by=45)
111+
112+
assert metrics_service.get("tool_call.messages.processed") == 5
113+
assert metrics_service.get("tool_call.messages.skipped") == 45
114+
115+
# Calculate skip rate
116+
total = 5 + 45
117+
skip_rate = (45 / total) * 100
118+
assert skip_rate == 90.0
119+
120+
def test_log_performance_stats_with_data(self, caplog):
121+
"""Test logging performance statistics with data."""
122+
metrics_service.inc("tool_call.messages.processed", by=10)
123+
metrics_service.inc("tool_call.messages.skipped", by=90)
124+
metrics_service.record_duration("tool_call.processing.duration", 0.05)
125+
metrics_service.record_duration("tool_call.processing.duration", 0.03)
126+
127+
metrics_service.log_performance_stats()
128+
129+
# Check that log messages were generated
130+
assert any("processed=10" in record.message for record in caplog.records)
131+
assert any("skipped=90" in record.message for record in caplog.records)
132+
assert any("skip_rate=90.0%" in record.message for record in caplog.records)
133+
134+
def test_log_performance_stats_no_data(self, caplog):
135+
"""Test logging performance statistics with no data."""
136+
metrics_service.log_performance_stats()
137+
138+
# Should not log anything when there's no data
139+
assert len(caplog.records) == 0

0 commit comments

Comments
 (0)