fix(embeddings): normalize custom endpoint path resolution (#276)
This commit is contained in:
parent
13f6ed7871
commit
89f689c67a
1 changed files with 70 additions and 1 deletions
|
|
@ -60,6 +60,35 @@ impl OpenAiEmbedding {
|
|||
dims,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_explicit_api_path(&self) -> bool {
|
||||
let Ok(url) = reqwest::Url::parse(&self.base_url) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let path = url.path().trim_end_matches('/');
|
||||
!path.is_empty() && path != "/"
|
||||
}
|
||||
|
||||
fn has_embeddings_endpoint(&self) -> bool {
|
||||
let Ok(url) = reqwest::Url::parse(&self.base_url) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
url.path().trim_end_matches('/').ends_with("/embeddings")
|
||||
}
|
||||
|
||||
fn embeddings_url(&self) -> String {
|
||||
if self.has_embeddings_endpoint() {
|
||||
return self.base_url.clone();
|
||||
}
|
||||
|
||||
if self.has_explicit_api_path() {
|
||||
format!("{}/embeddings", self.base_url)
|
||||
} else {
|
||||
format!("{}/v1/embeddings", self.base_url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -84,7 +113,7 @@ impl EmbeddingProvider for OpenAiEmbedding {
|
|||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/v1/embeddings", self.base_url))
|
||||
.post(self.embeddings_url())
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
|
|
@ -249,4 +278,44 @@ mod tests {
|
|||
let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
|
||||
assert_eq!(p.dimensions(), 384);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embeddings_url_standard_openai() {
|
||||
let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536);
|
||||
assert_eq!(p.embeddings_url(), "https://api.openai.com/v1/embeddings");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embeddings_url_base_with_v1_no_duplicate() {
|
||||
let p = OpenAiEmbedding::new("https://api.example.com/v1", "key", "model", 1536);
|
||||
assert_eq!(p.embeddings_url(), "https://api.example.com/v1/embeddings");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embeddings_url_non_v1_api_path_uses_raw_suffix() {
|
||||
let p = OpenAiEmbedding::new(
|
||||
"https://api.example.com/api/coding/v3",
|
||||
"key",
|
||||
"model",
|
||||
1536,
|
||||
);
|
||||
assert_eq!(
|
||||
p.embeddings_url(),
|
||||
"https://api.example.com/api/coding/v3/embeddings"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embeddings_url_custom_full_endpoint() {
|
||||
let p = OpenAiEmbedding::new(
|
||||
"https://my-api.example.com/api/v2/embeddings",
|
||||
"key",
|
||||
"model",
|
||||
1536,
|
||||
);
|
||||
assert_eq!(
|
||||
p.embeddings_url(),
|
||||
"https://my-api.example.com/api/v2/embeddings"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue