aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/abstractions/document.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/abstractions/document.py')
-rwxr-xr-xR2R/r2r/base/abstractions/document.py242
1 files changed, 242 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/document.py b/R2R/r2r/base/abstractions/document.py
new file mode 100755
index 00000000..117db7b9
--- /dev/null
+++ b/R2R/r2r/base/abstractions/document.py
@@ -0,0 +1,242 @@
+"""Abstractions for documents and their extractions."""
+
+import base64
+import json
+import logging
+import uuid
+from datetime import datetime
+from enum import Enum
+from typing import Optional, Union
+
+from pydantic import BaseModel, Field
+
+logger = logging.getLogger(__name__)
+
+DataType = Union[str, bytes]
+
+
+class DocumentType(str, Enum):
+ """Types of documents that can be stored."""
+
+ CSV = "csv"
+ DOCX = "docx"
+ HTML = "html"
+ JSON = "json"
+ MD = "md"
+ PDF = "pdf"
+ PPTX = "pptx"
+ TXT = "txt"
+ XLSX = "xlsx"
+ GIF = "gif"
+ PNG = "png"
+ JPG = "jpg"
+ JPEG = "jpeg"
+ SVG = "svg"
+ MP3 = "mp3"
+ MP4 = "mp4"
+
+
+class Document(BaseModel):
+ id: uuid.UUID = Field(default_factory=uuid.uuid4)
+ type: DocumentType
+ data: Union[str, bytes]
+ metadata: dict
+
+ def __init__(self, *args, **kwargs):
+ data = kwargs.get("data")
+ if data and isinstance(data, str):
+ try:
+ # Try to decode if it's already base64 encoded
+ kwargs["data"] = base64.b64decode(data)
+ except:
+ # If it's not base64, encode it to bytes
+ kwargs["data"] = data.encode("utf-8")
+
+ doc_type = kwargs.get("type")
+ if isinstance(doc_type, str):
+ kwargs["type"] = DocumentType(doc_type)
+
+ # Generate UUID based on the hash of the data
+ if "id" not in kwargs:
+ if isinstance(kwargs["data"], bytes):
+ data_hash = uuid.uuid5(
+ uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8")
+ )
+ else:
+ data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"])
+
+ kwargs["id"] = data_hash # Set the id based on the data hash
+
+ super().__init__(*args, **kwargs)
+
+ class Config:
+ arbitrary_types_allowed = True
+ json_encoders = {
+ uuid.UUID: str,
+ bytes: lambda v: base64.b64encode(v).decode("utf-8"),
+ }
+
+
+class DocumentStatus(str, Enum):
+ """Status of document processing."""
+
+ PROCESSING = "processing"
+ # TODO - Extend support for `partial-failure`
+ # PARTIAL_FAILURE = "partial-failure"
+ FAILURE = "failure"
+ SUCCESS = "success"
+
+
+class DocumentInfo(BaseModel):
+ """Base class for document information handling."""
+
+ document_id: uuid.UUID
+ version: str
+ size_in_bytes: int
+ metadata: dict
+ status: DocumentStatus = DocumentStatus.PROCESSING
+
+ user_id: Optional[uuid.UUID] = None
+ title: Optional[str] = None
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+
+ def convert_to_db_entry(self):
+ """Prepare the document info for database entry, extracting certain fields from metadata."""
+ now = datetime.now()
+ metadata = self.metadata
+ if "user_id" in metadata:
+ metadata["user_id"] = str(metadata["user_id"])
+
+ metadata["title"] = metadata.get("title", "N/A")
+ return {
+ "document_id": str(self.document_id),
+ "title": metadata.get("title", "N/A"),
+ "user_id": metadata.get("user_id", None),
+ "version": self.version,
+ "size_in_bytes": self.size_in_bytes,
+ "metadata": json.dumps(self.metadata),
+ "created_at": self.created_at or now,
+ "updated_at": self.updated_at or now,
+ "status": self.status,
+ }
+
+
+class ExtractionType(Enum):
+ """Types of extractions that can be performed."""
+
+ TXT = "txt"
+ IMG = "img"
+ MOV = "mov"
+
+
+class Extraction(BaseModel):
+ """An extraction from a document."""
+
+ id: uuid.UUID
+ type: ExtractionType = ExtractionType.TXT
+ data: DataType
+ metadata: dict
+ document_id: uuid.UUID
+
+
+class FragmentType(Enum):
+ """A type of fragment that can be extracted from a document."""
+
+ TEXT = "text"
+ IMAGE = "image"
+
+
+class Fragment(BaseModel):
+ """A fragment extracted from a document."""
+
+ id: uuid.UUID
+ type: FragmentType
+ data: DataType
+ metadata: dict
+ document_id: uuid.UUID
+ extraction_id: uuid.UUID
+
+
+class Entity(BaseModel):
+ """An entity extracted from a document."""
+
+ category: str
+ subcategory: Optional[str] = None
+ value: str
+
+ def __str__(self):
+ return (
+ f"{self.category}:{self.subcategory}:{self.value}"
+ if self.subcategory
+ else f"{self.category}:{self.value}"
+ )
+
+
+class Triple(BaseModel):
+ """A triple extracted from a document."""
+
+ subject: str
+ predicate: str
+ object: str
+
+
+def extract_entities(llm_payload: list[str]) -> dict[str, Entity]:
+ entities = {}
+ for entry in llm_payload:
+ try:
+ if "], " in entry: # Check if the entry is an entity
+ entry_val = entry.split("], ")[0] + "]"
+ entry = entry.split("], ")[1]
+ colon_count = entry.count(":")
+
+ if colon_count == 1:
+ category, value = entry.split(":")
+ subcategory = None
+ elif colon_count >= 2:
+ parts = entry.split(":", 2)
+ category, subcategory, value = (
+ parts[0],
+ parts[1],
+ parts[2],
+ )
+ else:
+ raise ValueError("Unexpected entry format")
+
+ entities[entry_val] = Entity(
+ category=category, subcategory=subcategory, value=value
+ )
+ except Exception as e:
+ logger.error(f"Error processing entity {entry}: {e}")
+ continue
+ return entities
+
+
+def extract_triples(
+ llm_payload: list[str], entities: dict[str, Entity]
+) -> list[Triple]:
+ triples = []
+ for entry in llm_payload:
+ try:
+ if "], " not in entry: # Check if the entry is an entity
+ elements = entry.split(" ")
+ subject = elements[0]
+ predicate = elements[1]
+ object = " ".join(elements[2:])
+ subject = entities[subject].value # Use entity.value
+ if "[" in object and "]" in object:
+ object = entities[object].value # Use entity.value
+ triples.append(
+ Triple(subject=subject, predicate=predicate, object=object)
+ )
+ except Exception as e:
+ logger.error(f"Error processing triplet {entry}: {e}")
+ continue
+ return triples
+
+
+class KGExtraction(BaseModel):
+ """An extraction from a document that is part of a knowledge graph."""
+
+ entities: dict[str, Entity]
+ triples: list[Triple]