about summary refs log tree commit diff
path: root/R2R/r2r/providers/embeddings/openai
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/providers/embeddings/openai
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/providers/embeddings/openai')
-rwxr-xr-xR2R/r2r/providers/embeddings/openai/openai_base.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/R2R/r2r/providers/embeddings/openai/openai_base.py b/R2R/r2r/providers/embeddings/openai/openai_base.py
new file mode 100755
index 00000000..7e7d32aa
--- /dev/null
+++ b/R2R/r2r/providers/embeddings/openai/openai_base.py
@@ -0,0 +1,200 @@
+import logging
+import os
+
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+
+from r2r.base import EmbeddingConfig, EmbeddingProvider, VectorSearchResult
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+    MODEL_TO_TOKENIZER = {
+        "text-embedding-ada-002": "cl100k_base",
+        "text-embedding-3-small": "cl100k_base",
+        "text-embedding-3-large": "cl100k_base",
+    }
+    MODEL_TO_DIMENSIONS = {
+        "text-embedding-ada-002": [1536],
+        "text-embedding-3-small": [512, 1536],
+        "text-embedding-3-large": [256, 1024, 3072],
+    }
+
+    def __init__(self, config: EmbeddingConfig):
+        super().__init__(config)
+        provider = config.provider
+        if not provider:
+            raise ValueError(
+                "Must set provider in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if provider != "openai":
+            raise ValueError(
+                "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+            )
+        self.client = OpenAI()
+        self.async_client = AsyncOpenAI()
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+        self.base_model = config.base_model
+        self.base_dimension = config.base_dimension
+
+        if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(
+                f"OpenAI embedding model {self.base_model} not supported."
+            )
+        if (
+            self.base_dimension
+            and self.base_dimension
+            not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+        ):
+            raise ValueError(
+                f"Dimensions {self.dimension} for {self.base_model} are not supported"
+            )
+
+        if not self.base_model or not self.base_dimension:
+            raise ValueError(
+                "Must set base_model and base_dimension in order to initialize OpenAIEmbeddingProvider."
+            )
+
+        if config.rerank_model:
+            raise ValueError(
+                "OpenAIEmbeddingProvider does not support separate reranking."
+            )
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            return (
+                self.client.embeddings.create(
+                    input=[text],
+                    model=self.base_model,
+                    dimensions=self.base_dimension
+                    or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                        self.base_model
+                    ][-1],
+                )
+                .data[0]
+                .embedding
+            )
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    async def async_get_embedding(
+        self,
+        text: str,
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[float]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=[text],
+                model=self.base_model,
+                dimensions=self.base_dimension
+                or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ][-1],
+            )
+            return response.data[0].embedding
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            return [
+                ele.embedding
+                for ele in self.client.embeddings.create(
+                    input=texts,
+                    model=self.base_model,
+                    dimensions=self.base_dimension
+                    or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                        self.base_model
+                    ][-1],
+                ).data
+            ]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
+    ) -> list[list[float]]:
+        if stage != EmbeddingProvider.PipeStage.BASE:
+            raise ValueError(
+                "OpenAIEmbeddingProvider only supports search stage."
+            )
+
+        try:
+            response = await self.async_client.embeddings.create(
+                input=texts,
+                model=self.base_model,
+                dimensions=self.base_dimension
+                or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+                    self.base_model
+                ][-1],
+            )
+            return [ele.embedding for ele in response.data]
+        except AuthenticationError as e:
+            raise ValueError(
+                "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+            ) from e
+
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
+        limit: int = 10,
+    ):
+        return results[:limit]
+
+    def tokenize_string(self, text: str, model: str) -> list[int]:
+        try:
+            import tiktoken
+        except ImportError:
+            raise ValueError(
+                "Must download tiktoken library to run `tokenize_string`."
+            )
+        # tiktoken encoding -
+        # cl100k_base -	gpt-4, gpt-3.5-turbo, text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
+        if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+            raise ValueError(f"OpenAI embedding model {model} not supported.")
+        encoding = tiktoken.get_encoding(
+            OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+        )
+        return encoding.encode(text)