about summary refs log tree commit diff
path: root/R2R/r2r/main/assembly/factory.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/assembly/factory.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/main/assembly/factory.py')
-rwxr-xr-xR2R/r2r/main/assembly/factory.py484
1 files changed, 484 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/factory.py b/R2R/r2r/main/assembly/factory.py
new file mode 100755
index 00000000..4e147337
--- /dev/null
+++ b/R2R/r2r/main/assembly/factory.py
@@ -0,0 +1,484 @@
+import logging
+import os
+from typing import Any, Optional
+
+from r2r.base import (
+    AsyncPipe,
+    EmbeddingConfig,
+    EmbeddingProvider,
+    EvalProvider,
+    KGProvider,
+    KVLoggingSingleton,
+    LLMConfig,
+    LLMProvider,
+    PromptProvider,
+    VectorDBConfig,
+    VectorDBProvider,
+)
+from r2r.pipelines import (
+    EvalPipeline,
+    IngestionPipeline,
+    RAGPipeline,
+    SearchPipeline,
+)
+
+from ..abstractions import R2RPipelines, R2RPipes, R2RProviders
+from .config import R2RConfig
+
+logger = logging.getLogger(__name__)
+
+
+class R2RProviderFactory:
+    def __init__(self, config: R2RConfig):
+        self.config = config
+
+    def create_vector_db_provider(
+        self, vector_db_config: VectorDBConfig, *args, **kwargs
+    ) -> VectorDBProvider:
+        vector_db_provider: Optional[VectorDBProvider] = None
+        if vector_db_config.provider == "pgvector":
+            from r2r.providers.vector_dbs import PGVectorDB
+
+            vector_db_provider = PGVectorDB(vector_db_config)
+        else:
+            raise ValueError(
+                f"Vector database provider {vector_db_config.provider} not supported"
+            )
+        if not vector_db_provider:
+            raise ValueError("Vector database provider not found")
+
+        if not self.config.embedding.base_dimension:
+            raise ValueError("Search dimension not found in embedding config")
+
+        vector_db_provider.initialize_collection(
+            self.config.embedding.base_dimension
+        )
+        return vector_db_provider
+
+    def create_embedding_provider(
+        self, embedding: EmbeddingConfig, *args, **kwargs
+    ) -> EmbeddingProvider:
+        embedding_provider: Optional[EmbeddingProvider] = None
+
+        if embedding.provider == "openai":
+            if not os.getenv("OPENAI_API_KEY"):
+                raise ValueError(
+                    "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+                )
+            from r2r.providers.embeddings import OpenAIEmbeddingProvider
+
+            embedding_provider = OpenAIEmbeddingProvider(embedding)
+        elif embedding.provider == "ollama":
+            from r2r.providers.embeddings import OllamaEmbeddingProvider
+
+            embedding_provider = OllamaEmbeddingProvider(embedding)
+
+        elif embedding.provider == "sentence-transformers":
+            from r2r.providers.embeddings import (
+                SentenceTransformerEmbeddingProvider,
+            )
+
+            embedding_provider = SentenceTransformerEmbeddingProvider(
+                embedding
+            )
+        elif embedding is None:
+            embedding_provider = None
+        else:
+            raise ValueError(
+                f"Embedding provider {embedding.provider} not supported"
+            )
+
+        return embedding_provider
+
+    def create_eval_provider(
+        self, eval_config, prompt_provider, *args, **kwargs
+    ) -> Optional[EvalProvider]:
+        if eval_config.provider == "local":
+            from r2r.providers.eval import LLMEvalProvider
+
+            llm_provider = self.create_llm_provider(eval_config.llm)
+            eval_provider = LLMEvalProvider(
+                eval_config,
+                llm_provider=llm_provider,
+                prompt_provider=prompt_provider,
+            )
+        elif eval_config.provider is None:
+            eval_provider = None
+        else:
+            raise ValueError(
+                f"Eval provider {eval_config.provider} not supported."
+            )
+
+        return eval_provider
+
+    def create_llm_provider(
+        self, llm_config: LLMConfig, *args, **kwargs
+    ) -> LLMProvider:
+        llm_provider: Optional[LLMProvider] = None
+        if llm_config.provider == "openai":
+            from r2r.providers.llms import OpenAILLM
+
+            llm_provider = OpenAILLM(llm_config)
+        elif llm_config.provider == "litellm":
+            from r2r.providers.llms import LiteLLM
+
+            llm_provider = LiteLLM(llm_config)
+        else:
+            raise ValueError(
+                f"Language model provider {llm_config.provider} not supported"
+            )
+        if not llm_provider:
+            raise ValueError("Language model provider not found")
+        return llm_provider
+
+    def create_prompt_provider(
+        self, prompt_config, *args, **kwargs
+    ) -> PromptProvider:
+        prompt_provider = None
+        if prompt_config.provider == "local":
+            from r2r.prompts import R2RPromptProvider
+
+            prompt_provider = R2RPromptProvider()
+        else:
+            raise ValueError(
+                f"Prompt provider {prompt_config.provider} not supported"
+            )
+        return prompt_provider
+
+    def create_kg_provider(self, kg_config, *args, **kwargs):
+        if kg_config.provider == "neo4j":
+            from r2r.providers.kg import Neo4jKGProvider
+
+            return Neo4jKGProvider(kg_config)
+        elif kg_config.provider is None:
+            return None
+        else:
+            raise ValueError(
+                f"KG provider {kg_config.provider} not supported."
+            )
+
+    def create_providers(
+        self,
+        vector_db_provider_override: Optional[VectorDBProvider] = None,
+        embedding_provider_override: Optional[EmbeddingProvider] = None,
+        eval_provider_override: Optional[EvalProvider] = None,
+        llm_provider_override: Optional[LLMProvider] = None,
+        prompt_provider_override: Optional[PromptProvider] = None,
+        kg_provider_override: Optional[KGProvider] = None,
+        *args,
+        **kwargs,
+    ) -> R2RProviders:
+        prompt_provider = (
+            prompt_provider_override
+            or self.create_prompt_provider(self.config.prompt, *args, **kwargs)
+        )
+        return R2RProviders(
+            vector_db=vector_db_provider_override
+            or self.create_vector_db_provider(
+                self.config.vector_database, *args, **kwargs
+            ),
+            embedding=embedding_provider_override
+            or self.create_embedding_provider(
+                self.config.embedding, *args, **kwargs
+            ),
+            eval=eval_provider_override
+            or self.create_eval_provider(
+                self.config.eval,
+                prompt_provider=prompt_provider,
+                *args,
+                **kwargs,
+            ),
+            llm=llm_provider_override
+            or self.create_llm_provider(
+                self.config.completions, *args, **kwargs
+            ),
+            prompt=prompt_provider_override
+            or self.create_prompt_provider(
+                self.config.prompt, *args, **kwargs
+            ),
+            kg=kg_provider_override
+            or self.create_kg_provider(self.config.kg, *args, **kwargs),
+        )
+
+
+class R2RPipeFactory:
+    def __init__(self, config: R2RConfig, providers: R2RProviders):
+        self.config = config
+        self.providers = providers
+
+    def create_pipes(
+        self,
+        parsing_pipe_override: Optional[AsyncPipe] = None,
+        embedding_pipe_override: Optional[AsyncPipe] = None,
+        kg_pipe_override: Optional[AsyncPipe] = None,
+        kg_storage_pipe_override: Optional[AsyncPipe] = None,
+        kg_agent_pipe_override: Optional[AsyncPipe] = None,
+        vector_storage_pipe_override: Optional[AsyncPipe] = None,
+        vector_search_pipe_override: Optional[AsyncPipe] = None,
+        rag_pipe_override: Optional[AsyncPipe] = None,
+        streaming_rag_pipe_override: Optional[AsyncPipe] = None,
+        eval_pipe_override: Optional[AsyncPipe] = None,
+        *args,
+        **kwargs,
+    ) -> R2RPipes:
+        return R2RPipes(
+            parsing_pipe=parsing_pipe_override
+            or self.create_parsing_pipe(
+                self.config.ingestion.get("excluded_parsers"), *args, **kwargs
+            ),
+            embedding_pipe=embedding_pipe_override
+            or self.create_embedding_pipe(*args, **kwargs),
+            kg_pipe=kg_pipe_override or self.create_kg_pipe(*args, **kwargs),
+            kg_storage_pipe=kg_storage_pipe_override
+            or self.create_kg_storage_pipe(*args, **kwargs),
+            kg_agent_search_pipe=kg_agent_pipe_override
+            or self.create_kg_agent_pipe(*args, **kwargs),
+            vector_storage_pipe=vector_storage_pipe_override
+            or self.create_vector_storage_pipe(*args, **kwargs),
+            vector_search_pipe=vector_search_pipe_override
+            or self.create_vector_search_pipe(*args, **kwargs),
+            rag_pipe=rag_pipe_override
+            or self.create_rag_pipe(*args, **kwargs),
+            streaming_rag_pipe=streaming_rag_pipe_override
+            or self.create_rag_pipe(stream=True, *args, **kwargs),
+            eval_pipe=eval_pipe_override
+            or self.create_eval_pipe(*args, **kwargs),
+        )
+
+    def create_parsing_pipe(
+        self, excluded_parsers: Optional[list] = None, *args, **kwargs
+    ) -> Any:
+        from r2r.pipes import ParsingPipe
+
+        return ParsingPipe(excluded_parsers=excluded_parsers or [])
+
+    def create_embedding_pipe(self, *args, **kwargs) -> Any:
+        if self.config.embedding.provider is None:
+            return None
+
+        from r2r.base import RecursiveCharacterTextSplitter
+        from r2r.pipes import EmbeddingPipe
+
+        text_splitter_config = self.config.embedding.extra_fields.get(
+            "text_splitter"
+        )
+        if not text_splitter_config:
+            raise ValueError(
+                "Text splitter config not found in embedding config"
+            )
+
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=text_splitter_config["chunk_size"],
+            chunk_overlap=text_splitter_config["chunk_overlap"],
+            length_function=len,
+            is_separator_regex=False,
+        )
+        return EmbeddingPipe(
+            embedding_provider=self.providers.embedding,
+            vector_db_provider=self.providers.vector_db,
+            text_splitter=text_splitter,
+            embedding_batch_size=self.config.embedding.batch_size,
+        )
+
+    def create_vector_storage_pipe(self, *args, **kwargs) -> Any:
+        if self.config.embedding.provider is None:
+            return None
+
+        from r2r.pipes import VectorStoragePipe
+
+        return VectorStoragePipe(vector_db_provider=self.providers.vector_db)
+
+    def create_vector_search_pipe(self, *args, **kwargs) -> Any:
+        if self.config.embedding.provider is None:
+            return None
+
+        from r2r.pipes import VectorSearchPipe
+
+        return VectorSearchPipe(
+            vector_db_provider=self.providers.vector_db,
+            embedding_provider=self.providers.embedding,
+        )
+
+    def create_kg_pipe(self, *args, **kwargs) -> Any:
+        if self.config.kg.provider is None:
+            return None
+
+        from r2r.base import RecursiveCharacterTextSplitter
+        from r2r.pipes import KGExtractionPipe
+
+        text_splitter_config = self.config.kg.extra_fields.get("text_splitter")
+        if not text_splitter_config:
+            raise ValueError("Text splitter config not found in kg config.")
+
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=text_splitter_config["chunk_size"],
+            chunk_overlap=text_splitter_config["chunk_overlap"],
+            length_function=len,
+            is_separator_regex=False,
+        )
+        return KGExtractionPipe(
+            kg_provider=self.providers.kg,
+            llm_provider=self.providers.llm,
+            prompt_provider=self.providers.prompt,
+            vector_db_provider=self.providers.vector_db,
+            text_splitter=text_splitter,
+            kg_batch_size=self.config.kg.batch_size,
+        )
+
+    def create_kg_storage_pipe(self, *args, **kwargs) -> Any:
+        if self.config.kg.provider is None:
+            return None
+
+        from r2r.pipes import KGStoragePipe
+
+        return KGStoragePipe(
+            kg_provider=self.providers.kg,
+            embedding_provider=self.providers.embedding,
+        )
+
+    def create_kg_agent_pipe(self, *args, **kwargs) -> Any:
+        if self.config.kg.provider is None:
+            return None
+
+        from r2r.pipes import KGAgentSearchPipe
+
+        return KGAgentSearchPipe(
+            kg_provider=self.providers.kg,
+            llm_provider=self.providers.llm,
+            prompt_provider=self.providers.prompt,
+        )
+
+    def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
+        if stream:
+            from r2r.pipes import StreamingSearchRAGPipe
+
+            return StreamingSearchRAGPipe(
+                llm_provider=self.providers.llm,
+                prompt_provider=self.providers.prompt,
+            )
+        else:
+            from r2r.pipes import SearchRAGPipe
+
+            return SearchRAGPipe(
+                llm_provider=self.providers.llm,
+                prompt_provider=self.providers.prompt,
+            )
+
+    def create_eval_pipe(self, *args, **kwargs) -> Any:
+        from r2r.pipes import EvalPipe
+
+        return EvalPipe(eval_provider=self.providers.eval)
+
+
+class R2RPipelineFactory:
+    def __init__(self, config: R2RConfig, pipes: R2RPipes):
+        self.config = config
+        self.pipes = pipes
+
+    def create_ingestion_pipeline(self, *args, **kwargs) -> IngestionPipeline:
+        """factory method to create an ingestion pipeline."""
+        ingestion_pipeline = IngestionPipeline()
+
+        ingestion_pipeline.add_pipe(
+            pipe=self.pipes.parsing_pipe, parsing_pipe=True
+        )
+        # Add embedding pipes if provider is set
+        if self.config.embedding.provider is not None:
+            ingestion_pipeline.add_pipe(
+                self.pipes.embedding_pipe, embedding_pipe=True
+            )
+            ingestion_pipeline.add_pipe(
+                self.pipes.vector_storage_pipe, embedding_pipe=True
+            )
+        # Add KG pipes if provider is set
+        if self.config.kg.provider is not None:
+            ingestion_pipeline.add_pipe(self.pipes.kg_pipe, kg_pipe=True)
+            ingestion_pipeline.add_pipe(
+                self.pipes.kg_storage_pipe, kg_pipe=True
+            )
+
+        return ingestion_pipeline
+
+    def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline:
+        """factory method to create an ingestion pipeline."""
+        search_pipeline = SearchPipeline()
+
+        # Add vector search pipes if embedding provider and vector provider is set
+        if (
+            self.config.embedding.provider is not None
+            and self.config.vector_database.provider is not None
+        ):
+            search_pipeline.add_pipe(
+                self.pipes.vector_search_pipe, vector_search_pipe=True
+            )
+
+        # Add KG pipes if provider is set
+        if self.config.kg.provider is not None:
+            search_pipeline.add_pipe(
+                self.pipes.kg_agent_search_pipe, kg_pipe=True
+            )
+
+        return search_pipeline
+
+    def create_rag_pipeline(
+        self,
+        search_pipeline: SearchPipeline,
+        stream: bool = False,
+        *args,
+        **kwargs,
+    ) -> RAGPipeline:
+        rag_pipe = (
+            self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe
+        )
+
+        rag_pipeline = RAGPipeline()
+        rag_pipeline.set_search_pipeline(search_pipeline)
+        rag_pipeline.add_pipe(rag_pipe)
+        return rag_pipeline
+
+    def create_eval_pipeline(self, *args, **kwargs) -> EvalPipeline:
+        eval_pipeline = EvalPipeline()
+        eval_pipeline.add_pipe(self.pipes.eval_pipe)
+        return eval_pipeline
+
+    def create_pipelines(
+        self,
+        ingestion_pipeline: Optional[IngestionPipeline] = None,
+        search_pipeline: Optional[SearchPipeline] = None,
+        rag_pipeline: Optional[RAGPipeline] = None,
+        streaming_rag_pipeline: Optional[RAGPipeline] = None,
+        eval_pipeline: Optional[EvalPipeline] = None,
+        *args,
+        **kwargs,
+    ) -> R2RPipelines:
+        try:
+            self.configure_logging()
+        except Exception as e:
+            logger.warn(f"Error configuring logging: {e}")
+        search_pipeline = search_pipeline or self.create_search_pipeline(
+            *args, **kwargs
+        )
+        return R2RPipelines(
+            ingestion_pipeline=ingestion_pipeline
+            or self.create_ingestion_pipeline(*args, **kwargs),
+            search_pipeline=search_pipeline,
+            rag_pipeline=rag_pipeline
+            or self.create_rag_pipeline(
+                search_pipeline=search_pipeline,
+                stream=False,
+                *args,
+                **kwargs,
+            ),
+            streaming_rag_pipeline=streaming_rag_pipeline
+            or self.create_rag_pipeline(
+                search_pipeline=search_pipeline,
+                stream=True,
+                *args,
+                **kwargs,
+            ),
+            eval_pipeline=eval_pipeline
+            or self.create_eval_pipeline(*args, **kwargs),
+        )
+
+    def configure_logging(self):
+        KVLoggingSingleton.configure(self.config.logging)