diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/retrieval')
-rwxr-xr-x | R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py | 103 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/multi_search.py | 79 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/query_transform_pipe.py | 101 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/search_rag_pipe.py | 130 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/streaming_rag_pipe.py | 131 | ||||
-rwxr-xr-x | R2R/r2r/pipes/retrieval/vector_search_pipe.py | 123 |
6 files changed, 667 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py new file mode 100755 index 00000000..60935265 --- /dev/null +++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py @@ -0,0 +1,103 @@ +import logging +import uuid +from typing import Any, Optional + +from r2r.base import ( + AsyncState, + KGProvider, + KGSearchSettings, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, +) + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class KGAgentSearchPipe(GeneratorPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + def __init__( + self, + kg_provider: KGProvider, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[GeneratorPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="kg_rag_pipe", task_prompt="kg_agent" + ), + pipe_logger=pipe_logger, + *args, + **kwargs, + ) + self.kg_provider = kg_provider + self.llm_provider = llm_provider + self.prompt_provider = prompt_provider + self.pipe_run_info = None + + async def _run_logic( + self, + input: GeneratorPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + kg_search_settings: KGSearchSettings, + *args: Any, + **kwargs: Any, + ): + async for message in input.message: + # TODO - Remove hard code + formatted_prompt = self.prompt_provider.get_prompt( + "kg_agent", {"input": message} + ) + messages = self._get_message_payload(formatted_prompt) + + result = await self.llm_provider.aget_completion( + messages=messages, + generation_config=kg_search_settings.agent_generation_config, + ) + + extraction = result.choices[0].message.content + query = extraction.split("```cypher")[1].split("```")[0] + result = self.kg_provider.structured_query(query) + yield (query, result) + + await self.enqueue_log( + run_id=run_id, + key="kg_agent_response", + value=extraction, + ) + + await self.enqueue_log( + run_id=run_id, + key="kg_agent_execution_result", + value=result, + ) + + def _get_message_payload(self, message: str) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + {"role": "user", "content": message}, + ] diff --git a/R2R/r2r/pipes/retrieval/multi_search.py b/R2R/r2r/pipes/retrieval/multi_search.py new file mode 100755 index 00000000..6da2c34b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/multi_search.py @@ -0,0 +1,79 @@ +import uuid +from copy import copy +from typing import Any, AsyncGenerator, Optional + +from r2r.base.abstractions.llm import GenerationConfig +from r2r.base.abstractions.search import VectorSearchResult +from r2r.base.pipes.base_pipe import AsyncPipe + +from ..abstractions.search_pipe import SearchPipe +from .query_transform_pipe import QueryTransformPipe + + +class MultiSearchPipe(AsyncPipe): + class PipeConfig(AsyncPipe.PipeConfig): + name: str = "multi_search_pipe" + + def __init__( + self, + query_transform_pipe: QueryTransformPipe, + inner_search_pipe: SearchPipe, + config: Optional[PipeConfig] = None, + *args, + **kwargs, + ): + self.query_transform_pipe = query_transform_pipe + self.vector_search_pipe = inner_search_pipe + if ( + not query_transform_pipe.config.name + == inner_search_pipe.config.name + ): + raise ValueError( + "The query transform pipe and search pipe must have the same name." + ) + if config and not config.name == query_transform_pipe.config.name: + raise ValueError( + "The pipe config name must match the query transform pipe name." + ) + + super().__init__( + config=config + or MultiSearchPipe.PipeConfig( + name=query_transform_pipe.config.name + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: Any, + state: Any, + run_id: uuid.UUID, + query_transform_generation_config: Optional[GenerationConfig] = None, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + query_transform_generation_config = ( + query_transform_generation_config + or copy(kwargs.get("rag_generation_config", None)) + or GenerationConfig(model="gpt-4o") + ) + query_transform_generation_config.stream = False + + query_generator = await self.query_transform_pipe.run( + input, + state, + query_transform_generation_config=query_transform_generation_config, + num_query_xf_outputs=3, + *args, + **kwargs, + ) + + async for search_result in await self.vector_search_pipe.run( + self.vector_search_pipe.Input(message=query_generator), + state, + *args, + **kwargs, + ): + yield search_result diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py new file mode 100755 index 00000000..99df6b5b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py @@ -0,0 +1,101 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class QueryTransformPipe(GeneratorPipe): + class QueryTransformConfig(GeneratorPipe.PipeConfig): + name: str = "default_query_transform" + system_prompt: str = "default_system" + task_prompt: str = "hyde" + + class Input(GeneratorPipe.Input): + message: AsyncGenerator[str, None] + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.TRANSFORM, + config: Optional[QueryTransformConfig] = None, + *args, + **kwargs, + ): + logger.info(f"Initalizing an `QueryTransformPipe` pipe.") + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config or QueryTransformPipe.QueryTransformConfig(), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + query_transform_generation_config: GenerationConfig, + num_query_xf_outputs: int = 3, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + async for query in input.message: + logger.info( + f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}." + ) + + query_transform_request = self._get_message_payload( + query, num_outputs=num_query_xf_outputs + ) + + response = await self.llm_provider.aget_completion( + messages=query_transform_request, + generation_config=query_transform_generation_config, + ) + content = self.llm_provider.extract_content(response) + outputs = content.split("\n") + outputs = [ + output.strip() for output in outputs if output.strip() != "" + ] + await state.update( + self.config.name, {"output": {"outputs": outputs}} + ) + + for output in outputs: + logger.info(f"Yielding transformed output: {output}") + yield output + + def _get_message_payload(self, input: str, num_outputs: int) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={ + "message": input, + "num_outputs": num_outputs, + }, + ), + }, + ] diff --git a/R2R/r2r/pipes/retrieval/search_rag_pipe.py b/R2R/r2r/pipes/retrieval/search_rag_pipe.py new file mode 100755 index 00000000..4d01d2df --- /dev/null +++ b/R2R/r2r/pipes/retrieval/search_rag_pipe.py @@ -0,0 +1,130 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Tuple + +from r2r.base import ( + AggregateSearchResult, + AsyncPipe, + AsyncState, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class SearchRAGPipe(GeneratorPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[Tuple[str, AggregateSearchResult], None] + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[GeneratorPipe] = None, + *args, + **kwargs, + ): + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="default_rag_pipe", task_prompt="default_rag" + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[RAGCompletion, None]: + context = "" + search_iteration = 1 + total_results = 0 + # must select a query if there are multiple + sel_query = None + async for query, search_results in input.message: + if search_iteration == 1: + sel_query = query + context_piece, total_results = await self._collect_context( + query, search_results, search_iteration, total_results + ) + context += context_piece + search_iteration += 1 + + messages = self._get_message_payload(sel_query, context) + + response = await self.llm_provider.aget_completion( + messages=messages, generation_config=rag_generation_config + ) + yield RAGCompletion(completion=response, search_results=search_results) + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response.choices[0].message.content, + ) + + def _get_message_payload(self, query: str, context: str) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={ + "query": query, + "context": context, + }, + ), + }, + ] + + async def _collect_context( + self, + query: str, + results: AggregateSearchResult, + iteration: int, + total_results: int, + ) -> Tuple[str, int]: + context = f"Query:\n{query}\n\n" + if results.vector_search_results: + context += f"Vector Search Results({iteration}):\n" + it = total_results + 1 + for result in results.vector_search_results: + context += f"[{it}]: {result.metadata['text']}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + if results.kg_search_results: + context += f"Knowledge Graph ({iteration}):\n" + it = total_results + 1 + for query, search_results in results.kg_search_results: # [1]: + context += f"Query: {query}\n\n" + context += f"Results:\n" + for search_result in search_results: + context += f"[{it}]: {search_result}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + return context, total_results diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py new file mode 100755 index 00000000..b01f6445 --- /dev/null +++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py @@ -0,0 +1,131 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Generator, Optional + +from r2r.base import ( + AsyncState, + LLMChatCompletionChunk, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.generator_pipe import GeneratorPipe +from .search_rag_pipe import SearchRAGPipe + +logger = logging.getLogger(__name__) + + +class StreamingSearchRAGPipe(SearchRAGPipe): + SEARCH_STREAM_MARKER = "search" + COMPLETION_STREAM_MARKER = "completion" + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[GeneratorPipe] = None, + *args, + **kwargs, + ): + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="default_streaming_rag_pipe", task_prompt="default_rag" + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: SearchRAGPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + iteration = 0 + context = "" + # dump the search results and construct the context + async for query, search_results in input.message: + yield f"<{self.SEARCH_STREAM_MARKER}>" + if search_results.vector_search_results: + context += "Vector Search Results:\n" + for result in search_results.vector_search_results: + if iteration >= 1: + yield "," + yield json.dumps(result.json()) + context += ( + f"{iteration + 1}:\n{result.metadata['text']}\n\n" + ) + iteration += 1 + + # if search_results.kg_search_results: + # for result in search_results.kg_search_results: + # if iteration >= 1: + # yield "," + # yield json.dumps(result.json()) + # context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n" + # iteration += 1 + + yield f"</{self.SEARCH_STREAM_MARKER}>" + + messages = self._get_message_payload(query, context) + yield f"<{self.COMPLETION_STREAM_MARKER}>" + response = "" + for chunk in self.llm_provider.get_completion_stream( + messages=messages, generation_config=rag_generation_config + ): + chunk = StreamingSearchRAGPipe._process_chunk(chunk) + response += chunk + yield chunk + + yield f"</{self.COMPLETION_STREAM_MARKER}>" + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response, + ) + + async def _yield_chunks( + self, + start_marker: str, + chunks: Generator[str, None, None], + end_marker: str, + ) -> str: + yield start_marker + for chunk in chunks: + yield chunk + yield end_marker + + def _get_message_payload( + self, query: str, context: str + ) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={"query": query, "context": context}, + ), + }, + ] + + @staticmethod + def _process_chunk(chunk: LLMChatCompletionChunk) -> str: + return chunk.choices[0].delta.content or "" diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py new file mode 100755 index 00000000..742de16b --- /dev/null +++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py @@ -0,0 +1,123 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncPipe, + AsyncState, + EmbeddingProvider, + PipeType, + VectorDBProvider, + VectorSearchResult, + VectorSearchSettings, +) + +from ..abstractions.search_pipe import SearchPipe + +logger = logging.getLogger(__name__) + + +class VectorSearchPipe(SearchPipe): + def __init__( + self, + vector_db_provider: VectorDBProvider, + embedding_provider: EmbeddingProvider, + type: PipeType = PipeType.SEARCH, + config: Optional[SearchPipe.SearchConfig] = None, + *args, + **kwargs, + ): + super().__init__( + type=type, + config=config or SearchPipe.SearchConfig(), + *args, + **kwargs, + ) + self.embedding_provider = embedding_provider + self.vector_db_provider = vector_db_provider + + async def search( + self, + message: str, + run_id: uuid.UUID, + vector_search_settings: VectorSearchSettings, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + await self.enqueue_log( + run_id=run_id, key="search_query", value=message + ) + search_filters = ( + vector_search_settings.search_filters or self.config.search_filters + ) + search_limit = ( + vector_search_settings.search_limit or self.config.search_limit + ) + results = [] + query_vector = self.embedding_provider.get_embedding( + message, + ) + search_results = ( + self.vector_db_provider.hybrid_search( + query_vector=query_vector, + query_text=message, + filters=search_filters, + limit=search_limit, + ) + if vector_search_settings.do_hybrid_search + else self.vector_db_provider.search( + query_vector=query_vector, + filters=search_filters, + limit=search_limit, + ) + ) + reranked_results = self.embedding_provider.rerank( + query=message, results=search_results, limit=search_limit + ) + for result in reranked_results: + result.metadata["associatedQuery"] = message + results.append(result) + yield result + await self.enqueue_log( + run_id=run_id, + key="search_results", + value=json.dumps([ele.json() for ele in results]), + ) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[VectorSearchResult, None]: + search_queries = [] + search_results = [] + async for search_request in input.message: + search_queries.append(search_request) + async for result in self.search( + message=search_request, + run_id=run_id, + vector_search_settings=vector_search_settings, + *args, + **kwargs, + ): + search_results.append(result) + yield result + + await state.update( + self.config.name, {"output": {"search_results": search_results}} + ) + + await state.update( + self.config.name, + { + "output": { + "search_queries": search_queries, + "search_results": search_results, + } + }, + ) |