diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main/orchestration')
7 files changed, 2096 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py new file mode 100644 index 00000000..19cb0428 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py @@ -0,0 +1,16 @@ +# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments +from .hatchet.graph_workflow import ( # type: ignore + hatchet_graph_search_results_factory, +) +from .hatchet.ingestion_workflow import ( # type: ignore + hatchet_ingestion_factory, +) +from .simple.graph_workflow import simple_graph_search_results_factory +from .simple.ingestion_workflow import simple_ingestion_factory + +__all__ = [ + "hatchet_ingestion_factory", + "hatchet_graph_search_results_factory", + "simple_ingestion_factory", + "simple_graph_search_results_factory", +] diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py new file mode 100644 index 00000000..cc128b0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py @@ -0,0 +1,539 @@ +# type: ignore +import asyncio +import contextlib +import json +import logging +import math +import time +import uuid +from typing import TYPE_CHECKING + +from hatchet_sdk import ConcurrencyLimitStrategy, Context + +from core import GenerationConfig +from core.base import OrchestrationProvider, R2RException +from core.base.abstractions import ( + GraphConstructionStatus, + GraphExtractionStatus, +) + +from ...services import GraphService + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + +logger = logging.getLogger() + + +def hatchet_graph_search_results_factory( + orchestration_provider: OrchestrationProvider, service: GraphService +) -> dict[str, "Hatchet.Workflow"]: + def convert_to_dict(input_data): + """Converts input data back to a plain dictionary format, handling + special cases like UUID and GenerationConfig. This is the inverse of + get_input_data_dict. + + Args: + input_data: Dictionary containing the input data with potentially special types + + Returns: + Dictionary with all values converted to basic Python types + """ + output_data = {} + + for key, value in input_data.items(): + if value is None: + output_data[key] = None + continue + + # Convert UUID to string + if isinstance(value, uuid.UUID): + output_data[key] = str(value) + + try: + output_data[key] = value.model_dump() + except Exception: + # Handle nested dictionaries that might contain settings + if isinstance(value, dict): + output_data[key] = convert_to_dict(value) + + # Handle lists that might contain dictionaries + elif isinstance(value, list): + output_data[key] = [ + ( + convert_to_dict(item) + if isinstance(item, dict) + else item + ) + for item in value + ] + + # All other types can be directly assigned + else: + output_data[key] = value + + return output_data + + def get_input_data_dict(input_data): + for key, value in input_data.items(): + if value is None: + continue + + if key == "document_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "collection_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "graph_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key in ["graph_creation_settings", "graph_enrichment_settings"]: + # Ensure we have a dict (if not already) + input_data[key] = ( + json.loads(value) if not isinstance(value, dict) else value + ) + + if "generation_config" in input_data[key]: + gen_cfg = input_data[key]["generation_config"] + # If it's a dict, convert it + if isinstance(gen_cfg, dict): + input_data[key]["generation_config"] = ( + GenerationConfig(**gen_cfg) + ) + # If it's not already a GenerationConfig, default it + elif not isinstance(gen_cfg, GenerationConfig): + input_data[key]["generation_config"] = ( + GenerationConfig() + ) + + input_data[key]["generation_config"].model = ( + input_data[key]["generation_config"].model + or service.config.app.fast_llm + ) + + return input_data + + @orchestration_provider.workflow(name="graph-extraction", timeout="360m") + class GraphExtractionWorkflow: + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + with contextlib.suppress(Exception): + return str( + context.workflow_input()["request"]["collection_id"] + ) + + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.step(retries=1, timeout="360m") + async def graph_search_results_extraction( + self, context: Context + ) -> dict: + request = context.workflow_input()["request"] + + input_data = get_input_data_dict(request) + document_id = input_data.get("document_id", None) + collection_id = input_data.get("collection_id", None) + + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.PROCESSING, + ) + + if collection_id and not document_id: + document_ids = await self.graph_search_results_service.get_document_ids_for_create_graph( + collection_id=collection_id, + **input_data["graph_creation_settings"], + ) + workflows = [] + + for document_id in document_ids: + input_data_copy = input_data.copy() + input_data_copy["collection_id"] = str( + input_data_copy["collection_id"] + ) + input_data_copy["document_id"] = str(document_id) + + workflows.append( + context.aio.spawn_workflow( + "graph-extraction", + { + "request": { + **convert_to_dict(input_data_copy), + } + }, + key=str(document_id), + ) + ) + # Wait for all workflows to complete + results = await asyncio.gather(*workflows) + return { + "result": f"successfully submitted graph_search_results relationships extraction for document {document_id}", + "document_id": str(collection_id), + } + + else: + # Extract relationships and store them + extractions = [] + async for extraction in self.graph_search_results_service.graph_search_results_extraction( + document_id=document_id, + **input_data["graph_creation_settings"], + ): + logger.info( + f"Found extraction with {len(extraction.entities)} entities" + ) + extractions.append(extraction) + + await self.graph_search_results_service.store_graph_search_results_extractions( + extractions + ) + + logger.info( + f"Successfully ran graph_search_results relationships extraction for document {document_id}" + ) + + return { + "result": f"successfully ran graph_search_results relationships extraction for document {document_id}", + "document_id": str(document_id), + } + + @orchestration_provider.step( + retries=1, + timeout="360m", + parents=["graph_search_results_extraction"], + ) + async def graph_search_results_entity_description( + self, context: Context + ) -> dict: + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + document_id = input_data.get("document_id", None) + + # Describe the entities in the graph + await self.graph_search_results_service.graph_search_results_entity_description( + document_id=document_id, + **input_data["graph_creation_settings"], + ) + + logger.info( + f"Successfully ran graph_search_results entity description for document {document_id}" + ) + + if service.providers.database.config.graph_creation_settings.automatic_deduplication: + extract_input = { + "document_id": str(document_id), + } + + extract_result = ( + await context.aio.spawn_workflow( + "graph-deduplication", + {"request": extract_input}, + ) + ).result() + + await asyncio.gather(extract_result) + + return { + "result": f"successfully ran graph_search_results entity description for document {document_id}" + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.info( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=uuid.UUID(document_id), + status_type="extraction_status", + status=GraphExtractionStatus.FAILED, + ) + logger.info( + f"Updated Graph extraction status for {document_id} to FAILED" + ) + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow(name="graph-clustering", timeout="360m") + class GraphClusteringWorkflow: + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.step(retries=1, timeout="360m", parents=[]) + async def graph_search_results_clustering( + self, context: Context + ) -> dict: + logger.info("Running Graph Clustering") + + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + + # Get the collection_id and graph_id + collection_id = input_data.get("collection_id", None) + graph_id = input_data.get("graph_id", None) + + # Check current workflow status + workflow_status = await self.graph_search_results_service.providers.database.documents_handler.get_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + ) + + if workflow_status == GraphConstructionStatus.SUCCESS: + raise R2RException( + "Communities have already been built for this collection. To build communities again, first reset the graph.", + 400, + ) + + # Run clustering + try: + graph_search_results_clustering_results = await self.graph_search_results_service.graph_search_results_clustering( + collection_id=collection_id, + graph_id=graph_id, + **input_data["graph_enrichment_settings"], + ) + + num_communities = graph_search_results_clustering_results[ + "num_communities" + ][0] + + if num_communities == 0: + raise R2RException("No communities found", 400) + + return { + "result": graph_search_results_clustering_results, + } + except Exception as e: + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + raise e + + @orchestration_provider.step( + retries=1, + timeout="360m", + parents=["graph_search_results_clustering"], + ) + async def graph_search_results_community_summary( + self, context: Context + ) -> dict: + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + collection_id = input_data.get("collection_id", None) + graph_id = input_data.get("graph_id", None) + # Get number of communities from previous step + num_communities = context.step_output( + "graph_search_results_clustering" + )["result"]["num_communities"][0] + + # Calculate batching + parallel_communities = min(100, num_communities) + total_workflows = math.ceil(num_communities / parallel_communities) + workflows = [] + + logger.info( + f"Running Graph Community Summary for {num_communities} communities, spawning {total_workflows} workflows" + ) + + # Spawn summary workflows + for i in range(total_workflows): + offset = i * parallel_communities + limit = min(parallel_communities, num_communities - offset) + + workflows.append( + ( + await context.aio.spawn_workflow( + "graph-community-summarization", + { + "request": { + "offset": offset, + "limit": limit, + "graph_id": ( + str(graph_id) if graph_id else None + ), + "collection_id": ( + str(collection_id) + if collection_id + else None + ), + "graph_enrichment_settings": convert_to_dict( + input_data["graph_enrichment_settings"] + ), + } + }, + key=f"{i}/{total_workflows}_community_summary", + ) + ).result() + ) + + results = await asyncio.gather(*workflows) + logger.info( + f"Completed {len(results)} community summary workflows" + ) + + # Update statuses + document_ids = await self.graph_search_results_service.providers.database.documents_handler.get_document_ids_by_status( + status_type="extraction_status", + status=GraphExtractionStatus.SUCCESS, + collection_id=collection_id, + ) + + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=document_ids, + status_type="extraction_status", + status=GraphExtractionStatus.ENRICHED, + ) + + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.SUCCESS, + ) + + return { + "result": f"Successfully completed enrichment with {len(results)} summary workflows" + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + collection_id = context.workflow_input()["request"].get( + "collection_id", None + ) + if collection_id: + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=uuid.UUID(collection_id), + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + + @orchestration_provider.workflow( + name="graph-community-summarization", timeout="360m" + ) + class GraphCommunitySummarizerWorkflow: + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + return str( + context.workflow_input()["request"]["collection_id"] + ) + except Exception: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=1, timeout="360m") + async def graph_search_results_community_summary( + self, context: Context + ) -> dict: + start_time = time.time() + + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + + base_args = { + k: v + for k, v in input_data.items() + if k != "graph_enrichment_settings" + } + enrichment_args = input_data.get("graph_enrichment_settings", {}) + + # Merge them together. + # Note: if there is any key overlap, values from enrichment_args will override those from base_args. + merged_args = {**base_args, **enrichment_args} + + # Now call the service method with all arguments at the top level. + # This ensures that keys like "max_summary_input_length" and "generation_config" are present. + community_summary = await self.graph_search_results_service.graph_search_results_community_summary( + **merged_args + ) + logger.info( + f"Successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)} in {time.time() - start_time:.2f} seconds " + ) + return { + "result": f"successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}" + } + + @orchestration_provider.workflow( + name="graph-deduplication", timeout="360m" + ) + class GraphDeduplicationWorkflow: + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + return str(context.workflow_input()["request"]["document_id"]) + except Exception: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=1, timeout="360m") + async def deduplicate_document_entities( + self, context: Context + ) -> dict: + start_time = time.time() + + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + + document_id = input_data.get("document_id", None) + + await service.deduplicate_document_entities( + document_id=document_id, + ) + logger.info( + f"Successfully ran deduplication for document {document_id} in {time.time() - start_time:.2f} seconds " + ) + return { + "result": f"Successfully ran deduplication for document {document_id}" + } + + return { + "graph-extraction": GraphExtractionWorkflow(service), + "graph-clustering": GraphClusteringWorkflow(service), + "graph-community-summarization": GraphCommunitySummarizerWorkflow( + service + ), + "graph-deduplication": GraphDeduplicationWorkflow(service), + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py new file mode 100644 index 00000000..96d7aebb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py @@ -0,0 +1,721 @@ +# type: ignore +import asyncio +import logging +import uuid +from typing import TYPE_CHECKING +from uuid import UUID + +import tiktoken +from fastapi import HTTPException +from hatchet_sdk import ConcurrencyLimitStrategy, Context +from litellm import AuthenticationError + +from core.base import ( + DocumentChunk, + GraphConstructionStatus, + IngestionStatus, + OrchestrationProvider, + generate_extraction_id, +) +from core.base.abstractions import DocumentResponse, R2RException +from core.utils import ( + generate_default_user_collection_id, + update_settings_from_dict, +) + +from ...services import IngestionService, IngestionServiceAdapter + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + +logger = logging.getLogger() + + +# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module +def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback to a known encoding if model not recognized + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text, disallowed_special=())) + + +def hatchet_ingestion_factory( + orchestration_provider: OrchestrationProvider, service: IngestionService +) -> dict[str, "Hatchet.Workflow"]: + @orchestration_provider.workflow( + name="ingest-files", + timeout="60m", + ) + class HatchetIngestFilesWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.ingestion_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + return str(parsed_data["user"].id) + except Exception: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=0, timeout="60m") + async def parse(self, context: Context) -> dict: + try: + logger.info("Initiating ingestion workflow, step: parse") + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + # ingestion_result = ( + # await self.ingestion_service.ingest_file_ingress( + # **parsed_data + # ) + # ) + + # document_info = ingestion_result["info"] + document_info = ( + self.ingestion_service.create_document_info_from_file( + parsed_data["document_id"], + parsed_data["user"], + parsed_data["file_data"]["filename"], + parsed_data["metadata"], + parsed_data["version"], + parsed_data["size_in_bytes"], + ) + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.PARSING, + ) + + ingestion_config = parsed_data["ingestion_config"] or {} + extractions_generator = self.ingestion_service.parse_file( + document_info, ingestion_config + ) + + extractions = [] + async for extraction in extractions_generator: + extractions.append(extraction) + + # 2) Sum tokens + total_tokens = 0 + for chunk in extractions: + text_data = chunk.data + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + if not ingestion_config.get("skip_document_summary", False): + await service.update_document_status( + document_info, status=IngestionStatus.AUGMENTING + ) + await service.augment_document_info( + document_info, + [extraction.to_dict() for extraction in extractions], + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.EMBEDDING, + ) + + # extractions = context.step_output("parse")["extractions"] + + embedding_generator = self.ingestion_service.embed_document( + [extraction.to_dict() for extraction in extractions] + ) + + embeddings = [] + async for embedding in embedding_generator: + embeddings.append(embedding) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.STORING, + ) + + storage_generator = self.ingestion_service.store_embeddings( # type: ignore + embeddings + ) + + async for _ in storage_generator: + pass + + await self.ingestion_service.finalize_ingestion(document_info) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + collection_ids = context.workflow_input()["request"].get( + "collection_ids" + ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=GraphConstructionStatus.OUTDATED, + ) + else: + for collection_id_str in collection_ids: + collection_id = UUID(collection_id_str) + try: + name = document_info.title or "N/A" + description = "" + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await ( + self.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + ) + + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=GraphConstructionStatus.OUTDATED, + ) + + # get server chunk enrichment settings and override parts of it if provided in the ingestion config + if server_chunk_enrichment_settings := getattr( + service.providers.ingestion.config, + "chunk_enrichment_settings", + None, + ): + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: + logger.info("Enriching document with contextual chunks") + + document_info: DocumentResponse = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_user_ids=[document_info.owner_id], + filter_document_ids=[document_info.id], + ) + )["results"][0] + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.ENRICHING, + ) + + await self.ingestion_service.chunk_enrichment( + document_id=document_info.id, + document_summary=document_info.summary, + chunk_enrichment_settings=chunk_enrichment_settings, + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + # ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + + if service.providers.ingestion.config.automatic_extraction: + extract_input = { + "document_id": str(document_info.id), + "graph_creation_settings": self.ingestion_service.providers.database.config.graph_creation_settings.model_dump_json(), + "user": input_data["user"], + } + + extract_result = ( + await context.aio.spawn_workflow( + "graph-extraction", + {"request": extract_input}, + ) + ).result() + + await asyncio.gather(extract_result) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + except AuthenticationError: + raise R2RException( + status_code=401, + message="Authentication error: Invalid API key or credentials.", + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during ingestion: {str(e)}", + ) from e + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + # Update the document status to FAILED + if document_info.ingestion_status != IngestionStatus.SUCCESS: + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{context.step_run_errors()}"}, + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow( + name="ingest-chunks", + timeout="60m", + ) + class HatchetIngestChunksWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def ingest(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await self.ingestion_service.ingest_chunks_ingress( + **parsed_data + ) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentChunk( + id=generate_extraction_id(document_id, i), + document_id=document_id, + collection_ids=[], + owner_id=document_info.owner_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).to_dict() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + + # 2) Sum tokens + total_tokens = 0 + for chunk in extractions: + text_data = chunk["data"] + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + return { + "status": "Successfully ingested chunks", + "extractions": extractions, + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["ingest"], timeout="60m") + async def embed(self, context: Context) -> dict: + document_info_dict = context.step_output("ingest")["document_info"] + document_info = DocumentResponse(**document_info_dict) + + extractions = context.step_output("ingest")["extractions"] + + embedding_generator = self.ingestion_service.embed_document( + extractions + ) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + + storage_generator = self.ingestion_service.store_embeddings( + embeddings + ) + async for _ in storage_generator: + pass + + return { + "status": "Successfully embedded and stored chunks", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["embed"], timeout="60m") + async def finalize(self, context: Context) -> dict: + document_info_dict = context.step_output("embed")["document_info"] + document_info = DocumentResponse(**document_info_dict) + + await self.ingestion_service.finalize_ingestion(document_info) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + try: + # TODO - Move logic onto the `management service` + collection_ids = context.workflow_input()["request"].get( + "collection_ids" + ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=GraphConstructionStatus.OUTDATED, + ) + else: + for collection_id_str in collection_ids: + collection_id = UUID(collection_id_str) + try: + name = document_info.title or "N/A" + description = "" + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await ( + self.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + ) + + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_document_ids=[document_id], + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + if document_info.ingestion_status != IngestionStatus.SUCCESS: + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.FAILED + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow( + name="update-chunk", + timeout="60m", + ) + class HatchetUpdateChunkWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def update_chunk(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["id"]) + if isinstance(parsed_data["id"], str) + else parsed_data["id"] + ) + + await self.ingestion_service.update_chunk_ingress( + document_id=document_uuid, + chunk_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) + + return { + "message": "Chunk update completed successfully.", + "task_id": context.workflow_run_id(), + "document_ids": [str(document_uuid)], + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during chunk update: {str(e)}", + ) from e + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + + @orchestration_provider.workflow( + name="create-vector-index", timeout="360m" + ) + class HatchetCreateVectorIndexWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def create_vector_index(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = ( + IngestionServiceAdapter.parse_create_vector_index_input( + input_data + ) + ) + + await self.ingestion_service.providers.database.chunks_handler.create_index( + **parsed_data + ) + + return { + "status": "Vector index creation queued successfully.", + } + + @orchestration_provider.workflow(name="delete-vector-index", timeout="30m") + class HatchetDeleteVectorIndexWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="10m") + async def delete_vector_index(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = ( + IngestionServiceAdapter.parse_delete_vector_index_input( + input_data + ) + ) + + await self.ingestion_service.providers.database.chunks_handler.delete_index( + **parsed_data + ) + + return {"status": "Vector index deleted successfully."} + + @orchestration_provider.workflow( + name="update-document-metadata", + timeout="30m", + ) + class HatchetUpdateDocumentMetadataWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="30m") + async def update_document_metadata(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_document_metadata_input( + input_data + ) + + document_id = UUID(parsed_data["document_id"]) + metadata = parsed_data["metadata"] + user = parsed_data["user"] + + await self.ingestion_service.update_document_metadata( + document_id=document_id, + metadata=metadata, + user=user, + ) + + return { + "message": "Document metadata update completed successfully.", + "document_id": str(document_id), + "task_id": context.workflow_run_id(), + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during document metadata update: {str(e)}", + ) from e + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + + # Add this to the workflows dictionary in hatchet_ingestion_factory + ingest_files_workflow = HatchetIngestFilesWorkflow(service) + ingest_chunks_workflow = HatchetIngestChunksWorkflow(service) + update_chunks_workflow = HatchetUpdateChunkWorkflow(service) + update_document_metadata_workflow = HatchetUpdateDocumentMetadataWorkflow( + service + ) + create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service) + delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service) + + return { + "ingest_files": ingest_files_workflow, + "ingest_chunks": ingest_chunks_workflow, + "update_chunk": update_chunks_workflow, + "update_document_metadata": update_document_metadata_workflow, + "create_vector_index": create_vector_index_workflow, + "delete_vector_index": delete_vector_index_workflow, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py new file mode 100644 index 00000000..9e043263 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py @@ -0,0 +1,222 @@ +import json +import logging +import math +import uuid + +from core import GenerationConfig, R2RException +from core.base.abstractions import ( + GraphConstructionStatus, + GraphExtractionStatus, +) + +from ...services import GraphService + +logger = logging.getLogger() + + +def simple_graph_search_results_factory(service: GraphService): + def get_input_data_dict(input_data): + for key, value in input_data.items(): + if value is None: + continue + + if key == "document_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "collection_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "graph_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key in ["graph_creation_settings", "graph_enrichment_settings"]: + # Ensure we have a dict (if not already) + input_data[key] = ( + json.loads(value) if not isinstance(value, dict) else value + ) + + if "generation_config" in input_data[key]: + if isinstance(input_data[key]["generation_config"], dict): + input_data[key]["generation_config"] = ( + GenerationConfig( + **input_data[key]["generation_config"] + ) + ) + elif not isinstance( + input_data[key]["generation_config"], GenerationConfig + ): + input_data[key]["generation_config"] = ( + GenerationConfig() + ) + + input_data[key]["generation_config"].model = ( + input_data[key]["generation_config"].model + or service.config.app.fast_llm + ) + + return input_data + + async def graph_extraction(input_data): + input_data = get_input_data_dict(input_data) + + if input_data.get("document_id"): + document_ids = [input_data.get("document_id")] + else: + documents = [] + collection_id = input_data.get("collection_id") + batch_size = 100 + offset = 0 + while True: + # Fetch current batch + batch = ( + await service.providers.database.collections_handler.documents_in_collection( + collection_id=collection_id, + offset=offset, + limit=batch_size, + ) + )["results"] + + # If no documents returned, we've reached the end + if not batch: + break + + # Add current batch to results + documents.extend(batch) + + # Update offset for next batch + offset += batch_size + + # Optional: If batch is smaller than batch_size, we've reached the end + if len(batch) < batch_size: + break + + document_ids = [document.id for document in documents] + + logger.info( + f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}" + ) + + for _, document_id in enumerate(document_ids): + await service.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.PROCESSING, + ) + + # Extract relationships from the document + try: + extractions = [] + async for ( + extraction + ) in service.graph_search_results_extraction( + document_id=document_id, + **input_data["graph_creation_settings"], + ): + extractions.append(extraction) + await service.store_graph_search_results_extractions( + extractions + ) + + # Describe the entities in the graph + await service.graph_search_results_entity_description( + document_id=document_id, + **input_data["graph_creation_settings"], + ) + + if service.providers.database.config.graph_creation_settings.automatic_deduplication: + logger.warning( + "Automatic deduplication is not yet implemented for `simple` workflows." + ) + + except Exception as e: + logger.error( + f"Error in creating graph for document {document_id}: {e}" + ) + raise e + + async def graph_clustering(input_data): + input_data = get_input_data_dict(input_data) + workflow_status = await service.providers.database.documents_handler.get_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + ) + if workflow_status == GraphConstructionStatus.SUCCESS: + raise R2RException( + "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset` to erase the previously built communities.", + 400, + ) + + try: + num_communities = await service.graph_search_results_clustering( + collection_id=input_data.get("collection_id", None), + # graph_id=input_data.get("graph_id", None), + **input_data["graph_enrichment_settings"], + ) + num_communities = num_communities["num_communities"][0] + # TODO - Do not hardcode the number of parallel communities, + # make it a configurable parameter at runtime & add server-side defaults + + if num_communities == 0: + raise R2RException("No communities found", 400) + + parallel_communities = min(100, num_communities) + + total_workflows = math.ceil(num_communities / parallel_communities) + for i in range(total_workflows): + input_data_copy = input_data.copy() + input_data_copy["offset"] = i * parallel_communities + input_data_copy["limit"] = min( + parallel_communities, + num_communities - i * parallel_communities, + ) + + logger.info( + f"Running graph_search_results community summary for workflow {i + 1} of {total_workflows}" + ) + + await service.graph_search_results_community_summary( + offset=input_data_copy["offset"], + limit=input_data_copy["limit"], + collection_id=input_data_copy.get("collection_id", None), + # graph_id=input_data_copy.get("graph_id", None), + **input_data_copy["graph_enrichment_settings"], + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + status=GraphConstructionStatus.SUCCESS, + ) + + except Exception as e: + await service.providers.database.documents_handler.set_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + + raise e + + async def graph_deduplication(input_data): + input_data = get_input_data_dict(input_data) + await service.deduplicate_document_entities( + document_id=input_data.get("document_id", None), + ) + + return { + "graph-extraction": graph_extraction, + "graph-clustering": graph_clustering, + "graph-deduplication": graph_deduplication, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py new file mode 100644 index 00000000..60a696c1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py @@ -0,0 +1,598 @@ +import asyncio +import logging +from uuid import UUID + +import tiktoken +from fastapi import HTTPException +from litellm import AuthenticationError + +from core.base import ( + DocumentChunk, + GraphConstructionStatus, + R2RException, + increment_version, +) +from core.utils import ( + generate_default_user_collection_id, + generate_extraction_id, + update_settings_from_dict, +) + +from ...services import IngestionService + +logger = logging.getLogger() + + +# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module +def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback to a known encoding if model not recognized + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text, disallowed_special=())) + + +def simple_ingestion_factory(service: IngestionService): + async def ingest_files(input_data): + document_info = None + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + document_info = service.create_document_info_from_file( + parsed_data["document_id"], + parsed_data["user"], + parsed_data["file_data"]["filename"], + parsed_data["metadata"], + parsed_data["version"], + parsed_data["size_in_bytes"], + ) + + await service.update_document_status( + document_info, status=IngestionStatus.PARSING + ) + + ingestion_config = parsed_data["ingestion_config"] + extractions_generator = service.parse_file( + document_info=document_info, + ingestion_config=ingestion_config, + ) + extractions = [ + extraction.model_dump() + async for extraction in extractions_generator + ] + + # 2) Sum tokens + total_tokens = 0 + for chunk_dict in extractions: + text_data = chunk_dict["data"] + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + if not ingestion_config.get("skip_document_summary", False): + await service.update_document_status( + document_info=document_info, + status=IngestionStatus.AUGMENTING, + ) + await service.augment_document_info(document_info, extractions) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + embedding_generator = service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + collection_ids = parsed_data.get("collection_ids") + + try: + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + else: + for collection_id in collection_ids: + try: + # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully + name = "My Collection" + description = f"A collection started during {document_info.title} ingestion" + + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await service.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + # Chunk enrichment + if server_chunk_enrichment_settings := getattr( + service.providers.ingestion.config, + "chunk_enrichment_settings", + None, + ): + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: + logger.info("Enriching document with contextual chunks") + + # Get updated document info with collection IDs + document_info = ( + await service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[document_info.owner_id], + filter_document_ids=[document_info.id], + ) + )["results"][0] + + await service.update_document_status( + document_info, + status=IngestionStatus.ENRICHING, + ) + + await service.chunk_enrichment( + document_id=document_info.id, + document_summary=document_info.summary, + chunk_enrichment_settings=chunk_enrichment_settings, + ) + + await service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + # Automatic extraction + if service.providers.ingestion.config.automatic_extraction: + logger.warning( + "Automatic extraction not yet implemented for `simple` ingestion workflows." + ) + + except AuthenticationError as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + raise R2RException( + status_code=401, + message="Authentication error: Invalid API key or credentials.", + ) from e + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + if isinstance(e, R2RException): + raise + raise HTTPException( + status_code=500, detail=f"Error during ingestion: {str(e)}" + ) from e + + async def update_files(input_data): + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_update_files_input( + input_data + ) + + file_datas = parsed_data["file_datas"] + user = parsed_data["user"] + document_ids = parsed_data["document_ids"] + metadatas = parsed_data["metadatas"] + ingestion_config = parsed_data["ingestion_config"] + file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] + + if not file_datas: + raise R2RException( + status_code=400, message="No files provided for update." + ) from None + if len(document_ids) != len(file_datas): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) from None + + documents_overview = ( + await service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_user_ids=None if user.is_superuser else [user.id], + filter_document_ids=document_ids, + ) + )["results"] + + if len(documents_overview) != len(document_ids): + raise R2RException( + status_code=404, + message="One or more documents not found.", + ) from None + + results = [] + + for idx, ( + file_data, + doc_id, + doc_info, + file_size_in_bytes, + ) in enumerate( + zip( + file_datas, + document_ids, + documents_overview, + file_sizes_in_bytes, + strict=False, + ) + ): + new_version = increment_version(doc_info.version) + + updated_metadata = ( + metadatas[idx] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title") + or file_data["filename"].split("/")[-1] + ) + + ingest_input = { + "file_data": file_data, + "user": user.model_dump(), + "metadata": updated_metadata, + "document_id": str(doc_id), + "version": new_version, + "ingestion_config": ingestion_config, + "size_in_bytes": file_size_in_bytes, + } + + result = ingest_files(ingest_input) + results.append(result) + + await asyncio.gather(*results) + if service.providers.ingestion.config.automatic_extraction: + raise R2RException( + status_code=501, + message="Automatic extraction not yet implemented for `simple` ingestion workflows.", + ) from None + + async def ingest_chunks(input_data): + document_info = None + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await service.ingest_chunks_ingress(**parsed_data) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentChunk( + id=( + generate_extraction_id(document_id, i) + if chunk.id is None + else chunk.id + ), + document_id=document_id, + collection_ids=[], + owner_id=document_info.owner_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).model_dump() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + + embedding_generator = service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + collection_ids = parsed_data.get("collection_ids") + + try: + # TODO - Move logic onto management service + if not collection_ids: + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + + else: + for collection_id in collection_ids: + try: + name = document_info.title or "N/A" + description = "" + result = await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await service.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + + if service.providers.ingestion.config.automatic_extraction: + raise R2RException( + status_code=501, + message="Automatic extraction not yet implemented for `simple` ingestion workflows.", + ) from None + + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + raise HTTPException( + status_code=500, + detail=f"Error during chunk ingestion: {str(e)}", + ) from e + + async def update_chunk(input_data): + from core.main import IngestionServiceAdapter + + try: + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["id"]) + if isinstance(parsed_data["id"], str) + else parsed_data["id"] + ) + + await service.update_chunk_ingress( + document_id=document_uuid, + chunk_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during chunk update: {str(e)}", + ) from e + + async def create_vector_index(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_create_vector_index_input( + input_data + ) + ) + + await service.providers.database.chunks_handler.create_index( + **parsed_data + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during vector index creation: {str(e)}", + ) from e + + async def delete_vector_index(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_delete_vector_index_input( + input_data + ) + ) + + await service.providers.database.chunks_handler.delete_index( + **parsed_data + ) + + return {"status": "Vector index deleted successfully."} + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during vector index deletion: {str(e)}", + ) from e + + async def update_document_metadata(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_update_document_metadata_input( + input_data + ) + ) + + document_id = parsed_data["document_id"] + metadata = parsed_data["metadata"] + user = parsed_data["user"] + + await service.update_document_metadata( + document_id=document_id, + metadata=metadata, + user=user, + ) + + return { + "message": "Document metadata update completed successfully.", + "document_id": str(document_id), + "task_id": None, + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during document metadata update: {str(e)}", + ) from e + + return { + "ingest-files": ingest_files, + "ingest-chunks": ingest_chunks, + "update-chunk": update_chunk, + "update-document-metadata": update_document_metadata, + "create-vector-index": create_vector_index, + "delete-vector-index": delete_vector_index, + } |
