fix(embeddings): normalize custom endpoint path resolution (#276)
This commit is contained in:
parent
13f6ed7871
commit
89f689c67a
1 changed files with 70 additions and 1 deletions
|
|
@ -60,6 +60,35 @@ impl OpenAiEmbedding {
|
||||||
dims,
|
dims,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn has_explicit_api_path(&self) -> bool {
|
||||||
|
let Ok(url) = reqwest::Url::parse(&self.base_url) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let path = url.path().trim_end_matches('/');
|
||||||
|
!path.is_empty() && path != "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_embeddings_endpoint(&self) -> bool {
|
||||||
|
let Ok(url) = reqwest::Url::parse(&self.base_url) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
url.path().trim_end_matches('/').ends_with("/embeddings")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embeddings_url(&self) -> String {
|
||||||
|
if self.has_embeddings_endpoint() {
|
||||||
|
return self.base_url.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.has_explicit_api_path() {
|
||||||
|
format!("{}/embeddings", self.base_url)
|
||||||
|
} else {
|
||||||
|
format!("{}/v1/embeddings", self.base_url)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -84,7 +113,7 @@ impl EmbeddingProvider for OpenAiEmbedding {
|
||||||
|
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/v1/embeddings", self.base_url))
|
.post(self.embeddings_url())
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.json(&body)
|
.json(&body)
|
||||||
|
|
@ -249,4 +278,44 @@ mod tests {
|
||||||
let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
|
let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
|
||||||
assert_eq!(p.dimensions(), 384);
|
assert_eq!(p.dimensions(), 384);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embeddings_url_standard_openai() {
|
||||||
|
let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536);
|
||||||
|
assert_eq!(p.embeddings_url(), "https://api.openai.com/v1/embeddings");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embeddings_url_base_with_v1_no_duplicate() {
|
||||||
|
let p = OpenAiEmbedding::new("https://api.example.com/v1", "key", "model", 1536);
|
||||||
|
assert_eq!(p.embeddings_url(), "https://api.example.com/v1/embeddings");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embeddings_url_non_v1_api_path_uses_raw_suffix() {
|
||||||
|
let p = OpenAiEmbedding::new(
|
||||||
|
"https://api.example.com/api/coding/v3",
|
||||||
|
"key",
|
||||||
|
"model",
|
||||||
|
1536,
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
p.embeddings_url(),
|
||||||
|
"https://api.example.com/api/coding/v3/embeddings"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embeddings_url_custom_full_endpoint() {
|
||||||
|
let p = OpenAiEmbedding::new(
|
||||||
|
"https://my-api.example.com/api/v2/embeddings",
|
||||||
|
"key",
|
||||||
|
"model",
|
||||||
|
1536,
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
p.embeddings_url(),
|
||||||
|
"https://my-api.example.com/api/v2/embeddings"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue