Skip to content

Commit 8661e22

Browse files
committed
feat: add support for optional cookie in API requests and enhance command context handling
1 parent 8b9f32d commit 8661e22

File tree

1 file changed

+70
-40
lines changed

1 file changed

+70
-40
lines changed

CSOJ-cli.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
from pathlib import Path
88
import base64
9-
import fnmatch # 导入用于文件名模式匹配的库
9+
import fnmatch
1010

1111
import click
1212
import requests
@@ -79,10 +79,13 @@ def load_config():
7979
sys.exit(1)
8080

8181

82-
def get_api_headers():
82+
def get_api_headers(cookie=None):
8383
"""Gets the necessary headers for authenticated API calls."""
8484
config = load_config()
85-
return {"Authorization": f"Bearer {config['jwt']}"}
85+
headers = {"Authorization": f"Bearer {config['jwt']}"}
86+
if cookie:
87+
headers['Cookie'] = cookie
88+
return headers
8689

8790

8891
def handle_api_error(response):
@@ -98,13 +101,17 @@ def handle_api_error(response):
98101

99102
# --- CLI Command Group ---
100103
@click.group(cls=AliasedGroup)
101-
def cli():
104+
@click.option('--cookie', help='Optional cookie string to include in all requests.')
105+
@click.pass_context
106+
def cli(ctx, cookie):
102107
"""
103108
CSOJ-cli: A command-line tool for the CSOJ Online Judge platform.
104109
105110
Supports both full command names and short aliases (e.g., 'leaderboard' or 'lb').
111+
The --cookie option can be used with any command.
106112
"""
107-
pass
113+
ctx.ensure_object(dict)
114+
ctx.obj['cookie'] = cookie
108115

109116

110117
# --- CLI Commands ---
@@ -124,14 +131,17 @@ def login(domain, jwt):
124131

125132

126133
@cli.command("ls-contests", help="List all available contests. (alias: lsc)")
127-
def list_contests():
134+
@click.pass_context
135+
def list_contests(ctx):
128136
config = load_config()
129137
domain = config["domain"]
138+
cookie = ctx.obj.get('cookie')
139+
headers = get_api_headers(cookie)
130140
url = f"{domain}/api/v1/contests"
131141

132142
try:
133143
with console.status("[bold green]Fetching contests...[/]"):
134-
response = requests.get(url, timeout=10)
144+
response = requests.get(url, headers=headers, timeout=10)
135145

136146
if response.status_code != 200:
137147
handle_api_error(response)
@@ -165,14 +175,17 @@ def list_contests():
165175

166176
@cli.command("ls-problems", help="List all problems for a given contest. (alias: lsp)")
167177
@click.argument("contest_id")
168-
def list_problems(contest_id):
178+
@click.pass_context
179+
def list_problems(ctx, contest_id):
169180
config = load_config()
170181
domain = config["domain"]
182+
cookie = ctx.obj.get('cookie')
183+
headers = get_api_headers(cookie)
171184
contest_url = f"{domain}/api/v1/contests/{contest_id}"
172185

173186
try:
174187
with console.status("[bold green]Fetching contest details...[/]"):
175-
contest_response = requests.get(contest_url, timeout=10)
188+
contest_response = requests.get(contest_url, headers=headers, timeout=10)
176189

177190
if contest_response.status_code != 200:
178191
handle_api_error(contest_response)
@@ -200,7 +213,7 @@ def list_problems(contest_id):
200213
task = progress.add_task("[green]Fetching problems...", total=len(problem_ids))
201214
for prob_id in problem_ids:
202215
problem_url = f"{domain}/api/v1/problems/{prob_id}"
203-
problem_response = requests.get(problem_url, timeout=10)
216+
problem_response = requests.get(problem_url, headers=headers, timeout=10)
204217
if problem_response.status_code == 200:
205218
problem_data = problem_response.json().get("data", {})
206219
table.add_row(problem_data.get("id", "N/A"), problem_data.get("name", "N/A"))
@@ -216,10 +229,12 @@ def list_problems(contest_id):
216229

217230
@cli.command(help="Register for a specific contest. (alias: reg)")
218231
@click.argument("contest_id")
219-
def register(contest_id):
232+
@click.pass_context
233+
def register(ctx, contest_id):
220234
config = load_config()
221235
domain = config["domain"]
222-
headers = get_api_headers()
236+
cookie = ctx.obj.get('cookie')
237+
headers = get_api_headers(cookie)
223238
url = f"{domain}/api/v1/contests/{contest_id}/register"
224239

