From 455eb3b847601afcf1b1ea833538a4aa0142948f Mon Sep 17 00:00:00 2001 From: tercerapersona Date: Tue, 17 Feb 2026 14:30:43 -0300 Subject: [PATCH] feat: add prompt caching support to Anthropic provider Implements Anthropic's prompt caching API to enable significant cost reduction (up to 90%) and latency improvements (up to 85%) for requests with repeated content. Key features: - Auto-caching heuristics: large system prompts (>3KB), tool definitions, and long conversations (>4 messages) - Full backward compatibility: cache_control fields are optional - Supports both string and block-array system prompt formats - Cache control on all content types (text, tool_use, tool_result) Implementation details: - Added CacheControl, SystemPrompt, and SystemBlock structures - Updated NativeContentOut and NativeToolSpec with cache_control - Strategic cache breakpoint placement (last tool, last message) - Comprehensive test coverage for serialization and heuristics Co-Authored-By: Claude Sonnet 4.5 (cherry picked from commit fff04f4edb5e4cb7e581b1b16035da8cc2e55cef) --- src/providers/anthropic.rs | 497 +++++++++++++++++++++++++++++++++++-- 1 file changed, 480 insertions(+), 17 deletions(-) diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 9b3a75f..7db8f2e 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -47,7 +47,7 @@ struct NativeChatRequest { model: String, max_tokens: u32, #[serde(skip_serializing_if = "Option::is_none")] - system: Option, + system: Option, messages: Vec, temperature: f64, #[serde(skip_serializing_if = "Option::is_none")] @@ -64,17 +64,25 @@ struct NativeMessage { #[serde(tag = "type")] enum NativeContentOut { #[serde(rename = "text")] - Text { text: String }, + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, }, #[serde(rename = "tool_result")] ToolResult { tool_use_id: String, content: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, }, } @@ -83,6 +91,38 @@ struct NativeToolSpec { name: String, description: String, input_schema: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, +} + +#[derive(Debug, Clone, Serialize)] +struct CacheControl { + #[serde(rename = "type")] + cache_type: String, +} + +impl CacheControl { + fn ephemeral() -> Self { + Self { + cache_type: "ephemeral".to_string(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum SystemPrompt { + String(String), + Blocks(Vec), +} + +#[derive(Debug, Serialize)] +struct SystemBlock { + #[serde(rename = "type")] + block_type: String, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, } #[derive(Debug, Deserialize)] @@ -147,21 +187,54 @@ impl AnthropicProvider { } } + /// Cache system prompts larger than ~1024 tokens (3KB of text) + fn should_cache_system(text: &str) -> bool { + text.len() > 3072 + } + + /// Cache conversations with more than 4 messages (excluding system) + fn should_cache_conversation(messages: &[ChatMessage]) -> bool { + messages.iter().filter(|m| m.role != "system").count() > 4 + } + + /// Apply cache control to the last message content block + fn apply_cache_to_last_message(messages: &mut [NativeMessage]) { + if let Some(last_msg) = messages.last_mut() { + if let Some(last_content) = last_msg.content.last_mut() { + match last_content { + NativeContentOut::Text { cache_control, .. } => { + *cache_control = Some(CacheControl::ephemeral()); + } + NativeContentOut::ToolResult { cache_control, .. } => { + *cache_control = Some(CacheControl::ephemeral()); + } + _ => {} + } + } + } + } + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { let items = tools?; if items.is_empty() { return None; } - Some( - items - .iter() - .map(|tool| NativeToolSpec { - name: tool.name.clone(), - description: tool.description.clone(), - input_schema: tool.parameters.clone(), - }) - .collect(), - ) + let mut native_tools: Vec = items + .iter() + .map(|tool| NativeToolSpec { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: tool.parameters.clone(), + cache_control: None, + }) + .collect(); + + // Cache the last tool definition (caches all tools) + if let Some(last_tool) = native_tools.last_mut() { + last_tool.cache_control = Some(CacheControl::ephemeral()); + } + + Some(native_tools) } fn parse_assistant_tool_call_message(content: &str) -> Option> { @@ -179,6 +252,7 @@ impl AnthropicProvider { { blocks.push(NativeContentOut::Text { text: text.to_string(), + cache_control: None, }); } for call in tool_calls { @@ -188,6 +262,7 @@ impl AnthropicProvider { id: call.id, name: call.name, input, + cache_control: None, }); } Some(blocks) @@ -209,19 +284,20 @@ impl AnthropicProvider { content: vec![NativeContentOut::ToolResult { tool_use_id, content: result, + cache_control: None, }], }) } - fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec) { - let mut system_prompt = None; + fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec) { + let mut system_text = None; let mut native_messages = Vec::new(); for msg in messages { match msg.role.as_str() { "system" => { - if system_prompt.is_none() { - system_prompt = Some(msg.content.clone()); + if system_text.is_none() { + system_text = Some(msg.content.clone()); } } "assistant" => { @@ -235,6 +311,7 @@ impl AnthropicProvider { role: "assistant".to_string(), content: vec![NativeContentOut::Text { text: msg.content.clone(), + cache_control: None, }], }); } @@ -247,6 +324,7 @@ impl AnthropicProvider { role: "user".to_string(), content: vec![NativeContentOut::Text { text: msg.content.clone(), + cache_control: None, }], }); } @@ -256,12 +334,26 @@ impl AnthropicProvider { role: "user".to_string(), content: vec![NativeContentOut::Text { text: msg.content.clone(), + cache_control: None, }], }); } } } + // Convert system text to SystemPrompt with cache control if large + let system_prompt = system_text.map(|text| { + if Self::should_cache_system(&text) { + SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text, + cache_control: Some(CacheControl::ephemeral()), + }]) + } else { + SystemPrompt::String(text) + } + }); + (system_prompt, native_messages) } @@ -373,7 +465,13 @@ impl Provider for AnthropicProvider { ) })?; - let (system_prompt, messages) = Self::convert_messages(request.messages); + let (system_prompt, mut messages) = Self::convert_messages(request.messages); + + // Auto-cache last message if conversation is long + if Self::should_cache_conversation(request.messages) { + Self::apply_cache_to_last_message(&mut messages); + } + let native_request = NativeChatRequest { model: model.to_string(), max_tokens: 4096, @@ -621,4 +719,369 @@ mod tests { let kind = detect_auth_kind("a.b.c", None); assert_eq!(kind, AnthropicAuthKind::Authorization); } + + #[test] + fn cache_control_serializes_correctly() { + let cache = CacheControl::ephemeral(); + let json = serde_json::to_string(&cache).unwrap(); + assert_eq!(json, r#"{"type":"ephemeral"}"#); + } + + #[test] + fn system_prompt_string_variant_serializes() { + let prompt = SystemPrompt::String("You are a helpful assistant".to_string()); + let json = serde_json::to_string(&prompt).unwrap(); + assert_eq!(json, r#""You are a helpful assistant""#); + } + + #[test] + fn system_prompt_blocks_variant_serializes() { + let prompt = SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text: "You are a helpful assistant".to_string(), + cache_control: Some(CacheControl::ephemeral()), + }]); + let json = serde_json::to_string(&prompt).unwrap(); + assert!(json.contains(r#""type":"text""#)); + assert!(json.contains("You are a helpful assistant")); + assert!(json.contains(r#""type":"ephemeral""#)); + } + + #[test] + fn system_prompt_blocks_without_cache_control() { + let prompt = SystemPrompt::Blocks(vec![SystemBlock { + block_type: "text".to_string(), + text: "Short prompt".to_string(), + cache_control: None, + }]); + let json = serde_json::to_string(&prompt).unwrap(); + assert!(json.contains("Short prompt")); + assert!(!json.contains("cache_control")); + } + + #[test] + fn native_content_text_without_cache_control() { + let content = NativeContentOut::Text { + text: "Hello".to_string(), + cache_control: None, + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains(r#""type":"text""#)); + assert!(json.contains("Hello")); + assert!(!json.contains("cache_control")); + } + + #[test] + fn native_content_text_with_cache_control() { + let content = NativeContentOut::Text { + text: "Hello".to_string(), + cache_control: Some(CacheControl::ephemeral()), + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains(r#""type":"text""#)); + assert!(json.contains("Hello")); + assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#)); + } + + #[test] + fn native_content_tool_use_without_cache_control() { + let content = NativeContentOut::ToolUse { + id: "tool_123".to_string(), + name: "get_weather".to_string(), + input: serde_json::json!({"location": "San Francisco"}), + cache_control: None, + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains(r#""type":"tool_use""#)); + assert!(json.contains("tool_123")); + assert!(json.contains("get_weather")); + assert!(!json.contains("cache_control")); + } + + #[test] + fn native_content_tool_result_with_cache_control() { + let content = NativeContentOut::ToolResult { + tool_use_id: "tool_123".to_string(), + content: "Result data".to_string(), + cache_control: Some(CacheControl::ephemeral()), + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains(r#""type":"tool_result""#)); + assert!(json.contains("tool_123")); + assert!(json.contains("Result data")); + assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#)); + } + + #[test] + fn native_tool_spec_without_cache_control() { + let tool = NativeToolSpec { + name: "get_weather".to_string(), + description: "Get weather info".to_string(), + input_schema: serde_json::json!({"type": "object"}), + cache_control: None, + }; + let json = serde_json::to_string(&tool).unwrap(); + assert!(json.contains("get_weather")); + assert!(!json.contains("cache_control")); + } + + #[test] + fn native_tool_spec_with_cache_control() { + let tool = NativeToolSpec { + name: "get_weather".to_string(), + description: "Get weather info".to_string(), + input_schema: serde_json::json!({"type": "object"}), + cache_control: Some(CacheControl::ephemeral()), + }; + let json = serde_json::to_string(&tool).unwrap(); + assert!(json.contains("get_weather")); + assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#)); + } + + #[test] + fn should_cache_system_small_prompt() { + let small_prompt = "You are a helpful assistant."; + assert!(!AnthropicProvider::should_cache_system(small_prompt)); + } + + #[test] + fn should_cache_system_large_prompt() { + let large_prompt = "a".repeat(3073); // Just over 3072 bytes + assert!(AnthropicProvider::should_cache_system(&large_prompt)); + } + + #[test] + fn should_cache_system_boundary() { + let boundary_prompt = "a".repeat(3072); // Exactly 3072 bytes + assert!(!AnthropicProvider::should_cache_system(&boundary_prompt)); + + let over_boundary = "a".repeat(3073); + assert!(AnthropicProvider::should_cache_system(&over_boundary)); + } + + #[test] + fn should_cache_conversation_short() { + let messages = vec![ + ChatMessage { + role: "system".to_string(), + content: "System prompt".to_string(), + }, + ChatMessage { + role: "user".to_string(), + content: "Hello".to_string(), + }, + ChatMessage { + role: "assistant".to_string(), + content: "Hi".to_string(), + }, + ]; + // Only 2 non-system messages + assert!(!AnthropicProvider::should_cache_conversation(&messages)); + } + + #[test] + fn should_cache_conversation_long() { + let mut messages = vec![ + ChatMessage { + role: "system".to_string(), + content: "System prompt".to_string(), + }, + ]; + // Add 5 non-system messages + for i in 0..5 { + messages.push(ChatMessage { + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + content: format!("Message {i}"), + }); + } + assert!(AnthropicProvider::should_cache_conversation(&messages)); + } + + #[test] + fn should_cache_conversation_boundary() { + let mut messages = vec![]; + // Add exactly 4 non-system messages + for i in 0..4 { + messages.push(ChatMessage { + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + content: format!("Message {i}"), + }); + } + assert!(!AnthropicProvider::should_cache_conversation(&messages)); + + // Add one more to cross boundary + messages.push(ChatMessage { + role: "user".to_string(), + content: "One more".to_string(), + }); + assert!(AnthropicProvider::should_cache_conversation(&messages)); + } + + #[test] + fn apply_cache_to_last_message_text() { + let mut messages = vec![NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::Text { + text: "Hello".to_string(), + cache_control: None, + }], + }]; + + AnthropicProvider::apply_cache_to_last_message(&mut messages); + + match &messages[0].content[0] { + NativeContentOut::Text { cache_control, .. } => { + assert!(cache_control.is_some()); + } + _ => panic!("Expected Text variant"), + } + } + + #[test] + fn apply_cache_to_last_message_tool_result() { + let mut messages = vec![NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::ToolResult { + tool_use_id: "tool_123".to_string(), + content: "Result".to_string(), + cache_control: None, + }], + }]; + + AnthropicProvider::apply_cache_to_last_message(&mut messages); + + match &messages[0].content[0] { + NativeContentOut::ToolResult { cache_control, .. } => { + assert!(cache_control.is_some()); + } + _ => panic!("Expected ToolResult variant"), + } + } + + #[test] + fn apply_cache_to_last_message_does_not_affect_tool_use() { + let mut messages = vec![NativeMessage { + role: "assistant".to_string(), + content: vec![NativeContentOut::ToolUse { + id: "tool_123".to_string(), + name: "get_weather".to_string(), + input: serde_json::json!({}), + cache_control: None, + }], + }]; + + AnthropicProvider::apply_cache_to_last_message(&mut messages); + + // ToolUse should not be affected + match &messages[0].content[0] { + NativeContentOut::ToolUse { cache_control, .. } => { + assert!(cache_control.is_none()); + } + _ => panic!("Expected ToolUse variant"), + } + } + + #[test] + fn apply_cache_empty_messages() { + let mut messages = vec![]; + AnthropicProvider::apply_cache_to_last_message(&mut messages); + // Should not panic + assert!(messages.is_empty()); + } + + #[test] + fn convert_tools_adds_cache_to_last_tool() { + let tools = vec![ + ToolSpec { + name: "tool1".to_string(), + description: "First tool".to_string(), + parameters: serde_json::json!({"type": "object"}), + }, + ToolSpec { + name: "tool2".to_string(), + description: "Second tool".to_string(), + parameters: serde_json::json!({"type": "object"}), + }, + ]; + + let native_tools = AnthropicProvider::convert_tools(Some(&tools)).unwrap(); + + assert_eq!(native_tools.len(), 2); + assert!(native_tools[0].cache_control.is_none()); + assert!(native_tools[1].cache_control.is_some()); + } + + #[test] + fn convert_tools_single_tool_gets_cache() { + let tools = vec![ToolSpec { + name: "tool1".to_string(), + description: "Only tool".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let native_tools = AnthropicProvider::convert_tools(Some(&tools)).unwrap(); + + assert_eq!(native_tools.len(), 1); + assert!(native_tools[0].cache_control.is_some()); + } + + #[test] + fn convert_messages_small_system_prompt() { + let messages = vec![ChatMessage { + role: "system".to_string(), + content: "Short system prompt".to_string(), + }]; + + let (system_prompt, _) = AnthropicProvider::convert_messages(&messages); + + match system_prompt.unwrap() { + SystemPrompt::String(s) => { + assert_eq!(s, "Short system prompt"); + } + SystemPrompt::Blocks(_) => panic!("Expected String variant for small prompt"), + } + } + + #[test] + fn convert_messages_large_system_prompt() { + let large_content = "a".repeat(3073); + let messages = vec![ChatMessage { + role: "system".to_string(), + content: large_content.clone(), + }]; + + let (system_prompt, _) = AnthropicProvider::convert_messages(&messages); + + match system_prompt.unwrap() { + SystemPrompt::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + assert_eq!(blocks[0].text, large_content); + assert!(blocks[0].cache_control.is_some()); + } + SystemPrompt::String(_) => panic!("Expected Blocks variant for large prompt"), + } + } + + #[test] + fn backward_compatibility_native_chat_request() { + // Test that requests without cache_control serialize identically to old format + let req = NativeChatRequest { + model: "claude-3-opus".to_string(), + max_tokens: 4096, + system: Some(SystemPrompt::String("System".to_string())), + messages: vec![NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::Text { + text: "Hello".to_string(), + cache_control: None, + }], + }], + temperature: 0.7, + tools: None, + }; + + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("cache_control")); + assert!(json.contains(r#""system":"System""#)); + } }