6
6
import subprocess
7
7
import sys
8
8
import tempfile
9
- from functools import cached_property , partial
9
+ from functools import cached_property
10
10
from pathlib import Path
11
11
from types import TracebackType
12
12
from typing import Any , Optional
@@ -50,18 +50,11 @@ def clone_path(self) -> Path:
50
50
["git" , "clone" , self .repo , "repo" ], cwd = self .tempdir , check = True
51
51
)
52
52
clone_path = self .tempdir / "repo"
53
- run = partial (subprocess .run , cwd = clone_path , check = True )
54
- if (
55
- run (
56
- ["git" , "ls-remote" , "origin" , self .branch ],
57
- capture_output = True ,
58
- )
59
- .stdout .decode () # type: ignore
60
- .strip ()
53
+ for cmd in (
54
+ ["git" , "checkout" , "-b" , self .branch ],
55
+ ["git" , "reset" , "--hard" , "origin/main" ],
61
56
):
62
- raise Exception (f'Branch "{ self .branch } " already exists on remote' )
63
- run (["git" , "checkout" , "-b" , self .branch ])
64
- run (["git" , "reset" , "--hard" , "origin/main" ])
57
+ subprocess .run (cmd , cwd = clone_path , check = True )
65
58
return clone_path
66
59
67
60
def cruft_attr (self , attr : str ) -> str :
@@ -114,13 +107,60 @@ def lint_test(self) -> None:
114
107
self .logger .error ("Resolve errors and exit shell to continue" )
115
108
self .shell ()
116
109
110
+ def find_existing_pr (self ) -> Optional [str ]:
111
+ with contextlib .suppress (
112
+ subprocess .CalledProcessError , json .JSONDecodeError , TypeError
113
+ ):
114
+ for pr in json .loads (
115
+ self .run (
116
+ [
117
+ "gh" ,
118
+ "pr" ,
119
+ "list" ,
120
+ "-H" ,
121
+ self .branch ,
122
+ "-B" ,
123
+ "main" ,
124
+ "--json" ,
125
+ "," .join (("url" , "headRefName" , "baseRefName" )),
126
+ ],
127
+ capture_output = True ,
128
+ check = True ,
129
+ ).stdout .decode ()
130
+ ):
131
+ pr_url = str (pr .pop ("url" ))
132
+ if pr == {"headRefName" : self .branch , "baseRefName" : "main" }:
133
+ return pr_url
134
+ return None
135
+
136
+ def close_existing_pr (self ) -> None :
137
+ # Locate existing PR
138
+ pr_url = self .find_existing_pr ()
139
+ if pr_url :
140
+ if self .dry_run :
141
+ self .logger .info (f"Would close existing PR { pr_url } " )
142
+ else :
143
+ self .run (["gh" , "pr" , "close" , pr_url ])
144
+ self .logger .info (f"Closed existing PR { pr_url } " )
145
+ if self .dry_run :
146
+ return
147
+ # Delete existing branch
148
+ delete_result = self .run (
149
+ ["git" , "push" , "origin" , f":{ self .branch } " ],
150
+ capture_output = True ,
151
+ check = False ,
152
+ )
153
+ if delete_result .returncode == 0 :
154
+ self .logger .info (f"Deleted existing remote branch { self .branch } " )
155
+
117
156
def open_pr (self , message : str ) -> None :
157
+ self .close_existing_pr ()
118
158
if self .dry_run :
119
159
self .logger .success ("Would open PR" )
120
160
return
121
161
self .run (["git" , "push" , "origin" , self .branch ])
122
162
commit_title , _ , * commit_body = message .splitlines ()
123
- self .run (
163
+ pr_url = self .run (
124
164
[
125
165
"gh" ,
126
166
"pr" ,
@@ -135,5 +175,6 @@ def open_pr(self, message: str) -> None:
135
175
self .branch ,
136
176
],
137
177
input = os .linesep .join (commit_body ).encode ("utf-8" ),
138
- )
139
- self .logger .success ("Opened PR" )
178
+ capture_output = True ,
179
+ ).stdout .decode ()
180
+ self .logger .success (f"Opened PR { pr_url } " )
0 commit comments