about summary refs log tree commit diff
path: root/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/pipes/ingestion/kg_extraction_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/ingestion/kg_extraction_pipe.py226
1 files changed, 226 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
new file mode 100755
index 00000000..13025e39
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
@@ -0,0 +1,226 @@
+import asyncio
+import copy
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+    AsyncState,
+    Extraction,
+    Fragment,
+    FragmentType,
+    KGExtraction,
+    KGProvider,
+    KVLoggingSingleton,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+    TextSplitter,
+    extract_entities,
+    extract_triples,
+    generate_id_from_label,
+)
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class ClientError(Exception):
+    """Base class for client connection errors."""
+
+    pass
+
+
+class KGExtractionPipe(AsyncPipe):
+    """
+    Embeds and stores documents using a specified embedding model and database.
+    """
+
+    def __init__(
+        self,
+        kg_provider: KGProvider,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        text_splitter: TextSplitter,
+        kg_batch_size: int = 1,
+        id_prefix: str = "demo",
+        pipe_logger: Optional[KVLoggingSingleton] = None,
+        type: PipeType = PipeType.INGESTOR,
+        config: Optional[AsyncPipe.PipeConfig] = None,
+        *args,
+        **kwargs,
+    ):
+        """
+        Initializes the embedding pipe with necessary components and configurations.
+        """
+        super().__init__(
+            pipe_logger=pipe_logger,
+            type=type,
+            config=config
+            or AsyncPipe.PipeConfig(name="default_embedding_pipe"),
+        )
+
+        self.kg_provider = kg_provider
+        self.prompt_provider = prompt_provider
+        self.llm_provider = llm_provider
+        self.text_splitter = text_splitter
+        self.kg_batch_size = kg_batch_size
+        self.id_prefix = id_prefix
+        self.pipe_run_info = None
+
+    async def fragment(
+        self, extraction: Extraction, run_id: uuid.UUID
+    ) -> AsyncGenerator[Fragment, None]:
+        """
+        Splits text into manageable chunks for embedding.
+        """
+        if not isinstance(extraction, Extraction):
+            raise ValueError(
+                f"Expected an Extraction, but received {type(extraction)}."
+            )
+        if not isinstance(extraction.data, str):
+            raise ValueError(
+                f"Expected a string, but received {type(extraction.data)}."
+            )
+        text_chunks = [
+            ele.page_content
+            for ele in self.text_splitter.create_documents([extraction.data])
+        ]
+        for iteration, chunk in enumerate(text_chunks):
+            fragment = Fragment(
+                id=generate_id_from_label(f"{extraction.id}-{iteration}"),
+                type=FragmentType.TEXT,
+                data=chunk,
+                metadata=copy.deepcopy(extraction.metadata),
+                extraction_id=extraction.id,
+                document_id=extraction.document_id,
+            )
+            yield fragment
+
+    async def transform_fragments(
+        self, fragments: list[Fragment]
+    ) -> AsyncGenerator[Fragment, None]:
+        """
+        Transforms text chunks based on their metadata, e.g., adding prefixes.
+        """
+        async for fragment in fragments:
+            if "chunk_prefix" in fragment.metadata:
+                prefix = fragment.metadata.pop("chunk_prefix")
+                fragment.data = f"{prefix}\n{fragment.data}"
+            yield fragment
+
+    async def extract_kg(
+        self,
+        fragment: Fragment,
+        retries: int = 3,
+        delay: int = 2,
+    ) -> KGExtraction:
+        """
+        Extracts NER triples from a list of fragments with retries.
+        """
+        task_prompt = self.prompt_provider.get_prompt(
+            self.kg_provider.config.kg_extraction_prompt,
+            inputs={"input": fragment.data},
+        )
+        messages = self.prompt_provider._get_message_payload(
+            self.prompt_provider.get_prompt("default_system"), task_prompt
+        )
+        for attempt in range(retries):
+            try:
+                response = await self.llm_provider.aget_completion(
+                    messages, self.kg_provider.config.kg_extraction_config
+                )
+
+                kg_extraction = response.choices[0].message.content
+
+                # Parsing JSON from the response
+                kg_json = (
+                    json.loads(
+                        kg_extraction.split("```json")[1].split("```")[0]
+                    )
+                    if """```json""" in kg_extraction
+                    else json.loads(kg_extraction)
+                )
+                llm_payload = kg_json.get("entities_and_triples", {})
+
+                # Extract triples with detailed logging
+                entities = extract_entities(llm_payload)
+                triples = extract_triples(llm_payload, entities)
+
+                # Create KG extraction object
+                return KGExtraction(entities=entities, triples=triples)
+            except (
+                ClientError,
+                json.JSONDecodeError,
+                KeyError,
+                IndexError,
+            ) as e:
+                logger.error(f"Error in extract_kg: {e}")
+                if attempt < retries - 1:
+                    await asyncio.sleep(delay)
+                else:
+                    logger.error(f"Failed after retries with {e}")
+                    # raise e  # Ensure the exception is raised after the final attempt
+
+        return KGExtraction(entities={}, triples=[])
+
+    async def _process_batch(
+        self,
+        fragment_batch: list[Fragment],
+    ) -> list[KGExtraction]:
+        """
+        Embeds a batch of fragments and yields vector entries.
+        """
+        tasks = [
+            asyncio.create_task(self.extract_kg(fragment))
+            for fragment in fragment_batch
+        ]
+        return await asyncio.gather(*tasks)
+
+    async def _run_logic(
+        self,
+        input: AsyncPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[KGExtraction, None]:
+        """
+        Executes the embedding pipe: chunking, transforming, embedding, and storing documents.
+        """
+        batch_tasks = []
+        fragment_batch = []
+
+        fragment_info = {}
+        async for extraction in input.message:
+            async for fragment in self.transform_fragments(
+                self.fragment(extraction, run_id)
+            ):
+                if extraction.document_id in fragment_info:
+                    fragment_info[extraction.document_id] += 1
+                else:
+                    fragment_info[extraction.document_id] = 1
+                extraction.metadata["chunk_order"] = fragment_info[
+                    extraction.document_id
+                ]
+                fragment_batch.append(fragment)
+                if len(fragment_batch) >= self.kg_batch_size:
+                    # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly
+                    batch_tasks.append(
+                        self._process_batch(fragment_batch.copy())
+                    )  # pass a copy if necessary
+                    fragment_batch.clear()  # Clear the batch for new fragments
+
+        logger.debug(
+            f"Fragmented the input document ids into counts as shown: {fragment_info}"
+        )
+
+        if fragment_batch:  # Process any remaining fragments
+            batch_tasks.append(self._process_batch(fragment_batch.copy()))
+
+        # Process tasks as they complete
+        for task in asyncio.as_completed(batch_tasks):
+            batch_result = await task  # Wait for the next task to complete
+            for kg_extraction in batch_result:
+                yield kg_extraction