diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py new file mode 100644 index 00000000..907cebd9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py @@ -0,0 +1,243 @@ +import logging +import os +from typing import Any + +import tiktoken +from openai import AsyncOpenAI, AuthenticationError, OpenAI +from openai._types import NOT_GIVEN + +from core.base import ( + ChunkSearchResult, + EmbeddingConfig, + EmbeddingProvider, + EmbeddingPurpose, +) + +logger = logging.getLogger() + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + MODEL_TO_TOKENIZER = { + "text-embedding-ada-002": "cl100k_base", + "text-embedding-3-small": "cl100k_base", + "text-embedding-3-large": "cl100k_base", + } + MODEL_TO_DIMENSIONS = { + "text-embedding-ada-002": [1536], + "text-embedding-3-small": [512, 1536], + "text-embedding-3-large": [256, 1024, 3072], + } + + def __init__(self, config: EmbeddingConfig): + super().__init__(config) + if not config.provider: + raise ValueError( + "Must set provider in order to initialize OpenAIEmbeddingProvider." + ) + + if config.provider != "openai": + raise ValueError( + "OpenAIEmbeddingProvider must be initialized with provider `openai`." + ) + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + self.client = OpenAI() + self.async_client = AsyncOpenAI() + + if config.rerank_model: + raise ValueError( + "OpenAIEmbeddingProvider does not support separate reranking." + ) + + if config.base_model and "openai/" in config.base_model: + self.base_model = config.base_model.split("/")[-1] + else: + self.base_model = config.base_model + self.base_dimension = config.base_dimension + + if not self.base_model: + raise ValueError( + "Must set base_model in order to initialize OpenAIEmbeddingProvider." + ) + + if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER: + raise ValueError( + f"OpenAI embedding model {self.base_model} not supported." + ) + + if self.base_dimension: + if ( + self.base_dimension + not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[ + self.base_model + ] + ): + raise ValueError( + f"Dimensions {self.base_dimension} for {self.base_model} are not supported" + ) + else: + # If base_dimension is not set, use the largest available dimension for the model + self.base_dimension = max( + OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model] + ) + + def _get_dimensions(self): + return ( + NOT_GIVEN + if self.base_model == "text-embedding-ada-002" + else self.base_dimension + or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1] + ) + + def _get_embedding_kwargs(self, **kwargs): + return { + "model": self.base_model, + "dimensions": self._get_dimensions(), + } | kwargs + + async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + + try: + response = await self.async_client.embeddings.create( + input=texts, + **kwargs, + ) + return [data.embedding for data in response.data] + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) from e + + def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: + texts = task["texts"] + kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) + try: + response = self.client.embeddings.create( + input=texts, + **kwargs, + ) + return [data.embedding for data in response.data] + except AuthenticationError as e: + raise ValueError( + "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable." + ) from e + except Exception as e: + error_msg = f"Error getting embeddings: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) from e + + async def async_get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = await self._execute_with_backoff_async(task) + return result[0] + + def get_embedding( + self, + text: str, + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[float]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": [text], + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + result = self._execute_with_backoff_sync(task) + return result[0] + + async def async_get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + **kwargs, + ) -> list[list[float]]: + if stage != EmbeddingProvider.Step.BASE: + raise ValueError( + "OpenAIEmbeddingProvider only supports search stage." + ) + + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + "kwargs": kwargs, + } + return self._execute_with_backoff_sync(task) + + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + return results[:limit] + + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, + limit: int = 10, + ): + return results[:limit] + + def tokenize_string(self, text: str, model: str) -> list[int]: + if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER: + raise ValueError(f"OpenAI embedding model {model} not supported.") + encoding = tiktoken.get_encoding( + OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model] + ) + return encoding.encode(text) |