aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/services/ingestion_service.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/services/ingestion_service.py')
-rwxr-xr-xR2R/r2r/main/services/ingestion_service.py505
1 files changed, 505 insertions, 0 deletions
diff --git a/R2R/r2r/main/services/ingestion_service.py b/R2R/r2r/main/services/ingestion_service.py
new file mode 100755
index 00000000..5677807a
--- /dev/null
+++ b/R2R/r2r/main/services/ingestion_service.py
@@ -0,0 +1,505 @@
+import json
+import logging
+import uuid
+from collections import defaultdict
+from datetime import datetime
+from typing import Any, Optional, Union
+
+from fastapi import Form, UploadFile
+
+from r2r.base import (
+ Document,
+ DocumentInfo,
+ DocumentType,
+ KVLoggingSingleton,
+ R2RDocumentProcessingError,
+ R2RException,
+ RunManager,
+ generate_id_from_label,
+ increment_version,
+ to_async_generator,
+)
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..api.requests import R2RIngestFilesRequest, R2RUpdateFilesRequest
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+MB_CONVERSION_FACTOR = 1024 * 1024
+
+
+class IngestionService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ super().__init__(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+ def _file_to_document(
+ self, file: UploadFile, document_id: uuid.UUID, metadata: dict
+ ) -> Document:
+ file_extension = file.filename.split(".")[-1].lower()
+ if file_extension.upper() not in DocumentType.__members__:
+ raise R2RException(
+ status_code=415,
+ message=f"'{file_extension}' is not a valid DocumentType.",
+ )
+
+ document_title = (
+ metadata.get("title", None) or file.filename.split("/")[-1]
+ )
+ metadata["title"] = document_title
+
+ return Document(
+ id=document_id,
+ type=DocumentType[file_extension.upper()],
+ data=file.file.read(),
+ metadata=metadata,
+ )
+
+ @telemetry_event("IngestDocuments")
+ async def ingest_documents(
+ self,
+ documents: list[Document],
+ versions: Optional[list[str]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if len(documents) == 0:
+ raise R2RException(
+ status_code=400, message="No documents provided for ingestion."
+ )
+
+ document_infos = []
+ skipped_documents = []
+ processed_documents = {}
+ duplicate_documents = defaultdict(list)
+
+ existing_document_info = {
+ doc_info.document_id: doc_info
+ for doc_info in self.providers.vector_db.get_documents_overview()
+ }
+
+ for iteration, document in enumerate(documents):
+ version = versions[iteration] if versions else "v0"
+
+ # Check for duplicates within the current batch
+ if document.id in processed_documents:
+ duplicate_documents[document.id].append(
+ document.metadata.get("title", str(document.id))
+ )
+ continue
+
+ if (
+ document.id in existing_document_info
+ and existing_document_info[document.id].version == version
+ and existing_document_info[document.id].status == "success"
+ ):
+ logger.error(
+ f"Document with ID {document.id} was already successfully processed."
+ )
+ if len(documents) == 1:
+ raise R2RException(
+ status_code=409,
+ message=f"Document with ID {document.id} was already successfully processed.",
+ )
+ skipped_documents.append(
+ (
+ document.id,
+ document.metadata.get("title", None)
+ or str(document.id),
+ )
+ )
+ continue
+
+ now = datetime.now()
+ document_infos.append(
+ DocumentInfo(
+ document_id=document.id,
+ version=version,
+ size_in_bytes=len(document.data),
+ metadata=document.metadata.copy(),
+ title=document.metadata.get("title", str(document.id)),
+ user_id=document.metadata.get("user_id", None),
+ created_at=now,
+ updated_at=now,
+ status="processing", # Set initial status to `processing`
+ )
+ )
+
+ processed_documents[document.id] = document.metadata.get(
+ "title", str(document.id)
+ )
+
+ if duplicate_documents:
+ duplicate_details = [
+ f"{doc_id}: {', '.join(titles)}"
+ for doc_id, titles in duplicate_documents.items()
+ ]
+ warning_message = f"Duplicate documents detected: {'; '.join(duplicate_details)}. These duplicates were skipped."
+ raise R2RException(status_code=418, message=warning_message)
+
+ if skipped_documents and len(skipped_documents) == len(documents):
+ logger.error("All provided documents already exist.")
+ raise R2RException(
+ status_code=409,
+ message="All provided documents already exist. Use the `update_documents` endpoint instead to update these documents.",
+ )
+
+ # Insert pending document infos
+ self.providers.vector_db.upsert_documents_overview(document_infos)
+ ingestion_results = await self.pipelines.ingestion_pipeline.run(
+ input=to_async_generator(
+ [
+ doc
+ for doc in documents
+ if doc.id
+ not in [skipped[0] for skipped in skipped_documents]
+ ]
+ ),
+ versions=[info.version for info in document_infos],
+ run_manager=self.run_manager,
+ *args,
+ **kwargs,
+ )
+
+ return await self._process_ingestion_results(
+ ingestion_results,
+ document_infos,
+ skipped_documents,
+ processed_documents,
+ )
+
+ @telemetry_event("IngestFiles")
+ async def ingest_files(
+ self,
+ files: list[UploadFile],
+ metadatas: Optional[list[dict]] = None,
+ document_ids: Optional[list[uuid.UUID]] = None,
+ versions: Optional[list[str]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if not files:
+ raise R2RException(
+ status_code=400, message="No files provided for ingestion."
+ )
+
+ try:
+ documents = []
+ for iteration, file in enumerate(files):
+ logger.info(f"Processing file: {file.filename}")
+ if (
+ file.size
+ > self.config.app.get("max_file_size_in_mb", 32)
+ * MB_CONVERSION_FACTOR
+ ):
+ raise R2RException(
+ status_code=413,
+ message=f"File size exceeds maximum allowed size: {file.filename}",
+ )
+ if not file.filename:
+ raise R2RException(
+ status_code=400, message="File name not provided."
+ )
+
+ document_metadata = metadatas[iteration] if metadatas else {}
+ document_id = (
+ document_ids[iteration]
+ if document_ids
+ else generate_id_from_label(file.filename.split("/")[-1])
+ )
+
+ document = self._file_to_document(
+ file, document_id, document_metadata
+ )
+ documents.append(document)
+
+ return await self.ingest_documents(
+ documents, versions, *args, **kwargs
+ )
+
+ finally:
+ for file in files:
+ file.file.close()
+
+ @telemetry_event("UpdateFiles")
+ async def update_files(
+ self,
+ files: list[UploadFile],
+ document_ids: list[uuid.UUID],
+ metadatas: Optional[list[dict]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if not files:
+ raise R2RException(
+ status_code=400, message="No files provided for update."
+ )
+
+ try:
+ if len(document_ids) != len(files):
+ raise R2RException(
+ status_code=400,
+ message="Number of ids does not match number of files.",
+ )
+
+ documents_overview = await self._documents_overview(
+ document_ids=document_ids
+ )
+ if len(documents_overview) != len(files):
+ raise R2RException(
+ status_code=404,
+ message="One or more documents was not found.",
+ )
+
+ documents = []
+ new_versions = []
+
+ for it, (file, doc_id, doc_info) in enumerate(
+ zip(files, document_ids, documents_overview)
+ ):
+ if not doc_info:
+ raise R2RException(
+ status_code=404,
+ message=f"Document with id {doc_id} not found.",
+ )
+
+ new_version = increment_version(doc_info.version)
+ new_versions.append(new_version)
+
+ updated_metadata = (
+ metadatas[it] if metadatas else doc_info.metadata
+ )
+ updated_metadata["title"] = (
+ updated_metadata.get("title", None)
+ or file.filename.split("/")[-1]
+ )
+
+ document = self._file_to_document(
+ file, doc_id, updated_metadata
+ )
+ documents.append(document)
+
+ ingestion_results = await self.ingest_documents(
+ documents, versions=new_versions, *args, **kwargs
+ )
+
+ for doc_id, old_version in zip(
+ document_ids,
+ [doc_info.version for doc_info in documents_overview],
+ ):
+ await self._delete(
+ ["document_id", "version"], [str(doc_id), old_version]
+ )
+ self.providers.vector_db.delete_from_documents_overview(
+ doc_id, old_version
+ )
+
+ return ingestion_results
+
+ finally:
+ for file in files:
+ file.file.close()
+
+ async def _process_ingestion_results(
+ self,
+ ingestion_results: dict,
+ document_infos: list[DocumentInfo],
+ skipped_documents: list[tuple[str, str]],
+ processed_documents: dict,
+ ):
+ skipped_ids = [ele[0] for ele in skipped_documents]
+ failed_ids = []
+ successful_ids = []
+
+ results = {}
+ if ingestion_results["embedding_pipeline_output"]:
+ results = {
+ k: v for k, v in ingestion_results["embedding_pipeline_output"]
+ }
+ for doc_id, error in results.items():
+ if isinstance(error, R2RDocumentProcessingError):
+ logger.error(
+ f"Error processing document with ID {error.document_id}: {error.message}"
+ )
+ failed_ids.append(error.document_id)
+ elif isinstance(error, Exception):
+ logger.error(f"Error processing document: {error}")
+ failed_ids.append(doc_id)
+ else:
+ successful_ids.append(doc_id)
+
+ documents_to_upsert = []
+ for document_info in document_infos:
+ if document_info.document_id not in skipped_ids:
+ if document_info.document_id in failed_ids:
+ document_info.status = "failure"
+ elif document_info.document_id in successful_ids:
+ document_info.status = "success"
+ documents_to_upsert.append(document_info)
+
+ if documents_to_upsert:
+ self.providers.vector_db.upsert_documents_overview(
+ documents_to_upsert
+ )
+
+ results = {
+ "processed_documents": [
+ f"Document '{processed_documents[document_id]}' processed successfully."
+ for document_id in successful_ids
+ ],
+ "failed_documents": [
+ f"Document '{processed_documents[document_id]}': {results[document_id]}"
+ for document_id in failed_ids
+ ],
+ "skipped_documents": [
+ f"Document '{filename}' skipped since it already exists."
+ for _, filename in skipped_documents
+ ],
+ }
+
+ # TODO - Clean up logging for document parse results
+ run_ids = list(self.run_manager.run_info.keys())
+ if run_ids:
+ run_id = run_ids[0]
+ for key in results:
+ if key in ["processed_documents", "failed_documents"]:
+ for value in results[key]:
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="document_parse_result",
+ value=value,
+ )
+ return results
+
+ @staticmethod
+ def parse_ingest_files_form_data(
+ metadatas: Optional[str] = Form(None),
+ document_ids: str = Form(None),
+ versions: Optional[str] = Form(None),
+ ) -> R2RIngestFilesRequest:
+ try:
+ parsed_metadatas = (
+ json.loads(metadatas)
+ if metadatas and metadatas != "null"
+ else None
+ )
+ if parsed_metadatas is not None and not isinstance(
+ parsed_metadatas, list
+ ):
+ raise ValueError("metadatas must be a list of dictionaries")
+
+ parsed_document_ids = (
+ json.loads(document_ids)
+ if document_ids and document_ids != "null"
+ else None
+ )
+ if parsed_document_ids is not None:
+ parsed_document_ids = [
+ uuid.UUID(doc_id) for doc_id in parsed_document_ids
+ ]
+
+ parsed_versions = (
+ json.loads(versions)
+ if versions and versions != "null"
+ else None
+ )
+
+ request_data = {
+ "metadatas": parsed_metadatas,
+ "document_ids": parsed_document_ids,
+ "versions": parsed_versions,
+ }
+ return R2RIngestFilesRequest(**request_data)
+ except json.JSONDecodeError as e:
+ raise R2RException(
+ status_code=400, message=f"Invalid JSON in form data: {e}"
+ )
+ except ValueError as e:
+ raise R2RException(status_code=400, message=str(e))
+ except Exception as e:
+ raise R2RException(
+ status_code=400, message=f"Error processing form data: {e}"
+ )
+
+ @staticmethod
+ def parse_update_files_form_data(
+ metadatas: Optional[str] = Form(None),
+ document_ids: str = Form(...),
+ ) -> R2RUpdateFilesRequest:
+ try:
+ parsed_metadatas = (
+ json.loads(metadatas)
+ if metadatas and metadatas != "null"
+ else None
+ )
+ if parsed_metadatas is not None and not isinstance(
+ parsed_metadatas, list
+ ):
+ raise ValueError("metadatas must be a list of dictionaries")
+
+ if not document_ids or document_ids == "null":
+ raise ValueError("document_ids is required and cannot be null")
+
+ parsed_document_ids = json.loads(document_ids)
+ if not isinstance(parsed_document_ids, list):
+ raise ValueError("document_ids must be a list")
+ parsed_document_ids = [
+ uuid.UUID(doc_id) for doc_id in parsed_document_ids
+ ]
+
+ request_data = {
+ "metadatas": parsed_metadatas,
+ "document_ids": parsed_document_ids,
+ }
+ return R2RUpdateFilesRequest(**request_data)
+ except json.JSONDecodeError as e:
+ raise R2RException(
+ status_code=400, message=f"Invalid JSON in form data: {e}"
+ )
+ except ValueError as e:
+ raise R2RException(status_code=400, message=str(e))
+ except Exception as e:
+ raise R2RException(
+ status_code=400, message=f"Error processing form data: {e}"
+ )
+
+ # TODO - Move to mgmt service for document info, delete, post orchestration buildout
+ async def _documents_overview(
+ self,
+ document_ids: Optional[list[uuid.UUID]] = None,
+ user_ids: Optional[list[uuid.UUID]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ return self.providers.vector_db.get_documents_overview(
+ filter_document_ids=(
+ [str(ele) for ele in document_ids] if document_ids else None
+ ),
+ filter_user_ids=(
+ [str(ele) for ele in user_ids] if user_ids else None
+ ),
+ )
+
+ async def _delete(
+ self, keys: list[str], values: list[Union[bool, int, str]]
+ ):
+ logger.info(
+ f"Deleting documents which match on these keys and values: ({keys}, {values})"
+ )
+
+ ids = self.providers.vector_db.delete_by_metadata(keys, values)
+ if not ids:
+ raise R2RException(
+ status_code=404, message="No entries found for deletion."
+ )
+ return "Entries deleted successfully."