about summary refs log tree commit diff
path: root/R2R/r2r/pipes/retrieval
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/pipes/retrieval')
-rwxr-xr-xR2R/r2r/pipes/retrieval/kg_agent_search_pipe.py103
-rwxr-xr-xR2R/r2r/pipes/retrieval/multi_search.py79
-rwxr-xr-xR2R/r2r/pipes/retrieval/query_transform_pipe.py101
-rwxr-xr-xR2R/r2r/pipes/retrieval/search_rag_pipe.py130
-rwxr-xr-xR2R/r2r/pipes/retrieval/streaming_rag_pipe.py131
-rwxr-xr-xR2R/r2r/pipes/retrieval/vector_search_pipe.py123
6 files changed, 667 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
new file mode 100755
index 00000000..60935265
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/kg_agent_search_pipe.py
@@ -0,0 +1,103 @@
+import logging
+import uuid
+from typing import Any, Optional
+
+from r2r.base import (
+    AsyncState,
+    KGProvider,
+    KGSearchSettings,
+    KVLoggingSingleton,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class KGAgentSearchPipe(GeneratorPipe):
+    """
+    Embeds and stores documents using a specified embedding model and database.
+    """
+
+    def __init__(
+        self,
+        kg_provider: KGProvider,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        type: PipeType = PipeType.INGESTOR,
+        config: Optional[GeneratorPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        """
+        Initializes the embedding pipe with necessary components and configurations.
+        """
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config
+            or GeneratorPipe.Config(
+                name="kg_rag_pipe", task_prompt="kg_agent"
+            ),
+            pipe_logger=pipe_logger,
+            *args,
+            **kwargs,
+        )
+        self.kg_provider = kg_provider
+        self.llm_provider = llm_provider
+        self.prompt_provider = prompt_provider
+        self.pipe_run_info = None
+
+    async def _run_logic(
+        self,
+        input: GeneratorPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        kg_search_settings: KGSearchSettings,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        async for message in input.message:
+            # TODO - Remove hard code
+            formatted_prompt = self.prompt_provider.get_prompt(
+                "kg_agent", {"input": message}
+            )
+            messages = self._get_message_payload(formatted_prompt)
+
+            result = await self.llm_provider.aget_completion(
+                messages=messages,
+                generation_config=kg_search_settings.agent_generation_config,
+            )
+
+            extraction = result.choices[0].message.content
+            query = extraction.split("```cypher")[1].split("```")[0]
+            result = self.kg_provider.structured_query(query)
+            yield (query, result)
+
+            await self.enqueue_log(
+                run_id=run_id,
+                key="kg_agent_response",
+                value=extraction,
+            )
+
+            await self.enqueue_log(
+                run_id=run_id,
+                key="kg_agent_execution_result",
+                value=result,
+            )
+
+    def _get_message_payload(self, message: str) -> dict:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt,
+                ),
+            },
+            {"role": "user", "content": message},
+        ]
diff --git a/R2R/r2r/pipes/retrieval/multi_search.py b/R2R/r2r/pipes/retrieval/multi_search.py
new file mode 100755
index 00000000..6da2c34b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/multi_search.py
@@ -0,0 +1,79 @@
+import uuid
+from copy import copy
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base.abstractions.llm import GenerationConfig
+from r2r.base.abstractions.search import VectorSearchResult
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+from ..abstractions.search_pipe import SearchPipe
+from .query_transform_pipe import QueryTransformPipe
+
+
+class MultiSearchPipe(AsyncPipe):
+    class PipeConfig(AsyncPipe.PipeConfig):
+        name: str = "multi_search_pipe"
+
+    def __init__(
+        self,
+        query_transform_pipe: QueryTransformPipe,
+        inner_search_pipe: SearchPipe,
+        config: Optional[PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        self.query_transform_pipe = query_transform_pipe
+        self.vector_search_pipe = inner_search_pipe
+        if (
+            not query_transform_pipe.config.name
+            == inner_search_pipe.config.name
+        ):
+            raise ValueError(
+                "The query transform pipe and search pipe must have the same name."
+            )
+        if config and not config.name == query_transform_pipe.config.name:
+            raise ValueError(
+                "The pipe config name must match the query transform pipe name."
+            )
+
+        super().__init__(
+            config=config
+            or MultiSearchPipe.PipeConfig(
+                name=query_transform_pipe.config.name
+            ),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: Any,
+        state: Any,
+        run_id: uuid.UUID,
+        query_transform_generation_config: Optional[GenerationConfig] = None,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[VectorSearchResult, None]:
+        query_transform_generation_config = (
+            query_transform_generation_config
+            or copy(kwargs.get("rag_generation_config", None))
+            or GenerationConfig(model="gpt-4o")
+        )
+        query_transform_generation_config.stream = False
+
+        query_generator = await self.query_transform_pipe.run(
+            input,
+            state,
+            query_transform_generation_config=query_transform_generation_config,
+            num_query_xf_outputs=3,
+            *args,
+            **kwargs,
+        )
+
+        async for search_result in await self.vector_search_pipe.run(
+            self.vector_search_pipe.Input(message=query_generator),
+            state,
+            *args,
+            **kwargs,
+        ):
+            yield search_result
diff --git a/R2R/r2r/pipes/retrieval/query_transform_pipe.py b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
new file mode 100755
index 00000000..99df6b5b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/query_transform_pipe.py
@@ -0,0 +1,101 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+    AsyncPipe,
+    AsyncState,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class QueryTransformPipe(GeneratorPipe):
+    class QueryTransformConfig(GeneratorPipe.PipeConfig):
+        name: str = "default_query_transform"
+        system_prompt: str = "default_system"
+        task_prompt: str = "hyde"
+
+    class Input(GeneratorPipe.Input):
+        message: AsyncGenerator[str, None]
+
+    def __init__(
+        self,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        type: PipeType = PipeType.TRANSFORM,
+        config: Optional[QueryTransformConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config or QueryTransformPipe.QueryTransformConfig(),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: AsyncPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        query_transform_generation_config: GenerationConfig,
+        num_query_xf_outputs: int = 3,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[str, None]:
+        async for query in input.message:
+            logger.info(
+                f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
+            )
+
+            query_transform_request = self._get_message_payload(
+                query, num_outputs=num_query_xf_outputs
+            )
+
+            response = await self.llm_provider.aget_completion(
+                messages=query_transform_request,
+                generation_config=query_transform_generation_config,
+            )
+            content = self.llm_provider.extract_content(response)
+            outputs = content.split("\n")
+            outputs = [
+                output.strip() for output in outputs if output.strip() != ""
+            ]
+            await state.update(
+                self.config.name, {"output": {"outputs": outputs}}
+            )
+
+            for output in outputs:
+                logger.info(f"Yielding transformed output: {output}")
+                yield output
+
+    def _get_message_payload(self, input: str, num_outputs: int) -> dict:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt,
+                ),
+            },
+            {
+                "role": "user",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.task_prompt,
+                    inputs={
+                        "message": input,
+                        "num_outputs": num_outputs,
+                    },
+                ),
+            },
+        ]
diff --git a/R2R/r2r/pipes/retrieval/search_rag_pipe.py b/R2R/r2r/pipes/retrieval/search_rag_pipe.py
new file mode 100755
index 00000000..4d01d2df
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/search_rag_pipe.py
@@ -0,0 +1,130 @@
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional, Tuple
+
+from r2r.base import (
+    AggregateSearchResult,
+    AsyncPipe,
+    AsyncState,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion
+
+from ..abstractions.generator_pipe import GeneratorPipe
+
+logger = logging.getLogger(__name__)
+
+
+class SearchRAGPipe(GeneratorPipe):
+    class Input(AsyncPipe.Input):
+        message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
+
+    def __init__(
+        self,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        type: PipeType = PipeType.GENERATOR,
+        config: Optional[GeneratorPipe] = None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config
+            or GeneratorPipe.Config(
+                name="default_rag_pipe", task_prompt="default_rag"
+            ),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        rag_generation_config: GenerationConfig,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[RAGCompletion, None]:
+        context = ""
+        search_iteration = 1
+        total_results = 0
+        # must select a query if there are multiple
+        sel_query = None
+        async for query, search_results in input.message:
+            if search_iteration == 1:
+                sel_query = query
+            context_piece, total_results = await self._collect_context(
+                query, search_results, search_iteration, total_results
+            )
+            context += context_piece
+            search_iteration += 1
+
+        messages = self._get_message_payload(sel_query, context)
+
+        response = await self.llm_provider.aget_completion(
+            messages=messages, generation_config=rag_generation_config
+        )
+        yield RAGCompletion(completion=response, search_results=search_results)
+
+        await self.enqueue_log(
+            run_id=run_id,
+            key="llm_response",
+            value=response.choices[0].message.content,
+        )
+
+    def _get_message_payload(self, query: str, context: str) -> dict:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt,
+                ),
+            },
+            {
+                "role": "user",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.task_prompt,
+                    inputs={
+                        "query": query,
+                        "context": context,
+                    },
+                ),
+            },
+        ]
+
+    async def _collect_context(
+        self,
+        query: str,
+        results: AggregateSearchResult,
+        iteration: int,
+        total_results: int,
+    ) -> Tuple[str, int]:
+        context = f"Query:\n{query}\n\n"
+        if results.vector_search_results:
+            context += f"Vector Search Results({iteration}):\n"
+            it = total_results + 1
+            for result in results.vector_search_results:
+                context += f"[{it}]: {result.metadata['text']}\n\n"
+                it += 1
+            total_results = (
+                it - 1
+            )  # Update total_results based on the last index used
+        if results.kg_search_results:
+            context += f"Knowledge Graph ({iteration}):\n"
+            it = total_results + 1
+            for query, search_results in results.kg_search_results:  # [1]:
+                context += f"Query: {query}\n\n"
+                context += f"Results:\n"
+                for search_result in search_results:
+                    context += f"[{it}]: {search_result}\n\n"
+                    it += 1
+            total_results = (
+                it - 1
+            )  # Update total_results based on the last index used
+        return context, total_results
diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
new file mode 100755
index 00000000..b01f6445
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
@@ -0,0 +1,131 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from r2r.base import (
+    AsyncState,
+    LLMChatCompletionChunk,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+from .search_rag_pipe import SearchRAGPipe
+
+logger = logging.getLogger(__name__)
+
+
+class StreamingSearchRAGPipe(SearchRAGPipe):
+    SEARCH_STREAM_MARKER = "search"
+    COMPLETION_STREAM_MARKER = "completion"
+
+    def __init__(
+        self,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        type: PipeType = PipeType.GENERATOR,
+        config: Optional[GeneratorPipe] = None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config
+            or GeneratorPipe.Config(
+                name="default_streaming_rag_pipe", task_prompt="default_rag"
+            ),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: SearchRAGPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        rag_generation_config: GenerationConfig,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[str, None]:
+        iteration = 0
+        context = ""
+        # dump the search results and construct the context
+        async for query, search_results in input.message:
+            yield f"<{self.SEARCH_STREAM_MARKER}>"
+            if search_results.vector_search_results:
+                context += "Vector Search Results:\n"
+                for result in search_results.vector_search_results:
+                    if iteration >= 1:
+                        yield ","
+                    yield json.dumps(result.json())
+                    context += (
+                        f"{iteration + 1}:\n{result.metadata['text']}\n\n"
+                    )
+                    iteration += 1
+
+            # if search_results.kg_search_results:
+            #     for result in search_results.kg_search_results:
+            #         if iteration >= 1:
+            #             yield ","
+            #         yield json.dumps(result.json())
+            #         context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n"
+            #         iteration += 1
+
+            yield f"</{self.SEARCH_STREAM_MARKER}>"
+
+            messages = self._get_message_payload(query, context)
+            yield f"<{self.COMPLETION_STREAM_MARKER}>"
+            response = ""
+            for chunk in self.llm_provider.get_completion_stream(
+                messages=messages, generation_config=rag_generation_config
+            ):
+                chunk = StreamingSearchRAGPipe._process_chunk(chunk)
+                response += chunk
+                yield chunk
+
+            yield f"</{self.COMPLETION_STREAM_MARKER}>"
+
+            await self.enqueue_log(
+                run_id=run_id,
+                key="llm_response",
+                value=response,
+            )
+
+    async def _yield_chunks(
+        self,
+        start_marker: str,
+        chunks: Generator[str, None, None],
+        end_marker: str,
+    ) -> str:
+        yield start_marker
+        for chunk in chunks:
+            yield chunk
+        yield end_marker
+
+    def _get_message_payload(
+        self, query: str, context: str
+    ) -> list[dict[str, str]]:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt
+                ),
+            },
+            {
+                "role": "user",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.task_prompt,
+                    inputs={"query": query, "context": context},
+                ),
+            },
+        ]
+
+    @staticmethod
+    def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
+        return chunk.choices[0].delta.content or ""
diff --git a/R2R/r2r/pipes/retrieval/vector_search_pipe.py b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
new file mode 100755
index 00000000..742de16b
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/vector_search_pipe.py
@@ -0,0 +1,123 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+    AsyncPipe,
+    AsyncState,
+    EmbeddingProvider,
+    PipeType,
+    VectorDBProvider,
+    VectorSearchResult,
+    VectorSearchSettings,
+)
+
+from ..abstractions.search_pipe import SearchPipe
+
+logger = logging.getLogger(__name__)
+
+
+class VectorSearchPipe(SearchPipe):
+    def __init__(
+        self,
+        vector_db_provider: VectorDBProvider,
+        embedding_provider: EmbeddingProvider,
+        type: PipeType = PipeType.SEARCH,
+        config: Optional[SearchPipe.SearchConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(
+            type=type,
+            config=config or SearchPipe.SearchConfig(),
+            *args,
+            **kwargs,
+        )
+        self.embedding_provider = embedding_provider
+        self.vector_db_provider = vector_db_provider
+
+    async def search(
+        self,
+        message: str,
+        run_id: uuid.UUID,
+        vector_search_settings: VectorSearchSettings,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[VectorSearchResult, None]:
+        await self.enqueue_log(
+            run_id=run_id, key="search_query", value=message
+        )
+        search_filters = (
+            vector_search_settings.search_filters or self.config.search_filters
+        )
+        search_limit = (
+            vector_search_settings.search_limit or self.config.search_limit
+        )
+        results = []
+        query_vector = self.embedding_provider.get_embedding(
+            message,
+        )
+        search_results = (
+            self.vector_db_provider.hybrid_search(
+                query_vector=query_vector,
+                query_text=message,
+                filters=search_filters,
+                limit=search_limit,
+            )
+            if vector_search_settings.do_hybrid_search
+            else self.vector_db_provider.search(
+                query_vector=query_vector,
+                filters=search_filters,
+                limit=search_limit,
+            )
+        )
+        reranked_results = self.embedding_provider.rerank(
+            query=message, results=search_results, limit=search_limit
+        )
+        for result in reranked_results:
+            result.metadata["associatedQuery"] = message
+            results.append(result)
+            yield result
+        await self.enqueue_log(
+            run_id=run_id,
+            key="search_results",
+            value=json.dumps([ele.json() for ele in results]),
+        )
+
+    async def _run_logic(
+        self,
+        input: AsyncPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[VectorSearchResult, None]:
+        search_queries = []
+        search_results = []
+        async for search_request in input.message:
+            search_queries.append(search_request)
+            async for result in self.search(
+                message=search_request,
+                run_id=run_id,
+                vector_search_settings=vector_search_settings,
+                *args,
+                **kwargs,
+            ):
+                search_results.append(result)
+                yield result
+
+        await state.update(
+            self.config.name, {"output": {"search_results": search_results}}
+        )
+
+        await state.update(
+            self.config.name,
+            {
+                "output": {
+                    "search_queries": search_queries,
+                    "search_results": search_results,
+                }
+            },
+        )