From 3c62b59a7264f684115c68e3cf051345deee1b4c Mon Sep 17 00:00:00 2001 From: Khoi Tran Date: Mon, 16 Feb 2026 08:42:20 -0800 Subject: [PATCH] fix(copilot): add proper OAuth device-flow authentication The existing Copilot provider passes a static Bearer token, but the Copilot API requires short-lived session tokens obtained via GitHub's OAuth device code flow, plus mandatory editor headers. This replaces the stub with a dedicated CopilotProvider that: - Runs the OAuth device code flow on first use (same client ID as VS Code) - Exchanges the OAuth token for a Copilot API key via api.github.com/copilot_internal/v2/token - Sends required Editor-Version/Editor-Plugin-Version headers - Caches tokens to disk (~/.config/zeroclaw/copilot/) with auto-refresh - Uses Mutex to prevent concurrent refresh races / duplicate device prompts - Writes token files with 0600 permissions (owner-only) - Respects GitHub's polling interval and code expiry from device flow - Sanitizes error messages to prevent token leakage - Uses async filesystem I/O (tokio::fs) throughout - Optionally accepts a pre-supplied GitHub token via config api_key Fixes: 403 'Access to this endpoint is forbidden' Fixes: 400 'missing Editor-Version header for IDE auth' --- src/providers/copilot.rs | 705 +++++++++++++++++++++++++++++++++++++++ src/providers/mod.rs | 48 ++- 2 files changed, 748 insertions(+), 5 deletions(-) create mode 100644 src/providers/copilot.rs diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs new file mode 100644 index 0000000..ab8eb3b --- /dev/null +++ b/src/providers/copilot.rs @@ -0,0 +1,705 @@ +//! GitHub Copilot provider with OAuth device-flow authentication. +//! +//! Authenticates via GitHub's device code flow (same as VS Code Copilot), +//! then exchanges the OAuth token for short-lived Copilot API keys. +//! Tokens are cached to disk and auto-refreshed. +//! +//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and +//! editor headers. This is the same approach used by LiteLLM, Codex CLI, +//! and other third-party Copilot integrations. The Copilot token endpoint is +//! private; there is no public OAuth scope or app registration for it. +//! GitHub could change or revoke this at any time, which would break all +//! third-party integrations simultaneously. + +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tracing::warn; + +/// GitHub OAuth client ID for Copilot (VS Code extension). +const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; +const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +const DEFAULT_API: &str = "https://api.githubcopilot.com"; + +// ── Token types ────────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default = "default_interval")] + interval: u64, + #[serde(default = "default_expires_in")] + expires_in: u64, +} + +fn default_interval() -> u64 { + 5 +} + +fn default_expires_in() -> u64 { + 900 +} + +#[derive(Debug, Deserialize)] +struct AccessTokenResponse { + access_token: Option, + error: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiKeyInfo { + token: String, + expires_at: i64, + #[serde(default)] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiEndpoints { + api: Option, +} + +struct CachedApiKey { + token: String, + api_endpoint: String, + expires_at: i64, +} + +// ── Chat completions types ─────────────────────────────────────── + +#[derive(Debug, Serialize)] +struct ApiChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct ApiMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct ApiChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +// ── Provider ───────────────────────────────────────────────────── + +/// GitHub Copilot provider with automatic OAuth and token refresh. +/// +/// On first use, prompts the user to visit github.com/login/device. +/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed +/// automatically. +pub struct CopilotProvider { + github_token: Option, + /// Mutex ensures only one caller refreshes tokens at a time, + /// preventing duplicate device flow prompts or redundant API calls. + refresh_lock: Arc>>, + http: Client, + token_dir: PathBuf, +} + +impl CopilotProvider { + pub fn new(github_token: Option<&str>) -> Self { + let token_dir = directories::ProjectDirs::from("", "", "zeroclaw") + .map(|dir| dir.config_dir().join("copilot")) + .unwrap_or_else(|| { + // Fall back to a user-specific temp directory to avoid + // shared-directory symlink attacks. + let user = std::env::var("USER") + .or_else(|_| std::env::var("USERNAME")) + .unwrap_or_else(|_| "unknown".to_string()); + std::env::temp_dir().join(format!("zeroclaw-copilot-{user}")) + }); + + if let Err(err) = std::fs::create_dir_all(&token_dir) { + warn!( + "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.", + token_dir + ); + } else { + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + + if let Err(err) = + std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700)) + { + warn!( + "Failed to set Copilot token directory permissions on {:?}: {err}", + token_dir + ); + } + } + } + + Self { + github_token: github_token + .filter(|token| !token.is_empty()) + .map(String::from), + refresh_lock: Arc::new(Mutex::new(None)), + http: Client::builder() + .timeout(Duration::from_secs(120)) + .connect_timeout(Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + token_dir, + } + } + + /// Required headers for Copilot API requests (editor identification). + const COPILOT_HEADERS: [(&str, &str); 4] = [ + ("Editor-Version", "vscode/1.85.1"), + ("Editor-Plugin-Version", "copilot/1.155.0"), + ("User-Agent", "GithubCopilot/1.155.0"), + ("Accept", "application/json"), + ]; + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + tools.map(|items| { + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect() + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|message| { + if message.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&message.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tool_call| NativeToolCall { + id: Some(tool_call.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tool_call.name, + arguments: tool_call.arguments, + }, + }) + .collect::>(); + + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if message.role == "tool" { + if let Ok(value) = serde_json::from_str::(&message.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + ApiMessage { + role: message.role.clone(), + content: Some(message.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + /// Send a chat completions request with required Copilot headers. + async fn send_chat_request( + &self, + messages: Vec, + tools: Option<&[ToolSpec]>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (token, endpoint) = self.get_api_key().await?; + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + + let native_tools = Self::convert_tools(tools); + let request = ApiChatRequest { + model: model.to_string(), + messages, + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let mut req = self + .http + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&request); + + for (header, value) in &Self::COPILOT_HEADERS { + req = req.header(*header, *value); + } + + let response = req.send().await?; + + if !response.status().is_success() { + return Err(super::api_error("GitHub Copilot", response).await); + } + + let api_response: ApiChatResponse = response.json().await?; + let choice = api_response + .choices + .into_iter() + .next() + .ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?; + + let tool_calls = choice + .message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tool_call| ProviderToolCall { + id: tool_call + .id + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }) + .collect(); + + Ok(ProviderChatResponse { + text: choice.message.content, + tool_calls, + }) + } + + /// Get a valid Copilot API key, refreshing or re-authenticating as needed. + /// Uses a Mutex to ensure only one caller refreshes at a time. + async fn get_api_key(&self) -> anyhow::Result<(String, String)> { + let mut cached = self.refresh_lock.lock().await; + + if let Some(cached_key) = cached.as_ref() { + if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at { + return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone())); + } + } + + if let Some(info) = self.load_api_key_from_disk().await { + if chrono::Utc::now().timestamp() + 120 < info.expires_at { + let endpoint = info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + let token = info.token; + + *cached = Some(CachedApiKey { + token: token.clone(), + api_endpoint: endpoint.clone(), + expires_at: info.expires_at, + }); + return Ok((token, endpoint)); + } + } + + let access_token = self.get_github_access_token().await?; + let api_key_info = self.exchange_for_api_key(&access_token).await?; + self.save_api_key_to_disk(&api_key_info).await; + + let endpoint = api_key_info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + + *cached = Some(CachedApiKey { + token: api_key_info.token.clone(), + api_endpoint: endpoint.clone(), + expires_at: api_key_info.expires_at, + }); + + Ok((api_key_info.token, endpoint)) + } + + /// Get a GitHub access token from config, cache, or device flow. + async fn get_github_access_token(&self) -> anyhow::Result { + if let Some(token) = &self.github_token { + return Ok(token.clone()); + } + + let access_token_path = self.token_dir.join("access-token"); + if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await { + let token = cached.trim(); + if !token.is_empty() { + return Ok(token.to_string()); + } + } + + let token = self.device_code_login().await?; + write_file_secure(&access_token_path, &token).await; + Ok(token) + } + + /// Run GitHub OAuth device code flow. + async fn device_code_login(&self) -> anyhow::Result { + let response: DeviceCodeResponse = self + .http + .post(GITHUB_DEVICE_CODE_URL) + .header("Accept", "application/json") + .json(&serde_json::json!({ + "client_id": GITHUB_CLIENT_ID, + "scope": "read:user" + })) + .send() + .await? + .error_for_status()? + .json() + .await?; + + let mut poll_interval = Duration::from_secs(response.interval.max(5)); + let expires_in = response.expires_in.max(1); + let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in); + + eprintln!( + "\nGitHub Copilot authentication is required.\n\ + Visit: {}\n\ + Code: {}\n\ + Waiting for authorization...\n", + response.verification_uri, response.user_code + ); + + while tokio::time::Instant::now() < expires_at { + tokio::time::sleep(poll_interval).await; + + let token_response: AccessTokenResponse = self + .http + .post(GITHUB_ACCESS_TOKEN_URL) + .header("Accept", "application/json") + .json(&serde_json::json!({ + "client_id": GITHUB_CLIENT_ID, + "device_code": response.device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code" + })) + .send() + .await? + .json() + .await?; + + if let Some(token) = token_response.access_token { + eprintln!("Authentication succeeded.\n"); + return Ok(token); + } + + match token_response.error.as_deref() { + Some("slow_down") => { + poll_interval += Duration::from_secs(5); + } + Some("authorization_pending") | None => {} + Some("expired_token") => { + anyhow::bail!("GitHub device authorization expired") + } + Some(error) => anyhow::bail!("GitHub auth failed: {error}"), + } + } + + anyhow::bail!("Timed out waiting for GitHub authorization") + } + + /// Exchange a GitHub access token for a Copilot API key. + async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result { + let mut request = self.http.get(GITHUB_API_KEY_URL); + for (header, value) in &Self::COPILOT_HEADERS { + request = request.header(*header, *value); + } + request = request.header("Authorization", format!("token {access_token}")); + + let response = request.send().await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + let sanitized = super::sanitize_api_error(&body); + + if status.as_u16() == 401 || status.as_u16() == 403 { + let access_token_path = self.token_dir.join("access-token"); + tokio::fs::remove_file(&access_token_path).await.ok(); + } + + anyhow::bail!( + "Failed to get Copilot API key ({status}): {sanitized}. \ + Ensure your GitHub account has an active Copilot subscription." + ); + } + + let info: ApiKeyInfo = response.json().await?; + Ok(info) + } + + async fn load_api_key_from_disk(&self) -> Option { + let path = self.token_dir.join("api-key.json"); + let data = tokio::fs::read_to_string(&path).await.ok()?; + serde_json::from_str(&data).ok() + } + + async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) { + let path = self.token_dir.join("api-key.json"); + if let Ok(json) = serde_json::to_string_pretty(info) { + write_file_secure(&path, &json).await; + } + } +} + +/// Write a file with 0600 permissions (owner read/write only). +/// Uses `spawn_blocking` to avoid blocking the async runtime. +async fn write_file_secure(path: &Path, content: &str) { + let path = path.to_path_buf(); + let content = content.to_string(); + + let result = tokio::task::spawn_blocking(move || { + #[cfg(unix)] + { + use std::io::Write; + use std::os::unix::fs::{OpenOptionsExt, PermissionsExt}; + + let mut file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(&path)?; + file.write_all(content.as_bytes())?; + + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?; + Ok::<(), std::io::Error>(()) + } + #[cfg(not(unix))] + { + std::fs::write(&path, &content)?; + Ok::<(), std::io::Error>(()) + } + }) + .await; + + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => warn!("Failed to write secure file: {err}"), + Err(err) => warn!("Failed to spawn blocking write: {err}"), + } +} + +#[async_trait] +impl Provider for CopilotProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut messages = Vec::new(); + if let Some(system) = system_prompt { + messages.push(ApiMessage { + role: "system".to_string(), + content: Some(system.to_string()), + tool_call_id: None, + tool_calls: None, + }); + } + messages.push(ApiMessage { + role: "user".to_string(), + content: Some(message.to_string()), + tool_call_id: None, + tool_calls: None, + }); + + let response = self + .send_chat_request(messages, None, model, temperature) + .await?; + Ok(response.text.unwrap_or_default()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let response = self + .send_chat_request(Self::convert_messages(messages), None, model, temperature) + .await?; + Ok(response.text.unwrap_or_default()) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.send_chat_request( + Self::convert_messages(request.messages), + request.tools, + model, + temperature, + ) + .await + } + + fn supports_native_tools(&self) -> bool { + true + } + + async fn warmup(&self) -> anyhow::Result<()> { + let _ = self.get_api_key().await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_without_token() { + let provider = CopilotProvider::new(None); + assert!(provider.github_token.is_none()); + } + + #[test] + fn new_with_token() { + let provider = CopilotProvider::new(Some("ghp_test")); + assert_eq!(provider.github_token.as_deref(), Some("ghp_test")); + } + + #[test] + fn empty_token_treated_as_none() { + let provider = CopilotProvider::new(Some("")); + assert!(provider.github_token.is_none()); + } + + #[tokio::test] + async fn cache_starts_empty() { + let provider = CopilotProvider::new(None); + let cached = provider.refresh_lock.lock().await; + assert!(cached.is_none()); + } + + #[test] + fn copilot_headers_include_required_fields() { + let headers = CopilotProvider::COPILOT_HEADERS; + assert!(headers + .iter() + .any(|(header, _)| *header == "Editor-Version")); + assert!(headers + .iter() + .any(|(header, _)| *header == "Editor-Plugin-Version")); + assert!(headers.iter().any(|(header, _)| *header == "User-Agent")); + } + + #[test] + fn default_interval_and_expiry() { + assert_eq!(default_interval(), 5); + assert_eq!(default_expires_in(), 900); + } + + #[test] + fn supports_native_tools() { + let provider = CopilotProvider::new(None); + assert!(provider.supports_native_tools()); + } +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 07c427d..1622280 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,5 +1,6 @@ pub mod anthropic; pub mod compatible; +pub mod copilot; pub mod gemini; pub mod ollama; pub mod openai; @@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize { /// Scrub known secret-like token prefixes from provider error strings. /// -/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`. +/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`, +/// `ghu_`, and `github_pat_`. pub fn scrub_secret_patterns(input: &str) -> String { - const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"]; + const PREFIXES: [&str; 7] = [ + "sk-", + "xoxb-", + "xoxp-", + "ghp_", + "gho_", + "ghu_", + "github_pat_", + ]; let mut scrubbed = input.to_string(); @@ -290,9 +300,9 @@ pub fn create_provider_with_url( "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, ))), - "copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new( - "GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer, - ))), + "copilot" | "github-copilot" => { + Ok(Box::new(copilot::CopilotProvider::new(api_key))) + }, "lmstudio" | "lm-studio" => { let lm_studio_key = api_key .map(str::trim) @@ -967,4 +977,32 @@ mod tests { let result = sanitize_api_error(input); assert_eq!(result, input); } + + #[test] + fn scrub_github_personal_access_token() { + let input = "auth failed with token ghp_abc123def456"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "auth failed with token [REDACTED]"); + } + + #[test] + fn scrub_github_oauth_token() { + let input = "Bearer gho_1234567890abcdef"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "Bearer [REDACTED]"); + } + + #[test] + fn scrub_github_user_token() { + let input = "token ghu_sessiontoken123"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "token [REDACTED]"); + } + + #[test] + fn scrub_github_fine_grained_pat() { + let input = "failed: github_pat_11AABBC_xyzzy789"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "failed: [REDACTED]"); + } }