aboutsummaryrefslogtreecommitdiff
import asyncio
import copy
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Optional

from r2r.base import (
    AsyncState,
    Extraction,
    Fragment,
    FragmentType,
    KGExtraction,
    KGProvider,
    KVLoggingSingleton,
    LLMProvider,
    PipeType,
    PromptProvider,
    TextSplitter,
    extract_entities,
    extract_triples,
    generate_id_from_label,
)
from r2r.base.pipes.base_pipe import AsyncPipe

logger = logging.getLogger(__name__)


class ClientError(Exception):
    """Base class for client connection errors."""

    pass


class KGExtractionPipe(AsyncPipe):
    """
    Embeds and stores documents using a specified embedding model and database.
    """

    def __init__(
        self,
        kg_provider: KGProvider,
        llm_provider: LLMProvider,
        prompt_provider: PromptProvider,
        text_splitter: TextSplitter,
        kg_batch_size: int = 1,
        id_prefix: str = "demo",
        pipe_logger: Optional[KVLoggingSingleton] = None,
        type: PipeType = PipeType.INGESTOR,
        config: Optional[AsyncPipe.PipeConfig] = None,
        *args,
        **kwargs,
    ):
        """
        Initializes the embedding pipe with necessary components and configurations.
        """
        super().__init__(
            pipe_logger=pipe_logger,
            type=type,
            config=config
            or AsyncPipe.PipeConfig(name="default_embedding_pipe"),
        )

        self.kg_provider = kg_provider
        self.prompt_provider = prompt_provider
        self.llm_provider = llm_provider
        self.text_splitter = text_splitter
        self.kg_batch_size = kg_batch_size
        self.id_prefix = id_prefix
        self.pipe_run_info = None

    async def fragment(
        self, extraction: Extraction, run_id: uuid.UUID
    ) -> AsyncGenerator[Fragment, None]:
        """
        Splits text into manageable chunks for embedding.
        """
        if not isinstance(extraction, Extraction):
            raise ValueError(
                f"Expected an Extraction, but received {type(extraction)}."
            )
        if not isinstance(extraction.data, str):
            raise ValueError(
                f"Expected a string, but received {type(extraction.data)}."
            )
        text_chunks = [
            ele.page_content
            for ele in self.text_splitter.create_documents([extraction.data])
        ]
        for iteration, chunk in enumerate(text_chunks):
            fragment = Fragment(
                id=generate_id_from_label(f"{extraction.id}-{iteration}"),
                type=FragmentType.TEXT,
                data=chunk,
                metadata=copy.deepcopy(extraction.metadata),
                extraction_id=extraction.id,
                document_id=extraction.document_id,
            )
            yield fragment

    async def transform_fragments(
        self, fragments: list[Fragment]
    ) -> AsyncGenerator[Fragment, None]:
        """
        Transforms text chunks based on their metadata, e.g., adding prefixes.
        """
        async for fragment in fragments:
            if "chunk_prefix" in fragment.metadata:
                prefix = fragment.metadata.pop("chunk_prefix")
                fragment.data = f"{prefix}\n{fragment.data}"
            yield fragment

    async def extract_kg(
        self,
        fragment: Fragment,
        retries: int = 3,
        delay: int = 2,
    ) -> KGExtraction:
        """
        Extracts NER triples from a list of fragments with retries.
        """
        task_prompt = self.prompt_provider.get_prompt(
            self.kg_provider.config.kg_extraction_prompt,
            inputs={"input": fragment.data},
        )
        messages = self.prompt_provider._get_message_payload(
            self.prompt_provider.get_prompt("default_system"), task_prompt
        )
        for attempt in range(retries):
            try:
                response = await self.llm_provider.aget_completion(
                    messages, self.kg_provider.config.kg_extraction_config
                )

                kg_extraction = response.choices[0].message.content

                # Parsing JSON from the response
                kg_json = (
                    json.loads(
                        kg_extraction.split("```json")[1].split("```")[0]
                    )
                    if """```json""" in kg_extraction
                    else json.loads(kg_extraction)
                )
                llm_payload = kg_json.get("entities_and_triples", {})

                # Extract triples with detailed logging
                entities = extract_entities(llm_payload)
                triples = extract_triples(llm_payload, entities)

                # Create KG extraction object
                return KGExtraction(entities=entities, triples=triples)
            except (
                ClientError,
                json.JSONDecodeError,
                KeyError,
                IndexError,
            ) as e:
                logger.error(f"Error in extract_kg: {e}")
                if attempt < retries - 1:
                    await asyncio.sleep(delay)
                else:
                    logger.error(f"Failed after retries with {e}")
                    # raise e  # Ensure the exception is raised after the final attempt

        return KGExtraction(entities={}, triples=[])

    async def _process_batch(
        self,
        fragment_batch: list[Fragment],
    ) -> list[KGExtraction]:
        """
        Embeds a batch of fragments and yields vector entries.
        """
        tasks = [
            asyncio.create_task(self.extract_kg(fragment))
            for fragment in fragment_batch
        ]
        return await asyncio.gather(*tasks)

    async def _run_logic(
        self,
        input: AsyncPipe.Input,
        state: AsyncState,
        run_id: uuid.UUID,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[KGExtraction, None]:
        """
        Executes the embedding pipe: chunking, transforming, embedding, and storing documents.
        """
        batch_tasks = []
        fragment_batch = []

        fragment_info = {}
        async for extraction in input.message:
            async for fragment in self.transform_fragments(
                self.fragment(extraction, run_id)
            ):
                if extraction.document_id in fragment_info:
                    fragment_info[extraction.document_id] += 1
                else:
                    fragment_info[extraction.document_id] = 1
                extraction.metadata["chunk_order"] = fragment_info[
                    extraction.document_id
                ]
                fragment_batch.append(fragment)
                if len(fragment_batch) >= self.kg_batch_size:
                    # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly
                    batch_tasks.append(
                        self._process_batch(fragment_batch.copy())
                    )  # pass a copy if necessary
                    fragment_batch.clear()  # Clear the batch for new fragments

        logger.debug(
            f"Fragmented the input document ids into counts as shown: {fragment_info}"
        )

        if fragment_batch:  # Process any remaining fragments
            batch_tasks.append(self._process_batch(fragment_batch.copy()))

        # Process tasks as they complete
        for task in asyncio.as_completed(batch_tasks):
            batch_result = await task  # Wait for the next task to complete
            for kg_extraction in batch_result:
                yield kg_extraction