diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/ingestion/kg_extraction_pipe.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/ingestion/kg_extraction_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/ingestion/kg_extraction_pipe.py | 226 |
1 files changed, 226 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py new file mode 100755 index 00000000..13025e39 --- /dev/null +++ b/R2R/r2r/pipes/ingestion/kg_extraction_pipe.py @@ -0,0 +1,226 @@ +import asyncio +import copy +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Optional + +from r2r.base import ( + AsyncState, + Extraction, + Fragment, + FragmentType, + KGExtraction, + KGProvider, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, + TextSplitter, + extract_entities, + extract_triples, + generate_id_from_label, +) +from r2r.base.pipes.base_pipe import AsyncPipe + +logger = logging.getLogger(__name__) + + +class ClientError(Exception): + """Base class for client connection errors.""" + + pass + + +class KGExtractionPipe(AsyncPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + def __init__( + self, + kg_provider: KGProvider, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + text_splitter: TextSplitter, + kg_batch_size: int = 1, + id_prefix: str = "demo", + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[AsyncPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + pipe_logger=pipe_logger, + type=type, + config=config + or AsyncPipe.PipeConfig(name="default_embedding_pipe"), + ) + + self.kg_provider = kg_provider + self.prompt_provider = prompt_provider + self.llm_provider = llm_provider + self.text_splitter = text_splitter + self.kg_batch_size = kg_batch_size + self.id_prefix = id_prefix + self.pipe_run_info = None + + async def fragment( + self, extraction: Extraction, run_id: uuid.UUID + ) -> AsyncGenerator[Fragment, None]: + """ + Splits text into manageable chunks for embedding. + """ + if not isinstance(extraction, Extraction): + raise ValueError( + f"Expected an Extraction, but received {type(extraction)}." + ) + if not isinstance(extraction.data, str): + raise ValueError( + f"Expected a string, but received {type(extraction.data)}." + ) + text_chunks = [ + ele.page_content + for ele in self.text_splitter.create_documents([extraction.data]) + ] + for iteration, chunk in enumerate(text_chunks): + fragment = Fragment( + id=generate_id_from_label(f"{extraction.id}-{iteration}"), + type=FragmentType.TEXT, + data=chunk, + metadata=copy.deepcopy(extraction.metadata), + extraction_id=extraction.id, + document_id=extraction.document_id, + ) + yield fragment + + async def transform_fragments( + self, fragments: list[Fragment] + ) -> AsyncGenerator[Fragment, None]: + """ + Transforms text chunks based on their metadata, e.g., adding prefixes. + """ + async for fragment in fragments: + if "chunk_prefix" in fragment.metadata: + prefix = fragment.metadata.pop("chunk_prefix") + fragment.data = f"{prefix}\n{fragment.data}" + yield fragment + + async def extract_kg( + self, + fragment: Fragment, + retries: int = 3, + delay: int = 2, + ) -> KGExtraction: + """ + Extracts NER triples from a list of fragments with retries. + """ + task_prompt = self.prompt_provider.get_prompt( + self.kg_provider.config.kg_extraction_prompt, + inputs={"input": fragment.data}, + ) + messages = self.prompt_provider._get_message_payload( + self.prompt_provider.get_prompt("default_system"), task_prompt + ) + for attempt in range(retries): + try: + response = await self.llm_provider.aget_completion( + messages, self.kg_provider.config.kg_extraction_config + ) + + kg_extraction = response.choices[0].message.content + + # Parsing JSON from the response + kg_json = ( + json.loads( + kg_extraction.split("```json")[1].split("```")[0] + ) + if """```json""" in kg_extraction + else json.loads(kg_extraction) + ) + llm_payload = kg_json.get("entities_and_triples", {}) + + # Extract triples with detailed logging + entities = extract_entities(llm_payload) + triples = extract_triples(llm_payload, entities) + + # Create KG extraction object + return KGExtraction(entities=entities, triples=triples) + except ( + ClientError, + json.JSONDecodeError, + KeyError, + IndexError, + ) as e: + logger.error(f"Error in extract_kg: {e}") + if attempt < retries - 1: + await asyncio.sleep(delay) + else: + logger.error(f"Failed after retries with {e}") + # raise e # Ensure the exception is raised after the final attempt + + return KGExtraction(entities={}, triples=[]) + + async def _process_batch( + self, + fragment_batch: list[Fragment], + ) -> list[KGExtraction]: + """ + Embeds a batch of fragments and yields vector entries. + """ + tasks = [ + asyncio.create_task(self.extract_kg(fragment)) + for fragment in fragment_batch + ] + return await asyncio.gather(*tasks) + + async def _run_logic( + self, + input: AsyncPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[KGExtraction, None]: + """ + Executes the embedding pipe: chunking, transforming, embedding, and storing documents. + """ + batch_tasks = [] + fragment_batch = [] + + fragment_info = {} + async for extraction in input.message: + async for fragment in self.transform_fragments( + self.fragment(extraction, run_id) + ): + if extraction.document_id in fragment_info: + fragment_info[extraction.document_id] += 1 + else: + fragment_info[extraction.document_id] = 1 + extraction.metadata["chunk_order"] = fragment_info[ + extraction.document_id + ] + fragment_batch.append(fragment) + if len(fragment_batch) >= self.kg_batch_size: + # Here, ensure `_process_batch` is scheduled as a coroutine, not called directly + batch_tasks.append( + self._process_batch(fragment_batch.copy()) + ) # pass a copy if necessary + fragment_batch.clear() # Clear the batch for new fragments + + logger.debug( + f"Fragmented the input document ids into counts as shown: {fragment_info}" + ) + + if fragment_batch: # Process any remaining fragments + batch_tasks.append(self._process_batch(fragment_batch.copy())) + + # Process tasks as they complete + for task in asyncio.as_completed(batch_tasks): + batch_result = await task # Wait for the next task to complete + for kg_extraction in batch_result: + yield kg_extraction |