about summary refs log tree commit diff
path: root/R2R/r2r/main/assembly/builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/assembly/builder.py')
-rwxr-xr-xR2R/r2r/main/assembly/builder.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/builder.py b/R2R/r2r/main/assembly/builder.py
new file mode 100755
index 00000000..863fc6d0
--- /dev/null
+++ b/R2R/r2r/main/assembly/builder.py
@@ -0,0 +1,207 @@
+import os
+from typing import Optional, Type
+
+from r2r.base import (
+    AsyncPipe,
+    EmbeddingProvider,
+    EvalProvider,
+    LLMProvider,
+    PromptProvider,
+    VectorDBProvider,
+)
+from r2r.pipelines import (
+    EvalPipeline,
+    IngestionPipeline,
+    RAGPipeline,
+    SearchPipeline,
+)
+
+from ..app import R2RApp
+from ..engine import R2REngine
+from ..r2r import R2R
+from .config import R2RConfig
+from .factory import R2RPipeFactory, R2RPipelineFactory, R2RProviderFactory
+
+
+class R2RBuilder:
+    current_file_path = os.path.dirname(__file__)
+    config_root = os.path.join(
+        current_file_path, "..", "..", "examples", "configs"
+    )
+    CONFIG_OPTIONS = {
+        "default": None,
+        "local_ollama": os.path.join(config_root, "local_ollama.json"),
+        "local_ollama_rerank": os.path.join(
+            config_root, "local_ollama_rerank.json"
+        ),
+        "neo4j_kg": os.path.join(config_root, "neo4j_kg.json"),
+        "local_neo4j_kg": os.path.join(config_root, "local_neo4j_kg.json"),
+        "postgres_logging": os.path.join(config_root, "postgres_logging.json"),
+    }
+
+    @staticmethod
+    def _get_config(config_name):
+        if config_name is None:
+            return R2RConfig.from_json()
+        if config_name in R2RBuilder.CONFIG_OPTIONS:
+            return R2RConfig.from_json(R2RBuilder.CONFIG_OPTIONS[config_name])
+        raise ValueError(f"Invalid config name: {config_name}")
+
+    def __init__(
+        self,
+        config: Optional[R2RConfig] = None,
+        from_config: Optional[str] = None,
+    ):
+        if config and from_config:
+            raise ValueError("Cannot specify both config and config_name")
+        self.config = config or R2RBuilder._get_config(from_config)
+        self.r2r_app_override: Optional[Type[R2REngine]] = None
+        self.provider_factory_override: Optional[Type[R2RProviderFactory]] = (
+            None
+        )
+        self.pipe_factory_override: Optional[R2RPipeFactory] = None
+        self.pipeline_factory_override: Optional[R2RPipelineFactory] = None
+        self.vector_db_provider_override: Optional[VectorDBProvider] = None
+        self.embedding_provider_override: Optional[EmbeddingProvider] = None
+        self.eval_provider_override: Optional[EvalProvider] = None
+        self.llm_provider_override: Optional[LLMProvider] = None
+        self.prompt_provider_override: Optional[PromptProvider] = None
+        self.parsing_pipe_override: Optional[AsyncPipe] = None
+        self.embedding_pipe_override: Optional[AsyncPipe] = None
+        self.vector_storage_pipe_override: Optional[AsyncPipe] = None
+        self.vector_search_pipe_override: Optional[AsyncPipe] = None
+        self.rag_pipe_override: Optional[AsyncPipe] = None
+        self.streaming_rag_pipe_override: Optional[AsyncPipe] = None
+        self.eval_pipe_override: Optional[AsyncPipe] = None
+        self.ingestion_pipeline: Optional[IngestionPipeline] = None
+        self.search_pipeline: Optional[SearchPipeline] = None
+        self.rag_pipeline: Optional[RAGPipeline] = None
+        self.streaming_rag_pipeline: Optional[RAGPipeline] = None
+        self.eval_pipeline: Optional[EvalPipeline] = None
+
+    def with_app(self, app: Type[R2REngine]):
+        self.r2r_app_override = app
+        return self
+
+    def with_provider_factory(self, factory: Type[R2RProviderFactory]):
+        self.provider_factory_override = factory
+        return self
+
+    def with_pipe_factory(self, factory: R2RPipeFactory):
+        self.pipe_factory_override = factory
+        return self
+
+    def with_pipeline_factory(self, factory: R2RPipelineFactory):
+        self.pipeline_factory_override = factory
+        return self
+
+    def with_vector_db_provider(self, provider: VectorDBProvider):
+        self.vector_db_provider_override = provider
+        return self
+
+    def with_embedding_provider(self, provider: EmbeddingProvider):
+        self.embedding_provider_override = provider
+        return self
+
+    def with_eval_provider(self, provider: EvalProvider):
+        self.eval_provider_override = provider
+        return self
+
+    def with_llm_provider(self, provider: LLMProvider):
+        self.llm_provider_override = provider
+        return self
+
+    def with_prompt_provider(self, provider: PromptProvider):
+        self.prompt_provider_override = provider
+        return self
+
+    def with_parsing_pipe(self, pipe: AsyncPipe):
+        self.parsing_pipe_override = pipe
+        return self
+
+    def with_embedding_pipe(self, pipe: AsyncPipe):
+        self.embedding_pipe_override = pipe
+        return self
+
+    def with_vector_storage_pipe(self, pipe: AsyncPipe):
+        self.vector_storage_pipe_override = pipe
+        return self
+
+    def with_vector_search_pipe(self, pipe: AsyncPipe):
+        self.vector_search_pipe_override = pipe
+        return self
+
+    def with_rag_pipe(self, pipe: AsyncPipe):
+        self.rag_pipe_override = pipe
+        return self
+
+    def with_streaming_rag_pipe(self, pipe: AsyncPipe):
+        self.streaming_rag_pipe_override = pipe
+        return self
+
+    def with_eval_pipe(self, pipe: AsyncPipe):
+        self.eval_pipe_override = pipe
+        return self
+
+    def with_ingestion_pipeline(self, pipeline: IngestionPipeline):
+        self.ingestion_pipeline = pipeline
+        return self
+
+    def with_vector_search_pipeline(self, pipeline: SearchPipeline):
+        self.search_pipeline = pipeline
+        return self
+
+    def with_rag_pipeline(self, pipeline: RAGPipeline):
+        self.rag_pipeline = pipeline
+        return self
+
+    def with_streaming_rag_pipeline(self, pipeline: RAGPipeline):
+        self.streaming_rag_pipeline = pipeline
+        return self
+
+    def with_eval_pipeline(self, pipeline: EvalPipeline):
+        self.eval_pipeline = pipeline
+        return self
+
+    def build(self, *args, **kwargs) -> R2R:
+        provider_factory = self.provider_factory_override or R2RProviderFactory
+        pipe_factory = self.pipe_factory_override or R2RPipeFactory
+        pipeline_factory = self.pipeline_factory_override or R2RPipelineFactory
+
+        providers = provider_factory(self.config).create_providers(
+            vector_db_provider_override=self.vector_db_provider_override,
+            embedding_provider_override=self.embedding_provider_override,
+            eval_provider_override=self.eval_provider_override,
+            llm_provider_override=self.llm_provider_override,
+            prompt_provider_override=self.prompt_provider_override,
+            *args,
+            **kwargs,
+        )
+
+        pipes = pipe_factory(self.config, providers).create_pipes(
+            parsing_pipe_override=self.parsing_pipe_override,
+            embedding_pipe_override=self.embedding_pipe_override,
+            vector_storage_pipe_override=self.vector_storage_pipe_override,
+            vector_search_pipe_override=self.vector_search_pipe_override,
+            rag_pipe_override=self.rag_pipe_override,
+            streaming_rag_pipe_override=self.streaming_rag_pipe_override,
+            eval_pipe_override=self.eval_pipe_override,
+            *args,
+            **kwargs,
+        )
+
+        pipelines = pipeline_factory(self.config, pipes).create_pipelines(
+            ingestion_pipeline=self.ingestion_pipeline,
+            search_pipeline=self.search_pipeline,
+            rag_pipeline=self.rag_pipeline,
+            streaming_rag_pipeline=self.streaming_rag_pipeline,
+            eval_pipeline=self.eval_pipeline,
+            *args,
+            **kwargs,
+        )
+
+        engine = (self.r2r_app_override or R2REngine)(
+            self.config, providers, pipelines
+        )
+        r2r_app = R2RApp(engine)
+        return R2R(engine=engine, app=r2r_app)