fix PR #200 review issues

This commit is contained in:
Codex 2026-02-15 21:55:54 +03:00 committed by Chummy
parent 39087a446d
commit e8aa63822a
4 changed files with 87 additions and 48 deletions

View file

@ -3,7 +3,9 @@ pub mod openai_oauth;
pub mod profiles;
use crate::auth::openai_oauth::refresh_access_token;
use crate::auth::profiles::{profile_id, AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore};
use crate::auth::profiles::{
profile_id, AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore,
};
use crate::config::Config;
use anyhow::Result;
use std::collections::HashMap;
@ -131,7 +133,12 @@ impl AuthService {
&self,
profile_override: Option<&str>,
) -> Result<Option<String>> {
let data = self.store.load()?;
let data = tokio::task::spawn_blocking({
let store = self.store.clone();
move || store.load()
})
.await
.map_err(|err| anyhow::anyhow!("Auth profile load task failed: {err}"))??;
let Some(profile_id) = select_profile_id(&data, OPENAI_CODEX_PROVIDER, profile_override)
else {
return Ok(None);
@ -157,7 +164,12 @@ impl AuthService {
let _guard = refresh_lock.lock().await;
// Re-load after waiting for lock to avoid duplicate refreshes.
let data = self.store.load()?;
let data = tokio::task::spawn_blocking({
let store = self.store.clone();
move || store.load()
})
.await
.map_err(|err| anyhow::anyhow!("Auth profile load task failed: {err}"))??;
let Some(latest_profile) = data.profiles.get(&profile_id) else {
return Ok(None);
};
@ -170,10 +182,7 @@ impl AuthService {
return Ok(Some(latest_tokens.access_token.clone()));
}
let refresh_token = latest_tokens
.refresh_token
.clone()
.unwrap_or(refresh_token);
let refresh_token = latest_tokens.refresh_token.clone().unwrap_or(refresh_token);
if let Some(remaining) = refresh_backoff_remaining(&profile_id) {
anyhow::bail!(
@ -200,16 +209,25 @@ impl AuthService {
.clone_from(&latest_tokens.refresh_token);
}
let refreshed_clone = refreshed.clone();
let account_id = openai_oauth::extract_account_id_from_jwt(&refreshed.access_token)
.or_else(|| latest_profile.account_id.clone());
let updated = self.store.update_profile(&profile_id, |profile| {
profile.kind = AuthProfileKind::OAuth;
profile.token_set = Some(refreshed_clone.clone());
profile.account_id.clone_from(&account_id);
Ok(())
})?;
let updated = tokio::task::spawn_blocking({
let store = self.store.clone();
let profile_id = profile_id.clone();
let refreshed = refreshed.clone();
let account_id = account_id.clone();
move || {
store.update_profile(&profile_id, |profile| {
profile.kind = AuthProfileKind::OAuth;
profile.token_set = Some(refreshed.clone());
profile.account_id.clone_from(&account_id);
Ok(())
})
}
})
.await
.map_err(|err| anyhow::anyhow!("Auth profile update task failed: {err}"))??;
Ok(updated.token_set.map(|t| t.access_token))
}

View file

@ -96,20 +96,17 @@ pub fn build_authorize_url(pkce: &PkceState) -> String {
let mut encoded: Vec<String> = Vec::with_capacity(params.len());
for (k, v) in params {
encoded.push(format!(
"{}={}",
url_encode(k),
url_encode(v)
));
encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
}
format!(
"{OPENAI_OAUTH_AUTHORIZE_URL}?{}",
encoded.join("&")
)
format!("{OPENAI_OAUTH_AUTHORIZE_URL}?{}", encoded.join("&"))
}
pub async fn exchange_code_for_tokens(client: &Client, code: &str, pkce: &PkceState) -> Result<TokenSet> {
pub async fn exchange_code_for_tokens(
client: &Client,
code: &str,
pkce: &PkceState,
) -> Result<TokenSet> {
let form = [
("grant_type", "authorization_code"),
("code", code),
@ -180,7 +177,10 @@ pub async fn start_device_code_flow(client: &Client) -> Result<DeviceCodeStart>
})
}
pub async fn poll_device_code_tokens(client: &Client, device: &DeviceCodeStart) -> Result<TokenSet> {
pub async fn poll_device_code_tokens(
client: &Client,
device: &DeviceCodeStart,
) -> Result<TokenSet> {
let started = Instant::now();
let mut interval_secs = device.interval.max(1);
@ -269,7 +269,8 @@ pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> R
let code = parse_code_from_redirect(path, Some(expected_state))?;
let body = "<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
let body =
"<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
@ -282,12 +283,8 @@ pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> R
pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
let trimmed = input.trim();
if !trimmed.contains("code=") {
if trimmed.is_empty() {
anyhow::bail!("No OAuth code provided");
}
return Ok(trimmed.to_string());
if trimmed.is_empty() {
anyhow::bail!("No OAuth code provided");
}
let query = if let Some((_, right)) = trimmed.split_once('?') {
@ -297,6 +294,10 @@ pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Re
};
let params = parse_query_params(query);
let is_callback_payload = trimmed.contains('?')
|| params.contains_key("code")
|| params.contains_key("state")
|| params.contains_key("error");
if let Some(err) = params.get("error") {
let desc = params
@ -307,18 +308,24 @@ pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Re
}
if let Some(expected_state) = expected_state {
let got = params
.get("state")
.ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?;
if got != expected_state {
anyhow::bail!("OAuth state mismatch");
if let Some(got) = params.get("state") {
if got != expected_state {
anyhow::bail!("OAuth state mismatch");
}
} else if is_callback_payload {
anyhow::bail!("Missing OAuth state in callback");
}
}
params
.get("code")
.cloned()
.ok_or_else(|| anyhow::anyhow!("Missing OAuth code in callback"))
if let Some(code) = params.get("code").cloned() {
return Ok(code);
}
if !is_callback_payload {
return Ok(trimmed.to_string());
}
anyhow::bail!("Missing OAuth code in callback")
}
pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
@ -478,6 +485,18 @@ mod tests {
assert!(err.to_string().contains("state mismatch"));
}
#[test]
fn parse_redirect_rejects_error_without_code() {
let err = parse_code_from_redirect(
"/auth/callback?error=access_denied&error_description=user+cancelled",
Some("xyz"),
)
.unwrap_err();
assert!(err
.to_string()
.contains("OpenAI OAuth error: access_denied"));
}
#[test]
fn extract_account_id_from_jwt_payload() {
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");