|
20 | 20 | import hmac |
21 | 21 | import logging |
22 | 22 | import sys |
23 | | -from typing import Callable, Optional |
| 23 | +from typing import Any, Callable, Dict, Optional |
24 | 24 |
|
25 | 25 | import requests |
26 | 26 | import yaml |
27 | 27 |
|
| 28 | +_DEFAULT_SERVER_URL = "http://localhost:8008" |
| 29 | + |
28 | 30 |
|
29 | 31 | def request_registration( |
30 | 32 | user: str, |
@@ -203,31 +205,76 @@ def main() -> None: |
203 | 205 |
|
204 | 206 | parser.add_argument( |
205 | 207 | "server_url", |
206 | | - default="https://localhost:8448", |
207 | 208 | nargs="?", |
208 | | - help="URL to use to talk to the homeserver. Defaults to " |
209 | | - " 'https://localhost:8448'.", |
| 209 | + help="URL to use to talk to the homeserver. By default, tries to find a " |
| 210 | + "suitable URL from the configuration file. Otherwise, defaults to " |
| 211 | + f"'{_DEFAULT_SERVER_URL}'.", |
210 | 212 | ) |
211 | 213 |
|
212 | 214 | args = parser.parse_args() |
213 | 215 |
|
214 | 216 | if "config" in args and args.config: |
215 | 217 | config = yaml.safe_load(args.config) |
| 218 | + |
| 219 | + if args.shared_secret: |
| 220 | + secret = args.shared_secret |
| 221 | + else: |
| 222 | + # argparse should check that we have either config or shared secret |
| 223 | + assert config |
| 224 | + |
216 | 225 | secret = config.get("registration_shared_secret", None) |
217 | 226 | if not secret: |
218 | 227 | print("No 'registration_shared_secret' defined in config.") |
219 | 228 | sys.exit(1) |
| 229 | + |
| 230 | + if args.server_url: |
| 231 | + server_url = args.server_url |
| 232 | + elif config: |
| 233 | + listening_port = _find_client_listener(config) |
| 234 | + if listening_port: |
| 235 | + server_url = f"http://{listening_port}" |
| 236 | + else: |
| 237 | + server_url = _DEFAULT_SERVER_URL |
| 238 | + print( |
| 239 | + "Unable to find a suitable HTTP listener in the configuration file. " |
| 240 | + f"Trying {server_url} as a last resort.", |
| 241 | + file=sys.stderr, |
| 242 | + ) |
220 | 243 | else: |
221 | | - secret = args.shared_secret |
| 244 | + server_url = _DEFAULT_SERVER_URL |
| 245 | + print( |
| 246 | + f"No server url or configuration file given. Defaulting to {server_url}.", |
| 247 | + file=sys.stderr, |
| 248 | + ) |
222 | 249 |
|
223 | 250 | admin = None |
224 | 251 | if args.admin or args.no_admin: |
225 | 252 | admin = args.admin |
226 | 253 |
|
227 | 254 | register_new_user( |
228 | | - args.user, args.password, args.server_url, secret, admin, args.user_type |
| 255 | + args.user, args.password, server_url, secret, admin, args.user_type |
229 | 256 | ) |
230 | 257 |
|
231 | 258 |
|
| 259 | +def _find_client_listener(config: Dict[str, Any]) -> Optional[str]: |
| 260 | + # try to find a listener in the config. Returns a host:port pair |
| 261 | + for listener in config.get("listeners", []): |
| 262 | + if listener.get("type") != "http" or listener.get("tls", False): |
| 263 | + continue |
| 264 | + |
| 265 | + if not any( |
| 266 | + name == "client" |
| 267 | + for resource in listener.get("resources", []) |
| 268 | + for name in resource.get("names", []) |
| 269 | + ): |
| 270 | + continue |
| 271 | + |
| 272 | + # TODO: consider bind_addresses |
| 273 | + return f"localhost:{listener['port']}" |
| 274 | + |
| 275 | + # no suitable listeners? |
| 276 | + return None |
| 277 | + |
| 278 | + |
232 | 279 | if __name__ == "__main__": |
233 | 280 | main() |
0 commit comments