diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py | 355 |
1 files changed, 355 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py new file mode 100644 index 00000000..7d452245 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py @@ -0,0 +1,355 @@ +# type: ignore +import logging +import time +from typing import Any, AsyncGenerator, Optional + +from core import parsers +from core.base import ( + AsyncParser, + ChunkingStrategy, + Document, + DocumentChunk, + DocumentType, + IngestionConfig, + IngestionProvider, + R2RDocumentProcessingError, + RecursiveCharacterTextSplitter, + TextSplitter, +) +from core.utils import generate_extraction_id + +from ...database import PostgresDatabaseProvider +from ...llm import ( + LiteLLMCompletionProvider, + OpenAICompletionProvider, + R2RCompletionProvider, +) + +logger = logging.getLogger() + + +class R2RIngestionConfig(IngestionConfig): + chunk_size: int = 1024 + chunk_overlap: int = 512 + chunking_strategy: ChunkingStrategy = ChunkingStrategy.RECURSIVE + extra_fields: dict[str, Any] = {} + separator: Optional[str] = None + + +class R2RIngestionProvider(IngestionProvider): + DEFAULT_PARSERS = { + DocumentType.BMP: parsers.BMPParser, + DocumentType.CSV: parsers.CSVParser, + DocumentType.DOC: parsers.DOCParser, + DocumentType.DOCX: parsers.DOCXParser, + DocumentType.EML: parsers.EMLParser, + DocumentType.EPUB: parsers.EPUBParser, + DocumentType.HTML: parsers.HTMLParser, + DocumentType.HTM: parsers.HTMLParser, + DocumentType.ODT: parsers.ODTParser, + DocumentType.JSON: parsers.JSONParser, + DocumentType.MSG: parsers.MSGParser, + DocumentType.ORG: parsers.ORGParser, + DocumentType.MD: parsers.MDParser, + DocumentType.PDF: parsers.BasicPDFParser, + DocumentType.PPT: parsers.PPTParser, + DocumentType.PPTX: parsers.PPTXParser, + DocumentType.TXT: parsers.TextParser, + DocumentType.XLSX: parsers.XLSXParser, + DocumentType.GIF: parsers.ImageParser, + DocumentType.JPEG: parsers.ImageParser, + DocumentType.JPG: parsers.ImageParser, + DocumentType.TSV: parsers.TSVParser, + DocumentType.PNG: parsers.ImageParser, + DocumentType.HEIC: parsers.ImageParser, + DocumentType.SVG: parsers.ImageParser, + DocumentType.MP3: parsers.AudioParser, + DocumentType.P7S: parsers.P7SParser, + DocumentType.RST: parsers.RSTParser, + DocumentType.RTF: parsers.RTFParser, + DocumentType.TIFF: parsers.ImageParser, + DocumentType.XLS: parsers.XLSParser, + } + + EXTRA_PARSERS = { + DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, + DocumentType.PDF: { + "unstructured": parsers.PDFParserUnstructured, + "zerox": parsers.VLMPDFParser, + }, + DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, + } + + IMAGE_TYPES = { + DocumentType.GIF, + DocumentType.HEIC, + DocumentType.JPG, + DocumentType.JPEG, + DocumentType.PNG, + DocumentType.SVG, + } + + def __init__( + self, + config: R2RIngestionConfig, + database_provider: PostgresDatabaseProvider, + llm_provider: ( + LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + ): + super().__init__(config, database_provider, llm_provider) + self.config: R2RIngestionConfig = config + self.database_provider: PostgresDatabaseProvider = database_provider + self.llm_provider: ( + LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ) = llm_provider + self.parsers: dict[DocumentType, AsyncParser] = {} + self.text_splitter = self._build_text_splitter() + self._initialize_parsers() + + logger.info( + f"R2RIngestionProvider initialized with config: {self.config}" + ) + + def _initialize_parsers(self): + for doc_type, parser in self.DEFAULT_PARSERS.items(): + # will choose the first parser in the list + if doc_type not in self.config.excluded_parsers: + self.parsers[doc_type] = parser( + config=self.config, + database_provider=self.database_provider, + llm_provider=self.llm_provider, + ) + for doc_type, doc_parser_name in self.config.extra_parsers.items(): + self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = ( + R2RIngestionProvider.EXTRA_PARSERS[doc_type][doc_parser_name]( + config=self.config, + database_provider=self.database_provider, + llm_provider=self.llm_provider, + ) + ) + + def _build_text_splitter( + self, ingestion_config_override: Optional[dict] = None + ) -> TextSplitter: + logger.info( + f"Initializing text splitter with method: {self.config.chunking_strategy}" + ) + + if not ingestion_config_override: + ingestion_config_override = {} + + chunking_strategy = ( + ingestion_config_override.get("chunking_strategy") + or self.config.chunking_strategy + ) + + chunk_size = ( + ingestion_config_override.get("chunk_size") + or self.config.chunk_size + ) + chunk_overlap = ( + ingestion_config_override.get("chunk_overlap") + or self.config.chunk_overlap + ) + + if chunking_strategy == ChunkingStrategy.RECURSIVE: + return RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + elif chunking_strategy == ChunkingStrategy.CHARACTER: + from core.base.utils.splitter.text import CharacterTextSplitter + + separator = ( + ingestion_config_override.get("separator") + or self.config.separator + or CharacterTextSplitter.DEFAULT_SEPARATOR + ) + + return CharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separator=separator, + keep_separator=False, + strip_whitespace=True, + ) + elif chunking_strategy == ChunkingStrategy.BASIC: + raise NotImplementedError( + "Basic chunking method not implemented. Please use Recursive." + ) + elif chunking_strategy == ChunkingStrategy.BY_TITLE: + raise NotImplementedError("By title method not implemented") + else: + raise ValueError(f"Unsupported method type: {chunking_strategy}") + + def validate_config(self) -> bool: + return self.config.chunk_size > 0 and self.config.chunk_overlap >= 0 + + def chunk( + self, + parsed_document: str | DocumentChunk, + ingestion_config_override: dict, + ) -> AsyncGenerator[Any, None]: + text_spliiter = self.text_splitter + if ingestion_config_override: + text_spliiter = self._build_text_splitter( + ingestion_config_override + ) + if isinstance(parsed_document, DocumentChunk): + parsed_document = parsed_document.data + + if isinstance(parsed_document, str): + chunks = text_spliiter.create_documents([parsed_document]) + else: + # Assuming parsed_document is already a list of text chunks + chunks = parsed_document + + for chunk in chunks: + yield ( + chunk.page_content if hasattr(chunk, "page_content") else chunk + ) + + async def parse( + self, + file_content: bytes, + document: Document, + ingestion_config_override: dict, + ) -> AsyncGenerator[DocumentChunk, None]: + if document.document_type not in self.parsers: + raise R2RDocumentProcessingError( + document_id=document.id, + error_message=f"Parser for {document.document_type} not found in `R2RIngestionProvider`.", + ) + else: + t0 = time.time() + contents = [] + parser_overrides = ingestion_config_override.get( + "parser_overrides", {} + ) + if document.document_type.value in parser_overrides: + logger.info( + f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}" + ) + # TODO - Cleanup this approach to be less hardcoded + if ( + document.document_type != DocumentType.PDF + or parser_overrides[DocumentType.PDF.value] != "zerox" + ): + raise ValueError( + "Only Zerox PDF parser override is available." + ) + + # Collect content from VLMPDFParser + async for chunk in self.parsers[ + f"zerox_{DocumentType.PDF.value}" + ].ingest(file_content, **ingestion_config_override): + if isinstance(chunk, dict) and chunk.get("content"): + contents.append(chunk) + elif ( + chunk + ): # Handle string output for backward compatibility + contents.append({"content": chunk}) + + if ( + contents + and document.document_type == DocumentType.PDF + and parser_overrides.get(DocumentType.PDF.value) == "zerox" + ): + text_splitter = self._build_text_splitter( + ingestion_config_override + ) + + iteration = 0 + + sorted_contents = [ + item + for item in sorted( + contents, key=lambda x: x.get("page_number", 0) + ) + if isinstance(item.get("content"), str) + ] + + for content_item in sorted_contents: + page_num = content_item.get("page_number", 0) + page_content = content_item["content"] + + page_chunks = text_splitter.create_documents( + [page_content] + ) + + # Create document chunks for each split piece + for chunk in page_chunks: + metadata = { + **document.metadata, + "chunk_order": iteration, + "page_number": page_num, + } + + extraction = DocumentChunk( + id=generate_extraction_id( + document.id, iteration + ), + document_id=document.id, + owner_id=document.owner_id, + collection_ids=document.collection_ids, + data=chunk.page_content, + metadata=metadata, + ) + iteration += 1 + yield extraction + + logger.debug( + f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, " + f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} " + f"into {iteration} extractions in t={time.time() - t0:.2f} seconds using page-by-page splitting." + ) + return + + else: + # Standard parsing for non-override cases + async for text in self.parsers[document.document_type].ingest( + file_content, **ingestion_config_override + ): + if text is not None: + contents.append({"content": text}) + + if not contents: + logging.warning( + "No valid text content was extracted during parsing" + ) + return + + iteration = 0 + for content_item in contents: + chunk_text = content_item["content"] + chunks = self.chunk(chunk_text, ingestion_config_override) + + for chunk in chunks: + metadata = {**document.metadata, "chunk_order": iteration} + if "page_number" in content_item: + metadata["page_number"] = content_item["page_number"] + + extraction = DocumentChunk( + id=generate_extraction_id(document.id, iteration), + document_id=document.id, + owner_id=document.owner_id, + collection_ids=document.collection_ids, + data=chunk, + metadata=metadata, + ) + iteration += 1 + yield extraction + + logger.debug( + f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, " + f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} " + f"into {iteration} extractions in t={time.time() - t0:.2f} seconds." + ) + + def get_parser_for_document_type(self, doc_type: DocumentType) -> Any: + return self.parsers.get(doc_type) |