fix PR #200 review issues
This commit is contained in:
parent
39087a446d
commit
e8aa63822a
4 changed files with 87 additions and 48 deletions
|
|
@ -3,7 +3,9 @@ pub mod openai_oauth;
|
||||||
pub mod profiles;
|
pub mod profiles;
|
||||||
|
|
||||||
use crate::auth::openai_oauth::refresh_access_token;
|
use crate::auth::openai_oauth::refresh_access_token;
|
||||||
use crate::auth::profiles::{profile_id, AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore};
|
use crate::auth::profiles::{
|
||||||
|
profile_id, AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore,
|
||||||
|
};
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -131,7 +133,12 @@ impl AuthService {
|
||||||
&self,
|
&self,
|
||||||
profile_override: Option<&str>,
|
profile_override: Option<&str>,
|
||||||
) -> Result<Option<String>> {
|
) -> Result<Option<String>> {
|
||||||
let data = self.store.load()?;
|
let data = tokio::task::spawn_blocking({
|
||||||
|
let store = self.store.clone();
|
||||||
|
move || store.load()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|err| anyhow::anyhow!("Auth profile load task failed: {err}"))??;
|
||||||
let Some(profile_id) = select_profile_id(&data, OPENAI_CODEX_PROVIDER, profile_override)
|
let Some(profile_id) = select_profile_id(&data, OPENAI_CODEX_PROVIDER, profile_override)
|
||||||
else {
|
else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
|
|
@ -157,7 +164,12 @@ impl AuthService {
|
||||||
let _guard = refresh_lock.lock().await;
|
let _guard = refresh_lock.lock().await;
|
||||||
|
|
||||||
// Re-load after waiting for lock to avoid duplicate refreshes.
|
// Re-load after waiting for lock to avoid duplicate refreshes.
|
||||||
let data = self.store.load()?;
|
let data = tokio::task::spawn_blocking({
|
||||||
|
let store = self.store.clone();
|
||||||
|
move || store.load()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|err| anyhow::anyhow!("Auth profile load task failed: {err}"))??;
|
||||||
let Some(latest_profile) = data.profiles.get(&profile_id) else {
|
let Some(latest_profile) = data.profiles.get(&profile_id) else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
@ -170,10 +182,7 @@ impl AuthService {
|
||||||
return Ok(Some(latest_tokens.access_token.clone()));
|
return Ok(Some(latest_tokens.access_token.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let refresh_token = latest_tokens
|
let refresh_token = latest_tokens.refresh_token.clone().unwrap_or(refresh_token);
|
||||||
.refresh_token
|
|
||||||
.clone()
|
|
||||||
.unwrap_or(refresh_token);
|
|
||||||
|
|
||||||
if let Some(remaining) = refresh_backoff_remaining(&profile_id) {
|
if let Some(remaining) = refresh_backoff_remaining(&profile_id) {
|
||||||
anyhow::bail!(
|
anyhow::bail!(
|
||||||
|
|
@ -200,16 +209,25 @@ impl AuthService {
|
||||||
.clone_from(&latest_tokens.refresh_token);
|
.clone_from(&latest_tokens.refresh_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
let refreshed_clone = refreshed.clone();
|
|
||||||
let account_id = openai_oauth::extract_account_id_from_jwt(&refreshed.access_token)
|
let account_id = openai_oauth::extract_account_id_from_jwt(&refreshed.access_token)
|
||||||
.or_else(|| latest_profile.account_id.clone());
|
.or_else(|| latest_profile.account_id.clone());
|
||||||
|
|
||||||
let updated = self.store.update_profile(&profile_id, |profile| {
|
let updated = tokio::task::spawn_blocking({
|
||||||
|
let store = self.store.clone();
|
||||||
|
let profile_id = profile_id.clone();
|
||||||
|
let refreshed = refreshed.clone();
|
||||||
|
let account_id = account_id.clone();
|
||||||
|
move || {
|
||||||
|
store.update_profile(&profile_id, |profile| {
|
||||||
profile.kind = AuthProfileKind::OAuth;
|
profile.kind = AuthProfileKind::OAuth;
|
||||||
profile.token_set = Some(refreshed_clone.clone());
|
profile.token_set = Some(refreshed.clone());
|
||||||
profile.account_id.clone_from(&account_id);
|
profile.account_id.clone_from(&account_id);
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|err| anyhow::anyhow!("Auth profile update task failed: {err}"))??;
|
||||||
|
|
||||||
Ok(updated.token_set.map(|t| t.access_token))
|
Ok(updated.token_set.map(|t| t.access_token))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -96,20 +96,17 @@ pub fn build_authorize_url(pkce: &PkceState) -> String {
|
||||||
|
|
||||||
let mut encoded: Vec<String> = Vec::with_capacity(params.len());
|
let mut encoded: Vec<String> = Vec::with_capacity(params.len());
|
||||||
for (k, v) in params {
|
for (k, v) in params {
|
||||||
encoded.push(format!(
|
encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
|
||||||
"{}={}",
|
|
||||||
url_encode(k),
|
|
||||||
url_encode(v)
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
format!(
|
format!("{OPENAI_OAUTH_AUTHORIZE_URL}?{}", encoded.join("&"))
|
||||||
"{OPENAI_OAUTH_AUTHORIZE_URL}?{}",
|
|
||||||
encoded.join("&")
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn exchange_code_for_tokens(client: &Client, code: &str, pkce: &PkceState) -> Result<TokenSet> {
|
pub async fn exchange_code_for_tokens(
|
||||||
|
client: &Client,
|
||||||
|
code: &str,
|
||||||
|
pkce: &PkceState,
|
||||||
|
) -> Result<TokenSet> {
|
||||||
let form = [
|
let form = [
|
||||||
("grant_type", "authorization_code"),
|
("grant_type", "authorization_code"),
|
||||||
("code", code),
|
("code", code),
|
||||||
|
|
@ -180,7 +177,10 @@ pub async fn start_device_code_flow(client: &Client) -> Result<DeviceCodeStart>
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn poll_device_code_tokens(client: &Client, device: &DeviceCodeStart) -> Result<TokenSet> {
|
pub async fn poll_device_code_tokens(
|
||||||
|
client: &Client,
|
||||||
|
device: &DeviceCodeStart,
|
||||||
|
) -> Result<TokenSet> {
|
||||||
let started = Instant::now();
|
let started = Instant::now();
|
||||||
let mut interval_secs = device.interval.max(1);
|
let mut interval_secs = device.interval.max(1);
|
||||||
|
|
||||||
|
|
@ -269,7 +269,8 @@ pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> R
|
||||||
|
|
||||||
let code = parse_code_from_redirect(path, Some(expected_state))?;
|
let code = parse_code_from_redirect(path, Some(expected_state))?;
|
||||||
|
|
||||||
let body = "<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
|
let body =
|
||||||
|
"<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
|
||||||
let response = format!(
|
let response = format!(
|
||||||
"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||||
body.len(),
|
body.len(),
|
||||||
|
|
@ -282,13 +283,9 @@ pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> R
|
||||||
|
|
||||||
pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
|
pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
|
|
||||||
if !trimmed.contains("code=") {
|
|
||||||
if trimmed.is_empty() {
|
if trimmed.is_empty() {
|
||||||
anyhow::bail!("No OAuth code provided");
|
anyhow::bail!("No OAuth code provided");
|
||||||
}
|
}
|
||||||
return Ok(trimmed.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let query = if let Some((_, right)) = trimmed.split_once('?') {
|
let query = if let Some((_, right)) = trimmed.split_once('?') {
|
||||||
right
|
right
|
||||||
|
|
@ -297,6 +294,10 @@ pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Re
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = parse_query_params(query);
|
let params = parse_query_params(query);
|
||||||
|
let is_callback_payload = trimmed.contains('?')
|
||||||
|
|| params.contains_key("code")
|
||||||
|
|| params.contains_key("state")
|
||||||
|
|| params.contains_key("error");
|
||||||
|
|
||||||
if let Some(err) = params.get("error") {
|
if let Some(err) = params.get("error") {
|
||||||
let desc = params
|
let desc = params
|
||||||
|
|
@ -307,18 +308,24 @@ pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Re
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(expected_state) = expected_state {
|
if let Some(expected_state) = expected_state {
|
||||||
let got = params
|
if let Some(got) = params.get("state") {
|
||||||
.get("state")
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?;
|
|
||||||
if got != expected_state {
|
if got != expected_state {
|
||||||
anyhow::bail!("OAuth state mismatch");
|
anyhow::bail!("OAuth state mismatch");
|
||||||
}
|
}
|
||||||
|
} else if is_callback_payload {
|
||||||
|
anyhow::bail!("Missing OAuth state in callback");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
params
|
if let Some(code) = params.get("code").cloned() {
|
||||||
.get("code")
|
return Ok(code);
|
||||||
.cloned()
|
}
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing OAuth code in callback"))
|
|
||||||
|
if !is_callback_payload {
|
||||||
|
return Ok(trimmed.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Missing OAuth code in callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
|
pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
|
||||||
|
|
@ -478,6 +485,18 @@ mod tests {
|
||||||
assert!(err.to_string().contains("state mismatch"));
|
assert!(err.to_string().contains("state mismatch"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_redirect_rejects_error_without_code() {
|
||||||
|
let err = parse_code_from_redirect(
|
||||||
|
"/auth/callback?error=access_denied&error_description=user+cancelled",
|
||||||
|
Some("xyz"),
|
||||||
|
)
|
||||||
|
.unwrap_err();
|
||||||
|
assert!(err
|
||||||
|
.to_string()
|
||||||
|
.contains("OpenAI OAuth error: access_denied"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_account_id_from_jwt_payload() {
|
fn extract_account_id_from_jwt_payload() {
|
||||||
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");
|
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");
|
||||||
|
|
|
||||||
|
|
@ -455,7 +455,7 @@ impl Default for PeripheralBoardConfig {
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GatewayConfig {
|
pub struct GatewayConfig {
|
||||||
/// Gateway port (default: 8080)
|
/// Gateway port (default: 3000)
|
||||||
#[serde(default = "default_gateway_port")]
|
#[serde(default = "default_gateway_port")]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
/// Gateway host (default: 127.0.0.1)
|
/// Gateway host (default: 127.0.0.1)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::auth::AuthService;
|
|
||||||
use crate::auth::openai_oauth::extract_account_id_from_jwt;
|
use crate::auth::openai_oauth::extract_account_id_from_jwt;
|
||||||
|
use crate::auth::AuthService;
|
||||||
use crate::providers::traits::Provider;
|
use crate::providers::traits::Provider;
|
||||||
use crate::providers::ProviderRuntimeOptions;
|
use crate::providers::ProviderRuntimeOptions;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
@ -9,7 +9,8 @@ use serde_json::Value;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
const CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
|
const CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
|
||||||
const DEFAULT_CODEX_INSTRUCTIONS: &str = "You are ZeroClaw, a concise and helpful coding assistant.";
|
const DEFAULT_CODEX_INSTRUCTIONS: &str =
|
||||||
|
"You are ZeroClaw, a concise and helpful coding assistant.";
|
||||||
|
|
||||||
pub struct OpenAiCodexProvider {
|
pub struct OpenAiCodexProvider {
|
||||||
auth: AuthService,
|
auth: AuthService,
|
||||||
|
|
@ -140,13 +141,13 @@ fn clamp_reasoning_effort(model: &str, effort: &str) -> String {
|
||||||
effort.to_string()
|
effort.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_reasoning_effort(model: &str) -> String {
|
fn resolve_reasoning_effort(model_id: &str) -> String {
|
||||||
let raw = std::env::var("ZEROCLAW_CODEX_REASONING_EFFORT")
|
let raw = std::env::var("ZEROCLAW_CODEX_REASONING_EFFORT")
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|value| first_nonempty(Some(&value)))
|
.and_then(|value| first_nonempty(Some(&value)))
|
||||||
.unwrap_or_else(|| "xhigh".to_string())
|
.unwrap_or_else(|| "xhigh".to_string())
|
||||||
.to_ascii_lowercase();
|
.to_ascii_lowercase();
|
||||||
clamp_reasoning_effort(model, &raw)
|
clamp_reasoning_effort(model_id, &raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nonempty_preserve(text: Option<&str>) -> Option<String> {
|
fn nonempty_preserve(text: Option<&str>) -> Option<String> {
|
||||||
|
|
@ -363,9 +364,10 @@ impl Provider for OpenAiCodexProvider {
|
||||||
"OpenAI Codex account id not found in auth profile/token. Run `zeroclaw auth login --provider openai-codex` again."
|
"OpenAI Codex account id not found in auth profile/token. Run `zeroclaw auth login --provider openai-codex` again."
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
let normalized_model = normalize_model_id(model);
|
||||||
|
|
||||||
let request = ResponsesRequest {
|
let request = ResponsesRequest {
|
||||||
model: model.to_string(),
|
model: normalized_model.to_string(),
|
||||||
input: vec![ResponsesInput {
|
input: vec![ResponsesInput {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: vec![ResponsesInputContent {
|
content: vec![ResponsesInputContent {
|
||||||
|
|
@ -380,7 +382,7 @@ impl Provider for OpenAiCodexProvider {
|
||||||
verbosity: "medium".to_string(),
|
verbosity: "medium".to_string(),
|
||||||
},
|
},
|
||||||
reasoning: ResponsesReasoningOptions {
|
reasoning: ResponsesReasoningOptions {
|
||||||
effort: resolve_reasoning_effort(model),
|
effort: resolve_reasoning_effort(normalized_model),
|
||||||
summary: "auto".to_string(),
|
summary: "auto".to_string(),
|
||||||
},
|
},
|
||||||
include: vec!["reasoning.encrypted_content".to_string()],
|
include: vec!["reasoning.encrypted_content".to_string()],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue