diff --git a/src/providers/glm.rs b/src/providers/glm.rs index 4a231c0..43a259b 100644 --- a/src/providers/glm.rs +++ b/src/providers/glm.rs @@ -2,7 +2,7 @@ //! The GLM API requires JWT tokens generated from the `id.secret` API key format //! with a custom `sign_type: "SIGN"` header, and uses `/v4/chat/completions`. -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatMessage, Provider}; use async_trait::async_trait; use reqwest::Client; use ring::hmac; @@ -206,6 +206,53 @@ impl Provider for GlmProvider { .map(|c| c.message.content) .ok_or_else(|| anyhow::anyhow!("No response from GLM")) } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let token = self.generate_token()?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ChatRequest { + model: model.to_string(), + messages: api_messages, + temperature, + }; + + let url = format!("{}/chat/completions", self.base_url); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error = response.text().await?; + anyhow::bail!("GLM API error: {error}"); + } + + let chat_response: ChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .ok_or_else(|| anyhow::anyhow!("No response from GLM")) + } } #[cfg(test)] @@ -268,6 +315,19 @@ mod tests { assert!(result.is_err()); } + #[tokio::test] + async fn chat_with_history_fails_without_key() { + let p = GlmProvider::new(None); + let messages = vec![ + ChatMessage::system("You are helpful."), + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi there!"), + ChatMessage::user("What did I say?"), + ]; + let result = p.chat_with_history(&messages, "glm-4.7", 0.7).await; + assert!(result.is_err()); + } + #[test] fn base64url_no_padding() { let encoded = base64url_encode_bytes(b"hello");