Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 53 additions & 22 deletions codex-rs/rmcp-client/src/perform_oauth_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,32 @@ fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, Strin
tokio::task::spawn_blocking(move || {
while let Ok(request) = server.recv() {
let path = request.url().to_string();
if let Some(OauthCallbackResult { code, state }) = parse_oauth_callback(&path) {
let response =
Response::from_string("Authentication complete. You may close this window.");
if let Err(err) = request.respond(response) {
eprintln!("Failed to respond to OAuth callback: {err}");
match parse_oauth_callback(&path) {
CallbackOutcome::Success(OauthCallbackResult { code, state }) => {
let response = Response::from_string(
"Authentication complete. You may close this window.",
);
if let Err(err) = request.respond(response) {
eprintln!("Failed to respond to OAuth callback: {err}");
}
if let Err(err) = tx.send((code, state)) {
eprintln!("Failed to send OAuth callback: {err:?}");
}
break;
}
if let Err(err) = tx.send((code, state)) {
eprintln!("Failed to send OAuth callback: {err:?}");
CallbackOutcome::Error(description) => {
let response = Response::from_string(format!("OAuth error: {description}"))
.with_status_code(400);
if let Err(err) = request.respond(response) {
eprintln!("Failed to respond to OAuth callback: {err}");
}
}
break;
} else {
let response =
Response::from_string("Invalid OAuth callback").with_status_code(400);
if let Err(err) = request.respond(response) {
eprintln!("Failed to respond to OAuth callback: {err}");
CallbackOutcome::Invalid => {
let response =
Response::from_string("Invalid OAuth callback").with_status_code(400);
if let Err(err) = request.respond(response) {
eprintln!("Failed to respond to OAuth callback: {err}");
}
}
}
}
Expand All @@ -129,29 +140,49 @@ struct OauthCallbackResult {
state: String,
}

fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
let (route, query) = path.split_once('?')?;
enum CallbackOutcome {
Success(OauthCallbackResult),
Error(String),
Invalid,
}

fn parse_oauth_callback(path: &str) -> CallbackOutcome {
let Some((route, query)) = path.split_once('?') else {
return CallbackOutcome::Invalid;
};
if route != "/callback" {
return None;
return CallbackOutcome::Invalid;
}

let mut code = None;
let mut state = None;
let mut error_description = None;

for pair in query.split('&') {
let (key, value) = pair.split_once('=')?;
let decoded = decode(value).ok()?.into_owned();
let Some((key, value)) = pair.split_once('=') else {
continue;
};
let Ok(decoded) = decode(value) else {
continue;
};
let decoded = decoded.into_owned();
match key {
"code" => code = Some(decoded),
"state" => state = Some(decoded),
"error_description" => error_description = Some(decoded),
_ => {}
}
}

Some(OauthCallbackResult {
code: code?,
state: state?,
})
if let (Some(code), Some(state)) = (code, state) {
return CallbackOutcome::Success(OauthCallbackResult { code, state });
}

if let Some(description) = error_description {
return CallbackOutcome::Error(description);
}

CallbackOutcome::Invalid
}

pub struct OauthLoginHandle {
Expand Down
Loading