Skip to content

Commit 557ad85

Browse files
authored
[MergeQueues] Delete branch if PR is merged (#193)
### Summary Github's delete branch logic doesn't work if SGTM added the pull request to the merge queue. This will make SGTM delete branches if they haven't been cleaned up yet. Asana tasks: https://app.asana.com/0/1125126232217429/1207661000934486/f Relevant deployment: CC: @suzyng83209 @prebeta @vn6 @michael-huang87 ### Test Plan Run on test sgtm deployment ### Risks Pull Request: #193 Pull Request synchronized with [Asana task](https://app.asana.com/0/0/1207680708784404)
1 parent a1c08c8 commit 557ad85

File tree

8 files changed

+109
-9
lines changed

8 files changed

+109
-9
lines changed

src/github/client.py

+37-9
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,41 @@
11
import requests
22
from requests.auth import HTTPBasicAuth
33

4-
from github import PullRequest # type: ignore
4+
from github.PullRequest import PullRequest
5+
from github.GithubException import GithubException
56
from src.github.get_app_token import sgtm_github_auth
67
from src.logger import logger
78

89
gh_client = sgtm_github_auth.get_rest_client()
910

1011

11-
def _get_pull_request(owner: str, repository: str, number: int) -> PullRequest: # type: ignore
12+
def _get_pull_request(owner: str, repository: str, number: int) -> PullRequest:
1213
repo = gh_client.get_repo(f"{owner}/{repository}")
1314
pr = repo.get_pull(number)
14-
return pr # type: ignore
15+
return pr
1516

1617

1718
def edit_pr_description(owner: str, repository: str, number: int, description: str):
1819
pr = _get_pull_request(owner, repository, number)
19-
pr.edit(body=description) # type: ignore
20+
pr.edit(body=description)
2021

2122

2223
def edit_pr_title(owner: str, repository: str, number: int, title: str):
2324
pr = _get_pull_request(owner, repository, number)
24-
pr.edit(title=title) # type: ignore
25+
pr.edit(title=title)
2526

2627

2728
def add_pr_comment(owner: str, repository: str, number: int, comment: str):
2829
pr = _get_pull_request(owner, repository, number)
29-
pr.create_issue_comment(comment) # type: ignore
30+
pr.create_issue_comment(comment)
3031

3132

3233
def set_pull_request_assignee(owner: str, repository: str, number: int, assignee: str):
3334
repo = gh_client.get_repo(f"{owner}/{repository}")
3435
# Using get_issue here because get_pull returns a pull request which only
3536
# allows you to *add* an assignee, not set the assignee.
3637
pr = repo.get_issue(number)
37-
pr.edit(assignee=assignee) # type: ignore
38+
pr.edit(assignee=assignee)
3839

3940

4041
def merge_pull_request(owner: str, repository: str, number: int, title: str, body: str):
@@ -44,13 +45,15 @@ def merge_pull_request(owner: str, repository: str, number: int, title: str, bod
4445
# which we rely on for code review tests.
4546
title_with_number = f"{title} (#{number})"
4647
try:
47-
pr.enable_automerge(commit_headline=title_with_number, commit_body=body) # type: ignore
48+
pr.enable_automerge(commit_headline=title_with_number, commit_body=body)
4849
except Exception as e:
4950
logger.info(
5051
f"Failed to enable automerge for PR {title_with_number}, with error {e}"
5152
)
5253
logger.info("Merging PR manually")
53-
pr.merge(commit_title=title_with_number, commit_message=body, merge_method="squash") # type: ignore
54+
pr.merge(
55+
commit_title=title_with_number, commit_message=body, merge_method="squash"
56+
)
5457

5558

5659
def rerequest_check_run(owner: str, repository: str, check_run_id: int):
@@ -60,3 +63,28 @@ def rerequest_check_run(owner: str, repository: str, check_run_id: int):
6063
)
6164
# Some check runs cannot be rerequested. See https://docs.github.com/en/rest/checks/runs?apiVersion=2022-11-28#rerequest-a-check-run--status-codes
6265
return requests.post(url, auth=auth).status_code == 201
66+
67+
68+
def delete_branch_if_exists(owner: str, repo_name: str, branch_name: str):
69+
"""
70+
Deletes a branch from a GitHub repository if it exists.
71+
72+
Args:
73+
owner (str): The owner of the repository.
74+
repo_name (str): The name of the repository.
75+
branch_name (str): The name of the branch to delete.
76+
"""
77+
try:
78+
repo = gh_client.get_repo(f"{owner}/{repo_name}")
79+
# Attempt to get the branch, will raise a 404 error if not found
80+
repo.get_branch(branch_name)
81+
ref = f"heads/{branch_name}"
82+
git_ref = repo.get_git_ref(ref)
83+
git_ref.delete()
84+
logger.info(f"Branch '{branch_name}' deleted successfully.")
85+
except GithubException as e:
86+
if e.status == 404:
87+
logger.info(f"Branch '{branch_name}' does not exist or is already deleted.")
88+
else:
89+
logger.error(f"Error deleting branch: {e}")
90+
raise

src/github/graphql/fragments/FullPullRequest.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
fragment FullPullRequest on PullRequest {
77
id
88
baseRefName
9+
headRefName
910
body
1011
bodyHTML
1112
title

src/github/logic.py

+8
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,14 @@ def maybe_automerge_pull_request(pull_request: PullRequest) -> bool:
296296
return False
297297

298298

299+
def maybe_delete_branch_if_merged(pull_request: PullRequest):
300+
if pull_request.merged():
301+
owner = pull_request.repository_owner_handle()
302+
repo_name = pull_request.repository_name()
303+
branch_name = pull_request.head_ref_name()
304+
github_client.delete_branch_if_exists(owner, repo_name, branch_name)
305+
306+
299307
# ----------------------------------------------------------------------------------
300308
# Automerge helpers
301309
# ----------------------------------------------------------------------------------

src/github/models/pull_request.py

+3
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,6 @@ def labels(self) -> List[Label]:
219219

220220
def base_ref_name(self) -> str:
221221
return self._raw["baseRefName"]
222+
223+
def head_ref_name(self) -> str:
224+
return self._raw["headRefName"]

src/github/webhook.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def _handle_pull_request_webhook(payload: dict) -> HttpResponse:
1818
pull_request = graphql_client.get_pull_request(pull_request_id)
1919
# a label change will trigger this webhook, so it may trigger automerge
2020
github_logic.maybe_automerge_pull_request(pull_request)
21+
if payload["action"] == "closed":
22+
github_logic.maybe_delete_branch_if_merged(pull_request)
2123
github_logic.maybe_add_automerge_warning_comment(pull_request)
2224
github_controller.upsert_pull_request(pull_request)
2325
return HttpResponse("200")

test/github/test_logic.py

+21
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,27 @@ def test_does_not_add_warning_comment_if_pr_is_approved(
602602
add_pr_comment_mock.assert_not_called()
603603

604604

605+
@patch.object(github_client, "delete_branch_if_exists")
606+
class TestMaybeDeleteBranchIfMerged(unittest.TestCase):
607+
def test_maybe_delete_branch_if_merged(self, mock_delete_branch):
608+
pull_request = build(
609+
builder.pull_request().merged(True).head_ref_name("feature-branch")
610+
)
611+
github_logic.maybe_delete_branch_if_merged(pull_request)
612+
mock_delete_branch.assert_called_once_with(
613+
pull_request.repository_owner_handle(),
614+
pull_request.repository_name(),
615+
"feature-branch",
616+
)
617+
618+
def test_do_not_delete_branch_if_closed(self, mock_delete_branch):
619+
pull_request = build(
620+
builder.pull_request().merged(False).head_ref_name("feature-branch")
621+
)
622+
github_logic.maybe_delete_branch_if_merged(pull_request)
623+
mock_delete_branch.assert_not_called()
624+
625+
605626
if __name__ == "__main__":
606627
from unittest import main as run_tests
607628

test/github/test_webhook.py

+32
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,38 @@ def test_comment_deletion_when_review_not_found(
123123
delete_comment.assert_called_once_with(self.COMMENT_NODE_ID)
124124

125125

126+
@patch("src.github.controller.upsert_pull_request")
127+
@patch("src.github.logic.maybe_automerge_pull_request")
128+
@patch("src.github.logic.maybe_delete_branch_if_merged")
129+
@patch("src.github.graphql.client.get_pull_request")
130+
class TestHandlePullRequestWebhookClosed(MockDynamoDbTestCase):
131+
PULL_REQUEST_NODE_ID = "abcdef"
132+
133+
def test_handle_pull_request_webhook_when_closed(
134+
self,
135+
get_pull_request,
136+
maybe_delete_branch_if_merged,
137+
maybe_automerge_pull_request,
138+
upsert_pull_request,
139+
):
140+
payload = {
141+
"action": "closed",
142+
"pull_request": {
143+
"node_id": self.PULL_REQUEST_NODE_ID,
144+
},
145+
}
146+
147+
pull_request = MagicMock(spec=PullRequest)
148+
get_pull_request.return_value = pull_request
149+
150+
response = webhook._handle_pull_request_webhook(payload)
151+
self.assertEqual(response.status_code, "200")
152+
153+
maybe_automerge_pull_request.assert_called_once()
154+
maybe_delete_branch_if_merged.assert_called_once()
155+
upsert_pull_request.assert_called_once()
156+
157+
126158
if __name__ == "__main__":
127159
from unittest import main as run_tests
128160

test/impl/builders/pull_request_builder.py

+5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(self, body: str = ""):
2727
"number": pr_number,
2828
"body": body,
2929
"baseRefName": create_uuid(),
30+
"headRefName": create_uuid(),
3031
"title": create_uuid(),
3132
"url": "https://www.github.com/foo/pulls/" + str(pr_number),
3233
"assignees": {"nodes": []},
@@ -168,6 +169,10 @@ def base_ref_name(self, base_ref_name: str):
168169
self.raw_pr["baseRefName"] = base_ref_name
169170
return self
170171

172+
def head_ref_name(self, head_ref_name: str):
173+
self.raw_pr["headRefName"] = head_ref_name
174+
return self
175+
171176
def build(self) -> PullRequest:
172177
return PullRequest(self.raw_pr)
173178

0 commit comments

Comments
 (0)