diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py new file mode 100644 index 00000000..9507ac89 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/voyage/embedding/transformation.py @@ -0,0 +1,146 @@ +from typing import List, Optional, Union + +import httpx + +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues +from litellm.types.utils import EmbeddingResponse, Usage + + +class VoyageError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Union[dict, httpx.Headers] = {}, + ): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://api.voyageai.com/v1/embeddings" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + status_code=status_code, + message=message, + headers=headers, + ) + + +class VoyageEmbeddingConfig(BaseEmbeddingConfig): + """ + Reference: https://docs.voyageai.com/reference/embeddings-api + """ + + def __init__(self) -> None: + pass + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + if api_base: + if not api_base.endswith("/embeddings"): + api_base = f"{api_base}/embeddings" + return api_base + return "https://api.voyageai.com/v1/embeddings" + + def get_supported_openai_params(self, model: str) -> list: + return [ + "encoding_format", + "dimensions", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + Map OpenAI params to Voyage params + + Reference: https://docs.voyageai.com/reference/embeddings-api + """ + if "encoding_format" in non_default_params: + optional_params["encoding_format"] = non_default_params["encoding_format"] + if "dimensions" in non_default_params: + optional_params["output_dimension"] = non_default_params["dimensions"] + return optional_params + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("VOYAGE_API_KEY") + or get_secret_str("VOYAGE_AI_API_KEY") + or get_secret_str("VOYAGE_AI_TOKEN") + ) + return { + "Authorization": f"Bearer {api_key}", + } + + def transform_embedding_request( + self, + model: str, + input: AllEmbeddingInputValues, + optional_params: dict, + headers: dict, + ) -> dict: + return { + "input": input, + "model": model, + **optional_params, + } + + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> EmbeddingResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise VoyageError( + message=raw_response.text, status_code=raw_response.status_code + ) + + # model_response.usage + model_response.model = raw_response_json.get("model") + model_response.data = raw_response_json.get("data") + model_response.object = raw_response_json.get("object") + + usage = Usage( + prompt_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0), + total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0), + ) + model_response.usage = usage + return model_response + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return VoyageError( + message=error_message, status_code=status_code, headers=headers + ) |