From 808450c48ef461e211f826f388edf783b7bce38f Mon Sep 17 00:00:00 2001 From: Kieran Date: Mon, 16 Feb 2026 22:25:23 +0000 Subject: [PATCH] feat: custom global api_url --- src/agent/agent.rs | 1 + src/agent/loop_.rs | 2 ++ src/channels/mod.rs | 1 + src/config/schema.rs | 5 +++++ src/gateway/mod.rs | 1 + src/onboard/wizard.rs | 2 ++ src/providers/mod.rs | 41 +++++++++++++++++++++++++++++++---------- 7 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 23c0cbf..44e40b6 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -251,6 +251,7 @@ impl Agent { let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, &config.model_routes, &model_name, diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 8356d33..4f4d84c 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -749,6 +749,7 @@ pub async fn run( let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, &config.model_routes, model_name, @@ -1105,6 +1106,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, &config.model_routes, &model_name, diff --git a/src/channels/mod.rs b/src/channels/mod.rs index a132eae..d46a998 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -762,6 +762,7 @@ pub async fn start_channels(config: Config) -> Result<()> { let provider: Arc = Arc::from(providers::create_resilient_provider( &provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, )?); diff --git a/src/config/schema.rs b/src/config/schema.rs index dbb6a78..d78e53f 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -18,6 +18,8 @@ pub struct Config { #[serde(skip)] pub config_path: PathBuf, pub api_key: Option, + /// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama) + pub api_url: Option, pub default_provider: Option, pub default_model: Option, pub default_temperature: f64, @@ -1594,6 +1596,7 @@ impl Default for Config { workspace_dir: zeroclaw_dir.join("workspace"), config_path: zeroclaw_dir.join("config.toml"), api_key: None, + api_url: None, default_provider: Some("openrouter".to_string()), default_model: Some("anthropic/claude-sonnet-4".to_string()), default_temperature: 0.7, @@ -1984,6 +1987,7 @@ default_temperature = 0.7 workspace_dir: PathBuf::from("/tmp/test/workspace"), config_path: PathBuf::from("/tmp/test/config.toml"), api_key: Some("sk-test-key".into()), + api_url: None, default_provider: Some("openrouter".into()), default_model: Some("gpt-4o".into()), default_temperature: 0.5, @@ -2126,6 +2130,7 @@ tool_dispatcher = "xml" workspace_dir: dir.join("workspace"), config_path: config_path.clone(), api_key: Some("sk-roundtrip".into()), + api_url: None, default_provider: Some("openrouter".into()), default_model: Some("test-model".into()), default_temperature: 0.9, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index c5d4da3..132aed1 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -209,6 +209,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let provider: Arc = Arc::from(providers::create_resilient_provider( config.default_provider.as_deref().unwrap_or("openrouter"), config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, )?); let model = config diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 20c3baa..8355c1e 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -106,6 +106,7 @@ pub fn run_wizard() -> Result { } else { Some(api_key) }, + api_url: None, default_provider: Some(provider), default_model: Some(model), default_temperature: 0.7, @@ -319,6 +320,7 @@ pub fn run_quick_setup( workspace_dir: workspace_dir.clone(), config_path: config_path.clone(), api_key: api_key.map(String::from), + api_url: None, default_provider: Some(provider_name.clone()), default_model: Some(model.clone()), default_temperature: 0.7, diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 86517d6..7ee24b0 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -182,9 +182,18 @@ fn parse_custom_provider_url( } } -/// Factory: create the right provider from config -#[allow(clippy::too_many_lines)] +/// Factory: create the right provider from config (without custom URL) pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { + create_provider_with_url(name, api_key, None) +} + +/// Factory: create the right provider from config with optional custom base URL +#[allow(clippy::too_many_lines)] +pub fn create_provider_with_url( + name: &str, + api_key: Option<&str>, + api_url: Option<&str>, +) -> anyhow::Result> { let resolved_key = resolve_api_key(name, api_key); let key = resolved_key.as_deref(); match name { @@ -192,9 +201,8 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(openrouter::OpenRouterProvider::new(key))), "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))), - // Ollama is a local service that doesn't use API keys. - // The api_key parameter is ignored to avoid it being misinterpreted as a base_url. - "ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))), + // Ollama uses api_url for custom base URL (e.g. remote Ollama instance) + "ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))), "gemini" | "google" | "google-gemini" => { Ok(Box::new(gemini::GeminiProvider::new(key))) } @@ -326,13 +334,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result, + api_url: Option<&str>, reliability: &crate::config::ReliabilityConfig, ) -> anyhow::Result> { let mut providers: Vec<(String, Box)> = Vec::new(); providers.push(( primary_name.to_string(), - create_provider(primary_name, api_key)?, + create_provider_with_url(primary_name, api_key, api_url)?, )); for fallback in &reliability.fallback_providers { @@ -349,6 +358,7 @@ pub fn create_resilient_provider( ); } + // Fallback providers don't use the custom api_url (it's specific to primary) match create_provider(fallback, api_key) { Ok(provider) => providers.push((fallback.clone(), provider)), Err(e) => { @@ -377,12 +387,13 @@ pub fn create_resilient_provider( pub fn create_routed_provider( primary_name: &str, api_key: Option<&str>, + api_url: Option<&str>, reliability: &crate::config::ReliabilityConfig, model_routes: &[crate::config::ModelRouteConfig], default_model: &str, ) -> anyhow::Result> { if model_routes.is_empty() { - return create_resilient_provider(primary_name, api_key, reliability); + return create_resilient_provider(primary_name, api_key, api_url, reliability); } // Collect unique provider names needed @@ -401,7 +412,9 @@ pub fn create_routed_provider( .find(|r| &r.provider == name) .and_then(|r| r.api_key.as_deref()) .or(api_key); - match create_resilient_provider(name, key, reliability) { + // Only use api_url for the primary provider + let url = if name == primary_name { api_url } else { None }; + match create_resilient_provider(name, key, url, reliability) { Ok(provider) => providers.push((name.clone(), provider)), Err(e) => { if name == primary_name { @@ -761,17 +774,25 @@ mod tests { scheduler_retries: 2, }; - let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability); + let provider = create_resilient_provider("openrouter", Some("sk-test"), None, &reliability); assert!(provider.is_ok()); } #[test] fn resilient_provider_errors_for_invalid_primary() { let reliability = crate::config::ReliabilityConfig::default(); - let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability); + let provider = + create_resilient_provider("totally-invalid", Some("sk-test"), None, &reliability); assert!(provider.is_err()); } + #[test] + fn ollama_with_custom_url() { + let provider = + create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434")); + assert!(provider.is_ok()); + } + #[test] fn factory_all_providers_create_successfully() { let providers = [