aboutsummaryrefslogtreecommitdiff
import logging
import uuid
from typing import Any, Optional

from r2r.base import (
    AsyncState,
    KGProvider,
    KGSearchSettings,
    KVLoggingSingleton,
    LLMProvider,
    PipeType,
    PromptProvider,
)

from ..abstractions.generator_pipe import GeneratorPipe

logger = logging.getLogger(__name__)


class KGAgentSearchPipe(GeneratorPipe):
    """
    Embeds and stores documents using a specified embedding model and database.
    """

    def __init__(
        self,
        kg_provider: KGProvider,
        llm_provider: LLMProvider,
        prompt_provider: PromptProvider,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        type: PipeType = PipeType.INGESTOR,
        config: Optional[GeneratorPipe.PipeConfig] = None,
        *args,
        **kwargs,
    ):
        """
        Initializes the embedding pipe with necessary components and configurations.
        """
        super().__init__(
            llm_provider=llm_provider,
            prompt_provider=prompt_provider,
            type=type,
            config=config
            or GeneratorPipe.Config(
                name="kg_rag_pipe", task_prompt="kg_agent"
            ),
            pipe_logger=pipe_logger,
            *args,
            **kwargs,
        )
        self.kg_provider = kg_provider
        self.llm_provider = llm_provider
        self.prompt_provider = prompt_provider
        self.pipe_run_info = None

    async def _run_logic(
        self,
        input: GeneratorPipe.Input,
        state: AsyncState,
        run_id: uuid.UUID,
        kg_search_settings: KGSearchSettings,
        *args: Any,
        **kwargs: Any,
    ):
        async for message in input.message:
            # TODO - Remove hard code
            formatted_prompt = self.prompt_provider.get_prompt(
                "kg_agent", {"input": message}
            )
            messages = self._get_message_payload(formatted_prompt)

            result = await self.llm_provider.aget_completion(
                messages=messages,
                generation_config=kg_search_settings.agent_generation_config,
            )

            extraction = result.choices[0].message.content
            query = extraction.split("```cypher")[1].split("```")[0]
            result = self.kg_provider.structured_query(query)
            yield (query, result)

            await self.enqueue_log(
                run_id=run_id,
                key="kg_agent_response",
                value=extraction,
            )

            await self.enqueue_log(
                run_id=run_id,
                key="kg_agent_execution_result",
                value=result,
            )

    def _get_message_payload(self, message: str) -> dict:
        return [
            {
                "role": "system",
                "content": self.prompt_provider.get_prompt(
                    self.config.system_prompt,
                ),
            },
            {"role": "user", "content": message},
        ]