about summary refs log tree commit diff
path: root/R2R/r2r/base/abstractions/document.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/abstractions/document.py')
-rwxr-xr-xR2R/r2r/base/abstractions/document.py242
1 files changed, 242 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/document.py b/R2R/r2r/base/abstractions/document.py
new file mode 100755
index 00000000..117db7b9
--- /dev/null
+++ b/R2R/r2r/base/abstractions/document.py
@@ -0,0 +1,242 @@
+"""Abstractions for documents and their extractions."""
+
+import base64
+import json
+import logging
+import uuid
+from datetime import datetime
+from enum import Enum
+from typing import Optional, Union
+
+from pydantic import BaseModel, Field
+
+logger = logging.getLogger(__name__)
+
+DataType = Union[str, bytes]
+
+
+class DocumentType(str, Enum):
+    """Types of documents that can be stored."""
+
+    CSV = "csv"
+    DOCX = "docx"
+    HTML = "html"
+    JSON = "json"
+    MD = "md"
+    PDF = "pdf"
+    PPTX = "pptx"
+    TXT = "txt"
+    XLSX = "xlsx"
+    GIF = "gif"
+    PNG = "png"
+    JPG = "jpg"
+    JPEG = "jpeg"
+    SVG = "svg"
+    MP3 = "mp3"
+    MP4 = "mp4"
+
+
+class Document(BaseModel):
+    id: uuid.UUID = Field(default_factory=uuid.uuid4)
+    type: DocumentType
+    data: Union[str, bytes]
+    metadata: dict
+
+    def __init__(self, *args, **kwargs):
+        data = kwargs.get("data")
+        if data and isinstance(data, str):
+            try:
+                # Try to decode if it's already base64 encoded
+                kwargs["data"] = base64.b64decode(data)
+            except:
+                # If it's not base64, encode it to bytes
+                kwargs["data"] = data.encode("utf-8")
+
+        doc_type = kwargs.get("type")
+        if isinstance(doc_type, str):
+            kwargs["type"] = DocumentType(doc_type)
+
+        # Generate UUID based on the hash of the data
+        if "id" not in kwargs:
+            if isinstance(kwargs["data"], bytes):
+                data_hash = uuid.uuid5(
+                    uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8")
+                )
+            else:
+                data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"])
+
+            kwargs["id"] = data_hash  # Set the id based on the data hash
+
+        super().__init__(*args, **kwargs)
+
+    class Config:
+        arbitrary_types_allowed = True
+        json_encoders = {
+            uuid.UUID: str,
+            bytes: lambda v: base64.b64encode(v).decode("utf-8"),
+        }
+
+
+class DocumentStatus(str, Enum):
+    """Status of document processing."""
+
+    PROCESSING = "processing"
+    # TODO - Extend support for `partial-failure`
+    # PARTIAL_FAILURE = "partial-failure"
+    FAILURE = "failure"
+    SUCCESS = "success"
+
+
+class DocumentInfo(BaseModel):
+    """Base class for document information handling."""
+
+    document_id: uuid.UUID
+    version: str
+    size_in_bytes: int
+    metadata: dict
+    status: DocumentStatus = DocumentStatus.PROCESSING
+
+    user_id: Optional[uuid.UUID] = None
+    title: Optional[str] = None
+    created_at: Optional[datetime] = None
+    updated_at: Optional[datetime] = None
+
+    def convert_to_db_entry(self):
+        """Prepare the document info for database entry, extracting certain fields from metadata."""
+        now = datetime.now()
+        metadata = self.metadata
+        if "user_id" in metadata:
+            metadata["user_id"] = str(metadata["user_id"])
+
+        metadata["title"] = metadata.get("title", "N/A")
+        return {
+            "document_id": str(self.document_id),
+            "title": metadata.get("title", "N/A"),
+            "user_id": metadata.get("user_id", None),
+            "version": self.version,
+            "size_in_bytes": self.size_in_bytes,
+            "metadata": json.dumps(self.metadata),
+            "created_at": self.created_at or now,
+            "updated_at": self.updated_at or now,
+            "status": self.status,
+        }
+
+
+class ExtractionType(Enum):
+    """Types of extractions that can be performed."""
+
+    TXT = "txt"
+    IMG = "img"
+    MOV = "mov"
+
+
+class Extraction(BaseModel):
+    """An extraction from a document."""
+
+    id: uuid.UUID
+    type: ExtractionType = ExtractionType.TXT
+    data: DataType
+    metadata: dict
+    document_id: uuid.UUID
+
+
+class FragmentType(Enum):
+    """A type of fragment that can be extracted from a document."""
+
+    TEXT = "text"
+    IMAGE = "image"
+
+
+class Fragment(BaseModel):
+    """A fragment extracted from a document."""
+
+    id: uuid.UUID
+    type: FragmentType
+    data: DataType
+    metadata: dict
+    document_id: uuid.UUID
+    extraction_id: uuid.UUID
+
+
+class Entity(BaseModel):
+    """An entity extracted from a document."""
+
+    category: str
+    subcategory: Optional[str] = None
+    value: str
+
+    def __str__(self):
+        return (
+            f"{self.category}:{self.subcategory}:{self.value}"
+            if self.subcategory
+            else f"{self.category}:{self.value}"
+        )
+
+
+class Triple(BaseModel):
+    """A triple extracted from a document."""
+
+    subject: str
+    predicate: str
+    object: str
+
+
+def extract_entities(llm_payload: list[str]) -> dict[str, Entity]:
+    entities = {}
+    for entry in llm_payload:
+        try:
+            if "], " in entry:  # Check if the entry is an entity
+                entry_val = entry.split("], ")[0] + "]"
+                entry = entry.split("], ")[1]
+                colon_count = entry.count(":")
+
+                if colon_count == 1:
+                    category, value = entry.split(":")
+                    subcategory = None
+                elif colon_count >= 2:
+                    parts = entry.split(":", 2)
+                    category, subcategory, value = (
+                        parts[0],
+                        parts[1],
+                        parts[2],
+                    )
+                else:
+                    raise ValueError("Unexpected entry format")
+
+                entities[entry_val] = Entity(
+                    category=category, subcategory=subcategory, value=value
+                )
+        except Exception as e:
+            logger.error(f"Error processing entity {entry}: {e}")
+            continue
+    return entities
+
+
+def extract_triples(
+    llm_payload: list[str], entities: dict[str, Entity]
+) -> list[Triple]:
+    triples = []
+    for entry in llm_payload:
+        try:
+            if "], " not in entry:  # Check if the entry is an entity
+                elements = entry.split(" ")
+                subject = elements[0]
+                predicate = elements[1]
+                object = " ".join(elements[2:])
+                subject = entities[subject].value  # Use entity.value
+                if "[" in object and "]" in object:
+                    object = entities[object].value  # Use entity.value
+                triples.append(
+                    Triple(subject=subject, predicate=predicate, object=object)
+                )
+        except Exception as e:
+            logger.error(f"Error processing triplet {entry}: {e}")
+            continue
+    return triples
+
+
+class KGExtraction(BaseModel):
+    """An extraction from a document that is part of a knowledge graph."""
+
+    entities: dict[str, Entity]
+    triples: list[Triple]