feat(runtime): add reasoning toggle for ollama
This commit is contained in:
parent
8f13fee4a6
commit
a5d7911923
10 changed files with 289 additions and 31 deletions
|
|
@ -7,6 +7,7 @@ use std::collections::HashMap;
|
|||
pub struct OllamaProvider {
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
reasoning_enabled: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||
|
|
@ -18,6 +19,8 @@ struct ChatRequest {
|
|||
stream: bool,
|
||||
options: Options,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
think: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
|
|
@ -85,6 +88,14 @@ struct OllamaFunction {
|
|||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self {
|
||||
Self::new_with_reasoning(base_url, api_key, None)
|
||||
}
|
||||
|
||||
pub fn new_with_reasoning(
|
||||
base_url: Option<&str>,
|
||||
api_key: Option<&str>,
|
||||
reasoning_enabled: Option<bool>,
|
||||
) -> Self {
|
||||
let api_key = api_key.and_then(|value| {
|
||||
let trimmed = value.trim();
|
||||
(!trimmed.is_empty()).then(|| trimmed.to_string())
|
||||
|
|
@ -96,6 +107,7 @@ impl OllamaProvider {
|
|||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
api_key,
|
||||
reasoning_enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -137,6 +149,23 @@ impl OllamaProvider {
|
|||
serde_json::from_str(arguments).unwrap_or_else(|_| serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn build_chat_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> ChatRequest {
|
||||
ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
think: self.reasoning_enabled,
|
||||
tools: tools.map(|t| t.to_vec()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert internal chat history format to Ollama's native tool-call message schema.
|
||||
///
|
||||
/// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in
|
||||
|
|
@ -235,22 +264,17 @@ impl OllamaProvider {
|
|||
should_auth: bool,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
tools: tools.map(|t| t.to_vec()),
|
||||
};
|
||||
let request = self.build_chat_request(messages, model, temperature, tools);
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama request: url={} model={} message_count={} temperature={} tool_count={}",
|
||||
"Ollama request: url={} model={} message_count={} temperature={} think={:?} tool_count={}",
|
||||
url,
|
||||
model,
|
||||
request.messages.len(),
|
||||
temperature,
|
||||
request.think,
|
||||
request.tools.as_ref().map_or(0, |t| t.len()),
|
||||
);
|
||||
|
||||
|
|
@ -645,6 +669,44 @@ mod tests {
|
|||
assert!(!should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_omits_think_when_reasoning_not_configured() {
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let request = provider.build_chat_request(
|
||||
vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: None,
|
||||
tool_name: None,
|
||||
}],
|
||||
"llama3",
|
||||
0.7,
|
||||
None,
|
||||
);
|
||||
|
||||
let json = serde_json::to_value(request).unwrap();
|
||||
assert!(json.get("think").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_includes_think_when_reasoning_configured() {
|
||||
let provider = OllamaProvider::new_with_reasoning(None, None, Some(false));
|
||||
let request = provider.build_chat_request(
|
||||
vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: None,
|
||||
tool_name: None,
|
||||
}],
|
||||
"llama3",
|
||||
0.7,
|
||||
None,
|
||||
);
|
||||
|
||||
let json = serde_json::to_value(request).unwrap();
|
||||
assert_eq!(json.get("think"), Some(&serde_json::json!(false)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue