|
24 | 24 | import argparse |
25 | 25 |
|
26 | 26 | def options(): |
| 27 | + def resolved_path(path): |
| 28 | + return pathlib.Path(path).expanduser().resolve() |
27 | 29 | p = argparse.ArgumentParser(description=__doc__) |
28 | | - p.add_argument("--hash-only", action="store_true") |
29 | | - p.add_argument("sources", type=pathlib.Path, nargs="+") |
30 | | - return p.parse_args() |
31 | | - |
32 | | - |
33 | | -TIMEOUT = 20 |
34 | | - |
35 | | -def warn(message: str) -> None: |
36 | | - print(f"WARNING: {message}", file=sys.stderr) |
37 | | - |
38 | | - |
39 | | -@dataclass |
40 | | -class Endpoint: |
41 | | - name: str |
42 | | - href: str |
43 | | - ssh: typing.Optional[str] = None |
44 | | - headers: typing.Dict[str, str] = dataclasses.field(default_factory=dict) |
45 | | - |
46 | | - def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]): |
47 | | - self.headers.update((k.capitalize(), v) for k, v in d) |
48 | | - |
49 | | - |
50 | | -class NoEndpointsFound(Exception): |
51 | | - pass |
52 | | - |
53 | | - |
54 | | -opts = options() |
55 | | -sources = [p.resolve() for p in opts.sources] |
56 | | -source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources)) |
57 | | -source_dir = subprocess.check_output( |
58 | | - ["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True |
59 | | -).strip() |
| 30 | + excl = p.add_mutually_exclusive_group(required=True) |
| 31 | + excl.add_argument("--hash-only", action="store_true") |
| 32 | + excl.add_argument("--git-lfs", type=resolved_path) |
| 33 | + p.add_argument("sources", type=resolved_path, nargs="+") |
| 34 | + opts = p.parse_args() |
| 35 | + source_dir = pathlib.Path(os.path.commonpath(src.parent for src in opts.sources)) |
| 36 | + opts.source_dir = subprocess.check_output( |
| 37 | + ["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True |
| 38 | + ).strip() |
| 39 | + return opts |
60 | 40 |
|
61 | 41 |
|
62 | 42 | def get_env(s: str, sep: str = "=") -> typing.Iterable[typing.Tuple[str, str]]: |
63 | 43 | for m in re.finditer(rf"(.*?){sep}(.*)", s, re.M): |
64 | 44 | yield m.groups() |
65 | 45 |
|
66 | 46 |
|
67 | | -def git(*args, **kwargs): |
68 | | - proc = subprocess.run( |
69 | | - ("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs |
70 | | - ) |
71 | | - return proc.stdout.strip() if proc.returncode == 0 else None |
72 | | - |
73 | | - |
74 | | -endpoint_re = re.compile(r"^Endpoint(?: \((.*)\))?$") |
75 | | - |
76 | | - |
77 | | -def get_endpoint_addresses() -> typing.Iterable[Endpoint]: |
78 | | - """Get all lfs endpoints, including SSH if present""" |
79 | | - lfs_env_items = get_env( |
80 | | - subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir) |
81 | | - ) |
82 | | - current_endpoint = None |
83 | | - for k, v in lfs_env_items: |
84 | | - m = endpoint_re.match(k) |
85 | | - if m: |
86 | | - if current_endpoint: |
87 | | - yield current_endpoint |
88 | | - href, _, _ = v.partition(" ") |
89 | | - current_endpoint = Endpoint(name=m[1] or "default", href=href) |
90 | | - elif k == " SSH" and current_endpoint: |
91 | | - current_endpoint.ssh = v |
92 | | - if current_endpoint: |
93 | | - yield current_endpoint |
94 | | - |
95 | | - |
96 | | -def get_endpoints() -> typing.Iterable[Endpoint]: |
97 | | - for endpoint in get_endpoint_addresses(): |
98 | | - endpoint.headers = { |
99 | | - "Content-Type": "application/vnd.git-lfs+json", |
100 | | - "Accept": "application/vnd.git-lfs+json", |
101 | | - } |
102 | | - if endpoint.ssh: |
103 | | - # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md |
104 | | - server, _, path = endpoint.ssh.partition(":") |
105 | | - ssh_command = shutil.which( |
106 | | - os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")) |
107 | | - ) |
108 | | - assert ssh_command, "no ssh command found" |
109 | | - cmd = [ |
110 | | - ssh_command, |
111 | | - "-oStrictHostKeyChecking=accept-new", |
112 | | - server, |
113 | | - "git-lfs-authenticate", |
114 | | - path, |
115 | | - "download", |
116 | | - ] |
117 | | - try: |
118 | | - res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=TIMEOUT) |
119 | | - except subprocess.TimeoutExpired: |
120 | | - warn(f"ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint") |
121 | | - continue |
122 | | - if res.returncode != 0: |
123 | | - warn(f"ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint") |
124 | | - continue |
125 | | - ssh_resp = json.loads(res.stdout) |
126 | | - endpoint.href = ssh_resp.get("href", endpoint) |
127 | | - endpoint.update_headers(ssh_resp.get("header", {}).items()) |
128 | | - url = urlparse(endpoint.href) |
129 | | - # this is how actions/checkout persist credentials |
130 | | - # see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63 |
131 | | - auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader") or "" |
132 | | - endpoint.update_headers(get_env(auth, sep=": ")) |
133 | | - if os.environ.get("GITHUB_TOKEN"): |
134 | | - endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}" |
135 | | - if "Authorization" not in endpoint.headers: |
136 | | - # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh) |
137 | | - # see https://git-scm.com/docs/git-credential |
138 | | - credentials = git( |
139 | | - "credential", |
140 | | - "fill", |
141 | | - check=True, |
142 | | - # drop leading / from url.path |
143 | | - input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n", |
144 | | - ) |
145 | | - if credentials is None: |
146 | | - warn(f"no authorization method found, ignoring {endpoint.name} endpoint") |
147 | | - continue |
148 | | - credentials = dict(get_env(credentials)) |
149 | | - auth = base64.b64encode( |
150 | | - f'{credentials["username"]}:{credentials["password"]}'.encode() |
151 | | - ).decode("ascii") |
152 | | - endpoint.headers["Authorization"] = f"Basic {auth}" |
153 | | - yield endpoint |
154 | | - |
155 | | - |
156 | | -# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md |
157 | | -def get_locations(objects): |
| 47 | +def get_locations(objects, opts): |
158 | 48 | ret = ["local" for _ in objects] |
159 | 49 | indexes = [i for i, o in enumerate(objects) if o] |
160 | | - if not indexes: |
161 | | - # all objects are local, do not send an empty request as that would be an error |
162 | | - return ret |
163 | 50 | if opts.hash_only: |
164 | 51 | for i in indexes: |
165 | 52 | ret[i] = objects[i]["oid"] |
166 | | - return ret |
167 | | - data = { |
168 | | - "operation": "download", |
169 | | - "transfers": ["basic"], |
170 | | - "objects": [objects[i] for i in indexes], |
171 | | - "hash_algo": "sha256", |
172 | | - } |
173 | | - for endpoint in get_endpoints(): |
174 | | - req = urllib.request.Request( |
175 | | - f"{endpoint.href}/objects/batch", |
176 | | - headers=endpoint.headers, |
177 | | - data=json.dumps(data).encode("ascii"), |
178 | | - ) |
179 | | - try: |
180 | | - with urllib.request.urlopen(req, timeout=TIMEOUT) as resp: |
181 | | - data = json.load(resp) |
182 | | - assert len(data["objects"]) == len( |
183 | | - indexes |
184 | | - ), f"received {len(data)} objects, expected {len(indexes)}" |
185 | | - for i, resp in zip(indexes, data["objects"]): |
186 | | - ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}' |
187 | | - return ret |
188 | | - except urllib.error.URLError as e: |
189 | | - warn(f"encountered {type(e).__name__} {e}, ignoring endpoint {endpoint.name}") |
190 | | - continue |
191 | | - except KeyError: |
192 | | - warn(f"encountered malformed response, ignoring endpoint {endpoint.name}:\n{json.dumps(data, indent=2)}") |
193 | | - continue |
194 | | - raise NoEndpointsFound |
195 | | - |
| 53 | + else: |
| 54 | + cmd = [opts.git_lfs, "ls-urls", "--json"] |
| 55 | + cmd.extend(objects[i]["path"] for i in indexes) |
| 56 | + data = json.loads(subprocess.check_output(cmd, cwd=opts.source_dir)) |
| 57 | + for i, f in zip(indexes, data["files"]): |
| 58 | + ret[i] = f'{f["oid"]} {f["url"]}' |
| 59 | + return ret |
196 | 60 |
|
197 | 61 | def get_lfs_object(path): |
198 | 62 | with open(path, "rb") as fileobj: |
199 | 63 | lfs_header = "version https://git-lfs.github.com/spec".encode() |
200 | 64 | actual_header = fileobj.read(len(lfs_header)) |
201 | | - sha256 = size = None |
202 | 65 | if lfs_header != actual_header: |
203 | 66 | return None |
204 | 67 | data = dict(get_env(fileobj.read().decode("ascii"), sep=" ")) |
205 | 68 | assert data["oid"].startswith("sha256:"), f"unknown oid type: {data['oid']}" |
206 | 69 | _, _, sha256 = data["oid"].partition(":") |
207 | | - size = int(data["size"]) |
208 | | - return {"oid": sha256, "size": size} |
| 70 | + return {"path": path, "oid": sha256} |
209 | 71 |
|
210 | 72 |
|
211 | | -try: |
212 | | - objects = [get_lfs_object(src) for src in sources] |
213 | | - for resp in get_locations(objects): |
| 73 | +def main(): |
| 74 | + opts = options() |
| 75 | + objects = [get_lfs_object(src) for src in opts.sources] |
| 76 | + for resp in get_locations(objects, opts): |
214 | 77 | print(resp) |
215 | | -except NoEndpointsFound as e: |
216 | | - print("""\ |
217 | | -ERROR: no valid endpoints found, your git authentication method might be currently unsupported by this script. |
218 | | -You can bypass this error by running from semmle-code (this might take a while): |
219 | | - git config lfs.fetchexclude "" |
220 | | - git -C ql config lfs.fetchinclude \\* |
221 | | - git lfs fetch && git lfs checkout |
222 | | - cd ql |
223 | | - git lfs fetch && git lfs checkout""", file=sys.stderr) |
224 | | - sys.exit(1) |
| 78 | + |
| 79 | +if __name__ == "__main__": |
| 80 | + main() |
0 commit comments