about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py243
1 files changed, 243 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
new file mode 100644
index 00000000..907cebd9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
@@ -0,0 +1,243 @@
+import logging
+import os
+from typing import Any
+
+import tiktoken
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+from openai._types import NOT_GIVEN
+
+from core.base import (
+    ChunkSearchResult,
+    EmbeddingConfig,
+    EmbeddingProvider,
+    EmbeddingPurpose,
+)
+
+logger = logging.getLogger()
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+    MODEL_TO_TOKENIZER = {
+        "text-embedding-ada-002": "cl100k_base",
+        "text-embedding-3-small": "cl100k_base",
+        "text-embedding-3-large": "cl100k_base",
+    }
+    MODEL_TO_DIMENSIONS = {
+        "text-embedding-ada-002": [1536],
+        "text-embedding-3-small": [512, 1536],
+        "text-embedding-3-large": [256, 1024, 3072],
+    }
+
+    def __init__(self, config: EmbeddingConfig):
+        super().__init__(config)
+        if not config.provider:
+            raise ValueError(
+                "Must set provider in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if config.provider != "openai":
+            raise ValueError(
+                "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+            )
+        self.client = OpenAI()
+        self.async_client = AsyncOpenAI()
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+
+        if config.base_model and "openai/" in config.base_model:
+            self.base_model = config.base_model.split("/")[-1]
+        else:
+            self.base_model = config.base_model
+        self.base_dimension = config.base_dimension
+
+        if not self.base_model:
+            raise ValueError(
+                "Must set base_model in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(
+                f"OpenAI embedding model {self.base_model} not supported."
+            )
+
+        if self.base_dimension:
+            if (
+                self.base_dimension
+                not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ]
+            ):
+                raise ValueError(
+                    f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
+                )
+        else:
+            # If base_dimension is not set, use the largest available dimension for the model
+            self.base_dimension = max(
+                OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+            )
+
+    def _get_dimensions(self):
+        return (
+            NOT_GIVEN
+            if self.base_model == "text-embedding-ada-002"
+            else self.base_dimension
+            or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
+        )
+
+    def _get_embedding_kwargs(self, **kwargs):
+        return {
+            "model": self.base_model,
+            "dimensions": self._get_dimensions(),
+        } | kwargs
+
+    async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+        texts = task["texts"]
+        kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=texts,
+                **kwargs,
+            )
+            return [data.embedding for data in response.data]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+        except Exception as e:
+            error_msg = f"Error getting embeddings: {str(e)}"
+            logger.error(error_msg)
+            raise ValueError(error_msg) from e
+
+    def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+        texts = task["texts"]
+        kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+        try:
+            response = self.client.embeddings.create(
+                input=texts,
+                **kwargs,
+            )
+            return [data.embedding for data in response.data]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+        except Exception as e:
+            error_msg = f"Error getting embeddings: {str(e)}"
+            logger.error(error_msg)
+            raise ValueError(error_msg) from e
+
+    async def async_get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+        **kwargs,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.Step.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        task = {
+            "texts": [text],
+            "stage": stage,
+            "purpose": purpose,
+            "kwargs": kwargs,
+        }
+        result = await self._execute_with_backoff_async(task)
+        return result[0]
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+        **kwargs,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.Step.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        task = {
+            "texts": [text],
+            "stage": stage,
+            "purpose": purpose,
+            "kwargs": kwargs,
+        }
+        result = self._execute_with_backoff_sync(task)
+        return result[0]
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+        **kwargs,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.Step.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        task = {
+            "texts": texts,
+            "stage": stage,
+            "purpose": purpose,
+            "kwargs": kwargs,
+        }
+        return await self._execute_with_backoff_async(task)
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+        **kwargs,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.Step.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        task = {
+            "texts": texts,
+            "stage": stage,
+            "purpose": purpose,
+            "kwargs": kwargs,
+        }
+        return self._execute_with_backoff_sync(task)
+
+    def rerank(
+        self,
+        query: str,
+        results: list[ChunkSearchResult],
+        stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+        limit: int = 10,
+    ):
+        return results[:limit]
+
+    async def arerank(
+        self,
+        query: str,
+        results: list[ChunkSearchResult],
+        stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+        limit: int = 10,
+    ):
+        return results[:limit]
+
+    def tokenize_string(self, text: str, model: str) -> list[int]:
+        if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(f"OpenAI embedding model {model} not supported.")
+        encoding = tiktoken.get_encoding(
+            OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+        )
+        return encoding.encode(text)