Skip to content

Commit 5e32205

Browse files
committed
modify core logic for efficiency and add tests for incident prs
1 parent a7fb8bf commit 5e32205

File tree

3 files changed

+288
-27
lines changed

3 files changed

+288
-27
lines changed

backend/analytics_server/mhq/service/incidents/incidents.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_resolved_team_incidents(
6262
resolved_pr_incidents = self.get_team_pr_incidents(team_id, interval, pr_filter)
6363

6464
total_incidents = resolved_incidents + resolved_pr_incidents
65+
total_incidents = sorted(total_incidents, key=lambda x: x.creation_date)
6566

6667
return {incident.key: incident for incident in total_incidents}.values()
6768

@@ -85,6 +86,7 @@ def get_team_incidents(
8586
)
8687

8788
total_incidents = incidents + pr_incidents
89+
total_incidents = sorted(total_incidents, key=lambda x: x.creation_date)
8890

8991
return {incident.key: incident for incident in total_incidents}.values()
9092

@@ -105,58 +107,62 @@ def get_team_pr_incidents(
105107
tr.org_repo_id
106108
for tr in self._code_repo_service.get_active_team_repos_by_team_id(team_id)
107109
)
108-
109-
prs = self._code_repo_service.get_prs_merged_in_interval(
110-
team_repo_ids, interval, pr_filter
111-
)
112-
113110
resolution_prs_interval = Interval(
114111
from_time=interval.from_time, to_time=time_now()
115112
)
116-
pr_filter: PRFilter = apply_pr_filter(
113+
resolution_prs_filter: PRFilter = apply_pr_filter(
117114
asdict(pr_filter),
118115
EntityType.TEAM,
119116
team_id,
120117
[SettingType.EXCLUDED_PRS_SETTING, SettingType.INCIDENT_PRS_SETTING],
121118
)
119+
122120
resolution_prs = self._code_repo_service.get_prs_merged_in_interval(
123-
team_repo_ids, resolution_prs_interval, pr_filter
121+
team_repo_ids, resolution_prs_interval, resolution_prs_filter
124122
)
125123

124+
pr_numbers: List[str] = []
126125
pr_incidents: List[Incident] = []
127126
repo_id_to_pr_number_to_pr_map: Dict[str, Dict[str, PullRequest]] = {}
128127

129-
for pr in prs:
130-
if str(pr.repo_id) not in repo_id_to_pr_number_to_pr_map:
131-
repo_id_to_pr_number_to_pr_map[str(pr.repo_id)] = {}
132-
repo_id_to_pr_number_to_pr_map[str(pr.repo_id)][pr.number] = pr
133-
134128
for pr in resolution_prs:
135129
for filter in incident_prs_setting.filters:
136-
pr_number = self._extract_pr_number_from_regex(
130+
incident_pr_number = self._extract_pr_number_from_regex(
137131
getattr(pr, filter["field"]), filter["value"]
138132
)
139-
if (
140-
pr_number
141-
and str(pr.repo_id) in repo_id_to_pr_number_to_pr_map
142-
and pr_number in repo_id_to_pr_number_to_pr_map[str(pr.repo_id)]
143-
):
144-
original_pr = repo_id_to_pr_number_to_pr_map[str(pr.repo_id)][
145-
pr_number
146-
]
147-
adapted_incident_pr = IncidentPRAdapter.adapt(original_pr, pr)
148-
pr_incidents.append(adapted_incident_pr)
149-
repo_id_to_pr_number_to_pr_map[str(pr.repo_id)].pop(
150-
original_pr.number
151-
)
133+
134+
if incident_pr_number:
135+
pr_numbers.append(incident_pr_number)
136+
if str(pr.repo_id) not in repo_id_to_pr_number_to_pr_map:
137+
repo_id_to_pr_number_to_pr_map[str(pr.repo_id)] = {}
138+
repo_id_to_pr_number_to_pr_map[str(pr.repo_id)][
139+
incident_pr_number
140+
] = pr
152141
break
153142

143+
prs = self._code_repo_service.get_prs_merged_in_interval_by_numbers(
144+
list(repo_id_to_pr_number_to_pr_map.keys()), interval, pr_numbers, pr_filter
145+
)
146+
147+
for pr in prs:
148+
if (
149+
str(pr.repo_id) not in repo_id_to_pr_number_to_pr_map
150+
or pr.number not in repo_id_to_pr_number_to_pr_map[str(pr.repo_id)]
151+
):
152+
continue
153+
154+
resolution_pr = repo_id_to_pr_number_to_pr_map[str(pr.repo_id)][pr.number]
155+
adapted_pr_incident = IncidentPRAdapter.adapt(pr, resolution_pr)
156+
pr_incidents.append(adapted_pr_incident)
157+
154158
return pr_incidents
155159

156160
def _extract_pr_number_from_regex(
157161
self, text: str, regex_pattern: str
158162
) -> Optional[str]:
159-
if regex_pattern and check_regex(regex_pattern):
163+
if not text or not regex_pattern:
164+
return None
165+
if check_regex(regex_pattern):
160166
match = re.search(regex_pattern, text)
161167
if match and len(match.groups()) >= 1:
162168
return match.group(1)

backend/analytics_server/mhq/store/repos/code.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,25 @@ def get_prs_merged_in_interval(
307307

308308
return query.all()
309309

310+
@rollback_on_exc
311+
def get_prs_merged_in_interval_by_numbers(
312+
self,
313+
repo_ids: List[str],
314+
interval: Interval,
315+
numbers: List[str],
316+
pr_filter: PRFilter = None,
317+
) -> List[PullRequest]:
318+
query = self._db.session.query(PullRequest).options(defer(PullRequest.data))
319+
320+
query = self._filter_prs_by_repo_ids(query, repo_ids)
321+
query = self._filter_prs_merged_in_interval(query, interval)
322+
query = self._filter_prs(query, pr_filter)
323+
query = query.filter(PullRequest.number.in_(numbers))
324+
325+
query = query.order_by(PullRequest.state_changed_at.asc())
326+
327+
return query.all()
328+
310329
@rollback_on_exc
311330
def get_pull_request_by_id(self, pr_id: str) -> PullRequest:
312331
return (
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from datetime import datetime
2+
from unittest.mock import patch, Mock
3+
4+
import pytz
5+
from mhq.service.incidents.incidents import IncidentService, get_incident_service
6+
from mhq.utils.time import Interval
7+
from mhq.service.settings.models import IncidentPRsSetting
8+
from mhq.store.models.settings.configuration_settings import SettingType
9+
from mhq.store.models.code.filter import PRFilter
10+
from mhq.store.models.code import TeamRepos, PullRequest
11+
from mhq.service.settings.models import ConfigurationSettings
12+
from mhq.store.models import EntityType
13+
import pytest
14+
15+
16+
def mock_interval():
17+
start_time = datetime(2025, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
18+
end_time = datetime(2025, 3, 31, 0, 0, 0, tzinfo=pytz.UTC)
19+
return Interval(start_time, end_time)
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def mock_apply_pr_filter():
24+
with patch("mhq.service.incidents.incidents.apply_pr_filter") as mock:
25+
mock.return_value = PRFilter()
26+
yield mock
27+
28+
29+
class FakeSettingsService:
30+
def get_settings(self, *args, **kwargs):
31+
filters = [
32+
{
33+
"field": "head_branch",
34+
"value": "^revert-(\\d+)$",
35+
},
36+
{
37+
"field": "title",
38+
"value": "^Revert PR #(\\d+).*",
39+
},
40+
]
41+
return ConfigurationSettings(
42+
entity_id="team_1",
43+
entity_type=EntityType.TEAM,
44+
specific_settings=IncidentPRsSetting(
45+
include_revert_prs=True, filters=filters
46+
),
47+
updated_by="user_1",
48+
created_at=datetime(2025, 1, 1, 0, 0, 0, tzinfo=pytz.UTC),
49+
updated_at=datetime(2025, 1, 1, 0, 0, 0, tzinfo=pytz.UTC),
50+
)
51+
52+
def get_settings_map(self, *args, **kwargs):
53+
return {
54+
SettingType.INCIDENT_PRS_SETTING: self.get_settings(*args, **kwargs),
55+
}
56+
57+
58+
class FakeCodeRepoService:
59+
def __init__(self, prs_using_filters, prs_using_numbers):
60+
self._prs_using_filters = prs_using_filters
61+
self._prs_using_numbers = prs_using_numbers
62+
63+
def get_active_team_repos_by_team_id(self, *args, **kwargs):
64+
return [
65+
TeamRepos(
66+
team_id="team_1",
67+
org_repo_id="repo_1",
68+
),
69+
TeamRepos(
70+
team_id="team_1",
71+
org_repo_id="repo_2",
72+
),
73+
]
74+
75+
def get_prs_merged_in_interval(self, *args, **kwargs):
76+
return self._prs_using_filters
77+
78+
def get_prs_merged_in_interval_by_numbers(self, *args, **kwargs):
79+
return self._prs_using_numbers
80+
81+
82+
class FakeIncidentsRepoService:
83+
pass
84+
85+
86+
def test_get_team_pr_incidents_no_filters():
87+
incident_service = get_incident_service()
88+
89+
mock_settings_service = Mock()
90+
mock_settings_service.get_settings.return_value.specific_settings = (
91+
IncidentPRsSetting(include_revert_prs=True, filters=[])
92+
)
93+
incident_service._settings_service = mock_settings_service
94+
95+
result = incident_service.get_team_pr_incidents(
96+
"team_1",
97+
mock_interval(),
98+
PRFilter(),
99+
)
100+
101+
assert result == []
102+
103+
104+
def test_get_team_pr_incidents_with_filters():
105+
106+
prs_using_filters = [
107+
PullRequest(
108+
id="pr_2_of_repo_1",
109+
repo_id="repo_1",
110+
number="2",
111+
head_branch="revert-1",
112+
),
113+
PullRequest(
114+
id="pr_4_of_repo_1",
115+
repo_id="repo_1",
116+
number="4",
117+
head_branch="branch_4",
118+
title="Revert PR #3 due to some reason",
119+
),
120+
]
121+
122+
prs_using_numbers = [
123+
PullRequest(
124+
id="pr_1_of_repo_1", repo_id="repo_1", number="1", head_branch="branch_1"
125+
),
126+
PullRequest(
127+
id="pr_3_of_repo_1",
128+
repo_id="repo_1",
129+
number="3",
130+
head_branch="branch_3",
131+
),
132+
]
133+
134+
incident_service = IncidentService(
135+
FakeIncidentsRepoService(),
136+
FakeSettingsService(),
137+
FakeCodeRepoService(prs_using_filters, prs_using_numbers),
138+
)
139+
140+
expected_result_keys = [
141+
"pr_1_of_repo_1",
142+
"pr_3_of_repo_1",
143+
]
144+
145+
result = incident_service.get_team_pr_incidents(
146+
"team_1",
147+
mock_interval(),
148+
PRFilter(),
149+
)
150+
151+
assert expected_result_keys == [incident.key for incident in result]
152+
153+
154+
def test_get_team_pr_incidents_with_multiple_repos_but_no_incidents():
155+
156+
prs_using_filters = [
157+
PullRequest(
158+
id="pr_2_of_repo_1",
159+
repo_id="repo_1",
160+
number="2",
161+
head_branch="revert-1",
162+
),
163+
PullRequest(
164+
id="pr_2_of_repo_2",
165+
repo_id="repo_2",
166+
number="2",
167+
head_branch="revert-1",
168+
),
169+
]
170+
171+
prs_using_numbers = []
172+
173+
incident_service = IncidentService(
174+
FakeIncidentsRepoService(),
175+
FakeSettingsService(),
176+
FakeCodeRepoService(prs_using_filters, prs_using_numbers),
177+
)
178+
179+
expected_result_keys = []
180+
181+
result = incident_service.get_team_pr_incidents(
182+
"team_1",
183+
mock_interval(),
184+
PRFilter(),
185+
)
186+
187+
assert expected_result_keys == [incident.key for incident in result]
188+
189+
190+
def test_get_team_pr_incidents_with_multiple_repos_and_incidents():
191+
192+
prs_using_filters = [
193+
PullRequest(
194+
id="pr_2_of_repo_1", repo_id="repo_1", number="2", head_branch="revert-1"
195+
),
196+
PullRequest(
197+
id="pr_2_of_repo_2",
198+
repo_id="repo_2",
199+
number="2",
200+
head_branch="revert-1",
201+
),
202+
]
203+
204+
prs_using_numbers = [
205+
PullRequest(
206+
id="pr_1_of_repo_1",
207+
repo_id="repo_1",
208+
number="1",
209+
head_branch="branch_1",
210+
),
211+
PullRequest(
212+
id="pr_1_of_repo_2",
213+
repo_id="repo_2",
214+
number="1",
215+
head_branch="branch_1",
216+
),
217+
]
218+
219+
incident_service = IncidentService(
220+
FakeIncidentsRepoService(),
221+
FakeSettingsService(),
222+
FakeCodeRepoService(prs_using_filters, prs_using_numbers),
223+
)
224+
225+
expected_result_keys = [
226+
"pr_1_of_repo_1",
227+
"pr_1_of_repo_2",
228+
]
229+
230+
result = incident_service.get_team_pr_incidents(
231+
"team_1",
232+
mock_interval(),
233+
PRFilter(),
234+
)
235+
236+
assert expected_result_keys == [incident.key for incident in result]

0 commit comments

Comments
 (0)