aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py
new file mode 100644
index 00000000..9507ac89
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py
@@ -0,0 +1,146 @@
+from typing import List, Optional, Union
+
+import httpx
+
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
+from litellm.types.utils import EmbeddingResponse, Usage
+
+
+class VoyageError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Union[dict, httpx.Headers] = {},
+ ):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url="https://api.voyageai.com/v1/embeddings"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ status_code=status_code,
+ message=message,
+ headers=headers,
+ )
+
+
+class VoyageEmbeddingConfig(BaseEmbeddingConfig):
+ """
+ Reference: https://docs.voyageai.com/reference/embeddings-api
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def get_complete_url(
+ self,
+ api_base: Optional[str],
+ model: str,
+ optional_params: dict,
+ litellm_params: dict,
+ stream: Optional[bool] = None,
+ ) -> str:
+ if api_base:
+ if not api_base.endswith("/embeddings"):
+ api_base = f"{api_base}/embeddings"
+ return api_base
+ return "https://api.voyageai.com/v1/embeddings"
+
+ def get_supported_openai_params(self, model: str) -> list:
+ return [
+ "encoding_format",
+ "dimensions",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ """
+ Map OpenAI params to Voyage params
+
+ Reference: https://docs.voyageai.com/reference/embeddings-api
+ """
+ if "encoding_format" in non_default_params:
+ optional_params["encoding_format"] = non_default_params["encoding_format"]
+ if "dimensions" in non_default_params:
+ optional_params["output_dimension"] = non_default_params["dimensions"]
+ return optional_params
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ if api_key is None:
+ api_key = (
+ get_secret_str("VOYAGE_API_KEY")
+ or get_secret_str("VOYAGE_AI_API_KEY")
+ or get_secret_str("VOYAGE_AI_TOKEN")
+ )
+ return {
+ "Authorization": f"Bearer {api_key}",
+ }
+
+ def transform_embedding_request(
+ self,
+ model: str,
+ input: AllEmbeddingInputValues,
+ optional_params: dict,
+ headers: dict,
+ ) -> dict:
+ return {
+ "input": input,
+ "model": model,
+ **optional_params,
+ }
+
+ def transform_embedding_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: EmbeddingResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ request_data: dict = {},
+ optional_params: dict = {},
+ litellm_params: dict = {},
+ ) -> EmbeddingResponse:
+ try:
+ raw_response_json = raw_response.json()
+ except Exception:
+ raise VoyageError(
+ message=raw_response.text, status_code=raw_response.status_code
+ )
+
+ # model_response.usage
+ model_response.model = raw_response_json.get("model")
+ model_response.data = raw_response_json.get("data")
+ model_response.object = raw_response_json.get("object")
+
+ usage = Usage(
+ prompt_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
+ total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
+ )
+ model_response.usage = usage
+ return model_response
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return VoyageError(
+ message=error_message, status_code=status_code, headers=headers
+ )