aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/retrieval')
-rwxr-xr-xR2R/r2r/pipes/retrieval/kg_agent_search_pipe.py103
-rwxr-xr-xR2R/r2r/pipes/retrieval/multi_search.py79
-rwxr-xr-xR2R/r2r/pipes/retrieval/query_transform_pipe.py101
-rwxr-xr-xR2R/r2r/pipes/retrieval/search_rag_pipe.py130
-rwxr-xr-xR2R/r2r/pipes/retrieval/streaming_rag_pipe.py131
-rwxr-xr-xR2R/r2r/pipes/retrieval/vector_search_pipe.py123
6 files changed, 667 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
new file mode 100755
index 00000000..60935265
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
@@ -0,0 +1,103 @@
+import logging
+import uuid
+from typing import Any, Optional
+
+from r2r.base import (
+ AsyncState,
+ KGProvider,
+ KGSearchSettings,
+ KVLoggingSingleton,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGAgentSearchPipe(GeneratorPipe):
+ """
+ Embeds and stores documents using a specified embedding model and database.
+ """
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[GeneratorPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the embedding pipe with necessary components and configurations.
+ """
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="kg_rag_pipe", task_prompt="kg_agent"
+ ),
+ pipe_logger=pipe_logger,
+ *args,
+ **kwargs,
+ )
+ self.kg_provider = kg_provider
+ self.llm_provider = llm_provider
+ self.prompt_provider = prompt_provider
+ self.pipe_run_info = None
+
+ async def _run_logic(
+ self,
+ input: GeneratorPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ kg_search_settings: KGSearchSettings,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ async for message in input.message:
+ # TODO - Remove hard code
+ formatted_prompt = self.prompt_provider.get_prompt(
+ "kg_agent", {"input": message}
+ )
+ messages = self._get_message_payload(formatted_prompt)
+
+ result = await self.llm_provider.aget_completion(
+ messages=messages,
+ generation_config=kg_search_settings.agent_generation_config,
+ )
+
+ extraction = result.choices[0].message.content
+ query = extraction.split("```cypher")[1].split("```")[0]
+ result = self.kg_provider.structured_query(query)
+ yield (query, result)
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="kg_agent_response",
+ value=extraction,
+ )
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="kg_agent_execution_result",
+ value=result,
+ )
+
+ def _get_message_payload(self, message: str) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {"role": "user", "content": message},
+ ]
diff --git a/R2R/r2r/pipes/retrieval/multi_search.py b/R2R/r2r/pipes/retrieval/multi_search.py
new file mode 100755
index 00000000..6da2c34b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/multi_search.py
@@ -0,0 +1,79 @@
+import uuid
+from copy import copy
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.abstractions.search import VectorSearchResult
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+from ..abstractions.search_pipe import SearchPipe
+from .query_transform_pipe import QueryTransformPipe
+
+
+class MultiSearchPipe(AsyncPipe):
+ class PipeConfig(AsyncPipe.PipeConfig):
+ name: str = "multi_search_pipe"
+
+ def __init__(
+ self,
+ query_transform_pipe: QueryTransformPipe,
+ inner_search_pipe: SearchPipe,
+ config: Optional[PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ self.query_transform_pipe = query_transform_pipe
+ self.vector_search_pipe = inner_search_pipe
+ if (
+ not query_transform_pipe.config.name
+ == inner_search_pipe.config.name
+ ):
+ raise ValueError(
+ "The query transform pipe and search pipe must have the same name."
+ )
+ if config and not config.name == query_transform_pipe.config.name:
+ raise ValueError(
+ "The pipe config name must match the query transform pipe name."
+ )
+
+ super().__init__(
+ config=config
+ or MultiSearchPipe.PipeConfig(
+ name=query_transform_pipe.config.name
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: Any,
+ state: Any,
+ run_id: uuid.UUID,
+ query_transform_generation_config: Optional[GenerationConfig] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ query_transform_generation_config = (
+ query_transform_generation_config
+ or copy(kwargs.get("rag_generation_config", None))
+ or GenerationConfig(model="gpt-4o")
+ )
+ query_transform_generation_config.stream = False
+
+ query_generator = await self.query_transform_pipe.run(
+ input,
+ state,
+ query_transform_generation_config=query_transform_generation_config,
+ num_query_xf_outputs=3,
+ *args,
+ **kwargs,
+ )
+
+ async for search_result in await self.vector_search_pipe.run(
+ self.vector_search_pipe.Input(message=query_generator),
+ state,
+ *args,
+ **kwargs,
+ ):
+ yield search_result
diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
new file mode 100755
index 00000000..99df6b5b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
@@ -0,0 +1,101 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class QueryTransformPipe(GeneratorPipe):
+ class QueryTransformConfig(GeneratorPipe.PipeConfig):
+ name: str = "default_query_transform"
+ system_prompt: str = "default_system"
+ task_prompt: str = "hyde"
+
+ class Input(GeneratorPipe.Input):
+ message: AsyncGenerator[str, None]
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.TRANSFORM,
+ config: Optional[QueryTransformConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config or QueryTransformPipe.QueryTransformConfig(),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ query_transform_generation_config: GenerationConfig,
+ num_query_xf_outputs: int = 3,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[str, None]:
+ async for query in input.message:
+ logger.info(
+ f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
+ )
+
+ query_transform_request = self._get_message_payload(
+ query, num_outputs=num_query_xf_outputs
+ )
+
+ response = await self.llm_provider.aget_completion(
+ messages=query_transform_request,
+ generation_config=query_transform_generation_config,
+ )
+ content = self.llm_provider.extract_content(response)
+ outputs = content.split("\n")
+ outputs = [
+ output.strip() for output in outputs if output.strip() != ""
+ ]
+ await state.update(
+ self.config.name, {"output": {"outputs": outputs}}
+ )
+
+ for output in outputs:
+ logger.info(f"Yielding transformed output: {output}")
+ yield output
+
+ def _get_message_payload(self, input: str, num_outputs: int) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={
+ "message": input,
+ "num_outputs": num_outputs,
+ },
+ ),
+ },
+ ]
diff --git a/R2R/r2r/pipes/retrieval/search_rag_pipe.py b/R2R/r2r/pipes/retrieval/search_rag_pipe.py
new file mode 100755
index 00000000..4d01d2df
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/search_rag_pipe.py
@@ -0,0 +1,130 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional, Tuple
+
+from r2r.base import (
+ AggregateSearchResult,
+ AsyncPipe,
+ AsyncState,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class SearchRAGPipe(GeneratorPipe):
+ class Input(AsyncPipe.Input):
+ message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.GENERATOR,
+ config: Optional[GeneratorPipe] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="default_rag_pipe", task_prompt="default_rag"
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ rag_generation_config: GenerationConfig,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[RAGCompletion, None]:
+ context = ""
+ search_iteration = 1
+ total_results = 0
+ # must select a query if there are multiple
+ sel_query = None
+ async for query, search_results in input.message:
+ if search_iteration == 1:
+ sel_query = query
+ context_piece, total_results = await self._collect_context(
+ query, search_results, search_iteration, total_results
+ )
+ context += context_piece
+ search_iteration += 1
+
+ messages = self._get_message_payload(sel_query, context)
+
+ response = await self.llm_provider.aget_completion(
+ messages=messages, generation_config=rag_generation_config
+ )
+ yield RAGCompletion(completion=response, search_results=search_results)
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="llm_response",
+ value=response.choices[0].message.content,
+ )
+
+ def _get_message_payload(self, query: str, context: str) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt,
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={
+ "query": query,
+ "context": context,
+ },
+ ),
+ },
+ ]
+
+ async def _collect_context(
+ self,
+ query: str,
+ results: AggregateSearchResult,
+ iteration: int,
+ total_results: int,
+ ) -> Tuple[str, int]:
+ context = f"Query:\n{query}\n\n"
+ if results.vector_search_results:
+ context += f"Vector Search Results({iteration}):\n"
+ it = total_results + 1
+ for result in results.vector_search_results:
+ context += f"[{it}]: {result.metadata['text']}\n\n"
+ it += 1
+ total_results = (
+ it - 1
+ ) # Update total_results based on the last index used
+ if results.kg_search_results:
+ context += f"Knowledge Graph ({iteration}):\n"
+ it = total_results + 1
+ for query, search_results in results.kg_search_results: # [1]:
+ context += f"Query: {query}\n\n"
+ context += f"Results:\n"
+ for search_result in search_results:
+ context += f"[{it}]: {search_result}\n\n"
+ it += 1
+ total_results = (
+ it - 1
+ ) # Update total_results based on the last index used
+ return context, total_results
diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
new file mode 100755
index 00000000..b01f6445
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
@@ -0,0 +1,131 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from r2r.base import (
+ AsyncState,
+ LLMChatCompletionChunk,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+from .search_rag_pipe import SearchRAGPipe
+
+logger = logging.getLogger(__name__)
+
+
+class StreamingSearchRAGPipe(SearchRAGPipe):
+ SEARCH_STREAM_MARKER = "search"
+ COMPLETION_STREAM_MARKER = "completion"
+
+ def __init__(
+ self,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ type: PipeType = PipeType.GENERATOR,
+ config: Optional[GeneratorPipe] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ llm_provider=llm_provider,
+ prompt_provider=prompt_provider,
+ type=type,
+ config=config
+ or GeneratorPipe.Config(
+ name="default_streaming_rag_pipe", task_prompt="default_rag"
+ ),
+ *args,
+ **kwargs,
+ )
+
+ async def _run_logic(
+ self,
+ input: SearchRAGPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ rag_generation_config: GenerationConfig,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[str, None]:
+ iteration = 0
+ context = ""
+ # dump the search results and construct the context
+ async for query, search_results in input.message:
+ yield f"<{self.SEARCH_STREAM_MARKER}>"
+ if search_results.vector_search_results:
+ context += "Vector Search Results:\n"
+ for result in search_results.vector_search_results:
+ if iteration >= 1:
+ yield ","
+ yield json.dumps(result.json())
+ context += (
+ f"{iteration + 1}:\n{result.metadata['text']}\n\n"
+ )
+ iteration += 1
+
+ # if search_results.kg_search_results:
+ # for result in search_results.kg_search_results:
+ # if iteration >= 1:
+ # yield ","
+ # yield json.dumps(result.json())
+ # context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n"
+ # iteration += 1
+
+ yield f"</{self.SEARCH_STREAM_MARKER}>"
+
+ messages = self._get_message_payload(query, context)
+ yield f"<{self.COMPLETION_STREAM_MARKER}>"
+ response = ""
+ for chunk in self.llm_provider.get_completion_stream(
+ messages=messages, generation_config=rag_generation_config
+ ):
+ chunk = StreamingSearchRAGPipe._process_chunk(chunk)
+ response += chunk
+ yield chunk
+
+ yield f"</{self.COMPLETION_STREAM_MARKER}>"
+
+ await self.enqueue_log(
+ run_id=run_id,
+ key="llm_response",
+ value=response,
+ )
+
+ async def _yield_chunks(
+ self,
+ start_marker: str,
+ chunks: Generator[str, None, None],
+ end_marker: str,
+ ) -> str:
+ yield start_marker
+ for chunk in chunks:
+ yield chunk
+ yield end_marker
+
+ def _get_message_payload(
+ self, query: str, context: str
+ ) -> list[dict[str, str]]:
+ return [
+ {
+ "role": "system",
+ "content": self.prompt_provider.get_prompt(
+ self.config.system_prompt
+ ),
+ },
+ {
+ "role": "user",
+ "content": self.prompt_provider.get_prompt(
+ self.config.task_prompt,
+ inputs={"query": query, "context": context},
+ ),
+ },
+ ]
+
+ @staticmethod
+ def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
+ return chunk.choices[0].delta.content or ""
diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
new file mode 100755
index 00000000..742de16b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
@@ -0,0 +1,123 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncPipe,
+ AsyncState,
+ EmbeddingProvider,
+ PipeType,
+ VectorDBProvider,
+ VectorSearchResult,
+ VectorSearchSettings,
+)
+
+from ..abstractions.search_pipe import SearchPipe
+
+logger = logging.getLogger(__name__)
+
+
+class VectorSearchPipe(SearchPipe):
+ def __init__(
+ self,
+ vector_db_provider: VectorDBProvider,
+ embedding_provider: EmbeddingProvider,
+ type: PipeType = PipeType.SEARCH,
+ config: Optional[SearchPipe.SearchConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ type=type,
+ config=config or SearchPipe.SearchConfig(),
+ *args,
+ **kwargs,
+ )
+ self.embedding_provider = embedding_provider
+ self.vector_db_provider = vector_db_provider
+
+ async def search(
+ self,
+ message: str,
+ run_id: uuid.UUID,
+ vector_search_settings: VectorSearchSettings,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ await self.enqueue_log(
+ run_id=run_id, key="search_query", value=message
+ )
+ search_filters = (
+ vector_search_settings.search_filters or self.config.search_filters
+ )
+ search_limit = (
+ vector_search_settings.search_limit or self.config.search_limit
+ )
+ results = []
+ query_vector = self.embedding_provider.get_embedding(
+ message,
+ )
+ search_results = (
+ self.vector_db_provider.hybrid_search(
+ query_vector=query_vector,
+ query_text=message,
+ filters=search_filters,
+ limit=search_limit,
+ )
+ if vector_search_settings.do_hybrid_search
+ else self.vector_db_provider.search(
+ query_vector=query_vector,
+ filters=search_filters,
+ limit=search_limit,
+ )
+ )
+ reranked_results = self.embedding_provider.rerank(
+ query=message, results=search_results, limit=search_limit
+ )
+ for result in reranked_results:
+ result.metadata["associatedQuery"] = message
+ results.append(result)
+ yield result
+ await self.enqueue_log(
+ run_id=run_id,
+ key="search_results",
+ value=json.dumps([ele.json() for ele in results]),
+ )
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[VectorSearchResult, None]:
+ search_queries = []
+ search_results = []
+ async for search_request in input.message:
+ search_queries.append(search_request)
+ async for result in self.search(
+ message=search_request,
+ run_id=run_id,
+ vector_search_settings=vector_search_settings,
+ *args,
+ **kwargs,
+ ):
+ search_results.append(result)
+ yield result
+
+ await state.update(
+ self.config.name, {"output": {"search_results": search_results}}
+ )
+
+ await state.update(
+ self.config.name,
+ {
+ "output": {
+ "search_queries": search_queries,
+ "search_results": search_results,
+ }
+ },
+ )