diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/main/orchestration/simple | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main/orchestration/simple')
3 files changed, 820 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py new file mode 100644 index 00000000..9e043263 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py @@ -0,0 +1,222 @@ +import json +import logging +import math +import uuid + +from core import GenerationConfig, R2RException +from core.base.abstractions import ( + GraphConstructionStatus, + GraphExtractionStatus, +) + +from ...services import GraphService + +logger = logging.getLogger() + + +def simple_graph_search_results_factory(service: GraphService): + def get_input_data_dict(input_data): + for key, value in input_data.items(): + if value is None: + continue + + if key == "document_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "collection_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "graph_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key in ["graph_creation_settings", "graph_enrichment_settings"]: + # Ensure we have a dict (if not already) + input_data[key] = ( + json.loads(value) if not isinstance(value, dict) else value + ) + + if "generation_config" in input_data[key]: + if isinstance(input_data[key]["generation_config"], dict): + input_data[key]["generation_config"] = ( + GenerationConfig( + **input_data[key]["generation_config"] + ) + ) + elif not isinstance( + input_data[key]["generation_config"], GenerationConfig + ): + input_data[key]["generation_config"] = ( + GenerationConfig() + ) + + input_data[key]["generation_config"].model = ( + input_data[key]["generation_config"].model + or service.config.app.fast_llm + ) + + return input_data + + async def graph_extraction(input_data): + input_data = get_input_data_dict(input_data) + + if input_data.get("document_id"): + document_ids = [input_data.get("document_id")] + else: + documents = [] + collection_id = input_data.get("collection_id") + batch_size = 100 + offset = 0 + while True: + # Fetch current batch + batch = ( + await service.providers.database.collections_handler.documents_in_collection( + collection_id=collection_id, + offset=offset, + limit=batch_size, + ) + )["results"] + + # If no documents returned, we've reached the end + if not batch: + break + + # Add current batch to results + documents.extend(batch) + + # Update offset for next batch + offset += batch_size + + # Optional: If batch is smaller than batch_size, we've reached the end + if len(batch) < batch_size: + break + + document_ids = [document.id for document in documents] + + logger.info( + f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}" + ) + + for _, document_id in enumerate(document_ids): + await service.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.PROCESSING, + ) + + # Extract relationships from the document + try: + extractions = [] + async for ( + extraction + ) in service.graph_search_results_extraction( + document_id=document_id, + **input_data["graph_creation_settings"], + ): + extractions.append(extraction) + await service.store_graph_search_results_extractions( + extractions + ) + + # Describe the entities in the graph + await service.graph_search_results_entity_description( + document_id=document_id, + **input_data["graph_creation_settings"], + ) + + if service.providers.database.config.graph_creation_settings.automatic_deduplication: + logger.warning( + "Automatic deduplication is not yet implemented for `simple` workflows." + ) + + except Exception as e: + logger.error( + f"Error in creating graph for document {document_id}: {e}" + ) + raise e + + async def graph_clustering(input_data): + input_data = get_input_data_dict(input_data) + workflow_status = await service.providers.database.documents_handler.get_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + ) + if workflow_status == GraphConstructionStatus.SUCCESS: + raise R2RException( + "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset` to erase the previously built communities.", + 400, + ) + + try: + num_communities = await service.graph_search_results_clustering( + collection_id=input_data.get("collection_id", None), + # graph_id=input_data.get("graph_id", None), + **input_data["graph_enrichment_settings"], + ) + num_communities = num_communities["num_communities"][0] + # TODO - Do not hardcode the number of parallel communities, + # make it a configurable parameter at runtime & add server-side defaults + + if num_communities == 0: + raise R2RException("No communities found", 400) + + parallel_communities = min(100, num_communities) + + total_workflows = math.ceil(num_communities / parallel_communities) + for i in range(total_workflows): + input_data_copy = input_data.copy() + input_data_copy["offset"] = i * parallel_communities + input_data_copy["limit"] = min( + parallel_communities, + num_communities - i * parallel_communities, + ) + + logger.info( + f"Running graph_search_results community summary for workflow {i + 1} of {total_workflows}" + ) + + await service.graph_search_results_community_summary( + offset=input_data_copy["offset"], + limit=input_data_copy["limit"], + collection_id=input_data_copy.get("collection_id", None), + # graph_id=input_data_copy.get("graph_id", None), + **input_data_copy["graph_enrichment_settings"], + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + status=GraphConstructionStatus.SUCCESS, + ) + + except Exception as e: + await service.providers.database.documents_handler.set_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + + raise e + + async def graph_deduplication(input_data): + input_data = get_input_data_dict(input_data) + await service.deduplicate_document_entities( + document_id=input_data.get("document_id", None), + ) + + return { + "graph-extraction": graph_extraction, + "graph-clustering": graph_clustering, + "graph-deduplication": graph_deduplication, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py new file mode 100644 index 00000000..60a696c1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py @@ -0,0 +1,598 @@ +import asyncio +import logging +from uuid import UUID + +import tiktoken +from fastapi import HTTPException +from litellm import AuthenticationError + +from core.base import ( + DocumentChunk, + GraphConstructionStatus, + R2RException, + increment_version, +) +from core.utils import ( + generate_default_user_collection_id, + generate_extraction_id, + update_settings_from_dict, +) + +from ...services import IngestionService + +logger = logging.getLogger() + + +# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module +def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback to a known encoding if model not recognized + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text, disallowed_special=())) + + +def simple_ingestion_factory(service: IngestionService): + async def ingest_files(input_data): + document_info = None + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + document_info = service.create_document_info_from_file( + parsed_data["document_id"], + parsed_data["user"], + parsed_data["file_data"]["filename"], + parsed_data["metadata"], + parsed_data["version"], + parsed_data["size_in_bytes"], + ) + + await service.update_document_status( + document_info, status=IngestionStatus.PARSING + ) + + ingestion_config = parsed_data["ingestion_config"] + extractions_generator = service.parse_file( + document_info=document_info, + ingestion_config=ingestion_config, + ) + extractions = [ + extraction.model_dump() + async for extraction in extractions_generator + ] + + # 2) Sum tokens + total_tokens = 0 + for chunk_dict in extractions: + text_data = chunk_dict["data"] + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + if not ingestion_config.get("skip_document_summary", False): + await service.update_document_status( + document_info=document_info, + status=IngestionStatus.AUGMENTING, + ) + await service.augment_document_info(document_info, extractions) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + embedding_generator = service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + collection_ids = parsed_data.get("collection_ids") + + try: + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + else: + for collection_id in collection_ids: + try: + # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully + name = "My Collection" + description = f"A collection started during {document_info.title} ingestion" + + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await service.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + # Chunk enrichment + if server_chunk_enrichment_settings := getattr( + service.providers.ingestion.config, + "chunk_enrichment_settings", + None, + ): + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: + logger.info("Enriching document with contextual chunks") + + # Get updated document info with collection IDs + document_info = ( + await service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[document_info.owner_id], + filter_document_ids=[document_info.id], + ) + )["results"][0] + + await service.update_document_status( + document_info, + status=IngestionStatus.ENRICHING, + ) + + await service.chunk_enrichment( + document_id=document_info.id, + document_summary=document_info.summary, + chunk_enrichment_settings=chunk_enrichment_settings, + ) + + await service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + # Automatic extraction + if service.providers.ingestion.config.automatic_extraction: + logger.warning( + "Automatic extraction not yet implemented for `simple` ingestion workflows." + ) + + except AuthenticationError as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + raise R2RException( + status_code=401, + message="Authentication error: Invalid API key or credentials.", + ) from e + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + if isinstance(e, R2RException): + raise + raise HTTPException( + status_code=500, detail=f"Error during ingestion: {str(e)}" + ) from e + + async def update_files(input_data): + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_update_files_input( + input_data + ) + + file_datas = parsed_data["file_datas"] + user = parsed_data["user"] + document_ids = parsed_data["document_ids"] + metadatas = parsed_data["metadatas"] + ingestion_config = parsed_data["ingestion_config"] + file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] + + if not file_datas: + raise R2RException( + status_code=400, message="No files provided for update." + ) from None + if len(document_ids) != len(file_datas): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) from None + + documents_overview = ( + await service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_user_ids=None if user.is_superuser else [user.id], + filter_document_ids=document_ids, + ) + )["results"] + + if len(documents_overview) != len(document_ids): + raise R2RException( + status_code=404, + message="One or more documents not found.", + ) from None + + results = [] + + for idx, ( + file_data, + doc_id, + doc_info, + file_size_in_bytes, + ) in enumerate( + zip( + file_datas, + document_ids, + documents_overview, + file_sizes_in_bytes, + strict=False, + ) + ): + new_version = increment_version(doc_info.version) + + updated_metadata = ( + metadatas[idx] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title") + or file_data["filename"].split("/")[-1] + ) + + ingest_input = { + "file_data": file_data, + "user": user.model_dump(), + "metadata": updated_metadata, + "document_id": str(doc_id), + "version": new_version, + "ingestion_config": ingestion_config, + "size_in_bytes": file_size_in_bytes, + } + + result = ingest_files(ingest_input) + results.append(result) + + await asyncio.gather(*results) + if service.providers.ingestion.config.automatic_extraction: + raise R2RException( + status_code=501, + message="Automatic extraction not yet implemented for `simple` ingestion workflows.", + ) from None + + async def ingest_chunks(input_data): + document_info = None + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await service.ingest_chunks_ingress(**parsed_data) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentChunk( + id=( + generate_extraction_id(document_id, i) + if chunk.id is None + else chunk.id + ), + document_id=document_id, + collection_ids=[], + owner_id=document_info.owner_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).model_dump() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + + embedding_generator = service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + collection_ids = parsed_data.get("collection_ids") + + try: + # TODO - Move logic onto management service + if not collection_ids: + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + + else: + for collection_id in collection_ids: + try: + name = document_info.title or "N/A" + description = "" + result = await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await service.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + + if service.providers.ingestion.config.automatic_extraction: + raise R2RException( + status_code=501, + message="Automatic extraction not yet implemented for `simple` ingestion workflows.", + ) from None + + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + raise HTTPException( + status_code=500, + detail=f"Error during chunk ingestion: {str(e)}", + ) from e + + async def update_chunk(input_data): + from core.main import IngestionServiceAdapter + + try: + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["id"]) + if isinstance(parsed_data["id"], str) + else parsed_data["id"] + ) + + await service.update_chunk_ingress( + document_id=document_uuid, + chunk_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during chunk update: {str(e)}", + ) from e + + async def create_vector_index(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_create_vector_index_input( + input_data + ) + ) + + await service.providers.database.chunks_handler.create_index( + **parsed_data + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during vector index creation: {str(e)}", + ) from e + + async def delete_vector_index(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_delete_vector_index_input( + input_data + ) + ) + + await service.providers.database.chunks_handler.delete_index( + **parsed_data + ) + + return {"status": "Vector index deleted successfully."} + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during vector index deletion: {str(e)}", + ) from e + + async def update_document_metadata(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_update_document_metadata_input( + input_data + ) + ) + + document_id = parsed_data["document_id"] + metadata = parsed_data["metadata"] + user = parsed_data["user"] + + await service.update_document_metadata( + document_id=document_id, + metadata=metadata, + user=user, + ) + + return { + "message": "Document metadata update completed successfully.", + "document_id": str(document_id), + "task_id": None, + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during document metadata update: {str(e)}", + ) from e + + return { + "ingest-files": ingest_files, + "ingest-chunks": ingest_chunks, + "update-chunk": update_chunk, + "update-document-metadata": update_document_metadata, + "create-vector-index": create_vector_index, + "delete-vector-index": delete_vector_index, + } |
