diff options
Diffstat (limited to 'R2R/r2r/main/assembly')
-rwxr-xr-x | R2R/r2r/main/assembly/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/main/assembly/builder.py | 207 | ||||
-rwxr-xr-x | R2R/r2r/main/assembly/config.py | 167 | ||||
-rwxr-xr-x | R2R/r2r/main/assembly/factory.py | 484 | ||||
-rwxr-xr-x | R2R/r2r/main/assembly/factory_extensions.py | 69 |
5 files changed, 927 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/__init__.py b/R2R/r2r/main/assembly/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/assembly/__init__.py diff --git a/R2R/r2r/main/assembly/builder.py b/R2R/r2r/main/assembly/builder.py new file mode 100755 index 00000000..863fc6d0 --- /dev/null +++ b/R2R/r2r/main/assembly/builder.py @@ -0,0 +1,207 @@ +import os +from typing import Optional, Type + +from r2r.base import ( + AsyncPipe, + EmbeddingProvider, + EvalProvider, + LLMProvider, + PromptProvider, + VectorDBProvider, +) +from r2r.pipelines import ( + EvalPipeline, + IngestionPipeline, + RAGPipeline, + SearchPipeline, +) + +from ..app import R2RApp +from ..engine import R2REngine +from ..r2r import R2R +from .config import R2RConfig +from .factory import R2RPipeFactory, R2RPipelineFactory, R2RProviderFactory + + +class R2RBuilder: + current_file_path = os.path.dirname(__file__) + config_root = os.path.join( + current_file_path, "..", "..", "examples", "configs" + ) + CONFIG_OPTIONS = { + "default": None, + "local_ollama": os.path.join(config_root, "local_ollama.json"), + "local_ollama_rerank": os.path.join( + config_root, "local_ollama_rerank.json" + ), + "neo4j_kg": os.path.join(config_root, "neo4j_kg.json"), + "local_neo4j_kg": os.path.join(config_root, "local_neo4j_kg.json"), + "postgres_logging": os.path.join(config_root, "postgres_logging.json"), + } + + @staticmethod + def _get_config(config_name): + if config_name is None: + return R2RConfig.from_json() + if config_name in R2RBuilder.CONFIG_OPTIONS: + return R2RConfig.from_json(R2RBuilder.CONFIG_OPTIONS[config_name]) + raise ValueError(f"Invalid config name: {config_name}") + + def __init__( + self, + config: Optional[R2RConfig] = None, + from_config: Optional[str] = None, + ): + if config and from_config: + raise ValueError("Cannot specify both config and config_name") + self.config = config or R2RBuilder._get_config(from_config) + self.r2r_app_override: Optional[Type[R2REngine]] = None + self.provider_factory_override: Optional[Type[R2RProviderFactory]] = ( + None + ) + self.pipe_factory_override: Optional[R2RPipeFactory] = None + self.pipeline_factory_override: Optional[R2RPipelineFactory] = None + self.vector_db_provider_override: Optional[VectorDBProvider] = None + self.embedding_provider_override: Optional[EmbeddingProvider] = None + self.eval_provider_override: Optional[EvalProvider] = None + self.llm_provider_override: Optional[LLMProvider] = None + self.prompt_provider_override: Optional[PromptProvider] = None + self.parsing_pipe_override: Optional[AsyncPipe] = None + self.embedding_pipe_override: Optional[AsyncPipe] = None + self.vector_storage_pipe_override: Optional[AsyncPipe] = None + self.vector_search_pipe_override: Optional[AsyncPipe] = None + self.rag_pipe_override: Optional[AsyncPipe] = None + self.streaming_rag_pipe_override: Optional[AsyncPipe] = None + self.eval_pipe_override: Optional[AsyncPipe] = None + self.ingestion_pipeline: Optional[IngestionPipeline] = None + self.search_pipeline: Optional[SearchPipeline] = None + self.rag_pipeline: Optional[RAGPipeline] = None + self.streaming_rag_pipeline: Optional[RAGPipeline] = None + self.eval_pipeline: Optional[EvalPipeline] = None + + def with_app(self, app: Type[R2REngine]): + self.r2r_app_override = app + return self + + def with_provider_factory(self, factory: Type[R2RProviderFactory]): + self.provider_factory_override = factory + return self + + def with_pipe_factory(self, factory: R2RPipeFactory): + self.pipe_factory_override = factory + return self + + def with_pipeline_factory(self, factory: R2RPipelineFactory): + self.pipeline_factory_override = factory + return self + + def with_vector_db_provider(self, provider: VectorDBProvider): + self.vector_db_provider_override = provider + return self + + def with_embedding_provider(self, provider: EmbeddingProvider): + self.embedding_provider_override = provider + return self + + def with_eval_provider(self, provider: EvalProvider): + self.eval_provider_override = provider + return self + + def with_llm_provider(self, provider: LLMProvider): + self.llm_provider_override = provider + return self + + def with_prompt_provider(self, provider: PromptProvider): + self.prompt_provider_override = provider + return self + + def with_parsing_pipe(self, pipe: AsyncPipe): + self.parsing_pipe_override = pipe + return self + + def with_embedding_pipe(self, pipe: AsyncPipe): + self.embedding_pipe_override = pipe + return self + + def with_vector_storage_pipe(self, pipe: AsyncPipe): + self.vector_storage_pipe_override = pipe + return self + + def with_vector_search_pipe(self, pipe: AsyncPipe): + self.vector_search_pipe_override = pipe + return self + + def with_rag_pipe(self, pipe: AsyncPipe): + self.rag_pipe_override = pipe + return self + + def with_streaming_rag_pipe(self, pipe: AsyncPipe): + self.streaming_rag_pipe_override = pipe + return self + + def with_eval_pipe(self, pipe: AsyncPipe): + self.eval_pipe_override = pipe + return self + + def with_ingestion_pipeline(self, pipeline: IngestionPipeline): + self.ingestion_pipeline = pipeline + return self + + def with_vector_search_pipeline(self, pipeline: SearchPipeline): + self.search_pipeline = pipeline + return self + + def with_rag_pipeline(self, pipeline: RAGPipeline): + self.rag_pipeline = pipeline + return self + + def with_streaming_rag_pipeline(self, pipeline: RAGPipeline): + self.streaming_rag_pipeline = pipeline + return self + + def with_eval_pipeline(self, pipeline: EvalPipeline): + self.eval_pipeline = pipeline + return self + + def build(self, *args, **kwargs) -> R2R: + provider_factory = self.provider_factory_override or R2RProviderFactory + pipe_factory = self.pipe_factory_override or R2RPipeFactory + pipeline_factory = self.pipeline_factory_override or R2RPipelineFactory + + providers = provider_factory(self.config).create_providers( + vector_db_provider_override=self.vector_db_provider_override, + embedding_provider_override=self.embedding_provider_override, + eval_provider_override=self.eval_provider_override, + llm_provider_override=self.llm_provider_override, + prompt_provider_override=self.prompt_provider_override, + *args, + **kwargs, + ) + + pipes = pipe_factory(self.config, providers).create_pipes( + parsing_pipe_override=self.parsing_pipe_override, + embedding_pipe_override=self.embedding_pipe_override, + vector_storage_pipe_override=self.vector_storage_pipe_override, + vector_search_pipe_override=self.vector_search_pipe_override, + rag_pipe_override=self.rag_pipe_override, + streaming_rag_pipe_override=self.streaming_rag_pipe_override, + eval_pipe_override=self.eval_pipe_override, + *args, + **kwargs, + ) + + pipelines = pipeline_factory(self.config, pipes).create_pipelines( + ingestion_pipeline=self.ingestion_pipeline, + search_pipeline=self.search_pipeline, + rag_pipeline=self.rag_pipeline, + streaming_rag_pipeline=self.streaming_rag_pipeline, + eval_pipeline=self.eval_pipeline, + *args, + **kwargs, + ) + + engine = (self.r2r_app_override or R2REngine)( + self.config, providers, pipelines + ) + r2r_app = R2RApp(engine) + return R2R(engine=engine, app=r2r_app) diff --git a/R2R/r2r/main/assembly/config.py b/R2R/r2r/main/assembly/config.py new file mode 100755 index 00000000..d52c4561 --- /dev/null +++ b/R2R/r2r/main/assembly/config.py @@ -0,0 +1,167 @@ +import json +import logging +import os +from enum import Enum +from typing import Any + +from ...base.abstractions.document import DocumentType +from ...base.abstractions.llm import GenerationConfig +from ...base.logging.kv_logger import LoggingConfig +from ...base.providers.embedding_provider import EmbeddingConfig +from ...base.providers.eval_provider import EvalConfig +from ...base.providers.kg_provider import KGConfig +from ...base.providers.llm_provider import LLMConfig +from ...base.providers.prompt_provider import PromptConfig +from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig + +logger = logging.getLogger(__name__) + + +class R2RConfig: + REQUIRED_KEYS: dict[str, list] = { + "app": ["max_file_size_in_mb"], + "embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "text_splitter", + ], + "eval": ["llm"], + "kg": [ + "provider", + "batch_size", + "kg_extraction_config", + "text_splitter", + ], + "ingestion": ["excluded_parsers"], + "completions": ["provider"], + "logging": ["provider", "log_table"], + "prompt": ["provider"], + "vector_database": ["provider"], + } + app: dict[str, Any] + embedding: EmbeddingConfig + completions: LLMConfig + logging: LoggingConfig + prompt: PromptConfig + vector_database: VectorDBConfig + + def __init__(self, config_data: dict[str, Any]): + # Load the default configuration + default_config = self.load_default_config() + + # Override the default configuration with the passed configuration + for key in config_data: + if key in default_config: + default_config[key].update(config_data[key]) + else: + default_config[key] = config_data[key] + + # Validate and set the configuration + for section, keys in R2RConfig.REQUIRED_KEYS.items(): + # Check the keys when provider is set + # TODO - Clean up robust null checks + if "provider" in default_config[section] and ( + default_config[section]["provider"] is not None + and default_config[section]["provider"] != "None" + and default_config[section]["provider"] != "null" + ): + self._validate_config_section(default_config, section, keys) + setattr(self, section, default_config[section]) + + self.app = self.app # for type hinting + self.ingestion = self.ingestion # for type hinting + self.ingestion["excluded_parsers"] = [ + DocumentType(k) for k in self.ingestion["excluded_parsers"] + ] + # override GenerationConfig defaults + GenerationConfig.set_default( + **self.completions.get("generation_config", {}) + ) + self.embedding = EmbeddingConfig.create(**self.embedding) + self.kg = KGConfig.create(**self.kg) + eval_llm = self.eval.pop("llm", None) + self.eval = EvalConfig.create( + **self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None + ) + self.completions = LLMConfig.create(**self.completions) + self.logging = LoggingConfig.create(**self.logging) + self.prompt = PromptConfig.create(**self.prompt) + self.vector_database = VectorDBConfig.create(**self.vector_database) + + def _validate_config_section( + self, config_data: dict[str, Any], section: str, keys: list + ): + if section not in config_data: + raise ValueError(f"Missing '{section}' section in config") + if not all(key in config_data[section] for key in keys): + raise ValueError(f"Missing required keys in '{section}' config") + + @classmethod + def from_json(cls, config_path: str = None) -> "R2RConfig": + if config_path is None: + # Get the root directory of the project + file_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join( + file_dir, "..", "..", "..", "config.json" + ) + + # Load configuration from JSON file + with open(config_path) as f: + config_data = json.load(f) + + return cls(config_data) + + def to_json(self): + config_data = { + section: self._serialize_config(getattr(self, section)) + for section in R2RConfig.REQUIRED_KEYS.keys() + } + return json.dumps(config_data) + + def save_to_redis(self, redis_client: Any, key: str): + redis_client.set(f"R2RConfig:{key}", self.to_json()) + + @classmethod + def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig": + config_data = redis_client.get(f"R2RConfig:{key}") + if config_data is None: + raise ValueError( + f"Configuration not found in Redis with key '{key}'" + ) + config_data = json.loads(config_data) + # config_data["ingestion"]["selected_parsers"] = { + # DocumentType(k): v + # for k, v in config_data["ingestion"]["selected_parsers"].items() + # } + return cls(config_data) + + @classmethod + def load_default_config(cls) -> dict: + # Get the root directory of the project + file_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join( + file_dir, "..", "..", "..", "config.json" + ) + # Load default configuration from JSON file + with open(default_config_path) as f: + return json.load(f) + + @staticmethod + def _serialize_config(config_section: Any) -> dict: + # TODO - Make this approach cleaner + if isinstance(config_section, ProviderConfig): + config_section = config_section.dict() + filtered_result = {} + for k, v in config_section.items(): + if isinstance(k, Enum): + k = k.value + if isinstance(v, dict): + formatted_v = { + k2.value if isinstance(k2, Enum) else k2: v2 + for k2, v2 in v.items() + } + v = formatted_v + filtered_result[k] = v + return filtered_result diff --git a/R2R/r2r/main/assembly/factory.py b/R2R/r2r/main/assembly/factory.py new file mode 100755 index 00000000..4e147337 --- /dev/null +++ b/R2R/r2r/main/assembly/factory.py @@ -0,0 +1,484 @@ +import logging +import os +from typing import Any, Optional + +from r2r.base import ( + AsyncPipe, + EmbeddingConfig, + EmbeddingProvider, + EvalProvider, + KGProvider, + KVLoggingSingleton, + LLMConfig, + LLMProvider, + PromptProvider, + VectorDBConfig, + VectorDBProvider, +) +from r2r.pipelines import ( + EvalPipeline, + IngestionPipeline, + RAGPipeline, + SearchPipeline, +) + +from ..abstractions import R2RPipelines, R2RPipes, R2RProviders +from .config import R2RConfig + +logger = logging.getLogger(__name__) + + +class R2RProviderFactory: + def __init__(self, config: R2RConfig): + self.config = config + + def create_vector_db_provider( + self, vector_db_config: VectorDBConfig, *args, **kwargs + ) -> VectorDBProvider: + vector_db_provider: Optional[VectorDBProvider] = None + if vector_db_config.provider == "pgvector": + from r2r.providers.vector_dbs import PGVectorDB + + vector_db_provider = PGVectorDB(vector_db_config) + else: + raise ValueError( + f"Vector database provider {vector_db_config.provider} not supported" + ) + if not vector_db_provider: + raise ValueError("Vector database provider not found") + + if not self.config.embedding.base_dimension: + raise ValueError("Search dimension not found in embedding config") + + vector_db_provider.initialize_collection( + self.config.embedding.base_dimension + ) + return vector_db_provider + + def create_embedding_provider( + self, embedding: EmbeddingConfig, *args, **kwargs + ) -> EmbeddingProvider: + embedding_provider: Optional[EmbeddingProvider] = None + + if embedding.provider == "openai": + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + from r2r.providers.embeddings import OpenAIEmbeddingProvider + + embedding_provider = OpenAIEmbeddingProvider(embedding) + elif embedding.provider == "ollama": + from r2r.providers.embeddings import OllamaEmbeddingProvider + + embedding_provider = OllamaEmbeddingProvider(embedding) + + elif embedding.provider == "sentence-transformers": + from r2r.providers.embeddings import ( + SentenceTransformerEmbeddingProvider, + ) + + embedding_provider = SentenceTransformerEmbeddingProvider( + embedding + ) + elif embedding is None: + embedding_provider = None + else: + raise ValueError( + f"Embedding provider {embedding.provider} not supported" + ) + + return embedding_provider + + def create_eval_provider( + self, eval_config, prompt_provider, *args, **kwargs + ) -> Optional[EvalProvider]: + if eval_config.provider == "local": + from r2r.providers.eval import LLMEvalProvider + + llm_provider = self.create_llm_provider(eval_config.llm) + eval_provider = LLMEvalProvider( + eval_config, + llm_provider=llm_provider, + prompt_provider=prompt_provider, + ) + elif eval_config.provider is None: + eval_provider = None + else: + raise ValueError( + f"Eval provider {eval_config.provider} not supported." + ) + + return eval_provider + + def create_llm_provider( + self, llm_config: LLMConfig, *args, **kwargs + ) -> LLMProvider: + llm_provider: Optional[LLMProvider] = None + if llm_config.provider == "openai": + from r2r.providers.llms import OpenAILLM + + llm_provider = OpenAILLM(llm_config) + elif llm_config.provider == "litellm": + from r2r.providers.llms import LiteLLM + + llm_provider = LiteLLM(llm_config) + else: + raise ValueError( + f"Language model provider {llm_config.provider} not supported" + ) + if not llm_provider: + raise ValueError("Language model provider not found") + return llm_provider + + def create_prompt_provider( + self, prompt_config, *args, **kwargs + ) -> PromptProvider: + prompt_provider = None + if prompt_config.provider == "local": + from r2r.prompts import R2RPromptProvider + + prompt_provider = R2RPromptProvider() + else: + raise ValueError( + f"Prompt provider {prompt_config.provider} not supported" + ) + return prompt_provider + + def create_kg_provider(self, kg_config, *args, **kwargs): + if kg_config.provider == "neo4j": + from r2r.providers.kg import Neo4jKGProvider + + return Neo4jKGProvider(kg_config) + elif kg_config.provider is None: + return None + else: + raise ValueError( + f"KG provider {kg_config.provider} not supported." + ) + + def create_providers( + self, + vector_db_provider_override: Optional[VectorDBProvider] = None, + embedding_provider_override: Optional[EmbeddingProvider] = None, + eval_provider_override: Optional[EvalProvider] = None, + llm_provider_override: Optional[LLMProvider] = None, + prompt_provider_override: Optional[PromptProvider] = None, + kg_provider_override: Optional[KGProvider] = None, + *args, + **kwargs, + ) -> R2RProviders: + prompt_provider = ( + prompt_provider_override + or self.create_prompt_provider(self.config.prompt, *args, **kwargs) + ) + return R2RProviders( + vector_db=vector_db_provider_override + or self.create_vector_db_provider( + self.config.vector_database, *args, **kwargs + ), + embedding=embedding_provider_override + or self.create_embedding_provider( + self.config.embedding, *args, **kwargs + ), + eval=eval_provider_override + or self.create_eval_provider( + self.config.eval, + prompt_provider=prompt_provider, + *args, + **kwargs, + ), + llm=llm_provider_override + or self.create_llm_provider( + self.config.completions, *args, **kwargs + ), + prompt=prompt_provider_override + or self.create_prompt_provider( + self.config.prompt, *args, **kwargs + ), + kg=kg_provider_override + or self.create_kg_provider(self.config.kg, *args, **kwargs), + ) + + +class R2RPipeFactory: + def __init__(self, config: R2RConfig, providers: R2RProviders): + self.config = config + self.providers = providers + + def create_pipes( + self, + parsing_pipe_override: Optional[AsyncPipe] = None, + embedding_pipe_override: Optional[AsyncPipe] = None, + kg_pipe_override: Optional[AsyncPipe] = None, + kg_storage_pipe_override: Optional[AsyncPipe] = None, + kg_agent_pipe_override: Optional[AsyncPipe] = None, + vector_storage_pipe_override: Optional[AsyncPipe] = None, + vector_search_pipe_override: Optional[AsyncPipe] = None, + rag_pipe_override: Optional[AsyncPipe] = None, + streaming_rag_pipe_override: Optional[AsyncPipe] = None, + eval_pipe_override: Optional[AsyncPipe] = None, + *args, + **kwargs, + ) -> R2RPipes: + return R2RPipes( + parsing_pipe=parsing_pipe_override + or self.create_parsing_pipe( + self.config.ingestion.get("excluded_parsers"), *args, **kwargs + ), + embedding_pipe=embedding_pipe_override + or self.create_embedding_pipe(*args, **kwargs), + kg_pipe=kg_pipe_override or self.create_kg_pipe(*args, **kwargs), + kg_storage_pipe=kg_storage_pipe_override + or self.create_kg_storage_pipe(*args, **kwargs), + kg_agent_search_pipe=kg_agent_pipe_override + or self.create_kg_agent_pipe(*args, **kwargs), + vector_storage_pipe=vector_storage_pipe_override + or self.create_vector_storage_pipe(*args, **kwargs), + vector_search_pipe=vector_search_pipe_override + or self.create_vector_search_pipe(*args, **kwargs), + rag_pipe=rag_pipe_override + or self.create_rag_pipe(*args, **kwargs), + streaming_rag_pipe=streaming_rag_pipe_override + or self.create_rag_pipe(stream=True, *args, **kwargs), + eval_pipe=eval_pipe_override + or self.create_eval_pipe(*args, **kwargs), + ) + + def create_parsing_pipe( + self, excluded_parsers: Optional[list] = None, *args, **kwargs + ) -> Any: + from r2r.pipes import ParsingPipe + + return ParsingPipe(excluded_parsers=excluded_parsers or []) + + def create_embedding_pipe(self, *args, **kwargs) -> Any: + if self.config.embedding.provider is None: + return None + + from r2r.base import RecursiveCharacterTextSplitter + from r2r.pipes import EmbeddingPipe + + text_splitter_config = self.config.embedding.extra_fields.get( + "text_splitter" + ) + if not text_splitter_config: + raise ValueError( + "Text splitter config not found in embedding config" + ) + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=text_splitter_config["chunk_size"], + chunk_overlap=text_splitter_config["chunk_overlap"], + length_function=len, + is_separator_regex=False, + ) + return EmbeddingPipe( + embedding_provider=self.providers.embedding, + vector_db_provider=self.providers.vector_db, + text_splitter=text_splitter, + embedding_batch_size=self.config.embedding.batch_size, + ) + + def create_vector_storage_pipe(self, *args, **kwargs) -> Any: + if self.config.embedding.provider is None: + return None + + from r2r.pipes import VectorStoragePipe + + return VectorStoragePipe(vector_db_provider=self.providers.vector_db) + + def create_vector_search_pipe(self, *args, **kwargs) -> Any: + if self.config.embedding.provider is None: + return None + + from r2r.pipes import VectorSearchPipe + + return VectorSearchPipe( + vector_db_provider=self.providers.vector_db, + embedding_provider=self.providers.embedding, + ) + + def create_kg_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.base import RecursiveCharacterTextSplitter + from r2r.pipes import KGExtractionPipe + + text_splitter_config = self.config.kg.extra_fields.get("text_splitter") + if not text_splitter_config: + raise ValueError("Text splitter config not found in kg config.") + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=text_splitter_config["chunk_size"], + chunk_overlap=text_splitter_config["chunk_overlap"], + length_function=len, + is_separator_regex=False, + ) + return KGExtractionPipe( + kg_provider=self.providers.kg, + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + vector_db_provider=self.providers.vector_db, + text_splitter=text_splitter, + kg_batch_size=self.config.kg.batch_size, + ) + + def create_kg_storage_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.pipes import KGStoragePipe + + return KGStoragePipe( + kg_provider=self.providers.kg, + embedding_provider=self.providers.embedding, + ) + + def create_kg_agent_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.pipes import KGAgentSearchPipe + + return KGAgentSearchPipe( + kg_provider=self.providers.kg, + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + + def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any: + if stream: + from r2r.pipes import StreamingSearchRAGPipe + + return StreamingSearchRAGPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + else: + from r2r.pipes import SearchRAGPipe + + return SearchRAGPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + + def create_eval_pipe(self, *args, **kwargs) -> Any: + from r2r.pipes import EvalPipe + + return EvalPipe(eval_provider=self.providers.eval) + + +class R2RPipelineFactory: + def __init__(self, config: R2RConfig, pipes: R2RPipes): + self.config = config + self.pipes = pipes + + def create_ingestion_pipeline(self, *args, **kwargs) -> IngestionPipeline: + """factory method to create an ingestion pipeline.""" + ingestion_pipeline = IngestionPipeline() + + ingestion_pipeline.add_pipe( + pipe=self.pipes.parsing_pipe, parsing_pipe=True + ) + # Add embedding pipes if provider is set + if self.config.embedding.provider is not None: + ingestion_pipeline.add_pipe( + self.pipes.embedding_pipe, embedding_pipe=True + ) + ingestion_pipeline.add_pipe( + self.pipes.vector_storage_pipe, embedding_pipe=True + ) + # Add KG pipes if provider is set + if self.config.kg.provider is not None: + ingestion_pipeline.add_pipe(self.pipes.kg_pipe, kg_pipe=True) + ingestion_pipeline.add_pipe( + self.pipes.kg_storage_pipe, kg_pipe=True + ) + + return ingestion_pipeline + + def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline: + """factory method to create an ingestion pipeline.""" + search_pipeline = SearchPipeline() + + # Add vector search pipes if embedding provider and vector provider is set + if ( + self.config.embedding.provider is not None + and self.config.vector_database.provider is not None + ): + search_pipeline.add_pipe( + self.pipes.vector_search_pipe, vector_search_pipe=True + ) + + # Add KG pipes if provider is set + if self.config.kg.provider is not None: + search_pipeline.add_pipe( + self.pipes.kg_agent_search_pipe, kg_pipe=True + ) + + return search_pipeline + + def create_rag_pipeline( + self, + search_pipeline: SearchPipeline, + stream: bool = False, + *args, + **kwargs, + ) -> RAGPipeline: + rag_pipe = ( + self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe + ) + + rag_pipeline = RAGPipeline() + rag_pipeline.set_search_pipeline(search_pipeline) + rag_pipeline.add_pipe(rag_pipe) + return rag_pipeline + + def create_eval_pipeline(self, *args, **kwargs) -> EvalPipeline: + eval_pipeline = EvalPipeline() + eval_pipeline.add_pipe(self.pipes.eval_pipe) + return eval_pipeline + + def create_pipelines( + self, + ingestion_pipeline: Optional[IngestionPipeline] = None, + search_pipeline: Optional[SearchPipeline] = None, + rag_pipeline: Optional[RAGPipeline] = None, + streaming_rag_pipeline: Optional[RAGPipeline] = None, + eval_pipeline: Optional[EvalPipeline] = None, + *args, + **kwargs, + ) -> R2RPipelines: + try: + self.configure_logging() + except Exception as e: + logger.warn(f"Error configuring logging: {e}") + search_pipeline = search_pipeline or self.create_search_pipeline( + *args, **kwargs + ) + return R2RPipelines( + ingestion_pipeline=ingestion_pipeline + or self.create_ingestion_pipeline(*args, **kwargs), + search_pipeline=search_pipeline, + rag_pipeline=rag_pipeline + or self.create_rag_pipeline( + search_pipeline=search_pipeline, + stream=False, + *args, + **kwargs, + ), + streaming_rag_pipeline=streaming_rag_pipeline + or self.create_rag_pipeline( + search_pipeline=search_pipeline, + stream=True, + *args, + **kwargs, + ), + eval_pipeline=eval_pipeline + or self.create_eval_pipeline(*args, **kwargs), + ) + + def configure_logging(self): + KVLoggingSingleton.configure(self.config.logging) diff --git a/R2R/r2r/main/assembly/factory_extensions.py b/R2R/r2r/main/assembly/factory_extensions.py new file mode 100755 index 00000000..56e82ef7 --- /dev/null +++ b/R2R/r2r/main/assembly/factory_extensions.py @@ -0,0 +1,69 @@ +from r2r.main import R2RPipeFactory +from r2r.pipes.retrieval.multi_search import MultiSearchPipe +from r2r.pipes.retrieval.query_transform_pipe import QueryTransformPipe + + +class R2RPipeFactoryWithMultiSearch(R2RPipeFactory): + QUERY_GENERATION_TEMPLATE: dict = ( + { # TODO - Can we have stricter typing like so? `: {"template": str, "input_types": dict[str, str]} = {`` + "template": "### Instruction:\n\nGiven the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query. \nDO NOT generate any single query which is likely to require information from multiple distinct documents, \nEACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents. \nFOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be \n`What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.\nHere is the original user query to be transformed into answers:\n\n### Query:\n{message}\n\n### Response:\n", + "input_types": {"num_outputs": "int", "message": "str"}, + } + ) + + def create_vector_search_pipe(self, *args, **kwargs): + """ + A factory method to create a search pipe. + + Overrides include + task_prompt_name: str + multi_query_transform_pipe_override: QueryTransformPipe + multi_inner_search_pipe_override: SearchPipe + query_generation_template_override: {'template': str, 'input_types': dict[str, str]} + """ + multi_search_config = MultiSearchPipe.PipeConfig() + if kwargs.get("task_prompt_name") and kwargs.get( + "query_generation_template_override" + ): + raise ValueError( + "Cannot provide both `task_prompt_name` and `query_generation_template_override`" + ) + task_prompt_name = ( + kwargs.get("task_prompt_name") + or f"{multi_search_config.name}_task_prompt" + ) + if kwargs.get("query_generation_template_override"): + # Add a prompt for transforming the user query + template = kwargs.get("query_generation_template_override") + self.providers.prompt.add_prompt( + **( + kwargs.get("query_generation_template_override") + or self.QUERY_GENERATION_TEMPLATE + ), + ) + task_prompt_name = template["name"] + + # Initialize the new query transform pipe + query_transform_pipe = kwargs.get( + "multi_query_transform_pipe_override", None + ) or QueryTransformPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + config=QueryTransformPipe.QueryTransformConfig( + name=multi_search_config.name, + task_prompt=task_prompt_name, + ), + ) + # Create search pipe override and pipes + inner_search_pipe = kwargs.get( + "multi_inner_search_pipe_override", None + ) or super().create_vector_search_pipe(*args, **kwargs) + + # TODO - modify `create_..._pipe` to allow naming the pipe + inner_search_pipe.config.name = multi_search_config.name + + return MultiSearchPipe( + query_transform_pipe=query_transform_pipe, + inner_search_pipe=inner_search_pipe, + config=multi_search_config, + ) |