aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/pipes/ingestion/kg_extraction_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/ingestion/kg_extraction_pipe.py226
1 files changed, 226 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
new file mode 100755
index 00000000..13025e39
--- /dev/null
+++ b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py
@@ -0,0 +1,226 @@
+import asyncio
+import copy
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Optional
+
+from r2r.base import (
+ AsyncState,
+ Extraction,
+ Fragment,
+ FragmentType,
+ KGExtraction,
+ KGProvider,
+ KVLoggingSingleton,
+ LLMProvider,
+ PipeType,
+ PromptProvider,
+ TextSplitter,
+ extract_entities,
+ extract_triples,
+ generate_id_from_label,
+)
+from r2r.base.pipes.base_pipe import AsyncPipe
+
+logger = logging.getLogger(__name__)
+
+
+class ClientError(Exception):
+ """Base class for client connection errors."""
+
+ pass
+
+
+class KGExtractionPipe(AsyncPipe):
+ """
+ Embeds and stores documents using a specified embedding model and database.
+ """
+
+ def __init__(
+ self,
+ kg_provider: KGProvider,
+ llm_provider: LLMProvider,
+ prompt_provider: PromptProvider,
+ text_splitter: TextSplitter,
+ kg_batch_size: int = 1,
+ id_prefix: str = "demo",
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ type: PipeType = PipeType.INGESTOR,
+ config: Optional[AsyncPipe.PipeConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Initializes the embedding pipe with necessary components and configurations.
+ """
+ super().__init__(
+ pipe_logger=pipe_logger,
+ type=type,
+ config=config
+ or AsyncPipe.PipeConfig(name="default_embedding_pipe"),
+ )
+
+ self.kg_provider = kg_provider
+ self.prompt_provider = prompt_provider
+ self.llm_provider = llm_provider
+ self.text_splitter = text_splitter
+ self.kg_batch_size = kg_batch_size
+ self.id_prefix = id_prefix
+ self.pipe_run_info = None
+
+ async def fragment(
+ self, extraction: Extraction, run_id: uuid.UUID
+ ) -> AsyncGenerator[Fragment, None]:
+ """
+ Splits text into manageable chunks for embedding.
+ """
+ if not isinstance(extraction, Extraction):
+ raise ValueError(
+ f"Expected an Extraction, but received {type(extraction)}."
+ )
+ if not isinstance(extraction.data, str):
+ raise ValueError(
+ f"Expected a string, but received {type(extraction.data)}."
+ )
+ text_chunks = [
+ ele.page_content
+ for ele in self.text_splitter.create_documents([extraction.data])
+ ]
+ for iteration, chunk in enumerate(text_chunks):
+ fragment = Fragment(
+ id=generate_id_from_label(f"{extraction.id}-{iteration}"),
+ type=FragmentType.TEXT,
+ data=chunk,
+ metadata=copy.deepcopy(extraction.metadata),
+ extraction_id=extraction.id,
+ document_id=extraction.document_id,
+ )
+ yield fragment
+
+ async def transform_fragments(
+ self, fragments: list[Fragment]
+ ) -> AsyncGenerator[Fragment, None]:
+ """
+ Transforms text chunks based on their metadata, e.g., adding prefixes.
+ """
+ async for fragment in fragments:
+ if "chunk_prefix" in fragment.metadata:
+ prefix = fragment.metadata.pop("chunk_prefix")
+ fragment.data = f"{prefix}\n{fragment.data}"
+ yield fragment
+
+ async def extract_kg(
+ self,
+ fragment: Fragment,
+ retries: int = 3,
+ delay: int = 2,
+ ) -> KGExtraction:
+ """
+ Extracts NER triples from a list of fragments with retries.
+ """
+ task_prompt = self.prompt_provider.get_prompt(
+ self.kg_provider.config.kg_extraction_prompt,
+ inputs={"input": fragment.data},
+ )
+ messages = self.prompt_provider._get_message_payload(
+ self.prompt_provider.get_prompt("default_system"), task_prompt
+ )
+ for attempt in range(retries):
+ try:
+ response = await self.llm_provider.aget_completion(
+ messages, self.kg_provider.config.kg_extraction_config
+ )
+
+ kg_extraction = response.choices[0].message.content
+
+ # Parsing JSON from the response
+ kg_json = (
+ json.loads(
+ kg_extraction.split("```json")[1].split("```")[0]
+ )
+ if """```json""" in kg_extraction
+ else json.loads(kg_extraction)
+ )
+ llm_payload = kg_json.get("entities_and_triples", {})
+
+ # Extract triples with detailed logging
+ entities = extract_entities(llm_payload)
+ triples = extract_triples(llm_payload, entities)
+
+ # Create KG extraction object
+ return KGExtraction(entities=entities, triples=triples)
+ except (
+ ClientError,
+ json.JSONDecodeError,
+ KeyError,
+ IndexError,
+ ) as e:
+ logger.error(f"Error in extract_kg: {e}")
+ if attempt < retries - 1:
+ await asyncio.sleep(delay)
+ else:
+ logger.error(f"Failed after retries with {e}")
+ # raise e # Ensure the exception is raised after the final attempt
+
+ return KGExtraction(entities={}, triples=[])
+
+ async def _process_batch(
+ self,
+ fragment_batch: list[Fragment],
+ ) -> list[KGExtraction]:
+ """
+ Embeds a batch of fragments and yields vector entries.
+ """
+ tasks = [
+ asyncio.create_task(self.extract_kg(fragment))
+ for fragment in fragment_batch
+ ]
+ return await asyncio.gather(*tasks)
+
+ async def _run_logic(
+ self,
+ input: AsyncPipe.Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[KGExtraction, None]:
+ """
+ Executes the embedding pipe: chunking, transforming, embedding, and storing documents.
+ """
+ batch_tasks = []
+ fragment_batch = []
+
+ fragment_info = {}
+ async for extraction in input.message:
+ async for fragment in self.transform_fragments(
+ self.fragment(extraction, run_id)
+ ):
+ if extraction.document_id in fragment_info:
+ fragment_info[extraction.document_id] += 1
+ else:
+ fragment_info[extraction.document_id] = 1
+ extraction.metadata["chunk_order"] = fragment_info[
+ extraction.document_id
+ ]
+ fragment_batch.append(fragment)
+ if len(fragment_batch) >= self.kg_batch_size:
+ # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly
+ batch_tasks.append(
+ self._process_batch(fragment_batch.copy())
+ ) # pass a copy if necessary
+ fragment_batch.clear() # Clear the batch for new fragments
+
+ logger.debug(
+ f"Fragmented the input document ids into counts as shown: {fragment_info}"
+ )
+
+ if fragment_batch: # Process any remaining fragments
+ batch_tasks.append(self._process_batch(fragment_batch.copy()))
+
+ # Process tasks as they complete
+ for task in asyncio.as_completed(batch_tasks):
+ batch_result = await task # Wait for the next task to complete
+ for kg_extraction in batch_result:
+ yield kg_extraction