diff --git a/Cargo.toml b/Cargo.toml index 15d4665..81a22b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,9 @@ parking_lot = "0.12" # Async traits async-trait = "0.1" +# HMAC-SHA256 (Zhipu/GLM JWT auth) +ring = "0.17" + # Protobuf encode/decode (Feishu WS long-connection frame codec) prost = { version = "0.14", default-features = false } diff --git a/src/providers/glm.rs b/src/providers/glm.rs new file mode 100644 index 0000000..4a231c0 --- /dev/null +++ b/src/providers/glm.rs @@ -0,0 +1,278 @@ +//! Zhipu GLM provider with JWT authentication. +//! The GLM API requires JWT tokens generated from the `id.secret` API key format +//! with a custom `sign_type: "SIGN"` header, and uses `/v4/chat/completions`. + +use crate::providers::traits::Provider; +use async_trait::async_trait; +use reqwest::Client; +use ring::hmac; +use serde::{Deserialize, Serialize}; +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub struct GlmProvider { + api_key_id: String, + api_key_secret: String, + base_url: String, + client: Client, + /// Cached JWT token + expiry timestamp (ms) + token_cache: Mutex>, +} + +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec, + temperature: f64, +} + +#[derive(Debug, Serialize)] +struct Message { + role: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + content: String, +} + +/// Base64url encode without padding (per JWT spec). +fn base64url_encode_bytes(data: &[u8]) -> String { + const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut result = String::new(); + let mut i = 0; + while i < data.len() { + let b0 = data[i] as u32; + let b1 = if i + 1 < data.len() { data[i + 1] as u32 } else { 0 }; + let b2 = if i + 2 < data.len() { data[i + 2] as u32 } else { 0 }; + let triple = (b0 << 16) | (b1 << 8) | b2; + + result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); + result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char); + + if i + 1 < data.len() { + result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char); + } + if i + 2 < data.len() { + result.push(CHARS[(triple & 0x3F) as usize] as char); + } + + i += 3; + } + + // Convert to base64url: replace + with -, / with _, strip = + result.replace('+', "-").replace('/', "_") +} + +fn base64url_encode_str(s: &str) -> String { + base64url_encode_bytes(s.as_bytes()) +} + +impl GlmProvider { + pub fn new(api_key: Option<&str>) -> Self { + let (id, secret) = api_key + .and_then(|k| k.split_once('.')) + .map(|(id, secret)| (id.to_string(), secret.to_string())) + .unwrap_or_default(); + + Self { + api_key_id: id, + api_key_secret: secret, + base_url: "https://api.z.ai/api/paas/v4".to_string(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + token_cache: Mutex::new(None), + } + } + + fn generate_token(&self) -> anyhow::Result { + if self.api_key_id.is_empty() || self.api_key_secret.is_empty() { + anyhow::bail!( + "GLM API key not set or invalid format. Expected 'id.secret'. \ + Run `zeroclaw onboard` or set GLM_API_KEY env var." + ); + } + + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH)? + .as_millis() as u64; + + // Check cache (valid for 3 minutes, token expires at 3.5 min) + if let Ok(cache) = self.token_cache.lock() { + if let Some((ref token, expiry)) = *cache { + if now_ms < expiry { + return Ok(token.clone()); + } + } + } + + let exp_ms = now_ms + 210_000; // 3.5 minutes + + // Build JWT manually to include custom sign_type header + // Header: {"alg":"HS256","typ":"JWT","sign_type":"SIGN"} + let header_json = r#"{"alg":"HS256","typ":"JWT","sign_type":"SIGN"}"#; + let header_b64 = base64url_encode_str(header_json); + + // Payload: {"api_key":"...","exp":...,"timestamp":...} + let payload_json = format!( + r#"{{"api_key":"{}","exp":{},"timestamp":{}}}"#, + self.api_key_id, exp_ms, now_ms + ); + let payload_b64 = base64url_encode_str(&payload_json); + + // Sign: HMAC-SHA256(header.payload, secret) + let signing_input = format!("{header_b64}.{payload_b64}"); + let key = hmac::Key::new(hmac::HMAC_SHA256, self.api_key_secret.as_bytes()); + let signature = hmac::sign(&key, signing_input.as_bytes()); + let sig_b64 = base64url_encode_bytes(signature.as_ref()); + + let token = format!("{signing_input}.{sig_b64}"); + + // Cache for 3 minutes + if let Ok(mut cache) = self.token_cache.lock() { + *cache = Some((token.clone(), now_ms + 180_000)); + } + + Ok(token) + } +} + +#[async_trait] +impl Provider for GlmProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let token = self.generate_token()?; + + let mut messages = Vec::new(); + + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), + }); + + let request = ChatRequest { + model: model.to_string(), + messages, + temperature, + }; + + let url = format!("{}/chat/completions", self.base_url); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error = response.text().await?; + anyhow::bail!("GLM API error: {error}"); + } + + let chat_response: ChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .ok_or_else(|| anyhow::anyhow!("No response from GLM")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_api_key() { + let p = GlmProvider::new(Some("abc123.secretXYZ")); + assert_eq!(p.api_key_id, "abc123"); + assert_eq!(p.api_key_secret, "secretXYZ"); + } + + #[test] + fn handles_no_key() { + let p = GlmProvider::new(None); + assert!(p.api_key_id.is_empty()); + assert!(p.api_key_secret.is_empty()); + } + + #[test] + fn handles_invalid_key_format() { + let p = GlmProvider::new(Some("no-dot-here")); + assert!(p.api_key_id.is_empty()); + assert!(p.api_key_secret.is_empty()); + } + + #[test] + fn generates_jwt_token() { + let p = GlmProvider::new(Some("testid.testsecret")); + let token = p.generate_token().unwrap(); + assert!(!token.is_empty()); + // JWT has 3 dot-separated parts + let parts: Vec<&str> = token.split('.').collect(); + assert_eq!(parts.len(), 3, "JWT should have 3 parts: {token}"); + } + + #[test] + fn caches_token() { + let p = GlmProvider::new(Some("testid.testsecret")); + let token1 = p.generate_token().unwrap(); + let token2 = p.generate_token().unwrap(); + assert_eq!(token1, token2, "Cached token should be reused"); + } + + #[test] + fn fails_without_key() { + let p = GlmProvider::new(None); + let result = p.generate_token(); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[tokio::test] + async fn chat_fails_without_key() { + let p = GlmProvider::new(None); + let result = p + .chat_with_system(None, "hello", "glm-4.7", 0.7) + .await; + assert!(result.is_err()); + } + + #[test] + fn base64url_no_padding() { + let encoded = base64url_encode_bytes(b"hello"); + assert!(!encoded.contains('=')); + assert!(!encoded.contains('+')); + assert!(!encoded.contains('/')); + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 9cee13c..287f446 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -5,6 +5,11 @@ use std::path::PathBuf; use std::process::Command; const SERVICE_LABEL: &str = "com.zeroclaw.daemon"; +const WINDOWS_TASK_NAME: &str = "ZeroClaw Daemon"; + +fn windows_task_name() -> &'static str { + WINDOWS_TASK_NAME +} pub fn handle_command(command: &crate::ServiceCommands, config: &Config) -> Result<()> { match command { @@ -21,6 +26,8 @@ fn install(config: &Config) -> Result<()> { install_macos(config) } else if cfg!(target_os = "linux") { install_linux(config) + } else if cfg!(target_os = "windows") { + install_windows(config) } else { anyhow::bail!("Service management is supported on macOS and Linux only"); } @@ -38,6 +45,11 @@ fn start(config: &Config) -> Result<()> { run_checked(Command::new("systemctl").args(["--user", "start", "zeroclaw.service"]))?; println!("✅ Service started"); Ok(()) + } else if cfg!(target_os = "windows") { + let _ = config; + run_checked(Command::new("schtasks").args(["/Run", "/TN", windows_task_name()]))?; + println!("✅ Service started"); + Ok(()) } else { let _ = config; anyhow::bail!("Service management is supported on macOS and Linux only") @@ -60,6 +72,12 @@ fn stop(config: &Config) -> Result<()> { let _ = run_checked(Command::new("systemctl").args(["--user", "stop", "zeroclaw.service"])); println!("✅ Service stopped"); Ok(()) + } else if cfg!(target_os = "windows") { + let _ = config; + let task_name = windows_task_name(); + let _ = run_checked(Command::new("schtasks").args(["/End", "/TN", task_name])); + println!("✅ Service stopped"); + Ok(()) } else { let _ = config; anyhow::bail!("Service management is supported on macOS and Linux only") @@ -94,6 +112,32 @@ fn status(config: &Config) -> Result<()> { return Ok(()); } + if cfg!(target_os = "windows") { + let _ = config; + let task_name = windows_task_name(); + let out = run_capture( + Command::new("schtasks").args(["/Query", "/TN", task_name, "/FO", "LIST"]), + ); + match out { + Ok(text) => { + let running = text.contains("Running"); + println!( + "Service: {}", + if running { + "✅ running" + } else { + "❌ not running" + } + ); + println!("Task: {}", task_name); + } + Err(_) => { + println!("Service: ❌ not installed"); + } + } + return Ok(()); + } + anyhow::bail!("Service management is supported on macOS and Linux only") } @@ -121,6 +165,25 @@ fn uninstall(config: &Config) -> Result<()> { return Ok(()); } + if cfg!(target_os = "windows") { + let task_name = windows_task_name(); + let _ = run_checked( + Command::new("schtasks").args(["/Delete", "/TN", task_name, "/F"]), + ); + // Remove the wrapper script + let wrapper = config + .config_path + .parent() + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join("logs") + .join("zeroclaw-daemon.cmd"); + if wrapper.exists() { + fs::remove_file(&wrapper).ok(); + } + println!("✅ Service uninstalled"); + return Ok(()); + } + anyhow::bail!("Service management is supported on macOS and Linux only") } @@ -196,6 +259,57 @@ fn install_linux(config: &Config) -> Result<()> { Ok(()) } +fn install_windows(config: &Config) -> Result<()> { + let exe = std::env::current_exe().context("Failed to resolve current executable")?; + let logs_dir = config + .config_path + .parent() + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join("logs"); + fs::create_dir_all(&logs_dir)?; + + // Create a wrapper script that redirects output to log files + let wrapper = logs_dir.join("zeroclaw-daemon.cmd"); + let stdout_log = logs_dir.join("daemon.stdout.log"); + let stderr_log = logs_dir.join("daemon.stderr.log"); + + let wrapper_content = format!( + "@echo off\r\n\"{}\" daemon >>\"{}\" 2>>\"{}\"", + exe.display(), + stdout_log.display(), + stderr_log.display() + ); + fs::write(&wrapper, &wrapper_content)?; + + let task_name = windows_task_name(); + + // Remove any existing task first (ignore errors if it doesn't exist) + let _ = Command::new("schtasks") + .args(["/Delete", "/TN", task_name, "/F"]) + .output(); + + run_checked( + Command::new("schtasks").args([ + "/Create", + "/TN", + task_name, + "/SC", + "ONLOGON", + "/TR", + &format!("\"{}\"", wrapper.display()), + "/RL", + "HIGHEST", + "/F", + ]), + )?; + + println!("✅ Installed Windows scheduled task: {}", task_name); + println!(" Wrapper: {}", wrapper.display()); + println!(" Logs: {}", logs_dir.display()); + println!(" Start with: zeroclaw service start"); + Ok(()) +} + fn macos_service_file() -> Result { let home = directories::UserDirs::new() .map(|u| u.home_dir().to_path_buf()) @@ -254,6 +368,7 @@ mod tests { assert_eq!(escaped, "<&>"' and text"); } + #[cfg(not(target_os = "windows"))] #[test] fn run_capture_reads_stdout() { let out = run_capture(Command::new("sh").args(["-lc", "echo hello"])) @@ -261,6 +376,7 @@ mod tests { assert_eq!(out.trim(), "hello"); } + #[cfg(not(target_os = "windows"))] #[test] fn run_capture_falls_back_to_stderr() { let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"])) @@ -268,6 +384,7 @@ mod tests { assert_eq!(out.trim(), "warn"); } + #[cfg(not(target_os = "windows"))] #[test] fn run_checked_errors_on_non_zero_status() { let err = run_checked(Command::new("sh").args(["-lc", "exit 17"])) @@ -275,10 +392,32 @@ mod tests { assert!(err.to_string().contains("Command failed")); } + #[cfg(not(target_os = "windows"))] #[test] fn linux_service_file_has_expected_suffix() { let file = linux_service_file(&Config::default()).unwrap(); let path = file.to_string_lossy(); assert!(path.ends_with(".config/systemd/user/zeroclaw.service")); } + + #[test] + fn windows_task_name_is_constant() { + assert_eq!(windows_task_name(), "ZeroClaw Daemon"); + } + + #[cfg(target_os = "windows")] + #[test] + fn run_capture_reads_stdout_windows() { + let out = run_capture(Command::new("cmd").args(["/C", "echo hello"])) + .expect("stdout capture should succeed"); + assert_eq!(out.trim(), "hello"); + } + + #[cfg(target_os = "windows")] + #[test] + fn run_checked_errors_on_non_zero_status_windows() { + let err = run_checked(Command::new("cmd").args(["/C", "exit /b 17"])) + .expect_err("non-zero exit should error"); + assert!(err.to_string().contains("Command failed")); + } } diff --git a/src/skills/symlink_tests.rs b/src/skills/symlink_tests.rs index d768f59..c77393a 100644 --- a/src/skills/symlink_tests.rs +++ b/src/skills/symlink_tests.rs @@ -50,19 +50,22 @@ mod tests { } // Test case 3: Non-Unix platforms should handle symlink errors gracefully - #[cfg(not(unix))] + #[cfg(windows)] { let source_dir = tmp.path().join("source_skill"); std::fs::create_dir_all(&source_dir).unwrap(); let dest_link = skills_path.join("linked_skill"); - // Symlink should fail on non-Unix - let result = std::os::unix::fs::symlink(&source_dir, &dest_link); - assert!(result.is_err()); - - // Directory should not exist - assert!(!dest_link.exists()); + // On Windows, creating directory symlinks may require elevated privileges + let result = std::os::windows::fs::symlink_dir(&source_dir, &dest_link); + // If symlink creation fails (no privileges), the directory should not exist + if result.is_err() { + assert!(!dest_link.exists()); + } else { + // Clean up if it succeeded + let _ = std::fs::remove_dir(&dest_link); + } } // Test case 4: skills_dir function edge cases