about summary refs log tree commit diff
path: root/R2R/r2r/main/services/ingestion_service.py
blob: 5677807a9f1a7d8e8482d8df143209f1d8c1e024 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
import json
import logging
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Union

from fastapi import Form, UploadFile

from r2r.base import (
    Document,
    DocumentInfo,
    DocumentType,
    KVLoggingSingleton,
    R2RDocumentProcessingError,
    R2RException,
    RunManager,
    generate_id_from_label,
    increment_version,
    to_async_generator,
)
from r2r.telemetry.telemetry_decorator import telemetry_event

from ..abstractions import R2RPipelines, R2RProviders
from ..api.requests import R2RIngestFilesRequest, R2RUpdateFilesRequest
from ..assembly.config import R2RConfig
from .base import Service

logger = logging.getLogger(__name__)
MB_CONVERSION_FACTOR = 1024 * 1024


class IngestionService(Service):
    def __init__(
        self,
        config: R2RConfig,
        providers: R2RProviders,
        pipelines: R2RPipelines,
        run_manager: RunManager,
        logging_connection: KVLoggingSingleton,
    ):
        super().__init__(
            config, providers, pipelines, run_manager, logging_connection
        )

    def _file_to_document(
        self, file: UploadFile, document_id: uuid.UUID, metadata: dict
    ) -> Document:
        file_extension = file.filename.split(".")[-1].lower()
        if file_extension.upper() not in DocumentType.__members__:
            raise R2RException(
                status_code=415,
                message=f"'{file_extension}' is not a valid DocumentType.",
            )

        document_title = (
            metadata.get("title", None) or file.filename.split("/")[-1]
        )
        metadata["title"] = document_title

        return Document(
            id=document_id,
            type=DocumentType[file_extension.upper()],
            data=file.file.read(),
            metadata=metadata,
        )

    @telemetry_event("IngestDocuments")
    async def ingest_documents(
        self,
        documents: list[Document],
        versions: Optional[list[str]] = None,
        *args: Any,
        **kwargs: Any,
    ):
        if len(documents) == 0:
            raise R2RException(
                status_code=400, message="No documents provided for ingestion."
            )

        document_infos = []
        skipped_documents = []
        processed_documents = {}
        duplicate_documents = defaultdict(list)

        existing_document_info = {
            doc_info.document_id: doc_info
            for doc_info in self.providers.vector_db.get_documents_overview()
        }

        for iteration, document in enumerate(documents):
            version = versions[iteration] if versions else "v0"

            # Check for duplicates within the current batch
            if document.id in processed_documents:
                duplicate_documents[document.id].append(
                    document.metadata.get("title", str(document.id))
                )
                continue

            if (
                document.id in existing_document_info
                and existing_document_info[document.id].version == version
                and existing_document_info[document.id].status == "success"
            ):
                logger.error(
                    f"Document with ID {document.id} was already successfully processed."
                )
                if len(documents) == 1:
                    raise R2RException(
                        status_code=409,
                        message=f"Document with ID {document.id} was already successfully processed.",
                    )
                skipped_documents.append(
                    (
                        document.id,
                        document.metadata.get("title", None)
                        or str(document.id),
                    )
                )
                continue

            now = datetime.now()
            document_infos.append(
                DocumentInfo(
                    document_id=document.id,
                    version=version,
                    size_in_bytes=len(document.data),
                    metadata=document.metadata.copy(),
                    title=document.metadata.get("title", str(document.id)),
                    user_id=document.metadata.get("user_id", None),
                    created_at=now,
                    updated_at=now,
                    status="processing",  # Set initial status to `processing`
                )
            )

            processed_documents[document.id] = document.metadata.get(
                "title", str(document.id)
            )

        if duplicate_documents:
            duplicate_details = [
                f"{doc_id}: {', '.join(titles)}"
                for doc_id, titles in duplicate_documents.items()
            ]
            warning_message = f"Duplicate documents detected: {'; '.join(duplicate_details)}. These duplicates were skipped."
            raise R2RException(status_code=418, message=warning_message)

        if skipped_documents and len(skipped_documents) == len(documents):
            logger.error("All provided documents already exist.")
            raise R2RException(
                status_code=409,
                message="All provided documents already exist. Use the `update_documents` endpoint instead to update these documents.",
            )

        # Insert pending document infos
        self.providers.vector_db.upsert_documents_overview(document_infos)
        ingestion_results = await self.pipelines.ingestion_pipeline.run(
            input=to_async_generator(
                [
                    doc
                    for doc in documents
                    if doc.id
                    not in [skipped[0] for skipped in skipped_documents]
                ]
            ),
            versions=[info.version for info in document_infos],
            run_manager=self.run_manager,
            *args,
            **kwargs,
        )

        return await self._process_ingestion_results(
            ingestion_results,
            document_infos,
            skipped_documents,
            processed_documents,
        )

    @telemetry_event("IngestFiles")
    async def ingest_files(
        self,
        files: list[UploadFile],
        metadatas: Optional[list[dict]] = None,
        document_ids: Optional[list[uuid.UUID]] = None,
        versions: Optional[list[str]] = None,
        *args: Any,
        **kwargs: Any,
    ):
        if not files:
            raise R2RException(
                status_code=400, message="No files provided for ingestion."
            )

        try:
            documents = []
            for iteration, file in enumerate(files):
                logger.info(f"Processing file: {file.filename}")
                if (
                    file.size
                    > self.config.app.get("max_file_size_in_mb", 32)
                    * MB_CONVERSION_FACTOR
                ):
                    raise R2RException(
                        status_code=413,
                        message=f"File size exceeds maximum allowed size: {file.filename}",
                    )
                if not file.filename:
                    raise R2RException(
                        status_code=400, message="File name not provided."
                    )

                document_metadata = metadatas[iteration] if metadatas else {}
                document_id = (
                    document_ids[iteration]
                    if document_ids
                    else generate_id_from_label(file.filename.split("/")[-1])
                )

                document = self._file_to_document(
                    file, document_id, document_metadata
                )
                documents.append(document)

            return await self.ingest_documents(
                documents, versions, *args, **kwargs
            )

        finally:
            for file in files:
                file.file.close()

    @telemetry_event("UpdateFiles")
    async def update_files(
        self,
        files: list[UploadFile],
        document_ids: list[uuid.UUID],
        metadatas: Optional[list[dict]] = None,
        *args: Any,
        **kwargs: Any,
    ):
        if not files:
            raise R2RException(
                status_code=400, message="No files provided for update."
            )

        try:
            if len(document_ids) != len(files):
                raise R2RException(
                    status_code=400,
                    message="Number of ids does not match number of files.",
                )

            documents_overview = await self._documents_overview(
                document_ids=document_ids
            )
            if len(documents_overview) != len(files):
                raise R2RException(
                    status_code=404,
                    message="One or more documents was not found.",
                )

            documents = []
            new_versions = []

            for it, (file, doc_id, doc_info) in enumerate(
                zip(files, document_ids, documents_overview)
            ):
                if not doc_info:
                    raise R2RException(
                        status_code=404,
                        message=f"Document with id {doc_id} not found.",
                    )

                new_version = increment_version(doc_info.version)
                new_versions.append(new_version)

                updated_metadata = (
                    metadatas[it] if metadatas else doc_info.metadata
                )
                updated_metadata["title"] = (
                    updated_metadata.get("title", None)
                    or file.filename.split("/")[-1]
                )

                document = self._file_to_document(
                    file, doc_id, updated_metadata
                )
                documents.append(document)

            ingestion_results = await self.ingest_documents(
                documents, versions=new_versions, *args, **kwargs
            )

            for doc_id, old_version in zip(
                document_ids,
                [doc_info.version for doc_info in documents_overview],
            ):
                await self._delete(
                    ["document_id", "version"], [str(doc_id), old_version]
                )
                self.providers.vector_db.delete_from_documents_overview(
                    doc_id, old_version
                )

            return ingestion_results

        finally:
            for file in files:
                file.file.close()

    async def _process_ingestion_results(
        self,
        ingestion_results: dict,
        document_infos: list[DocumentInfo],
        skipped_documents: list[tuple[str, str]],
        processed_documents: dict,
    ):
        skipped_ids = [ele[0] for ele in skipped_documents]
        failed_ids = []
        successful_ids = []

        results = {}
        if ingestion_results["embedding_pipeline_output"]:
            results = {
                k: v for k, v in ingestion_results["embedding_pipeline_output"]
            }
            for doc_id, error in results.items():
                if isinstance(error, R2RDocumentProcessingError):
                    logger.error(
                        f"Error processing document with ID {error.document_id}: {error.message}"
                    )
                    failed_ids.append(error.document_id)
                elif isinstance(error, Exception):
                    logger.error(f"Error processing document: {error}")
                    failed_ids.append(doc_id)
                else:
                    successful_ids.append(doc_id)

        documents_to_upsert = []
        for document_info in document_infos:
            if document_info.document_id not in skipped_ids:
                if document_info.document_id in failed_ids:
                    document_info.status = "failure"
                elif document_info.document_id in successful_ids:
                    document_info.status = "success"
                documents_to_upsert.append(document_info)

        if documents_to_upsert:
            self.providers.vector_db.upsert_documents_overview(
                documents_to_upsert
            )

        results = {
            "processed_documents": [
                f"Document '{processed_documents[document_id]}' processed successfully."
                for document_id in successful_ids
            ],
            "failed_documents": [
                f"Document '{processed_documents[document_id]}': {results[document_id]}"
                for document_id in failed_ids
            ],
            "skipped_documents": [
                f"Document '{filename}' skipped since it already exists."
                for _, filename in skipped_documents
            ],
        }

        # TODO - Clean up logging for document parse results
        run_ids = list(self.run_manager.run_info.keys())
        if run_ids:
            run_id = run_ids[0]
            for key in results:
                if key in ["processed_documents", "failed_documents"]:
                    for value in results[key]:
                        await self.logging_connection.log(
                            log_id=run_id,
                            key="document_parse_result",
                            value=value,
                        )
        return results

    @staticmethod
    def parse_ingest_files_form_data(
        metadatas: Optional[str] = Form(None),
        document_ids: str = Form(None),
        versions: Optional[str] = Form(None),
    ) -> R2RIngestFilesRequest:
        try:
            parsed_metadatas = (
                json.loads(metadatas)
                if metadatas and metadatas != "null"
                else None
            )
            if parsed_metadatas is not None and not isinstance(
                parsed_metadatas, list
            ):
                raise ValueError("metadatas must be a list of dictionaries")

            parsed_document_ids = (
                json.loads(document_ids)
                if document_ids and document_ids != "null"
                else None
            )
            if parsed_document_ids is not None:
                parsed_document_ids = [
                    uuid.UUID(doc_id) for doc_id in parsed_document_ids
                ]

            parsed_versions = (
                json.loads(versions)
                if versions and versions != "null"
                else None
            )

            request_data = {
                "metadatas": parsed_metadatas,
                "document_ids": parsed_document_ids,
                "versions": parsed_versions,
            }
            return R2RIngestFilesRequest(**request_data)
        except json.JSONDecodeError as e:
            raise R2RException(
                status_code=400, message=f"Invalid JSON in form data: {e}"
            )
        except ValueError as e:
            raise R2RException(status_code=400, message=str(e))
        except Exception as e:
            raise R2RException(
                status_code=400, message=f"Error processing form data: {e}"
            )

    @staticmethod
    def parse_update_files_form_data(
        metadatas: Optional[str] = Form(None),
        document_ids: str = Form(...),
    ) -> R2RUpdateFilesRequest:
        try:
            parsed_metadatas = (
                json.loads(metadatas)
                if metadatas and metadatas != "null"
                else None
            )
            if parsed_metadatas is not None and not isinstance(
                parsed_metadatas, list
            ):
                raise ValueError("metadatas must be a list of dictionaries")

            if not document_ids or document_ids == "null":
                raise ValueError("document_ids is required and cannot be null")

            parsed_document_ids = json.loads(document_ids)
            if not isinstance(parsed_document_ids, list):
                raise ValueError("document_ids must be a list")
            parsed_document_ids = [
                uuid.UUID(doc_id) for doc_id in parsed_document_ids
            ]

            request_data = {
                "metadatas": parsed_metadatas,
                "document_ids": parsed_document_ids,
            }
            return R2RUpdateFilesRequest(**request_data)
        except json.JSONDecodeError as e:
            raise R2RException(
                status_code=400, message=f"Invalid JSON in form data: {e}"
            )
        except ValueError as e:
            raise R2RException(status_code=400, message=str(e))
        except Exception as e:
            raise R2RException(
                status_code=400, message=f"Error processing form data: {e}"
            )

    # TODO - Move to mgmt service for document info, delete, post orchestration buildout
    async def _documents_overview(
        self,
        document_ids: Optional[list[uuid.UUID]] = None,
        user_ids: Optional[list[uuid.UUID]] = None,
        *args: Any,
        **kwargs: Any,
    ):
        return self.providers.vector_db.get_documents_overview(
            filter_document_ids=(
                [str(ele) for ele in document_ids] if document_ids else None
            ),
            filter_user_ids=(
                [str(ele) for ele in user_ids] if user_ids else None
            ),
        )

    async def _delete(
        self, keys: list[str], values: list[Union[bool, int, str]]
    ):
        logger.info(
            f"Deleting documents which match on these keys and values: ({keys}, {values})"
        )

        ids = self.providers.vector_db.delete_by_metadata(keys, values)
        if not ids:
            raise R2RException(
                status_code=404, message="No entries found for deletion."
            )
        return "Entries deleted successfully."