diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main/services')
7 files changed, 5856 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/__init__.py b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py new file mode 100644 index 00000000..e6a6dec0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py @@ -0,0 +1,14 @@ +from .auth_service import AuthService +from .graph_service import GraphService +from .ingestion_service import IngestionService, IngestionServiceAdapter +from .management_service import ManagementService +from .retrieval_service import RetrievalService # type: ignore + +__all__ = [ + "AuthService", + "IngestionService", + "IngestionServiceAdapter", + "ManagementService", + "GraphService", + "RetrievalService", +] diff --git a/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py new file mode 100644 index 00000000..c04dd78c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py @@ -0,0 +1,316 @@ +import logging +from datetime import datetime +from typing import Optional +from uuid import UUID + +from core.base import R2RException, Token +from core.base.api.models import User +from core.utils import generate_default_user_collection_id + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + + +class AuthService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def register( + self, + email: str, + password: str, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + ) -> User: + return await self.providers.auth.register( + email=email, + password=password, + name=name, + bio=bio, + profile_picture=profile_picture, + ) + + async def send_verification_email( + self, email: str + ) -> tuple[str, datetime]: + return await self.providers.auth.send_verification_email(email=email) + + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + if not self.config.auth.require_email_verification: + raise R2RException( + status_code=400, message="Email verification is not required" + ) + + user_id = await self.providers.database.users_handler.get_user_id_by_verification_code( + verification_code + ) + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if not user or user.email != email: + raise R2RException( + status_code=400, message="Invalid or expired verification code" + ) + + await self.providers.database.users_handler.mark_user_as_verified( + user_id + ) + await self.providers.database.users_handler.remove_verification_code( + verification_code + ) + return {"message": f"User account {user_id} verified successfully."} + + async def login(self, email: str, password: str) -> dict[str, Token]: + return await self.providers.auth.login(email, password) + + async def user(self, token: str) -> User: + token_data = await self.providers.auth.decode_token(token) + if not token_data.email: + raise R2RException( + status_code=401, message="Invalid authentication credentials" + ) + user = await self.providers.database.users_handler.get_user_by_email( + token_data.email + ) + if user is None: + raise R2RException( + status_code=401, message="Invalid authentication credentials" + ) + return user + + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + return await self.providers.auth.refresh_access_token(refresh_token) + + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + if not user: + raise R2RException(status_code=404, message="User not found") + return await self.providers.auth.change_password( + user, current_password, new_password + ) + + async def request_password_reset(self, email: str) -> dict[str, str]: + return await self.providers.auth.request_password_reset(email) + + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + return await self.providers.auth.confirm_password_reset( + reset_token, new_password + ) + + async def logout(self, token: str) -> dict[str, str]: + return await self.providers.auth.logout(token) + + async def update_user( + self, + user_id: UUID, + email: Optional[str] = None, + is_superuser: Optional[bool] = None, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + limits_overrides: Optional[dict] = None, + merge_limits: bool = False, + new_metadata: Optional[dict] = None, + ) -> User: + user: User = ( + await self.providers.database.users_handler.get_user_by_id(user_id) + ) + if not user: + raise R2RException(status_code=404, message="User not found") + if email is not None: + user.email = email + if is_superuser is not None: + user.is_superuser = is_superuser + if name is not None: + user.name = name + if bio is not None: + user.bio = bio + if profile_picture is not None: + user.profile_picture = profile_picture + if limits_overrides is not None: + user.limits_overrides = limits_overrides + return await self.providers.database.users_handler.update_user( + user, merge_limits=merge_limits, new_metadata=new_metadata + ) + + async def delete_user( + self, + user_id: UUID, + password: Optional[str] = None, + delete_vector_data: bool = False, + is_superuser: bool = False, + ) -> dict[str, str]: + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if not user: + raise R2RException(status_code=404, message="User not found") + if not is_superuser and not password: + raise R2RException( + status_code=422, message="Password is required for deletion" + ) + if not ( + is_superuser + or ( + user.hashed_password is not None + and password is not None + and self.providers.auth.crypto_provider.verify_password( + plain_password=password, + hashed_password=user.hashed_password, + ) + ) + ): + raise R2RException(status_code=400, message="Incorrect password") + await self.providers.database.users_handler.delete_user_relational( + user_id + ) + + # Delete user's default collection + # TODO: We need to better define what happens to the user's data when they are deleted + collection_id = generate_default_user_collection_id(user_id) + await self.providers.database.collections_handler.delete_collection_relational( + collection_id + ) + + try: + await self.providers.database.graphs_handler.delete( + collection_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Error deleting graph for collection {collection_id}: {e}" + ) + + if delete_vector_data: + await self.providers.database.chunks_handler.delete_user_vector( + user_id + ) + await self.providers.database.chunks_handler.delete_collection_vector( + collection_id + ) + + return {"message": f"User account {user_id} deleted successfully."} + + async def clean_expired_blacklisted_tokens( + self, + max_age_hours: int = 7 * 24, + current_time: Optional[datetime] = None, + ): + await self.providers.database.token_handler.clean_expired_blacklisted_tokens( + max_age_hours, current_time + ) + + async def get_user_verification_code( + self, + user_id: UUID, + ) -> dict: + """Get only the verification code data for a specific user. + + This method should be called after superuser authorization has been + verified. + """ + verification_data = await self.providers.database.users_handler.get_user_validation_data( + user_id=user_id + ) + return { + "verification_code": verification_data["verification_data"][ + "verification_code" + ], + "expiry": verification_data["verification_data"][ + "verification_code_expiry" + ], + } + + async def get_user_reset_token( + self, + user_id: UUID, + ) -> dict: + """Get only the verification code data for a specific user. + + This method should be called after superuser authorization has been + verified. + """ + verification_data = await self.providers.database.users_handler.get_user_validation_data( + user_id=user_id + ) + return { + "reset_token": verification_data["verification_data"][ + "reset_token" + ], + "expiry": verification_data["verification_data"][ + "reset_token_expiry" + ], + } + + async def send_reset_email(self, email: str) -> dict: + """Generate a new verification code and send a reset email to the user. + Returns the verification code for testing/sandbox environments. + + Args: + email (str): The email address of the user + + Returns: + dict: Contains verification_code and message + """ + return await self.providers.auth.send_reset_email(email) + + async def create_user_api_key( + self, user_id: UUID, name: Optional[str], description: Optional[str] + ) -> dict: + """Generate a new API key for the user with optional name and + description. + + Args: + user_id (UUID): The ID of the user + name (Optional[str]): Name of the API key + description (Optional[str]): Description of the API key + + Returns: + dict: Contains the API key and message + """ + return await self.providers.auth.create_user_api_key( + user_id=user_id, name=name, description=description + ) + + async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: + """Delete the API key for the user. + + Args: + user_id (UUID): The ID of the user + key_id (str): The ID of the API key + + Returns: + bool: True if the API key was deleted successfully + """ + return await self.providers.auth.delete_user_api_key( + user_id=user_id, key_id=key_id + ) + + async def list_user_api_keys(self, user_id: UUID) -> list[dict]: + """List all API keys for the user. + + Args: + user_id (UUID): The ID of the user + + Returns: + dict: Contains the list of API keys + """ + return await self.providers.auth.list_user_api_keys(user_id) diff --git a/.venv/lib/python3.12/site-packages/core/main/services/base.py b/.venv/lib/python3.12/site-packages/core/main/services/base.py new file mode 100644 index 00000000..dcd98fd5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/base.py @@ -0,0 +1,14 @@ +from abc import ABC + +from ..abstractions import R2RProviders +from ..config import R2RConfig + + +class Service(ABC): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + self.config = config + self.providers = providers diff --git a/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py new file mode 100644 index 00000000..56f32cf8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py @@ -0,0 +1,1358 @@ +import asyncio +import logging +import math +import random +import re +import time +import uuid +import xml.etree.ElementTree as ET +from typing import Any, AsyncGenerator, Coroutine, Optional +from uuid import UUID +from xml.etree.ElementTree import Element + +from core.base import ( + DocumentChunk, + GraphExtraction, + GraphExtractionStatus, + R2RDocumentProcessingError, +) +from core.base.abstractions import ( + Community, + Entity, + GenerationConfig, + GraphConstructionStatus, + R2RException, + Relationship, + StoreType, +) +from core.base.api.models import GraphResponse + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + +MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128 + + +async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]: + """Collects all results from an async generator into a list.""" + results = [] + async for res in result_gen: + results.append(res) + return results + + +class GraphService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def create_entity( + self, + name: str, + description: str, + parent_id: UUID, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return await self.providers.database.graphs_handler.entities.create( + name=name, + parent_id=parent_id, + store_type=StoreType.GRAPHS, + category=category, + description=description, + description_embedding=description_embedding, + metadata=metadata, + ) + + async def update_entity( + self, + entity_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + description_embedding = None + if description is not None: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return await self.providers.database.graphs_handler.entities.update( + entity_id=entity_id, + store_type=StoreType.GRAPHS, + name=name, + description=description, + description_embedding=description_embedding, + category=category, + metadata=metadata, + ) + + async def delete_entity( + self, + parent_id: UUID, + entity_id: UUID, + ): + return await self.providers.database.graphs_handler.entities.delete( + parent_id=parent_id, + entity_ids=[entity_id], + store_type=StoreType.GRAPHS, + ) + + async def get_entities( + self, + parent_id: UUID, + offset: int, + limit: int, + entity_ids: Optional[list[UUID]] = None, + entity_names: Optional[list[str]] = None, + include_embeddings: bool = False, + ): + return await self.providers.database.graphs_handler.get_entities( + parent_id=parent_id, + offset=offset, + limit=limit, + entity_ids=entity_ids, + entity_names=entity_names, + include_embeddings=include_embeddings, + ) + + async def create_relationship( + self, + subject: str, + subject_id: UUID, + predicate: str, + object: str, + object_id: UUID, + parent_id: UUID, + description: str | None = None, + weight: float | None = 1.0, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + description_embedding = None + if description: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return ( + await self.providers.database.graphs_handler.relationships.create( + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + parent_id=parent_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type=StoreType.GRAPHS, + ) + ) + + async def delete_relationship( + self, + parent_id: UUID, + relationship_id: UUID, + ): + return ( + await self.providers.database.graphs_handler.relationships.delete( + parent_id=parent_id, + relationship_ids=[relationship_id], + store_type=StoreType.GRAPHS, + ) + ) + + async def update_relationship( + self, + relationship_id: UUID, + subject: Optional[str] = None, + subject_id: Optional[UUID] = None, + predicate: Optional[str] = None, + object: Optional[str] = None, + object_id: Optional[UUID] = None, + description: Optional[str] = None, + weight: Optional[float] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + description_embedding = None + if description is not None: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return ( + await self.providers.database.graphs_handler.relationships.update( + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type=StoreType.GRAPHS, + ) + ) + + async def get_relationships( + self, + parent_id: UUID, + offset: int, + limit: int, + relationship_ids: Optional[list[UUID]] = None, + entity_names: Optional[list[str]] = None, + ): + return await self.providers.database.graphs_handler.relationships.get( + parent_id=parent_id, + store_type=StoreType.GRAPHS, + offset=offset, + limit=limit, + relationship_ids=relationship_ids, + entity_names=entity_names, + ) + + async def create_community( + self, + parent_id: UUID, + name: str, + summary: str, + findings: Optional[list[str]], + rating: Optional[float], + rating_explanation: Optional[str], + ) -> Community: + description_embedding = str( + await self.providers.embedding.async_get_embedding(summary) + ) + return await self.providers.database.graphs_handler.communities.create( + parent_id=parent_id, + store_type=StoreType.GRAPHS, + name=name, + summary=summary, + description_embedding=description_embedding, + findings=findings, + rating=rating, + rating_explanation=rating_explanation, + ) + + async def update_community( + self, + community_id: UUID, + name: Optional[str], + summary: Optional[str], + findings: Optional[list[str]], + rating: Optional[float], + rating_explanation: Optional[str], + ) -> Community: + summary_embedding = None + if summary is not None: + summary_embedding = str( + await self.providers.embedding.async_get_embedding(summary) + ) + + return await self.providers.database.graphs_handler.communities.update( + community_id=community_id, + store_type=StoreType.GRAPHS, + name=name, + summary=summary, + summary_embedding=summary_embedding, + findings=findings, + rating=rating, + rating_explanation=rating_explanation, + ) + + async def delete_community( + self, + parent_id: UUID, + community_id: UUID, + ) -> None: + await self.providers.database.graphs_handler.communities.delete( + parent_id=parent_id, + community_id=community_id, + ) + + async def get_communities( + self, + parent_id: UUID, + offset: int, + limit: int, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, + ): + return await self.providers.database.graphs_handler.get_communities( + parent_id=parent_id, + offset=offset, + limit=limit, + community_ids=community_ids, + include_embeddings=include_embeddings, + ) + + async def list_graphs( + self, + offset: int, + limit: int, + graph_ids: Optional[list[UUID]] = None, + collection_id: Optional[UUID] = None, + ) -> dict[str, list[GraphResponse] | int]: + return await self.providers.database.graphs_handler.list_graphs( + offset=offset, + limit=limit, + filter_graph_ids=graph_ids, + filter_collection_id=collection_id, + ) + + async def update_graph( + self, + collection_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> GraphResponse: + return await self.providers.database.graphs_handler.update( + collection_id=collection_id, + name=name, + description=description, + ) + + async def reset_graph(self, id: UUID) -> bool: + await self.providers.database.graphs_handler.reset( + parent_id=id, + ) + await self.providers.database.documents_handler.set_workflow_status( + id=id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.PENDING, + ) + return True + + async def get_document_ids_for_create_graph( + self, + collection_id: UUID, + **kwargs, + ): + document_status_filter = [ + GraphExtractionStatus.PENDING, + GraphExtractionStatus.FAILED, + ] + + return await self.providers.database.documents_handler.get_document_ids_by_status( + status_type="extraction_status", + status=[str(ele) for ele in document_status_filter], + collection_id=collection_id, + ) + + async def graph_search_results_entity_description( + self, + document_id: UUID, + max_description_input_length: int, + batch_size: int = 256, + **kwargs, + ): + """A new implementation of the old GraphDescriptionPipe logic inline. + No references to pipe objects. + + We: + 1) Count how many entities are in the document + 2) Process them in batches of `batch_size` + 3) For each batch, we retrieve the entity map and possibly call LLM for missing descriptions + """ + start_time = time.time() + logger.info( + f"GraphService: Running graph_search_results_entity_description for doc={document_id}" + ) + + # Count how many doc-entities exist + entity_count = ( + await self.providers.database.graphs_handler.get_entity_count( + document_id=document_id, + distinct=True, + entity_table_name="documents_entities", # or whichever table + ) + ) + logger.info( + f"GraphService: Found {entity_count} doc-entities to describe." + ) + + all_results = [] + num_batches = math.ceil(entity_count / batch_size) + + for i in range(num_batches): + offset = i * batch_size + limit = batch_size + + logger.info( + f"GraphService: describing batch {i + 1}/{num_batches}, offset={offset}, limit={limit}" + ) + + # Actually handle describing the entities in the batch + # We'll collect them into a list via an async generator + gen = self._describe_entities_in_document_batch( + document_id=document_id, + offset=offset, + limit=limit, + max_description_input_length=max_description_input_length, + ) + batch_results = await _collect_async_results(gen) + all_results.append(batch_results) + + # Mark the doc's extraction status as success + await self.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.SUCCESS, + ) + logger.info( + f"GraphService: Completed graph_search_results_entity_description for doc {document_id} in {time.time() - start_time:.2f}s." + ) + return all_results + + async def _describe_entities_in_document_batch( + self, + document_id: UUID, + offset: int, + limit: int, + max_description_input_length: int, + ) -> AsyncGenerator[str, None]: + """Core logic that replaces GraphDescriptionPipe._run_logic for a + particular document/batch. + + Yields entity-names or some textual result as each entity is updated. + """ + start_time = time.time() + logger.info( + f"Started describing doc={document_id}, offset={offset}, limit={limit}" + ) + + # 1) Get the "entity map" from the DB + entity_map = ( + await self.providers.database.graphs_handler.get_entity_map( + offset=offset, limit=limit, document_id=document_id + ) + ) + total_entities = len(entity_map) + logger.info( + f"_describe_entities_in_document_batch: got {total_entities} items in entity_map for doc={document_id}." + ) + + # 2) For each entity name in the map, we gather sub-entities and relationships + tasks: list[Coroutine[Any, Any, str]] = [] + tasks.extend( + self._process_entity_for_description( + entities=[ + entity if isinstance(entity, Entity) else Entity(**entity) + for entity in entity_info["entities"] + ], + relationships=[ + rel + if isinstance(rel, Relationship) + else Relationship(**rel) + for rel in entity_info["relationships"] + ], + document_id=document_id, + max_description_input_length=max_description_input_length, + ) + for entity_name, entity_info in entity_map.items() + ) + + # 3) Wait for all tasks, yield as they complete + idx = 0 + for coro in asyncio.as_completed(tasks): + result = await coro + idx += 1 + if idx % 100 == 0: + logger.info( + f"_describe_entities_in_document_batch: {idx}/{total_entities} described for doc={document_id}" + ) + yield result + + logger.info( + f"Finished describing doc={document_id} batch offset={offset} in {time.time() - start_time:.2f}s." + ) + + async def _process_entity_for_description( + self, + entities: list[Entity], + relationships: list[Relationship], + document_id: UUID, + max_description_input_length: int, + ) -> str: + """Adapted from the old process_entity function in + GraphDescriptionPipe. + + If entity has no description, call an LLM to create one, then store it. + Returns the name of the top entity (or could store more details). + """ + + def truncate_info(info_list: list[str], max_length: int) -> str: + """Shuffles lines of info to try to keep them distinct, then + accumulates until hitting max_length.""" + random.shuffle(info_list) + truncated_info = "" + current_length = 0 + for info in info_list: + if current_length + len(info) > max_length: + break + truncated_info += info + "\n" + current_length += len(info) + return truncated_info + + # Grab a doc-level summary (optional) to feed into the prompt + response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + document_summary = ( + response["results"][0].summary if response["results"] else None + ) + + # Synthesize a minimal “entity info” string + relationship summary + entity_info = [ + f"{e.name}, {e.description or 'NONE'}" for e in entities + ] + relationships_txt = [ + f"{i + 1}: {r.subject}, {r.object}, {r.predicate} - Summary: {r.description or ''}" + for i, r in enumerate(relationships) + ] + + # We'll describe only the first entity for simplicity + # or you could do them all if needed + main_entity = entities[0] + + if not main_entity.description: + # We only call LLM if the entity is missing a description + messages = await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, + task_inputs={ + "document_summary": document_summary, + "entity_info": truncate_info( + entity_info, max_description_input_length + ), + "relationships_txt": truncate_info( + relationships_txt, max_description_input_length + ), + }, + ) + + # Call the LLM + gen_config = ( + self.providers.database.config.graph_creation_settings.generation_config + or GenerationConfig(model=self.config.app.fast_llm) + ) + llm_resp = await self.providers.llm.aget_completion( + messages=messages, + generation_config=gen_config, + ) + new_description = llm_resp.choices[0].message.content + + if not new_description: + logger.error( + f"No LLM description returned for entity={main_entity.name}" + ) + return main_entity.name + + # create embedding + embed = ( + await self.providers.embedding.async_get_embeddings( + [new_description] + ) + )[0] + + # update DB + main_entity.description = new_description + main_entity.description_embedding = embed + + # Use a method to upsert entity in `documents_entities` or your table + await self.providers.database.graphs_handler.add_entities( + [main_entity], + table_name="documents_entities", + ) + + return main_entity.name + + async def graph_search_results_clustering( + self, + collection_id: UUID, + generation_config: GenerationConfig, + leiden_params: dict, + **kwargs, + ): + """ + Replacement for the old GraphClusteringPipe logic: + 1) call perform_graph_clustering on the DB + 2) return the result + """ + logger.info( + f"Running inline clustering for collection={collection_id} with params={leiden_params}" + ) + return await self._perform_graph_clustering( + collection_id=collection_id, + generation_config=generation_config, + leiden_params=leiden_params, + ) + + async def _perform_graph_clustering( + self, + collection_id: UUID, + generation_config: GenerationConfig, + leiden_params: dict, + ) -> dict: + """The actual clustering logic (previously in + GraphClusteringPipe.cluster_graph_search_results).""" + num_communities = await self.providers.database.graphs_handler.perform_graph_clustering( + collection_id=collection_id, + leiden_params=leiden_params, + ) + return {"num_communities": num_communities} + + async def graph_search_results_community_summary( + self, + offset: int, + limit: int, + max_summary_input_length: int, + generation_config: GenerationConfig, + collection_id: UUID, + leiden_params: Optional[dict] = None, + **kwargs, + ): + """Replacement for the old GraphCommunitySummaryPipe logic. + + Summarizes communities after clustering. Returns an async generator or + you can collect into a list. + """ + logger.info( + f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}" + ) + # We call an internal function that yields summaries + gen = self._summarize_communities( + offset=offset, + limit=limit, + max_summary_input_length=max_summary_input_length, + generation_config=generation_config, + collection_id=collection_id, + leiden_params=leiden_params or {}, + ) + return await _collect_async_results(gen) + + async def _summarize_communities( + self, + offset: int, + limit: int, + max_summary_input_length: int, + generation_config: GenerationConfig, + collection_id: UUID, + leiden_params: dict, + ) -> AsyncGenerator[dict, None]: + """Does the community summary logic from + GraphCommunitySummaryPipe._run_logic. + + Yields each summary dictionary as it completes. + """ + start_time = time.time() + logger.info( + f"Starting community summarization for collection={collection_id}" + ) + + # get all entities & relationships + ( + all_entities, + _, + ) = await self.providers.database.graphs_handler.get_entities( + parent_id=collection_id, + offset=0, + limit=-1, + include_embeddings=False, + ) + ( + all_relationships, + _, + ) = await self.providers.database.graphs_handler.get_relationships( + parent_id=collection_id, + offset=0, + limit=-1, + include_embeddings=False, + ) + + # We can optionally re-run the clustering to produce fresh community assignments + ( + _, + community_clusters, + ) = await self.providers.database.graphs_handler._cluster_and_add_community_info( + relationships=all_relationships, + leiden_params=leiden_params, + collection_id=collection_id, + ) + + # Group clusters + clusters: dict[Any, list[str]] = {} + for item in community_clusters: + cluster_id = item["cluster"] + node_name = item["node"] + clusters.setdefault(cluster_id, []).append(node_name) + + # create an async job for each cluster + tasks: list[Coroutine[Any, Any, dict]] = [] + + tasks.extend( + self._process_community_summary( + community_id=uuid.uuid4(), + nodes=nodes, + all_entities=all_entities, + all_relationships=all_relationships, + max_summary_input_length=max_summary_input_length, + generation_config=generation_config, + collection_id=collection_id, + ) + for nodes in clusters.values() + ) + + total_jobs = len(tasks) + results_returned = 0 + total_errors = 0 + + for coro in asyncio.as_completed(tasks): + summary = await coro + results_returned += 1 + if results_returned % 50 == 0: + logger.info( + f"Community summaries: {results_returned}/{total_jobs} done in {time.time() - start_time:.2f}s" + ) + if "error" in summary: + total_errors += 1 + yield summary + + if total_errors > 0: + logger.warning( + f"{total_errors} communities failed summarization out of {total_jobs}" + ) + + async def _process_community_summary( + self, + community_id: UUID, + nodes: list[str], + all_entities: list[Entity], + all_relationships: list[Relationship], + max_summary_input_length: int, + generation_config: GenerationConfig, + collection_id: UUID, + ) -> dict: + """ + Summarize a single community: gather all relevant entities/relationships, call LLM to generate an XML block, + parse it, store the result as a community in DB. + """ + # (Equivalent to process_community in old code) + # fetch the collection description (optional) + response = await self.providers.database.collections_handler.get_collections_overview( + offset=0, + limit=1, + filter_collection_ids=[collection_id], + ) + collection_description = ( + response["results"][0].description if response["results"] else None # type: ignore + ) + + # filter out relevant entities / relationships + entities = [e for e in all_entities if e.name in nodes] + relationships = [ + r + for r in all_relationships + if r.subject in nodes and r.object in nodes + ] + if not entities and not relationships: + return { + "community_id": community_id, + "error": f"No data in this community (nodes={nodes})", + } + + # Create the big input text for the LLM + input_text = await self._community_summary_prompt( + entities, + relationships, + max_summary_input_length, + ) + + # Attempt up to 3 times to parse + for attempt in range(3): + try: + # Build the prompt + messages = await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=self.providers.database.config.graph_enrichment_settings.graph_communities_prompt, + task_inputs={ + "collection_description": collection_description, + "input_text": input_text, + }, + ) + llm_resp = await self.providers.llm.aget_completion( + messages=messages, + generation_config=generation_config, + ) + llm_text = llm_resp.choices[0].message.content or "" + + # find <community>...</community> XML + match = re.search( + r"<community>.*?</community>", llm_text, re.DOTALL + ) + if not match: + raise ValueError( + "No <community> XML found in LLM response" + ) + + xml_content = match.group(0) + root = ET.fromstring(xml_content) + + # extract fields + name_elem = root.find("name") + summary_elem = root.find("summary") + rating_elem = root.find("rating") + rating_expl_elem = root.find("rating_explanation") + findings_elem = root.find("findings") + + name = name_elem.text if name_elem is not None else "" + summary = summary_elem.text if summary_elem is not None else "" + rating = ( + float(rating_elem.text) + if isinstance(rating_elem, Element) and rating_elem.text + else "" + ) + rating_explanation = ( + rating_expl_elem.text + if rating_expl_elem is not None + else None + ) + findings = ( + [f.text for f in findings_elem.findall("finding")] + if findings_elem is not None + else [] + ) + + # build embedding + embed_text = ( + "Summary:\n" + + (summary or "") + + "\n\nFindings:\n" + + "\n".join( + finding for finding in findings if finding is not None + ) + ) + embedding = await self.providers.embedding.async_get_embedding( + embed_text + ) + + # build Community object + community = Community( + community_id=community_id, + collection_id=collection_id, + name=name, + summary=summary, + rating=rating, + rating_explanation=rating_explanation, + findings=findings, + description_embedding=embedding, + ) + + # store it + await self.providers.database.graphs_handler.add_community( + community + ) + + return { + "community_id": community_id, + "name": name, + } + + except Exception as e: + logger.error( + f"Error summarizing community {community_id}: {e}" + ) + if attempt == 2: + return {"community_id": community_id, "error": str(e)} + await asyncio.sleep(1) + + # fallback + return {"community_id": community_id, "error": "Failed after retries"} + + async def _community_summary_prompt( + self, + entities: list[Entity], + relationships: list[Relationship], + max_summary_input_length: int, + ) -> str: + """Gathers the entity/relationship text, tries not to exceed + `max_summary_input_length`.""" + # Group them by entity.name + entity_map: dict[str, dict] = {} + for e in entities: + entity_map.setdefault( + e.name, {"entities": [], "relationships": []} + ) + entity_map[e.name]["entities"].append(e) + + for r in relationships: + # subject + entity_map.setdefault( + r.subject, {"entities": [], "relationships": []} + ) + entity_map[r.subject]["relationships"].append(r) + + # sort by # of relationships + sorted_entries = sorted( + entity_map.items(), + key=lambda x: len(x[1]["relationships"]), + reverse=True, + ) + + # build up the prompt text + prompt_chunks = [] + cur_len = 0 + for entity_name, data in sorted_entries: + block = f"\nEntity: {entity_name}\nDescriptions:\n" + block += "\n".join( + f"{e.id},{(e.description or '')}" for e in data["entities"] + ) + block += "\nRelationships:\n" + block += "\n".join( + f"{r.id},{r.subject},{r.object},{r.predicate},{r.description or ''}" + for r in data["relationships"] + ) + # check length + if cur_len + len(block) > max_summary_input_length: + prompt_chunks.append( + block[: max_summary_input_length - cur_len] + ) + break + else: + prompt_chunks.append(block) + cur_len += len(block) + + return "".join(prompt_chunks) + + async def delete( + self, + collection_id: UUID, + **kwargs, + ): + return await self.providers.database.graphs_handler.delete( + collection_id=collection_id, + ) + + async def graph_search_results_extraction( + self, + document_id: UUID, + generation_config: GenerationConfig, + entity_types: list[str], + relation_types: list[str], + chunk_merge_count: int, + filter_out_existing_chunks: bool = True, + total_tasks: Optional[int] = None, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]: + """The original “extract Graph from doc” logic, but inlined instead of + referencing a pipe.""" + start_time = time.time() + + logger.info( + f"Graph Extraction: Processing document {document_id} for graph extraction" + ) + + # Retrieve chunks from DB + chunks = [] + limit = 100 + offset = 0 + while True: + chunk_req = await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=offset, + limit=limit, + ) + new_chunk_objs = [ + DocumentChunk( + id=chunk["id"], + document_id=chunk["document_id"], + owner_id=chunk["owner_id"], + collection_ids=chunk["collection_ids"], + data=chunk["text"], + metadata=chunk["metadata"], + ) + for chunk in chunk_req["results"] + ] + chunks.extend(new_chunk_objs) + if len(chunk_req["results"]) < limit: + break + offset += limit + + if not chunks: + logger.info(f"No chunks found for document {document_id}") + raise R2RException( + message="No chunks found for document", + status_code=404, + ) + + # Possibly filter out any chunks that have already been processed + if filter_out_existing_chunks: + existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids( + document_id=document_id + ) + before_count = len(chunks) + chunks = [c for c in chunks if c.id not in existing_chunk_ids] + logger.info( + f"Filtered out {len(existing_chunk_ids)} existing chunk-IDs. {before_count}->{len(chunks)} remain." + ) + if not chunks: + return # nothing left to yield + + # sort by chunk_order if present + chunks = sorted( + chunks, + key=lambda x: x.metadata.get("chunk_order", float("inf")), + ) + + # group them + grouped_chunks = [ + chunks[i : i + chunk_merge_count] + for i in range(0, len(chunks), chunk_merge_count) + ] + + logger.info( + f"Graph Extraction: Created {len(grouped_chunks)} tasks for doc={document_id}" + ) + tasks = [ + asyncio.create_task( + self._extract_graph_search_results_from_chunk_group( + chunk_group, + generation_config, + entity_types, + relation_types, + ) + ) + for chunk_group in grouped_chunks + ] + + completed_tasks = 0 + for t in asyncio.as_completed(tasks): + try: + yield await t + completed_tasks += 1 + if completed_tasks % 100 == 0: + logger.info( + f"Graph Extraction: completed {completed_tasks}/{len(tasks)} tasks" + ) + except Exception as e: + logger.error(f"Error extracting from chunk group: {e}") + yield R2RDocumentProcessingError( + document_id=document_id, + error_message=str(e), + ) + + logger.info( + f"Graph Extraction: done with {document_id}, time={time.time() - start_time:.2f}s" + ) + + async def _extract_graph_search_results_from_chunk_group( + self, + chunks: list[DocumentChunk], + generation_config: GenerationConfig, + entity_types: list[str], + relation_types: list[str], + retries: int = 5, + delay: int = 2, + ) -> GraphExtraction: + """(Equivalent to _extract_graph_search_results in old code.) Merges + chunk data, calls LLM, parses XML, returns GraphExtraction object.""" + combined_extraction: str = " ".join( + [ + c.data.decode("utf-8") if isinstance(c.data, bytes) else c.data + for c in chunks + if c.data + ] + ) + + # Possibly get doc-level summary + doc_id = chunks[0].document_id + response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[doc_id], + ) + document_summary = ( + response["results"][0].summary if response["results"] else None + ) + + # Build messages/prompt + prompt_name = self.providers.database.config.graph_creation_settings.graph_extraction_prompt + messages = ( + await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=prompt_name, + task_inputs={ + "document_summary": document_summary or "", + "input": combined_extraction, + "entity_types": "\n".join(entity_types), + "relation_types": "\n".join(relation_types), + }, + ) + ) + + for attempt in range(retries): + try: + resp = await self.providers.llm.aget_completion( + messages, generation_config=generation_config + ) + graph_search_results_str = resp.choices[0].message.content + + if not graph_search_results_str: + raise R2RException( + "No extraction found in LLM response.", + 400, + ) + + # parse the XML + ( + entities, + relationships, + ) = await self._parse_graph_search_results_extraction_xml( + graph_search_results_str, chunks + ) + return GraphExtraction( + entities=entities, relationships=relationships + ) + + except Exception as e: + if attempt < retries - 1: + await asyncio.sleep(delay) + continue + else: + logger.error( + f"All extraction attempts for doc={doc_id} and chunks{[chunk.id for chunk in chunks]} failed with error:\n{e}" + ) + return GraphExtraction(entities=[], relationships=[]) + + return GraphExtraction(entities=[], relationships=[]) + + async def _parse_graph_search_results_extraction_xml( + self, response_str: str, chunks: list[DocumentChunk] + ) -> tuple[list[Entity], list[Relationship]]: + """Helper to parse the LLM's XML format, handle edge cases/cleanup, + produce Entities/Relationships.""" + + def sanitize_xml(r: str) -> str: + # Remove markdown fences + r = re.sub(r"```xml|```", "", r) + # Remove xml instructions or userStyle + r = re.sub(r"<\?.*?\?>", "", r) + r = re.sub(r"<userStyle>.*?</userStyle>", "", r) + # Replace bare `&` with `&` + r = re.sub(r"&(?!amp;|quot;|apos;|lt;|gt;)", "&", r) + # Also remove <root> if it appears + r = r.replace("<root>", "").replace("</root>", "") + return r.strip() + + cleaned_xml = sanitize_xml(response_str) + wrapped = f"<root>{cleaned_xml}</root>" + try: + root = ET.fromstring(wrapped) + except ET.ParseError: + raise R2RException( + f"Failed to parse XML:\nData: {wrapped[:1000]}...", 400 + ) from None + + entities_elems = root.findall(".//entity") + if ( + len(response_str) > MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH + and len(entities_elems) == 0 + ): + raise R2RException( + f"No <entity> found in LLM XML, possibly malformed. Response excerpt: {response_str[:300]}", + 400, + ) + + # build entity objects + doc_id = chunks[0].document_id + chunk_ids = [c.id for c in chunks] + entities_list: list[Entity] = [] + for element in entities_elems: + name_attr = element.get("name") + type_elem = element.find("type") + desc_elem = element.find("description") + category = type_elem.text if type_elem is not None else None + desc = desc_elem.text if desc_elem is not None else None + desc_embed = await self.providers.embedding.async_get_embedding( + desc or "" + ) + ent = Entity( + category=category, + description=desc, + name=name_attr, + parent_id=doc_id, + chunk_ids=chunk_ids, + description_embedding=desc_embed, + attributes={}, + ) + entities_list.append(ent) + + # build relationship objects + relationships_list: list[Relationship] = [] + rel_elems = root.findall(".//relationship") + for r_elem in rel_elems: + source_elem = r_elem.find("source") + target_elem = r_elem.find("target") + type_elem = r_elem.find("type") + desc_elem = r_elem.find("description") + weight_elem = r_elem.find("weight") + try: + subject = source_elem.text if source_elem is not None else "" + object_ = target_elem.text if target_elem is not None else "" + predicate = type_elem.text if type_elem is not None else "" + desc = desc_elem.text if desc_elem is not None else "" + weight = ( + float(weight_elem.text) + if isinstance(weight_elem, Element) and weight_elem.text + else "" + ) + embed = await self.providers.embedding.async_get_embedding( + desc or "" + ) + + rel = Relationship( + subject=subject, + predicate=predicate, + object=object_, + description=desc, + weight=weight, + parent_id=doc_id, + chunk_ids=chunk_ids, + attributes={}, + description_embedding=embed, + ) + relationships_list.append(rel) + except Exception: + continue + return entities_list, relationships_list + + async def store_graph_search_results_extractions( + self, + graph_search_results_extractions: list[GraphExtraction], + ): + """Stores a batch of knowledge graph extractions in the DB.""" + for extraction in graph_search_results_extractions: + # Map name->id after creation + entities_id_map = {} + for e in extraction.entities: + if e.parent_id is not None: + result = await self.providers.database.graphs_handler.entities.create( + name=e.name, + parent_id=e.parent_id, + store_type=StoreType.DOCUMENTS, + category=e.category, + description=e.description, + description_embedding=e.description_embedding, + chunk_ids=e.chunk_ids, + metadata=e.metadata, + ) + entities_id_map[e.name] = result.id + else: + logger.warning(f"Skipping entity with None parent_id: {e}") + + # Insert relationships + for rel in extraction.relationships: + subject_id = entities_id_map.get(rel.subject) + object_id = entities_id_map.get(rel.object) + parent_id = rel.parent_id + + if any( + id is None for id in (subject_id, object_id, parent_id) + ): + logger.warning(f"Missing ID for relationship: {rel}") + continue + + assert isinstance(subject_id, UUID) + assert isinstance(object_id, UUID) + assert isinstance(parent_id, UUID) + + await self.providers.database.graphs_handler.relationships.create( + subject=rel.subject, + subject_id=subject_id, + predicate=rel.predicate, + object=rel.object, + object_id=object_id, + parent_id=parent_id, + description=rel.description, + description_embedding=rel.description_embedding, + weight=rel.weight, + metadata=rel.metadata, + store_type=StoreType.DOCUMENTS, + ) + + async def deduplicate_document_entities( + self, + document_id: UUID, + ): + """ + Inlined from old code: merges duplicates by name, calls LLM for a new consolidated description, updates the record. + """ + merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks( + parent_id=document_id, + store_type=StoreType.DOCUMENTS, + ) + + # Grab doc summary + response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + document_summary = ( + response["results"][0].summary if response["results"] else None + ) + + # For each merged entity + for original_entities, merged_entity in merged_results: + # Summarize them with LLM + entity_info = "\n".join( + e.description for e in original_entities if e.description + ) + messages = await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, + task_inputs={ + "document_summary": document_summary, + "entity_info": f"{merged_entity.name}\n{entity_info}", + "relationships_txt": "", + }, + ) + gen_config = ( + self.config.database.graph_creation_settings.generation_config + or GenerationConfig(model=self.config.app.fast_llm) + ) + resp = await self.providers.llm.aget_completion( + messages, generation_config=gen_config + ) + new_description = resp.choices[0].message.content + + new_embedding = await self.providers.embedding.async_get_embedding( + new_description or "" + ) + + if merged_entity.id is not None: + await self.providers.database.graphs_handler.entities.update( + entity_id=merged_entity.id, + store_type=StoreType.DOCUMENTS, + description=new_description, + description_embedding=str(new_embedding), + ) + else: + logger.warning("Skipping update for entity with None id") diff --git a/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py new file mode 100644 index 00000000..55b06911 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py @@ -0,0 +1,983 @@ +import asyncio +import json +import logging +from datetime import datetime +from typing import Any, AsyncGenerator, Optional, Sequence +from uuid import UUID + +from fastapi import HTTPException + +from core.base import ( + Document, + DocumentChunk, + DocumentResponse, + DocumentType, + GenerationConfig, + IngestionStatus, + R2RException, + RawChunk, + UnprocessedChunk, + Vector, + VectorEntry, + VectorType, + generate_id, +) +from core.base.abstractions import ( + ChunkEnrichmentSettings, + IndexMeasure, + IndexMethod, + R2RDocumentProcessingError, + VectorTableName, +) +from core.base.api.models import User +from shared.abstractions import PDFParsingError, PopplerNotFoundError + +from ..abstractions import R2RProviders +from ..config import R2RConfig + +logger = logging.getLogger() +STARTING_VERSION = "v0" + + +class IngestionService: + """A refactored IngestionService that inlines all pipe logic for parsing, + embedding, and vector storage directly in its methods.""" + + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ) -> None: + self.config = config + self.providers = providers + + async def ingest_file_ingress( + self, + file_data: dict, + user: User, + document_id: UUID, + size_in_bytes, + metadata: Optional[dict] = None, + version: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> dict: + """Pre-ingests a file by creating or validating the DocumentResponse + entry. + + Does not actually parse/ingest the content. (See parse_file() for that + step.) + """ + try: + if not file_data: + raise R2RException( + status_code=400, message="No files provided for ingestion." + ) + if not file_data.get("filename"): + raise R2RException( + status_code=400, message="File name not provided." + ) + + metadata = metadata or {} + version = version or STARTING_VERSION + + document_info = self.create_document_info_from_file( + document_id, + user, + file_data["filename"], + metadata, + version, + size_in_bytes, + ) + + existing_document_info = ( + await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[user.id], + filter_document_ids=[document_id], + ) + )["results"] + + # Validate ingestion status for re-ingestion + if len(existing_document_info) > 0: + existing_doc = existing_document_info[0] + if existing_doc.ingestion_status == IngestionStatus.SUCCESS: + raise R2RException( + status_code=409, + message=( + f"Document {document_id} already exists. " + "Submit a DELETE request to `/documents/{document_id}` " + "to delete this document and allow for re-ingestion." + ), + ) + elif existing_doc.ingestion_status != IngestionStatus.FAILED: + raise R2RException( + status_code=409, + message=( + f"Document {document_id} is currently ingesting " + f"with status {existing_doc.ingestion_status}." + ), + ) + + # Set to PARSING until we actually parse + document_info.ingestion_status = IngestionStatus.PARSING + await self.providers.database.documents_handler.upsert_documents_overview( + document_info + ) + + return { + "info": document_info, + } + except R2RException as e: + logger.error(f"R2RException in ingest_file_ingress: {str(e)}") + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error during ingestion: {str(e)}" + ) from e + + def create_document_info_from_file( + self, + document_id: UUID, + user: User, + file_name: str, + metadata: dict, + version: str, + size_in_bytes: int, + ) -> DocumentResponse: + file_extension = ( + file_name.split(".")[-1].lower() if file_name != "N/A" else "txt" + ) + if file_extension.upper() not in DocumentType.__members__: + raise R2RException( + status_code=415, + message=f"'{file_extension}' is not a valid DocumentType.", + ) + + metadata = metadata or {} + metadata["version"] = version + + return DocumentResponse( + id=document_id, + owner_id=user.id, + collection_ids=metadata.get("collection_ids", []), + document_type=DocumentType[file_extension.upper()], + title=( + metadata.get("title", file_name.split("/")[-1]) + if file_name != "N/A" + else "N/A" + ), + metadata=metadata, + version=version, + size_in_bytes=size_in_bytes, + ingestion_status=IngestionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + + def _create_document_info_from_chunks( + self, + document_id: UUID, + user: User, + chunks: list[RawChunk], + metadata: dict, + version: str, + ) -> DocumentResponse: + metadata = metadata or {} + metadata["version"] = version + + return DocumentResponse( + id=document_id, + owner_id=user.id, + collection_ids=metadata.get("collection_ids", []), + document_type=DocumentType.TXT, + title=metadata.get("title", f"Ingested Chunks - {document_id}"), + metadata=metadata, + version=version, + size_in_bytes=sum( + len(chunk.text.encode("utf-8")) for chunk in chunks + ), + ingestion_status=IngestionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + + async def parse_file( + self, + document_info: DocumentResponse, + ingestion_config: dict | None, + ) -> AsyncGenerator[DocumentChunk, None]: + """Reads the file content from the DB, calls the ingestion + provider to parse, and yields DocumentChunk objects.""" + version = document_info.version or "v0" + ingestion_config_override = ingestion_config or {} + + # The ingestion config might specify a different provider, etc. + override_provider = ingestion_config_override.pop("provider", None) + if ( + override_provider + and override_provider != self.providers.ingestion.config.provider + ): + raise ValueError( + f"Provider '{override_provider}' does not match ingestion provider " + f"'{self.providers.ingestion.config.provider}'." + ) + + try: + # Pull file from DB + retrieved = ( + await self.providers.database.files_handler.retrieve_file( + document_info.id + ) + ) + if not retrieved: + # No file found in the DB, can't parse + raise R2RDocumentProcessingError( + document_id=document_info.id, + error_message="No file content found in DB for this document.", + ) + + file_name, file_wrapper, file_size = retrieved + + # Read the content + with file_wrapper as file_content_stream: + file_content = file_content_stream.read() + + # Build a barebones Document object + doc = Document( + id=document_info.id, + collection_ids=document_info.collection_ids, + owner_id=document_info.owner_id, + metadata={ + "document_type": document_info.document_type.value, + **document_info.metadata, + }, + document_type=document_info.document_type, + ) + + # Delegate to the ingestion provider to parse + async for extraction in self.providers.ingestion.parse( + file_content, # raw bytes + doc, + ingestion_config_override, + ): + # Adjust chunk ID to incorporate version + # or any other needed transformations + extraction.id = generate_id(f"{extraction.id}_{version}") + extraction.metadata["version"] = version + yield extraction + + except (PopplerNotFoundError, PDFParsingError) as e: + raise R2RDocumentProcessingError( + error_message=e.message, + document_id=document_info.id, + status_code=e.status_code, + ) from None + except Exception as e: + if isinstance(e, R2RException): + raise + raise R2RDocumentProcessingError( + document_id=document_info.id, + error_message=f"Error parsing document: {str(e)}", + ) from e + + async def augment_document_info( + self, + document_info: DocumentResponse, + chunked_documents: list[dict], + ) -> None: + if not self.config.ingestion.skip_document_summary: + document = f"Document Title: {document_info.title}\n" + if document_info.metadata != {}: + document += f"Document Metadata: {json.dumps(document_info.metadata)}\n" + + document += "Document Text:\n" + for chunk in chunked_documents[ + : self.config.ingestion.chunks_for_document_summary + ]: + document += chunk["data"] + + messages = await self.providers.database.prompts_handler.get_message_payload( + system_prompt_name=self.config.ingestion.document_summary_system_prompt, + task_prompt_name=self.config.ingestion.document_summary_task_prompt, + task_inputs={ + "document": document[ + : self.config.ingestion.document_summary_max_length + ] + }, + ) + + response = await self.providers.llm.aget_completion( + messages=messages, + generation_config=GenerationConfig( + model=self.config.ingestion.document_summary_model + or self.config.app.fast_llm + ), + ) + + document_info.summary = response.choices[0].message.content # type: ignore + + if not document_info.summary: + raise ValueError("Expected a generated response.") + + embedding = await self.providers.embedding.async_get_embedding( + text=document_info.summary, + ) + document_info.summary_embedding = embedding + return + + async def embed_document( + self, + chunked_documents: list[dict], + embedding_batch_size: int = 8, + ) -> AsyncGenerator[VectorEntry, None]: + """Inline replacement for the old embedding_pipe.run(...). + + Batches the embedding calls and yields VectorEntry objects. + """ + if not chunked_documents: + return + + concurrency_limit = ( + self.providers.embedding.config.concurrent_request_limit or 5 + ) + extraction_batch: list[DocumentChunk] = [] + tasks: set[asyncio.Task] = set() + + async def process_batch( + batch: list[DocumentChunk], + ) -> list[VectorEntry]: + # All text from the batch + texts = [ + ( + ex.data.decode("utf-8") + if isinstance(ex.data, bytes) + else ex.data + ) + for ex in batch + ] + # Retrieve embeddings in bulk + vectors = await self.providers.embedding.async_get_embeddings( + texts, # list of strings + ) + # Zip them back together + results = [] + for raw_vector, extraction in zip(vectors, batch, strict=False): + results.append( + VectorEntry( + id=extraction.id, + document_id=extraction.document_id, + owner_id=extraction.owner_id, + collection_ids=extraction.collection_ids, + vector=Vector(data=raw_vector, type=VectorType.FIXED), + text=( + extraction.data.decode("utf-8") + if isinstance(extraction.data, bytes) + else str(extraction.data) + ), + metadata={**extraction.metadata}, + ) + ) + return results + + async def run_process_batch(batch: list[DocumentChunk]): + return await process_batch(batch) + + # Convert each chunk dict to a DocumentChunk + for chunk_dict in chunked_documents: + extraction = DocumentChunk.from_dict(chunk_dict) + extraction_batch.append(extraction) + + # If we hit a batch threshold, spawn a task + if len(extraction_batch) >= embedding_batch_size: + tasks.add( + asyncio.create_task(run_process_batch(extraction_batch)) + ) + extraction_batch = [] + + # If tasks are at concurrency limit, wait for the first to finish + while len(tasks) >= concurrency_limit: + done, tasks = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + for t in done: + for vector_entry in await t: + yield vector_entry + + # Handle any leftover items + if extraction_batch: + tasks.add(asyncio.create_task(run_process_batch(extraction_batch))) + + # Gather remaining tasks + for future_task in asyncio.as_completed(tasks): + for vector_entry in await future_task: + yield vector_entry + + async def store_embeddings( + self, + embeddings: Sequence[dict | VectorEntry], + storage_batch_size: int = 128, + ) -> AsyncGenerator[str, None]: + """Inline replacement for the old vector_storage_pipe.run(...). + + Batches up the vector entries, enforces usage limits, stores them, and + yields a success/error string (or you could yield a StorageResult). + """ + if not embeddings: + return + + vector_entries: list[VectorEntry] = [] + for item in embeddings: + if isinstance(item, VectorEntry): + vector_entries.append(item) + else: + vector_entries.append(VectorEntry.from_dict(item)) + + vector_batch: list[VectorEntry] = [] + document_counts: dict[UUID, int] = {} + + # We'll track usage from the first user we see; if your scenario allows + # multiple user owners in a single ingestion, you'd need to refine usage checks. + current_usage = None + user_id_for_usage_check: UUID | None = None + + count = 0 + + for msg in vector_entries: + # If we haven't set usage yet, do so on the first chunk + if current_usage is None: + user_id_for_usage_check = msg.owner_id + usage_data = ( + await self.providers.database.chunks_handler.list_chunks( + limit=1, + offset=0, + filters={"owner_id": msg.owner_id}, + ) + ) + current_usage = usage_data["total_entries"] + + # Figure out the user's limit + user = await self.providers.database.users_handler.get_user_by_id( + msg.owner_id + ) + max_chunks = ( + self.providers.database.config.app.default_max_chunks_per_user + ) + if user.limits_overrides and "max_chunks" in user.limits_overrides: + max_chunks = user.limits_overrides["max_chunks"] + + # Add to our local batch + vector_batch.append(msg) + document_counts[msg.document_id] = ( + document_counts.get(msg.document_id, 0) + 1 + ) + count += 1 + + # Check usage + if ( + current_usage is not None + and (current_usage + len(vector_batch) + count) > max_chunks + ): + error_message = f"User {msg.owner_id} has exceeded the maximum number of allowed chunks: {max_chunks}" + logger.error(error_message) + yield error_message + continue + + # Once we hit our batch size, store them + if len(vector_batch) >= storage_batch_size: + try: + await ( + self.providers.database.chunks_handler.upsert_entries( + vector_batch + ) + ) + except Exception as e: + logger.error(f"Failed to store vector batch: {e}") + yield f"Error: {e}" + vector_batch.clear() + + # Store any leftover items + if vector_batch: + try: + await self.providers.database.chunks_handler.upsert_entries( + vector_batch + ) + except Exception as e: + logger.error(f"Failed to store final vector batch: {e}") + yield f"Error: {e}" + + # Summaries + for doc_id, cnt in document_counts.items(): + info_msg = f"Successful ingestion for document_id: {doc_id}, with vector count: {cnt}" + logger.info(info_msg) + yield info_msg + + async def finalize_ingestion( + self, document_info: DocumentResponse + ) -> None: + """Called at the end of a successful ingestion pipeline to set the + document status to SUCCESS or similar final steps.""" + + async def empty_generator(): + yield document_info + + await self.update_document_status( + document_info, IngestionStatus.SUCCESS + ) + return empty_generator() + + async def update_document_status( + self, + document_info: DocumentResponse, + status: IngestionStatus, + metadata: Optional[dict] = None, + ) -> None: + document_info.ingestion_status = status + if metadata: + document_info.metadata = {**document_info.metadata, **metadata} + await self._update_document_status_in_db(document_info) + + async def _update_document_status_in_db( + self, document_info: DocumentResponse + ): + try: + await self.providers.database.documents_handler.upsert_documents_overview( + document_info + ) + except Exception as e: + logger.error( + f"Failed to update document status: {document_info.id}. Error: {str(e)}" + ) + + async def ingest_chunks_ingress( + self, + document_id: UUID, + metadata: Optional[dict], + chunks: list[RawChunk], + user: User, + *args: Any, + **kwargs: Any, + ) -> DocumentResponse: + """Directly ingest user-provided text chunks (rather than from a + file).""" + if not chunks: + raise R2RException( + status_code=400, message="No chunks provided for ingestion." + ) + metadata = metadata or {} + version = STARTING_VERSION + + document_info = self._create_document_info_from_chunks( + document_id, + user, + chunks, + metadata, + version, + ) + + existing_document_info = ( + await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[user.id], + filter_document_ids=[document_id], + ) + )["results"] + if len(existing_document_info) > 0: + existing_doc = existing_document_info[0] + if existing_doc.ingestion_status != IngestionStatus.FAILED: + raise R2RException( + status_code=409, + message=( + f"Document {document_id} was already ingested " + "and is not in a failed state." + ), + ) + + await self.providers.database.documents_handler.upsert_documents_overview( + document_info + ) + return document_info + + async def update_chunk_ingress( + self, + document_id: UUID, + chunk_id: UUID, + text: str, + user: User, + metadata: Optional[dict] = None, + *args: Any, + **kwargs: Any, + ) -> dict: + """Update an individual chunk's text and metadata, re-embed, and re- + store it.""" + # Verify chunk exists and user has access + existing_chunks = ( + await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=0, + limit=1, + ) + ) + if not existing_chunks["results"]: + raise R2RException( + status_code=404, + message=f"Chunk with chunk_id {chunk_id} not found.", + ) + + existing_chunk = ( + await self.providers.database.chunks_handler.get_chunk(chunk_id) + ) + if not existing_chunk: + raise R2RException( + status_code=404, + message=f"Chunk with id {chunk_id} not found", + ) + + if ( + str(existing_chunk["owner_id"]) != str(user.id) + and not user.is_superuser + ): + raise R2RException( + status_code=403, + message="You don't have permission to modify this chunk.", + ) + + # Merge metadata + merged_metadata = {**existing_chunk["metadata"]} + if metadata is not None: + merged_metadata |= metadata + + # Create updated chunk + extraction_data = { + "id": chunk_id, + "document_id": document_id, + "collection_ids": kwargs.get( + "collection_ids", existing_chunk["collection_ids"] + ), + "owner_id": existing_chunk["owner_id"], + "data": text or existing_chunk["text"], + "metadata": merged_metadata, + } + extraction = DocumentChunk(**extraction_data).model_dump() + + # Re-embed + embeddings_generator = self.embed_document( + [extraction], embedding_batch_size=1 + ) + embeddings = [] + async for embedding in embeddings_generator: + embeddings.append(embedding) + + # Re-store + store_gen = self.store_embeddings(embeddings, storage_batch_size=1) + async for _ in store_gen: + pass + + return extraction + + async def _get_enriched_chunk_text( + self, + chunk_idx: int, + chunk: dict, + document_id: UUID, + document_summary: str | None, + chunk_enrichment_settings: ChunkEnrichmentSettings, + list_document_chunks: list[dict], + ) -> VectorEntry: + """Helper for chunk_enrichment. + + Leverages an LLM to rewrite or expand chunk text, then re-embeds it. + """ + preceding_chunks = [ + list_document_chunks[idx]["text"] + for idx in range( + max(0, chunk_idx - chunk_enrichment_settings.n_chunks), + chunk_idx, + ) + ] + succeeding_chunks = [ + list_document_chunks[idx]["text"] + for idx in range( + chunk_idx + 1, + min( + len(list_document_chunks), + chunk_idx + chunk_enrichment_settings.n_chunks + 1, + ), + ) + ] + try: + # Obtain the updated text from the LLM + updated_chunk_text = ( + ( + await self.providers.llm.aget_completion( + messages=await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=chunk_enrichment_settings.chunk_enrichment_prompt, + task_inputs={ + "document_summary": document_summary or "None", + "chunk": chunk["text"], + "preceding_chunks": ( + "\n".join(preceding_chunks) + if preceding_chunks + else "None" + ), + "succeeding_chunks": ( + "\n".join(succeeding_chunks) + if succeeding_chunks + else "None" + ), + "chunk_size": self.config.ingestion.chunk_size + or 1024, + }, + ), + generation_config=chunk_enrichment_settings.generation_config + or GenerationConfig(model=self.config.app.fast_llm), + ) + ) + .choices[0] + .message.content + ) + except Exception: + updated_chunk_text = chunk["text"] + chunk["metadata"]["chunk_enrichment_status"] = "failed" + else: + chunk["metadata"]["chunk_enrichment_status"] = ( + "success" if updated_chunk_text else "failed" + ) + + if not updated_chunk_text or not isinstance(updated_chunk_text, str): + updated_chunk_text = str(chunk["text"]) + chunk["metadata"]["chunk_enrichment_status"] = "failed" + + # Re-embed + data = await self.providers.embedding.async_get_embedding( + updated_chunk_text + ) + chunk["metadata"]["original_text"] = chunk["text"] + + return VectorEntry( + id=generate_id(str(chunk["id"])), + vector=Vector(data=data, type=VectorType.FIXED, length=len(data)), + document_id=document_id, + owner_id=chunk["owner_id"], + collection_ids=chunk["collection_ids"], + text=updated_chunk_text, + metadata=chunk["metadata"], + ) + + async def chunk_enrichment( + self, + document_id: UUID, + document_summary: str | None, + chunk_enrichment_settings: ChunkEnrichmentSettings, + ) -> int: + """Example function that modifies chunk text via an LLM then re-embeds + and re-stores all chunks for the given document.""" + list_document_chunks = ( + await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=0, + limit=-1, + ) + )["results"] + + new_vector_entries: list[VectorEntry] = [] + tasks = [] + total_completed = 0 + + for chunk_idx, chunk in enumerate(list_document_chunks): + tasks.append( + self._get_enriched_chunk_text( + chunk_idx=chunk_idx, + chunk=chunk, + document_id=document_id, + document_summary=document_summary, + chunk_enrichment_settings=chunk_enrichment_settings, + list_document_chunks=list_document_chunks, + ) + ) + + # Process in batches of e.g. 128 concurrency + if len(tasks) == 128: + new_vector_entries.extend(await asyncio.gather(*tasks)) + total_completed += 128 + logger.info( + f"Completed {total_completed} out of {len(list_document_chunks)} chunks for document {document_id}" + ) + tasks = [] + + # Finish any remaining tasks + new_vector_entries.extend(await asyncio.gather(*tasks)) + logger.info( + f"Completed enrichment of {len(list_document_chunks)} chunks for document {document_id}" + ) + + # Delete old chunks from vector db + await self.providers.database.chunks_handler.delete( + filters={"document_id": document_id} + ) + + # Insert the newly enriched entries + await self.providers.database.chunks_handler.upsert_entries( + new_vector_entries + ) + return len(new_vector_entries) + + async def list_chunks( + self, + offset: int, + limit: int, + filters: Optional[dict[str, Any]] = None, + include_vectors: bool = False, + *args: Any, + **kwargs: Any, + ) -> dict: + return await self.providers.database.chunks_handler.list_chunks( + offset=offset, + limit=limit, + filters=filters, + include_vectors=include_vectors, + ) + + async def get_chunk( + self, + chunk_id: UUID, + *args: Any, + **kwargs: Any, + ) -> dict: + return await self.providers.database.chunks_handler.get_chunk(chunk_id) + + async def update_document_metadata( + self, + document_id: UUID, + metadata: dict, + user: User, + ) -> None: + # Verify document exists and user has access + existing_document = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_document_ids=[document_id], + filter_user_ids=[user.id], + ) + if not existing_document["results"]: + raise R2RException( + status_code=404, + message=( + f"Document with id {document_id} not found " + "or you don't have access." + ), + ) + + existing_document = existing_document["results"][0] + + # Merge metadata + merged_metadata = {**existing_document.metadata, **metadata} # type: ignore + + # Update document metadata + existing_document.metadata = merged_metadata # type: ignore + await self.providers.database.documents_handler.upsert_documents_overview( + existing_document # type: ignore + ) + + +class IngestionServiceAdapter: + @staticmethod + def _parse_user_data(user_data) -> User: + if isinstance(user_data, str): + try: + user_data = json.loads(user_data) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid user data format: {user_data}" + ) from e + return User.from_dict(user_data) + + @staticmethod + def parse_ingest_file_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "metadata": data["metadata"], + "document_id": ( + UUID(data["document_id"]) if data["document_id"] else None + ), + "version": data.get("version"), + "ingestion_config": data["ingestion_config"] or {}, + "file_data": data["file_data"], + "size_in_bytes": data["size_in_bytes"], + "collection_ids": data.get("collection_ids", []), + } + + @staticmethod + def parse_ingest_chunks_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "metadata": data["metadata"], + "document_id": data["document_id"], + "chunks": [ + UnprocessedChunk.from_dict(chunk) for chunk in data["chunks"] + ], + "id": data.get("id"), + } + + @staticmethod + def parse_update_chunk_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "document_id": UUID(data["document_id"]), + "id": UUID(data["id"]), + "text": data["text"], + "metadata": data.get("metadata"), + "collection_ids": data.get("collection_ids", []), + } + + @staticmethod + def parse_update_files_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "document_ids": [UUID(doc_id) for doc_id in data["document_ids"]], + "metadatas": data["metadatas"], + "ingestion_config": data["ingestion_config"], + "file_sizes_in_bytes": data["file_sizes_in_bytes"], + "file_datas": data["file_datas"], + } + + @staticmethod + def parse_create_vector_index_input(data: dict) -> dict: + return { + "table_name": VectorTableName(data["table_name"]), + "index_method": IndexMethod(data["index_method"]), + "index_measure": IndexMeasure(data["index_measure"]), + "index_name": data["index_name"], + "index_column": data["index_column"], + "index_arguments": data["index_arguments"], + "concurrently": data["concurrently"], + } + + @staticmethod + def parse_list_vector_indices_input(input_data: dict) -> dict: + return {"table_name": input_data["table_name"]} + + @staticmethod + def parse_delete_vector_index_input(input_data: dict) -> dict: + return { + "index_name": input_data["index_name"], + "table_name": input_data.get("table_name"), + "concurrently": input_data.get("concurrently", True), + } + + @staticmethod + def parse_select_vector_index_input(input_data: dict) -> dict: + return { + "index_name": input_data["index_name"], + "table_name": input_data.get("table_name"), + } + + @staticmethod + def parse_update_document_metadata_input(data: dict) -> dict: + return { + "document_id": data["document_id"], + "metadata": data["metadata"], + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + } diff --git a/.venv/lib/python3.12/site-packages/core/main/services/management_service.py b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py new file mode 100644 index 00000000..62b4ca0b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py @@ -0,0 +1,1084 @@ +import logging +import os +from collections import defaultdict +from datetime import datetime, timedelta, timezone +from typing import IO, Any, BinaryIO, Optional, Tuple +from uuid import UUID + +import toml + +from core.base import ( + CollectionResponse, + ConversationResponse, + DocumentResponse, + GenerationConfig, + GraphConstructionStatus, + Message, + MessageResponse, + Prompt, + R2RException, + StoreType, + User, +) + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + + +class ManagementService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def app_settings(self): + prompts = ( + await self.providers.database.prompts_handler.get_all_prompts() + ) + config_toml = self.config.to_toml() + config_dict = toml.loads(config_toml) + try: + project_name = os.environ["R2R_PROJECT_NAME"] + except KeyError: + project_name = "" + return { + "config": config_dict, + "prompts": prompts, + "r2r_project_name": project_name, + } + + async def users_overview( + self, + offset: int, + limit: int, + user_ids: Optional[list[UUID]] = None, + ): + return await self.providers.database.users_handler.get_users_overview( + offset=offset, + limit=limit, + user_ids=user_ids, + ) + + async def delete_documents_and_chunks_by_filter( + self, + filters: dict[str, Any], + ): + """Delete chunks matching the given filters. If any documents are now + empty (i.e., have no remaining chunks), delete those documents as well. + + Args: + filters (dict[str, Any]): Filters specifying which chunks to delete. + chunks_handler (PostgresChunksHandler): The handler for chunk operations. + documents_handler (PostgresDocumentsHandler): The handler for document operations. + graphs_handler: Handler for entity and relationship operations in the Graph. + + Returns: + dict: A summary of what was deleted. + """ + + def transform_chunk_id_to_id( + filters: dict[str, Any], + ) -> dict[str, Any]: + """Example transformation function if your filters use `chunk_id` + instead of `id`. + + Recursively transform `chunk_id` to `id`. + """ + if isinstance(filters, dict): + transformed = {} + for key, value in filters.items(): + if key == "chunk_id": + transformed["id"] = value + elif key in ["$and", "$or"]: + transformed[key] = [ + transform_chunk_id_to_id(item) for item in value + ] + else: + transformed[key] = transform_chunk_id_to_id(value) + return transformed + return filters + + # Transform filters if needed. + transformed_filters = transform_chunk_id_to_id(filters) + + # Find chunks that match the filters before deleting + interim_results = ( + await self.providers.database.chunks_handler.list_chunks( + filters=transformed_filters, + offset=0, + limit=1_000, + include_vectors=False, + ) + ) + + results = interim_results["results"] + while interim_results["total_entries"] == 1_000: + # If we hit the limit, we need to paginate to get all results + + interim_results = ( + await self.providers.database.chunks_handler.list_chunks( + filters=transformed_filters, + offset=interim_results["offset"] + 1_000, + limit=1_000, + include_vectors=False, + ) + ) + results.extend(interim_results["results"]) + + document_ids = set() + owner_id = None + + if "$and" in filters: + for condition in filters["$and"]: + if "owner_id" in condition and "$eq" in condition["owner_id"]: + owner_id = condition["owner_id"]["$eq"] + elif ( + "document_id" in condition + and "$eq" in condition["document_id"] + ): + document_ids.add(UUID(condition["document_id"]["$eq"])) + elif "document_id" in filters: + doc_id = filters["document_id"] + if isinstance(doc_id, str): + document_ids.add(UUID(doc_id)) + elif isinstance(doc_id, UUID): + document_ids.add(doc_id) + elif isinstance(doc_id, dict) and "$eq" in doc_id: + value = doc_id["$eq"] + document_ids.add( + UUID(value) if isinstance(value, str) else value + ) + + # Delete matching chunks from the database + delete_results = await self.providers.database.chunks_handler.delete( + transformed_filters + ) + + # Extract the document_ids that were affected. + affected_doc_ids = { + UUID(info["document_id"]) + for info in delete_results.values() + if info.get("document_id") + } + document_ids.update(affected_doc_ids) + + # Check if the document still has any chunks left + docs_to_delete = [] + for doc_id in document_ids: + documents_overview_response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, limit=1, filter_document_ids=[doc_id] + ) + if not documents_overview_response["results"]: + raise R2RException( + status_code=404, message="Document not found" + ) + + document = documents_overview_response["results"][0] + + for collection_id in document.collection_ids: + await self.providers.database.collections_handler.decrement_collection_document_count( + collection_id=collection_id + ) + + if owner_id and str(document.owner_id) != owner_id: + raise R2RException( + status_code=404, + message="Document not found or insufficient permissions", + ) + docs_to_delete.append(doc_id) + + # Delete documents that no longer have associated chunks + for doc_id in docs_to_delete: + # Delete related entities & relationships if needed: + await self.providers.database.graphs_handler.entities.delete( + parent_id=doc_id, + store_type=StoreType.DOCUMENTS, + ) + await self.providers.database.graphs_handler.relationships.delete( + parent_id=doc_id, + store_type=StoreType.DOCUMENTS, + ) + + # Finally, delete the document from documents_overview: + await self.providers.database.documents_handler.delete( + document_id=doc_id + ) + + return { + "success": True, + "deleted_chunks_count": len(delete_results), + "deleted_documents_count": len(docs_to_delete), + "deleted_document_ids": [str(d) for d in docs_to_delete], + } + + async def download_file( + self, document_id: UUID + ) -> Optional[Tuple[str, BinaryIO, int]]: + if result := await self.providers.database.files_handler.retrieve_file( + document_id + ): + return result + return None + + async def export_files( + self, + document_ids: Optional[list[UUID]] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> tuple[str, BinaryIO, int]: + return ( + await self.providers.database.files_handler.retrieve_files_as_zip( + document_ids=document_ids, + start_date=start_date, + end_date=end_date, + ) + ) + + async def export_collections( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.collections_handler.export_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_documents( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.documents_handler.export_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_document_entities( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.entities.export_to_csv( + parent_id=id, + store_type=StoreType.DOCUMENTS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_document_relationships( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.relationships.export_to_csv( + parent_id=id, + store_type=StoreType.DOCUMENTS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_conversations( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.conversations_handler.export_conversations_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_graph_entities( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.entities.export_to_csv( + parent_id=id, + store_type=StoreType.GRAPHS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_graph_relationships( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.relationships.export_to_csv( + parent_id=id, + store_type=StoreType.GRAPHS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_graph_communities( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.communities.export_to_csv( + parent_id=id, + store_type=StoreType.GRAPHS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_messages( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.conversations_handler.export_messages_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_users( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.users_handler.export_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def documents_overview( + self, + offset: int, + limit: int, + user_ids: Optional[list[UUID]] = None, + collection_ids: Optional[list[UUID]] = None, + document_ids: Optional[list[UUID]] = None, + ): + return await self.providers.database.documents_handler.get_documents_overview( + offset=offset, + limit=limit, + filter_document_ids=document_ids, + filter_user_ids=user_ids, + filter_collection_ids=collection_ids, + ) + + async def update_document_metadata( + self, + document_id: UUID, + metadata: list[dict], + overwrite: bool = False, + ): + return await self.providers.database.documents_handler.update_document_metadata( + document_id=document_id, + metadata=metadata, + overwrite=overwrite, + ) + + async def list_document_chunks( + self, + document_id: UUID, + offset: int, + limit: int, + include_vectors: bool = False, + ): + return ( + await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=offset, + limit=limit, + include_vectors=include_vectors, + ) + ) + + async def assign_document_to_collection( + self, document_id: UUID, collection_id: UUID + ): + await self.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id, collection_id + ) + await self.providers.database.collections_handler.assign_document_to_collection_relational( + document_id, collection_id + ) + await self.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await self.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, + ) + + return {"message": "Document assigned to collection successfully"} + + async def remove_document_from_collection( + self, document_id: UUID, collection_id: UUID + ): + await self.providers.database.collections_handler.remove_document_from_collection_relational( + document_id, collection_id + ) + await self.providers.database.chunks_handler.remove_document_from_collection_vector( + document_id, collection_id + ) + # await self.providers.database.graphs_handler.delete_node_via_document_id( + # document_id, collection_id + # ) + return None + + def _process_relationships( + self, relationships: list[Tuple[str, str, str]] + ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]: + graph = defaultdict(list) + grouped: dict[str, dict[str, list[str]]] = defaultdict( + lambda: defaultdict(list) + ) + for subject, relation, obj in relationships: + graph[subject].append(obj) + grouped[subject][relation].append(obj) + if obj not in graph: + graph[obj] = [] + return dict(graph), dict(grouped) + + def generate_output( + self, + grouped_relationships: dict[str, dict[str, list[str]]], + graph: dict[str, list[str]], + descriptions_dict: dict[str, str], + print_descriptions: bool = True, + ) -> list[str]: + output = [] + # Print grouped relationships + for subject, relations in grouped_relationships.items(): + output.append(f"\n== {subject} ==") + if print_descriptions and subject in descriptions_dict: + output.append(f"\tDescription: {descriptions_dict[subject]}") + for relation, objects in relations.items(): + output.append(f" {relation}:") + for obj in objects: + output.append(f" - {obj}") + if print_descriptions and obj in descriptions_dict: + output.append( + f" Description: {descriptions_dict[obj]}" + ) + + # Print basic graph statistics + output.extend( + [ + "\n== Graph Statistics ==", + f"Number of nodes: {len(graph)}", + f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}", + f"Number of connected components: {self._count_connected_components(graph)}", + ] + ) + + # Find central nodes + central_nodes = self._get_central_nodes(graph) + output.extend( + [ + "\n== Most Central Nodes ==", + *( + f" {node}: {centrality:.4f}" + for node, centrality in central_nodes + ), + ] + ) + + return output + + def _count_connected_components(self, graph: dict[str, list[str]]) -> int: + visited = set() + components = 0 + + def dfs(node): + visited.add(node) + for neighbor in graph[node]: + if neighbor not in visited: + dfs(neighbor) + + for node in graph: + if node not in visited: + dfs(node) + components += 1 + + return components + + def _get_central_nodes( + self, graph: dict[str, list[str]] + ) -> list[Tuple[str, float]]: + degree = {node: len(neighbors) for node, neighbors in graph.items()} + total_nodes = len(graph) + centrality = { + node: deg / (total_nodes - 1) for node, deg in degree.items() + } + return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5] + + async def create_collection( + self, + owner_id: UUID, + name: Optional[str] = None, + description: str | None = None, + ) -> CollectionResponse: + result = await self.providers.database.collections_handler.create_collection( + owner_id=owner_id, + name=name, + description=description, + ) + await self.providers.database.graphs_handler.create( + collection_id=result.id, + name=name, + description=description, + ) + return result + + async def update_collection( + self, + collection_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + generate_description: bool = False, + ) -> CollectionResponse: + if generate_description: + description = await self.summarize_collection( + id=collection_id, offset=0, limit=100 + ) + return await self.providers.database.collections_handler.update_collection( + collection_id=collection_id, + name=name, + description=description, + ) + + async def delete_collection(self, collection_id: UUID) -> bool: + await self.providers.database.collections_handler.delete_collection_relational( + collection_id + ) + await self.providers.database.chunks_handler.delete_collection_vector( + collection_id + ) + try: + await self.providers.database.graphs_handler.delete( + collection_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Error deleting graph for collection {collection_id}: {e}" + ) + return True + + async def collections_overview( + self, + offset: int, + limit: int, + user_ids: Optional[list[UUID]] = None, + document_ids: Optional[list[UUID]] = None, + collection_ids: Optional[list[UUID]] = None, + ) -> dict[str, list[CollectionResponse] | int]: + return await self.providers.database.collections_handler.get_collections_overview( + offset=offset, + limit=limit, + filter_user_ids=user_ids, + filter_document_ids=document_ids, + filter_collection_ids=collection_ids, + ) + + async def add_user_to_collection( + self, user_id: UUID, collection_id: UUID + ) -> bool: + return ( + await self.providers.database.users_handler.add_user_to_collection( + user_id, collection_id + ) + ) + + async def remove_user_from_collection( + self, user_id: UUID, collection_id: UUID + ) -> bool: + return await self.providers.database.users_handler.remove_user_from_collection( + user_id, collection_id + ) + + async def get_users_in_collection( + self, collection_id: UUID, offset: int = 0, limit: int = 100 + ) -> dict[str, list[User] | int]: + return await self.providers.database.users_handler.get_users_in_collection( + collection_id, offset=offset, limit=limit + ) + + async def documents_in_collection( + self, collection_id: UUID, offset: int = 0, limit: int = 100 + ) -> dict[str, list[DocumentResponse] | int]: + return await self.providers.database.collections_handler.documents_in_collection( + collection_id, offset=offset, limit=limit + ) + + async def summarize_collection( + self, id: UUID, offset: int, limit: int + ) -> str: + documents_in_collection_response = await self.documents_in_collection( + collection_id=id, + offset=offset, + limit=limit, + ) + + document_summaries = [ + document.summary + for document in documents_in_collection_response["results"] # type: ignore + ] + + logger.info( + f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents." + ) + + formatted_summaries = "\n\n".join(document_summaries) # type: ignore + + messages = await self.providers.database.prompts_handler.get_message_payload( + system_prompt_name=self.config.database.collection_summary_system_prompt, + task_prompt_name=self.config.database.collection_summary_prompt, + task_inputs={"document_summaries": formatted_summaries}, + ) + + response = await self.providers.llm.aget_completion( + messages=messages, + generation_config=GenerationConfig( + model=self.config.ingestion.document_summary_model + or self.config.app.fast_llm + ), + ) + + if collection_summary := response.choices[0].message.content: + return collection_summary + else: + raise ValueError("Expected a generated response.") + + async def add_prompt( + self, name: str, template: str, input_types: dict[str, str] + ) -> dict: + try: + await self.providers.database.prompts_handler.add_prompt( + name, template, input_types + ) + return f"Prompt '{name}' added successfully." # type: ignore + except ValueError as e: + raise R2RException(status_code=400, message=str(e)) from e + + async def get_cached_prompt( + self, + prompt_name: str, + inputs: Optional[dict[str, Any]] = None, + prompt_override: Optional[str] = None, + ) -> dict: + try: + return { + "message": ( + await self.providers.database.prompts_handler.get_cached_prompt( + prompt_name=prompt_name, + inputs=inputs, + prompt_override=prompt_override, + ) + ) + } + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def get_prompt( + self, + prompt_name: str, + inputs: Optional[dict[str, Any]] = None, + prompt_override: Optional[str] = None, + ) -> dict: + try: + return await self.providers.database.prompts_handler.get_prompt( # type: ignore + name=prompt_name, + inputs=inputs, + prompt_override=prompt_override, + ) + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def get_all_prompts(self) -> dict[str, Prompt]: + return await self.providers.database.prompts_handler.get_all_prompts() + + async def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = None, + ) -> dict: + try: + await self.providers.database.prompts_handler.update_prompt( + name, template, input_types + ) + return f"Prompt '{name}' updated successfully." # type: ignore + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def delete_prompt(self, name: str) -> dict: + try: + await self.providers.database.prompts_handler.delete_prompt(name) + return {"message": f"Prompt '{name}' deleted successfully."} + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def get_conversation( + self, + conversation_id: UUID, + user_ids: Optional[list[UUID]] = None, + ) -> list[MessageResponse]: + return await self.providers.database.conversations_handler.get_conversation( + conversation_id=conversation_id, + filter_user_ids=user_ids, + ) + + async def create_conversation( + self, + user_id: Optional[UUID] = None, + name: Optional[str] = None, + ) -> ConversationResponse: + return await self.providers.database.conversations_handler.create_conversation( + user_id=user_id, + name=name, + ) + + async def conversations_overview( + self, + offset: int, + limit: int, + conversation_ids: Optional[list[UUID]] = None, + user_ids: Optional[list[UUID]] = None, + ) -> dict[str, list[dict] | int]: + return await self.providers.database.conversations_handler.get_conversations_overview( + offset=offset, + limit=limit, + filter_user_ids=user_ids, + conversation_ids=conversation_ids, + ) + + async def add_message( + self, + conversation_id: UUID, + content: Message, + parent_id: Optional[UUID] = None, + metadata: Optional[dict] = None, + ) -> MessageResponse: + return await self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=content, + parent_id=parent_id, + metadata=metadata, + ) + + async def edit_message( + self, + message_id: UUID, + new_content: Optional[str] = None, + additional_metadata: Optional[dict] = None, + ) -> dict[str, Any]: + return ( + await self.providers.database.conversations_handler.edit_message( + message_id=message_id, + new_content=new_content, + additional_metadata=additional_metadata or {}, + ) + ) + + async def update_conversation( + self, conversation_id: UUID, name: str + ) -> ConversationResponse: + return await self.providers.database.conversations_handler.update_conversation( + conversation_id=conversation_id, name=name + ) + + async def delete_conversation( + self, + conversation_id: UUID, + user_ids: Optional[list[UUID]] = None, + ) -> None: + await ( + self.providers.database.conversations_handler.delete_conversation( + conversation_id=conversation_id, + filter_user_ids=user_ids, + ) + ) + + async def get_user_max_documents(self, user_id: UUID) -> int | None: + # Fetch the user to see if they have any overrides stored + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if user.limits_overrides and "max_documents" in user.limits_overrides: + return user.limits_overrides["max_documents"] + return self.config.app.default_max_documents_per_user + + async def get_user_max_chunks(self, user_id: UUID) -> int | None: + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if user.limits_overrides and "max_chunks" in user.limits_overrides: + return user.limits_overrides["max_chunks"] + return self.config.app.default_max_chunks_per_user + + async def get_user_max_collections(self, user_id: UUID) -> int | None: + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if ( + user.limits_overrides + and "max_collections" in user.limits_overrides + ): + return user.limits_overrides["max_collections"] + return self.config.app.default_max_collections_per_user + + async def get_max_upload_size_by_type( + self, user_id: UUID, file_type_or_ext: str + ) -> int: + """Return the maximum allowed upload size (in bytes) for the given + user's file type/extension. Respects user-level overrides if present, + falling back to the system config. + + ```json + { + "limits_overrides": { + "max_file_size": 20_000_000, + "max_file_size_by_type": + { + "pdf": 50_000_000, + "docx": 30_000_000 + }, + ... + } + } + ``` + """ + # 1. Normalize extension + ext = file_type_or_ext.lower().lstrip(".") + + # 2. Fetch user from DB to see if we have any overrides + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + user_overrides = user.limits_overrides or {} + + # 3. Check if there's a user-level override for "max_file_size_by_type" + user_file_type_limits = user_overrides.get("max_file_size_by_type", {}) + if ext in user_file_type_limits: + return user_file_type_limits[ext] + + # 4. If not, check if there's a user-level fallback "max_file_size" + if "max_file_size" in user_overrides: + return user_overrides["max_file_size"] + + # 5. If none exist at user level, use system config + # Example config paths: + system_type_limits = self.config.app.max_upload_size_by_type + if ext in system_type_limits: + return system_type_limits[ext] + + # 6. Otherwise, return the global default + return self.config.app.default_max_upload_size + + async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: + """ + Return a dictionary containing: + - The system default limits (from self.config.limits) + - The user's overrides (from user.limits_overrides) + - The final 'effective' set of limits after merging (overall) + - The usage for each relevant limit (per-route usage, etc.) + """ + # 1) Fetch the user + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + user_overrides = user.limits_overrides or {} + + # 2) Grab system defaults + system_defaults = { + "global_per_min": self.config.database.limits.global_per_min, + "route_per_min": self.config.database.limits.route_per_min, + "monthly_limit": self.config.database.limits.monthly_limit, + # Add additional fields if your LimitSettings has them + } + + # 3) Build the overall (global) "effective limits" ignoring any specific route + overall_effective = ( + self.providers.database.limits_handler.determine_effective_limits( + user, route="" + ) + ) + + # 4) Build usage data. We'll do top-level usage for global_per_min/monthly, + # then do route-by-route usage in a loop. + usage: dict[str, Any] = {} + now = datetime.now(timezone.utc) + one_min_ago = now - timedelta(minutes=1) + + # (a) Global usage (per-minute) + global_per_min_used = ( + await self.providers.database.limits_handler._count_requests( + user_id, route=None, since=one_min_ago + ) + ) + # (a2) Global usage (monthly) - i.e. usage across ALL routes + global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests( + user_id, route=None + ) + + usage["global_per_min"] = { + "used": global_per_min_used, + "limit": overall_effective.global_per_min, + "remaining": ( + overall_effective.global_per_min - global_per_min_used + if overall_effective.global_per_min is not None + else None + ), + } + usage["monthly_limit"] = { + "used": global_monthly_used, + "limit": overall_effective.monthly_limit, + "remaining": ( + overall_effective.monthly_limit - global_monthly_used + if overall_effective.monthly_limit is not None + else None + ), + } + + # (b) Route-level usage. We'll gather all routes from system + user overrides + system_route_limits = ( + self.config.database.route_limits + ) # dict[str, LimitSettings] + user_route_overrides = user_overrides.get("route_overrides", {}) + route_keys = set(system_route_limits.keys()) | set( + user_route_overrides.keys() + ) + + usage["routes"] = {} + for route in route_keys: + # 1) Get the final merged limits for this specific route + route_effective = self.providers.database.limits_handler.determine_effective_limits( + user, route + ) + + # 2) Count requests for the last minute on this route + route_per_min_used = ( + await self.providers.database.limits_handler._count_requests( + user_id, route, one_min_ago + ) + ) + + # 3) Count route-specific monthly usage + route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests( + user_id, route + ) + + usage["routes"][route] = { + "route_per_min": { + "used": route_per_min_used, + "limit": route_effective.route_per_min, + "remaining": ( + route_effective.route_per_min - route_per_min_used + if route_effective.route_per_min is not None + else None + ), + }, + "monthly_limit": { + "used": route_monthly_used, + "limit": route_effective.monthly_limit, + "remaining": ( + route_effective.monthly_limit - route_monthly_used + if route_effective.monthly_limit is not None + else None + ), + }, + } + + max_documents = await self.get_user_max_documents(user_id) + used_documents = ( + await self.providers.database.documents_handler.get_documents_overview( + limit=1, offset=0, filter_user_ids=[user_id] + ) + )["total_entries"] + max_chunks = await self.get_user_max_chunks(user_id) + used_chunks = ( + await self.providers.database.chunks_handler.list_chunks( + limit=1, offset=0, filters={"owner_id": user_id} + ) + )["total_entries"] + + max_collections = await self.get_user_max_collections(user_id) + used_collections: int = ( # type: ignore + await self.providers.database.collections_handler.get_collections_overview( + limit=1, offset=0, filter_user_ids=[user_id] + ) + )["total_entries"] + + storage_limits = { + "chunks": { + "limit": max_chunks, + "used": used_chunks, + "remaining": ( + max_chunks - used_chunks + if max_chunks is not None + else None + ), + }, + "documents": { + "limit": max_documents, + "used": used_documents, + "remaining": ( + max_documents - used_documents + if max_documents is not None + else None + ), + }, + "collections": { + "limit": max_collections, + "used": used_collections, + "remaining": ( + max_collections - used_collections + if max_collections is not None + else None + ), + }, + } + # 5) Return a structured response + return { + "storage_limits": storage_limits, + "system_defaults": system_defaults, + "user_overrides": user_overrides, + "effective_limits": { + "global_per_min": overall_effective.global_per_min, + "route_per_min": overall_effective.route_per_min, + "monthly_limit": overall_effective.monthly_limit, + }, + "usage": usage, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py new file mode 100644 index 00000000..2ae4af31 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py @@ -0,0 +1,2087 @@ +import asyncio +import json +import logging +from copy import deepcopy +from datetime import datetime +from typing import Any, AsyncGenerator, Literal, Optional +from uuid import UUID + +from fastapi import HTTPException + +from core import ( + Citation, + R2RRAGAgent, + R2RStreamingRAGAgent, + R2RStreamingResearchAgent, + R2RXMLToolsRAGAgent, + R2RXMLToolsResearchAgent, + R2RXMLToolsStreamingRAGAgent, + R2RXMLToolsStreamingResearchAgent, +) +from core.agent.research import R2RResearchAgent +from core.base import ( + AggregateSearchResult, + ChunkSearchResult, + DocumentResponse, + GenerationConfig, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, + GraphSearchResult, + GraphSearchResultType, + IngestionStatus, + Message, + R2RException, + SearchSettings, + WebSearchResult, + format_search_results_for_llm, +) +from core.base.api.models import RAGResponse, User +from core.utils import ( + CitationTracker, + SearchResultsCollector, + SSEFormatter, + dump_collector, + dump_obj, + extract_citations, + find_new_citation_spans, + num_tokens_from_messages, +) +from shared.api.models.management.responses import MessageResponse + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + + +class AgentFactory: + """ + Factory class that creates appropriate agent instances based on mode, + model type, and streaming preferences. + """ + + @staticmethod + def create_agent( + mode: Literal["rag", "research"], + database_provider, + llm_provider, + config, # : AgentConfig + search_settings, # : SearchSettings + generation_config, #: GenerationConfig + app_config, #: AppConfig + knowledge_search_method, + content_method, + file_search_method, + max_tool_context_length: int = 32_768, + rag_tools: Optional[list[str]] = None, + research_tools: Optional[list[str]] = None, + tools: Optional[list[str]] = None, # For backward compatibility + ): + """ + Creates and returns the appropriate agent based on provided parameters. + + Args: + mode: Either "rag" or "research" to determine agent type + database_provider: Provider for database operations + llm_provider: Provider for LLM operations + config: Agent configuration + search_settings: Search settings for retrieval + generation_config: Generation configuration with LLM parameters + app_config: Application configuration + knowledge_search_method: Method for knowledge search + content_method: Method for content retrieval + file_search_method: Method for file search + max_tool_context_length: Maximum context length for tools + rag_tools: Tools specifically for RAG mode + research_tools: Tools specifically for Research mode + tools: Deprecated backward compatibility parameter + + Returns: + An appropriate agent instance + """ + # Create a deep copy of the config to avoid modifying the original + agent_config = deepcopy(config) + + # Handle tool specifications based on mode + if mode == "rag": + # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults + if rag_tools: + agent_config.rag_tools = rag_tools + elif tools: # Backward compatibility + agent_config.rag_tools = tools + # If neither was provided, the config's default rag_tools will be used + elif mode == "research": + # For Research mode, prioritize explicitly passed research_tools, then tools, then config defaults + if research_tools: + agent_config.research_tools = research_tools + elif tools: # Backward compatibility + agent_config.research_tools = tools + # If neither was provided, the config's default research_tools will be used + + # Determine if we need XML-based tools based on model + use_xml_format = False + # if generation_config.model: + # model_str = generation_config.model.lower() + # use_xml_format = "deepseek" in model_str or "gemini" in model_str + + # Set streaming mode based on generation config + is_streaming = generation_config.stream + + # Create the appropriate agent based on all factors + if mode == "rag": + # RAG mode agents + if is_streaming: + if use_xml_format: + return R2RXMLToolsStreamingRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RStreamingRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + if use_xml_format: + return R2RXMLToolsRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + # Research mode agents + if is_streaming: + if use_xml_format: + return R2RXMLToolsStreamingResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RStreamingResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + if use_xml_format: + return R2RXMLToolsResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + + +class RetrievalService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def search( + self, + query: str, + search_settings: SearchSettings = SearchSettings(), + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Depending on search_settings.search_strategy, fan out + to basic, hyde, or rag_fusion method. Each returns + an AggregateSearchResult that includes chunk + graph results. + """ + strategy = search_settings.search_strategy.lower() + + if strategy == "hyde": + return await self._hyde_search(query, search_settings) + elif strategy == "rag_fusion": + return await self._rag_fusion_search(query, search_settings) + else: + # 'vanilla', 'basic', or anything else... + return await self._basic_search(query, search_settings) + + async def _basic_search( + self, query: str, search_settings: SearchSettings + ) -> AggregateSearchResult: + """ + 1) Possibly embed the query (if semantic or hybrid). + 2) Chunk search. + 3) Graph search. + 4) Combine into an AggregateSearchResult. + """ + # -- 1) Possibly embed the query + query_vector = None + if ( + search_settings.use_semantic_search + or search_settings.use_hybrid_search + ): + query_vector = ( + await self.providers.completion_embedding.async_get_embedding( + query # , EmbeddingPurpose.QUERY + ) + ) + + # -- 2) Chunk search + chunk_results = [] + if search_settings.chunk_settings.enabled: + chunk_results = await self._vector_search_logic( + query_text=query, + search_settings=search_settings, + precomputed_vector=query_vector, # Pass in the vector we just computed (if any) + ) + + # -- 3) Graph search + graph_results = [] + if search_settings.graph_settings.enabled: + graph_results = await self._graph_search_logic( + query_text=query, + search_settings=search_settings, + precomputed_vector=query_vector, # same idea + ) + + # -- 4) Combine + return AggregateSearchResult( + chunk_search_results=chunk_results, + graph_search_results=graph_results, + ) + + async def _rag_fusion_search( + self, query: str, search_settings: SearchSettings + ) -> AggregateSearchResult: + """ + Implements 'RAG Fusion': + 1) Generate N sub-queries from the user query + 2) For each sub-query => do chunk & graph search + 3) Combine / fuse all retrieved results using Reciprocal Rank Fusion + 4) Return an AggregateSearchResult + """ + + # 1) Generate sub-queries from the user’s original query + # Typically you want the original query to remain in the set as well, + # so that we do not lose the exact user intent. + sub_queries = [query] + if search_settings.num_sub_queries > 1: + # Generate (num_sub_queries - 1) rephrasings + # (Or just generate exactly search_settings.num_sub_queries, + # and remove the first if you prefer.) + extra = await self._generate_similar_queries( + query=query, + num_sub_queries=search_settings.num_sub_queries - 1, + ) + sub_queries.extend(extra) + + # 2) For each sub-query => do chunk + graph search + # We’ll store them in a structure so we can fuse them. + # chunk_results_list is a list of lists of ChunkSearchResult + # graph_results_list is a list of lists of GraphSearchResult + chunk_results_list = [] + graph_results_list = [] + + for sq in sub_queries: + # Recompute or reuse the embedding if desired + # (You could do so, but not mandatory if you have a local approach) + # chunk + graph search + aggr = await self._basic_search(sq, search_settings) + chunk_results_list.append(aggr.chunk_search_results) + graph_results_list.append(aggr.graph_search_results) + + # 3) Fuse the chunk results and fuse the graph results. + # We'll use a simple RRF approach: each sub-query's result list + # is a ranking from best to worst. + fused_chunk_results = self._reciprocal_rank_fusion_chunks( # type: ignore + chunk_results_list # type: ignore + ) + filtered_graph_results = [ + results for results in graph_results_list if results is not None + ] + fused_graph_results = self._reciprocal_rank_fusion_graphs( + filtered_graph_results + ) + + # Optionally, after the RRF, you may want to do a final semantic re-rank + # of the fused results by the user’s original query. + # E.g.: + if fused_chunk_results: + fused_chunk_results = ( + await self.providers.completion_embedding.arerank( + query=query, + results=fused_chunk_results, + limit=search_settings.limit, + ) + ) + + # Sort or slice the graph results if needed: + if fused_graph_results and search_settings.include_scores: + fused_graph_results.sort( + key=lambda g: g.score if g.score is not None else 0.0, + reverse=True, + ) + fused_graph_results = fused_graph_results[: search_settings.limit] + + # 4) Return final AggregateSearchResult + return AggregateSearchResult( + chunk_search_results=fused_chunk_results, + graph_search_results=fused_graph_results, + ) + + async def _generate_similar_queries( + self, query: str, num_sub_queries: int = 2 + ) -> list[str]: + """ + Use your LLM to produce 'similar' queries or rephrasings + that might retrieve different but relevant documents. + + You can prompt your model with something like: + "Given the user query, produce N alternative short queries that + capture possible interpretations or expansions. + Keep them relevant to the user's intent." + """ + if num_sub_queries < 1: + return [] + + # In production, you'd fetch a prompt from your prompts DB: + # Something like: + prompt = f""" + You are a helpful assistant. The user query is: "{query}" + Generate {num_sub_queries} alternative search queries that capture + slightly different phrasings or expansions while preserving the core meaning. + Return each alternative on its own line. + """ + + # For a short generation, we can set minimal tokens + gen_config = GenerationConfig( + model=self.config.app.fast_llm, + max_tokens=128, + temperature=0.8, + stream=False, + ) + response = await self.providers.llm.aget_completion( + messages=[{"role": "system", "content": prompt}], + generation_config=gen_config, + ) + raw_text = ( + response.choices[0].message.content.strip() + if response.choices[0].message.content is not None + else "" + ) + + # Suppose each line is a sub-query + lines = [line.strip() for line in raw_text.split("\n") if line.strip()] + return lines[:num_sub_queries] + + def _reciprocal_rank_fusion_chunks( + self, list_of_rankings: list[list[ChunkSearchResult]], k: float = 60.0 + ) -> list[ChunkSearchResult]: + """ + Simple RRF for chunk results. + list_of_rankings is something like: + [ + [chunkA, chunkB, chunkC], # sub-query #1, in order + [chunkC, chunkD], # sub-query #2, in order + ... + ] + + We'll produce a dictionary mapping chunk.id -> aggregated_score, + then sort descending. + """ + if not list_of_rankings: + return [] + + # Build a map of chunk_id => final_rff_score + score_map: dict[str, float] = {} + + # We also need to store a reference to the chunk object + # (the "first" or "best" instance), so we can reconstruct them later + chunk_map: dict[str, Any] = {} + + for ranking_list in list_of_rankings: + for rank, chunk_result in enumerate(ranking_list, start=1): + if not chunk_result.id: + # fallback if no chunk_id is present + continue + + c_id = chunk_result.id + # RRF scoring + # score = sum(1 / (k + rank)) for each sub-query ranking + # We'll accumulate it. + existing_score = score_map.get(str(c_id), 0.0) + new_score = existing_score + 1.0 / (k + rank) + score_map[str(c_id)] = new_score + + # Keep a reference to chunk + if c_id not in chunk_map: + chunk_map[str(c_id)] = chunk_result + + # Now sort by final score + fused_items = sorted( + score_map.items(), key=lambda x: x[1], reverse=True + ) + + # Rebuild the final list of chunk results with new 'score' + fused_chunks = [] + for c_id, agg_score in fused_items: # type: ignore + # copy the chunk + c = chunk_map[str(c_id)] + # Optionally store the RRF score if you want + c.score = agg_score + fused_chunks.append(c) + + return fused_chunks + + def _reciprocal_rank_fusion_graphs( + self, list_of_rankings: list[list[GraphSearchResult]], k: float = 60.0 + ) -> list[GraphSearchResult]: + """ + Similar RRF logic but for graph results. + """ + if not list_of_rankings: + return [] + + score_map: dict[str, float] = {} + graph_map = {} + + for ranking_list in list_of_rankings: + for rank, g_result in enumerate(ranking_list, start=1): + # We'll do a naive ID approach: + # If your GraphSearchResult has a unique ID in g_result.content.id or so + # we can use that as a key. + # If not, you might have to build a key from the content. + g_id = None + if hasattr(g_result.content, "id"): + g_id = str(g_result.content.id) + else: + # fallback + g_id = f"graph_{hash(g_result.content.json())}" + + existing_score = score_map.get(g_id, 0.0) + new_score = existing_score + 1.0 / (k + rank) + score_map[g_id] = new_score + + if g_id not in graph_map: + graph_map[g_id] = g_result + + # Sort descending by aggregated RRF score + fused_items = sorted( + score_map.items(), key=lambda x: x[1], reverse=True + ) + + fused_graphs = [] + for g_id, agg_score in fused_items: + g = graph_map[g_id] + g.score = agg_score + fused_graphs.append(g) + + return fused_graphs + + async def _hyde_search( + self, query: str, search_settings: SearchSettings + ) -> AggregateSearchResult: + """ + 1) Generate N hypothetical docs via LLM + 2) For each doc => embed => parallel chunk search & graph search + 3) Merge chunk results => optional re-rank => top K + 4) Merge graph results => (optionally re-rank or keep them distinct) + """ + # 1) Generate hypothetical docs + hyde_docs = await self._run_hyde_generation( + query=query, num_sub_queries=search_settings.num_sub_queries + ) + + chunk_all = [] + graph_all = [] + + # We'll gather the per-doc searches in parallel + tasks = [] + for hypothetical_text in hyde_docs: + tasks.append( + asyncio.create_task( + self._fanout_chunk_and_graph_search( + user_text=query, # The user’s original query + alt_text=hypothetical_text, # The hypothetical doc + search_settings=search_settings, + ) + ) + ) + + # 2) Wait for them all + results_list = await asyncio.gather(*tasks) + # each item in results_list is a tuple: (chunks, graphs) + + # Flatten chunk+graph results + for c_results, g_results in results_list: + chunk_all.extend(c_results) + graph_all.extend(g_results) + + # 3) Re-rank chunk results with the original query + if chunk_all: + chunk_all = await self.providers.completion_embedding.arerank( + query=query, # final user query + results=chunk_all, + limit=int( + search_settings.limit * search_settings.num_sub_queries + ), + # no limit on results - limit=search_settings.limit, + ) + + # 4) If needed, re-rank graph results or just slice top-K by score + if search_settings.include_scores and graph_all: + graph_all.sort(key=lambda g: g.score or 0.0, reverse=True) + graph_all = ( + graph_all # no limit on results - [: search_settings.limit] + ) + + return AggregateSearchResult( + chunk_search_results=chunk_all, + graph_search_results=graph_all, + ) + + async def _fanout_chunk_and_graph_search( + self, + user_text: str, + alt_text: str, + search_settings: SearchSettings, + ) -> tuple[list[ChunkSearchResult], list[GraphSearchResult]]: + """ + 1) embed alt_text (HyDE doc or sub-query, etc.) + 2) chunk search + graph search with that embedding + """ + # Precompute the embedding of alt_text + vec = await self.providers.completion_embedding.async_get_embedding( + alt_text # , EmbeddingPurpose.QUERY + ) + + # chunk search + chunk_results = [] + if search_settings.chunk_settings.enabled: + chunk_results = await self._vector_search_logic( + query_text=user_text, # used for text-based stuff & re-ranking + search_settings=search_settings, + precomputed_vector=vec, # use the alt_text vector for semantic/hybrid + ) + + # graph search + graph_results = [] + if search_settings.graph_settings.enabled: + graph_results = await self._graph_search_logic( + query_text=user_text, # or alt_text if you prefer + search_settings=search_settings, + precomputed_vector=vec, + ) + + return (chunk_results, graph_results) + + async def _vector_search_logic( + self, + query_text: str, + search_settings: SearchSettings, + precomputed_vector: Optional[list[float]] = None, + ) -> list[ChunkSearchResult]: + """ + • If precomputed_vector is given, use it for semantic/hybrid search. + Otherwise embed query_text ourselves. + • Then do fulltext, semantic, or hybrid search. + • Optionally re-rank and return results. + """ + if not search_settings.chunk_settings.enabled: + return [] + + # 1) Possibly embed + query_vector = precomputed_vector + if query_vector is None and ( + search_settings.use_semantic_search + or search_settings.use_hybrid_search + ): + query_vector = ( + await self.providers.completion_embedding.async_get_embedding( + query_text # , EmbeddingPurpose.QUERY + ) + ) + + # 2) Choose which search to run + if ( + search_settings.use_fulltext_search + and search_settings.use_semantic_search + ) or search_settings.use_hybrid_search: + if query_vector is None: + raise ValueError("Hybrid search requires a precomputed vector") + raw_results = ( + await self.providers.database.chunks_handler.hybrid_search( + query_vector=query_vector, + query_text=query_text, + search_settings=search_settings, + ) + ) + elif search_settings.use_fulltext_search: + raw_results = ( + await self.providers.database.chunks_handler.full_text_search( + query_text=query_text, + search_settings=search_settings, + ) + ) + elif search_settings.use_semantic_search: + if query_vector is None: + raise ValueError( + "Semantic search requires a precomputed vector" + ) + raw_results = ( + await self.providers.database.chunks_handler.semantic_search( + query_vector=query_vector, + search_settings=search_settings, + ) + ) + else: + raise ValueError( + "At least one of use_fulltext_search or use_semantic_search must be True" + ) + + # 3) Re-rank + reranked = await self.providers.completion_embedding.arerank( + query=query_text, results=raw_results, limit=search_settings.limit + ) + + # 4) Possibly augment text or metadata + final_results = [] + for r in reranked: + if "title" in r.metadata and search_settings.include_metadatas: + title = r.metadata["title"] + r.text = f"Document Title: {title}\n\nText: {r.text}" + r.metadata["associated_query"] = query_text + final_results.append(r) + + return final_results + + async def _graph_search_logic( + self, + query_text: str, + search_settings: SearchSettings, + precomputed_vector: Optional[list[float]] = None, + ) -> list[GraphSearchResult]: + """ + Mirrors your previous GraphSearch approach: + • if precomputed_vector is supplied, use that + • otherwise embed query_text + • search entities, relationships, communities + • return results + """ + results: list[GraphSearchResult] = [] + + if not search_settings.graph_settings.enabled: + return results + + # 1) Possibly embed + query_embedding = precomputed_vector + if query_embedding is None: + query_embedding = ( + await self.providers.completion_embedding.async_get_embedding( + query_text + ) + ) + + base_limit = search_settings.limit + graph_limits = search_settings.graph_settings.limits or {} + + # Entity search + entity_limit = graph_limits.get("entities", base_limit) + entity_cursor = self.providers.database.graphs_handler.graph_search( + query_text, + search_type="entities", + limit=entity_limit, + query_embedding=query_embedding, + property_names=["name", "description", "id"], + filters=search_settings.filters, + ) + async for ent in entity_cursor: + score = ent.get("similarity_score") + metadata = ent.get("metadata", {}) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except Exception as e: + pass + + results.append( + GraphSearchResult( + id=ent.get("id", None), + content=GraphEntityResult( + name=ent.get("name", ""), + description=ent.get("description", ""), + id=ent.get("id", None), + ), + result_type=GraphSearchResultType.ENTITY, + score=score if search_settings.include_scores else None, + metadata=( + { + **(metadata or {}), + "associated_query": query_text, + } + if search_settings.include_metadatas + else {} + ), + ) + ) + + # Relationship search + rel_limit = graph_limits.get("relationships", base_limit) + rel_cursor = self.providers.database.graphs_handler.graph_search( + query_text, + search_type="relationships", + limit=rel_limit, + query_embedding=query_embedding, + property_names=[ + "id", + "subject", + "predicate", + "object", + "description", + "subject_id", + "object_id", + ], + filters=search_settings.filters, + ) + async for rel in rel_cursor: + score = rel.get("similarity_score") + metadata = rel.get("metadata", {}) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except Exception as e: + pass + + results.append( + GraphSearchResult( + id=ent.get("id", None), + content=GraphRelationshipResult( + id=rel.get("id", None), + subject=rel.get("subject", ""), + predicate=rel.get("predicate", ""), + object=rel.get("object", ""), + subject_id=rel.get("subject_id", None), + object_id=rel.get("object_id", None), + description=rel.get("description", ""), + ), + result_type=GraphSearchResultType.RELATIONSHIP, + score=score if search_settings.include_scores else None, + metadata=( + { + **(metadata or {}), + "associated_query": query_text, + } + if search_settings.include_metadatas + else {} + ), + ) + ) + + # Community search + comm_limit = graph_limits.get("communities", base_limit) + comm_cursor = self.providers.database.graphs_handler.graph_search( + query_text, + search_type="communities", + limit=comm_limit, + query_embedding=query_embedding, + property_names=[ + "id", + "name", + "summary", + ], + filters=search_settings.filters, + ) + async for comm in comm_cursor: + score = comm.get("similarity_score") + metadata = comm.get("metadata", {}) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except Exception as e: + pass + + results.append( + GraphSearchResult( + id=ent.get("id", None), + content=GraphCommunityResult( + id=comm.get("id", None), + name=comm.get("name", ""), + summary=comm.get("summary", ""), + ), + result_type=GraphSearchResultType.COMMUNITY, + score=score if search_settings.include_scores else None, + metadata=( + { + **(metadata or {}), + "associated_query": query_text, + } + if search_settings.include_metadatas + else {} + ), + ) + ) + + return results + + async def _run_hyde_generation( + self, + query: str, + num_sub_queries: int = 2, + ) -> list[str]: + """ + Calls the LLM with a 'HyDE' style prompt to produce multiple + hypothetical documents/answers, one per line or separated by blank lines. + """ + # Retrieve the prompt template from your database or config: + # e.g. your "hyde" prompt has placeholders: {message}, {num_outputs} + hyde_template = ( + await self.providers.database.prompts_handler.get_cached_prompt( + prompt_name="hyde", + inputs={"message": query, "num_outputs": num_sub_queries}, + ) + ) + + # Now call the LLM with that as the system or user prompt: + completion_config = GenerationConfig( + model=self.config.app.fast_llm, # or whichever short/cheap model + max_tokens=512, + temperature=0.7, + stream=False, + ) + + response = await self.providers.llm.aget_completion( + messages=[{"role": "system", "content": hyde_template}], + generation_config=completion_config, + ) + + # Suppose the LLM returns something like: + # + # "Doc1. Some made up text.\n\nDoc2. Another made up text.\n\n" + # + # So we split by double-newline or some pattern: + raw_text = response.choices[0].message.content + return [ + chunk.strip() + for chunk in (raw_text or "").split("\n\n") + if chunk.strip() + ] + + async def search_documents( + self, + query: str, + settings: SearchSettings, + query_embedding: Optional[list[float]] = None, + ) -> list[DocumentResponse]: + if query_embedding is None: + query_embedding = ( + await self.providers.completion_embedding.async_get_embedding( + query + ) + ) + result = ( + await self.providers.database.documents_handler.search_documents( + query_text=query, + settings=settings, + query_embedding=query_embedding, + ) + ) + return result + + async def completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + *args, + **kwargs, + ): + return await self.providers.llm.aget_completion( + [message.to_dict() for message in messages], # type: ignore + generation_config, + *args, + **kwargs, + ) + + async def embedding( + self, + text: str, + ): + return await self.providers.completion_embedding.async_get_embedding( + text=text + ) + + async def rag( + self, + query: str, + rag_generation_config: GenerationConfig, + search_settings: SearchSettings = SearchSettings(), + system_prompt_name: str | None = None, + task_prompt_name: str | None = None, + include_web_search: bool = False, + **kwargs, + ) -> Any: + """ + A single RAG method that can do EITHER a one-shot synchronous RAG or + streaming SSE-based RAG, depending on rag_generation_config.stream. + + 1) Perform aggregator search => context + 2) Build system+task prompts => messages + 3) If not streaming => normal LLM call => return RAGResponse + 4) If streaming => return an async generator of SSE lines + """ + # 1) Possibly fix up any UUID filters in search_settings + for f, val in list(search_settings.filters.items()): + if isinstance(val, UUID): + search_settings.filters[f] = str(val) + + try: + # 2) Perform search => aggregated_results + aggregated_results = await self.search(query, search_settings) + # 3) Optionally add web search results if flag is enabled + if include_web_search: + web_results = await self._perform_web_search(query) + # Merge web search results with existing aggregated results + if web_results and web_results.web_search_results: + if not aggregated_results.web_search_results: + aggregated_results.web_search_results = ( + web_results.web_search_results + ) + else: + aggregated_results.web_search_results.extend( + web_results.web_search_results + ) + # 3) Build context from aggregator + collector = SearchResultsCollector() + collector.add_aggregate_result(aggregated_results) + context_str = format_search_results_for_llm( + aggregated_results, collector + ) + + # 4) Prepare system+task messages + system_prompt_name = system_prompt_name or "system" + task_prompt_name = task_prompt_name or "rag" + task_prompt = kwargs.get("task_prompt") + + messages = await self.providers.database.prompts_handler.get_message_payload( + system_prompt_name=system_prompt_name, + task_prompt_name=task_prompt_name, + task_inputs={"query": query, "context": context_str}, + task_prompt=task_prompt, + ) + + # 5) Check streaming vs. non-streaming + if not rag_generation_config.stream: + # ========== Non-Streaming Logic ========== + response = await self.providers.llm.aget_completion( + messages=messages, + generation_config=rag_generation_config, + ) + llm_text = response.choices[0].message.content + + # (a) Extract short-ID references from final text + raw_sids = extract_citations(llm_text or "") + + # (b) Possibly prune large content out of metadata + metadata = response.dict() + if "choices" in metadata and len(metadata["choices"]) > 0: + metadata["choices"][0]["message"].pop("content", None) + + # (c) Build final RAGResponse + rag_resp = RAGResponse( + generated_answer=llm_text or "", + search_results=aggregated_results, + citations=[ + Citation( + id=f"{sid}", + object="citation", + payload=dump_obj( # type: ignore + self._find_item_by_shortid(sid, collector) + ), + ) + for sid in raw_sids + ], + metadata=metadata, + completion=llm_text or "", + ) + return rag_resp + + else: + # ========== Streaming SSE Logic ========== + async def sse_generator() -> AsyncGenerator[str, None]: + # 1) Emit search results via SSEFormatter + async for line in SSEFormatter.yield_search_results_event( + aggregated_results + ): + yield line + + # Initialize citation tracker to manage citation state + citation_tracker = CitationTracker() + + # Store citation payloads by ID for reuse + citation_payloads = {} + + partial_text_buffer = "" + + # Begin streaming from the LLM + msg_stream = self.providers.llm.aget_completion_stream( + messages=messages, + generation_config=rag_generation_config, + ) + + try: + async for chunk in msg_stream: + delta = chunk.choices[0].delta + finish_reason = chunk.choices[0].finish_reason + # if delta.thinking: + # check if delta has `thinking` attribute + + if hasattr(delta, "thinking") and delta.thinking: + # Emit SSE "thinking" event + async for ( + line + ) in SSEFormatter.yield_thinking_event( + delta.thinking + ): + yield line + + if delta.content: + # (b) Emit SSE "message" event for this chunk of text + async for ( + line + ) in SSEFormatter.yield_message_event( + delta.content + ): + yield line + + # Accumulate new text + partial_text_buffer += delta.content + + # (a) Extract citations from updated buffer + # For each *new* short ID, emit an SSE "citation" event + # Find new citation spans in the accumulated text + new_citation_spans = find_new_citation_spans( + partial_text_buffer, citation_tracker + ) + + # Process each new citation span + for cid, spans in new_citation_spans.items(): + for span in spans: + # Check if this is the first time we've seen this citation ID + is_new_citation = ( + citation_tracker.is_new_citation( + cid + ) + ) + + # Get payload if it's a new citation + payload = None + if is_new_citation: + source_obj = ( + self._find_item_by_shortid( + cid, collector + ) + ) + if source_obj: + # Store payload for reuse + payload = dump_obj(source_obj) + citation_payloads[cid] = ( + payload + ) + + # Create citation event payload + citation_data = { + "id": cid, + "object": "citation", + "is_new": is_new_citation, + "span": { + "start": span[0], + "end": span[1], + }, + } + + # Only include full payload for new citations + if is_new_citation and payload: + citation_data["payload"] = payload + + # Emit the citation event + async for ( + line + ) in SSEFormatter.yield_citation_event( + citation_data + ): + yield line + + # If the LLM signals it’s done + if finish_reason == "stop": + # Prepare consolidated citations for final answer event + consolidated_citations = [] + # Group citations by ID with all their spans + for ( + cid, + spans, + ) in citation_tracker.get_all_spans().items(): + if cid in citation_payloads: + consolidated_citations.append( + { + "id": cid, + "object": "citation", + "spans": [ + { + "start": s[0], + "end": s[1], + } + for s in spans + ], + "payload": citation_payloads[ + cid + ], + } + ) + + # (c) Emit final answer + all collected citations + final_answer_evt = { + "id": "msg_final", + "object": "rag.final_answer", + "generated_answer": partial_text_buffer, + "citations": consolidated_citations, + } + async for ( + line + ) in SSEFormatter.yield_final_answer_event( + final_answer_evt + ): + yield line + + # (d) Signal the end of the SSE stream + yield SSEFormatter.yield_done_event() + break + + except Exception as e: + logger.error(f"Error streaming LLM in rag: {e}") + # Optionally yield an SSE "error" event or handle differently + raise + + return sse_generator() + + except Exception as e: + logger.exception(f"Error in RAG pipeline: {e}") + if "NoneType" in str(e): + raise HTTPException( + status_code=502, + detail="Server not reachable or returned an invalid response", + ) from e + raise HTTPException( + status_code=500, + detail=f"Internal RAG Error - {str(e)}", + ) from e + + def _find_item_by_shortid( + self, sid: str, collector: SearchResultsCollector + ) -> Optional[tuple[str, Any, int]]: + """ + Example helper that tries to match aggregator items by short ID, + meaning result_obj.id starts with sid. + """ + for source_type, result_obj in collector.get_all_results(): + # if the aggregator item has an 'id' attribute + if getattr(result_obj, "id", None) is not None: + full_id_str = str(result_obj.id) + if full_id_str.startswith(sid): + if source_type == "chunk": + return ( + result_obj.as_dict() + ) # (source_type, result_obj.as_dict()) + else: + return result_obj # (source_type, result_obj) + return None + + async def agent( + self, + rag_generation_config: GenerationConfig, + rag_tools: Optional[list[str]] = None, + tools: Optional[list[str]] = None, # backward compatibility + search_settings: SearchSettings = SearchSettings(), + task_prompt: Optional[str] = None, + include_title_if_available: Optional[bool] = False, + conversation_id: Optional[UUID] = None, + message: Optional[Message] = None, + messages: Optional[list[Message]] = None, + use_system_context: bool = False, + max_tool_context_length: int = 32_768, + research_tools: Optional[list[str]] = None, + research_generation_config: Optional[GenerationConfig] = None, + needs_initial_conversation_name: Optional[bool] = None, + mode: Optional[Literal["rag", "research"]] = "rag", + ): + """ + Engage with an intelligent agent for information retrieval, analysis, and research. + + Args: + rag_generation_config: Configuration for RAG mode generation + search_settings: Search configuration for retrieving context + task_prompt: Optional custom prompt override + include_title_if_available: Whether to include document titles + conversation_id: Optional conversation ID for continuity + message: Current message to process + messages: List of messages (deprecated) + use_system_context: Whether to use extended prompt + max_tool_context_length: Maximum context length for tools + rag_tools: List of tools for RAG mode + research_tools: List of tools for Research mode + research_generation_config: Configuration for Research mode generation + mode: Either "rag" or "research" + + Returns: + Agent response with messages and conversation ID + """ + try: + # Validate message inputs + if message and messages: + raise R2RException( + status_code=400, + message="Only one of message or messages should be provided", + ) + + if not message and not messages: + raise R2RException( + status_code=400, + message="Either message or messages should be provided", + ) + + # Ensure 'message' is a Message instance + if message and not isinstance(message, Message): + if isinstance(message, dict): + message = Message.from_dict(message) + else: + raise R2RException( + status_code=400, + message=""" + Invalid message format. The expected format contains: + role: MessageType | 'system' | 'user' | 'assistant' | 'function' + content: Optional[str] + name: Optional[str] + function_call: Optional[dict[str, Any]] + tool_calls: Optional[list[dict[str, Any]]] + """, + ) + + # Ensure 'messages' is a list of Message instances + if messages: + processed_messages = [] + for msg in messages: + if isinstance(msg, Message): + processed_messages.append(msg) + elif hasattr(msg, "dict"): + processed_messages.append( + Message.from_dict(msg.dict()) + ) + elif isinstance(msg, dict): + processed_messages.append(Message.from_dict(msg)) + else: + processed_messages.append(Message.from_dict(str(msg))) + messages = processed_messages + else: + messages = [] + + # Validate and process mode-specific configurations + if mode == "rag" and research_tools: + logger.warning( + "research_tools provided but mode is 'rag'. These tools will be ignored." + ) + research_tools = None + + # Determine effective generation config based on mode + effective_generation_config = rag_generation_config + if mode == "research" and research_generation_config: + effective_generation_config = research_generation_config + + # Set appropriate LLM model based on mode if not explicitly specified + if "model" not in effective_generation_config.__fields_set__: + if mode == "rag": + effective_generation_config.model = ( + self.config.app.quality_llm + ) + elif mode == "research": + effective_generation_config.model = ( + self.config.app.planning_llm + ) + + # Transform UUID filters to strings + for filter_key, value in search_settings.filters.items(): + if isinstance(value, UUID): + search_settings.filters[filter_key] = str(value) + + # Process conversation data + ids = [] + if conversation_id: # Fetch the existing conversation + try: + conversation_messages = await self.providers.database.conversations_handler.get_conversation( + conversation_id=conversation_id, + ) + if needs_initial_conversation_name is None: + overview = await self.providers.database.conversations_handler.get_conversations_overview( + offset=0, + limit=1, + conversation_ids=[conversation_id], + ) + if overview.get("total_entries", 0) > 0: + needs_initial_conversation_name = ( + overview.get("results")[0].get("name") is None # type: ignore + ) + except Exception as e: + logger.error(f"Error fetching conversation: {str(e)}") + + if conversation_messages is not None: + messages_from_conversation: list[Message] = [] + for message_response in conversation_messages: + if isinstance(message_response, MessageResponse): + messages_from_conversation.append( + message_response.message + ) + ids.append(message_response.id) + else: + logger.warning( + f"Unexpected type in conversation found: {type(message_response)}\n{message_response}" + ) + messages = messages_from_conversation + messages + else: # Create new conversation + conversation_response = await self.providers.database.conversations_handler.create_conversation() + conversation_id = conversation_response.id + needs_initial_conversation_name = True + + if message: + messages.append(message) + + if not messages: + raise R2RException( + status_code=400, + message="No messages to process", + ) + + current_message = messages[-1] + logger.debug( + f"Running the agent with conversation_id = {conversation_id} and message = {current_message}" + ) + + # Save the new message to the conversation + parent_id = ids[-1] if ids else None + message_response = await self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=current_message, + parent_id=parent_id, + ) + + message_id = ( + message_response.id if message_response is not None else None + ) + + # Extract filter information from search settings + filter_user_id, filter_collection_ids = ( + self._parse_user_and_collection_filters( + search_settings.filters + ) + ) + + # Validate system instruction configuration + if use_system_context and task_prompt: + raise R2RException( + status_code=400, + message="Both use_system_context and task_prompt cannot be True at the same time", + ) + + # Build the system instruction + if task_prompt: + system_instruction = task_prompt + else: + system_instruction = ( + await self._build_aware_system_instruction( + max_tool_context_length=max_tool_context_length, + filter_user_id=filter_user_id, + filter_collection_ids=filter_collection_ids, + model=effective_generation_config.model, + use_system_context=use_system_context, + mode=mode, + ) + ) + + # Configure agent with appropriate tools + agent_config = deepcopy(self.config.agent) + if mode == "rag": + # Use provided RAG tools or default from config + agent_config.rag_tools = ( + rag_tools or tools or self.config.agent.rag_tools + ) + else: # research mode + # Use provided Research tools or default from config + agent_config.research_tools = ( + research_tools or tools or self.config.agent.research_tools + ) + + # Create the agent using our factory + mode = mode or "rag" + + for msg in messages: + if msg.content is None: + msg.content = "" + + agent = AgentFactory.create_agent( + mode=mode, + database_provider=self.providers.database, + llm_provider=self.providers.llm, + config=agent_config, + search_settings=search_settings, + generation_config=effective_generation_config, + app_config=self.config.app, + knowledge_search_method=self.search, + content_method=self.get_context, + file_search_method=self.search_documents, + max_tool_context_length=max_tool_context_length, + rag_tools=rag_tools, + research_tools=research_tools, + tools=tools, # Backward compatibility + ) + + # Handle streaming vs. non-streaming response + if effective_generation_config.stream: + + async def stream_response(): + try: + async for chunk in agent.arun( + messages=messages, + system_instruction=system_instruction, + include_title_if_available=include_title_if_available, + ): + yield chunk + except Exception as e: + logger.error(f"Error streaming agent output: {e}") + raise e + finally: + # Persist conversation data + msgs = [ + msg.to_dict() + for msg in agent.conversation.messages + ] + input_tokens = num_tokens_from_messages(msgs[:-1]) + output_tokens = num_tokens_from_messages([msgs[-1]]) + await self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=agent.conversation.messages[-1], + parent_id=message_id, + metadata={ + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + ) + + # Generate conversation name if needed + if needs_initial_conversation_name: + try: + prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict())}" + conversation_name = ( + ( + await self.providers.llm.aget_completion( + [ + { + "role": "system", + "content": prompt, + } + ], + GenerationConfig( + model=self.config.app.fast_llm + ), + ) + ) + .choices[0] + .message.content + ) + await self.providers.database.conversations_handler.update_conversation( + conversation_id=conversation_id, + name=conversation_name, + ) + except Exception as e: + logger.error( + f"Error generating conversation name: {e}" + ) + + return stream_response() + else: + for idx, msg in enumerate(messages): + if msg.content is None: + if ( + hasattr(msg, "structured_content") + and msg.structured_content + ): + messages[idx].content = "" + else: + messages[idx].content = "" + + # Non-streaming path + results = await agent.arun( + messages=messages, + system_instruction=system_instruction, + include_title_if_available=include_title_if_available, + ) + + # Process the agent results + if isinstance(results[-1], dict): + if results[-1].get("content") is None: + results[-1]["content"] = "" + assistant_message = Message(**results[-1]) + elif isinstance(results[-1], Message): + assistant_message = results[-1] + if assistant_message.content is None: + assistant_message.content = "" + else: + assistant_message = Message( + role="assistant", content=str(results[-1]) + ) + + # Get search results collector for citations + if hasattr(agent, "search_results_collector"): + collector = agent.search_results_collector + else: + collector = SearchResultsCollector() + + # Extract content from the message + structured_content = assistant_message.structured_content + structured_content = ( + structured_content[-1].get("text") + if structured_content + else None + ) + raw_text = ( + assistant_message.content or structured_content or "" + ) + # Process citations + short_ids = extract_citations(raw_text or "") + final_citations = [] + for sid in short_ids: + obj = collector.find_by_short_id(sid) + final_citations.append( + { + "id": sid, + "object": "citation", + "payload": dump_obj(obj) if obj else None, + } + ) + + # Persist in conversation DB + await ( + self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=assistant_message, + parent_id=message_id, + metadata={ + "citations": final_citations, + "aggregated_search_result": json.dumps( + dump_collector(collector) + ), + }, + ) + ) + + # Generate conversation name if needed + if needs_initial_conversation_name: + conversation_name = None + try: + prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict() if message else {})}" + conversation_name = ( + ( + await self.providers.llm.aget_completion( + [{"role": "system", "content": prompt}], + GenerationConfig( + model=self.config.app.fast_llm + ), + ) + ) + .choices[0] + .message.content + ) + except Exception as e: + pass + finally: + await self.providers.database.conversations_handler.update_conversation( + conversation_id=conversation_id, + name=conversation_name or "", + ) + + tool_calls = [] + if hasattr(agent, "tool_calls"): + if agent.tool_calls is not None: + tool_calls = agent.tool_calls + else: + logger.warning( + "agent.tool_calls is None, using empty list instead" + ) + # Return the final response + return { + "messages": [ + Message( + role="assistant", + content=assistant_message.content + or structured_content + or "", + metadata={ + "citations": final_citations, + "tool_calls": tool_calls, + "aggregated_search_result": json.dumps( + dump_collector(collector) + ), + }, + ) + ], + "conversation_id": str(conversation_id), + } + + except Exception as e: + logger.error(f"Error in agent response: {str(e)}") + if "NoneType" in str(e): + raise HTTPException( + status_code=502, + detail="Server not reachable or returned an invalid response", + ) from e + raise HTTPException( + status_code=500, + detail=f"Internal Server Error - {str(e)}", + ) from e + + async def get_context( + self, + filters: dict[str, Any], + options: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + Return an ordered list of documents (with minimal overview fields), + plus all associated chunks in ascending chunk order. + + Only the filters: owner_id, collection_ids, and document_id + are supported. If any other filter or operator is passed in, + we raise an error. + + Args: + filters: A dictionary describing the allowed filters + (owner_id, collection_ids, document_id). + options: A dictionary with extra options, e.g. include_summary_embedding + or any custom flags for additional logic. + + Returns: + A list of dicts, where each dict has: + { + "document": <DocumentResponse>, + "chunks": [ <chunk0>, <chunk1>, ... ] + } + """ + # 2. Fetch matching documents + matching_docs = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=-1, + filters=filters, + include_summary_embedding=options.get( + "include_summary_embedding", False + ), + ) + + if not matching_docs["results"]: + return [] + + # 3. For each document, fetch associated chunks in ascending chunk order + results = [] + for doc_response in matching_docs["results"]: + doc_id = doc_response.id + chunk_data = await self.providers.database.chunks_handler.list_document_chunks( + document_id=doc_id, + offset=0, + limit=-1, # get all chunks + include_vectors=False, + ) + chunks = chunk_data["results"] # already sorted by chunk_order + doc_response.chunks = chunks + # 4. Build a returned structure that includes doc + chunks + results.append(doc_response.model_dump()) + + return results + + def _parse_user_and_collection_filters( + self, + filters: dict[str, Any], + ): + ### TODO - Come up with smarter way to extract owner / collection ids for non-admin + filter_starts_with_and = filters.get("$and") + filter_starts_with_or = filters.get("$or") + if filter_starts_with_and: + try: + filter_starts_with_and_then_or = filter_starts_with_and[0][ + "$or" + ] + + user_id = filter_starts_with_and_then_or[0]["owner_id"]["$eq"] + collection_ids = [ + UUID(ele) + for ele in filter_starts_with_and_then_or[1][ + "collection_ids" + ]["$overlap"] + ] + return user_id, [str(ele) for ele in collection_ids] + except Exception as e: + logger.error( + f"Error: {e}.\n\n While" + + """ parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored.""" + ) + return None, [] + elif filter_starts_with_or: + try: + user_id = filter_starts_with_or[0]["owner_id"]["$eq"] + collection_ids = [ + UUID(ele) + for ele in filter_starts_with_or[1]["collection_ids"][ + "$overlap" + ] + ] + return user_id, [str(ele) for ele in collection_ids] + except Exception as e: + logger.error( + """Error parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored.""" + ) + return None, [] + else: + # Admin user + return None, [] + + async def _build_documents_context( + self, + filter_user_id: Optional[UUID] = None, + max_summary_length: int = 128, + limit: int = 25, + reverse_order: bool = True, + ) -> str: + """ + Fetches documents matching the given filters and returns a formatted string + enumerating them. + """ + # We only want up to `limit` documents for brevity + docs_data = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=limit, + filter_user_ids=[filter_user_id] if filter_user_id else None, + include_summary_embedding=False, + sort_order="DESC" if reverse_order else "ASC", + ) + + found_max = False + if len(docs_data["results"]) == limit: + found_max = True + + docs = docs_data["results"] + if not docs: + return "No documents found." + + lines = [] + for i, doc in enumerate(docs, start=1): + if ( + not doc.summary + or doc.ingestion_status != IngestionStatus.SUCCESS + ): + lines.append( + f"[{i}] Title: {doc.title}, Summary: (Summary not available), Status:{doc.ingestion_status} ID: {doc.id}" + ) + continue + + # Build a line referencing the doc + title = doc.title or "(Untitled Document)" + lines.append( + f"[{i}] Title: {title}, Summary: {(doc.summary[0:max_summary_length] + ('...' if len(doc.summary) > max_summary_length else ''),)}, Total Tokens: {doc.total_tokens}, ID: {doc.id}" + ) + if found_max: + lines.append( + f"Note: Displaying only the first {limit} documents. Use a filter to narrow down the search if more documents are required." + ) + + return "\n".join(lines) + + async def _build_aware_system_instruction( + self, + max_tool_context_length: int = 10_000, + filter_user_id: Optional[UUID] = None, + filter_collection_ids: Optional[list[UUID]] = None, + model: Optional[str] = None, + use_system_context: bool = False, + mode: Optional[str] = "rag", + ) -> str: + """ + High-level method that: + 1) builds the documents context + 2) builds the collections context + 3) loads the new `dynamic_reasoning_rag_agent` prompt + """ + date_str = str(datetime.now().strftime("%m/%d/%Y")) + + # "dynamic_rag_agent" // "static_rag_agent" + + if mode == "rag": + prompt_name = ( + self.config.agent.rag_agent_dynamic_prompt + if use_system_context + else self.config.agent.rag_rag_agent_static_prompt + ) + else: + prompt_name = "static_research_agent" + return await self.providers.database.prompts_handler.get_cached_prompt( + # We use custom tooling and a custom agent to handle gemini models + prompt_name, + inputs={ + "date": date_str, + }, + ) + + if model is not None and ("deepseek" in model): + prompt_name = f"{prompt_name}_xml_tooling" + + if use_system_context: + doc_context_str = await self._build_documents_context( + filter_user_id=filter_user_id, + ) + logger.debug(f"Loading prompt {prompt_name}") + # Now fetch the prompt from the database prompts handler + # This relies on your "rag_agent_extended" existing with + # placeholders: date, document_context + system_prompt = await self.providers.database.prompts_handler.get_cached_prompt( + # We use custom tooling and a custom agent to handle gemini models + prompt_name, + inputs={ + "date": date_str, + "max_tool_context_length": max_tool_context_length, + "document_context": doc_context_str, + }, + ) + else: + system_prompt = await self.providers.database.prompts_handler.get_cached_prompt( + prompt_name, + inputs={ + "date": date_str, + }, + ) + logger.debug(f"Running agent with system prompt = {system_prompt}") + return system_prompt + + async def _perform_web_search( + self, + query: str, + search_settings: SearchSettings = SearchSettings(), + ) -> AggregateSearchResult: + """ + Perform a web search using an external search engine API (Serper). + + Args: + query: The search query string + search_settings: Optional search settings to customize the search + + Returns: + AggregateSearchResult containing web search results + """ + try: + # Import the Serper client here to avoid circular imports + from core.utils.serper import SerperClient + + # Initialize the Serper client + serper_client = SerperClient() + + # Perform the raw search using Serper API + raw_results = serper_client.get_raw(query) + + # Process the raw results into a WebSearchResult object + web_response = WebSearchResult.from_serper_results(raw_results) + + # Create an AggregateSearchResult with the web search results + agg_result = AggregateSearchResult( + chunk_search_results=None, + graph_search_results=None, + web_search_results=web_response.organic_results, + ) + + # Log the search for monitoring purposes + logger.debug(f"Web search completed for query: {query}") + logger.debug( + f"Found {len(web_response.organic_results)} web results" + ) + + return agg_result + + except Exception as e: + logger.error(f"Error performing web search: {str(e)}") + # Return empty results rather than failing completely + return AggregateSearchResult( + chunk_search_results=None, + graph_search_results=None, + web_search_results=[], + ) + + +class RetrievalServiceAdapter: + @staticmethod + def _parse_user_data(user_data): + if isinstance(user_data, str): + try: + user_data = json.loads(user_data) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid user data format: {user_data}" + ) from e + return User.from_dict(user_data) + + @staticmethod + def prepare_search_input( + query: str, + search_settings: SearchSettings, + user: User, + ) -> dict: + return { + "query": query, + "search_settings": search_settings.to_dict(), + "user": user.to_dict(), + } + + @staticmethod + def parse_search_input(data: dict): + return { + "query": data["query"], + "search_settings": SearchSettings.from_dict( + data["search_settings"] + ), + "user": RetrievalServiceAdapter._parse_user_data(data["user"]), + } + + @staticmethod + def prepare_rag_input( + query: str, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + task_prompt: Optional[str], + include_web_search: bool, + user: User, + ) -> dict: + return { + "query": query, + "search_settings": search_settings.to_dict(), + "rag_generation_config": rag_generation_config.to_dict(), + "task_prompt": task_prompt, + "include_web_search": include_web_search, + "user": user.to_dict(), + } + + @staticmethod + def parse_rag_input(data: dict): + return { + "query": data["query"], + "search_settings": SearchSettings.from_dict( + data["search_settings"] + ), + "rag_generation_config": GenerationConfig.from_dict( + data["rag_generation_config"] + ), + "task_prompt": data["task_prompt"], + "include_web_search": data["include_web_search"], + "user": RetrievalServiceAdapter._parse_user_data(data["user"]), + } + + @staticmethod + def prepare_agent_input( + message: Message, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + task_prompt: Optional[str], + include_title_if_available: bool, + user: User, + conversation_id: Optional[str] = None, + ) -> dict: + return { + "message": message.to_dict(), + "search_settings": search_settings.to_dict(), + "rag_generation_config": rag_generation_config.to_dict(), + "task_prompt": task_prompt, + "include_title_if_available": include_title_if_available, + "user": user.to_dict(), + "conversation_id": conversation_id, + } + + @staticmethod + def parse_agent_input(data: dict): + return { + "message": Message.from_dict(data["message"]), + "search_settings": SearchSettings.from_dict( + data["search_settings"] + ), + "rag_generation_config": GenerationConfig.from_dict( + data["rag_generation_config"] + ), + "task_prompt": data["task_prompt"], + "include_title_if_available": data["include_title_if_available"], + "user": RetrievalServiceAdapter._parse_user_data(data["user"]), + "conversation_id": data.get("conversation_id"), + } |
