aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/pipes')
-rwxr-xr-xR2R/r2r/pipes/__init__.py31
-rwxr-xr-xR2R/r2r/pipes/abstractions/__init__.py0
-rwxr-xr-xR2R/r2r/pipes/abstractions/generator_pipe.py58
-rwxr-xr-xR2R/r2r/pipes/abstractions/search_pipe.py62
-rwxr-xr-xR2R/r2r/pipes/ingestion/__init__.py0
-rwxr-xr-xR2R/r2r/pipes/ingestion/embedding_pipe.py218
-rwxr-xr-xR2R/r2r/pipes/ingestion/kg_extraction_pipe.py226
-rwxr-xr-xR2R/r2r/pipes/ingestion/kg_storage_pipe.py133
-rwxr-xr-xR2R/r2r/pipes/ingestion/parsing_pipe.py211
-rwxr-xr-xR2R/r2r/pipes/ingestion/vector_storage_pipe.py128
-rwxr-xr-xR2R/r2r/pipes/other/eval_pipe.py54
-rwxr-xr-xR2R/r2r/pipes/other/web_search_pipe.py105
-rwxr-xr-xR2R/r2r/pipes/retrieval/kg_agent_search_pipe.py103
-rwxr-xr-xR2R/r2r/pipes/retrieval/multi_search.py79
-rwxr-xr-xR2R/r2r/pipes/retrieval/query_transform_pipe.py101
-rwxr-xr-xR2R/r2r/pipes/retrieval/search_rag_pipe.py130
-rwxr-xr-xR2R/r2r/pipes/retrieval/streaming_rag_pipe.py131
-rwxr-xr-xR2R/r2r/pipes/retrieval/vector_search_pipe.py123
18 files changed, 1893 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/__init__.py b/R2R/r2r/pipes/__init__.py
new file mode 100755
index 00000000..b86c31c0
--- /dev/null
+++ b/R2R/r2r/pipes/__init__.py
@@ -0,0 +1,31 @@
+from .abstractions.search_pipe import SearchPipe
+from .ingestion.embedding_pipe import EmbeddingPipe
+from .ingestion.kg_extraction_pipe import KGExtractionPipe
+from .ingestion.kg_storage_pipe import KGStoragePipe
+from .ingestion.parsing_pipe import ParsingPipe
+from .ingestion.vector_storage_pipe import VectorStoragePipe
+from .other.eval_pipe import EvalPipe
+from .other.web_search_pipe import WebSearchPipe
+from .retrieval.kg_agent_search_pipe import KGAgentSearchPipe
+from .retrieval.multi_search import MultiSearchPipe
+from .retrieval.query_transform_pipe import QueryTransformPipe
+from .retrieval.search_rag_pipe import SearchRAGPipe
+from .retrieval.streaming_rag_pipe import StreamingSearchRAGPipe
+from .retrieval.vector_search_pipe import VectorSearchPipe
+
+__all__ = [
+ "SearchPipe",
+ "EmbeddingPipe",
+ "EvalPipe",
+ "KGExtractionPipe",
+ "ParsingPipe",
+ "QueryTransformPipe",
+ "SearchRAGPipe",
+ "StreamingSearchRAGPipe",
+ "VectorSearchPipe",
+ "VectorStoragePipe",
+ "WebSearchPipe",
+ "KGAgentSearchPipe",
+ "KGStoragePipe",
+ "MultiSearchPipe",
+]
diff --git a/R2R/r2r/pipes/abstractions/__init__.py b/R2R/r2r/pipes/abstractions/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/pipes/abstractions/__init__.py
diff --git a/R2R/r2r/pipes/abstractions/generator_pipe.py b/R2R/r2r/pipes/abstractions/generator_pipe.py
new file mode 100755
index 00000000..002ebd23
--- /dev/null
+++ b/R2R/r2r/pipes/abstractions/generator_pipe.py
@@ -0,0 +1,58 @@
+import uuid
+from abc import abstractmethod
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncState,
+ KVLoggingSingleton,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+
+class GeneratorPipe(AsyncPipe):
+ class Config(AsyncPipe.PipeConfig):
+ name: str
+ task_prompt: str
+ system_prompt: str = "default_system"
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.GENERATOR,
+ config: Optional[Config] = None,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ type=type,
+ config=config or self.Config(),
+ pipe_logger=pipe_logger,
+ *args,
+ **kwargs,
+ )
+ self.llm_provider = llm_provider
+ self.prompt_provider = prompt_provider
+
+ @abstractmethod
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ rag_generation_config: GenerationConfig,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[Any, None]:
+ pass
+
+ @abstractmethod
+ def _get_message_payload(
+ self, message: str, *args: Any, **kwargs: Any
+ ) -> list:
+ pass
diff --git a/R2R/r2r/pipes/abstractions/search_pipe.py b/R2R/r2r/pipes/abstractions/search_pipe.py
new file mode 100755
index 00000000..bb0303e0
--- /dev/null
+++ b/R2R/r2r/pipes/abstractions/search_pipe.py
@@ -0,0 +1,62 @@
+import logging
+import uuid
+from abc import abstractmethod
+from typing import Any, AsyncGenerator, Optional, Union
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ KVLoggingSingleton,
+ PipeType,
+ VectorSearchResult,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SearchPipe(AsyncPipe):
+ class SearchConfig(AsyncPipe.PipeConfig):
+ name: str = "default_vector_search"
+ search_filters: dict = {}
+ search_limit: int = 10
+
+ class Input(AsyncPipe.Input):
+ message: Union[AsyncGenerator[str, None], str]
+
+ def __init__(
+ self,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.SEARCH,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config,
+ *args,
+ **kwargs,
+ )
+
+ @abstractmethod
+ async def search(
+ self,
+ query: str,
+ filters: dict[str, Any] = {},
+ limit: int = 10,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ pass
+
+ @abstractmethod
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ pass
diff --git a/R2R/r2r/pipes/ingestion/__init__.py b/R2R/r2r/pipes/ingestion/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/__init__.py
diff --git a/R2R/r2r/pipes/ingestion/embedding_pipe.py b/R2R/r2r/pipes/ingestion/embedding_pipe.py
new file mode 100755
index 00000000..971ccc9d
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/embedding_pipe.py
@@ -0,0 +1,218 @@
+import asyncio
+import copy
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional, Union
+
+from r2r.base import (
+ AsyncState,
+ EmbeddingProvider,
+ Extraction,
+ Fragment,
+ FragmentType,
+ KVLoggingSingleton,
+ PipeType,
+ R2RDocumentProcessingError,
+ TextSplitter,
+ Vector,
+ VectorEntry,
+ generate_id_from_label,
+)
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddingPipe(AsyncPipe):
+ """
+ Embeds and stores documents using a specified embedding model and database.
+ """
+
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[
+ Union[Extraction, R2RDocumentProcessingError], None
+ ]
+
+ def __init__(
+ self,
+ embedding_provider: EmbeddingProvider,
+ text_splitter: TextSplitter,
+ embedding_batch_size: int = 1,
+ id_prefix: str = "demo",
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the embedding pipe with necessary components and configurations.
+ """
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config
+ or AsyncPipe.PipeConfig(name="default_embedding_pipe"),
+ )
+ self.embedding_provider = embedding_provider
+ self.text_splitter = text_splitter
+ self.embedding_batch_size = embedding_batch_size
+ self.id_prefix = id_prefix
+ self.pipe_run_info = None
+
+ async def fragment(
+ self, extraction: Extraction, run_id: uuid.UUID
+ ) -> AsyncGenerator[Fragment, None]:
+ """
+ Splits text into manageable chunks for embedding.
+ """
+ if not isinstance(extraction, Extraction):
+ raise ValueError(
+ f"Expected an Extraction, but received {type(extraction)}."
+ )
+ if not isinstance(extraction.data, str):
+ raise ValueError(
+ f"Expected a string, but received {type(extraction.data)}."
+ )
+ text_chunks = [
+ ele.page_content
+ for ele in self.text_splitter.create_documents([extraction.data])
+ ]
+ for iteration, chunk in enumerate(text_chunks):
+ fragment = Fragment(
+ id=generate_id_from_label(f"{extraction.id}-{iteration}"),
+ type=FragmentType.TEXT,
+ data=chunk,
+ metadata=copy.deepcopy(extraction.metadata),
+ extraction_id=extraction.id,
+ document_id=extraction.document_id,
+ )
+ yield fragment
+ iteration += 1
+
+ async def transform_fragments(
+ self, fragments: list[Fragment], metadatas: list[dict]
+ ) -> AsyncGenerator[Fragment, None]:
+ """
+ Transforms text chunks based on their metadata, e.g., adding prefixes.
+ """
+ async for fragment, metadata in zip(fragments, metadatas):
+ if "chunk_prefix" in metadata:
+ prefix = metadata.pop("chunk_prefix")
+ fragment.data = f"{prefix}\n{fragment.data}"
+ yield fragment
+
+ async def embed(self, fragments: list[Fragment]) -> list[float]:
+ return await self.embedding_provider.async_get_embeddings(
+ [fragment.data for fragment in fragments],
+ EmbeddingProvider.PipeStage.BASE,
+ )
+
+ async def _process_batch(
+ self, fragment_batch: list[Fragment]
+ ) -> list[VectorEntry]:
+ """
+ Embeds a batch of fragments and yields vector entries.
+ """
+ vectors = await self.embed(fragment_batch)
+ return [
+ VectorEntry(
+ id=fragment.id,
+ vector=Vector(data=raw_vector),
+ metadata={
+ "document_id": fragment.document_id,
+ "extraction_id": fragment.extraction_id,
+ "text": fragment.data,
+ **fragment.metadata,
+ },
+ )
+ for raw_vector, fragment in zip(vectors, fragment_batch)
+ ]
+
+ async def _process_and_enqueue_batch(
+ self, fragment_batch: list[Fragment], vector_entry_queue: asyncio.Queue
+ ):
+ try:
+ batch_result = await self._process_batch(fragment_batch)
+ for vector_entry in batch_result:
+ await vector_entry_queue.put(vector_entry)
+ except Exception as e:
+ logger.error(f"Error processing batch: {e}")
+ await vector_entry_queue.put(
+ R2RDocumentProcessingError(
+ error_message=str(e),
+ document_id=fragment_batch[0].document_id,
+ )
+ )
+ finally:
+ await vector_entry_queue.put(None) # Signal completion
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[Union[R2RDocumentProcessingError, VectorEntry], None]:
+ """
+ Executes the embedding pipe: chunking, transforming, embedding, and storing documents.
+ """
+ vector_entry_queue = asyncio.Queue()
+ fragment_batch = []
+ active_tasks = 0
+
+ fragment_info = {}
+ async for extraction in input.message:
+ if isinstance(extraction, R2RDocumentProcessingError):
+ yield extraction
+ continue
+
+ async for fragment in self.fragment(extraction, run_id):
+ if extraction.document_id in fragment_info:
+ fragment_info[extraction.document_id] += 1
+ else:
+ fragment_info[extraction.document_id] = 0 # Start with 0
+ fragment.metadata["chunk_order"] = fragment_info[
+ extraction.document_id
+ ]
+
+ version = fragment.metadata.get("version", "v0")
+
+ # Ensure fragment ID is set correctly
+ if not fragment.id:
+ fragment.id = generate_id_from_label(
+ f"{extraction.id}-{fragment_info[extraction.document_id]}-{version}"
+ )
+
+ fragment_batch.append(fragment)
+ if len(fragment_batch) >= self.embedding_batch_size:
+ asyncio.create_task(
+ self._process_and_enqueue_batch(
+ fragment_batch.copy(), vector_entry_queue
+ )
+ )
+ active_tasks += 1
+ fragment_batch.clear()
+
+ logger.debug(
+ f"Fragmented the input document ids into counts as shown: {fragment_info}"
+ )
+
+ if fragment_batch:
+ asyncio.create_task(
+ self._process_and_enqueue_batch(
+ fragment_batch.copy(), vector_entry_queue
+ )
+ )
+ active_tasks += 1
+
+ while active_tasks > 0:
+ vector_entry = await vector_entry_queue.get()
+ if vector_entry is None: # Check for termination signal
+ active_tasks -= 1
+ elif isinstance(vector_entry, Exception):
+ yield vector_entry # Propagate the exception
+ active_tasks -= 1
+ else:
+ yield vector_entry
diff --git a/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
new file mode 100755
index 00000000..13025e39
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
@@ -0,0 +1,226 @@
+import asyncio
+import copy
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncState,
+ Extraction,
+ Fragment,
+ FragmentType,
+ KGExtraction,
+ KGProvider,
+ KVLoggingSingleton,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+ TextSplitter,
+ extract_entities,
+ extract_triples,
+ generate_id_from_label,
+)
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class ClientError(Exception):
+ """Base class for client connection errors."""
+
+ pass
+
+
+class KGExtractionPipe(AsyncPipe):
+ """
+ Embeds and stores documents using a specified embedding model and database.
+ """
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ text_splitter: TextSplitter,
+ kg_batch_size: int = 1,
+ id_prefix: str = "demo",
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the embedding pipe with necessary components and configurations.
+ """
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config
+ or AsyncPipe.PipeConfig(name="default_embedding_pipe"),
+ )
+
+ self.kg_provider = kg_provider
+ self.prompt_provider = prompt_provider
+ self.llm_provider = llm_provider
+ self.text_splitter = text_splitter
+ self.kg_batch_size = kg_batch_size
+ self.id_prefix = id_prefix
+ self.pipe_run_info = None
+
+ async def fragment(
+ self, extraction: Extraction, run_id: uuid.UUID
+ ) -> AsyncGenerator[Fragment, None]:
+ """
+ Splits text into manageable chunks for embedding.
+ """
+ if not isinstance(extraction, Extraction):
+ raise ValueError(
+ f"Expected an Extraction, but received {type(extraction)}."
+ )
+ if not isinstance(extraction.data, str):
+ raise ValueError(
+ f"Expected a string, but received {type(extraction.data)}."
+ )
+ text_chunks = [
+ ele.page_content
+ for ele in self.text_splitter.create_documents([extraction.data])
+ ]
+ for iteration, chunk in enumerate(text_chunks):
+ fragment = Fragment(
+ id=generate_id_from_label(f"{extraction.id}-{iteration}"),
+ type=FragmentType.TEXT,
+ data=chunk,
+ metadata=copy.deepcopy(extraction.metadata),
+ extraction_id=extraction.id,
+ document_id=extraction.document_id,
+ )
+ yield fragment
+
+ async def transform_fragments(
+ self, fragments: list[Fragment]
+ ) -> AsyncGenerator[Fragment, None]:
+ """
+ Transforms text chunks based on their metadata, e.g., adding prefixes.
+ """
+ async for fragment in fragments:
+ if "chunk_prefix" in fragment.metadata:
+ prefix = fragment.metadata.pop("chunk_prefix")
+ fragment.data = f"{prefix}\n{fragment.data}"
+ yield fragment
+
+ async def extract_kg(
+ self,
+ fragment: Fragment,
+ retries: int = 3,
+ delay: int = 2,
+ ) -> KGExtraction:
+ """
+ Extracts NER triples from a list of fragments with retries.
+ """
+ task_prompt = self.prompt_provider.get_prompt(
+ self.kg_provider.config.kg_extraction_prompt,
+ inputs={"input": fragment.data},
+ )
+ messages = self.prompt_provider._get_message_payload(
+ self.prompt_provider.get_prompt("default_system"), task_prompt
+ )
+ for attempt in range(retries):
+ try:
+ response = await self.llm_provider.aget_completion(
+ messages, self.kg_provider.config.kg_extraction_config
+ )
+
+ kg_extraction = response.choices[0].message.content
+
+ # Parsing JSON from the response
+ kg_json = (
+ json.loads(
+ kg_extraction.split("```json")[1].split("```")[0]
+ )
+ if """```json""" in kg_extraction
+ else json.loads(kg_extraction)
+ )
+ llm_payload = kg_json.get("entities_and_triples", {})
+
+ # Extract triples with detailed logging
+ entities = extract_entities(llm_payload)
+ triples = extract_triples(llm_payload, entities)
+
+ # Create KG extraction object
+ return KGExtraction(entities=entities, triples=triples)
+ except (
+ ClientError,
+ json.JSONDecodeError,
+ KeyError,
+ IndexError,
+ ) as e:
+ logger.error(f"Error in extract_kg: {e}")
+ if attempt < retries - 1:
+ await asyncio.sleep(delay)
+ else:
+ logger.error(f"Failed after retries with {e}")
+ # raise e # Ensure the exception is raised after the final attempt
+
+ return KGExtraction(entities={}, triples=[])
+
+ async def _process_batch(
+ self,
+ fragment_batch: list[Fragment],
+ ) -> list[KGExtraction]:
+ """
+ Embeds a batch of fragments and yields vector entries.
+ """
+ tasks = [
+ asyncio.create_task(self.extract_kg(fragment))
+ for fragment in fragment_batch
+ ]
+ return await asyncio.gather(*tasks)
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[KGExtraction, None]:
+ """
+ Executes the embedding pipe: chunking, transforming, embedding, and storing documents.
+ """
+ batch_tasks = []
+ fragment_batch = []
+
+ fragment_info = {}
+ async for extraction in input.message:
+ async for fragment in self.transform_fragments(
+ self.fragment(extraction, run_id)
+ ):
+ if extraction.document_id in fragment_info:
+ fragment_info[extraction.document_id] += 1
+ else:
+ fragment_info[extraction.document_id] = 1
+ extraction.metadata["chunk_order"] = fragment_info[
+ extraction.document_id
+ ]
+ fragment_batch.append(fragment)
+ if len(fragment_batch) >= self.kg_batch_size:
+ # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly
+ batch_tasks.append(
+ self._process_batch(fragment_batch.copy())
+ ) # pass a copy if necessary
+ fragment_batch.clear() # Clear the batch for new fragments
+
+ logger.debug(
+ f"Fragmented the input document ids into counts as shown: {fragment_info}"
+ )
+
+ if fragment_batch: # Process any remaining fragments
+ batch_tasks.append(self._process_batch(fragment_batch.copy()))
+
+ # Process tasks as they complete
+ for task in asyncio.as_completed(batch_tasks):
+ batch_result = await task # Wait for the next task to complete
+ for kg_extraction in batch_result:
+ yield kg_extraction
diff --git a/R2R/r2r/pipes/ingestion/kg_storage_pipe.py b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
new file mode 100755
index 00000000..9ac63479
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py
@@ -0,0 +1,133 @@
+import asyncio
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncState,
+ EmbeddingProvider,
+ KGExtraction,
+ KGProvider,
+ KVLoggingSingleton,
+ PipeType,
+)
+from r2r.base.abstractions.llama_abstractions import EntityNode, Relation
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGStoragePipe(AsyncPipe):
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[KGExtraction, None]
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ embedding_provider: Optional[EmbeddingProvider] = None,
+ storage_batch_size: int = 1,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the async knowledge graph storage pipe with necessary components and configurations.
+ """
+ logger.info(
+ f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database."
+ )
+
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config,
+ *args,
+ **kwargs,
+ )
+ self.kg_provider = kg_provider
+ self.embedding_provider = embedding_provider
+ self.storage_batch_size = storage_batch_size
+
+ async def store(
+ self,
+ kg_extractions: list[KGExtraction],
+ ) -> None:
+ """
+ Stores a batch of knowledge graph extractions in the graph database.
+ """
+ try:
+ nodes = []
+ relations = []
+ for extraction in kg_extractions:
+ for entity in extraction.entities.values():
+ embedding = None
+ if self.embedding_provider:
+ embedding = self.embedding_provider.get_embedding(
+ "Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}"
+ )
+ nodes.append(
+ EntityNode(
+ name=entity.value,
+ label=entity.category,
+ embedding=embedding,
+ properties=(
+ {"subcategory": entity.subcategory}
+ if entity.subcategory
+ else {}
+ ),
+ )
+ )
+ for triple in extraction.triples:
+ relations.append(
+ Relation(
+ source_id=triple.subject,
+ target_id=triple.object,
+ label=triple.predicate,
+ )
+ )
+ self.kg_provider.upsert_nodes(nodes)
+ self.kg_provider.upsert_relations(relations)
+ except Exception as e:
+ error_message = f"Failed to store knowledge graph extractions in the database: {e}"
+ logger.error(error_message)
+ raise ValueError(error_message)
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[None, None]:
+ """
+ Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.
+ """
+ batch_tasks = []
+ kg_batch = []
+
+ async for kg_extraction in input.message:
+ kg_batch.append(kg_extraction)
+ if len(kg_batch) >= self.storage_batch_size:
+ # Schedule the storage task
+ batch_tasks.append(
+ asyncio.create_task(
+ self.store(kg_batch.copy()),
+ name=f"kg-store-{self.config.name}",
+ )
+ )
+ kg_batch.clear()
+
+ if kg_batch: # Process any remaining extractions
+ batch_tasks.append(
+ asyncio.create_task(
+ self.store(kg_batch.copy()),
+ name=f"kg-store-{self.config.name}",
+ )
+ )
+
+ # Wait for all storage tasks to complete
+ await asyncio.gather(*batch_tasks)
+ yield None
diff --git a/R2R/r2r/pipes/ingestion/parsing_pipe.py b/R2R/r2r/pipes/ingestion/parsing_pipe.py
new file mode 100755
index 00000000..f3c81ca0
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/parsing_pipe.py
@@ -0,0 +1,211 @@
+"""
+This module contains the `DocumentParsingPipe` class, which is responsible for parsing incoming documents into plaintext.
+"""
+
+import asyncio
+import logging
+import time
+import uuid
+from typing import AsyncGenerator, Optional, Union
+
+from r2r.base import (
+ AsyncParser,
+ AsyncState,
+ Document,
+ DocumentType,
+ Extraction,
+ ExtractionType,
+ KVLoggingSingleton,
+ PipeType,
+ generate_id_from_label,
+)
+from r2r.base.abstractions.exception import R2RDocumentProcessingError
+from r2r.base.pipes.base_pipe import AsyncPipe
+from r2r.parsers.media.audio_parser import AudioParser
+from r2r.parsers.media.docx_parser import DOCXParser
+from r2r.parsers.media.img_parser import ImageParser
+from r2r.parsers.media.movie_parser import MovieParser
+from r2r.parsers.media.pdf_parser import PDFParser
+from r2r.parsers.media.ppt_parser import PPTParser
+from r2r.parsers.structured.csv_parser import CSVParser
+from r2r.parsers.structured.json_parser import JSONParser
+from r2r.parsers.structured.xlsx_parser import XLSXParser
+from r2r.parsers.text.html_parser import HTMLParser
+from r2r.parsers.text.md_parser import MDParser
+from r2r.parsers.text.text_parser import TextParser
+
+logger = logging.getLogger(__name__)
+
+
+class ParsingPipe(AsyncPipe):
+ """
+ Processes incoming documents into plaintext based on their data type.
+ Supports TXT, JSON, HTML, and PDF formats.
+ """
+
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[Document, None]
+
+ AVAILABLE_PARSERS = {
+ DocumentType.CSV: CSVParser,
+ DocumentType.DOCX: DOCXParser,
+ DocumentType.HTML: HTMLParser,
+ DocumentType.JSON: JSONParser,
+ DocumentType.MD: MDParser,
+ DocumentType.PDF: PDFParser,
+ DocumentType.PPTX: PPTParser,
+ DocumentType.TXT: TextParser,
+ DocumentType.XLSX: XLSXParser,
+ DocumentType.GIF: ImageParser,
+ DocumentType.JPEG: ImageParser,
+ DocumentType.JPG: ImageParser,
+ DocumentType.PNG: ImageParser,
+ DocumentType.SVG: ImageParser,
+ DocumentType.MP3: AudioParser,
+ DocumentType.MP4: MovieParser,
+ }
+
+ IMAGE_TYPES = {
+ DocumentType.GIF,
+ DocumentType.JPG,
+ DocumentType.JPEG,
+ DocumentType.PNG,
+ DocumentType.SVG,
+ }
+
+ def __init__(
+ self,
+ excluded_parsers: list[DocumentType],
+ override_parsers: Optional[dict[DocumentType, AsyncParser]] = None,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config
+ or AsyncPipe.PipeConfig(name="default_document_parsing_pipe"),
+ *args,
+ **kwargs,
+ )
+
+ self.parsers = {}
+
+ if not override_parsers:
+ override_parsers = {}
+
+ # Apply overrides if specified
+ for doc_type, parser in override_parsers.items():
+ self.parsers[doc_type] = parser
+
+ for doc_type, parser_info in self.AVAILABLE_PARSERS.items():
+ if (
+ doc_type not in excluded_parsers
+ and doc_type not in self.parsers
+ ):
+ self.parsers[doc_type] = parser_info()
+
+ @property
+ def supported_types(self) -> list[str]:
+ """
+ Lists the data types supported by the pipe.
+ """
+ return [entry_type for entry_type in DocumentType]
+
+ async def _parse(
+ self,
+ document: Document,
+ run_id: uuid.UUID,
+ version: str,
+ ) -> AsyncGenerator[Union[R2RDocumentProcessingError, Extraction], None]:
+ if document.type not in self.parsers:
+ yield R2RDocumentProcessingError(
+ document_id=document.id,
+ error_message=f"Parser for {document.type} not found in `ParsingPipe`.",
+ )
+ return
+ parser = self.parsers[document.type]
+ texts = parser.ingest(document.data)
+ extraction_type = ExtractionType.TXT
+ t0 = time.time()
+ if document.type in self.IMAGE_TYPES:
+ extraction_type = ExtractionType.IMG
+ document.metadata["image_type"] = document.type.value
+ # SAVE IMAGE DATA
+ # try:
+ # import base64
+ # sanitized_data = base64.b64encode(document.data).decode('utf-8')
+ # except Exception as e:
+ # sanitized_data = document.data
+
+ # document.metadata["image_data"] = sanitized_data
+ elif document.type == DocumentType.MP4:
+ extraction_type = ExtractionType.MOV
+ document.metadata["audio_type"] = document.type.value
+
+ iteration = 0
+ async for text in texts:
+ extraction_id = generate_id_from_label(
+ f"{document.id}-{iteration}-{version}"
+ )
+ document.metadata["version"] = version
+ extraction = Extraction(
+ id=extraction_id,
+ data=text,
+ metadata=document.metadata,
+ document_id=document.id,
+ type=extraction_type,
+ )
+ yield extraction
+ # TODO - Add settings to enable extraction logging
+ # extraction_dict = extraction.dict()
+ # await self.enqueue_log(
+ # run_id=run_id,
+ # key="extraction",
+ # value=json.dumps(
+ # {
+ # "data": extraction_dict["data"],
+ # "document_id": str(extraction_dict["document_id"]),
+ # "extraction_id": str(extraction_dict["id"]),
+ # }
+ # ),
+ # )
+ iteration += 1
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} into {iteration} extractions in t={time.time() - t0:.2f} seconds."
+ )
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ versions: Optional[list[str]] = None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[Extraction, None]:
+ parse_tasks = []
+
+ iteration = 0
+ async for document in input.message:
+ version = versions[iteration] if versions else "v0"
+ iteration += 1
+ parse_tasks.append(
+ self._handle_parse_task(document, version, run_id)
+ )
+
+ # Await all tasks and yield results concurrently
+ for parse_task in asyncio.as_completed(parse_tasks):
+ for extraction in await parse_task:
+ yield extraction
+
+ async def _handle_parse_task(
+ self, document: Document, version: str, run_id: uuid.UUID
+ ) -> AsyncGenerator[Extraction, None]:
+ extractions = []
+ async for extraction in self._parse(document, run_id, version):
+ extractions.append(extraction)
+ return extractions
diff --git a/R2R/r2r/pipes/ingestion/vector_storage_pipe.py b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py
new file mode 100755
index 00000000..9564fd22
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py
@@ -0,0 +1,128 @@
+import asyncio
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional, Tuple, Union
+
+from r2r.base import (
+ AsyncState,
+ KVLoggingSingleton,
+ PipeType,
+ VectorDBProvider,
+ VectorEntry,
+)
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+from ...base.abstractions.exception import R2RDocumentProcessingError
+
+logger = logging.getLogger(__name__)
+
+
+class VectorStoragePipe(AsyncPipe):
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[
+ Union[R2RDocumentProcessingError, VectorEntry], None
+ ]
+ do_upsert: bool = True
+
+ def __init__(
+ self,
+ vector_db_provider: VectorDBProvider,
+ storage_batch_size: int = 128,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the async vector storage pipe with necessary components and configurations.
+ """
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config,
+ *args,
+ **kwargs,
+ )
+ self.vector_db_provider = vector_db_provider
+ self.storage_batch_size = storage_batch_size
+
+ async def store(
+ self,
+ vector_entries: list[VectorEntry],
+ do_upsert: bool = True,
+ ) -> None:
+ """
+ Stores a batch of vector entries in the database.
+ """
+
+ try:
+ if do_upsert:
+ self.vector_db_provider.upsert_entries(vector_entries)
+ else:
+ self.vector_db_provider.copy_entries(vector_entries)
+ except Exception as e:
+ error_message = (
+ f"Failed to store vector entries in the database: {e}"
+ )
+ logger.error(error_message)
+ raise ValueError(error_message)
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[
+ Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None
+ ]:
+ """
+ Executes the async vector storage pipe: storing embeddings in the vector database.
+ """
+ batch_tasks = []
+ vector_batch = []
+ document_counts = {}
+ i = 0
+ async for msg in input.message:
+ i += 1
+ if isinstance(msg, R2RDocumentProcessingError):
+ yield (msg.document_id, msg)
+ continue
+
+ document_id = msg.metadata.get("document_id", None)
+ if not document_id:
+ raise ValueError("Document ID not found in the metadata.")
+ if document_id not in document_counts:
+ document_counts[document_id] = 1
+ else:
+ document_counts[document_id] += 1
+
+ vector_batch.append(msg)
+ if len(vector_batch) >= self.storage_batch_size:
+ # Schedule the storage task
+ batch_tasks.append(
+ asyncio.create_task(
+ self.store(vector_batch.copy(), input.do_upsert),
+ name=f"vector-store-{self.config.name}",
+ )
+ )
+ vector_batch.clear()
+
+ if vector_batch: # Process any remaining vectors
+ batch_tasks.append(
+ asyncio.create_task(
+ self.store(vector_batch.copy(), input.do_upsert),
+ name=f"vector-store-{self.config.name}",
+ )
+ )
+
+ # Wait for all storage tasks to complete
+ await asyncio.gather(*batch_tasks)
+
+ for document_id, count in document_counts.items():
+ yield (
+ document_id,
+ f"Processed {count} vectors for document {document_id}.",
+ )
diff --git a/R2R/r2r/pipes/other/eval_pipe.py b/R2R/r2r/pipes/other/eval_pipe.py
new file mode 100755
index 00000000..b1c60343
--- /dev/null
+++ b/R2R/r2r/pipes/other/eval_pipe.py
@@ -0,0 +1,54 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from pydantic import BaseModel
+
+from r2r import AsyncState, EvalProvider, LLMChatCompletion, PipeType
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class EvalPipe(AsyncPipe):
+ class EvalPayload(BaseModel):
+ query: str
+ context: str
+ completion: str
+
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator["EvalPipe.EvalPayload", None]
+
+ def __init__(
+ self,
+ eval_provider: EvalProvider,
+ type: PipeType = PipeType.EVAL,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ self.eval_provider = eval_provider
+ super().__init__(
+ type=type,
+ config=config or AsyncPipe.PipeConfig(name="default_eval_pipe"),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ eval_generation_config: GenerationConfig,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[LLMChatCompletion, None]:
+ async for item in input.message:
+ yield self.eval_provider.evaluate(
+ item.query,
+ item.context,
+ item.completion,
+ eval_generation_config,
+ )
diff --git a/R2R/r2r/pipes/other/web_search_pipe.py b/R2R/r2r/pipes/other/web_search_pipe.py
new file mode 100755
index 00000000..92e3feee
--- /dev/null
+++ b/R2R/r2r/pipes/other/web_search_pipe.py
@@ -0,0 +1,105 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ PipeType,
+ VectorSearchResult,
+ generate_id_from_label,
+)
+from r2r.integrations import SerperClient
+
+from ..abstractions.search_pipe import SearchPipe
+
+logger = logging.getLogger(__name__)
+
+
+class WebSearchPipe(SearchPipe):
+ def __init__(
+ self,
+ serper_client: SerperClient,
+ type: PipeType = PipeType.SEARCH,
+ config: Optional[SearchPipe.SearchConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ type=type,
+ config=config or SearchPipe.SearchConfig(),
+ *args,
+ **kwargs,
+ )
+ self.serper_client = serper_client
+
+ async def search(
+ self,
+ message: str,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ search_limit_override = kwargs.get("search_limit", None)
+ await self.enqueue_log(
+ run_id=run_id, key="search_query", value=message
+ )
+ # TODO - Make more general in the future by creating a SearchProvider interface
+ results = self.serper_client.get_raw(
+ query=message,
+ limit=search_limit_override or self.config.search_limit,
+ )
+
+ search_results = []
+ for result in results:
+ if result.get("snippet") is None:
+ continue
+ result["text"] = result.pop("snippet")
+ search_result = VectorSearchResult(
+ id=generate_id_from_label(str(result)),
+ score=result.get(
+ "score", 0
+ ), # TODO - Consider dynamically generating scores based on similarity
+ metadata=result,
+ )
+ search_results.append(search_result)
+ yield search_result
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="search_results",
+ value=json.dumps([ele.json() for ele in search_results]),
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ search_queries = []
+ search_results = []
+ async for search_request in input.message:
+ search_queries.append(search_request)
+ async for result in self.search(
+ message=search_request, run_id=run_id, *args, **kwargs
+ ):
+ search_results.append(result)
+ yield result
+
+ await state.update(
+ self.config.name, {"output": {"search_results": search_results}}
+ )
+
+ await state.update(
+ self.config.name,
+ {
+ "output": {
+ "search_queries": search_queries,
+ "search_results": search_results,
+ }
+ },
+ )
diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
new file mode 100755
index 00000000..60935265
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
@@ -0,0 +1,103 @@
+import logging
+import uuid
+from typing import Any, Optional
+
+from r2r.base import (
+ AsyncState,
+ KGProvider,
+ KGSearchSettings,
+ KVLoggingSingleton,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGAgentSearchPipe(GeneratorPipe):
+ """
+ Embeds and stores documents using a specified embedding model and database.
+ """
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[GeneratorPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the embedding pipe with necessary components and configurations.
+ """
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="kg_rag_pipe", task_prompt="kg_agent"
+ ),
+ pipe_logger=pipe_logger,
+ *args,
+ **kwargs,
+ )
+ self.kg_provider = kg_provider
+ self.llm_provider = llm_provider
+ self.prompt_provider = prompt_provider
+ self.pipe_run_info = None
+
+ async def _run_logic(
+ self,
+ input: GeneratorPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ kg_search_settings: KGSearchSettings,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ async for message in input.message:
+ # TODO - Remove hard code
+ formatted_prompt = self.prompt_provider.get_prompt(
+ "kg_agent", {"input": message}
+ )
+ messages = self._get_message_payload(formatted_prompt)
+
+ result = await self.llm_provider.aget_completion(
+ messages=messages,
+ generation_config=kg_search_settings.agent_generation_config,
+ )
+
+ extraction = result.choices[0].message.content
+ query = extraction.split("```cypher")[1].split("```")[0]
+ result = self.kg_provider.structured_query(query)
+ yield (query, result)
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="kg_agent_response",
+ value=extraction,
+ )
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="kg_agent_execution_result",
+ value=result,
+ )
+
+ def _get_message_payload(self, message: str) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {"role": "user", "content": message},
+ ]
diff --git a/R2R/r2r/pipes/retrieval/multi_search.py b/R2R/r2r/pipes/retrieval/multi_search.py
new file mode 100755
index 00000000..6da2c34b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/multi_search.py
@@ -0,0 +1,79 @@
+import uuid
+from copy import copy
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.abstractions.search import VectorSearchResult
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+from ..abstractions.search_pipe import SearchPipe
+from .query_transform_pipe import QueryTransformPipe
+
+
+class MultiSearchPipe(AsyncPipe):
+ class PipeConfig(AsyncPipe.PipeConfig):
+ name: str = "multi_search_pipe"
+
+ def __init__(
+ self,
+ query_transform_pipe: QueryTransformPipe,
+ inner_search_pipe: SearchPipe,
+ config: Optional[PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ self.query_transform_pipe = query_transform_pipe
+ self.vector_search_pipe = inner_search_pipe
+ if (
+ not query_transform_pipe.config.name
+ == inner_search_pipe.config.name
+ ):
+ raise ValueError(
+ "The query transform pipe and search pipe must have the same name."
+ )
+ if config and not config.name == query_transform_pipe.config.name:
+ raise ValueError(
+ "The pipe config name must match the query transform pipe name."
+ )
+
+ super().__init__(
+ config=config
+ or MultiSearchPipe.PipeConfig(
+ name=query_transform_pipe.config.name
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: Any,
+ state: Any,
+ run_id: uuid.UUID,
+ query_transform_generation_config: Optional[GenerationConfig] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ query_transform_generation_config = (
+ query_transform_generation_config
+ or copy(kwargs.get("rag_generation_config", None))
+ or GenerationConfig(model="gpt-4o")
+ )
+ query_transform_generation_config.stream = False
+
+ query_generator = await self.query_transform_pipe.run(
+ input,
+ state,
+ query_transform_generation_config=query_transform_generation_config,
+ num_query_xf_outputs=3,
+ *args,
+ **kwargs,
+ )
+
+ async for search_result in await self.vector_search_pipe.run(
+ self.vector_search_pipe.Input(message=query_generator),
+ state,
+ *args,
+ **kwargs,
+ ):
+ yield search_result
diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
new file mode 100755
index 00000000..99df6b5b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
@@ -0,0 +1,101 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class QueryTransformPipe(GeneratorPipe):
+ class QueryTransformConfig(GeneratorPipe.PipeConfig):
+ name: str = "default_query_transform"
+ system_prompt: str = "default_system"
+ task_prompt: str = "hyde"
+
+ class Input(GeneratorPipe.Input):
+ message: AsyncGenerator[str, None]
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.TRANSFORM,
+ config: Optional[QueryTransformConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config or QueryTransformPipe.QueryTransformConfig(),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ query_transform_generation_config: GenerationConfig,
+ num_query_xf_outputs: int = 3,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[str, None]:
+ async for query in input.message:
+ logger.info(
+ f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
+ )
+
+ query_transform_request = self._get_message_payload(
+ query, num_outputs=num_query_xf_outputs
+ )
+
+ response = await self.llm_provider.aget_completion(
+ messages=query_transform_request,
+ generation_config=query_transform_generation_config,
+ )
+ content = self.llm_provider.extract_content(response)
+ outputs = content.split("\n")
+ outputs = [
+ output.strip() for output in outputs if output.strip() != ""
+ ]
+ await state.update(
+ self.config.name, {"output": {"outputs": outputs}}
+ )
+
+ for output in outputs:
+ logger.info(f"Yielding transformed output: {output}")
+ yield output
+
+ def _get_message_payload(self, input: str, num_outputs: int) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={
+ "message": input,
+ "num_outputs": num_outputs,
+ },
+ ),
+ },
+ ]
diff --git a/R2R/r2r/pipes/retrieval/search_rag_pipe.py b/R2R/r2r/pipes/retrieval/search_rag_pipe.py
new file mode 100755
index 00000000..4d01d2df
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/search_rag_pipe.py
@@ -0,0 +1,130 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional, Tuple
+
+from r2r.base import (
+ AggregateSearchResult,
+ AsyncPipe,
+ AsyncState,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class SearchRAGPipe(GeneratorPipe):
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.GENERATOR,
+ config: Optional[GeneratorPipe] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="default_rag_pipe", task_prompt="default_rag"
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ rag_generation_config: GenerationConfig,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[RAGCompletion, None]:
+ context = ""
+ search_iteration = 1
+ total_results = 0
+ # must select a query if there are multiple
+ sel_query = None
+ async for query, search_results in input.message:
+ if search_iteration == 1:
+ sel_query = query
+ context_piece, total_results = await self._collect_context(
+ query, search_results, search_iteration, total_results
+ )
+ context += context_piece
+ search_iteration += 1
+
+ messages = self._get_message_payload(sel_query, context)
+
+ response = await self.llm_provider.aget_completion(
+ messages=messages, generation_config=rag_generation_config
+ )
+ yield RAGCompletion(completion=response, search_results=search_results)
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="llm_response",
+ value=response.choices[0].message.content,
+ )
+
+ def _get_message_payload(self, query: str, context: str) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={
+ "query": query,
+ "context": context,
+ },
+ ),
+ },
+ ]
+
+ async def _collect_context(
+ self,
+ query: str,
+ results: AggregateSearchResult,
+ iteration: int,
+ total_results: int,
+ ) -> Tuple[str, int]:
+ context = f"Query:\n{query}\n\n"
+ if results.vector_search_results:
+ context += f"Vector Search Results({iteration}):\n"
+ it = total_results + 1
+ for result in results.vector_search_results:
+ context += f"[{it}]: {result.metadata['text']}\n\n"
+ it += 1
+ total_results = (
+ it - 1
+ ) # Update total_results based on the last index used
+ if results.kg_search_results:
+ context += f"Knowledge Graph ({iteration}):\n"
+ it = total_results + 1
+ for query, search_results in results.kg_search_results: # [1]:
+ context += f"Query: {query}\n\n"
+ context += f"Results:\n"
+ for search_result in search_results:
+ context += f"[{it}]: {search_result}\n\n"
+ it += 1
+ total_results = (
+ it - 1
+ ) # Update total_results based on the last index used
+ return context, total_results
diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
new file mode 100755
index 00000000..b01f6445
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
@@ -0,0 +1,131 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from r2r.base import (
+ AsyncState,
+ LLMChatCompletionChunk,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+from .search_rag_pipe import SearchRAGPipe
+
+logger = logging.getLogger(__name__)
+
+
+class StreamingSearchRAGPipe(SearchRAGPipe):
+ SEARCH_STREAM_MARKER = "search"
+ COMPLETION_STREAM_MARKER = "completion"
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.GENERATOR,
+ config: Optional[GeneratorPipe] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="default_streaming_rag_pipe", task_prompt="default_rag"
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: SearchRAGPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ rag_generation_config: GenerationConfig,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[str, None]:
+ iteration = 0
+ context = ""
+ # dump the search results and construct the context
+ async for query, search_results in input.message:
+ yield f"<{self.SEARCH_STREAM_MARKER}>"
+ if search_results.vector_search_results:
+ context += "Vector Search Results:\n"
+ for result in search_results.vector_search_results:
+ if iteration >= 1:
+ yield ","
+ yield json.dumps(result.json())
+ context += (
+ f"{iteration + 1}:\n{result.metadata['text']}\n\n"
+ )
+ iteration += 1
+
+ # if search_results.kg_search_results:
+ # for result in search_results.kg_search_results:
+ # if iteration >= 1:
+ # yield ","
+ # yield json.dumps(result.json())
+ # context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n"
+ # iteration += 1
+
+ yield f"</{self.SEARCH_STREAM_MARKER}>"
+
+ messages = self._get_message_payload(query, context)
+ yield f"<{self.COMPLETION_STREAM_MARKER}>"
+ response = ""
+ for chunk in self.llm_provider.get_completion_stream(
+ messages=messages, generation_config=rag_generation_config
+ ):
+ chunk = StreamingSearchRAGPipe._process_chunk(chunk)
+ response += chunk
+ yield chunk
+
+ yield f"</{self.COMPLETION_STREAM_MARKER}>"
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="llm_response",
+ value=response,
+ )
+
+ async def _yield_chunks(
+ self,
+ start_marker: str,
+ chunks: Generator[str, None, None],
+ end_marker: str,
+ ) -> str:
+ yield start_marker
+ for chunk in chunks:
+ yield chunk
+ yield end_marker
+
+ def _get_message_payload(
+ self, query: str, context: str
+ ) -> list[dict[str, str]]:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={"query": query, "context": context},
+ ),
+ },
+ ]
+
+ @staticmethod
+ def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
+ return chunk.choices[0].delta.content or ""
diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
new file mode 100755
index 00000000..742de16b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
@@ -0,0 +1,123 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ EmbeddingProvider,
+ PipeType,
+ VectorDBProvider,
+ VectorSearchResult,
+ VectorSearchSettings,
+)
+
+from ..abstractions.search_pipe import SearchPipe
+
+logger = logging.getLogger(__name__)
+
+
+class VectorSearchPipe(SearchPipe):
+ def __init__(
+ self,
+ vector_db_provider: VectorDBProvider,
+ embedding_provider: EmbeddingProvider,
+ type: PipeType = PipeType.SEARCH,
+ config: Optional[SearchPipe.SearchConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ type=type,
+ config=config or SearchPipe.SearchConfig(),
+ *args,
+ **kwargs,
+ )
+ self.embedding_provider = embedding_provider
+ self.vector_db_provider = vector_db_provider
+
+ async def search(
+ self,
+ message: str,
+ run_id: uuid.UUID,
+ vector_search_settings: VectorSearchSettings,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ await self.enqueue_log(
+ run_id=run_id, key="search_query", value=message
+ )
+ search_filters = (
+ vector_search_settings.search_filters or self.config.search_filters
+ )
+ search_limit = (
+ vector_search_settings.search_limit or self.config.search_limit
+ )
+ results = []
+ query_vector = self.embedding_provider.get_embedding(
+ message,
+ )
+ search_results = (
+ self.vector_db_provider.hybrid_search(
+ query_vector=query_vector,
+ query_text=message,
+ filters=search_filters,
+ limit=search_limit,
+ )
+ if vector_search_settings.do_hybrid_search
+ else self.vector_db_provider.search(
+ query_vector=query_vector,
+ filters=search_filters,
+ limit=search_limit,
+ )
+ )
+ reranked_results = self.embedding_provider.rerank(
+ query=message, results=search_results, limit=search_limit
+ )
+ for result in reranked_results:
+ result.metadata["associatedQuery"] = message
+ results.append(result)
+ yield result
+ await self.enqueue_log(
+ run_id=run_id,
+ key="search_results",
+ value=json.dumps([ele.json() for ele in results]),
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ search_queries = []
+ search_results = []
+ async for search_request in input.message:
+ search_queries.append(search_request)
+ async for result in self.search(
+ message=search_request,
+ run_id=run_id,
+ vector_search_settings=vector_search_settings,
+ *args,
+ **kwargs,
+ ):
+ search_results.append(result)
+ yield result
+
+ await state.update(
+ self.config.name, {"output": {"search_results": search_results}}
+ )
+
+ await state.update(
+ self.config.name,
+ {
+ "output": {
+ "search_queries": search_queries,
+ "search_results": search_results,
+ }
+ },
+ )