aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/assembly/builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/assembly/builder.py')
-rwxr-xr-xR2R/r2r/main/assembly/builder.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/builder.py b/R2R/r2r/main/assembly/builder.py
new file mode 100755
index 00000000..863fc6d0
--- /dev/null
+++ b/R2R/r2r/main/assembly/builder.py
@@ -0,0 +1,207 @@
+import os
+from typing import Optional, Type
+
+from r2r.base import (
+ AsyncPipe,
+ EmbeddingProvider,
+ EvalProvider,
+ LLMProvider,
+ PromptProvider,
+ VectorDBProvider,
+)
+from r2r.pipelines import (
+ EvalPipeline,
+ IngestionPipeline,
+ RAGPipeline,
+ SearchPipeline,
+)
+
+from ..app import R2RApp
+from ..engine import R2REngine
+from ..r2r import R2R
+from .config import R2RConfig
+from .factory import R2RPipeFactory, R2RPipelineFactory, R2RProviderFactory
+
+
+class R2RBuilder:
+ current_file_path = os.path.dirname(__file__)
+ config_root = os.path.join(
+ current_file_path, "..", "..", "examples", "configs"
+ )
+ CONFIG_OPTIONS = {
+ "default": None,
+ "local_ollama": os.path.join(config_root, "local_ollama.json"),
+ "local_ollama_rerank": os.path.join(
+ config_root, "local_ollama_rerank.json"
+ ),
+ "neo4j_kg": os.path.join(config_root, "neo4j_kg.json"),
+ "local_neo4j_kg": os.path.join(config_root, "local_neo4j_kg.json"),
+ "postgres_logging": os.path.join(config_root, "postgres_logging.json"),
+ }
+
+ @staticmethod
+ def _get_config(config_name):
+ if config_name is None:
+ return R2RConfig.from_json()
+ if config_name in R2RBuilder.CONFIG_OPTIONS:
+ return R2RConfig.from_json(R2RBuilder.CONFIG_OPTIONS[config_name])
+ raise ValueError(f"Invalid config name: {config_name}")
+
+ def __init__(
+ self,
+ config: Optional[R2RConfig] = None,
+ from_config: Optional[str] = None,
+ ):
+ if config and from_config:
+ raise ValueError("Cannot specify both config and config_name")
+ self.config = config or R2RBuilder._get_config(from_config)
+ self.r2r_app_override: Optional[Type[R2REngine]] = None
+ self.provider_factory_override: Optional[Type[R2RProviderFactory]] = (
+ None
+ )
+ self.pipe_factory_override: Optional[R2RPipeFactory] = None
+ self.pipeline_factory_override: Optional[R2RPipelineFactory] = None
+ self.vector_db_provider_override: Optional[VectorDBProvider] = None
+ self.embedding_provider_override: Optional[EmbeddingProvider] = None
+ self.eval_provider_override: Optional[EvalProvider] = None
+ self.llm_provider_override: Optional[LLMProvider] = None
+ self.prompt_provider_override: Optional[PromptProvider] = None
+ self.parsing_pipe_override: Optional[AsyncPipe] = None
+ self.embedding_pipe_override: Optional[AsyncPipe] = None
+ self.vector_storage_pipe_override: Optional[AsyncPipe] = None
+ self.vector_search_pipe_override: Optional[AsyncPipe] = None
+ self.rag_pipe_override: Optional[AsyncPipe] = None
+ self.streaming_rag_pipe_override: Optional[AsyncPipe] = None
+ self.eval_pipe_override: Optional[AsyncPipe] = None
+ self.ingestion_pipeline: Optional[IngestionPipeline] = None
+ self.search_pipeline: Optional[SearchPipeline] = None
+ self.rag_pipeline: Optional[RAGPipeline] = None
+ self.streaming_rag_pipeline: Optional[RAGPipeline] = None
+ self.eval_pipeline: Optional[EvalPipeline] = None
+
+ def with_app(self, app: Type[R2REngine]):
+ self.r2r_app_override = app
+ return self
+
+ def with_provider_factory(self, factory: Type[R2RProviderFactory]):
+ self.provider_factory_override = factory
+ return self
+
+ def with_pipe_factory(self, factory: R2RPipeFactory):
+ self.pipe_factory_override = factory
+ return self
+
+ def with_pipeline_factory(self, factory: R2RPipelineFactory):
+ self.pipeline_factory_override = factory
+ return self
+
+ def with_vector_db_provider(self, provider: VectorDBProvider):
+ self.vector_db_provider_override = provider
+ return self
+
+ def with_embedding_provider(self, provider: EmbeddingProvider):
+ self.embedding_provider_override = provider
+ return self
+
+ def with_eval_provider(self, provider: EvalProvider):
+ self.eval_provider_override = provider
+ return self
+
+ def with_llm_provider(self, provider: LLMProvider):
+ self.llm_provider_override = provider
+ return self
+
+ def with_prompt_provider(self, provider: PromptProvider):
+ self.prompt_provider_override = provider
+ return self
+
+ def with_parsing_pipe(self, pipe: AsyncPipe):
+ self.parsing_pipe_override = pipe
+ return self
+
+ def with_embedding_pipe(self, pipe: AsyncPipe):
+ self.embedding_pipe_override = pipe
+ return self
+
+ def with_vector_storage_pipe(self, pipe: AsyncPipe):
+ self.vector_storage_pipe_override = pipe
+ return self
+
+ def with_vector_search_pipe(self, pipe: AsyncPipe):
+ self.vector_search_pipe_override = pipe
+ return self
+
+ def with_rag_pipe(self, pipe: AsyncPipe):
+ self.rag_pipe_override = pipe
+ return self
+
+ def with_streaming_rag_pipe(self, pipe: AsyncPipe):
+ self.streaming_rag_pipe_override = pipe
+ return self
+
+ def with_eval_pipe(self, pipe: AsyncPipe):
+ self.eval_pipe_override = pipe
+ return self
+
+ def with_ingestion_pipeline(self, pipeline: IngestionPipeline):
+ self.ingestion_pipeline = pipeline
+ return self
+
+ def with_vector_search_pipeline(self, pipeline: SearchPipeline):
+ self.search_pipeline = pipeline
+ return self
+
+ def with_rag_pipeline(self, pipeline: RAGPipeline):
+ self.rag_pipeline = pipeline
+ return self
+
+ def with_streaming_rag_pipeline(self, pipeline: RAGPipeline):
+ self.streaming_rag_pipeline = pipeline
+ return self
+
+ def with_eval_pipeline(self, pipeline: EvalPipeline):
+ self.eval_pipeline = pipeline
+ return self
+
+ def build(self, *args, **kwargs) -> R2R:
+ provider_factory = self.provider_factory_override or R2RProviderFactory
+ pipe_factory = self.pipe_factory_override or R2RPipeFactory
+ pipeline_factory = self.pipeline_factory_override or R2RPipelineFactory
+
+ providers = provider_factory(self.config).create_providers(
+ vector_db_provider_override=self.vector_db_provider_override,
+ embedding_provider_override=self.embedding_provider_override,
+ eval_provider_override=self.eval_provider_override,
+ llm_provider_override=self.llm_provider_override,
+ prompt_provider_override=self.prompt_provider_override,
+ *args,
+ **kwargs,
+ )
+
+ pipes = pipe_factory(self.config, providers).create_pipes(
+ parsing_pipe_override=self.parsing_pipe_override,
+ embedding_pipe_override=self.embedding_pipe_override,
+ vector_storage_pipe_override=self.vector_storage_pipe_override,
+ vector_search_pipe_override=self.vector_search_pipe_override,
+ rag_pipe_override=self.rag_pipe_override,
+ streaming_rag_pipe_override=self.streaming_rag_pipe_override,
+ eval_pipe_override=self.eval_pipe_override,
+ *args,
+ **kwargs,
+ )
+
+ pipelines = pipeline_factory(self.config, pipes).create_pipelines(
+ ingestion_pipeline=self.ingestion_pipeline,
+ search_pipeline=self.search_pipeline,
+ rag_pipeline=self.rag_pipeline,
+ streaming_rag_pipeline=self.streaming_rag_pipeline,
+ eval_pipeline=self.eval_pipeline,
+ *args,
+ **kwargs,
+ )
+
+ engine = (self.r2r_app_override or R2REngine)(
+ self.config, providers, pipelines
+ )
+ r2r_app = R2RApp(engine)
+ return R2R(engine=engine, app=r2r_app)