"""Abstractions for documents and their extractions."""
import base64
import json
import logging
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
DataType = Union[str, bytes]
class DocumentType(str, Enum):
"""Types of documents that can be stored."""
CSV = "csv"
DOCX = "docx"
HTML = "html"
JSON = "json"
MD = "md"
PDF = "pdf"
PPTX = "pptx"
TXT = "txt"
XLSX = "xlsx"
GIF = "gif"
PNG = "png"
JPG = "jpg"
JPEG = "jpeg"
SVG = "svg"
MP3 = "mp3"
MP4 = "mp4"
class Document(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
type: DocumentType
data: Union[str, bytes]
metadata: dict
def __init__(self, *args, **kwargs):
data = kwargs.get("data")
if data and isinstance(data, str):
try:
# Try to decode if it's already base64 encoded
kwargs["data"] = base64.b64decode(data)
except:
# If it's not base64, encode it to bytes
kwargs["data"] = data.encode("utf-8")
doc_type = kwargs.get("type")
if isinstance(doc_type, str):
kwargs["type"] = DocumentType(doc_type)
# Generate UUID based on the hash of the data
if "id" not in kwargs:
if isinstance(kwargs["data"], bytes):
data_hash = uuid.uuid5(
uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8")
)
else:
data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"])
kwargs["id"] = data_hash # Set the id based on the data hash
super().__init__(*args, **kwargs)
class Config:
arbitrary_types_allowed = True
json_encoders = {
uuid.UUID: str,
bytes: lambda v: base64.b64encode(v).decode("utf-8"),
}
class DocumentStatus(str, Enum):
"""Status of document processing."""
PROCESSING = "processing"
# TODO - Extend support for `partial-failure`
# PARTIAL_FAILURE = "partial-failure"
FAILURE = "failure"
SUCCESS = "success"
class DocumentInfo(BaseModel):
"""Base class for document information handling."""
document_id: uuid.UUID
version: str
size_in_bytes: int
metadata: dict
status: DocumentStatus = DocumentStatus.PROCESSING
user_id: Optional[uuid.UUID] = None
title: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
def convert_to_db_entry(self):
"""Prepare the document info for database entry, extracting certain fields from metadata."""
now = datetime.now()
metadata = self.metadata
if "user_id" in metadata:
metadata["user_id"] = str(metadata["user_id"])
metadata["title"] = metadata.get("title", "N/A")
return {
"document_id": str(self.document_id),
"title": metadata.get("title", "N/A"),
"user_id": metadata.get("user_id", None),
"version": self.version,
"size_in_bytes": self.size_in_bytes,
"metadata": json.dumps(self.metadata),
"created_at": self.created_at or now,
"updated_at": self.updated_at or now,
"status": self.status,
}
class ExtractionType(Enum):
"""Types of extractions that can be performed."""
TXT = "txt"
IMG = "img"
MOV = "mov"
class Extraction(BaseModel):
"""An extraction from a document."""
id: uuid.UUID
type: ExtractionType = ExtractionType.TXT
data: DataType
metadata: dict
document_id: uuid.UUID
class FragmentType(Enum):
"""A type of fragment that can be extracted from a document."""
TEXT = "text"
IMAGE = "image"
class Fragment(BaseModel):
"""A fragment extracted from a document."""
id: uuid.UUID
type: FragmentType
data: DataType
metadata: dict
document_id: uuid.UUID
extraction_id: uuid.UUID
class Entity(BaseModel):
"""An entity extracted from a document."""
category: str
subcategory: Optional[str] = None
value: str
def __str__(self):
return (
f"{self.category}:{self.subcategory}:{self.value}"
if self.subcategory
else f"{self.category}:{self.value}"
)
class Triple(BaseModel):
"""A triple extracted from a document."""
subject: str
predicate: str
object: str
def extract_entities(llm_payload: list[str]) -> dict[str, Entity]:
entities = {}
for entry in llm_payload:
try:
if "], " in entry: # Check if the entry is an entity
entry_val = entry.split("], ")[0] + "]"
entry = entry.split("], ")[1]
colon_count = entry.count(":")
if colon_count == 1:
category, value = entry.split(":")
subcategory = None
elif colon_count >= 2:
parts = entry.split(":", 2)
category, subcategory, value = (
parts[0],
parts[1],
parts[2],
)
else:
raise ValueError("Unexpected entry format")
entities[entry_val] = Entity(
category=category, subcategory=subcategory, value=value
)
except Exception as e:
logger.error(f"Error processing entity {entry}: {e}")
continue
return entities
def extract_triples(
llm_payload: list[str], entities: dict[str, Entity]
) -> list[Triple]:
triples = []
for entry in llm_payload:
try:
if "], " not in entry: # Check if the entry is an entity
elements = entry.split(" ")
subject = elements[0]
predicate = elements[1]
object = " ".join(elements[2:])
subject = entities[subject].value # Use entity.value
if "[" in object and "]" in object:
object = entities[object].value # Use entity.value
triples.append(
Triple(subject=subject, predicate=predicate, object=object)
)
except Exception as e:
logger.error(f"Error processing triplet {entry}: {e}")
continue
return triples
class KGExtraction(BaseModel):
"""An extraction from a document that is part of a knowledge graph."""
entities: dict[str, Entity]
triples: list[Triple]