aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py243
1 files changed, 243 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
new file mode 100644
index 00000000..907cebd9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
@@ -0,0 +1,243 @@
+import logging
+import os
+from typing import Any
+
+import tiktoken
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+from openai._types import NOT_GIVEN
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+)
+
+logger = logging.getLogger()
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+ MODEL_TO_TOKENIZER = {
+ "text-embedding-ada-002": "cl100k_base",
+ "text-embedding-3-small": "cl100k_base",
+ "text-embedding-3-large": "cl100k_base",
+ }
+ MODEL_TO_DIMENSIONS = {
+ "text-embedding-ada-002": [1536],
+ "text-embedding-3-small": [512, 1536],
+ "text-embedding-3-large": [256, 1024, 3072],
+ }
+
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ if not config.provider:
+ raise ValueError(
+ "Must set provider in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if config.provider != "openai":
+ raise ValueError(
+ "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ self.client = OpenAI()
+ self.async_client = AsyncOpenAI()
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+
+ if config.base_model and "openai/" in config.base_model:
+ self.base_model = config.base_model.split("/")[-1]
+ else:
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+
+ if not self.base_model:
+ raise ValueError(
+ "Must set base_model in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(
+ f"OpenAI embedding model {self.base_model} not supported."
+ )
+
+ if self.base_dimension:
+ if (
+ self.base_dimension
+ not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ]
+ ):
+ raise ValueError(
+ f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
+ )
+ else:
+ # If base_dimension is not set, use the largest available dimension for the model
+ self.base_dimension = max(
+ OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+ )
+
+ def _get_dimensions(self):
+ return (
+ NOT_GIVEN
+ if self.base_model == "text-embedding-ada-002"
+ else self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
+ )
+
+ def _get_embedding_kwargs(self, **kwargs):
+ return {
+ "model": self.base_model,
+ "dimensions": self._get_dimensions(),
+ } | kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=texts,
+ **kwargs,
+ )
+ return [data.embedding for data in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise ValueError(error_msg) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+ try:
+ response = self.client.embeddings.create(
+ input=texts,
+ **kwargs,
+ )
+ return [data.embedding for data in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise ValueError(error_msg) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = await self._execute_with_backoff_async(task)
+ return result[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = self._execute_with_backoff_sync(task)
+ return result[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ def tokenize_string(self, text: str, model: str) -> list[int]:
+ if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(f"OpenAI embedding model {model} not supported.")
+ encoding = tiktoken.get_encoding(
+ OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+ )
+ return encoding.encode(text)