225240
try:
@@ -241,16 +256,18 @@ def register(contest_id):
241256
@cli.command(help="Submit a file or folder to a problem. (alias: sub)")
242257
@click.argument("problem_id")
243258
@click.argument("path", type=click.Path(exists=True, resolve_path=True, file_okay=True, dir_okay=True))
244-
def submit(problem_id, path):
259+
@click.pass_context
260+
def submit(ctx, problem_id, path):
245261
config = load_config()
246262
domain = config["domain"]
247-
headers = get_api_headers()
263+
cookie = ctx.obj.get('cookie')
264+
headers = get_api_headers(cookie)
248265
submit_url = f"{domain}/api/v1/problems/{problem_id}/submit"
249266
problem_url = f"{domain}/api/v1/problems/{problem_id}"
250267
submit_path = Path(path)
251-
files_to_upload = []
252268
skipped_files = []
253-
prepared_files = []
269+
files_to_upload_info = []
270+
prepared_files_for_request = []
254271

255272
try:
256273
# Step 1: Fetch problem details to get upload rules
@@ -275,7 +292,7 @@ def submit(problem_id, path):
275292

276293
for file_path in all_files_to_check:
277294
if submit_path.is_dir():
278-
relative_path = file_path.relative_to(submit_path)
295+
relative_path = file_path.relative_to(submit_path.parent)
279296
else:
280297
relative_path = Path(file_path.name)
281298

@@ -286,7 +303,7 @@ def submit(problem_id, path):
286303
is_allowed = any(fnmatch.fnmatch(str(relative_path), pattern) for pattern in upload_rules)
287304

288305
if is_allowed:
289-
files_to_upload.append({'absolute_path': file_path, 'relative_path': relative_path})
306+
files_to_upload_info.append({'absolute_path': file_path, 'relative_path': relative_path})
290307
else:
291308
skipped_files.append(str(relative_path))
292309

@@ -296,27 +313,26 @@ def submit(problem_id, path):
296313
for skipped in skipped_files:
297314
console.print(f" - [dim]{skipped}[/dim]")
298315

299-
if not files_to_upload:
316+
if not files_to_upload_info:
300317
console.print("[bold red]Error:[/bold red] No valid files found for submission after filtering. Aborting.")
301318
sys.exit(1)
302319

303320
# Step 4: Prepare files for upload by opening them
304-
for file_info in files_to_upload:
321+
for file_info in files_to_upload_info:
305322
file_obj = open(file_info['absolute_path'], 'rb')
306-
b64_name = base64.b64encode(str(file_info['relative_path']).encode()).decode()
307-
prepared_files.append(('files', (b64_name, file_obj)))
323+
b64_name = base64.b64encode(file_info['relative_path'].as_posix().encode()).decode()
324+
prepared_files_for_request.append(('files', (b64_name, file_obj)))
308325

309-
console.print(f"Uploading {len(prepared_files)} file(s) to problem [cyan]{problem_id}[/cyan]...")
310-
response = requests.post(submit_url, headers=headers, files=prepared_files, timeout=60)
326+
console.print(f"Uploading {len(prepared_files_for_request)} file(s) to problem [cyan]{problem_id}[/cyan]...")
327+
response = requests.post(submit_url, headers=headers, files=prepared_files_for_request, timeout=60)
311328

312329
if response.status_code == 403:
313-
# Check for ban message first
314330
try:
315331
error_data = response.json()
316332
if "banned" in error_data.get("message", "").lower():
317-
handle_api_error(response) # Let the handler print the specific ban message
333+
handle_api_error(response)
318334
except json.JSONDecodeError:
319-
pass # Fallback to generic registration message
335+
pass
320336

321337
console.print("[bold red]Submission Forbidden (HTTP 403).[/bold red]")
322338
console.print("This may be because you have not registered for the contest.")
@@ -340,12 +356,12 @@ def submit(problem_id, path):
340356
sys.exit(1)
341357
finally:
342358
# Step 5: Ensure all opened file objects are closed
343-
for _, (_, file_obj) in prepared_files:
359+
for _, (_, file_obj) in prepared_files_for_request:
344360
if file_obj:
345361
file_obj.close()
346362

347363

348-
async def stream_logs(domain, jwt, submission_id, container_id, step_name):
364+
async def stream_logs(domain, jwt, submission_id, container_id, step_name, cookie=None):
349365
"""Coroutine to connect to WebSocket and stream logs."""
350366
ws_protocol = "wss" if domain.startswith("https") else "ws"
351367
http_protocol = "https" if ws_protocol == "wss" else "http"
@@ -358,8 +374,12 @@ async def stream_logs(domain, jwt, submission_id, container_id, step_name):
358374
connecting_text.overflow = "fold"
359375
console.print(Panel(connecting_text, title=panel_title, border_style="green", title_align="left"))
360376

