diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes')
-rwxr-xr-x | R2R/r2r/pipes/__init__.py | 31 | ||||
-rwxr-xr-x | R2R/r2r/pipes/abstractions/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/pipes/abstractions/generator_pipe.py | 58 | ||||
-rwxr-xr-x | R2R/r2r/pipes/abstractions/search_pipe.py | 62 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/embedding_pipe.py | 218 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/kg_extraction_pipe.py | 226 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/kg_storage_pipe.py | 133 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/parsing_pipe.py | 211 | ||||
-rwxr-xr-x | R2R/r2r/pipes/ingestion/vector_storage_pipe.py | 128 | ||||
-rwxr-xr-x | R2R/r2r/pipes/other/eval_pipe.py | 54 | ||||
-rwxr-xr-x | R2R/r2r/pipes/other/web_search_pipe.py | 105 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py | 103 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/multi_search.py | 79 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/query_transform_pipe.py | 101 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/search_rag_pipe.py | 130 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/streaming_rag_pipe.py | 131 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/vector_search_pipe.py | 123 |
18 files changed, 1893 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/__init__.py b/R2R/r2r/pipes/__init__.py new file mode 100755 index 00000000..b86c31c0 --- /dev/null +++ b/R2R/r2r/pipes/__init__.py @@ -0,0 +1,31 @@ +from .abstractions.search_pipe import SearchPipe +from .ingestion.embedding_pipe import EmbeddingPipe +from .ingestion.kg_extraction_pipe import KGExtractionPipe +from .ingestion.kg_storage_pipe import KGStoragePipe +from .ingestion.parsing_pipe import ParsingPipe +from .ingestion.vector_storage_pipe import VectorStoragePipe +from .other.eval_pipe import EvalPipe +from .other.web_search_pipe import WebSearchPipe +from .retrieval.kg_agent_search_pipe import KGAgentSearchPipe +from .retrieval.multi_search import MultiSearchPipe +from .retrieval.query_transform_pipe import QueryTransformPipe +from .retrieval.search_rag_pipe import SearchRAGPipe +from .retrieval.streaming_rag_pipe import StreamingSearchRAGPipe +from .retrieval.vector_search_pipe import VectorSearchPipe + +__all__ = [ + "SearchPipe", + "EmbeddingPipe", + "EvalPipe", + "KGExtractionPipe", + "ParsingPipe", + "QueryTransformPipe", + "SearchRAGPipe", + "StreamingSearchRAGPipe", + "VectorSearchPipe", + "VectorStoragePipe", + "WebSearchPipe", + "KGAgentSearchPipe", + "KGStoragePipe", + "MultiSearchPipe", +] diff --git a/R2R/r2r/pipes/abstractions/__init__.py b/R2R/r2r/pipes/abstractions/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/pipes/abstractions/__init__.py diff --git a/R2R/r2r/pipes/abstractions/generator_pipe.py b/R2R/r2r/pipes/abstractions/generator_pipe.py new file mode 100755 index 00000000..002ebd23 --- /dev/null +++ b/R2R/r2r/pipes/abstractions/generator_pipe.py @@ -0,0 +1,58 @@ +import uuid +from abc import abstractmethod +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig +from r2r.base.pipes.base_pipe import AsyncPipe + + +class GeneratorPipe(AsyncPipe): + class Config(AsyncPipe.PipeConfig): + name: str + task_prompt: str + system_prompt: str = "default_system" + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[Config] = None, + pipe_logger: Optional[KVLoggingSingleton] = None, + *args, + **kwargs, + ): + super().__init__( + type=type, + config=config or self.Config(), + pipe_logger=pipe_logger, + *args, + **kwargs, + ) + self.llm_provider = llm_provider + self.prompt_provider = prompt_provider + + @abstractmethod + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Any, None]: + pass + + @abstractmethod + def _get_message_payload( + self, message: str, *args: Any, **kwargs: Any + ) -> list: + pass diff --git a/R2R/r2r/pipes/abstractions/search_pipe.py b/R2R/r2r/pipes/abstractions/search_pipe.py new file mode 100755 index 00000000..bb0303e0 --- /dev/null +++ b/R2R/r2r/pipes/abstractions/search_pipe.py @@ -0,0 +1,62 @@ +import logging +import uuid +from abc import abstractmethod +from typing import Any, AsyncGenerator, Optional, Union + +from r2r.base import ( + AsyncPipe, + AsyncState, + KVLoggingSingleton, + PipeType, + VectorSearchResult, +) + +logger = logging.getLogger(__name__) + + +class SearchPipe(AsyncPipe): + class SearchConfig(AsyncPipe.PipeConfig): + name: str = "default_vector_search" + search_filters: dict = {} + search_limit: int = 10 + + class Input(AsyncPipe.Input): + message: Union[AsyncGenerator[str, None], str] + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.SEARCH, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + + @abstractmethod + async def search( + self, + query: str, + filters: dict[str, Any] = {}, + limit: int = 10, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + pass + + @abstractmethod + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs, + ) -> AsyncGenerator[VectorSearchResult, None]: + pass diff --git a/R2R/r2r/pipes/ingestion/__init__.py b/R2R/r2r/pipes/ingestion/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/pipes/ingestion/__init__.py diff --git a/R2R/r2r/pipes/ingestion/embedding_pipe.py b/R2R/r2r/pipes/ingestion/embedding_pipe.py new file mode 100755 index 00000000..971ccc9d --- /dev/null +++ b/R2R/r2r/pipes/ingestion/embedding_pipe.py @@ -0,0 +1,218 @@ +import asyncio +import copy +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Union + +from r2r.base import ( + AsyncState, + EmbeddingProvider, + Extraction, + Fragment, + FragmentType, + KVLoggingSingleton, + PipeType, + R2RDocumentProcessingError, + TextSplitter, + Vector, + VectorEntry, + generate_id_from_label, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class EmbeddingPipe(AsyncPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + class Input(AsyncPipe.Input): + message: AsyncGenerator[ + Union[Extraction, R2RDocumentProcessingError], None + ] + + def __init__( + self, + embedding_provider: EmbeddingProvider, + text_splitter: TextSplitter, + embedding_batch_size: int = 1, + id_prefix: str = "demo", + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_embedding_pipe"), + ) + self.embedding_provider = embedding_provider + self.text_splitter = text_splitter + self.embedding_batch_size = embedding_batch_size + self.id_prefix = id_prefix + self.pipe_run_info = None + + async def fragment( + self, extraction: Extraction, run_id: uuid.UUID + ) -> AsyncGenerator[Fragment, None]: + """ + Splits text into manageable chunks for embedding. + """ + if not isinstance(extraction, Extraction): + raise ValueError( + f"Expected an Extraction, but received {type(extraction)}." + ) + if not isinstance(extraction.data, str): + raise ValueError( + f"Expected a string, but received {type(extraction.data)}." + ) + text_chunks = [ + ele.page_content + for ele in self.text_splitter.create_documents([extraction.data]) + ] + for iteration, chunk in enumerate(text_chunks): + fragment = Fragment( + id=generate_id_from_label(f"{extraction.id}-{iteration}"), + type=FragmentType.TEXT, + data=chunk, + metadata=copy.deepcopy(extraction.metadata), + extraction_id=extraction.id, + document_id=extraction.document_id, + ) + yield fragment + iteration += 1 + + async def transform_fragments( + self, fragments: list[Fragment], metadatas: list[dict] + ) -> AsyncGenerator[Fragment, None]: + """ + Transforms text chunks based on their metadata, e.g., adding prefixes. + """ + async for fragment, metadata in zip(fragments, metadatas): + if "chunk_prefix" in metadata: + prefix = metadata.pop("chunk_prefix") + fragment.data = f"{prefix}\n{fragment.data}" + yield fragment + + async def embed(self, fragments: list[Fragment]) -> list[float]: + return await self.embedding_provider.async_get_embeddings( + [fragment.data for fragment in fragments], + EmbeddingProvider.PipeStage.BASE, + ) + + async def _process_batch( + self, fragment_batch: list[Fragment] + ) -> list[VectorEntry]: + """ + Embeds a batch of fragments and yields vector entries. + """ + vectors = await self.embed(fragment_batch) + return [ + VectorEntry( + id=fragment.id, + vector=Vector(data=raw_vector), + metadata={ + "document_id": fragment.document_id, + "extraction_id": fragment.extraction_id, + "text": fragment.data, + **fragment.metadata, + }, + ) + for raw_vector, fragment in zip(vectors, fragment_batch) + ] + + async def _process_and_enqueue_batch( + self, fragment_batch: list[Fragment], vector_entry_queue: asyncio.Queue + ): + try: + batch_result = await self._process_batch(fragment_batch) + for vector_entry in batch_result: + await vector_entry_queue.put(vector_entry) + except Exception as e: + logger.error(f"Error processing batch: {e}") + await vector_entry_queue.put( + R2RDocumentProcessingError( + error_message=str(e), + document_id=fragment_batch[0].document_id, + ) + ) + finally: + await vector_entry_queue.put(None) # Signal completion + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Union[R2RDocumentProcessingError, VectorEntry], None]: + """ + Executes the embedding pipe: chunking, transforming, embedding, and storing documents. + """ + vector_entry_queue = asyncio.Queue() + fragment_batch = [] + active_tasks = 0 + + fragment_info = {} + async for extraction in input.message: + if isinstance(extraction, R2RDocumentProcessingError): + yield extraction + continue + + async for fragment in self.fragment(extraction, run_id): + if extraction.document_id in fragment_info: + fragment_info[extraction.document_id] += 1 + else: + fragment_info[extraction.document_id] = 0 # Start with 0 + fragment.metadata["chunk_order"] = fragment_info[ + extraction.document_id + ] + + version = fragment.metadata.get("version", "v0") + + # Ensure fragment ID is set correctly + if not fragment.id: + fragment.id = generate_id_from_label( + f"{extraction.id}-{fragment_info[extraction.document_id]}-{version}" + ) + + fragment_batch.append(fragment) + if len(fragment_batch) >= self.embedding_batch_size: + asyncio.create_task( + self._process_and_enqueue_batch( + fragment_batch.copy(), vector_entry_queue + ) + ) + active_tasks += 1 + fragment_batch.clear() + + logger.debug( + f"Fragmented the input document ids into counts as shown: {fragment_info}" + ) + + if fragment_batch: + asyncio.create_task( + self._process_and_enqueue_batch( + fragment_batch.copy(), vector_entry_queue + ) + ) + active_tasks += 1 + + while active_tasks > 0: + vector_entry = await vector_entry_queue.get() + if vector_entry is None: # Check for termination signal + active_tasks -= 1 + elif isinstance(vector_entry, Exception): + yield vector_entry # Propagate the exception + active_tasks -= 1 + else: + yield vector_entry diff --git a/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py new file mode 100755 index 00000000..13025e39 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py @@ -0,0 +1,226 @@ +import asyncio +import copy +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + Extraction, + Fragment, + FragmentType, + KGExtraction, + KGProvider, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, + TextSplitter, + extract_entities, + extract_triples, + generate_id_from_label, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class ClientError(Exception): + """Base class for client connection errors.""" + + pass + + +class KGExtractionPipe(AsyncPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + def __init__( + self, + kg_provider: KGProvider, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + text_splitter: TextSplitter, + kg_batch_size: int = 1, + id_prefix: str = "demo", + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_embedding_pipe"), + ) + + self.kg_provider = kg_provider + self.prompt_provider = prompt_provider + self.llm_provider = llm_provider + self.text_splitter = text_splitter + self.kg_batch_size = kg_batch_size + self.id_prefix = id_prefix + self.pipe_run_info = None + + async def fragment( + self, extraction: Extraction, run_id: uuid.UUID + ) -> AsyncGenerator[Fragment, None]: + """ + Splits text into manageable chunks for embedding. + """ + if not isinstance(extraction, Extraction): + raise ValueError( + f"Expected an Extraction, but received {type(extraction)}." + ) + if not isinstance(extraction.data, str): + raise ValueError( + f"Expected a string, but received {type(extraction.data)}." + ) + text_chunks = [ + ele.page_content + for ele in self.text_splitter.create_documents([extraction.data]) + ] + for iteration, chunk in enumerate(text_chunks): + fragment = Fragment( + id=generate_id_from_label(f"{extraction.id}-{iteration}"), + type=FragmentType.TEXT, + data=chunk, + metadata=copy.deepcopy(extraction.metadata), + extraction_id=extraction.id, + document_id=extraction.document_id, + ) + yield fragment + + async def transform_fragments( + self, fragments: list[Fragment] + ) -> AsyncGenerator[Fragment, None]: + """ + Transforms text chunks based on their metadata, e.g., adding prefixes. + """ + async for fragment in fragments: + if "chunk_prefix" in fragment.metadata: + prefix = fragment.metadata.pop("chunk_prefix") + fragment.data = f"{prefix}\n{fragment.data}" + yield fragment + + async def extract_kg( + self, + fragment: Fragment, + retries: int = 3, + delay: int = 2, + ) -> KGExtraction: + """ + Extracts NER triples from a list of fragments with retries. + """ + task_prompt = self.prompt_provider.get_prompt( + self.kg_provider.config.kg_extraction_prompt, + inputs={"input": fragment.data}, + ) + messages = self.prompt_provider._get_message_payload( + self.prompt_provider.get_prompt("default_system"), task_prompt + ) + for attempt in range(retries): + try: + response = await self.llm_provider.aget_completion( + messages, self.kg_provider.config.kg_extraction_config + ) + + kg_extraction = response.choices[0].message.content + + # Parsing JSON from the response + kg_json = ( + json.loads( + kg_extraction.split("```json")[1].split("```")[0] + ) + if """```json""" in kg_extraction + else json.loads(kg_extraction) + ) + llm_payload = kg_json.get("entities_and_triples", {}) + + # Extract triples with detailed logging + entities = extract_entities(llm_payload) + triples = extract_triples(llm_payload, entities) + + # Create KG extraction object + return KGExtraction(entities=entities, triples=triples) + except ( + ClientError, + json.JSONDecodeError, + KeyError, + IndexError, + ) as e: + logger.error(f"Error in extract_kg: {e}") + if attempt < retries - 1: + await asyncio.sleep(delay) + else: + logger.error(f"Failed after retries with {e}") + # raise e # Ensure the exception is raised after the final attempt + + return KGExtraction(entities={}, triples=[]) + + async def _process_batch( + self, + fragment_batch: list[Fragment], + ) -> list[KGExtraction]: + """ + Embeds a batch of fragments and yields vector entries. + """ + tasks = [ + asyncio.create_task(self.extract_kg(fragment)) + for fragment in fragment_batch + ] + return await asyncio.gather(*tasks) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[KGExtraction, None]: + """ + Executes the embedding pipe: chunking, transforming, embedding, and storing documents. + """ + batch_tasks = [] + fragment_batch = [] + + fragment_info = {} + async for extraction in input.message: + async for fragment in self.transform_fragments( + self.fragment(extraction, run_id) + ): + if extraction.document_id in fragment_info: + fragment_info[extraction.document_id] += 1 + else: + fragment_info[extraction.document_id] = 1 + extraction.metadata["chunk_order"] = fragment_info[ + extraction.document_id + ] + fragment_batch.append(fragment) + if len(fragment_batch) >= self.kg_batch_size: + # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly + batch_tasks.append( + self._process_batch(fragment_batch.copy()) + ) # pass a copy if necessary + fragment_batch.clear() # Clear the batch for new fragments + + logger.debug( + f"Fragmented the input document ids into counts as shown: {fragment_info}" + ) + + if fragment_batch: # Process any remaining fragments + batch_tasks.append(self._process_batch(fragment_batch.copy())) + + # Process tasks as they complete + for task in asyncio.as_completed(batch_tasks): + batch_result = await task # Wait for the next task to complete + for kg_extraction in batch_result: + yield kg_extraction diff --git a/R2R/r2r/pipes/ingestion/kg_storage_pipe.py b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py new file mode 100755 index 00000000..9ac63479 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/kg_storage_pipe.py @@ -0,0 +1,133 @@ +import asyncio +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + EmbeddingProvider, + KGExtraction, + KGProvider, + KVLoggingSingleton, + PipeType, +) +from r2r.base.abstractions.llama_abstractions import EntityNode, Relation +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class KGStoragePipe(AsyncPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[KGExtraction, None] + + def __init__( + self, + kg_provider: KGProvider, + embedding_provider: Optional[EmbeddingProvider] = None, + storage_batch_size: int = 1, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the async knowledge graph storage pipe with necessary components and configurations. + """ + logger.info( + f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database." + ) + + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + self.kg_provider = kg_provider + self.embedding_provider = embedding_provider + self.storage_batch_size = storage_batch_size + + async def store( + self, + kg_extractions: list[KGExtraction], + ) -> None: + """ + Stores a batch of knowledge graph extractions in the graph database. + """ + try: + nodes = [] + relations = [] + for extraction in kg_extractions: + for entity in extraction.entities.values(): + embedding = None + if self.embedding_provider: + embedding = self.embedding_provider.get_embedding( + "Entity:\n{entity.value}\nLabel:\n{entity.category}\nSubcategory:\n{entity.subcategory}" + ) + nodes.append( + EntityNode( + name=entity.value, + label=entity.category, + embedding=embedding, + properties=( + {"subcategory": entity.subcategory} + if entity.subcategory + else {} + ), + ) + ) + for triple in extraction.triples: + relations.append( + Relation( + source_id=triple.subject, + target_id=triple.object, + label=triple.predicate, + ) + ) + self.kg_provider.upsert_nodes(nodes) + self.kg_provider.upsert_relations(relations) + except Exception as e: + error_message = f"Failed to store knowledge graph extractions in the database: {e}" + logger.error(error_message) + raise ValueError(error_message) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[None, None]: + """ + Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database. + """ + batch_tasks = [] + kg_batch = [] + + async for kg_extraction in input.message: + kg_batch.append(kg_extraction) + if len(kg_batch) >= self.storage_batch_size: + # Schedule the storage task + batch_tasks.append( + asyncio.create_task( + self.store(kg_batch.copy()), + name=f"kg-store-{self.config.name}", + ) + ) + kg_batch.clear() + + if kg_batch: # Process any remaining extractions + batch_tasks.append( + asyncio.create_task( + self.store(kg_batch.copy()), + name=f"kg-store-{self.config.name}", + ) + ) + + # Wait for all storage tasks to complete + await asyncio.gather(*batch_tasks) + yield None diff --git a/R2R/r2r/pipes/ingestion/parsing_pipe.py b/R2R/r2r/pipes/ingestion/parsing_pipe.py new file mode 100755 index 00000000..f3c81ca0 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/parsing_pipe.py @@ -0,0 +1,211 @@ +""" +This module contains the `DocumentParsingPipe` class, which is responsible for parsing incoming documents into plaintext. +""" + +import asyncio +import logging +import time +import uuid +from typing import AsyncGenerator, Optional, Union + +from r2r.base import ( + AsyncParser, + AsyncState, + Document, + DocumentType, + Extraction, + ExtractionType, + KVLoggingSingleton, + PipeType, + generate_id_from_label, +) +from r2r.base.abstractions.exception import R2RDocumentProcessingError +from r2r.base.pipes.base_pipe import AsyncPipe +from r2r.parsers.media.audio_parser import AudioParser +from r2r.parsers.media.docx_parser import DOCXParser +from r2r.parsers.media.img_parser import ImageParser +from r2r.parsers.media.movie_parser import MovieParser +from r2r.parsers.media.pdf_parser import PDFParser +from r2r.parsers.media.ppt_parser import PPTParser +from r2r.parsers.structured.csv_parser import CSVParser +from r2r.parsers.structured.json_parser import JSONParser +from r2r.parsers.structured.xlsx_parser import XLSXParser +from r2r.parsers.text.html_parser import HTMLParser +from r2r.parsers.text.md_parser import MDParser +from r2r.parsers.text.text_parser import TextParser + +logger = logging.getLogger(__name__) + + +class ParsingPipe(AsyncPipe): + """ + Processes incoming documents into plaintext based on their data type. + Supports TXT, JSON, HTML, and PDF formats. + """ + + class Input(AsyncPipe.Input): + message: AsyncGenerator[Document, None] + + AVAILABLE_PARSERS = { + DocumentType.CSV: CSVParser, + DocumentType.DOCX: DOCXParser, + DocumentType.HTML: HTMLParser, + DocumentType.JSON: JSONParser, + DocumentType.MD: MDParser, + DocumentType.PDF: PDFParser, + DocumentType.PPTX: PPTParser, + DocumentType.TXT: TextParser, + DocumentType.XLSX: XLSXParser, + DocumentType.GIF: ImageParser, + DocumentType.JPEG: ImageParser, + DocumentType.JPG: ImageParser, + DocumentType.PNG: ImageParser, + DocumentType.SVG: ImageParser, + DocumentType.MP3: AudioParser, + DocumentType.MP4: MovieParser, + } + + IMAGE_TYPES = { + DocumentType.GIF, + DocumentType.JPG, + DocumentType.JPEG, + DocumentType.PNG, + DocumentType.SVG, + } + + def __init__( + self, + excluded_parsers: list[DocumentType], + override_parsers: Optional[dict[DocumentType, AsyncParser]] = None, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_document_parsing_pipe"), + *args, + **kwargs, + ) + + self.parsers = {} + + if not override_parsers: + override_parsers = {} + + # Apply overrides if specified + for doc_type, parser in override_parsers.items(): + self.parsers[doc_type] = parser + + for doc_type, parser_info in self.AVAILABLE_PARSERS.items(): + if ( + doc_type not in excluded_parsers + and doc_type not in self.parsers + ): + self.parsers[doc_type] = parser_info() + + @property + def supported_types(self) -> list[str]: + """ + Lists the data types supported by the pipe. + """ + return [entry_type for entry_type in DocumentType] + + async def _parse( + self, + document: Document, + run_id: uuid.UUID, + version: str, + ) -> AsyncGenerator[Union[R2RDocumentProcessingError, Extraction], None]: + if document.type not in self.parsers: + yield R2RDocumentProcessingError( + document_id=document.id, + error_message=f"Parser for {document.type} not found in `ParsingPipe`.", + ) + return + parser = self.parsers[document.type] + texts = parser.ingest(document.data) + extraction_type = ExtractionType.TXT + t0 = time.time() + if document.type in self.IMAGE_TYPES: + extraction_type = ExtractionType.IMG + document.metadata["image_type"] = document.type.value + # SAVE IMAGE DATA + # try: + # import base64 + # sanitized_data = base64.b64encode(document.data).decode('utf-8') + # except Exception as e: + # sanitized_data = document.data + + # document.metadata["image_data"] = sanitized_data + elif document.type == DocumentType.MP4: + extraction_type = ExtractionType.MOV + document.metadata["audio_type"] = document.type.value + + iteration = 0 + async for text in texts: + extraction_id = generate_id_from_label( + f"{document.id}-{iteration}-{version}" + ) + document.metadata["version"] = version + extraction = Extraction( + id=extraction_id, + data=text, + metadata=document.metadata, + document_id=document.id, + type=extraction_type, + ) + yield extraction + # TODO - Add settings to enable extraction logging + # extraction_dict = extraction.dict() + # await self.enqueue_log( + # run_id=run_id, + # key="extraction", + # value=json.dumps( + # { + # "data": extraction_dict["data"], + # "document_id": str(extraction_dict["document_id"]), + # "extraction_id": str(extraction_dict["id"]), + # } + # ), + # ) + iteration += 1 + logger.debug( + f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} into {iteration} extractions in t={time.time() - t0:.2f} seconds." + ) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + versions: Optional[list[str]] = None, + *args, + **kwargs, + ) -> AsyncGenerator[Extraction, None]: + parse_tasks = [] + + iteration = 0 + async for document in input.message: + version = versions[iteration] if versions else "v0" + iteration += 1 + parse_tasks.append( + self._handle_parse_task(document, version, run_id) + ) + + # Await all tasks and yield results concurrently + for parse_task in asyncio.as_completed(parse_tasks): + for extraction in await parse_task: + yield extraction + + async def _handle_parse_task( + self, document: Document, version: str, run_id: uuid.UUID + ) -> AsyncGenerator[Extraction, None]: + extractions = [] + async for extraction in self._parse(document, run_id, version): + extractions.append(extraction) + return extractions diff --git a/R2R/r2r/pipes/ingestion/vector_storage_pipe.py b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py new file mode 100755 index 00000000..9564fd22 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/vector_storage_pipe.py @@ -0,0 +1,128 @@ +import asyncio +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Tuple, Union + +from r2r.base import ( + AsyncState, + KVLoggingSingleton, + PipeType, + VectorDBProvider, + VectorEntry, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +from ...base.abstractions.exception import R2RDocumentProcessingError + +logger = logging.getLogger(__name__) + + +class VectorStoragePipe(AsyncPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[ + Union[R2RDocumentProcessingError, VectorEntry], None + ] + do_upsert: bool = True + + def __init__( + self, + vector_db_provider: VectorDBProvider, + storage_batch_size: int = 128, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the async vector storage pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config, + *args, + **kwargs, + ) + self.vector_db_provider = vector_db_provider + self.storage_batch_size = storage_batch_size + + async def store( + self, + vector_entries: list[VectorEntry], + do_upsert: bool = True, + ) -> None: + """ + Stores a batch of vector entries in the database. + """ + + try: + if do_upsert: + self.vector_db_provider.upsert_entries(vector_entries) + else: + self.vector_db_provider.copy_entries(vector_entries) + except Exception as e: + error_message = ( + f"Failed to store vector entries in the database: {e}" + ) + logger.error(error_message) + raise ValueError(error_message) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[ + Tuple[uuid.UUID, Union[str, R2RDocumentProcessingError]], None + ]: + """ + Executes the async vector storage pipe: storing embeddings in the vector database. + """ + batch_tasks = [] + vector_batch = [] + document_counts = {} + i = 0 + async for msg in input.message: + i += 1 + if isinstance(msg, R2RDocumentProcessingError): + yield (msg.document_id, msg) + continue + + document_id = msg.metadata.get("document_id", None) + if not document_id: + raise ValueError("Document ID not found in the metadata.") + if document_id not in document_counts: + document_counts[document_id] = 1 + else: + document_counts[document_id] += 1 + + vector_batch.append(msg) + if len(vector_batch) >= self.storage_batch_size: + # Schedule the storage task + batch_tasks.append( + asyncio.create_task( + self.store(vector_batch.copy(), input.do_upsert), + name=f"vector-store-{self.config.name}", + ) + ) + vector_batch.clear() + + if vector_batch: # Process any remaining vectors + batch_tasks.append( + asyncio.create_task( + self.store(vector_batch.copy(), input.do_upsert), + name=f"vector-store-{self.config.name}", + ) + ) + + # Wait for all storage tasks to complete + await asyncio.gather(*batch_tasks) + + for document_id, count in document_counts.items(): + yield ( + document_id, + f"Processed {count} vectors for document {document_id}.", + ) diff --git a/R2R/r2r/pipes/other/eval_pipe.py b/R2R/r2r/pipes/other/eval_pipe.py new file mode 100755 index 00000000..b1c60343 --- /dev/null +++ b/R2R/r2r/pipes/other/eval_pipe.py @@ -0,0 +1,54 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from pydantic import BaseModel + +from r2r import AsyncState, EvalProvider, LLMChatCompletion, PipeType +from r2r.base.abstractions.llm import GenerationConfig +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class EvalPipe(AsyncPipe): + class EvalPayload(BaseModel): + query: str + context: str + completion: str + + class Input(AsyncPipe.Input): + message: AsyncGenerator["EvalPipe.EvalPayload", None] + + def __init__( + self, + eval_provider: EvalProvider, + type: PipeType = PipeType.EVAL, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + self.eval_provider = eval_provider + super().__init__( + type=type, + config=config or AsyncPipe.PipeConfig(name="default_eval_pipe"), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + eval_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[LLMChatCompletion, None]: + async for item in input.message: + yield self.eval_provider.evaluate( + item.query, + item.context, + item.completion, + eval_generation_config, + ) diff --git a/R2R/r2r/pipes/other/web_search_pipe.py b/R2R/r2r/pipes/other/web_search_pipe.py new file mode 100755 index 00000000..92e3feee --- /dev/null +++ b/R2R/r2r/pipes/other/web_search_pipe.py @@ -0,0 +1,105 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + PipeType, + VectorSearchResult, + generate_id_from_label, +) +from r2r.integrations import SerperClient + +from ..abstractions.search_pipe import SearchPipe + +logger = logging.getLogger(__name__) + + +class WebSearchPipe(SearchPipe): + def __init__( + self, + serper_client: SerperClient, + type: PipeType = PipeType.SEARCH, + config: Optional[SearchPipe.SearchConfig] = None, + *args, + **kwargs, + ): + super().__init__( + type=type, + config=config or SearchPipe.SearchConfig(), + *args, + **kwargs, + ) + self.serper_client = serper_client + + async def search( + self, + message: str, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + search_limit_override = kwargs.get("search_limit", None) + await self.enqueue_log( + run_id=run_id, key="search_query", value=message + ) + # TODO - Make more general in the future by creating a SearchProvider interface + results = self.serper_client.get_raw( + query=message, + limit=search_limit_override or self.config.search_limit, + ) + + search_results = [] + for result in results: + if result.get("snippet") is None: + continue + result["text"] = result.pop("snippet") + search_result = VectorSearchResult( + id=generate_id_from_label(str(result)), + score=result.get( + "score", 0 + ), # TODO - Consider dynamically generating scores based on similarity + metadata=result, + ) + search_results.append(search_result) + yield search_result + + await self.enqueue_log( + run_id=run_id, + key="search_results", + value=json.dumps([ele.json() for ele in search_results]), + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs, + ) -> AsyncGenerator[VectorSearchResult, None]: + search_queries = [] + search_results = [] + async for search_request in input.message: + search_queries.append(search_request) + async for result in self.search( + message=search_request, run_id=run_id, *args, **kwargs + ): + search_results.append(result) + yield result + + await state.update( + self.config.name, {"output": {"search_results": search_results}} + ) + + await state.update( + self.config.name, + { + "output": { + "search_queries": search_queries, + "search_results": search_results, + } + }, + ) diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py new file mode 100755 index 00000000..60935265 --- /dev/null +++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py @@ -0,0 +1,103 @@ +import logging +import uuid +from typing import Any, Optional + +from r2r.base import ( + AsyncState, + KGProvider, + KGSearchSettings, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, +) + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class KGAgentSearchPipe(GeneratorPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + def __init__( + self, + kg_provider: KGProvider, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[GeneratorPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="kg_rag_pipe", task_prompt="kg_agent" + ), + pipe_logger=pipe_logger, + *args, + **kwargs, + ) + self.kg_provider = kg_provider + self.llm_provider = llm_provider + self.prompt_provider = prompt_provider + self.pipe_run_info = None + + async def _run_logic( + self, + input: GeneratorPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + kg_search_settings: KGSearchSettings, + *args: Any, + **kwargs: Any, + ): + async for message in input.message: + # TODO - Remove hard code + formatted_prompt = self.prompt_provider.get_prompt( + "kg_agent", {"input": message} + ) + messages = self._get_message_payload(formatted_prompt) + + result = await self.llm_provider.aget_completion( + messages=messages, + generation_config=kg_search_settings.agent_generation_config, + ) + + extraction = result.choices[0].message.content + query = extraction.split("```cypher")[1].split("```")[0] + result = self.kg_provider.structured_query(query) + yield (query, result) + + await self.enqueue_log( + run_id=run_id, + key="kg_agent_response", + value=extraction, + ) + + await self.enqueue_log( + run_id=run_id, + key="kg_agent_execution_result", + value=result, + ) + + def _get_message_payload(self, message: str) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + {"role": "user", "content": message}, + ] diff --git a/R2R/r2r/pipes/retrieval/multi_search.py b/R2R/r2r/pipes/retrieval/multi_search.py new file mode 100755 index 00000000..6da2c34b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/multi_search.py @@ -0,0 +1,79 @@ +import uuid +from copy import copy +from typing import Any, AsyncGenerator, Optional + +from r2r.base.abstractions.llm import GenerationConfig +from r2r.base.abstractions.search import VectorSearchResult +from r2r.base.pipes.base_pipe import AsyncPipe + +from ..abstractions.search_pipe import SearchPipe +from .query_transform_pipe import QueryTransformPipe + + +class MultiSearchPipe(AsyncPipe): + class PipeConfig(AsyncPipe.PipeConfig): + name: str = "multi_search_pipe" + + def __init__( + self, + query_transform_pipe: QueryTransformPipe, + inner_search_pipe: SearchPipe, + config: Optional[PipeConfig] = None, + *args, + **kwargs, + ): + self.query_transform_pipe = query_transform_pipe + self.vector_search_pipe = inner_search_pipe + if ( + not query_transform_pipe.config.name + == inner_search_pipe.config.name + ): + raise ValueError( + "The query transform pipe and search pipe must have the same name." + ) + if config and not config.name == query_transform_pipe.config.name: + raise ValueError( + "The pipe config name must match the query transform pipe name." + ) + + super().__init__( + config=config + or MultiSearchPipe.PipeConfig( + name=query_transform_pipe.config.name + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: Any, + state: Any, + run_id: uuid.UUID, + query_transform_generation_config: Optional[GenerationConfig] = None, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + query_transform_generation_config = ( + query_transform_generation_config + or copy(kwargs.get("rag_generation_config", None)) + or GenerationConfig(model="gpt-4o") + ) + query_transform_generation_config.stream = False + + query_generator = await self.query_transform_pipe.run( + input, + state, + query_transform_generation_config=query_transform_generation_config, + num_query_xf_outputs=3, + *args, + **kwargs, + ) + + async for search_result in await self.vector_search_pipe.run( + self.vector_search_pipe.Input(message=query_generator), + state, + *args, + **kwargs, + ): + yield search_result diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py new file mode 100755 index 00000000..99df6b5b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py @@ -0,0 +1,101 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class QueryTransformPipe(GeneratorPipe): + class QueryTransformConfig(GeneratorPipe.PipeConfig): + name: str = "default_query_transform" + system_prompt: str = "default_system" + task_prompt: str = "hyde" + + class Input(GeneratorPipe.Input): + message: AsyncGenerator[str, None] + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.TRANSFORM, + config: Optional[QueryTransformConfig] = None, + *args, + **kwargs, + ): + logger.info(f"Initalizing an `QueryTransformPipe` pipe.") + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config or QueryTransformPipe.QueryTransformConfig(), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + query_transform_generation_config: GenerationConfig, + num_query_xf_outputs: int = 3, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + async for query in input.message: + logger.info( + f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}." + ) + + query_transform_request = self._get_message_payload( + query, num_outputs=num_query_xf_outputs + ) + + response = await self.llm_provider.aget_completion( + messages=query_transform_request, + generation_config=query_transform_generation_config, + ) + content = self.llm_provider.extract_content(response) + outputs = content.split("\n") + outputs = [ + output.strip() for output in outputs if output.strip() != "" + ] + await state.update( + self.config.name, {"output": {"outputs": outputs}} + ) + + for output in outputs: + logger.info(f"Yielding transformed output: {output}") + yield output + + def _get_message_payload(self, input: str, num_outputs: int) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={ + "message": input, + "num_outputs": num_outputs, + }, + ), + }, + ] diff --git a/R2R/r2r/pipes/retrieval/search_rag_pipe.py b/R2R/r2r/pipes/retrieval/search_rag_pipe.py new file mode 100755 index 00000000..4d01d2df --- /dev/null +++ b/R2R/r2r/pipes/retrieval/search_rag_pipe.py @@ -0,0 +1,130 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Tuple + +from r2r.base import ( + AggregateSearchResult, + AsyncPipe, + AsyncState, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class SearchRAGPipe(GeneratorPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[Tuple[str, AggregateSearchResult], None] + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[GeneratorPipe] = None, + *args, + **kwargs, + ): + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="default_rag_pipe", task_prompt="default_rag" + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[RAGCompletion, None]: + context = "" + search_iteration = 1 + total_results = 0 + # must select a query if there are multiple + sel_query = None + async for query, search_results in input.message: + if search_iteration == 1: + sel_query = query + context_piece, total_results = await self._collect_context( + query, search_results, search_iteration, total_results + ) + context += context_piece + search_iteration += 1 + + messages = self._get_message_payload(sel_query, context) + + response = await self.llm_provider.aget_completion( + messages=messages, generation_config=rag_generation_config + ) + yield RAGCompletion(completion=response, search_results=search_results) + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response.choices[0].message.content, + ) + + def _get_message_payload(self, query: str, context: str) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={ + "query": query, + "context": context, + }, + ), + }, + ] + + async def _collect_context( + self, + query: str, + results: AggregateSearchResult, + iteration: int, + total_results: int, + ) -> Tuple[str, int]: + context = f"Query:\n{query}\n\n" + if results.vector_search_results: + context += f"Vector Search Results({iteration}):\n" + it = total_results + 1 + for result in results.vector_search_results: + context += f"[{it}]: {result.metadata['text']}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + if results.kg_search_results: + context += f"Knowledge Graph ({iteration}):\n" + it = total_results + 1 + for query, search_results in results.kg_search_results: # [1]: + context += f"Query: {query}\n\n" + context += f"Results:\n" + for search_result in search_results: + context += f"[{it}]: {search_result}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + return context, total_results diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py new file mode 100755 index 00000000..b01f6445 --- /dev/null +++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py @@ -0,0 +1,131 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Generator, Optional + +from r2r.base import ( + AsyncState, + LLMChatCompletionChunk, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.generator_pipe import GeneratorPipe +from .search_rag_pipe import SearchRAGPipe + +logger = logging.getLogger(__name__) + + +class StreamingSearchRAGPipe(SearchRAGPipe): + SEARCH_STREAM_MARKER = "search" + COMPLETION_STREAM_MARKER = "completion" + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[GeneratorPipe] = None, + *args, + **kwargs, + ): + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="default_streaming_rag_pipe", task_prompt="default_rag" + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: SearchRAGPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + iteration = 0 + context = "" + # dump the search results and construct the context + async for query, search_results in input.message: + yield f"<{self.SEARCH_STREAM_MARKER}>" + if search_results.vector_search_results: + context += "Vector Search Results:\n" + for result in search_results.vector_search_results: + if iteration >= 1: + yield "," + yield json.dumps(result.json()) + context += ( + f"{iteration + 1}:\n{result.metadata['text']}\n\n" + ) + iteration += 1 + + # if search_results.kg_search_results: + # for result in search_results.kg_search_results: + # if iteration >= 1: + # yield "," + # yield json.dumps(result.json()) + # context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n" + # iteration += 1 + + yield f"</{self.SEARCH_STREAM_MARKER}>" + + messages = self._get_message_payload(query, context) + yield f"<{self.COMPLETION_STREAM_MARKER}>" + response = "" + for chunk in self.llm_provider.get_completion_stream( + messages=messages, generation_config=rag_generation_config + ): + chunk = StreamingSearchRAGPipe._process_chunk(chunk) + response += chunk + yield chunk + + yield f"</{self.COMPLETION_STREAM_MARKER}>" + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response, + ) + + async def _yield_chunks( + self, + start_marker: str, + chunks: Generator[str, None, None], + end_marker: str, + ) -> str: + yield start_marker + for chunk in chunks: + yield chunk + yield end_marker + + def _get_message_payload( + self, query: str, context: str + ) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={"query": query, "context": context}, + ), + }, + ] + + @staticmethod + def _process_chunk(chunk: LLMChatCompletionChunk) -> str: + return chunk.choices[0].delta.content or "" diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py new file mode 100755 index 00000000..742de16b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py @@ -0,0 +1,123 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + EmbeddingProvider, + PipeType, + VectorDBProvider, + VectorSearchResult, + VectorSearchSettings, +) + +from ..abstractions.search_pipe import SearchPipe + +logger = logging.getLogger(__name__) + + +class VectorSearchPipe(SearchPipe): + def __init__( + self, + vector_db_provider: VectorDBProvider, + embedding_provider: EmbeddingProvider, + type: PipeType = PipeType.SEARCH, + config: Optional[SearchPipe.SearchConfig] = None, + *args, + **kwargs, + ): + super().__init__( + type=type, + config=config or SearchPipe.SearchConfig(), + *args, + **kwargs, + ) + self.embedding_provider = embedding_provider + self.vector_db_provider = vector_db_provider + + async def search( + self, + message: str, + run_id: uuid.UUID, + vector_search_settings: VectorSearchSettings, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + await self.enqueue_log( + run_id=run_id, key="search_query", value=message + ) + search_filters = ( + vector_search_settings.search_filters or self.config.search_filters + ) + search_limit = ( + vector_search_settings.search_limit or self.config.search_limit + ) + results = [] + query_vector = self.embedding_provider.get_embedding( + message, + ) + search_results = ( + self.vector_db_provider.hybrid_search( + query_vector=query_vector, + query_text=message, + filters=search_filters, + limit=search_limit, + ) + if vector_search_settings.do_hybrid_search + else self.vector_db_provider.search( + query_vector=query_vector, + filters=search_filters, + limit=search_limit, + ) + ) + reranked_results = self.embedding_provider.rerank( + query=message, results=search_results, limit=search_limit + ) + for result in reranked_results: + result.metadata["associatedQuery"] = message + results.append(result) + yield result + await self.enqueue_log( + run_id=run_id, + key="search_results", + value=json.dumps([ele.json() for ele in results]), + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + search_queries = [] + search_results = [] + async for search_request in input.message: + search_queries.append(search_request) + async for result in self.search( + message=search_request, + run_id=run_id, + vector_search_settings=vector_search_settings, + *args, + **kwargs, + ): + search_results.append(result) + yield result + + await state.update( + self.config.name, {"output": {"search_results": search_results}} + ) + + await state.update( + self.config.name, + { + "output": { + "search_queries": search_queries, + "search_results": search_results, + } + }, + ) |