diff options
Diffstat (limited to 'R2R/r2r/base/abstractions/document.py')
-rwxr-xr-x | R2R/r2r/base/abstractions/document.py | 242 |
1 files changed, 242 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/document.py b/R2R/r2r/base/abstractions/document.py new file mode 100755 index 00000000..117db7b9 --- /dev/null +++ b/R2R/r2r/base/abstractions/document.py @@ -0,0 +1,242 @@ +"""Abstractions for documents and their extractions.""" + +import base64 +import json +import logging +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +DataType = Union[str, bytes] + + +class DocumentType(str, Enum): + """Types of documents that can be stored.""" + + CSV = "csv" + DOCX = "docx" + HTML = "html" + JSON = "json" + MD = "md" + PDF = "pdf" + PPTX = "pptx" + TXT = "txt" + XLSX = "xlsx" + GIF = "gif" + PNG = "png" + JPG = "jpg" + JPEG = "jpeg" + SVG = "svg" + MP3 = "mp3" + MP4 = "mp4" + + +class Document(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + type: DocumentType + data: Union[str, bytes] + metadata: dict + + def __init__(self, *args, **kwargs): + data = kwargs.get("data") + if data and isinstance(data, str): + try: + # Try to decode if it's already base64 encoded + kwargs["data"] = base64.b64decode(data) + except: + # If it's not base64, encode it to bytes + kwargs["data"] = data.encode("utf-8") + + doc_type = kwargs.get("type") + if isinstance(doc_type, str): + kwargs["type"] = DocumentType(doc_type) + + # Generate UUID based on the hash of the data + if "id" not in kwargs: + if isinstance(kwargs["data"], bytes): + data_hash = uuid.uuid5( + uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8") + ) + else: + data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"]) + + kwargs["id"] = data_hash # Set the id based on the data hash + + super().__init__(*args, **kwargs) + + class Config: + arbitrary_types_allowed = True + json_encoders = { + uuid.UUID: str, + bytes: lambda v: base64.b64encode(v).decode("utf-8"), + } + + +class DocumentStatus(str, Enum): + """Status of document processing.""" + + PROCESSING = "processing" + # TODO - Extend support for `partial-failure` + # PARTIAL_FAILURE = "partial-failure" + FAILURE = "failure" + SUCCESS = "success" + + +class DocumentInfo(BaseModel): + """Base class for document information handling.""" + + document_id: uuid.UUID + version: str + size_in_bytes: int + metadata: dict + status: DocumentStatus = DocumentStatus.PROCESSING + + user_id: Optional[uuid.UUID] = None + title: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + def convert_to_db_entry(self): + """Prepare the document info for database entry, extracting certain fields from metadata.""" + now = datetime.now() + metadata = self.metadata + if "user_id" in metadata: + metadata["user_id"] = str(metadata["user_id"]) + + metadata["title"] = metadata.get("title", "N/A") + return { + "document_id": str(self.document_id), + "title": metadata.get("title", "N/A"), + "user_id": metadata.get("user_id", None), + "version": self.version, + "size_in_bytes": self.size_in_bytes, + "metadata": json.dumps(self.metadata), + "created_at": self.created_at or now, + "updated_at": self.updated_at or now, + "status": self.status, + } + + +class ExtractionType(Enum): + """Types of extractions that can be performed.""" + + TXT = "txt" + IMG = "img" + MOV = "mov" + + +class Extraction(BaseModel): + """An extraction from a document.""" + + id: uuid.UUID + type: ExtractionType = ExtractionType.TXT + data: DataType + metadata: dict + document_id: uuid.UUID + + +class FragmentType(Enum): + """A type of fragment that can be extracted from a document.""" + + TEXT = "text" + IMAGE = "image" + + +class Fragment(BaseModel): + """A fragment extracted from a document.""" + + id: uuid.UUID + type: FragmentType + data: DataType + metadata: dict + document_id: uuid.UUID + extraction_id: uuid.UUID + + +class Entity(BaseModel): + """An entity extracted from a document.""" + + category: str + subcategory: Optional[str] = None + value: str + + def __str__(self): + return ( + f"{self.category}:{self.subcategory}:{self.value}" + if self.subcategory + else f"{self.category}:{self.value}" + ) + + +class Triple(BaseModel): + """A triple extracted from a document.""" + + subject: str + predicate: str + object: str + + +def extract_entities(llm_payload: list[str]) -> dict[str, Entity]: + entities = {} + for entry in llm_payload: + try: + if "], " in entry: # Check if the entry is an entity + entry_val = entry.split("], ")[0] + "]" + entry = entry.split("], ")[1] + colon_count = entry.count(":") + + if colon_count == 1: + category, value = entry.split(":") + subcategory = None + elif colon_count >= 2: + parts = entry.split(":", 2) + category, subcategory, value = ( + parts[0], + parts[1], + parts[2], + ) + else: + raise ValueError("Unexpected entry format") + + entities[entry_val] = Entity( + category=category, subcategory=subcategory, value=value + ) + except Exception as e: + logger.error(f"Error processing entity {entry}: {e}") + continue + return entities + + +def extract_triples( + llm_payload: list[str], entities: dict[str, Entity] +) -> list[Triple]: + triples = [] + for entry in llm_payload: + try: + if "], " not in entry: # Check if the entry is an entity + elements = entry.split(" ") + subject = elements[0] + predicate = elements[1] + object = " ".join(elements[2:]) + subject = entities[subject].value # Use entity.value + if "[" in object and "]" in object: + object = entities[object].value # Use entity.value + triples.append( + Triple(subject=subject, predicate=predicate, object=object) + ) + except Exception as e: + logger.error(f"Error processing triplet {entry}: {e}") + continue + return triples + + +class KGExtraction(BaseModel): + """An extraction from a document that is part of a knowledge graph.""" + + entities: dict[str, Entity] + triples: list[Triple] |