377+
ws_headers = {}
378+
if cookie:
379+
ws_headers['Cookie'] = cookie
380+
361381
try:
362-
async with websockets.connect(uri) as websocket:
382+
async with websockets.connect(uri, extra_headers=ws_headers) as websocket:
363383
async for message in websocket:
364384
try:
365385
log_data = json.loads(message)
@@ -388,10 +408,12 @@ async def stream_logs(domain, jwt, submission_id, container_id, step_name):
388408
@cli.command(help="Get the status and optionally logs of a submission. (alias: st)")
389409
@click.argument("submission_id")
390410
@click.option("--logs", "-l", is_flag=True, help="Stream logs for all workflow steps.")
391-
def status(submission_id, logs):
411+
@click.pass_context
412+
def status(ctx, submission_id, logs):
392413
config = load_config()
393414
domain = config["domain"]
394-
headers = get_api_headers()
415+
cookie = ctx.obj.get('cookie')
416+
headers = get_api_headers(cookie)
395417
sub_url = f"{domain}/api/v1/submissions/{submission_id}"
396418

397419
try:
@@ -444,7 +466,7 @@ def status(submission_id, logs):
444466
console.print(Panel(f"[dim]Logs for step '[blue]{step_name}[/blue]' are hidden by the problem setter.[/dim]", border_style="dim"))
445467
continue
446468

447-
asyncio.run(stream_logs(domain, config['jwt'], submission_id, container.get("id"), step_name))
469+
asyncio.run(stream_logs(domain, config['jwt'], submission_id, container.get("id"), step_name, cookie))
448470

449471
except requests.exceptions.RequestException as e:
450472
console.print(f"[bold red]Network Error:[/bold red] {e}")
@@ -453,14 +475,17 @@ def status(submission_id, logs):
453475

454476
@cli.command(help="Display the leaderboard for a contest. (alias: lb)")
455477
@click.argument("contest_id")
456-
def leaderboard(contest_id):
478+
@click.pass_context
479+
def leaderboard(ctx, contest_id):
457480
config = load_config()
458481
domain = config["domain"]
482+
cookie = ctx.obj.get('cookie')
483+
headers = get_api_headers(cookie)
459484

460485
try:
461486
with console.status("[bold green]Fetching contest details...[/]"):
462487
contest_url = f"{domain}/api/v1/contests/{contest_id}"
463-
contest_res = requests.get(contest_url, timeout=10)
488+
contest_res = requests.get(contest_url, headers=headers, timeout=10)
464489
if contest_res.status_code != 200:
465490
handle_api_error(contest_res)
466491
contest_data = contest_res.json().get("data", {})
@@ -469,7 +494,7 @@ def leaderboard(contest_id):
469494

470495
with console.status(f"[bold green]Fetching leaderboard for {contest_name}...[/]"):
471496
leaderboard_url = f"{domain}/api/v1/contests/{contest_id}/leaderboard"
472-
leaderboard_res = requests.get(leaderboard_url, timeout=15)
497+
leaderboard_res = requests.get(leaderboard_url, headers=headers, timeout=15)
473498
if leaderboard_res.status_code != 200:
474499
handle_api_error(leaderboard_res)
475500
leaderboard_data = leaderboard_res.json().get("data", [])
@@ -507,14 +532,17 @@ def leaderboard(contest_id):
507532

508533
@cli.command("show-problem", help="Show the description and details of a problem. (alias: show)")
509534
@click.argument("problem_id")
510-
def show_problem(problem_id):
535+
@click.pass_context
536+
def show_problem(ctx, problem_id):
511537
config = load_config()
512538
domain = config["domain"]
539+
cookie = ctx.obj.get('cookie')
540+
headers = get_api_headers(cookie)
513541
url = f"{domain}/api/v1/problems/{problem_id}"
514542

515543
try:
516544
with console.status(f"[bold green]Fetching problem {problem_id}...[/]"):
517-
response = requests.get(url, timeout=10)
545+
response = requests.get(url, headers=headers, timeout=10)
518546

519547
if response.status_code != 200:
520548
handle_api_error(response)
@@ -538,10 +566,12 @@ def show_problem(problem_id):
538566

539567

540568
@cli.command("ls-submissions", help="List your recent submissions. (alias: lss)")
541-
def list_submissions():
569+
@click.pass_context
570+
def list_submissions(ctx):
542571
config = load_config()
543572
domain = config["domain"]
544-
headers = get_api_headers()
573+
cookie = ctx.obj.get('cookie')
574+
headers = get_api_headers(cookie)
545575
url = f"{domain}/api/v1/submissions"
546576

547577
try:

0 commit comments

Comments
 (0)