feat(ollama): unify local and remote endpoint routing
Integrate cloud endpoint behavior into existing ollama provider flow, avoid a separate standalone doc, and keep configuration minimal via api_url/api_key. Also align reply_target and memory trait call sites needed for current baseline compatibility.
This commit is contained in:
parent
85de9b5625
commit
d94d7baa14
4 changed files with 195 additions and 24 deletions
|
|
@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
pub struct OllamaProvider {
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -63,12 +64,18 @@ struct OllamaFunction {
|
|||
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self {
|
||||
let api_key = api_key.and_then(|value| {
|
||||
let trimmed = value.trim();
|
||||
(!trimmed.is_empty()).then(|| trimmed.to_string())
|
||||
});
|
||||
|
||||
Self {
|
||||
base_url: base_url
|
||||
.unwrap_or("http://localhost:11434")
|
||||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
api_key,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -77,12 +84,43 @@ impl OllamaProvider {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_local_endpoint(&self) -> bool {
|
||||
reqwest::Url::parse(&self.base_url)
|
||||
.ok()
|
||||
.and_then(|url| url.host_str().map(|host| host.to_string()))
|
||||
.is_some_and(|host| matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1"))
|
||||
}
|
||||
|
||||
fn resolve_request_details(&self, model: &str) -> anyhow::Result<(String, bool)> {
|
||||
let requests_cloud = model.ends_with(":cloud");
|
||||
let normalized_model = model.strip_suffix(":cloud").unwrap_or(model).to_string();
|
||||
|
||||
if requests_cloud && self.is_local_endpoint() {
|
||||
anyhow::bail!(
|
||||
"Model '{}' requested cloud routing, but Ollama endpoint is local. Configure api_url with a remote Ollama endpoint.",
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
if requests_cloud && self.api_key.is_none() {
|
||||
anyhow::bail!(
|
||||
"Model '{}' requested cloud routing, but no API key is configured. Set OLLAMA_API_KEY or config api_key.",
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
let should_auth = self.api_key.is_some() && !self.is_local_endpoint();
|
||||
|
||||
Ok((normalized_model, should_auth))
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
|
|
@ -101,7 +139,15 @@ impl OllamaProvider {
|
|||
temperature
|
||||
);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
let mut request_builder = self.client.post(&url).json(&request);
|
||||
|
||||
if should_auth {
|
||||
if let Some(key) = self.api_key.as_ref() {
|
||||
request_builder = request_builder.bearer_auth(key);
|
||||
}
|
||||
}
|
||||
|
||||
let response = request_builder.send().await?;
|
||||
let status = response.status();
|
||||
tracing::debug!("Ollama response status: {}", status);
|
||||
|
||||
|
|
@ -220,6 +266,8 @@ impl Provider for OllamaProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
|
|
@ -234,7 +282,9 @@ impl Provider for OllamaProvider {
|
|||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let response = self.send_request(messages, model, temperature).await?;
|
||||
let response = self
|
||||
.send_request(messages, &normalized_model, temperature, should_auth)
|
||||
.await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
|
|
@ -272,6 +322,8 @@ impl Provider for OllamaProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
|
|
@ -280,7 +332,9 @@ impl Provider for OllamaProvider {
|
|||
})
|
||||
.collect();
|
||||
|
||||
let response = self.send_request(api_messages, model, temperature).await?;
|
||||
let response = self
|
||||
.send_request(api_messages, &normalized_model, temperature, should_auth)
|
||||
.await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
|
|
@ -330,28 +384,72 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn default_url() {
|
||||
let p = OllamaProvider::new(None);
|
||||
let p = OllamaProvider::new(None, None);
|
||||
assert_eq!(p.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_url_trailing_slash() {
|
||||
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"));
|
||||
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"), None);
|
||||
assert_eq!(p.base_url, "http://192.168.1.100:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_url_no_trailing_slash() {
|
||||
let p = OllamaProvider::new(Some("http://myserver:11434"));
|
||||
let p = OllamaProvider::new(Some("http://myserver:11434"), None);
|
||||
assert_eq!(p.base_url, "http://myserver:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_url_uses_empty() {
|
||||
let p = OllamaProvider::new(Some(""));
|
||||
let p = OllamaProvider::new(Some(""), None);
|
||||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cloud_suffix_strips_model_name() {
|
||||
let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key"));
|
||||
let (model, should_auth) = p.resolve_request_details("qwen3:cloud").unwrap();
|
||||
assert_eq!(model, "qwen3");
|
||||
assert!(should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cloud_suffix_with_local_endpoint_errors() {
|
||||
let p = OllamaProvider::new(None, Some("ollama-key"));
|
||||
let error = p
|
||||
.resolve_request_details("qwen3:cloud")
|
||||
.expect_err("cloud suffix should fail on local endpoint");
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("requested cloud routing, but Ollama endpoint is local"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cloud_suffix_without_api_key_errors() {
|
||||
let p = OllamaProvider::new(Some("https://ollama.com"), None);
|
||||
let error = p
|
||||
.resolve_request_details("qwen3:cloud")
|
||||
.expect_err("cloud suffix should require API key");
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("requested cloud routing, but no API key is configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_endpoint_auth_enabled_when_key_present() {
|
||||
let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key"));
|
||||
let (_model, should_auth) = p.resolve_request_details("qwen3").unwrap();
|
||||
assert!(should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_endpoint_auth_disabled_even_with_key() {
|
||||
let p = OllamaProvider::new(None, Some("ollama-key"));
|
||||
let (_model, should_auth) = p.resolve_request_details("llama3").unwrap();
|
||||
assert!(!should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
|
|
@ -392,7 +490,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_nested_tool_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
|
|
@ -410,7 +508,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_prefixed_name() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
|
|
@ -425,7 +523,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_normal_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
|
|
@ -440,7 +538,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn format_tool_calls_produces_valid_json() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tool_calls = vec![OllamaToolCall {
|
||||
id: Some("call_abc".into()),
|
||||
function: OllamaFunction {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue