feat: add prompt caching support to Anthropic provider

Implements Anthropic's prompt caching API to enable significant cost
reduction (up to 90%) and latency improvements (up to 85%) for
requests with repeated content.

Key features:
- Auto-caching heuristics: large system prompts (>3KB), tool
  definitions, and long conversations (>4 messages)
- Full backward compatibility: cache_control fields are optional
- Supports both string and block-array system prompt formats
- Cache control on all content types (text, tool_use, tool_result)

Implementation details:
- Added CacheControl, SystemPrompt, and SystemBlock structures
- Updated NativeContentOut and NativeToolSpec with cache_control
- Strategic cache breakpoint placement (last tool, last message)
- Comprehensive test coverage for serialization and heuristics

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
(cherry picked from commit fff04f4edb5e4cb7e581b1b16035da8cc2e55cef)
This commit is contained in:
tercerapersona 2026-02-17 14:30:43 -03:00 committed by Chummy
parent 63bc4721e3
commit 455eb3b847

View file

@ -47,7 +47,7 @@ struct NativeChatRequest {
model: String,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
system: Option<SystemPrompt>,
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
@ -64,17 +64,25 @@ struct NativeMessage {
#[serde(tag = "type")]
enum NativeContentOut {
#[serde(rename = "text")]
Text { text: String },
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
@ -83,6 +91,38 @@ struct NativeToolSpec {
name: String,
description: String,
input_schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize)]
struct CacheControl {
#[serde(rename = "type")]
cache_type: String,
}
impl CacheControl {
fn ephemeral() -> Self {
Self {
cache_type: "ephemeral".to_string(),
}
}
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum SystemPrompt {
String(String),
Blocks(Vec<SystemBlock>),
}
#[derive(Debug, Serialize)]
struct SystemBlock {
#[serde(rename = "type")]
block_type: String,
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
}
#[derive(Debug, Deserialize)]
@ -147,21 +187,54 @@ impl AnthropicProvider {
}
}
/// Cache system prompts larger than ~1024 tokens (3KB of text)
fn should_cache_system(text: &str) -> bool {
text.len() > 3072
}
/// Cache conversations with more than 4 messages (excluding system)
fn should_cache_conversation(messages: &[ChatMessage]) -> bool {
messages.iter().filter(|m| m.role != "system").count() > 4
}
/// Apply cache control to the last message content block
fn apply_cache_to_last_message(messages: &mut [NativeMessage]) {
if let Some(last_msg) = messages.last_mut() {
if let Some(last_content) = last_msg.content.last_mut() {
match last_content {
NativeContentOut::Text { cache_control, .. } => {
*cache_control = Some(CacheControl::ephemeral());
}
NativeContentOut::ToolResult { cache_control, .. } => {
*cache_control = Some(CacheControl::ephemeral());
}
_ => {}
}
}
}
}
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
let items = tools?;
if items.is_empty() {
return None;
}
Some(
items
.iter()
.map(|tool| NativeToolSpec {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: tool.parameters.clone(),
})
.collect(),
)
let mut native_tools: Vec<NativeToolSpec> = items
.iter()
.map(|tool| NativeToolSpec {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: tool.parameters.clone(),
cache_control: None,
})
.collect();
// Cache the last tool definition (caches all tools)
if let Some(last_tool) = native_tools.last_mut() {
last_tool.cache_control = Some(CacheControl::ephemeral());
}
Some(native_tools)
}
fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<NativeContentOut>> {
@ -179,6 +252,7 @@ impl AnthropicProvider {
{
blocks.push(NativeContentOut::Text {
text: text.to_string(),
cache_control: None,
});
}
for call in tool_calls {
@ -188,6 +262,7 @@ impl AnthropicProvider {
id: call.id,
name: call.name,
input,
cache_control: None,
});
}
Some(blocks)
@ -209,19 +284,20 @@ impl AnthropicProvider {
content: vec![NativeContentOut::ToolResult {
tool_use_id,
content: result,
cache_control: None,
}],
})
}
fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<NativeMessage>) {
let mut system_prompt = None;
fn convert_messages(messages: &[ChatMessage]) -> (Option<SystemPrompt>, Vec<NativeMessage>) {
let mut system_text = None;
let mut native_messages = Vec::new();
for msg in messages {
match msg.role.as_str() {
"system" => {
if system_prompt.is_none() {
system_prompt = Some(msg.content.clone());
if system_text.is_none() {
system_text = Some(msg.content.clone());
}
}
"assistant" => {
@ -235,6 +311,7 @@ impl AnthropicProvider {
role: "assistant".to_string(),
content: vec![NativeContentOut::Text {
text: msg.content.clone(),
cache_control: None,
}],
});
}
@ -247,6 +324,7 @@ impl AnthropicProvider {
role: "user".to_string(),
content: vec![NativeContentOut::Text {
text: msg.content.clone(),
cache_control: None,
}],
});
}
@ -256,12 +334,26 @@ impl AnthropicProvider {
role: "user".to_string(),
content: vec![NativeContentOut::Text {
text: msg.content.clone(),
cache_control: None,
}],
});
}
}
}
// Convert system text to SystemPrompt with cache control if large
let system_prompt = system_text.map(|text| {
if Self::should_cache_system(&text) {
SystemPrompt::Blocks(vec![SystemBlock {
block_type: "text".to_string(),
text,
cache_control: Some(CacheControl::ephemeral()),
}])
} else {
SystemPrompt::String(text)
}
});
(system_prompt, native_messages)
}
@ -373,7 +465,13 @@ impl Provider for AnthropicProvider {
)
})?;
let (system_prompt, messages) = Self::convert_messages(request.messages);
let (system_prompt, mut messages) = Self::convert_messages(request.messages);
// Auto-cache last message if conversation is long
if Self::should_cache_conversation(request.messages) {
Self::apply_cache_to_last_message(&mut messages);
}
let native_request = NativeChatRequest {
model: model.to_string(),
max_tokens: 4096,
@ -621,4 +719,369 @@ mod tests {
let kind = detect_auth_kind("a.b.c", None);
assert_eq!(kind, AnthropicAuthKind::Authorization);
}
#[test]
fn cache_control_serializes_correctly() {
let cache = CacheControl::ephemeral();
let json = serde_json::to_string(&cache).unwrap();
assert_eq!(json, r#"{"type":"ephemeral"}"#);
}
#[test]
fn system_prompt_string_variant_serializes() {
let prompt = SystemPrompt::String("You are a helpful assistant".to_string());
let json = serde_json::to_string(&prompt).unwrap();
assert_eq!(json, r#""You are a helpful assistant""#);
}
#[test]
fn system_prompt_blocks_variant_serializes() {
let prompt = SystemPrompt::Blocks(vec![SystemBlock {
block_type: "text".to_string(),
text: "You are a helpful assistant".to_string(),
cache_control: Some(CacheControl::ephemeral()),
}]);
let json = serde_json::to_string(&prompt).unwrap();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains("You are a helpful assistant"));
assert!(json.contains(r#""type":"ephemeral""#));
}
#[test]
fn system_prompt_blocks_without_cache_control() {
let prompt = SystemPrompt::Blocks(vec![SystemBlock {
block_type: "text".to_string(),
text: "Short prompt".to_string(),
cache_control: None,
}]);
let json = serde_json::to_string(&prompt).unwrap();
assert!(json.contains("Short prompt"));
assert!(!json.contains("cache_control"));
}
#[test]
fn native_content_text_without_cache_control() {
let content = NativeContentOut::Text {
text: "Hello".to_string(),
cache_control: None,
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains("Hello"));
assert!(!json.contains("cache_control"));
}
#[test]
fn native_content_text_with_cache_control() {
let content = NativeContentOut::Text {
text: "Hello".to_string(),
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains("Hello"));
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
}
#[test]
fn native_content_tool_use_without_cache_control() {
let content = NativeContentOut::ToolUse {
id: "tool_123".to_string(),
name: "get_weather".to_string(),
input: serde_json::json!({"location": "San Francisco"}),
cache_control: None,
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"tool_use""#));
assert!(json.contains("tool_123"));
assert!(json.contains("get_weather"));
assert!(!json.contains("cache_control"));
}
#[test]
fn native_content_tool_result_with_cache_control() {
let content = NativeContentOut::ToolResult {
tool_use_id: "tool_123".to_string(),
content: "Result data".to_string(),
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"tool_result""#));
assert!(json.contains("tool_123"));
assert!(json.contains("Result data"));
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
}
#[test]
fn native_tool_spec_without_cache_control() {
let tool = NativeToolSpec {
name: "get_weather".to_string(),
description: "Get weather info".to_string(),
input_schema: serde_json::json!({"type": "object"}),
cache_control: None,
};
let json = serde_json::to_string(&tool).unwrap();
assert!(json.contains("get_weather"));
assert!(!json.contains("cache_control"));
}
#[test]
fn native_tool_spec_with_cache_control() {
let tool = NativeToolSpec {
name: "get_weather".to_string(),
description: "Get weather info".to_string(),
input_schema: serde_json::json!({"type": "object"}),
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_string(&tool).unwrap();
assert!(json.contains("get_weather"));
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#));
}
#[test]
fn should_cache_system_small_prompt() {
let small_prompt = "You are a helpful assistant.";
assert!(!AnthropicProvider::should_cache_system(small_prompt));
}
#[test]
fn should_cache_system_large_prompt() {
let large_prompt = "a".repeat(3073); // Just over 3072 bytes
assert!(AnthropicProvider::should_cache_system(&large_prompt));
}
#[test]
fn should_cache_system_boundary() {
let boundary_prompt = "a".repeat(3072); // Exactly 3072 bytes
assert!(!AnthropicProvider::should_cache_system(&boundary_prompt));
let over_boundary = "a".repeat(3073);
assert!(AnthropicProvider::should_cache_system(&over_boundary));
}
#[test]
fn should_cache_conversation_short() {
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "System prompt".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
ChatMessage {
role: "assistant".to_string(),
content: "Hi".to_string(),
},
];
// Only 2 non-system messages
assert!(!AnthropicProvider::should_cache_conversation(&messages));
}
#[test]
fn should_cache_conversation_long() {
let mut messages = vec![
ChatMessage {
role: "system".to_string(),
content: "System prompt".to_string(),
},
];
// Add 5 non-system messages
for i in 0..5 {
messages.push(ChatMessage {
role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
content: format!("Message {i}"),
});
}
assert!(AnthropicProvider::should_cache_conversation(&messages));
}
#[test]
fn should_cache_conversation_boundary() {
let mut messages = vec![];
// Add exactly 4 non-system messages
for i in 0..4 {
messages.push(ChatMessage {
role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
content: format!("Message {i}"),
});
}
assert!(!AnthropicProvider::should_cache_conversation(&messages));
// Add one more to cross boundary
messages.push(ChatMessage {
role: "user".to_string(),
content: "One more".to_string(),
});
assert!(AnthropicProvider::should_cache_conversation(&messages));
}
#[test]
fn apply_cache_to_last_message_text() {
let mut messages = vec![NativeMessage {
role: "user".to_string(),
content: vec![NativeContentOut::Text {
text: "Hello".to_string(),
cache_control: None,
}],
}];
AnthropicProvider::apply_cache_to_last_message(&mut messages);
match &messages[0].content[0] {
NativeContentOut::Text { cache_control, .. } => {
assert!(cache_control.is_some());
}
_ => panic!("Expected Text variant"),
}
}
#[test]
fn apply_cache_to_last_message_tool_result() {
let mut messages = vec![NativeMessage {
role: "user".to_string(),
content: vec![NativeContentOut::ToolResult {
tool_use_id: "tool_123".to_string(),
content: "Result".to_string(),
cache_control: None,
}],
}];
AnthropicProvider::apply_cache_to_last_message(&mut messages);
match &messages[0].content[0] {
NativeContentOut::ToolResult { cache_control, .. } => {
assert!(cache_control.is_some());
}
_ => panic!("Expected ToolResult variant"),
}
}
#[test]
fn apply_cache_to_last_message_does_not_affect_tool_use() {
let mut messages = vec![NativeMessage {
role: "assistant".to_string(),
content: vec![NativeContentOut::ToolUse {
id: "tool_123".to_string(),
name: "get_weather".to_string(),
input: serde_json::json!({}),
cache_control: None,
}],
}];
AnthropicProvider::apply_cache_to_last_message(&mut messages);
// ToolUse should not be affected
match &messages[0].content[0] {
NativeContentOut::ToolUse { cache_control, .. } => {
assert!(cache_control.is_none());
}
_ => panic!("Expected ToolUse variant"),
}
}
#[test]
fn apply_cache_empty_messages() {
let mut messages = vec![];
AnthropicProvider::apply_cache_to_last_message(&mut messages);
// Should not panic
assert!(messages.is_empty());
}
#[test]
fn convert_tools_adds_cache_to_last_tool() {
let tools = vec![
ToolSpec {
name: "tool1".to_string(),
description: "First tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
},
ToolSpec {
name: "tool2".to_string(),
description: "Second tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
},
];
let native_tools = AnthropicProvider::convert_tools(Some(&tools)).unwrap();
assert_eq!(native_tools.len(), 2);
assert!(native_tools[0].cache_control.is_none());
assert!(native_tools[1].cache_control.is_some());
}
#[test]
fn convert_tools_single_tool_gets_cache() {
let tools = vec![ToolSpec {
name: "tool1".to_string(),
description: "Only tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let native_tools = AnthropicProvider::convert_tools(Some(&tools)).unwrap();
assert_eq!(native_tools.len(), 1);
assert!(native_tools[0].cache_control.is_some());
}
#[test]
fn convert_messages_small_system_prompt() {
let messages = vec![ChatMessage {
role: "system".to_string(),
content: "Short system prompt".to_string(),
}];
let (system_prompt, _) = AnthropicProvider::convert_messages(&messages);
match system_prompt.unwrap() {
SystemPrompt::String(s) => {
assert_eq!(s, "Short system prompt");
}
SystemPrompt::Blocks(_) => panic!("Expected String variant for small prompt"),
}
}
#[test]
fn convert_messages_large_system_prompt() {
let large_content = "a".repeat(3073);
let messages = vec![ChatMessage {
role: "system".to_string(),
content: large_content.clone(),
}];
let (system_prompt, _) = AnthropicProvider::convert_messages(&messages);
match system_prompt.unwrap() {
SystemPrompt::Blocks(blocks) => {
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].text, large_content);
assert!(blocks[0].cache_control.is_some());
}
SystemPrompt::String(_) => panic!("Expected Blocks variant for large prompt"),
}
}
#[test]
fn backward_compatibility_native_chat_request() {
// Test that requests without cache_control serialize identically to old format
let req = NativeChatRequest {
model: "claude-3-opus".to_string(),
max_tokens: 4096,
system: Some(SystemPrompt::String("System".to_string())),
messages: vec![NativeMessage {
role: "user".to_string(),
content: vec![NativeContentOut::Text {
text: "Hello".to_string(),
cache_control: None,
}],
}],
temperature: 0.7,
tools: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("cache_control"));
assert!(json.contains(r#""system":"System""#));
}
}