diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main')
34 files changed, 20557 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/__init__.py b/.venv/lib/python3.12/site-packages/core/main/__init__.py new file mode 100644 index 00000000..7043d029 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/__init__.py @@ -0,0 +1,24 @@ +from .abstractions import R2RProviders +from .api import * +from .app import * + +# from .app_entry import r2r_app +from .assembly import * +from .orchestration import * +from .services import * + +__all__ = [ + # R2R Primary + "R2RProviders", + "R2RApp", + "R2RBuilder", + "R2RConfig", + # Factory + "R2RProviderFactory", + ## R2R SERVICES + "AuthService", + "IngestionService", + "ManagementService", + "RetrievalService", + "GraphService", +] diff --git a/.venv/lib/python3.12/site-packages/core/main/abstractions.py b/.venv/lib/python3.12/site-packages/core/main/abstractions.py new file mode 100644 index 00000000..3aaf2dbf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/abstractions.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from core.providers import ( + AnthropicCompletionProvider, + AsyncSMTPEmailProvider, + ClerkAuthProvider, + ConsoleMockEmailProvider, + HatchetOrchestrationProvider, + JwtAuthProvider, + LiteLLMCompletionProvider, + LiteLLMEmbeddingProvider, + MailerSendEmailProvider, + OllamaEmbeddingProvider, + OpenAICompletionProvider, + OpenAIEmbeddingProvider, + PostgresDatabaseProvider, + R2RAuthProvider, + R2RCompletionProvider, + R2RIngestionProvider, + SendGridEmailProvider, + SimpleOrchestrationProvider, + SupabaseAuthProvider, + UnstructuredIngestionProvider, +) + +if TYPE_CHECKING: + from core.main.services.auth_service import AuthService + from core.main.services.graph_service import GraphService + from core.main.services.ingestion_service import IngestionService + from core.main.services.management_service import ManagementService + from core.main.services.retrieval_service import ( # type: ignore + RetrievalService, # type: ignore + ) + + +class R2RProviders(BaseModel): + auth: ( + R2RAuthProvider + | SupabaseAuthProvider + | JwtAuthProvider + | ClerkAuthProvider + ) + database: PostgresDatabaseProvider + ingestion: R2RIngestionProvider | UnstructuredIngestionProvider + embedding: ( + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider + ) + completion_embedding: ( + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider + ) + llm: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ) + orchestration: HatchetOrchestrationProvider | SimpleOrchestrationProvider + email: ( + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider + | MailerSendEmailProvider + ) + + class Config: + arbitrary_types_allowed = True + + +@dataclass +class R2RServices: + auth: "AuthService" + ingestion: "IngestionService" + management: "ManagementService" + retrieval: "RetrievalService" + graph: "GraphService" diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py new file mode 100644 index 00000000..ef432420 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py @@ -0,0 +1,151 @@ +import functools +import logging +from abc import abstractmethod +from typing import Callable + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import FileResponse, StreamingResponse + +from core.base import R2RException + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig + +logger = logging.getLogger() + + +class BaseRouterV3: + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + """ + :param providers: Typically includes auth, database, etc. + :param services: Additional service references (ingestion, etc). + """ + self.providers = providers + self.services = services + self.config = config + self.router = APIRouter() + self.openapi_extras = self._load_openapi_extras() + + # Add the rate-limiting dependency + self.set_rate_limiting() + + # Initialize any routes + self._setup_routes() + self._register_workflows() + + def get_router(self): + return self.router + + def base_endpoint(self, func: Callable): + """ + A decorator to wrap endpoints in a standard pattern: + - error handling + - response shaping + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + try: + func_result = await func(*args, **kwargs) + if isinstance(func_result, tuple) and len(func_result) == 2: + results, outer_kwargs = func_result + else: + results, outer_kwargs = func_result, {} + + if isinstance(results, (StreamingResponse, FileResponse)): + return results + return {"results": results, **outer_kwargs} + + except R2RException: + raise + except Exception as e: + logger.error( + f"Error in base endpoint {func.__name__}() - {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=500, + detail={ + "message": f"An error '{e}' occurred during {func.__name__}", + "error": str(e), + "error_type": type(e).__name__, + }, + ) from e + + wrapper._is_base_endpoint = True # type: ignore + return wrapper + + @classmethod + def build_router(cls, engine): + """Class method for building a router instance (if you have a standard + pattern).""" + return cls(engine).router + + def _register_workflows(self): + pass + + def _load_openapi_extras(self): + return {} + + @abstractmethod + def _setup_routes(self): + """Subclasses override this to define actual endpoints.""" + pass + + def set_rate_limiting(self): + """Adds a yield-based dependency for rate limiting each request. + + Checks the limits, then logs the request if the check passes. + """ + + async def rate_limit_dependency( + request: Request, + auth_user=Depends(self.providers.auth.auth_wrapper()), + ): + """1) Fetch the user from the DB (including .limits_overrides). + + 2) Pass it to limits_handler.check_limits. 3) After the endpoint + completes, call limits_handler.log_request. + """ + # If the user is superuser, skip checks + if auth_user.is_superuser: + yield + return + + user_id = auth_user.id + route = request.scope["path"] + + # 1) Fetch the user from DB + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if not user: + raise HTTPException(status_code=404, detail="User not found.") + + # 2) Rate-limit check + try: + await self.providers.database.limits_handler.check_limits( + user=user, + route=route, # Pass the User object + ) + except ValueError as e: + # If check_limits raises ValueError -> 429 Too Many Requests + raise HTTPException(status_code=429, detail=str(e)) from e + + request.state.user_id = user_id + request.state.route = route + + # 3) Execute the route + try: + yield + finally: + # 4) Log only POST and DELETE requests + if request.method in ["POST", "DELETE"]: + await self.providers.database.limits_handler.log_request( + user_id, route + ) + + # Attach the dependencies so you can use them in your endpoints + self.rate_limit_dependency = rate_limit_dependency diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py new file mode 100644 index 00000000..ab0a62cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py @@ -0,0 +1,422 @@ +import json +import logging +import textwrap +from typing import Optional +from uuid import UUID + +from fastapi import Body, Depends, Path, Query + +from core.base import ( + ChunkResponse, + GraphSearchSettings, + R2RException, + SearchSettings, + UpdateChunk, + select_search_filters, +) +from core.base.api.models import ( + GenericBooleanResponse, + WrappedBooleanResponse, + WrappedChunkResponse, + WrappedChunksResponse, + WrappedVectorSearchResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +logger = logging.getLogger() + +MAX_CHUNKS_PER_REQUEST = 1024 * 100 + + +class ChunksRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing ChunksRouter") + super().__init__(providers, services, config) + + def _setup_routes(self): + @self.router.post( + "/chunks/search", + summary="Search Chunks", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + response = client.chunks.search( + query="search query", + search_settings={ + "limit": 10 + } + ) + """), + } + ] + }, + ) + @self.base_endpoint + async def search_chunks( + query: str = Body(...), + search_settings: SearchSettings = Body( + default_factory=SearchSettings, + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedVectorSearchResponse: # type: ignore + # TODO - Deduplicate this code by sharing the code on the retrieval router + """Perform a semantic search query over all stored chunks. + + This endpoint allows for complex filtering of search results using PostgreSQL-based queries. + Filters can be applied to various fields such as document_id, and internal metadata values. + + Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. + """ + + search_settings.filters = select_search_filters( + auth_user, search_settings + ) + + search_settings.graph_settings = GraphSearchSettings(enabled=False) + + results = await self.services.retrieval.search( + query=query, + search_settings=search_settings, + ) + return results.chunk_search_results # type: ignore + + @self.router.get( + "/chunks/{id}", + summary="Retrieve Chunk", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + response = client.chunks.retrieve( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.chunks.retrieve({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def retrieve_chunk( + id: UUID = Path(...), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedChunkResponse: + """Get a specific chunk by its ID. + + Returns the chunk's content, metadata, and associated + document/collection information. Users can only retrieve chunks + they own or have access to through collections. + """ + chunk = await self.services.ingestion.get_chunk(id) + if not chunk: + raise R2RException("Chunk not found", 404) + + # TODO - Add collection ID check + if not auth_user.is_superuser and str(auth_user.id) != str( + chunk["owner_id"] + ): + raise R2RException("Not authorized to access this chunk", 403) + + return ChunkResponse( # type: ignore + id=chunk["id"], + document_id=chunk["document_id"], + owner_id=chunk["owner_id"], + collection_ids=chunk["collection_ids"], + text=chunk["text"], + metadata=chunk["metadata"], + # vector = chunk["vector"] # TODO - Add include vector flag + ) + + @self.router.post( + "/chunks/{id}", + summary="Update Chunk", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + response = client.chunks.update( + { + "id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + "text": "Updated content", + "metadata": {"key": "new value"} + } + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.chunks.update({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + text: "Updated content", + metadata: {key: "new value"} + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_chunk( + id: UUID = Path(...), + chunk_update: UpdateChunk = Body(...), + # TODO: Run with orchestration? + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedChunkResponse: + """Update an existing chunk's content and/or metadata. + + The chunk's vectors will be automatically recomputed based on the + new content. Users can only update chunks they own unless they are + superusers. + """ + # Get the existing chunk to get its chunk_id + existing_chunk = await self.services.ingestion.get_chunk( + chunk_update.id + ) + if existing_chunk is None: + raise R2RException(f"Chunk {chunk_update.id} not found", 404) + + workflow_input = { + "document_id": str(existing_chunk["document_id"]), + "id": str(chunk_update.id), + "text": chunk_update.text, + "metadata": chunk_update.metadata + or existing_chunk["metadata"], + "user": auth_user.model_dump_json(), + } + + logger.info("Running chunk ingestion without orchestration.") + from core.main.orchestration import simple_ingestion_factory + + # TODO - CLEAN THIS UP + + simple_ingestor = simple_ingestion_factory(self.services.ingestion) + await simple_ingestor["update-chunk"](workflow_input) + + return ChunkResponse( # type: ignore + id=chunk_update.id, + document_id=existing_chunk["document_id"], + owner_id=existing_chunk["owner_id"], + collection_ids=existing_chunk["collection_ids"], + text=chunk_update.text, + metadata=chunk_update.metadata or existing_chunk["metadata"], + # vector = existing_chunk.get('vector') + ) + + @self.router.delete( + "/chunks/{id}", + summary="Delete Chunk", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + response = client.chunks.delete( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.chunks.delete({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_chunk( + id: UUID = Path(...), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete a specific chunk by ID. + + This permanently removes the chunk and its associated vector + embeddings. The parent document remains unchanged. Users can only + delete chunks they own unless they are superusers. + """ + # Get the existing chunk to get its chunk_id + existing_chunk = await self.services.ingestion.get_chunk(id) + + if existing_chunk is None: + raise R2RException( + message=f"Chunk {id} not found", status_code=404 + ) + + filters = { + "$and": [ + {"owner_id": {"$eq": str(auth_user.id)}}, + {"chunk_id": {"$eq": str(id)}}, + ] + } + await ( + self.services.management.delete_documents_and_chunks_by_filter( + filters=filters + ) + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.get( + "/chunks", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List Chunks", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + response = client.chunks.list( + metadata_filter={"key": "value"}, + include_vectors=False, + offset=0, + limit=10, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.chunks.list({ + metadataFilter: {key: "value"}, + includeVectors: false, + offset: 0, + limit: 10, + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_chunks( + metadata_filter: Optional[str] = Query( + None, description="Filter by metadata" + ), + include_vectors: bool = Query( + False, description="Include vector data in response" + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedChunksResponse: + """List chunks with pagination support. + + Returns a paginated list of chunks that the user has access to. + Results can be filtered and sorted based on various parameters. + Vector embeddings are only included if specifically requested. + + Regular users can only list chunks they own or have access to + through collections. Superusers can list all chunks in the system. + """ # Build filters + filters = {} + + # Add user access control filter + if not auth_user.is_superuser: + filters["owner_id"] = {"$eq": str(auth_user.id)} + + # Add metadata filters if provided + if metadata_filter: + metadata_filter = json.loads(metadata_filter) + + # Get chunks using the vector handler's list_chunks method + results = await self.services.ingestion.list_chunks( + filters=filters, + include_vectors=include_vectors, + offset=offset, + limit=limit, + ) + + # Convert to response format + chunks = [ + ChunkResponse( + id=chunk["id"], + document_id=chunk["document_id"], + owner_id=chunk["owner_id"], + collection_ids=chunk["collection_ids"], + text=chunk["text"], + metadata=chunk["metadata"], + vector=chunk.get("vector") if include_vectors else None, + ) + for chunk in results["results"] + ] + + return (chunks, {"total_entries": results["total_entries"]}) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py new file mode 100644 index 00000000..462f5ca3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py @@ -0,0 +1,1207 @@ +import logging +import textwrap +from enum import Enum +from typing import Optional +from uuid import UUID + +from fastapi import Body, Depends, Path, Query +from fastapi.background import BackgroundTasks +from fastapi.responses import FileResponse + +from core.base import R2RException +from core.base.abstractions import GraphCreationSettings +from core.base.api.models import ( + GenericBooleanResponse, + WrappedBooleanResponse, + WrappedCollectionResponse, + WrappedCollectionsResponse, + WrappedDocumentsResponse, + WrappedGenericMessageResponse, + WrappedUsersResponse, +) +from core.utils import ( + generate_default_user_collection_id, + update_settings_from_dict, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +logger = logging.getLogger() + + +class CollectionAction(str, Enum): + VIEW = "view" + EDIT = "edit" + DELETE = "delete" + MANAGE_USERS = "manage_users" + ADD_DOCUMENT = "add_document" + REMOVE_DOCUMENT = "remove_document" + + +async def authorize_collection_action( + auth_user, collection_id: UUID, action: CollectionAction, services +) -> bool: + """Authorize a user's action on a given collection based on: + + - If user is superuser (admin): Full access. + - If user is owner of the collection: Full access. + - If user is a member of the collection (in `collection_ids`): VIEW only. + - Otherwise: No access. + """ + + # Superusers have complete access + if auth_user.is_superuser: + return True + + # Fetch collection details: owner_id and members + results = ( + await services.management.collections_overview( + 0, 1, collection_ids=[collection_id] + ) + )["results"] + if len(results) == 0: + raise R2RException("The specified collection does not exist.", 404) + details = results[0] + owner_id = details.owner_id + + # Check if user is owner + if auth_user.id == owner_id: + # Owner can do all actions + return True + + # Check if user is a member (non-owner) + if collection_id in auth_user.collection_ids: + # Members can only view + if action == CollectionAction.VIEW: + return True + else: + raise R2RException( + "Insufficient permissions for this action.", 403 + ) + + # User is neither owner nor member + raise R2RException("You do not have access to this collection.", 403) + + +class CollectionsRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing CollectionsRouter") + super().__init__(providers, services, config) + + def _setup_routes(self): + @self.router.post( + "/collections", + summary="Create a new collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.create( + name="My New Collection", + description="This is a sample collection" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.create({ + name: "My New Collection", + description: "This is a sample collection" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/collections" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{"name": "My New Collection", "description": "This is a sample collection"}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_collection( + name: str = Body(..., description="The name of the collection"), + description: Optional[str] = Body( + None, description="An optional description of the collection" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionResponse: + """Create a new collection and automatically add the creating user + to it. + + This endpoint allows authenticated users to create a new collection + with a specified name and optional description. The user creating + the collection is automatically added as a member. + """ + user_collections_count = ( + await self.services.management.collections_overview( + user_ids=[auth_user.id], limit=1, offset=0 + ) + )["total_entries"] + user_max_collections = ( + await self.services.management.get_user_max_collections( + auth_user.id + ) + ) + if (user_collections_count + 1) >= user_max_collections: # type: ignore + raise R2RException( + f"User has reached the maximum number of collections allowed ({user_max_collections}).", + 400, + ) + collection = await self.services.management.create_collection( + owner_id=auth_user.id, + name=name, + description=description, + ) + # Add the creating user to the collection + await self.services.management.add_user_to_collection( + auth_user.id, collection.id + ) + return collection # type: ignore + + @self.router.post( + "/collections/export", + summary="Export collections to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.collections.export( + output_path="export.csv", + columns=["id", "name", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.collections.export({ + outputPath: "export.csv", + columns: ["id", "name", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/collections/export" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_collections( + background_tasks: BackgroundTasks, + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export collections as a CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_collections( + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="collections_export.csv", + ) + + @self.router.get( + "/collections", + summary="List collections", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.list( + offset=0, + limit=10, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.list(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/collections?offset=0&limit=10&name=Sample" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_collections( + ids: list[str] = Query( + [], + description="A list of collection IDs to retrieve. If not provided, all collections will be returned.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionsResponse: + """Returns a paginated list of collections the authenticated user + has access to. + + Results can be filtered by providing specific collection IDs. + Regular users will only see collections they own or have access to. + Superusers can see all collections. + + The collections are returned in order of last modification, with + most recent first. + """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + collection_uuids = [UUID(collection_id) for collection_id in ids] + + collections_overview_response = ( + await self.services.management.collections_overview( + user_ids=requesting_user_id, + collection_ids=collection_uuids, + offset=offset, + limit=limit, + ) + ) + + return ( # type: ignore + collections_overview_response["results"], + { + "total_entries": collections_overview_response[ + "total_entries" + ] + }, + ) + + @self.router.get( + "/collections/{id}", + summary="Get collection details", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.retrieve({id: "123e4567-e89b-12d3-a456-426614174000"}); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_collection( + id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionResponse: + """Get details of a specific collection. + + This endpoint retrieves detailed information about a single + collection identified by its UUID. The user must have access to the + collection to view its details. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.VIEW, self.services + ) + + collections_overview_response = ( + await self.services.management.collections_overview( + user_ids=None, + collection_ids=[id], + offset=0, + limit=1, + ) + ) + overview = collections_overview_response["results"] + + if len(overview) == 0: # type: ignore + raise R2RException( + "The specified collection does not exist.", + 404, + ) + return overview[0] # type: ignore + + @self.router.post( + "/collections/{id}", + summary="Update collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.update( + "123e4567-e89b-12d3-a456-426614174000", + name="Updated Collection Name", + description="Updated description" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.update({ + id: "123e4567-e89b-12d3-a456-426614174000", + name: "Updated Collection Name", + description: "Updated description" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{"name": "Updated Collection Name", "description": "Updated description"}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_collection( + id: UUID = Path( + ..., + description="The unique identifier of the collection to update", + ), + name: Optional[str] = Body( + None, description="The name of the collection" + ), + description: Optional[str] = Body( + None, description="An optional description of the collection" + ), + generate_description: Optional[bool] = Body( + False, + description="Whether to generate a new synthetic description for the collection", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionResponse: + """Update an existing collection's configuration. + + This endpoint allows updating the name and description of an + existing collection. The user must have appropriate permissions to + modify the collection. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.EDIT, self.services + ) + + if generate_description and description is not None: + raise R2RException( + "Cannot provide both a description and request to synthetically generate a new one.", + 400, + ) + + return await self.services.management.update_collection( # type: ignore + id, + name=name, + description=description, + generate_description=generate_description or False, + ) + + @self.router.delete( + "/collections/{id}", + summary="Delete collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.delete({id: "123e4567-e89b-12d3-a456-426614174000"}); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_collection( + id: UUID = Path( + ..., + description="The unique identifier of the collection to delete", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete an existing collection. + + This endpoint allows deletion of a collection identified by its + UUID. The user must have appropriate permissions to delete the + collection. Deleting a collection removes all associations but does + not delete the documents within it. + """ + if id == generate_default_user_collection_id(auth_user.id): + raise R2RException( + "Cannot delete the default user collection.", + 400, + ) + await authorize_collection_action( + auth_user, id, CollectionAction.DELETE, self.services + ) + + await self.services.management.delete_collection(collection_id=id) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.post( + "/collections/{id}/documents/{document_id}", + summary="Add document to collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.add_document( + "123e4567-e89b-12d3-a456-426614174000", + "456e789a-b12c-34d5-e678-901234567890" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.addDocument({ + id: "123e4567-e89b-12d3-a456-426614174000" + documentId: "456e789a-b12c-34d5-e678-901234567890" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def add_document_to_collection( + id: UUID = Path(...), + document_id: UUID = Path(...), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Add a document to a collection.""" + await authorize_collection_action( + auth_user, id, CollectionAction.ADD_DOCUMENT, self.services + ) + + return ( + await self.services.management.assign_document_to_collection( + document_id, id + ) + ) + + @self.router.get( + "/collections/{id}/documents", + summary="List documents in collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.list_documents( + "123e4567-e89b-12d3-a456-426614174000", + offset=0, + limit=10, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.listDocuments({id: "123e4567-e89b-12d3-a456-426614174000"}); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents?offset=0&limit=10" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_collection_documents( + id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedDocumentsResponse: + """Get all documents in a collection with pagination and sorting + options. + + This endpoint retrieves a paginated list of documents associated + with a specific collection. It supports sorting options to + customize the order of returned documents. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.VIEW, self.services + ) + + documents_in_collection_response = ( + await self.services.management.documents_in_collection( + id, offset, limit + ) + ) + + return documents_in_collection_response["results"], { # type: ignore + "total_entries": documents_in_collection_response[ + "total_entries" + ] + } + + @self.router.delete( + "/collections/{id}/documents/{document_id}", + summary="Remove document from collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.remove_document( + "123e4567-e89b-12d3-a456-426614174000", + "456e789a-b12c-34d5-e678-901234567890" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.removeDocument({ + id: "123e4567-e89b-12d3-a456-426614174000" + documentId: "456e789a-b12c-34d5-e678-901234567890" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def remove_document_from_collection( + id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + document_id: UUID = Path( + ..., + description="The unique identifier of the document to remove", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Remove a document from a collection. + + This endpoint removes the association between a document and a + collection. It does not delete the document itself. The user must + have permissions to modify the collection. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.REMOVE_DOCUMENT, self.services + ) + await self.services.management.remove_document_from_collection( + document_id, id + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.get( + "/collections/{id}/users", + summary="List users in collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.list_users( + "123e4567-e89b-12d3-a456-426614174000", + offset=0, + limit=10, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.listUsers({ + id: "123e4567-e89b-12d3-a456-426614174000" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users?offset=0&limit=10" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_collection_users( + id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedUsersResponse: + """Get all users in a collection with pagination and sorting + options. + + This endpoint retrieves a paginated list of users who have access + to a specific collection. It supports sorting options to customize + the order of returned users. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.VIEW, self.services + ) + + users_in_collection_response = ( + await self.services.management.get_users_in_collection( + collection_id=id, + offset=offset, + limit=min(max(limit, 1), 1000), + ) + ) + + return users_in_collection_response["results"], { # type: ignore + "total_entries": users_in_collection_response["total_entries"] + } + + @self.router.post( + "/collections/{id}/users/{user_id}", + summary="Add user to collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.add_user( + "123e4567-e89b-12d3-a456-426614174000", + "789a012b-c34d-5e6f-g789-012345678901" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.addUser({ + id: "123e4567-e89b-12d3-a456-426614174000" + userId: "789a012b-c34d-5e6f-g789-012345678901" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def add_user_to_collection( + id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + user_id: UUID = Path( + ..., description="The unique identifier of the user to add" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Add a user to a collection. + + This endpoint grants a user access to a specific collection. The + authenticated user must have admin permissions for the collection + to add new users. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.MANAGE_USERS, self.services + ) + + result = await self.services.management.add_user_to_collection( + user_id, id + ) + return GenericBooleanResponse(success=result) # type: ignore + + @self.router.delete( + "/collections/{id}/users/{user_id}", + summary="Remove user from collection", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.collections.remove_user( + "123e4567-e89b-12d3-a456-426614174000", + "789a012b-c34d-5e6f-g789-012345678901" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.collections.removeUser({ + id: "123e4567-e89b-12d3-a456-426614174000" + userId: "789a012b-c34d-5e6f-g789-012345678901" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def remove_user_from_collection( + id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + user_id: UUID = Path( + ..., description="The unique identifier of the user to remove" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Remove a user from a collection. + + This endpoint revokes a user's access to a specific collection. The + authenticated user must have admin permissions for the collection + to remove users. + """ + await authorize_collection_action( + auth_user, id, CollectionAction.MANAGE_USERS, self.services + ) + + result = ( + await self.services.management.remove_user_from_collection( + user_id, id + ) + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.post( + "/collections/{id}/extract", + summary="Extract entities and relationships", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.documents.extract( + id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1" + ) + """), + }, + ], + }, + ) + @self.base_endpoint + async def extract( + id: UUID = Path( + ..., + description="The ID of the document to extract entities and relationships from.", + ), + settings: Optional[GraphCreationSettings] = Body( + default=None, + description="Settings for the entities and relationships extraction process.", + ), + run_with_orchestration: Optional[bool] = Query( + default=True, + description="Whether to run the entities and relationships extraction process with orchestration.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Extracts entities and relationships from a document. + + The entities and relationships extraction process involves: + 1. Parsing documents into semantic chunks + 2. Extracting entities and relationships using LLMs + """ + await authorize_collection_action( + auth_user, id, CollectionAction.EDIT, self.services + ) + + settings = settings.dict() if settings else None # type: ignore + if not auth_user.is_superuser: + logger.warning("Implement permission checks here.") + + # Apply runtime settings overrides + server_graph_creation_settings = ( + self.providers.database.config.graph_creation_settings + ) + + if settings: + server_graph_creation_settings = update_settings_from_dict( + server_settings=server_graph_creation_settings, + settings_dict=settings, # type: ignore + ) + if run_with_orchestration: + try: + workflow_input = { + "collection_id": str(id), + "graph_creation_settings": server_graph_creation_settings.model_dump_json(), + "user": auth_user.json(), + } + + return await self.providers.orchestration.run_workflow( # type: ignore + "graph-extraction", {"request": workflow_input}, {} + ) + except Exception as e: # TODO: Need to find specific error (gRPC most likely?) + logger.error( + f"Error running orchestrated extraction: {e} \n\nAttempting to run without orchestration." + ) + + from core.main.orchestration import ( + simple_graph_search_results_factory, + ) + + logger.info("Running extract-triples without orchestration.") + simple_graph_search_results = simple_graph_search_results_factory( + self.services.graph + ) + await simple_graph_search_results["graph-extraction"]( + workflow_input + ) # type: ignore + return { # type: ignore + "message": "Graph created successfully.", + "task_id": None, + } + + @self.router.get( + "/collections/name/{collection_name}", + summary="Get a collection by name", + dependencies=[Depends(self.rate_limit_dependency)], + ) + @self.base_endpoint + async def get_collection_by_name( + collection_name: str = Path( + ..., description="The name of the collection" + ), + owner_id: Optional[UUID] = Query( + None, + description="(Superuser only) Specify the owner_id to retrieve a collection by name", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionResponse: + """Retrieve a collection by its (owner_id, name) combination. + + The authenticated user can only fetch collections they own, or, if + superuser, from anyone. + """ + if auth_user.is_superuser: + if not owner_id: + owner_id = auth_user.id + else: + owner_id = auth_user.id + + # If not superuser, fetch by (owner_id, name). Otherwise, maybe pass `owner_id=None`. + # Decide on the logic for superusers. + if not owner_id: # is_superuser + # If you want superusers to do /collections/name/<string>?owner_id=... + # just parse it from the query. For now, let's say it's not implemented. + raise R2RException( + "Superuser must specify an owner_id to fetch by name.", 400 + ) + + collection = await self.providers.database.collections_handler.get_collection_by_name( + owner_id, collection_name + ) + if not collection: + raise R2RException("Collection not found.", 404) + + # Now, authorize the 'view' action just in case: + # e.g. await authorize_collection_action(auth_user, collection.id, CollectionAction.VIEW, self.services) + + return collection # type: ignore diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py new file mode 100644 index 00000000..d1b6d645 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py @@ -0,0 +1,737 @@ +import logging +import textwrap +from typing import Optional +from uuid import UUID + +from fastapi import Body, Depends, Path, Query +from fastapi.background import BackgroundTasks +from fastapi.responses import FileResponse + +from core.base import Message, R2RException +from core.base.api.models import ( + GenericBooleanResponse, + WrappedBooleanResponse, + WrappedConversationMessagesResponse, + WrappedConversationResponse, + WrappedConversationsResponse, + WrappedMessageResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +logger = logging.getLogger() + + +class ConversationsRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing ConversationsRouter") + super().__init__(providers, services, config) + + def _setup_routes(self): + @self.router.post( + "/conversations", + summary="Create a new conversation", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.create() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.create(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/conversations" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_conversation( + name: Optional[str] = Body( + None, description="The name of the conversation", embed=True + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedConversationResponse: + """Create a new conversation. + + This endpoint initializes a new conversation for the authenticated + user. + """ + user_id = auth_user.id + + return await self.services.management.create_conversation( # type: ignore + user_id=user_id, + name=name, + ) + + @self.router.get( + "/conversations", + summary="List conversations", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.list( + offset=0, + limit=10, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.list(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/conversations?offset=0&limit=10" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_conversations( + ids: list[str] = Query( + [], + description="A list of conversation IDs to retrieve. If not provided, all conversations will be returned.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedConversationsResponse: + """List conversations with pagination and sorting options. + + This endpoint returns a paginated list of conversations for the + authenticated user. + """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + conversation_uuids = [ + UUID(conversation_id) for conversation_id in ids + ] + + conversations_response = ( + await self.services.management.conversations_overview( + offset=offset, + limit=limit, + conversation_ids=conversation_uuids, + user_ids=requesting_user_id, + ) + ) + return conversations_response["results"], { # type: ignore + "total_entries": conversations_response["total_entries"] + } + + @self.router.post( + "/conversations/export", + summary="Export conversations to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.conversations.export( + output_path="export.csv", + columns=["id", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.conversations.export({ + outputPath: "export.csv", + columns: ["id", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/conversations/export" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_conversations( + background_tasks: BackgroundTasks, + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export conversations as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_conversations( + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.post( + "/conversations/export_messages", + summary="Export messages to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.conversations.export_messages( + output_path="export.csv", + columns=["id", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.conversations.exportMessages({ + outputPath: "export.csv", + columns: ["id", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/conversations/export_messages" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_messages( + background_tasks: BackgroundTasks, + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export conversations as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_messages( + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.get( + "/conversations/{id}", + summary="Get conversation details", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.get( + "123e4567-e89b-12d3-a456-426614174000" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.retrieve({ + id: "123e4567-e89b-12d3-a456-426614174000", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_conversation( + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedConversationMessagesResponse: + """Get details of a specific conversation. + + This endpoint retrieves detailed information about a single + conversation identified by its UUID. + """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + conversation = await self.services.management.get_conversation( + conversation_id=id, + user_ids=requesting_user_id, + ) + return conversation # type: ignore + + @self.router.post( + "/conversations/{id}", + summary="Update conversation", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.update({ + id: "123e4567-e89b-12d3-a456-426614174000", + name: "new_name", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"name": "new_name"}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_conversation( + id: UUID = Path( + ..., + description="The unique identifier of the conversation to delete", + ), + name: str = Body( + ..., + description="The updated name for the conversation", + embed=True, + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedConversationResponse: + """Update an existing conversation. + + This endpoint updates the name of an existing conversation + identified by its UUID. + """ + return await self.services.management.update_conversation( # type: ignore + conversation_id=id, + name=name, + ) + + @self.router.delete( + "/conversations/{id}", + summary="Delete conversation", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.delete({ + id: "123e4567-e89b-12d3-a456-426614174000", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_conversation( + id: UUID = Path( + ..., + description="The unique identifier of the conversation to delete", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete an existing conversation. + + This endpoint deletes a conversation identified by its UUID. + """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + await self.services.management.delete_conversation( + conversation_id=id, + user_ids=requesting_user_id, + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.post( + "/conversations/{id}/messages", + summary="Add message to conversation", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.add_message( + "123e4567-e89b-12d3-a456-426614174000", + content="Hello, world!", + role="user", + parent_id="parent_message_id", + metadata={"key": "value"} + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.addMessage({ + id: "123e4567-e89b-12d3-a456-426614174000", + content: "Hello, world!", + role: "user", + parentId: "parent_message_id", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"content": "Hello, world!", "parent_id": "parent_message_id", "metadata": {"key": "value"}}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def add_message( + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), + content: str = Body( + ..., description="The content of the message to add" + ), + role: str = Body( + ..., description="The role of the message to add" + ), + parent_id: Optional[UUID] = Body( + None, description="The ID of the parent message, if any" + ), + metadata: Optional[dict[str, str]] = Body( + None, description="Additional metadata for the message" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedMessageResponse: + """Add a new message to a conversation. + + This endpoint adds a new message to an existing conversation. + """ + if content == "": + raise R2RException("Content cannot be empty", status_code=400) + if role not in ["user", "assistant", "system"]: + raise R2RException("Invalid role", status_code=400) + message = Message(role=role, content=content) + return await self.services.management.add_message( # type: ignore + conversation_id=id, + content=message, + parent_id=parent_id, + metadata=metadata, + ) + + @self.router.post( + "/conversations/{id}/messages/{message_id}", + summary="Update message in conversation", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.conversations.update_message( + "123e4567-e89b-12d3-a456-426614174000", + "message_id_to_update", + content="Updated content" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.conversations.updateMessage({ + id: "123e4567-e89b-12d3-a456-426614174000", + messageId: "message_id_to_update", + content: "Updated content", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages/message_id_to_update" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"content": "Updated content"}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_message( + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), + message_id: UUID = Path( + ..., description="The ID of the message to update" + ), + content: Optional[str] = Body( + None, description="The new content for the message" + ), + metadata: Optional[dict[str, str]] = Body( + None, description="Additional metadata for the message" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedMessageResponse: + """Update an existing message in a conversation. + + This endpoint updates the content of an existing message in a + conversation. + """ + return await self.services.management.edit_message( # type: ignore + message_id=message_id, + new_content=content, + additional_metadata=metadata, + ) diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py new file mode 100644 index 00000000..fe152b8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py @@ -0,0 +1,2342 @@ +import base64 +import logging +import mimetypes +import textwrap +from datetime import datetime +from io import BytesIO +from typing import Any, Optional +from urllib.parse import quote +from uuid import UUID + +from fastapi import Body, Depends, File, Form, Path, Query, UploadFile +from fastapi.background import BackgroundTasks +from fastapi.responses import FileResponse, StreamingResponse +from pydantic import Json + +from core.base import ( + IngestionConfig, + IngestionMode, + R2RException, + SearchMode, + SearchSettings, + UnprocessedChunk, + Workflow, + generate_document_id, + generate_id, + select_search_filters, +) +from core.base.abstractions import GraphCreationSettings, StoreType +from core.base.api.models import ( + GenericBooleanResponse, + WrappedBooleanResponse, + WrappedChunksResponse, + WrappedCollectionsResponse, + WrappedDocumentResponse, + WrappedDocumentSearchResponse, + WrappedDocumentsResponse, + WrappedEntitiesResponse, + WrappedGenericMessageResponse, + WrappedIngestionResponse, + WrappedRelationshipsResponse, +) +from core.utils import update_settings_from_dict + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +logger = logging.getLogger() +MAX_CHUNKS_PER_REQUEST = 1024 * 100 + + +def merge_search_settings( + base: SearchSettings, overrides: SearchSettings +) -> SearchSettings: + # Convert both to dict + base_dict = base.model_dump() + overrides_dict = overrides.model_dump(exclude_unset=True) + + # Update base_dict with values from overrides_dict + # This ensures that any field set in overrides takes precedence + for k, v in overrides_dict.items(): + base_dict[k] = v + + # Construct a new SearchSettings from the merged dict + return SearchSettings(**base_dict) + + +def merge_ingestion_config( + base: IngestionConfig, overrides: IngestionConfig +) -> IngestionConfig: + base_dict = base.model_dump() + overrides_dict = overrides.model_dump(exclude_unset=True) + + for k, v in overrides_dict.items(): + base_dict[k] = v + + return IngestionConfig(**base_dict) + + +class DocumentsRouter(BaseRouterV3): + def __init__( + self, + providers: R2RProviders, + services: R2RServices, + config: R2RConfig, + ): + logging.info("Initializing DocumentsRouter") + super().__init__(providers, services, config) + self._register_workflows() + + def _prepare_search_settings( + self, + auth_user: Any, + search_mode: SearchMode, + search_settings: Optional[SearchSettings], + ) -> SearchSettings: + """Prepare the effective search settings based on the provided + search_mode, optional user-overrides in search_settings, and applied + filters.""" + + if search_mode != SearchMode.custom: + # Start from mode defaults + effective_settings = SearchSettings.get_default(search_mode.value) + if search_settings: + # Merge user-provided overrides + effective_settings = merge_search_settings( + effective_settings, search_settings + ) + else: + # Custom mode: use provided settings or defaults + effective_settings = search_settings or SearchSettings() + + # Apply user-specific filters + effective_settings.filters = select_search_filters( + auth_user, effective_settings + ) + + return effective_settings + + # TODO - Remove this legacy method + def _register_workflows(self): + self.providers.orchestration.register_workflows( + Workflow.INGESTION, + self.services.ingestion, + { + "ingest-files": ( + "Ingest files task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Document created and ingested successfully." + ), + "ingest-chunks": ( + "Ingest chunks task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Document created and ingested successfully." + ), + "update-chunk": ( + "Update chunk task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Chunk update completed successfully." + ), + "update-document-metadata": ( + "Update document metadata task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Document metadata update completed successfully." + ), + "create-vector-index": ( + "Vector index creation task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Vector index creation task completed successfully." + ), + "delete-vector-index": ( + "Vector index deletion task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Vector index deletion task completed successfully." + ), + "select-vector-index": ( + "Vector index selection task queued successfully." + if self.providers.orchestration.config.provider != "simple" + else "Vector index selection task completed successfully." + ), + }, + ) + + def _prepare_ingestion_config( + self, + ingestion_mode: IngestionMode, + ingestion_config: Optional[IngestionConfig], + ) -> IngestionConfig: + # If not custom, start from defaults + if ingestion_mode != IngestionMode.custom: + effective_config = IngestionConfig.get_default( + ingestion_mode.value, app=self.providers.auth.config.app + ) + if ingestion_config: + effective_config = merge_ingestion_config( + effective_config, ingestion_config + ) + else: + # custom mode + effective_config = ingestion_config or IngestionConfig( + app=self.providers.auth.config.app + ) + + effective_config.validate_config() + return effective_config + + def _setup_routes(self): + @self.router.post( + "/documents", + dependencies=[Depends(self.rate_limit_dependency)], + status_code=202, + summary="Create a new document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.create( + file_path="pg_essay_1.html", + metadata={"metadata_1":"some random metadata"}, + id=None + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.create({ + file: { path: "examples/data/marmeladov.txt", name: "marmeladov.txt" }, + metadata: { title: "marmeladov.txt" }, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/documents" \\ + -H "Content-Type: multipart/form-data" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -F "file=@pg_essay_1.html;type=text/html" \\ + -F 'metadata={}' \\ + -F 'id=null' + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_document( + file: Optional[UploadFile] = File( + None, + description="The file to ingest. Exactly one of file, raw_text, or chunks must be provided.", + ), + raw_text: Optional[str] = Form( + None, + description="Raw text content to ingest. Exactly one of file, raw_text, or chunks must be provided.", + ), + chunks: Optional[Json[list[str]]] = Form( + None, + description="Pre-processed text chunks to ingest. Exactly one of file, raw_text, or chunks must be provided.", + ), + id: Optional[UUID] = Form( + None, + description="The ID of the document. If not provided, a new ID will be generated.", + ), + collection_ids: Optional[Json[list[UUID]]] = Form( + None, + description="Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.", + ), + metadata: Optional[Json[dict]] = Form( + None, + description="Metadata to associate with the document, such as title, description, or custom fields.", + ), + ingestion_mode: IngestionMode = Form( + default=IngestionMode.custom, + description=( + "Ingestion modes:\n" + "- `hi-res`: Thorough ingestion with full summaries and enrichment.\n" + "- `fast`: Quick ingestion with minimal enrichment and no summaries.\n" + "- `custom`: Full control via `ingestion_config`.\n\n" + "If `filters` or `limit` (in `ingestion_config`) are provided alongside `hi-res` or `fast`, " + "they will override the default settings for that mode." + ), + ), + ingestion_config: Optional[Json[IngestionConfig]] = Form( + None, + description="An optional dictionary to override the default chunking configuration for the ingestion process. If not provided, the system will use the default server-side chunking configuration.", + ), + run_with_orchestration: Optional[bool] = Form( + True, + description="Whether or not ingestion runs with orchestration, default is `True`. When set to `False`, the ingestion process will run synchronous and directly return the result.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedIngestionResponse: + """ + Creates a new Document object from an input file, text content, or chunks. The chosen `ingestion_mode` determines + how the ingestion process is configured: + + **Ingestion Modes:** + - `hi-res`: Comprehensive parsing and enrichment, including summaries and possibly more thorough parsing. + - `fast`: Speed-focused ingestion that skips certain enrichment steps like summaries. + - `custom`: Provide a full `ingestion_config` to customize the entire ingestion process. + + Either a file or text content must be provided, but not both. Documents are shared through `Collections` which allow for tightly specified cross-user interactions. + + The ingestion process runs asynchronously and its progress can be tracked using the returned + task_id. + """ + if not auth_user.is_superuser: + user_document_count = ( + await self.services.management.documents_overview( + user_ids=[auth_user.id], + offset=0, + limit=1, + ) + )["total_entries"] + user_max_documents = ( + await self.services.management.get_user_max_documents( + auth_user.id + ) + ) + + if user_document_count >= user_max_documents: + raise R2RException( + status_code=403, + message=f"User has reached the maximum number of documents allowed ({user_max_documents}).", + ) + + # Get chunks using the vector handler's list_chunks method + user_chunk_count = ( + await self.services.ingestion.list_chunks( + filters={"owner_id": {"$eq": str(auth_user.id)}}, + offset=0, + limit=1, + ) + )["total_entries"] + user_max_chunks = ( + await self.services.management.get_user_max_chunks( + auth_user.id + ) + ) + if user_chunk_count >= user_max_chunks: + raise R2RException( + status_code=403, + message=f"User has reached the maximum number of chunks allowed ({user_max_chunks}).", + ) + + user_collections_count = ( + await self.services.management.collections_overview( + user_ids=[auth_user.id], + offset=0, + limit=1, + ) + )["total_entries"] + user_max_collections = ( + await self.services.management.get_user_max_collections( + auth_user.id + ) + ) + if user_collections_count >= user_max_collections: # type: ignore + raise R2RException( + status_code=403, + message=f"User has reached the maximum number of collections allowed ({user_max_collections}).", + ) + + effective_ingestion_config = self._prepare_ingestion_config( + ingestion_mode=ingestion_mode, + ingestion_config=ingestion_config, + ) + if not file and not raw_text and not chunks: + raise R2RException( + status_code=422, + message="Either a `file`, `raw_text`, or `chunks` must be provided.", + ) + if ( + (file and raw_text) + or (file and chunks) + or (raw_text and chunks) + ): + raise R2RException( + status_code=422, + message="Only one of `file`, `raw_text`, or `chunks` may be provided.", + ) + # Check if the user is a superuser + metadata = metadata or {} + + if chunks: + if len(chunks) == 0: + raise R2RException("Empty list of chunks provided", 400) + + if len(chunks) > MAX_CHUNKS_PER_REQUEST: + raise R2RException( + f"Maximum of {MAX_CHUNKS_PER_REQUEST} chunks per request", + 400, + ) + + document_id = id or generate_document_id( + "".join(chunks), auth_user.id + ) + + # FIXME: Metadata doesn't seem to be getting passed through + raw_chunks_for_doc = [ + UnprocessedChunk( + text=chunk, + metadata=metadata, + id=generate_id(), + ) + for chunk in chunks + ] + + # Prepare workflow input + workflow_input = { + "document_id": str(document_id), + "chunks": [ + chunk.model_dump(mode="json") + for chunk in raw_chunks_for_doc + ], + "metadata": metadata, # Base metadata for the document + "user": auth_user.model_dump_json(), + "ingestion_config": effective_ingestion_config.model_dump( + mode="json" + ), + } + + if run_with_orchestration: + try: + # Run ingestion with orchestration + raw_message = ( + await self.providers.orchestration.run_workflow( + "ingest-chunks", + {"request": workflow_input}, + options={ + "additional_metadata": { + "document_id": str(document_id), + } + }, + ) + ) + raw_message["document_id"] = str(document_id) + return raw_message # type: ignore + except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?) + logger.error( + f"Error running orchestrated ingestion: {e} \n\nAttempting to run without orchestration." + ) + + logger.info("Running chunk ingestion without orchestration.") + from core.main.orchestration import simple_ingestion_factory + + simple_ingestor = simple_ingestion_factory( + self.services.ingestion + ) + await simple_ingestor["ingest-chunks"](workflow_input) + + return { # type: ignore + "message": "Document created and ingested successfully.", + "document_id": str(document_id), + "task_id": None, + } + + else: + if file: + file_data = await self._process_file(file) + + if not file.filename: + raise R2RException( + status_code=422, + message="Uploaded file must have a filename.", + ) + + file_ext = file.filename.split(".")[ + -1 + ] # e.g. "pdf", "txt" + max_allowed_size = await self.services.management.get_max_upload_size_by_type( + user_id=auth_user.id, file_type_or_ext=file_ext + ) + + content_length = file_data["content_length"] + + if content_length > max_allowed_size: + raise R2RException( + status_code=413, # HTTP 413: Payload Too Large + message=( + f"File size exceeds maximum of {max_allowed_size} bytes " + f"for extension '{file_ext}'." + ), + ) + + file_content = BytesIO( + base64.b64decode(file_data["content"]) + ) + + file_data.pop("content", None) + document_id = id or generate_document_id( + file_data["filename"], auth_user.id + ) + elif raw_text: + content_length = len(raw_text) + file_content = BytesIO(raw_text.encode("utf-8")) + document_id = id or generate_document_id( + raw_text, auth_user.id + ) + title = metadata.get("title", None) + title = title + ".txt" if title else None + file_data = { + "filename": title or "N/A", + "content_type": "text/plain", + } + else: + raise R2RException( + status_code=422, + message="Either a file or content must be provided.", + ) + + workflow_input = { + "file_data": file_data, + "document_id": str(document_id), + "collection_ids": ( + [str(cid) for cid in collection_ids] + if collection_ids + else None + ), + "metadata": metadata, + "ingestion_config": effective_ingestion_config.model_dump( + mode="json" + ), + "user": auth_user.model_dump_json(), + "size_in_bytes": content_length, + "version": "v0", + } + + file_name = file_data["filename"] + await self.providers.database.files_handler.store_file( + document_id, + file_name, + file_content, + file_data["content_type"], + ) + + await self.services.ingestion.ingest_file_ingress( + file_data=workflow_input["file_data"], + user=auth_user, + document_id=workflow_input["document_id"], + size_in_bytes=workflow_input["size_in_bytes"], + metadata=workflow_input["metadata"], + version=workflow_input["version"], + ) + + if run_with_orchestration: + try: + # TODO - Modify create_chunks so that we can add chunks to existing document + + workflow_result: dict[ + str, str | None + ] = await self.providers.orchestration.run_workflow( # type: ignore + "ingest-files", + {"request": workflow_input}, + options={ + "additional_metadata": { + "document_id": str(document_id), + } + }, + ) + workflow_result["document_id"] = str(document_id) + return workflow_result # type: ignore + except Exception as e: # TODO: Need to find specific error (gRPC most likely?) + logger.error( + f"Error running orchestrated ingestion: {e} \n\nAttempting to run without orchestration." + ) + logger.info( + f"Running ingestion without orchestration for file {file_name} and document_id {document_id}." + ) + # TODO - Clean up implementation logic here to be more explicitly `synchronous` + from core.main.orchestration import simple_ingestion_factory + + simple_ingestor = simple_ingestion_factory(self.services.ingestion) + await simple_ingestor["ingest-files"](workflow_input) + return { # type: ignore + "message": "Document created and ingested successfully.", + "document_id": str(document_id), + "task_id": None, + } + + @self.router.patch( + "/documents/{id}/metadata", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Append metadata to a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.append_metadata( + id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + metadata=[{"key": "new_key", "value": "new_value"}] + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.appendMetadata({ + id: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + metadata: [{ key: "new_key", value: "new_value" }], + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def patch_metadata( + id: UUID = Path( + ..., + description="The ID of the document to append metadata to.", + ), + metadata: list[dict] = Body( + ..., + description="Metadata to append to the document.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedDocumentResponse: + """Appends metadata to a document. This endpoint allows adding new metadata fields or updating existing ones.""" + request_user_ids = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=request_user_ids, + document_ids=[id], + offset=0, + limit=1, + ) + ) + results = documents_overview_response["results"] + if len(results) == 0: + raise R2RException("Document not found.", 404) + + return await self.services.management.update_document_metadata( + document_id=id, + metadata=metadata, + overwrite=False, + ) + + @self.router.put( + "/documents/{id}/metadata", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Replace metadata of a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.replace_metadata( + id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + metadata=[{"key": "new_key", "value": "new_value"}] + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.replaceMetadata({ + id: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + metadata: [{ key: "new_key", value: "new_value" }], + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def put_metadata( + id: UUID = Path( + ..., + description="The ID of the document to append metadata to.", + ), + metadata: list[dict] = Body( + ..., + description="Metadata to append to the document.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedDocumentResponse: + """Replaces metadata in a document. This endpoint allows overwriting existing metadata fields.""" + request_user_ids = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=request_user_ids, + document_ids=[id], + offset=0, + limit=1, + ) + ) + results = documents_overview_response["results"] + if len(results) == 0: + raise R2RException("Document not found.", 404) + + return await self.services.management.update_document_metadata( + document_id=id, + metadata=metadata, + overwrite=True, + ) + + @self.router.post( + "/documents/export", + summary="Export documents to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.documents.export( + output_path="export.csv", + columns=["id", "title", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.documents.export({ + outputPath: "export.csv", + columns: ["id", "title", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/documents/export" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_documents( + background_tasks: BackgroundTasks, + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export documents as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_documents( + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.get( + "/documents/download_zip", + dependencies=[Depends(self.rate_limit_dependency)], + response_class=StreamingResponse, + summary="Export multiple documents as zip", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + client.documents.download_zip( + document_ids=["uuid1", "uuid2"], + start_date="2024-01-01", + end_date="2024-12-31" + ) + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents/download_zip?document_ids=uuid1,uuid2&start_date=2024-01-01&end_date=2024-12-31" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_files( + document_ids: Optional[list[UUID]] = Query( + None, + description="List of document IDs to include in the export. If not provided, all accessible documents will be included.", + ), + start_date: Optional[datetime] = Query( + None, + description="Filter documents created on or after this date.", + ), + end_date: Optional[datetime] = Query( + None, + description="Filter documents created before this date.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> StreamingResponse: + """Export multiple documents as a zip file. Documents can be + filtered by IDs and/or date range. + + The endpoint allows downloading: + - Specific documents by providing their IDs + - Documents within a date range + - All accessible documents if no filters are provided + + Files are streamed as a zip archive to handle potentially large downloads efficiently. + """ + if not auth_user.is_superuser: + # For non-superusers, verify access to requested documents + if document_ids: + documents_overview = ( + await self.services.management.documents_overview( + user_ids=[auth_user.id], + document_ids=document_ids, + offset=0, + limit=len(document_ids), + ) + ) + if len(documents_overview["results"]) != len(document_ids): + raise R2RException( + status_code=403, + message="You don't have access to one or more requested documents.", + ) + if not document_ids: + raise R2RException( + status_code=403, + message="Non-superusers must provide document IDs to export.", + ) + + ( + zip_name, + zip_content, + zip_size, + ) = await self.services.management.export_files( + document_ids=document_ids, + start_date=start_date, + end_date=end_date, + ) + encoded_filename = quote(zip_name) + + async def stream_file(): + yield zip_content.getvalue() + + return StreamingResponse( + stream_file(), + media_type="application/zip", + headers={ + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}", + "Content-Length": str(zip_size), + }, + ) + + @self.router.get( + "/documents", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List documents", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.list( + limit=10, + offset=0 + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.list({ + limit: 10, + offset: 0, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_documents( + ids: list[str] = Query( + [], + description="A list of document IDs to retrieve. If not provided, all documents will be returned.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + include_summary_embeddings: bool = Query( + False, + description="Specifies whether or not to include embeddings of each document summary.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedDocumentsResponse: + """Returns a paginated list of documents the authenticated user has + access to. + + Results can be filtered by providing specific document IDs. Regular + users will only see documents they own or have access to through + collections. Superusers can see all documents. + + The documents are returned in order of last modification, with most + recent first. + """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + filter_collection_ids = ( + None if auth_user.is_superuser else auth_user.collection_ids + ) + + document_uuids = [UUID(document_id) for document_id in ids] + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=requesting_user_id, + collection_ids=filter_collection_ids, + document_ids=document_uuids, + offset=offset, + limit=limit, + ) + ) + if not include_summary_embeddings: + for document in documents_overview_response["results"]: + document.summary_embedding = None + + return ( # type: ignore + documents_overview_response["results"], + { + "total_entries": documents_overview_response[ + "total_entries" + ] + }, + ) + + @self.router.get( + "/documents/{id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Retrieve a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.retrieve( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.retrieve({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_document( + id: UUID = Path( + ..., + description="The ID of the document to retrieve.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedDocumentResponse: + """Retrieves detailed information about a specific document by its + ID. + + This endpoint returns the document's metadata, status, and system information. It does not + return the document's content - use the `/documents/{id}/download` endpoint for that. + + Users can only retrieve documents they own or have access to through collections. + Superusers can retrieve any document. + """ + request_user_ids = ( + None if auth_user.is_superuser else [auth_user.id] + ) + filter_collection_ids = ( + None if auth_user.is_superuser else auth_user.collection_ids + ) + + documents_overview_response = await self.services.management.documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + user_ids=request_user_ids, + collection_ids=filter_collection_ids, + document_ids=[id], + offset=0, + limit=100, + ) + results = documents_overview_response["results"] + if len(results) == 0: + raise R2RException("Document not found.", 404) + + return results[0] + + @self.router.get( + "/documents/{id}/chunks", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List document chunks", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.list_chunks( + id="32b6a70f-a995-5c51-85d2-834f06283a1e" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.listChunks({ + id: "32b6a70f-a995-5c51-85d2-834f06283a1e", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/chunks" \\ + -H "Authorization: Bearer YOUR_API_KEY"\ + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_chunks( + id: UUID = Path( + ..., + description="The ID of the document to retrieve chunks for.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + include_vectors: Optional[bool] = Query( + False, + description="Whether to include vector embeddings in the response.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedChunksResponse: + """Retrieves the text chunks that were generated from a document + during ingestion. Chunks represent semantic sections of the + document and are used for retrieval and analysis. + + Users can only access chunks from documents they own or have access + to through collections. Vector embeddings are only included if + specifically requested. + + Results are returned in chunk sequence order, representing their + position in the original document. + """ + list_document_chunks = ( + await self.services.management.list_document_chunks( + document_id=id, + offset=offset, + limit=limit, + include_vectors=include_vectors or False, + ) + ) + + if not list_document_chunks["results"]: + raise R2RException( + "No chunks found for the given document ID.", 404 + ) + + is_owner = str( + list_document_chunks["results"][0].get("owner_id") + ) == str(auth_user.id) + document_collections = ( + await self.services.management.collections_overview( + offset=0, + limit=-1, + document_ids=[id], + ) + ) + + user_has_access = ( + is_owner + or set(auth_user.collection_ids).intersection( + {ele.id for ele in document_collections["results"]} # type: ignore + ) + != set() + ) + + if not user_has_access and not auth_user.is_superuser: + raise R2RException( + "Not authorized to access this document's chunks.", 403 + ) + + return ( # type: ignore + list_document_chunks["results"], + {"total_entries": list_document_chunks["total_entries"]}, + ) + + @self.router.get( + "/documents/{id}/download", + dependencies=[Depends(self.rate_limit_dependency)], + response_class=StreamingResponse, + summary="Download document content", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.download( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.download({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/download" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_document_file( + id: str = Path(..., description="Document ID"), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> StreamingResponse: + """Downloads the original file content of a document. + + For uploaded files, returns the original file with its proper MIME + type. For text-only documents, returns the content as plain text. + + Users can only download documents they own or have access to + through collections. + """ + try: + document_uuid = UUID(id) + except ValueError: + raise R2RException( + status_code=422, message="Invalid document ID format." + ) from None + + # Retrieve the document's information + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=None, + collection_ids=None, + document_ids=[document_uuid], + offset=0, + limit=1, + ) + ) + + if not documents_overview_response["results"]: + raise R2RException("Document not found.", 404) + + document = documents_overview_response["results"][0] + + is_owner = str(document.owner_id) == str(auth_user.id) + + if not auth_user.is_superuser and not is_owner: + document_collections = ( + await self.services.management.collections_overview( + offset=0, + limit=-1, + document_ids=[document_uuid], + ) + ) + + document_collection_ids = { + str(ele.id) + for ele in document_collections["results"] # type: ignore + } + + user_collection_ids = { + str(cid) for cid in auth_user.collection_ids + } + + has_collection_access = user_collection_ids.intersection( + document_collection_ids + ) + + if not has_collection_access: + raise R2RException( + "Not authorized to access this document.", 403 + ) + + file_tuple = await self.services.management.download_file( + document_uuid + ) + if not file_tuple: + raise R2RException(status_code=404, message="File not found.") + + file_name, file_content, file_size = file_tuple + encoded_filename = quote(file_name) + + mime_type, _ = mimetypes.guess_type(file_name) + if not mime_type: + mime_type = "application/octet-stream" + + async def file_stream(): + chunk_size = 1024 * 1024 # 1MB + while True: + data = file_content.read(chunk_size) + if not data: + break + yield data + + return StreamingResponse( + file_stream(), + media_type=mime_type, + headers={ + "Content-Disposition": f"inline; filename*=UTF-8''{encoded_filename}", + "Content-Length": str(file_size), + }, + ) + + @self.router.delete( + "/documents/by-filter", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete documents by filter", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + client = R2RClient() + # when using auth, do client.login(...) + response = client.documents.delete_by_filter( + filters={"document_type": {"$eq": "txt"}} + ) + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/documents/by-filter?filters=%7B%22document_type%22%3A%7B%22%24eq%22%3A%22text%22%7D%2C%22created_at%22%3A%7B%22%24lt%22%3A%222023-01-01T00%3A00%3A00Z%22%7D%7D" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_document_by_filter( + filters: Json[dict] = Body( + ..., description="JSON-encoded filters" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete documents based on provided filters. + + Allowed operators + include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, + `ilike`, `in`, and `nin`. Deletion requests are limited to a + user's own documents. + """ + + filters_dict = { + "$and": [{"owner_id": {"$eq": str(auth_user.id)}}, filters] + } + await ( + self.services.management.delete_documents_and_chunks_by_filter( + filters=filters_dict + ) + ) + + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.delete( + "/documents/{id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.delete( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.delete({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_document_by_id( + id: UUID = Path(..., description="Document ID"), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete a specific document. All chunks corresponding to the + document are deleted, and all other references to the document are + removed. + + NOTE - Deletions do not yet impact the knowledge graph or other derived data. This feature is planned for a future release. + """ + + filters: dict[str, Any] = {"document_id": {"$eq": str(id)}} + if not auth_user.is_superuser: + filters = { + "$and": [ + {"owner_id": {"$eq": str(auth_user.id)}}, + {"document_id": {"$eq": str(id)}}, + ] + } + + await ( + self.services.management.delete_documents_and_chunks_by_filter( + filters=filters + ) + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.get( + "/documents/{id}/collections", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List document collections", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.list_collections( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", offset=0, limit=10 + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.listCollections({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/collections" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_document_collections( + id: str = Path(..., description="Document ID"), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionsResponse: + """Retrieves all collections that contain the specified document. + This endpoint is restricted to superusers only and provides a + system-wide view of document organization. + + Collections are used to organize documents and manage access control. A document can belong + to multiple collections, and users can access documents through collection membership. + + The results are paginated and ordered by collection creation date, with the most recently + created collections appearing first. + + NOTE - This endpoint is only available to superusers, it will be extended to regular users in a future release. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can get the collections belonging to a document.", + 403, + ) + + collections_response = ( + await self.services.management.collections_overview( + offset=offset, + limit=limit, + document_ids=[UUID(id)], # Convert string ID to UUID + ) + ) + + return collections_response["results"], { # type: ignore + "total_entries": collections_response["total_entries"] + } + + @self.router.post( + "/documents/{id}/extract", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Extract entities and relationships", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.extract( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + ], + }, + ) + @self.base_endpoint + async def extract( + id: UUID = Path( + ..., + description="The ID of the document to extract entities and relationships from.", + ), + settings: Optional[GraphCreationSettings] = Body( + default=None, + description="Settings for the entities and relationships extraction process.", + ), + run_with_orchestration: Optional[bool] = Body( + default=True, + description="Whether to run the entities and relationships extraction process with orchestration.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Extracts entities and relationships from a document. + + The entities and relationships extraction process involves: + + 1. Parsing documents into semantic chunks + + 2. Extracting entities and relationships using LLMs + + 3. Storing the created entities and relationships in the knowledge graph + + 4. Preserving the document's metadata and content, and associating the elements with collections the document belongs to + """ + + settings = settings.dict() if settings else None # type: ignore + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=( + None if auth_user.is_superuser else [auth_user.id] + ), + collection_ids=( + None + if auth_user.is_superuser + else auth_user.collection_ids + ), + document_ids=[id], + offset=0, + limit=1, + ) + )["results"] + if len(documents_overview_response) == 0: + raise R2RException("Document not found.", 404) + + if ( + not auth_user.is_superuser + and auth_user.id != documents_overview_response[0].owner_id + ): + raise R2RException( + "Only a superuser can extract entities and relationships from a document they do not own.", + 403, + ) + + # Apply runtime settings overrides + server_graph_creation_settings = ( + self.providers.database.config.graph_creation_settings + ) + + if settings: + server_graph_creation_settings = update_settings_from_dict( + server_settings=server_graph_creation_settings, + settings_dict=settings, # type: ignore + ) + + if run_with_orchestration: + try: + workflow_input = { + "document_id": str(id), + "graph_creation_settings": server_graph_creation_settings.model_dump_json(), + "user": auth_user.json(), + } + + return await self.providers.orchestration.run_workflow( # type: ignore + "graph-extraction", {"request": workflow_input}, {} + ) + except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?) + logger.error( + f"Error running orchestrated extraction: {e} \n\nAttempting to run without orchestration." + ) + + from core.main.orchestration import ( + simple_graph_search_results_factory, + ) + + logger.info("Running extract-triples without orchestration.") + simple_graph_search_results = simple_graph_search_results_factory( + self.services.graph + ) + await simple_graph_search_results["graph-extraction"]( + workflow_input + ) + return { # type: ignore + "message": "Graph created successfully.", + "task_id": None, + } + + @self.router.post( + "/documents/{id}/deduplicate", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Deduplicate entities", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + + response = client.documents.deduplicate( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.deduplicate({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/deduplicate" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ], + }, + ) + @self.base_endpoint + async def deduplicate( + id: UUID = Path( + ..., + description="The ID of the document to extract entities and relationships from.", + ), + settings: Optional[GraphCreationSettings] = Body( + default=None, + description="Settings for the entities and relationships extraction process.", + ), + run_with_orchestration: Optional[bool] = Body( + default=True, + description="Whether to run the entities and relationships extraction process with orchestration.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Deduplicates entities from a document.""" + + settings = settings.model_dump() if settings else None # type: ignore + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=( + None if auth_user.is_superuser else [auth_user.id] + ), + collection_ids=( + None + if auth_user.is_superuser + else auth_user.collection_ids + ), + document_ids=[id], + offset=0, + limit=1, + ) + )["results"] + if len(documents_overview_response) == 0: + raise R2RException("Document not found.", 404) + + if ( + not auth_user.is_superuser + and auth_user.id != documents_overview_response[0].owner_id + ): + raise R2RException( + "Only a superuser can run deduplication on a document they do not own.", + 403, + ) + + # Apply runtime settings overrides + server_graph_creation_settings = ( + self.providers.database.config.graph_creation_settings + ) + + if settings: + server_graph_creation_settings = update_settings_from_dict( + server_settings=server_graph_creation_settings, + settings_dict=settings, # type: ignore + ) + + if run_with_orchestration: + try: + workflow_input = { + "document_id": str(id), + } + + return await self.providers.orchestration.run_workflow( # type: ignore + "graph-deduplication", + {"request": workflow_input}, + {}, + ) + except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?) + logger.error( + f"Error running orchestrated deduplication: {e} \n\nAttempting to run without orchestration." + ) + + from core.main.orchestration import ( + simple_graph_search_results_factory, + ) + + logger.info( + "Running deduplicate-document-entities without orchestration." + ) + simple_graph_search_results = simple_graph_search_results_factory( + self.services.graph + ) + await simple_graph_search_results["graph-deduplication"]( + workflow_input + ) + return { # type: ignore + "message": "Graph created successfully.", + "task_id": None, + } + + @self.router.get( + "/documents/{id}/entities", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Lists the entities from the document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.extract( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + ], + }, + ) + @self.base_endpoint + async def get_entities( + id: UUID = Path( + ..., + description="The ID of the document to retrieve entities from.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + include_embeddings: Optional[bool] = Query( + False, + description="Whether to include vector embeddings in the response.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedEntitiesResponse: + """Retrieves the entities that were extracted from a document. + These represent important semantic elements like people, places, + organizations, concepts, etc. + + Users can only access entities from documents they own or have + access to through collections. Entity embeddings are only included + if specifically requested. + + Results are returned in the order they were extracted from the + document. + """ + # if ( + # not auth_user.is_superuser + # and id not in auth_user.collection_ids + # ): + # raise R2RException( + # "The currently authenticated user does not have access to the specified collection.", + # 403, + # ) + + # First check if the document exists and user has access + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=( + None if auth_user.is_superuser else [auth_user.id] + ), + collection_ids=( + None + if auth_user.is_superuser + else auth_user.collection_ids + ), + document_ids=[id], + offset=0, + limit=1, + ) + ) + + if not documents_overview_response["results"]: + raise R2RException("Document not found.", 404) + + # Get all entities for this document from the document_entity table + ( + entities, + count, + ) = await self.providers.database.graphs_handler.entities.get( + parent_id=id, + store_type=StoreType.DOCUMENTS, + offset=offset, + limit=limit, + include_embeddings=include_embeddings or False, + ) + + return entities, {"total_entries": count} # type: ignore + + @self.router.post( + "/documents/{id}/entities/export", + summary="Export document entities to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.documents.export_entities( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + output_path="export.csv", + columns=["id", "title", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.documents.exportEntities({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + outputPath: "export.csv", + columns: ["id", "title", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_entities( + background_tasks: BackgroundTasks, + id: UUID = Path( + ..., + description="The ID of the document to export entities from.", + ), + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export documents as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_document_entities( + id=id, + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.get( + "/documents/{id}/relationships", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List document relationships", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.documents.list_relationships( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + offset=0, + limit=100 + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.documents.listRelationships({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + offset: 0, + limit: 100, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/relationships" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_relationships( + id: UUID = Path( + ..., + description="The ID of the document to retrieve relationships for.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + entity_names: Optional[list[str]] = Query( + None, + description="Filter relationships by specific entity names.", + ), + relationship_types: Optional[list[str]] = Query( + None, + description="Filter relationships by specific relationship types.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedRelationshipsResponse: + """Retrieves the relationships between entities that were extracted + from a document. These represent connections and interactions + between entities found in the text. + + Users can only access relationships from documents they own or have + access to through collections. Results can be filtered by entity + names and relationship types. + + Results are returned in the order they were extracted from the + document. + """ + # if ( + # not auth_user.is_superuser + # and id not in auth_user.collection_ids + # ): + # raise R2RException( + # "The currently authenticated user does not have access to the specified collection.", + # 403, + # ) + + # First check if the document exists and user has access + documents_overview_response = ( + await self.services.management.documents_overview( + user_ids=( + None if auth_user.is_superuser else [auth_user.id] + ), + collection_ids=( + None + if auth_user.is_superuser + else auth_user.collection_ids + ), + document_ids=[id], + offset=0, + limit=1, + ) + ) + + if not documents_overview_response["results"]: + raise R2RException("Document not found.", 404) + + # Get relationships for this document + ( + relationships, + count, + ) = await self.providers.database.graphs_handler.relationships.get( + parent_id=id, + store_type=StoreType.DOCUMENTS, + entity_names=entity_names, + relationship_types=relationship_types, + offset=offset, + limit=limit, + ) + + return relationships, {"total_entries": count} # type: ignore + + @self.router.post( + "/documents/{id}/relationships/export", + summary="Export document relationships to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.documents.export_entities( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + output_path="export.csv", + columns=["id", "title", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.documents.exportEntities({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + outputPath: "export.csv", + columns: ["id", "title", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_relationships( + background_tasks: BackgroundTasks, + id: UUID = Path( + ..., + description="The ID of the document to export entities from.", + ), + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export documents as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_document_relationships( + id=id, + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.post( + "/documents/search", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Search document summaries", + ) + @self.base_endpoint + async def search_documents( + query: str = Body( + ..., + description="The search query to perform.", + ), + search_mode: SearchMode = Body( + default=SearchMode.custom, + description=( + "Default value of `custom` allows full control over search settings.\n\n" + "Pre-configured search modes:\n" + "`basic`: A simple semantic-based search.\n" + "`advanced`: A more powerful hybrid search combining semantic and full-text.\n" + "`custom`: Full control via `search_settings`.\n\n" + "If `filters` or `limit` are provided alongside `basic` or `advanced`, " + "they will override the default settings for that mode." + ), + ), + search_settings: SearchSettings = Body( + default_factory=SearchSettings, + description="Settings for document search", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedDocumentSearchResponse: + """Perform a search query on the automatically generated document + summaries in the system. + + This endpoint allows for complex filtering of search results using PostgreSQL-based queries. + Filters can be applied to various fields such as document_id, and internal metadata values. + + + Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. + """ + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings + ) + + query_embedding = ( + await self.providers.embedding.async_get_embedding(query) + ) + results = await self.services.retrieval.search_documents( + query=query, + query_embedding=query_embedding, + settings=effective_settings, + ) + return results # type: ignore + + @staticmethod + async def _process_file(file): + import base64 + + content = await file.read() + + return { + "filename": file.filename, + "content": base64.b64encode(content).decode("utf-8"), + "content_type": file.content_type, + "content_length": len(content), + } diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py new file mode 100644 index 00000000..ba588c3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py @@ -0,0 +1,1065 @@ +import textwrap + +""" +This file contains updated OpenAPI examples for the RetrievalRouterV3 class. +These examples are designed to be included in the openapi_extra field for each route. +""" + +# Updated examples for search_app endpoint +search_app_examples = { + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient() + # if using auth, do client.login(...) + + # Basic search + response = client.retrieval.search( + query="What is DeepSeek R1?", + ) + + # Advanced mode with specific filters + response = client.retrieval.search( + query="What is DeepSeek R1?", + search_mode="advanced", + search_settings={ + "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}, + "limit": 5 + } + ) + + # Using hybrid search + response = client.retrieval.search( + query="What was Uber's profit in 2020?", + search_settings={ + "use_hybrid_search": True, + "hybrid_settings": { + "full_text_weight": 1.0, + "semantic_weight": 5.0, + "full_text_limit": 200, + "rrf_k": 50 + }, + "filters": {"title": {"$in": ["DeepSeek_R1.pdf"]}}, + } + ) + + # Advanced filtering + results = client.retrieval.search( + query="What are the effects of climate change?", + search_settings={ + "filters": { + "$and":[ + {"document_type": {"$eq": "pdf"}}, + {"metadata.year": {"$gt": 2020}} + ] + }, + "limit": 10 + } + ) + + # Knowledge graph enhanced search + results = client.retrieval.search( + query="What was DeepSeek R1", + graph_search_settings={ + "use_graph_search": True, + "kg_search_type": "local" + } + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // if using auth, do client.login(...) + + // Basic search + const response = await client.retrieval.search({ + query: "What is DeepSeek R1?", + }); + + // With specific filters + const filteredResponse = await client.retrieval.search({ + query: "What is DeepSeek R1?", + searchSettings: { + filters: {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}, + limit: 5 + } + }); + + // Using hybrid search + const hybridResponse = await client.retrieval.search({ + query: "What was Uber's profit in 2020?", + searchSettings: { + indexMeasure: "l2_distance", + useHybridSearch: true, + hybridSettings: { + fullTextWeight: 1.0, + semanticWeight: 5.0, + fullTextLimit: 200, + }, + filters: {"title": {"$in": ["DeepSeek_R1.pdf"]}}, + } + }); + + // Advanced filtering + const advancedResults = await client.retrieval.search({ + query: "What are the effects of climate change?", + searchSettings: { + filters: { + $and: [ + {document_type: {$eq: "pdf"}}, + {"metadata.year": {$gt: 2020}} + ] + }, + limit: 10 + } + }); + + // Knowledge graph enhanced search + const kgResults = await client.retrieval.search({ + query: "who was aristotle?", + graphSearchSettings: { + useKgSearch: true, + kgSearchType: "local" + } + }); + """ + ), + }, + { + "lang": "Shell", + "source": textwrap.dedent( + """ + # Basic search + curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "query": "What is DeepSeek R1?" + }' + + # With hybrid search and filters + curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "query": "What was Uber'\''s profit in 2020?", + "search_settings": { + "use_hybrid_search": true, + "hybrid_settings": { + "full_text_weight": 1.0, + "semantic_weight": 5.0, + "full_text_limit": 200, + "rrf_k": 50 + }, + "filters": {"title": {"$in": ["DeepSeek_R1.pdf"]}}, + "limit": 10, + "chunk_settings": { + "index_measure": "l2_distance", + "probes": 25, + "ef_search": 100 + } + } + }' + + # Knowledge graph enhanced search + curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\ + -H "Content-Type: application/json" \\ + -d '{ + "query": "who was aristotle?", + "graph_search_settings": { + "use_graph_search": true, + "kg_search_type": "local" + } + }' \\ + -H "Authorization: Bearer YOUR_API_KEY" + """ + ), + }, + ] +} + +# Updated examples for rag_app endpoint +rag_app_examples = { + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + # Basic RAG request + response = client.retrieval.rag( + query="What is DeepSeek R1?", + ) + + # Advanced RAG with custom search settings + response = client.retrieval.rag( + query="What is DeepSeek R1?", + search_settings={ + "use_semantic_search": True, + "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}, + "limit": 10, + }, + rag_generation_config={ + "stream": False, + "temperature": 0.7, + "max_tokens": 1500 + } + ) + + # Hybrid search in RAG + results = client.retrieval.rag( + "Who is Jon Snow?", + search_settings={"use_hybrid_search": True} + ) + + # Custom model selection + response = client.retrieval.rag( + "Who was Aristotle?", + rag_generation_config={"model":"anthropic/claude-3-haiku-20240307", "stream": True} + ) + for chunk in response: + print(chunk) + + # Streaming RAG + from r2r import ( + CitationEvent, + FinalAnswerEvent, + MessageEvent, + SearchResultsEvent, + R2RClient, + ) + + result_stream = client.retrieval.rag( + query="What is DeepSeek R1?", + search_settings={"limit": 25}, + rag_generation_config={"stream": True}, + ) + + # Process different event types + for event in result_stream: + if isinstance(event, SearchResultsEvent): + print("Search results:", event.data) + elif isinstance(event, MessageEvent): + print("Partial message:", event.data.delta) + elif isinstance(event, CitationEvent): + print("New citation detected:", event.data.id) + elif isinstance(event, FinalAnswerEvent): + print("Final answer:", event.data.generated_answer) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // when using auth, do client.login(...) + + // Basic RAG request + const response = await client.retrieval.rag({ + query: "What is DeepSeek R1?", + }); + + // RAG with custom settings + const advancedResponse = await client.retrieval.rag({ + query: "What is DeepSeek R1?", + searchSettings: { + useSemanticSearch: true, + filters: {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}, + limit: 10, + }, + ragGenerationConfig: { + stream: false, + temperature: 0.7, + maxTokens: 1500 + } + }); + + // Hybrid search in RAG + const hybridResults = await client.retrieval.rag({ + query: "Who is Jon Snow?", + searchSettings: { + useHybridSearch: true + }, + }); + + // Custom model + const customModelResponse = await client.retrieval.rag({ + query: "Who was Aristotle?", + ragGenerationConfig: { + model: 'anthropic/claude-3-haiku-20240307', + temperature: 0.7, + } + }); + + // Streaming RAG + const resultStream = await client.retrieval.rag({ + query: "What is DeepSeek R1?", + searchSettings: { limit: 25 }, + ragGenerationConfig: { stream: true }, + }); + + // Process streaming events + if (Symbol.asyncIterator in resultStream) { + for await (const event of resultStream) { + switch (event.event) { + case "search_results": + console.log("Search results:", event.data); + break; + case "message": + console.log("Partial message delta:", event.data.delta); + break; + case "citation": + console.log("New citation event:", event.data.id); + break; + case "final_answer": + console.log("Final answer:", event.data.generated_answer); + break; + default: + console.log("Unknown or unhandled event:", event); + } + } + } + """ + ), + }, + { + "lang": "Shell", + "source": textwrap.dedent( + """ + # Basic RAG request + curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "query": "What is DeepSeek R1?" + }' + + # RAG with custom settings + curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "query": "What is DeepSeek R1?", + "search_settings": { + "use_semantic_search": true, + "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}, + "limit": 10 + }, + "rag_generation_config": { + "stream": false, + "temperature": 0.7, + "max_tokens": 1500 + } + }' + + # Hybrid search in RAG + curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "query": "Who is Jon Snow?", + "search_settings": { + "use_hybrid_search": true, + "filters": {}, + "limit": 10 + } + }' + + # Custom model + curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "query": "Who is Jon Snow?", + "rag_generation_config": { + "model": "anthropic/claude-3-haiku-20240307", + "temperature": 0.7 + } + }' + """ + ), + }, + ] +} + +# Updated examples for agent_app endpoint +agent_app_examples = { + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ +from r2r import ( + R2RClient, + ThinkingEvent, + ToolCallEvent, + ToolResultEvent, + CitationEvent, + FinalAnswerEvent, + MessageEvent, +) + +client = R2RClient() +# when using auth, do client.login(...) + +# Basic synchronous request +response = client.retrieval.agent( + message={ + "role": "user", + "content": "Do a deep analysis of the philosophical implications of DeepSeek R1" + }, + rag_tools=["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"], +) + +# Advanced analysis with streaming and extended thinking +streaming_response = client.retrieval.agent( + message={ + "role": "user", + "content": "Do a deep analysis of the philosophical implications of DeepSeek R1" + }, + search_settings={"limit": 20}, + rag_tools=["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"], + rag_generation_config={ + "model": "anthropic/claude-3-7-sonnet-20250219", + "extended_thinking": True, + "thinking_budget": 4096, + "temperature": 1, + "top_p": None, + "max_tokens": 16000, + "stream": True + } +) + +# Process streaming events with emoji only on type change +current_event_type = None +for event in streaming_response: + # Check if the event type has changed + event_type = type(event) + if event_type != current_event_type: + current_event_type = event_type + print() # Add newline before new event type + + # Print emoji based on the new event type + if isinstance(event, ThinkingEvent): + print(f"\n🧠 Thinking: ", end="", flush=True) + elif isinstance(event, ToolCallEvent): + print(f"\n🔧 Tool call: ", end="", flush=True) + elif isinstance(event, ToolResultEvent): + print(f"\n📊 Tool result: ", end="", flush=True) + elif isinstance(event, CitationEvent): + print(f"\n📑 Citation: ", end="", flush=True) + elif isinstance(event, MessageEvent): + print(f"\n💬 Message: ", end="", flush=True) + elif isinstance(event, FinalAnswerEvent): + print(f"\n✅ Final answer: ", end="", flush=True) + + # Print the content without the emoji + if isinstance(event, ThinkingEvent): + print(f"{event.data.delta.content[0].payload.value}", end="", flush=True) + elif isinstance(event, ToolCallEvent): + print(f"{event.data.name}({event.data.arguments})") + elif isinstance(event, ToolResultEvent): + print(f"{event.data.content[:60]}...") + elif isinstance(event, CitationEvent): + print(f"{event.data.id}") + elif isinstance(event, MessageEvent): + print(f"{event.data.delta.content[0].payload.value}", end="", flush=True) + elif isinstance(event, FinalAnswerEvent): + print(f"{event.data.generated_answer[:100]}...") + print(f" Citations: {len(event.data.citations)} sources referenced") + +# Conversation with multiple turns (synchronous) +conversation = client.conversations.create() + +# First message in conversation +results_1 = client.retrieval.agent( + query="What does DeepSeek R1 imply for the future of AI?", + rag_generation_config={ + "model": "anthropic/claude-3-7-sonnet-20250219", + "extended_thinking": True, + "thinking_budget": 4096, + "temperature": 1, + "top_p": None, + "max_tokens": 16000, + "stream": True + }, + conversation_id=conversation.results.id +) + +# Follow-up query in the same conversation +results_2 = client.retrieval.agent( + query="How does it compare to other reasoning models?", + rag_generation_config={ + "model": "anthropic/claude-3-7-sonnet-20250219", + "extended_thinking": True, + "thinking_budget": 4096, + "temperature": 1, + "top_p": None, + "max_tokens": 16000, + "stream": True + }, + conversation_id=conversation.results.id +) + +# Access the final results +print(f"First response: {results_1.generated_answer[:100]}...") +print(f"Follow-up response: {results_2.generated_answer[:100]}...") +""" + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // when using auth, do client.login(...) + + async function main() { + // Basic synchronous request + const ragResponse = await client.retrieval.agent({ + message: { + role: "user", + content: "Do a deep analysis of the philosophical implications of DeepSeek R1" + }, + ragTools: ["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"] + }); + + // Advanced analysis with streaming and extended thinking + const streamingResponse = await client.retrieval.agent({ + message: { + role: "user", + content: "Do a deep analysis of the philosophical implications of DeepSeek R1" + }, + searchSettings: {limit: 20}, + ragTools: ["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"], + ragGenerationConfig: { + model: "anthropic/claude-3-7-sonnet-20250219", + extendedThinking: true, + thinkingBudget: 4096, + temperature: 1, + maxTokens: 16000, + stream: true + } + }); + + // Process streaming events with emoji only on type change + if (Symbol.asyncIterator in streamingResponse) { + let currentEventType = null; + + for await (const event of streamingResponse) { + // Check if event type has changed + const eventType = event.event; + if (eventType !== currentEventType) { + currentEventType = eventType; + console.log(); // Add newline before new event type + + // Print emoji based on the new event type + switch(eventType) { + case "thinking": + process.stdout.write(`🧠 Thinking: `); + break; + case "tool_call": + process.stdout.write(`🔧 Tool call: `); + break; + case "tool_result": + process.stdout.write(`📊 Tool result: `); + break; + case "citation": + process.stdout.write(`📑 Citation: `); + break; + case "message": + process.stdout.write(`💬 Message: `); + break; + case "final_answer": + process.stdout.write(`✅ Final answer: `); + break; + } + } + + // Print content based on event type + switch(eventType) { + case "thinking": + process.stdout.write(`${event.data.delta.content[0].payload.value}`); + break; + case "tool_call": + console.log(`${event.data.name}(${JSON.stringify(event.data.arguments)})`); + break; + case "tool_result": + console.log(`${event.data.content.substring(0, 60)}...`); + break; + case "citation": + console.log(`${event.data.id}`); + break; + case "message": + process.stdout.write(`${event.data.delta.content[0].payload.value}`); + break; + case "final_answer": + console.log(`${event.data.generated_answer.substring(0, 100)}...`); + console.log(` Citations: ${event.data.citations.length} sources referenced`); + break; + } + } + } + + // Conversation with multiple turns (synchronous) + const conversation = await client.conversations.create(); + + // First message in conversation + const results1 = await client.retrieval.agent({ + query: "What does DeepSeek R1 imply for the future of AI?", + ragGenerationConfig: { + model: "anthropic/claude-3-7-sonnet-20250219", + extendedThinking: true, + thinkingBudget: 4096, + temperature: 1, + maxTokens: 16000, + stream: true + }, + conversationId: conversation.results.id + }); + + // Follow-up query in the same conversation + const results2 = await client.retrieval.agent({ + query: "How does it compare to other reasoning models?", + ragGenerationConfig: { + model: "anthropic/claude-3-7-sonnet-20250219", + extendedThinking: true, + thinkingBudget: 4096, + temperature: 1, + maxTokens: 16000, + stream: true + }, + conversationId: conversation.results.id + }); + + // Log the results + console.log(`First response: ${results1.generated_answer.substring(0, 100)}...`); + console.log(`Follow-up response: ${results2.generated_answer.substring(0, 100)}...`); + } + + main(); + """ + ), + }, + { + "lang": "Shell", + "source": textwrap.dedent( + """ + # Basic request + curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "message": { + "role": "user", + "content": "What were the key contributions of Aristotle to logic?" + }, + "search_settings": { + "use_semantic_search": true, + "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}} + }, + "rag_tools": ["search_file_knowledge", "content", "web_search"] + }' + + # Advanced analysis with extended thinking + curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "message": { + "role": "user", + "content": "Do a deep analysis of the philosophical implications of DeepSeek R1" + }, + "search_settings": {"limit": 20}, + "research_tools": ["rag", "reasoning", "critique", "python_executor"], + "rag_generation_config": { + "model": "anthropic/claude-3-7-sonnet-20250219", + "extended_thinking": true, + "thinking_budget": 4096, + "temperature": 1, + "top_p": null, + "max_tokens": 16000, + "stream": true + } + }' + + # Conversation continuation + curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "message": { + "role": "user", + "content": "How does it compare to other reasoning models?" + }, + "conversation_id": "YOUR_CONVERSATION_ID" + }' + """ + ), + }, + ] +} + +# Updated examples for completion endpoint +completion_examples = { + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.completion( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What about Italy?"} + ], + generation_config={ + "model": "openai/gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 150, + "stream": False + } + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // when using auth, do client.login(...) + + async function main() { + const response = await client.completion({ + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is the capital of France?" }, + { role: "assistant", content: "The capital of France is Paris." }, + { role: "user", content: "What about Italy?" } + ], + generationConfig: { + model: "openai/gpt-4o-mini", + temperature: 0.7, + maxTokens: 150, + stream: false + } + }); + } + + main(); + """ + ), + }, + { + "lang": "Shell", + "source": textwrap.dedent( + """ + curl -X POST "https://api.sciphi.ai/v3/retrieval/completion" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What about Italy?"} + ], + "generation_config": { + "model": "openai/gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 150, + "stream": false + } + }' + """ + ), + }, + ] +} + +# Updated examples for embedding endpoint +embedding_examples = { + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.retrieval.embedding( + text="What is DeepSeek R1?", + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // when using auth, do client.login(...) + + async function main() { + const response = await client.retrieval.embedding({ + text: "What is DeepSeek R1?", + }); + } + + main(); + """ + ), + }, + { + "lang": "Shell", + "source": textwrap.dedent( + """ + curl -X POST "https://api.sciphi.ai/v3/retrieval/embedding" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "text": "What is DeepSeek R1?", + }' + """ + ), + }, + ] +} + +# Updated rag_app docstring +rag_app_docstring = """ +Execute a RAG (Retrieval-Augmented Generation) query. + +This endpoint combines search results with language model generation to produce accurate, +contextually-relevant responses based on your document corpus. + +**Features:** +- Combines vector search, optional knowledge graph integration, and LLM generation +- Automatically cites sources with unique citation identifiers +- Supports both streaming and non-streaming responses +- Compatible with various LLM providers (OpenAI, Anthropic, etc.) +- Web search integration for up-to-date information + +**Search Configuration:** +All search parameters from the search endpoint apply here, including filters, hybrid search, and graph-enhanced search. + +**Generation Configuration:** +Fine-tune the language model's behavior with `rag_generation_config`: +```json +{ + "model": "openai/gpt-4o-mini", // Model to use + "temperature": 0.7, // Control randomness (0-1) + "max_tokens": 1500, // Maximum output length + "stream": true // Enable token streaming +} +``` + +**Model Support:** +- OpenAI models (default) +- Anthropic Claude models (requires ANTHROPIC_API_KEY) +- Local models via Ollama +- Any provider supported by LiteLLM + +**Streaming Responses:** +When `stream: true` is set, the endpoint returns Server-Sent Events with the following types: +- `search_results`: Initial search results from your documents +- `message`: Partial tokens as they're generated +- `citation`: Citation metadata when sources are referenced +- `final_answer`: Complete answer with structured citations + +**Example Response:** +```json +{ + "generated_answer": "DeepSeek-R1 is a model that demonstrates impressive performance...[1]", + "search_results": { ... }, + "citations": [ + { + "id": "cit.123456", + "object": "citation", + "payload": { ... } + } + ] +} +``` +""" + +# Updated agent_app docstring +agent_app_docstring = """ +Engage with an intelligent agent for information retrieval, analysis, and research. + +This endpoint offers two operating modes: +- **RAG mode**: Standard retrieval-augmented generation for answering questions based on knowledge base +- **Research mode**: Advanced capabilities for deep analysis, reasoning, and computation + +### RAG Mode (Default) + +The RAG mode provides fast, knowledge-based responses using: +- Semantic and hybrid search capabilities +- Document-level and chunk-level content retrieval +- Optional web search integration +- Source citation and evidence-based responses + +### Research Mode + +The Research mode builds on RAG capabilities and adds: +- A dedicated reasoning system for complex problem-solving +- Critique capabilities to identify potential biases or logical fallacies +- Python execution for computational analysis +- Multi-step reasoning for deeper exploration of topics + +### Available Tools + +**RAG Tools:** +- `search_file_knowledge`: Semantic/hybrid search on your ingested documents +- `search_file_descriptions`: Search over file-level metadata +- `content`: Fetch entire documents or chunk structures +- `web_search`: Query external search APIs for up-to-date information +- `web_scrape`: Scrape and extract content from specific web pages + +**Research Tools:** +- `rag`: Leverage the underlying RAG agent for information retrieval +- `reasoning`: Call a dedicated model for complex analytical thinking +- `critique`: Analyze conversation history to identify flaws and biases +- `python_executor`: Execute Python code for complex calculations and analysis + +### Streaming Output + +When streaming is enabled, the agent produces different event types: +- `thinking`: Shows the model's step-by-step reasoning (when extended_thinking=true) +- `tool_call`: Shows when the agent invokes a tool +- `tool_result`: Shows the result of a tool call +- `citation`: Indicates when a citation is added to the response +- `message`: Streams partial tokens of the response +- `final_answer`: Contains the complete generated answer and structured citations + +### Conversations + +Maintain context across multiple turns by including `conversation_id` in each request. +After your first call, store the returned `conversation_id` and include it in subsequent calls. +""" + +# Updated completion_docstring +completion_docstring = """ +Generate completions for a list of messages. + +This endpoint uses the language model to generate completions for the provided messages. +The generation process can be customized using the generation_config parameter. + +The messages list should contain alternating user and assistant messages, with an optional +system message at the start. Each message should have a 'role' and 'content'. + +**Generation Configuration:** +Fine-tune the language model's behavior with `generation_config`: +```json +{ + "model": "openai/gpt-4o-mini", // Model to use + "temperature": 0.7, // Control randomness (0-1) + "max_tokens": 1500, // Maximum output length + "stream": true // Enable token streaming +} +``` + +**Multiple LLM Support:** +- OpenAI models (default) +- Anthropic Claude models (requires ANTHROPIC_API_KEY) +- Local models via Ollama +- Any provider supported by LiteLLM +""" + +# Updated embedding_docstring +embedding_docstring = """ +Generate embeddings for the provided text using the specified model. + +This endpoint uses the language model to generate embeddings for the provided text. +The model parameter specifies the model to use for generating embeddings. + +Embeddings are numerical representations of text that capture semantic meaning, +allowing for similarity comparisons and other vector operations. + +**Uses:** +- Semantic search +- Document clustering +- Text similarity analysis +- Content recommendation +""" + +# # Example implementation to update the routers in the RetrievalRouterV3 class +# def update_retrieval_router(router_class): +# """ +# Update the RetrievalRouterV3 class with the improved docstrings and examples. + +# This function demonstrates how the updated examples and docstrings would be +# integrated into the actual router class. +# """ +# # Update search_app endpoint +# router_class.search_app.__doc__ = search_app_docstring +# router_class.search_app.openapi_extra = search_app_examples + +# # Update rag_app endpoint +# router_class.rag_app.__doc__ = rag_app_docstring +# router_class.rag_app.openapi_extra = rag_app_examples + +# # Update agent_app endpoint +# router_class.agent_app.__doc__ = agent_app_docstring +# router_class.agent_app.openapi_extra = agent_app_examples + +# # Update completion endpoint +# router_class.completion.__doc__ = completion_docstring +# router_class.completion.openapi_extra = completion_examples + +# # Update embedding endpoint +# router_class.embedding.__doc__ = embedding_docstring +# router_class.embedding.openapi_extra = embedding_examples + +# return router_class + +# Example showing how the updated router would be integrated +""" +from your_module import RetrievalRouterV3 + +# Apply the updated docstrings and examples +router = RetrievalRouterV3(providers, services, config) +router = update_retrieval_router(router) + +# Now the router has the improved docstrings and examples +""" + +EXAMPLES = { + "search": search_app_examples, + "rag": rag_app_examples, + "agent": agent_app_examples, + "completion": completion_examples, + "embedding": embedding_examples, +} diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py new file mode 100644 index 00000000..244d76cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py @@ -0,0 +1,2051 @@ +import logging +import textwrap +from typing import Optional, cast +from uuid import UUID + +from fastapi import Body, Depends, Path, Query +from fastapi.background import BackgroundTasks +from fastapi.responses import FileResponse + +from core.base import GraphConstructionStatus, R2RException, Workflow +from core.base.abstractions import DocumentResponse, StoreType +from core.base.api.models import ( + GenericBooleanResponse, + GenericMessageResponse, + WrappedBooleanResponse, + WrappedCommunitiesResponse, + WrappedCommunityResponse, + WrappedEntitiesResponse, + WrappedEntityResponse, + WrappedGenericMessageResponse, + WrappedGraphResponse, + WrappedGraphsResponse, + WrappedRelationshipResponse, + WrappedRelationshipsResponse, +) +from core.utils import ( + generate_default_user_collection_id, + update_settings_from_dict, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +logger = logging.getLogger() + + +class GraphRouter(BaseRouterV3): + def __init__( + self, + providers: R2RProviders, + services: R2RServices, + config: R2RConfig, + ): + logging.info("Initializing GraphRouter") + super().__init__(providers, services, config) + self._register_workflows() + + def _register_workflows(self): + workflow_messages = {} + if self.providers.orchestration.config.provider == "hatchet": + workflow_messages["graph-extraction"] = ( + "Document extraction task queued successfully." + ) + workflow_messages["graph-clustering"] = ( + "Graph enrichment task queued successfully." + ) + workflow_messages["graph-deduplication"] = ( + "Entity deduplication task queued successfully." + ) + else: + workflow_messages["graph-extraction"] = ( + "Document entities and relationships extracted successfully." + ) + workflow_messages["graph-clustering"] = ( + "Graph communities created successfully." + ) + workflow_messages["graph-deduplication"] = ( + "Entity deduplication completed successfully." + ) + + self.providers.orchestration.register_workflows( + Workflow.GRAPH, + self.services.graph, + workflow_messages, + ) + + async def _get_collection_id( + self, collection_id: Optional[UUID], auth_user + ) -> UUID: + """Helper method to get collection ID, using default if none + provided.""" + if collection_id is None: + return generate_default_user_collection_id(auth_user.id) + return collection_id + + def _setup_routes(self): + @self.router.get( + "/graphs", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List graphs", + openapi_extra={ + "x-codeSamples": [ + { # TODO: Verify + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.list() + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.list({}); + } + + main(); + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def list_graphs( + collection_ids: list[str] = Query( + [], + description="A list of graph IDs to retrieve. If not provided, all graphs will be returned.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGraphsResponse: + """Returns a paginated list of graphs the authenticated user has + access to. + + Results can be filtered by providing specific graph IDs. Regular + users will only see graphs they own or have access to. Superusers + can see all graphs. + + The graphs are returned in order of last modification, with most + recent first. + """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + + graph_uuids = [UUID(graph_id) for graph_id in collection_ids] + + list_graphs_response = await self.services.graph.list_graphs( + # user_ids=requesting_user_id, + graph_ids=graph_uuids, + offset=offset, + limit=limit, + ) + + return ( # type: ignore + list_graphs_response["results"], + {"total_entries": list_graphs_response["total_entries"]}, + ) + + @self.router.get( + "/graphs/{collection_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Retrieve graph details", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.get( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.retrieve({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\ + -H "Authorization: Bearer YOUR_API_KEY" """), + }, + ] + }, + ) + @self.base_endpoint + async def get_graph( + collection_id: UUID = Path(...), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGraphResponse: + """Retrieves detailed information about a specific graph by ID.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified collection associated with the given graph.", + 403, + ) + + list_graphs_response = await self.services.graph.list_graphs( + # user_ids=None, + graph_ids=[collection_id], + offset=0, + limit=1, + ) + return list_graphs_response["results"][0] # type: ignore + + @self.router.post( + "/graphs/{collection_id}/communities/build", + dependencies=[Depends(self.rate_limit_dependency)], + ) + @self.base_endpoint + async def build_communities( + collection_id: UUID = Path( + ..., description="The unique identifier of the collection" + ), + graph_enrichment_settings: Optional[dict] = Body( + default=None, + description="Settings for the graph enrichment process.", + ), + run_with_orchestration: Optional[bool] = Body(True), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Creates communities in the graph by analyzing entity + relationships and similarities. + + Communities are created through the following process: + 1. Analyzes entity relationships and metadata to build a similarity graph + 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups + 3. Creates hierarchical community structure with multiple granularity levels + 4. Generates natural language summaries and statistical insights for each community + + The resulting communities can be used to: + - Understand high-level graph structure and organization + - Identify key entity groupings and their relationships + - Navigate and explore the graph at different levels of detail + - Generate insights about entity clusters and their characteristics + + The community detection process is configurable through settings like: + - Community detection algorithm parameters + - Summary generation prompt + """ + collections_overview_response = ( + await self.services.management.collections_overview( + user_ids=[auth_user.id], + collection_ids=[collection_id], + offset=0, + limit=1, + ) + )["results"] + if len(collections_overview_response) == 0: # type: ignore + raise R2RException("Collection not found.", 404) + + # Check user permissions for graph + if ( + not auth_user.is_superuser + and collections_overview_response[0].owner_id != auth_user.id # type: ignore + ): + raise R2RException( + "Only superusers can `build communities` for a graph they do not own.", + 403, + ) + + # If no collection ID is provided, use the default user collection + # id = generate_default_user_collection_id(auth_user.id) + + # Apply runtime settings overrides + server_graph_enrichment_settings = ( + self.providers.database.config.graph_enrichment_settings + ) + if graph_enrichment_settings: + server_graph_enrichment_settings = update_settings_from_dict( + server_graph_enrichment_settings, graph_enrichment_settings + ) + + workflow_input = { + "collection_id": str(collection_id), + "graph_enrichment_settings": server_graph_enrichment_settings.model_dump_json(), + "user": auth_user.json(), + } + + if run_with_orchestration: + try: + return await self.providers.orchestration.run_workflow( # type: ignore + "graph-clustering", {"request": workflow_input}, {} + ) + return GenericMessageResponse( + message="Graph communities created successfully." + ) # type: ignore + + except Exception as e: # TODO: Need to find specific error (gRPC most likely?) + logger.error( + f"Error running orchestrated community building: {e} \n\nAttempting to run without orchestration." + ) + from core.main.orchestration import ( + simple_graph_search_results_factory, + ) + + logger.info("Running build-communities without orchestration.") + simple_graph_search_results = simple_graph_search_results_factory( + self.services.graph + ) + await simple_graph_search_results["graph-clustering"]( + workflow_input + ) + return { # type: ignore + "message": "Graph communities created successfully.", + "task_id": None, + } + + @self.router.post( + "/graphs/{collection_id}/reset", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Reset a graph back to the initial state.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.reset( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.reset({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/reset" \\ + -H "Authorization: Bearer YOUR_API_KEY" """), + }, + ] + }, + ) + @self.base_endpoint + async def reset( + collection_id: UUID = Path(...), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Deletes a graph and all its associated data. + + This endpoint permanently removes the specified graph along with + all entities and relationships that belong to only this graph. The + original source entities and relationships extracted from + underlying documents are not deleted and are managed through the + document lifecycle. + """ + if not auth_user.is_superuser: + raise R2RException("Only superusers can reset a graph", 403) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + await self.services.graph.reset_graph(id=collection_id) + # await _pull(collection_id, auth_user) + return GenericBooleanResponse(success=True) # type: ignore + + # update graph + @self.router.post( + "/graphs/{collection_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Update graph", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.update( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + graph={ + "name": "New Name", + "description": "New Description" + } + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.update({ + collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + name: "New Name", + description: "New Description", + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_graph( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to update", + ), + name: Optional[str] = Body( + None, description="The name of the graph" + ), + description: Optional[str] = Body( + None, description="An optional description of the graph" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGraphResponse: + """Update an existing graphs's configuration. + + This endpoint allows updating the name and description of an + existing collection. The user must have appropriate permissions to + modify the collection. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can update graph details", 403 + ) + + if ( + not auth_user.is_superuser + and id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + return await self.services.graph.update_graph( # type: ignore + collection_id, + name=name, + description=description, + ) + + @self.router.get( + "/graphs/{collection_id}/entities", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.list_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.listEntities({ + collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + }); + } + + main(); + """), + }, + ], + }, + ) + @self.base_endpoint + async def get_entities( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to list entities from.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedEntitiesResponse: + """Lists all entities in the graph with pagination support.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + entities, count = await self.services.graph.get_entities( + parent_id=collection_id, + offset=offset, + limit=limit, + ) + + return entities, { # type: ignore + "total_entries": count, + } + + @self.router.post( + "/graphs/{collection_id}/entities/export", + summary="Export graph entities to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.graphs.export_entities( + collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + output_path="export.csv", + columns=["id", "title", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.graphs.exportEntities({ + collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + outputPath: "export.csv", + columns: ["id", "title", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/graphs/export_entities" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_entities( + background_tasks: BackgroundTasks, + collection_id: UUID = Path( + ..., + description="The ID of the collection to export entities from.", + ), + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export documents as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_graph_entities( + id=collection_id, + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.post( + "/graphs/{collection_id}/entities", + dependencies=[Depends(self.rate_limit_dependency)], + ) + @self.base_endpoint + async def create_entity( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to add the entity to.", + ), + name: str = Body( + ..., description="The name of the entity to create." + ), + description: str = Body( + ..., description="The description of the entity to create." + ), + category: Optional[str] = Body( + None, description="The category of the entity to create." + ), + metadata: Optional[dict] = Body( + None, description="The metadata of the entity to create." + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedEntityResponse: + """Creates a new entity in the graph.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + return await self.services.graph.create_entity( # type: ignore + name=name, + description=description, + parent_id=collection_id, + category=category, + metadata=metadata, + ) + + @self.router.post( + "/graphs/{collection_id}/relationships", + dependencies=[Depends(self.rate_limit_dependency)], + ) + @self.base_endpoint + async def create_relationship( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to add the relationship to.", + ), + subject: str = Body( + ..., description="The subject of the relationship to create." + ), + subject_id: UUID = Body( + ..., + description="The ID of the subject of the relationship to create.", + ), + predicate: str = Body( + ..., description="The predicate of the relationship to create." + ), + object: str = Body( + ..., description="The object of the relationship to create." + ), + object_id: UUID = Body( + ..., + description="The ID of the object of the relationship to create.", + ), + description: str = Body( + ..., + description="The description of the relationship to create.", + ), + weight: float = Body( + 1.0, description="The weight of the relationship to create." + ), + metadata: Optional[dict] = Body( + None, description="The metadata of the relationship to create." + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedRelationshipResponse: + """Creates a new relationship in the graph.""" + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can create relationships.", 403 + ) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + return await self.services.graph.create_relationship( # type: ignore + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + weight=weight, + metadata=metadata, + parent_id=collection_id, + ) + + @self.router.post( + "/graphs/{collection_id}/relationships/export", + summary="Export graph relationships to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.graphs.export_entities( + collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + output_path="export.csv", + columns=["id", "title", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.graphs.exportEntities({ + collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + outputPath: "export.csv", + columns: ["id", "title", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/graphs/export_relationships" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_relationships( + background_tasks: BackgroundTasks, + collection_id: UUID = Path( + ..., + description="The ID of the document to export entities from.", + ), + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export documents as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_graph_relationships( + id=collection_id, + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.get( + "/graphs/{collection_id}/entities/{entity_id}", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.get_entity( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.get_entity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_entity( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph containing the entity.", + ), + entity_id: UUID = Path( + ..., description="The ID of the entity to retrieve." + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedEntityResponse: + """Retrieves a specific entity by its ID.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + result = await self.providers.database.graphs_handler.entities.get( + parent_id=collection_id, + store_type=StoreType.GRAPHS, + offset=0, + limit=1, + entity_ids=[entity_id], + ) + if len(result) == 0 or len(result[0]) == 0: + raise R2RException("Entity not found", 404) + return result[0][0] + + @self.router.post( + "/graphs/{collection_id}/entities/{entity_id}", + dependencies=[Depends(self.rate_limit_dependency)], + ) + @self.base_endpoint + async def update_entity( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph containing the entity.", + ), + entity_id: UUID = Path( + ..., description="The ID of the entity to update." + ), + name: Optional[str] = Body( + ..., description="The updated name of the entity." + ), + description: Optional[str] = Body( + None, description="The updated description of the entity." + ), + category: Optional[str] = Body( + None, description="The updated category of the entity." + ), + metadata: Optional[dict] = Body( + None, description="The updated metadata of the entity." + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedEntityResponse: + """Updates an existing entity in the graph.""" + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can update graph entities.", 403 + ) + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + return await self.services.graph.update_entity( # type: ignore + entity_id=entity_id, + name=name, + category=category, + description=description, + metadata=metadata, + ) + + @self.router.delete( + "/graphs/{collection_id}/entities/{entity_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Remove an entity", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.remove_entity( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.removeEntity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_entity( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to remove the entity from.", + ), + entity_id: UUID = Path( + ..., + description="The ID of the entity to remove from the graph.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Removes an entity from the graph.""" + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can delete graph details.", 403 + ) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + await self.services.graph.delete_entity( + parent_id=collection_id, + entity_id=entity_id, + ) + + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.get( + "/graphs/{collection_id}/relationships", + dependencies=[Depends(self.rate_limit_dependency)], + description="Lists all relationships in the graph with pagination support.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.listRelationships({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + }); + } + + main(); + """), + }, + ], + }, + ) + @self.base_endpoint + async def get_relationships( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to list relationships from.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedRelationshipsResponse: + """Lists all relationships in the graph with pagination support.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + relationships, count = await self.services.graph.get_relationships( + parent_id=collection_id, + offset=offset, + limit=limit, + ) + + return relationships, { # type: ignore + "total_entries": count, + } + + @self.router.get( + "/graphs/{collection_id}/relationships/{relationship_id}", + dependencies=[Depends(self.rate_limit_dependency)], + description="Retrieves a specific relationship by its ID.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.get_relationship( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.getRelationship({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + ], + }, + ) + @self.base_endpoint + async def get_relationship( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph containing the relationship.", + ), + relationship_id: UUID = Path( + ..., description="The ID of the relationship to retrieve." + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedRelationshipResponse: + """Retrieves a specific relationship by its ID.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + results = ( + await self.providers.database.graphs_handler.relationships.get( + parent_id=collection_id, + store_type=StoreType.GRAPHS, + offset=0, + limit=1, + relationship_ids=[relationship_id], + ) + ) + if len(results) == 0 or len(results[0]) == 0: + raise R2RException("Relationship not found", 404) + return results[0][0] + + @self.router.post( + "/graphs/{collection_id}/relationships/{relationship_id}", + dependencies=[Depends(self.rate_limit_dependency)], + ) + @self.base_endpoint + async def update_relationship( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph containing the relationship.", + ), + relationship_id: UUID = Path( + ..., description="The ID of the relationship to update." + ), + subject: Optional[str] = Body( + ..., description="The updated subject of the relationship." + ), + subject_id: Optional[UUID] = Body( + ..., description="The updated subject ID of the relationship." + ), + predicate: Optional[str] = Body( + ..., description="The updated predicate of the relationship." + ), + object: Optional[str] = Body( + ..., description="The updated object of the relationship." + ), + object_id: Optional[UUID] = Body( + ..., description="The updated object ID of the relationship." + ), + description: Optional[str] = Body( + None, + description="The updated description of the relationship.", + ), + weight: Optional[float] = Body( + None, description="The updated weight of the relationship." + ), + metadata: Optional[dict] = Body( + None, description="The updated metadata of the relationship." + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedRelationshipResponse: + """Updates an existing relationship in the graph.""" + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can update graph details", 403 + ) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + return await self.services.graph.update_relationship( # type: ignore + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + weight=weight, + metadata=metadata, + ) + + @self.router.delete( + "/graphs/{collection_id}/relationships/{relationship_id}", + dependencies=[Depends(self.rate_limit_dependency)], + description="Removes a relationship from the graph.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.delete_relationship( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.deleteRelationship({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + ], + }, + ) + @self.base_endpoint + async def delete_relationship( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to remove the relationship from.", + ), + relationship_id: UUID = Path( + ..., + description="The ID of the relationship to remove from the graph.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Removes a relationship from the graph.""" + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can delete a relationship.", 403 + ) + + if ( + not auth_user.is_superuser + and collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + await self.services.graph.delete_relationship( + parent_id=collection_id, + relationship_id=relationship_id, + ) + + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.post( + "/graphs/{collection_id}/communities", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Create a new community", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.create_community( + collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + name="My Community", + summary="A summary of the community", + findings=["Finding 1", "Finding 2"], + rating=5, + rating_explanation="This is a rating explanation", + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.createCommunity({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + name: "My Community", + summary: "A summary of the community", + findings: ["Finding 1", "Finding 2"], + rating: 5, + ratingExplanation: "This is a rating explanation", + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_community( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to create the community in.", + ), + name: str = Body(..., description="The name of the community"), + summary: str = Body(..., description="A summary of the community"), + findings: Optional[list[str]] = Body( + default=[], description="Findings about the community" + ), + rating: Optional[float] = Body( + default=5, ge=1, le=10, description="Rating between 1 and 10" + ), + rating_explanation: Optional[str] = Body( + default="", description="Explanation for the rating" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCommunityResponse: + """Creates a new community in the graph. + + While communities are typically built automatically via the /graphs/{id}/communities/build endpoint, + this endpoint allows you to manually create your own communities. + + This can be useful when you want to: + - Define custom groupings of entities based on domain knowledge + - Add communities that weren't detected by the automatic process + - Create hierarchical organization structures + - Tag groups of entities with specific metadata + + The created communities will be integrated with any existing automatically detected communities + in the graph's community structure. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can create a community.", 403 + ) + + if ( + not auth_user.is_superuser + and collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + return await self.services.graph.create_community( # type: ignore + parent_id=collection_id, + name=name, + summary=summary, + findings=findings, + rating=rating, + rating_explanation=rating_explanation, + ) + + @self.router.get( + "/graphs/{collection_id}/communities", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List communities", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.listCommunities({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_communities( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to get communities for.", + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCommunitiesResponse: + """Lists all communities in the graph with pagination support.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + communities, count = await self.services.graph.get_communities( + parent_id=collection_id, + offset=offset, + limit=limit, + ) + + return communities, { # type: ignore + "total_entries": count, + } + + @self.router.get( + "/graphs/{collection_id}/communities/{community_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Retrieve a community", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.getCommunity({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_community( + collection_id: UUID = Path( + ..., + description="The ID of the collection to get communities for.", + ), + community_id: UUID = Path( + ..., + description="The ID of the community to get.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCommunityResponse: + """Retrieves a specific community by its ID.""" + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + results = ( + await self.providers.database.graphs_handler.communities.get( + parent_id=collection_id, + community_ids=[community_id], + store_type=StoreType.GRAPHS, + offset=0, + limit=1, + ) + ) + if len(results) == 0 or len(results[0]) == 0: + raise R2RException("Community not found", 404) + return results[0][0] + + @self.router.delete( + "/graphs/{collection_id}/communities/{community_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete a community", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.delete_community( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + community_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.graphs.deleteCommunity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_community( + collection_id: UUID = Path( + ..., + description="The collection ID corresponding to the graph to delete the community from.", + ), + community_id: UUID = Path( + ..., + description="The ID of the community to delete.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "Only superusers can delete communities", 403 + ) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + await self.services.graph.delete_community( + parent_id=collection_id, + community_id=community_id, + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.post( + "/graphs/{collection_id}/communities/export", + summary="Export document communities to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.graphs.export_communities( + collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + output_path="export.csv", + columns=["id", "title", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.graphs.exportCommunities({ + collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", + outputPath: "export.csv", + columns: ["id", "title", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/graphs/export_communities" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_communities( + background_tasks: BackgroundTasks, + collection_id: UUID = Path( + ..., + description="The ID of the document to export entities from.", + ), + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export documents as a downloadable CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can export data.", + 403, + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_graph_communities( + id=collection_id, + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="documents_export.csv", + ) + + @self.router.post( + "/graphs/{collection_id}/communities/{community_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Update community", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.update_community( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + community_update={ + "metadata": { + "topic": "Technology", + "description": "Tech companies and products" + } + } + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + async function main() { + const response = await client.graphs.updateCommunity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityUpdate: { + metadata: { + topic: "Technology", + description: "Tech companies and products" + } + } + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_community( + collection_id: UUID = Path(...), + community_id: UUID = Path(...), + name: Optional[str] = Body(None), + summary: Optional[str] = Body(None), + findings: Optional[list[str]] = Body(None), + rating: Optional[float] = Body(default=None, ge=1, le=10), + rating_explanation: Optional[str] = Body(None), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCommunityResponse: + """Updates an existing community in the graph.""" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "Only superusers can update communities.", 403 + ) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + return await self.services.graph.update_community( # type: ignore + community_id=community_id, + name=name, + summary=summary, + findings=findings, + rating=rating, + rating_explanation=rating_explanation, + ) + + @self.router.post( + "/graphs/{collection_id}/pull", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Pull latest entities to the graph", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.pull( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + async function main() { + const response = await client.graphs.pull({ + collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def pull( + collection_id: UUID = Path( + ..., description="The ID of the graph to initialize." + ), + force: Optional[bool] = Body( + False, + description="If true, forces a re-pull of all entities and relationships.", + ), + # document_ids: list[UUID] = Body( + # ..., description="List of document IDs to add to the graph." + # ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Adds documents to a graph by copying their entities and + relationships. + + This endpoint: + 1. Copies document entities to the graphs_entities table + 2. Copies document relationships to the graphs_relationships table + 3. Associates the documents with the graph + + When a document is added: + - Its entities and relationships are copied to graph-specific tables + - Existing entities/relationships are updated by merging their properties + - The document ID is recorded in the graph's document_ids array + + Documents added to a graph will contribute their knowledge to: + - Graph analysis and querying + - Community detection + - Knowledge graph enrichment + + The user must have access to both the graph and the documents being added. + """ + + collections_overview_response = ( + await self.services.management.collections_overview( + user_ids=[auth_user.id], + collection_ids=[collection_id], + offset=0, + limit=1, + ) + )["results"] + if len(collections_overview_response) == 0: # type: ignore + raise R2RException("Collection not found.", 404) + + # Check user permissions for graph + if ( + not auth_user.is_superuser + and collections_overview_response[0].owner_id != auth_user.id # type: ignore + ): + raise R2RException("Only superusers can `pull` a graph.", 403) + + if ( + # not auth_user.is_superuser + collection_id not in auth_user.collection_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + list_graphs_response = await self.services.graph.list_graphs( + # user_ids=None, + graph_ids=[collection_id], + offset=0, + limit=1, + ) + if len(list_graphs_response["results"]) == 0: # type: ignore + raise R2RException("Graph not found", 404) + collection_id = list_graphs_response["results"][0].collection_id # type: ignore + documents: list[DocumentResponse] = [] + document_req = await self.providers.database.collections_handler.documents_in_collection( + collection_id, offset=0, limit=100 + ) + results = cast(list[DocumentResponse], document_req["results"]) + documents.extend(results) + + while len(results) == 100: + document_req = await self.providers.database.collections_handler.documents_in_collection( + collection_id, offset=len(documents), limit=100 + ) + results = cast(list[DocumentResponse], document_req["results"]) + documents.extend(results) + + success = False + + for document in documents: + entities = ( + await self.providers.database.graphs_handler.entities.get( + parent_id=document.id, + store_type=StoreType.DOCUMENTS, + offset=0, + limit=100, + ) + ) + has_document = ( + await self.providers.database.graphs_handler.has_document( + collection_id, document.id + ) + ) + if has_document: + logger.info( + f"Document {document.id} is already in graph {collection_id}, skipping." + ) + continue + if len(entities[0]) == 0: + if not force: + logger.warning( + f"Document {document.id} has no entities, extraction may not have been called, skipping." + ) + continue + else: + logger.warning( + f"Document {document.id} has no entities, but force=True, continuing." + ) + + success = ( + await self.providers.database.graphs_handler.add_documents( + id=collection_id, + document_ids=[document.id], + ) + ) + if not success: + logger.warning( + f"No documents were added to graph {collection_id}, marking as failed." + ) + + if success: + await self.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.SUCCESS, + ) + + return GenericBooleanResponse(success=success) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py new file mode 100644 index 00000000..29b75226 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py @@ -0,0 +1,576 @@ +import logging +import textwrap +from typing import Optional + +from fastapi import Body, Depends, Path, Query + +from core.base import IndexConfig, R2RException +from core.base.abstractions import VectorTableName +from core.base.api.models import ( + VectorIndexResponse, + VectorIndicesResponse, + WrappedGenericMessageResponse, + WrappedVectorIndexResponse, + WrappedVectorIndicesResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +logger = logging.getLogger() + + +class IndicesRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing IndicesRouter") + super().__init__(providers, services, config) + + def _setup_routes(self): + ## TODO - Allow developer to pass the index id with the request + @self.router.post( + "/indices", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Create Vector Index", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + # Create an HNSW index for efficient similarity search + result = client.indices.create( + config={ + "table_name": "chunks", # The table containing vector embeddings + "index_method": "hnsw", # Hierarchical Navigable Small World graph + "index_measure": "cosine_distance", # Similarity measure + "index_arguments": { + "m": 16, # Number of connections per layer + "ef_construction": 64,# Size of dynamic candidate list for construction + "ef": 40, # Size of dynamic candidate list for search + }, + "index_name": "my_document_embeddings_idx", + "index_column": "embedding", + "concurrently": True # Build index without blocking table writes + }, + run_with_orchestration=True # Run as orchestrated task for large indices + ) + + # Create an IVF-Flat index for balanced performance + result = client.indices.create( + config={ + "table_name": "chunks", + "index_method": "ivf_flat", # Inverted File with Flat storage + "index_measure": "l2_distance", + "index_arguments": { + "lists": 100, # Number of cluster centroids + "probe": 10, # Number of clusters to search + }, + "index_name": "my_ivf_embeddings_idx", + "index_column": "embedding", + "concurrently": True + } + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.indicies.create({ + config: { + tableName: "vectors", + indexMethod: "hnsw", + indexMeasure: "cosine_distance", + indexArguments: { + m: 16, + ef_construction: 64, + ef: 40 + }, + indexName: "my_document_embeddings_idx", + indexColumn: "embedding", + concurrently: true + }, + runWithOrchestration: true + }); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + # Create HNSW Index + curl -X POST "https://api.example.com/indices" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "config": { + "table_name": "vectors", + "index_method": "hnsw", + "index_measure": "cosine_distance", + "index_arguments": { + "m": 16, + "ef_construction": 64, + "ef": 40 + }, + "index_name": "my_document_embeddings_idx", + "index_column": "embedding", + "concurrently": true + }, + "run_with_orchestration": true + }' + + # Create IVF-Flat Index + curl -X POST "https://api.example.com/indices" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "config": { + "table_name": "vectors", + "index_method": "ivf_flat", + "index_measure": "l2_distance", + "index_arguments": { + "lists": 100, + "probe": 10 + }, + "index_name": "my_ivf_embeddings_idx", + "index_column": "embedding", + "concurrently": true + } + }' + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_index( + config: IndexConfig, + run_with_orchestration: Optional[bool] = Body( + True, + description="Whether to run index creation as an orchestrated task (recommended for large indices)", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Create a new vector similarity search index in over the target + table. Allowed tables include 'vectors', 'entity', + 'document_collections'. Vectors correspond to the chunks of text + that are indexed for similarity search, whereas entity and + document_collections are created during knowledge graph + construction. + + This endpoint creates a database index optimized for efficient similarity search over vector embeddings. + It supports two main indexing methods: + + 1. HNSW (Hierarchical Navigable Small World): + - Best for: High-dimensional vectors requiring fast approximate nearest neighbor search + - Pros: Very fast search, good recall, memory-resident for speed + - Cons: Slower index construction, more memory usage + - Key parameters: + * m: Number of connections per layer (higher = better recall but more memory) + * ef_construction: Build-time search width (higher = better recall but slower build) + * ef: Query-time search width (higher = better recall but slower search) + + 2. IVF-Flat (Inverted File with Flat Storage): + - Best for: Balance between build speed, search speed, and recall + - Pros: Faster index construction, less memory usage + - Cons: Slightly slower search than HNSW + - Key parameters: + * lists: Number of clusters (usually sqrt(n) where n is number of vectors) + * probe: Number of nearest clusters to search + + Supported similarity measures: + - cosine_distance: Best for comparing semantic similarity + - l2_distance: Best for comparing absolute distances + - ip_distance: Best for comparing raw dot products + + Notes: + - Index creation can be resource-intensive for large datasets + - Use run_with_orchestration=True for large indices to prevent timeouts + - The 'concurrently' option allows other operations while building + - Index names must be unique per table + """ + # TODO: Implement index creation logic + logger.info( + f"Creating vector index for {config.table_name} with method {config.index_method}, measure {config.index_measure}, concurrently {config.concurrently}" + ) + + result = await self.providers.orchestration.run_workflow( + "create-vector-index", + { + "request": { + "table_name": config.table_name, + "index_method": config.index_method, + "index_measure": config.index_measure, + "index_name": config.index_name, + "index_column": config.index_column, + "index_arguments": config.index_arguments, + "concurrently": config.concurrently, + }, + }, + options={ + "additional_metadata": {}, + }, + ) + + return result # type: ignore + + @self.router.get( + "/indices", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List Vector Indices", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + + # List all indices + indices = client.indices.list( + offset=0, + limit=10 + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.indicies.list({ + offset: 0, + limit: 10, + filters: { table_name: "vectors" } + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/indices?offset=0&limit=10" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" + + # With filters + curl -X GET "https://api.example.com/indices?offset=0&limit=10&filters={\"table_name\":\"vectors\"}" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_indices( + # filters: list[str] = Query([]), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedVectorIndicesResponse: + """List existing vector similarity search indices with pagination + support. + + Returns details about each index including: + - Name and table name + - Indexing method and parameters + - Size and row count + - Creation timestamp and last updated + - Performance statistics (if available) + + The response can be filtered using the filter_by parameter to narrow down results + based on table name, index method, or other attributes. + """ + # TODO: Implement index listing logic + indices_data = ( + await self.providers.database.chunks_handler.list_indices( + offset=offset, limit=limit + ) + ) + + formatted_indices = VectorIndicesResponse( + indices=[ + VectorIndexResponse(index=index_data) + for index_data in indices_data["indices"] + ] + ) + + return ( # type: ignore + formatted_indices, + {"total_entries": indices_data["total_entries"]}, + ) + + @self.router.get( + "/indices/{table_name}/{index_name}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Get Vector Index Details", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + + # Get detailed information about a specific index + index = client.indices.retrieve("index_1") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.indicies.retrieve({ + indexName: "index_1", + tableName: "vectors" + }); + + console.log(response); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/indices/vectors/index_1" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_index( + table_name: VectorTableName = Path( + ..., + description="The table of vector embeddings to delete (e.g. `vectors`, `entity`, `document_collections`)", + ), + index_name: str = Path( + ..., description="The name of the index to delete" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedVectorIndexResponse: + """Get detailed information about a specific vector index. + + Returns comprehensive information about the index including: + - Configuration details (method, measure, parameters) + - Current size and row count + - Build progress (if still under construction) + - Performance statistics: + * Average query time + * Memory usage + * Cache hit rates + * Recent query patterns + - Maintenance information: + * Last vacuum + * Fragmentation level + * Recommended optimizations + """ + # TODO: Implement get index logic + indices = ( + await self.providers.database.chunks_handler.list_indices( + filters={ + "index_name": index_name, + "table_name": table_name, + }, + limit=1, + offset=0, + ) + ) + if len(indices["indices"]) != 1: + raise R2RException( + f"Index '{index_name}' not found", status_code=404 + ) + return {"index": indices["indices"][0]} # type: ignore + + # TODO - Implement update index + # @self.router.post( + # "/indices/{name}", + # summary="Update Vector Index", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": """ + # from r2r import R2RClient + + # client = R2RClient() + + # # Update HNSW index parameters + # result = client.indices.update( + # "550e8400-e29b-41d4-a716-446655440000", + # config={ + # "index_arguments": { + # "ef": 80, # Increase search quality + # "m": 24 # Increase connections per layer + # }, + # "concurrently": True + # }, + # run_with_orchestration=True + # )""", + # }, + # { + # "lang": "Shell", + # "source": """ + # curl -X PUT "https://api.example.com/indices/550e8400-e29b-41d4-a716-446655440000" \\ + # -H "Content-Type: application/json" \\ + # -H "Authorization: Bearer YOUR_API_KEY" \\ + # -d '{ + # "config": { + # "index_arguments": { + # "ef": 80, + # "m": 24 + # }, + # "concurrently": true + # }, + # "run_with_orchestration": true + # }'""", + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def update_index( + # id: UUID = Path(...), + # config: IndexConfig = Body(...), + # run_with_orchestration: Optional[bool] = Body(True), + # auth_user=Depends(self.providers.auth.auth_wrapper()), + # ): # -> WrappedUpdateIndexResponse: + # """ + # Update an existing index's configuration. + # """ + # # TODO: Implement index update logic + # pass + + @self.router.delete( + "/indices/{table_name}/{index_name}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete Vector Index", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + + # Delete an index with orchestration for cleanup + result = client.indices.delete( + index_name="index_1", + table_name="vectors", + run_with_orchestration=True + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.indicies.delete({ + indexName: "index_1" + tableName: "vectors" + }); + + console.log(response); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/indices/index_1" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_index( + table_name: VectorTableName = Path( + default=..., + description="The table of vector embeddings to delete (e.g. `vectors`, `entity`, `document_collections`)", + ), + index_name: str = Path( + ..., description="The name of the index to delete" + ), + # concurrently: bool = Body( + # default=True, + # description="Whether to delete the index concurrently (recommended for large indices)", + # ), + # run_with_orchestration: Optional[bool] = Body(True), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Delete an existing vector similarity search index. + + This endpoint removes the specified index from the database. Important considerations: + + - Deletion is permanent and cannot be undone + - Underlying vector data remains intact + - Queries will fall back to sequential scan + - Running queries during deletion may be slower + - Use run_with_orchestration=True for large indices to prevent timeouts + - Consider index dependencies before deletion + + The operation returns immediately but cleanup may continue in background. + """ + logger.info( + f"Deleting vector index {index_name} from table {table_name}" + ) + + return await self.providers.orchestration.run_workflow( # type: ignore + "delete-vector-index", + { + "request": { + "index_name": index_name, + "table_name": table_name, + "concurrently": True, + }, + }, + options={ + "additional_metadata": {}, + }, + ) diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py new file mode 100644 index 00000000..55512143 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py @@ -0,0 +1,387 @@ +import logging +import textwrap +from typing import Optional + +from fastapi import Body, Depends, Path, Query + +from core.base import R2RException +from core.base.api.models import ( + GenericBooleanResponse, + GenericMessageResponse, + WrappedBooleanResponse, + WrappedGenericMessageResponse, + WrappedPromptResponse, + WrappedPromptsResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + + +class PromptsRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing PromptsRouter") + super().__init__(providers, services, config) + + def _setup_routes(self): + @self.router.post( + "/prompts", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Create a new prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.prompts.create( + name="greeting_prompt", + template="Hello, {name}!", + input_types={"name": "string"} + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.prompts.create({ + name: "greeting_prompt", + template: "Hello, {name}!", + inputTypes: { name: "string" }, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/prompts" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"name": "greeting_prompt", "template": "Hello, {name}!", "input_types": {"name": "string"}}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_prompt( + name: str = Body(..., description="The name of the prompt"), + template: str = Body( + ..., description="The template string for the prompt" + ), + input_types: dict[str, str] = Body( + default={}, + description="A dictionary mapping input names to their types", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Create a new prompt with the given configuration. + + This endpoint allows superusers to create a new prompt with a + specified name, template, and input types. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can create prompts.", + 403, + ) + result = await self.services.management.add_prompt( + name, template, input_types + ) + return GenericMessageResponse(message=result) # type: ignore + + @self.router.get( + "/prompts", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List all prompts", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.prompts.list() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.prompts.list(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/v3/prompts" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_prompts( + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedPromptsResponse: + """List all available prompts. + + This endpoint retrieves a list of all prompts in the system. Only + superusers can access this endpoint. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can list prompts.", + 403, + ) + get_prompts_response = ( + await self.services.management.get_all_prompts() + ) + + return ( # type: ignore + get_prompts_response["results"], + { + "total_entries": get_prompts_response["total_entries"], + }, + ) + + @self.router.post( + "/prompts/{name}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Get a specific prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.prompts.get( + "greeting_prompt", + inputs={"name": "John"}, + prompt_override="Hi, {name}!" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.prompts.retrieve({ + name: "greeting_prompt", + inputs: { name: "John" }, + promptOverride: "Hi, {name}!", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_prompt( + name: str = Path(..., description="Prompt name"), + inputs: Optional[dict[str, str]] = Body( + None, description="Prompt inputs" + ), + prompt_override: Optional[str] = Query( + None, description="Prompt override" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedPromptResponse: + """Get a specific prompt by name, optionally with inputs and + override. + + This endpoint retrieves a specific prompt and allows for optional + inputs and template override. Only superusers can access this + endpoint. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can retrieve prompts.", + 403, + ) + result = await self.services.management.get_prompt( + name, inputs, prompt_override + ) + return result # type: ignore + + @self.router.put( + "/prompts/{name}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Update an existing prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.prompts.update( + "greeting_prompt", + template="Greetings, {name}!", + input_types={"name": "string", "age": "integer"} + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.prompts.update({ + name: "greeting_prompt", + template: "Greetings, {name}!", + inputTypes: { name: "string", age: "integer" }, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X PUT "https://api.example.com/v3/prompts/greeting_prompt" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"template": "Greetings, {name}!", "input_types": {"name": "string", "age": "integer"}}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def update_prompt( + name: str = Path(..., description="Prompt name"), + template: Optional[str] = Body( + None, description="Updated prompt template" + ), + input_types: dict[str, str] = Body( + default={}, + description="A dictionary mapping input names to their types", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Update an existing prompt's template and/or input types. + + This endpoint allows superusers to update the template and input + types of an existing prompt. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can update prompts.", + 403, + ) + result = await self.services.management.update_prompt( + name, template, input_types + ) + return GenericMessageResponse(message=result) # type: ignore + + @self.router.delete( + "/prompts/{name}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete a prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.prompts.delete("greeting_prompt") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.prompts.delete({ + name: "greeting_prompt", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/v3/prompts/greeting_prompt" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_prompt( + name: str = Path(..., description="Prompt name"), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete a prompt by name. + + This endpoint allows superusers to delete an existing prompt. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can delete prompts.", + 403, + ) + await self.services.management.delete_prompt(name) + return GenericBooleanResponse(success=True) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py new file mode 100644 index 00000000..28749319 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py @@ -0,0 +1,639 @@ +import logging +from typing import Any, Literal, Optional +from uuid import UUID + +from fastapi import Body, Depends +from fastapi.responses import StreamingResponse + +from core.base import ( + GenerationConfig, + Message, + R2RException, + SearchMode, + SearchSettings, + select_search_filters, +) +from core.base.api.models import ( + WrappedAgentResponse, + WrappedCompletionResponse, + WrappedEmbeddingResponse, + WrappedLLMChatCompletion, + WrappedRAGResponse, + WrappedSearchResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 +from .examples import EXAMPLES + +logger = logging.getLogger(__name__) + + +def merge_search_settings( + base: SearchSettings, overrides: SearchSettings +) -> SearchSettings: + # Convert both to dict + base_dict = base.model_dump() + overrides_dict = overrides.model_dump(exclude_unset=True) + + # Update base_dict with values from overrides_dict + # This ensures that any field set in overrides takes precedence + for k, v in overrides_dict.items(): + base_dict[k] = v + + # Construct a new SearchSettings from the merged dict + return SearchSettings(**base_dict) + + +class RetrievalRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing RetrievalRouter") + super().__init__(providers, services, config) + + def _register_workflows(self): + pass + + def _prepare_search_settings( + self, + auth_user: Any, + search_mode: SearchMode, + search_settings: Optional[SearchSettings], + ) -> SearchSettings: + """Prepare the effective search settings based on the provided + search_mode, optional user-overrides in search_settings, and applied + filters.""" + if search_mode != SearchMode.custom: + # Start from mode defaults + effective_settings = SearchSettings.get_default(search_mode.value) + if search_settings: + # Merge user-provided overrides + effective_settings = merge_search_settings( + effective_settings, search_settings + ) + else: + # Custom mode: use provided settings or defaults + effective_settings = search_settings or SearchSettings() + + # Apply user-specific filters + effective_settings.filters = select_search_filters( + auth_user, effective_settings + ) + return effective_settings + + def _setup_routes(self): + @self.router.post( + "/retrieval/search", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Search R2R", + openapi_extra=EXAMPLES["search"], + ) + @self.base_endpoint + async def search_app( + query: str = Body( + ..., + description="Search query to find relevant documents", + ), + search_mode: SearchMode = Body( + default=SearchMode.custom, + description=( + "Default value of `custom` allows full control over search settings.\n\n" + "Pre-configured search modes:\n" + "`basic`: A simple semantic-based search.\n" + "`advanced`: A more powerful hybrid search combining semantic and full-text.\n" + "`custom`: Full control via `search_settings`.\n\n" + "If `filters` or `limit` are provided alongside `basic` or `advanced`, " + "they will override the default settings for that mode." + ), + ), + search_settings: Optional[SearchSettings] = Body( + None, + description=( + "The search configuration object. If `search_mode` is `custom`, " + "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n" + "Common overrides include `filters` to narrow results and `limit` to control how many results are returned." + ), + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedSearchResponse: + """Perform a search query against vector and/or graph-based + databases. + + **Search Modes:** + - `basic`: Defaults to semantic search. Simple and easy to use. + - `advanced`: Combines semantic search with full-text search for more comprehensive results. + - `custom`: Complete control over how search is performed. Provide a full `SearchSettings` object. + + **Filters:** + Apply filters directly inside `search_settings.filters`. For example: + ```json + { + "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}} + } + ``` + Supported operators: `$eq`, `$neq`, `$gt`, `$gte`, `$lt`, `$lte`, `$like`, `$ilike`, `$in`, `$nin`. + + **Hybrid Search:** + Enable hybrid search by setting `use_hybrid_search: true` in search_settings. This combines semantic search with + keyword-based search for improved results. Configure with `hybrid_settings`: + ```json + { + "use_hybrid_search": true, + "hybrid_settings": { + "full_text_weight": 1.0, + "semantic_weight": 5.0, + "full_text_limit": 200, + "rrf_k": 50 + } + } + ``` + + **Graph-Enhanced Search:** + Knowledge graph integration is enabled by default. Control with `graph_search_settings`: + ```json + { + "graph_search_settings": { + "use_graph_search": true, + "kg_search_type": "local" + } + } + ``` + + **Advanced Filtering:** + Use complex filters to narrow down results by metadata fields or document properties: + ```json + { + "filters": { + "$and":[ + {"document_type": {"$eq": "pdf"}}, + {"metadata.year": {"$gt": 2020}} + ] + } + } + ``` + + **Results:** + The response includes vector search results and optional graph search results. + Each result contains the matched text, document ID, and relevance score. + + """ + if query == "": + raise R2RException("Query cannot be empty", 400) + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings + ) + results = await self.services.retrieval.search( + query=query, + search_settings=effective_settings, + ) + return results # type: ignore + + @self.router.post( + "/retrieval/rag", + dependencies=[Depends(self.rate_limit_dependency)], + summary="RAG Query", + response_model=None, + openapi_extra=EXAMPLES["rag"], + ) + @self.base_endpoint + async def rag_app( + query: str = Body(...), + search_mode: SearchMode = Body( + default=SearchMode.custom, + description=( + "Default value of `custom` allows full control over search settings.\n\n" + "Pre-configured search modes:\n" + "`basic`: A simple semantic-based search.\n" + "`advanced`: A more powerful hybrid search combining semantic and full-text.\n" + "`custom`: Full control via `search_settings`.\n\n" + "If `filters` or `limit` are provided alongside `basic` or `advanced`, " + "they will override the default settings for that mode." + ), + ), + search_settings: Optional[SearchSettings] = Body( + None, + description=( + "The search configuration object. If `search_mode` is `custom`, " + "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n" + "Common overrides include `filters` to narrow results and `limit` to control how many results are returned." + ), + ), + rag_generation_config: GenerationConfig = Body( + default_factory=GenerationConfig, + description="Configuration for RAG generation", + ), + task_prompt: Optional[str] = Body( + default=None, + description="Optional custom prompt to override default", + ), + include_title_if_available: bool = Body( + default=False, + description="Include document titles in responses when available", + ), + include_web_search: bool = Body( + default=False, + description="Include web search results provided to the LLM.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedRAGResponse: + """Execute a RAG (Retrieval-Augmented Generation) query. + + This endpoint combines search results with language model generation to produce accurate, + contextually-relevant responses based on your document corpus. + + **Features:** + - Combines vector search, optional knowledge graph integration, and LLM generation + - Automatically cites sources with unique citation identifiers + - Supports both streaming and non-streaming responses + - Compatible with various LLM providers (OpenAI, Anthropic, etc.) + - Web search integration for up-to-date information + + **Search Configuration:** + All search parameters from the search endpoint apply here, including filters, hybrid search, and graph-enhanced search. + + **Generation Configuration:** + Fine-tune the language model's behavior with `rag_generation_config`: + ```json + { + "model": "openai/gpt-4o-mini", // Model to use + "temperature": 0.7, // Control randomness (0-1) + "max_tokens": 1500, // Maximum output length + "stream": true // Enable token streaming + } + ``` + + **Model Support:** + - OpenAI models (default) + - Anthropic Claude models (requires ANTHROPIC_API_KEY) + - Local models via Ollama + - Any provider supported by LiteLLM + + **Streaming Responses:** + When `stream: true` is set, the endpoint returns Server-Sent Events with the following types: + - `search_results`: Initial search results from your documents + - `message`: Partial tokens as they're generated + - `citation`: Citation metadata when sources are referenced + - `final_answer`: Complete answer with structured citations + + **Example Response:** + ```json + { + "generated_answer": "DeepSeek-R1 is a model that demonstrates impressive performance...[1]", + "search_results": { ... }, + "citations": [ + { + "id": "cit.123456", + "object": "citation", + "payload": { ... } + } + ] + } + ``` + """ + + if "model" not in rag_generation_config.__fields_set__: + rag_generation_config.model = self.config.app.quality_llm + + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings + ) + + response = await self.services.retrieval.rag( + query=query, + search_settings=effective_settings, + rag_generation_config=rag_generation_config, + task_prompt=task_prompt, + include_title_if_available=include_title_if_available, + include_web_search=include_web_search, + ) + + if rag_generation_config.stream: + # ========== Streaming path ========== + async def stream_generator(): + try: + async for chunk in response: + if len(chunk) > 1024: + for i in range(0, len(chunk), 1024): + yield chunk[i : i + 1024] + else: + yield chunk + except GeneratorExit: + # Clean up if needed, then return + return + + return StreamingResponse( + stream_generator(), media_type="text/event-stream" + ) # type: ignore + else: + # ========== Non-streaming path ========== + return response + + @self.router.post( + "/retrieval/agent", + dependencies=[Depends(self.rate_limit_dependency)], + summary="RAG-powered Conversational Agent", + openapi_extra=EXAMPLES["agent"], + ) + @self.base_endpoint + async def agent_app( + message: Optional[Message] = Body( + None, + description="Current message to process", + ), + messages: Optional[list[Message]] = Body( + None, + deprecated=True, + description="List of messages (deprecated, use message instead)", + ), + search_mode: SearchMode = Body( + default=SearchMode.custom, + description="Pre-configured search modes: basic, advanced, or custom.", + ), + search_settings: Optional[SearchSettings] = Body( + None, + description="The search configuration object for retrieving context.", + ), + # Generation configurations + rag_generation_config: GenerationConfig = Body( + default_factory=GenerationConfig, + description="Configuration for RAG generation in 'rag' mode", + ), + research_generation_config: Optional[GenerationConfig] = Body( + None, + description="Configuration for generation in 'research' mode. If not provided but mode='research', rag_generation_config will be used with appropriate model overrides.", + ), + # Tool configurations + rag_tools: Optional[ + list[ + Literal[ + "web_search", + "web_scrape", + "search_file_descriptions", + "search_file_knowledge", + "get_file_content", + ] + ] + ] = Body( + None, + description="List of tools to enable for RAG mode. Available tools: search_file_knowledge, get_file_content, web_search, web_scrape, search_file_descriptions", + ), + research_tools: Optional[ + list[ + Literal["rag", "reasoning", "critique", "python_executor"] + ] + ] = Body( + None, + description="List of tools to enable for Research mode. Available tools: rag, reasoning, critique, python_executor", + ), + # Backward compatibility + tools: Optional[list[str]] = Body( + None, + deprecated=True, + description="List of tools to execute (deprecated, use rag_tools or research_tools instead)", + ), + # Other parameters + task_prompt: Optional[str] = Body( + default=None, + description="Optional custom prompt to override default", + ), + # Backward compatibility + task_prompt_override: Optional[str] = Body( + default=None, + deprecated=True, + description="Optional custom prompt to override default", + ), + include_title_if_available: bool = Body( + default=True, + description="Pass document titles from search results into the LLM context window.", + ), + conversation_id: Optional[UUID] = Body( + default=None, + description="ID of the conversation", + ), + max_tool_context_length: Optional[int] = Body( + default=32_768, + description="Maximum length of returned tool context", + ), + use_system_context: Optional[bool] = Body( + default=True, + description="Use extended prompt for generation", + ), + mode: Optional[Literal["rag", "research"]] = Body( + default="rag", + description="Mode to use for generation: 'rag' for standard retrieval or 'research' for deep analysis with reasoning capabilities", + ), + needs_initial_conversation_name: Optional[bool] = Body( + default=None, + description="If true, the system will automatically assign a conversation name if not already specified previously.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedAgentResponse: + """ + Engage with an intelligent agent for information retrieval, analysis, and research. + + This endpoint offers two operating modes: + - **RAG mode**: Standard retrieval-augmented generation for answering questions based on knowledge base + - **Research mode**: Advanced capabilities for deep analysis, reasoning, and computation + + ### RAG Mode (Default) + + The RAG mode provides fast, knowledge-based responses using: + - Semantic and hybrid search capabilities + - Document-level and chunk-level content retrieval + - Optional web search integration + - Source citation and evidence-based responses + + ### Research Mode + + The Research mode builds on RAG capabilities and adds: + - A dedicated reasoning system for complex problem-solving + - Critique capabilities to identify potential biases or logical fallacies + - Python execution for computational analysis + - Multi-step reasoning for deeper exploration of topics + + ### Available Tools + + **RAG Tools:** + - `search_file_knowledge`: Semantic/hybrid search on your ingested documents + - `search_file_descriptions`: Search over file-level metadata + - `content`: Fetch entire documents or chunk structures + - `web_search`: Query external search APIs for up-to-date information + - `web_scrape`: Scrape and extract content from specific web pages + + **Research Tools:** + - `rag`: Leverage the underlying RAG agent for information retrieval + - `reasoning`: Call a dedicated model for complex analytical thinking + - `critique`: Analyze conversation history to identify flaws and biases + - `python_executor`: Execute Python code for complex calculations and analysis + + ### Streaming Output + + When streaming is enabled, the agent produces different event types: + - `thinking`: Shows the model's step-by-step reasoning (when extended_thinking=true) + - `tool_call`: Shows when the agent invokes a tool + - `tool_result`: Shows the result of a tool call + - `citation`: Indicates when a citation is added to the response + - `message`: Streams partial tokens of the response + - `final_answer`: Contains the complete generated answer and structured citations + + ### Conversations + + Maintain context across multiple turns by including `conversation_id` in each request. + After your first call, store the returned `conversation_id` and include it in subsequent calls. + If no conversation name has already been set for the conversation, the system will automatically assign one. + + """ + # Handle backward compatibility for task_prompt + task_prompt = task_prompt or task_prompt_override + # Handle model selection based on mode + if "model" not in rag_generation_config.__fields_set__: + if mode == "rag": + rag_generation_config.model = self.config.app.quality_llm + elif mode == "research": + rag_generation_config.model = self.config.app.planning_llm + + # Prepare search settings + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings + ) + + # Handle tool configuration and backward compatibility + if tools: # Handle deprecated tools parameter + logger.warning( + "The 'tools' parameter is deprecated. Use 'rag_tools' or 'research_tools' based on mode." + ) + rag_tools = tools # type: ignore + + # Determine effective generation config + effective_generation_config = rag_generation_config + if mode == "research" and research_generation_config: + effective_generation_config = research_generation_config + + try: + response = await self.services.retrieval.agent( + message=message, + messages=messages, + search_settings=effective_settings, + rag_generation_config=rag_generation_config, + research_generation_config=research_generation_config, + task_prompt=task_prompt, + include_title_if_available=include_title_if_available, + max_tool_context_length=max_tool_context_length or 32_768, + conversation_id=( + str(conversation_id) if conversation_id else None # type: ignore + ), + use_system_context=use_system_context + if use_system_context is not None + else True, + rag_tools=rag_tools, # type: ignore + research_tools=research_tools, # type: ignore + mode=mode, + needs_initial_conversation_name=needs_initial_conversation_name, + ) + + if effective_generation_config.stream: + + async def stream_generator(): + try: + async for chunk in response: + if len(chunk) > 1024: + for i in range(0, len(chunk), 1024): + yield chunk[i : i + 1024] + else: + yield chunk + except GeneratorExit: + # Clean up if needed, then return + return + + return StreamingResponse( # type: ignore + stream_generator(), media_type="text/event-stream" + ) + else: + return response + except Exception as e: + logger.error(f"Error in agent_app: {e}") + raise R2RException(str(e), 500) from e + + @self.router.post( + "/retrieval/completion", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Generate Message Completions", + openapi_extra=EXAMPLES["completion"], + ) + @self.base_endpoint + async def completion( + messages: list[Message] = Body( + ..., + description="List of messages to generate completion for", + example=[ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "What is the capital of France?", + }, + { + "role": "assistant", + "content": "The capital of France is Paris.", + }, + {"role": "user", "content": "What about Italy?"}, + ], + ), + generation_config: GenerationConfig = Body( + default_factory=GenerationConfig, + description="Configuration for text generation", + example={ + "model": "openai/gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 150, + "stream": False, + }, + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + response_model=WrappedCompletionResponse, + ) -> WrappedLLMChatCompletion: + """Generate completions for a list of messages. + + This endpoint uses the language model to generate completions for + the provided messages. The generation process can be customized + using the generation_config parameter. + + The messages list should contain alternating user and assistant + messages, with an optional system message at the start. Each + message should have a 'role' and 'content'. + """ + + return await self.services.retrieval.completion( + messages=messages, # type: ignore + generation_config=generation_config, + ) + + @self.router.post( + "/retrieval/embedding", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Generate Embeddings", + openapi_extra=EXAMPLES["embedding"], + ) + @self.base_endpoint + async def embedding( + text: str = Body( + ..., + description="Text to generate embeddings for", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedEmbeddingResponse: + """Generate embeddings for the provided text using the specified + model. + + This endpoint uses the language model to generate embeddings for + the provided text. The model parameter specifies the model to use + for generating embeddings. + """ + + return await self.services.retrieval.embedding( + text=text, + ) diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py new file mode 100644 index 00000000..682be750 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py @@ -0,0 +1,186 @@ +import logging +import textwrap +from datetime import datetime, timezone + +import psutil +from fastapi import Depends + +from core.base import R2RException +from core.base.api.models import ( + GenericMessageResponse, + WrappedGenericMessageResponse, + WrappedServerStatsResponse, + WrappedSettingsResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + + +class SystemRouter(BaseRouterV3): + def __init__( + self, + providers: R2RProviders, + services: R2RServices, + config: R2RConfig, + ): + logging.info("Initializing SystemRouter") + super().__init__(providers, services, config) + self.start_time = datetime.now(timezone.utc) + + def _setup_routes(self): + @self.router.get( + "/health", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.system.health() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.system.health(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/health"\\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + """), + }, + ] + }, + ) + @self.base_endpoint + async def health_check() -> WrappedGenericMessageResponse: + return GenericMessageResponse(message="ok") # type: ignore + + @self.router.get( + "/system/settings", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.system.settings() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.system.settings(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/system/settings" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + """), + }, + ] + }, + ) + @self.base_endpoint + async def app_settings( + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedSettingsResponse: + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can call the `system/settings` endpoint.", + 403, + ) + return await self.services.management.app_settings() + + @self.router.get( + "/system/status", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + result = client.system.status() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.system.status(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/system/status" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + """), + }, + ] + }, + ) + @self.base_endpoint + async def server_stats( + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedServerStatsResponse: + if not auth_user.is_superuser: + raise R2RException( + "Only an authorized user can call the `system/status` endpoint.", + 403, + ) + return { # type: ignore + "start_time": self.start_time.isoformat(), + "uptime_seconds": ( + datetime.now(timezone.utc) - self.start_time + ).total_seconds(), + "cpu_usage": psutil.cpu_percent(), + "memory_usage": psutil.virtual_memory().percent, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py new file mode 100644 index 00000000..686f0013 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py @@ -0,0 +1,1721 @@ +import logging +import os +import textwrap +import urllib.parse +from typing import Optional +from uuid import UUID + +import requests +from fastapi import Body, Depends, HTTPException, Path, Query +from fastapi.background import BackgroundTasks +from fastapi.responses import FileResponse +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from google.auth.transport import requests as google_requests +from google.oauth2 import id_token +from pydantic import EmailStr + +from core.base import R2RException +from core.base.api.models import ( + GenericBooleanResponse, + GenericMessageResponse, + WrappedAPIKeyResponse, + WrappedAPIKeysResponse, + WrappedBooleanResponse, + WrappedCollectionsResponse, + WrappedGenericMessageResponse, + WrappedLimitsResponse, + WrappedLoginResponse, + WrappedTokenResponse, + WrappedUserResponse, + WrappedUsersResponse, +) + +from ...abstractions import R2RProviders, R2RServices +from ...config import R2RConfig +from .base_router import BaseRouterV3 + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +class UsersRouter(BaseRouterV3): + def __init__( + self, providers: R2RProviders, services: R2RServices, config: R2RConfig + ): + logging.info("Initializing UsersRouter") + super().__init__(providers, services, config) + self.google_client_id = os.environ.get("GOOGLE_CLIENT_ID") + self.google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET") + self.google_redirect_uri = os.environ.get("GOOGLE_REDIRECT_URI") + + self.github_client_id = os.environ.get("GITHUB_CLIENT_ID") + self.github_client_secret = os.environ.get("GITHUB_CLIENT_SECRET") + self.github_redirect_uri = os.environ.get("GITHUB_REDIRECT_URI") + + def _setup_routes(self): + @self.router.post( + "/users", + # dependencies=[Depends(self.rate_limit_dependency)], + response_model=WrappedUserResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + new_user = client.users.create( + email="jane.doe@example.com", + password="secure_password123" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.create({ + email: "jane.doe@example.com", + password: "secure_password123" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users" \\ + -H "Content-Type: application/json" \\ + -d '{ + "email": "jane.doe@example.com", + "password": "secure_password123" + }'"""), + }, + ] + }, + ) + @self.base_endpoint + async def register( + email: EmailStr = Body(..., description="User's email address"), + password: str = Body(..., description="User's password"), + name: str | None = Body( + None, description="The name for the new user" + ), + bio: str | None = Body( + None, description="The bio for the new user" + ), + profile_picture: str | None = Body( + None, description="Updated user profile picture" + ), + # auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedUserResponse: + """Register a new user with the given email and password.""" + + # TODO: Do we really want this validation? The default password for the superuser would not pass... + def validate_password(password: str) -> bool: + if len(password) < 10: + return False + if not any(c.isupper() for c in password): + return False + if not any(c.islower() for c in password): + return False + if not any(c.isdigit() for c in password): + return False + if not any(c in "!@#$%^&*" for c in password): + return False + return True + + # if not validate_password(password): + # raise R2RException( + # f"Password must be at least 10 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character from '!@#$%^&*'.", + # 400, + # ) + + registration_response = await self.services.auth.register( + email=email, + password=password, + name=name, + bio=bio, + profile_picture=profile_picture, + ) + + return registration_response # type: ignore + + @self.router.post( + "/users/export", + summary="Export users to CSV", + dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + response = client.users.export( + output_path="export.csv", + columns=["id", "name", "created_at"], + include_header=True, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + await client.users.export({ + outputPath: "export.csv", + columns: ["id", "name", "created_at"], + includeHeader: true, + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "http://127.0.0.1:7272/v3/users/export" \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Accept: text/csv" \ + -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \ + --output export.csv + """), + }, + ] + }, + ) + @self.base_endpoint + async def export_users( + background_tasks: BackgroundTasks, + columns: Optional[list[str]] = Body( + None, description="Specific columns to export" + ), + filters: Optional[dict] = Body( + None, description="Filters to apply to the export" + ), + include_header: Optional[bool] = Body( + True, description="Whether to include column headers" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> FileResponse: + """Export users as a CSV file.""" + + if not auth_user.is_superuser: + raise R2RException( + status_code=403, + message="Only a superuser can export data.", + ) + + ( + csv_file_path, + temp_file, + ) = await self.services.management.export_users( + columns=columns, + filters=filters, + include_header=include_header + if include_header is not None + else True, + ) + + background_tasks.add_task(temp_file.close) + + return FileResponse( + path=csv_file_path, + media_type="text/csv", + filename="users_export.csv", + ) + + @self.router.post( + "/users/verify-email", + # dependencies=[Depends(self.rate_limit_dependency)], + response_model=WrappedGenericMessageResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + tokens = client.users.verify_email( + email="jane.doe@example.com", + verification_code="1lklwal!awdclm" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.verifyEmail({ + email: jane.doe@example.com", + verificationCode: "1lklwal!awdclm" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/login" \\ + -H "Content-Type: application/x-www-form-urlencoded" \\ + -d "email=jane.doe@example.com&verification_code=1lklwal!awdclm" + """), + }, + ] + }, + ) + @self.base_endpoint + async def verify_email( + email: EmailStr = Body(..., description="User's email address"), + verification_code: str = Body( + ..., description="Email verification code" + ), + ) -> WrappedGenericMessageResponse: + """Verify a user's email address.""" + user = ( + await self.providers.database.users_handler.get_user_by_email( + email + ) + ) + if user and user.is_verified: + raise R2RException( + status_code=400, + message="This email is already verified. Please log in.", + ) + + result = await self.services.auth.verify_email( + email, verification_code + ) + return GenericMessageResponse(message=result["message"]) # type: ignore + + @self.router.post( + "/users/send-verification-email", + dependencies=[ + Depends(self.providers.auth.auth_wrapper(public=True)) + ], + response_model=WrappedGenericMessageResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + tokens = client.users.send_verification_email( + email="jane.doe@example.com", + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.sendVerificationEmail({ + email: jane.doe@example.com", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/send-verification-email" \\ + -H "Content-Type: application/x-www-form-urlencoded" \\ + -d "email=jane.doe@example.com" + """), + }, + ] + }, + ) + @self.base_endpoint + async def send_verification_email( + email: EmailStr = Body(..., description="User's email address"), + ) -> WrappedGenericMessageResponse: + """Send a user's email a verification code.""" + user = ( + await self.providers.database.users_handler.get_user_by_email( + email + ) + ) + if user and user.is_verified: + raise R2RException( + status_code=400, + message="This email is already verified. Please log in.", + ) + + await self.services.auth.send_verification_email(email=email) + return GenericMessageResponse( + message="A verification email has been sent." + ) # type: ignore + + @self.router.post( + "/users/login", + # dependencies=[Depends(self.rate_limit_dependency)], + response_model=WrappedTokenResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + tokens = client.users.login( + email="jane.doe@example.com", + password="secure_password123" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.login({ + email: jane.doe@example.com", + password: "secure_password123" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/login" \\ + -H "Content-Type: application/x-www-form-urlencoded" \\ + -d "username=jane.doe@example.com&password=secure_password123" + """), + }, + ] + }, + ) + @self.base_endpoint + async def login( + form_data: OAuth2PasswordRequestForm = Depends(), + ) -> WrappedLoginResponse: + """Authenticate a user and provide access tokens.""" + return await self.services.auth.login( # type: ignore + form_data.username, form_data.password + ) + + @self.router.post( + "/users/logout", + response_model=WrappedGenericMessageResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + result = client.users.logout() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.logout(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/logout" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def logout( + token: str = Depends(oauth2_scheme), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Log out the current user.""" + result = await self.services.auth.logout(token) + return GenericMessageResponse(message=result["message"]) # type: ignore + + @self.router.post( + "/users/refresh-token", + # dependencies=[Depends(self.rate_limit_dependency)], + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + new_tokens = client.users.refresh_token() + # New tokens are automatically stored in the client"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.refreshAccessToken(); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/refresh-token" \\ + -H "Content-Type: application/json" \\ + -d '{ + "refresh_token": "YOUR_REFRESH_TOKEN" + }'"""), + }, + ] + }, + ) + @self.base_endpoint + async def refresh_token( + refresh_token: str = Body(..., description="Refresh token"), + ) -> WrappedTokenResponse: + """Refresh the access token using a refresh token.""" + result = await self.services.auth.refresh_access_token( + refresh_token=refresh_token + ) + return result # type: ignore + + @self.router.post( + "/users/change-password", + dependencies=[Depends(self.rate_limit_dependency)], + response_model=WrappedGenericMessageResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + result = client.users.change_password( + current_password="old_password123", + new_password="new_secure_password456" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.changePassword({ + currentPassword: "old_password123", + newPassword: "new_secure_password456" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/change-password" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{ + "current_password": "old_password123", + "new_password": "new_secure_password456" + }'"""), + }, + ] + }, + ) + @self.base_endpoint + async def change_password( + current_password: str = Body(..., description="Current password"), + new_password: str = Body(..., description="New password"), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedGenericMessageResponse: + """Change the authenticated user's password.""" + result = await self.services.auth.change_password( + auth_user, current_password, new_password + ) + return GenericMessageResponse(message=result["message"]) # type: ignore + + @self.router.post( + "/users/request-password-reset", + dependencies=[ + Depends(self.providers.auth.auth_wrapper(public=True)) + ], + response_model=WrappedGenericMessageResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + result = client.users.request_password_reset( + email="jane.doe@example.com" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.requestPasswordReset({ + email: jane.doe@example.com", + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/request-password-reset" \\ + -H "Content-Type: application/json" \\ + -d '{ + "email": "jane.doe@example.com" + }'"""), + }, + ] + }, + ) + @self.base_endpoint + async def request_password_reset( + email: EmailStr = Body(..., description="User's email address"), + ) -> WrappedGenericMessageResponse: + """Request a password reset for a user.""" + result = await self.services.auth.request_password_reset(email) + return GenericMessageResponse(message=result["message"]) # type: ignore + + @self.router.post( + "/users/reset-password", + dependencies=[ + Depends(self.providers.auth.auth_wrapper(public=True)) + ], + response_model=WrappedGenericMessageResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + result = client.users.reset_password( + reset_token="reset_token_received_via_email", + new_password="new_secure_password789" + )"""), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.resetPassword({ + resestToken: "reset_token_received_via_email", + newPassword: "new_secure_password789" + }); + } + + main(); + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/v3/users/reset-password" \\ + -H "Content-Type: application/json" \\ + -d '{ + "reset_token": "reset_token_received_via_email", + "new_password": "new_secure_password789" + }'"""), + }, + ] + }, + ) + @self.base_endpoint + async def reset_password( + reset_token: str = Body(..., description="Password reset token"), + new_password: str = Body(..., description="New password"), + ) -> WrappedGenericMessageResponse: + """Reset a user's password using a reset token.""" + result = await self.services.auth.confirm_password_reset( + reset_token, new_password + ) + return GenericMessageResponse(message=result["message"]) # type: ignore + + @self.router.get( + "/users", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List Users", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # List users with filters + users = client.users.list( + offset=0, + limit=100, + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.list(); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/users?offset=0&limit=100&username=john&email=john@example.com&is_active=true&is_superuser=false" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_users( + ids: list[str] = Query( + [], description="List of user IDs to filter by" + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedUsersResponse: + """List all users with pagination and filtering options. + + Only accessible by superusers. + """ + + if not auth_user.is_superuser: + raise R2RException( + status_code=403, + message="Only a superuser can call the `users_overview` endpoint.", + ) + + user_uuids = [UUID(user_id) for user_id in ids] + + users_overview_response = ( + await self.services.management.users_overview( + user_ids=user_uuids, offset=offset, limit=limit + ) + ) + return users_overview_response["results"], { # type: ignore + "total_entries": users_overview_response["total_entries"] + } + + @self.router.get( + "/users/me", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Get the Current User", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Get user details + users = client.users.me() + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.retrieve(); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/users/me" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_current_user( + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedUserResponse: + """Get detailed information about the currently authenticated + user.""" + return auth_user + + @self.router.get( + "/users/{id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Get User Details", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Get user details + users = client.users.retrieve( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.retrieve({ + id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + }); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_user( + id: UUID = Path( + ..., example="550e8400-e29b-41d4-a716-446655440000" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedUserResponse: + """Get detailed information about a specific user. + + Users can only access their own information unless they are + superusers. + """ + if not auth_user.is_superuser and auth_user.id != id: + raise R2RException( + "Only a superuser can call the get `user` endpoint for other users.", + 403, + ) + + users_overview_response = ( + await self.services.management.users_overview( + offset=0, + limit=1, + user_ids=[id], + ) + ) + + return users_overview_response["results"][0] + + @self.router.delete( + "/users/{id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete User", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Delete user + client.users.delete(id="550e8400-e29b-41d4-a716-446655440000", password="secure_password123") + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.delete({ + id: "550e8400-e29b-41d4-a716-446655440000", + password: "secure_password123" + }); + } + + main(); + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_user( + id: UUID = Path( + ..., example="550e8400-e29b-41d4-a716-446655440000" + ), + password: Optional[str] = Body( + None, description="User's current password" + ), + delete_vector_data: Optional[bool] = Body( + False, + description="Whether to delete the user's vector data", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete a specific user. + + Users can only delete their own account unless they are superusers. + """ + if not auth_user.is_superuser and auth_user.id != id: + raise R2RException( + "Only a superuser can delete other users.", + 403, + ) + + await self.services.auth.delete_user( + user_id=id, + password=password, + delete_vector_data=delete_vector_data or False, + is_superuser=auth_user.is_superuser, + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.get( + "/users/{id}/collections", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Get User Collections", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Get user collections + collections = client.user.list_collections( + "550e8400-e29b-41d4-a716-446655440000", + offset=0, + limit=100 + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.listCollections({ + id: "550e8400-e29b-41d4-a716-446655440000", + offset: 0, + limit: 100 + }); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections?offset=0&limit=100" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def get_user_collections( + id: UUID = Path( + ..., example="550e8400-e29b-41d4-a716-446655440000" + ), + offset: int = Query( + 0, + ge=0, + description="Specifies the number of objects to skip. Defaults to 0.", + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedCollectionsResponse: + """Get all collections associated with a specific user. + + Users can only access their own collections unless they are + superusers. + """ + if auth_user.id != id and not auth_user.is_superuser: + raise R2RException( + "The currently authenticated user does not have access to the specified collection.", + 403, + ) + user_collection_response = ( + await self.services.management.collections_overview( + offset=offset, + limit=limit, + user_ids=[id], + ) + ) + return user_collection_response["results"], { # type: ignore + "total_entries": user_collection_response["total_entries"] + } + + @self.router.post( + "/users/{id}/collections/{collection_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Add User to Collection", + response_model=WrappedBooleanResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Add user to collection + client.users.add_to_collection( + id="550e8400-e29b-41d4-a716-446655440000", + collection_id="750e8400-e29b-41d4-a716-446655440000" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.addToCollection({ + id: "550e8400-e29b-41d4-a716-446655440000", + collectionId: "750e8400-e29b-41d4-a716-446655440000" + }); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def add_user_to_collection( + id: UUID = Path( + ..., example="550e8400-e29b-41d4-a716-446655440000" + ), + collection_id: UUID = Path( + ..., example="750e8400-e29b-41d4-a716-446655440000" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + if auth_user.id != id and not auth_user.is_superuser: + raise R2RException( + "The currently authenticated user does not have access to the specified collection.", + 403, + ) + + # TODO - Do we need a check on user access to the collection? + await self.services.management.add_user_to_collection( # type: ignore + id, collection_id + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.delete( + "/users/{id}/collections/{collection_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Remove User from Collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Remove user from collection + client.users.remove_from_collection( + id="550e8400-e29b-41d4-a716-446655440000", + collection_id="750e8400-e29b-41d4-a716-446655440000" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.removeFromCollection({ + id: "550e8400-e29b-41d4-a716-446655440000", + collectionId: "750e8400-e29b-41d4-a716-446655440000" + }); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """), + }, + ] + }, + ) + @self.base_endpoint + async def remove_user_from_collection( + id: UUID = Path( + ..., example="550e8400-e29b-41d4-a716-446655440000" + ), + collection_id: UUID = Path( + ..., example="750e8400-e29b-41d4-a716-446655440000" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Remove a user from a collection. + + Requires either superuser status or access to the collection. + """ + if auth_user.id != id and not auth_user.is_superuser: + raise R2RException( + "The currently authenticated user does not have access to the specified collection.", + 403, + ) + + # TODO - Do we need a check on user access to the collection? + await self.services.management.remove_user_from_collection( # type: ignore + id, collection_id + ) + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.post( + "/users/{id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Update User", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + # Update user + updated_user = client.update_user( + "550e8400-e29b-41d4-a716-446655440000", + name="John Doe" + ) + """), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent(""" + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + function main() { + const response = await client.users.update({ + id: "550e8400-e29b-41d4-a716-446655440000", + name: "John Doe" + }); + } + + main(); + """), + }, + { + "lang": "Shell", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "John Doe", + }' + """), + }, + ] + }, + ) + # TODO - Modify update user to have synced params with user object + @self.base_endpoint + async def update_user( + id: UUID = Path(..., description="ID of the user to update"), + email: EmailStr | None = Body( + None, description="Updated email address" + ), + is_superuser: bool | None = Body( + None, description="Updated superuser status" + ), + name: str | None = Body(None, description="Updated user name"), + bio: str | None = Body(None, description="Updated user bio"), + profile_picture: str | None = Body( + None, description="Updated profile picture URL" + ), + limits_overrides: dict = Body( + None, + description="Updated limits overrides", + ), + metadata: dict[str, str | None] | None = None, + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedUserResponse: + """Update user information. + + Users can only update their own information unless they are + superusers. Superuser status can only be modified by existing + superusers. + """ + + if is_superuser is not None and not auth_user.is_superuser: + raise R2RException( + "Only superusers can update the superuser status of a user", + 403, + ) + + if not auth_user.is_superuser and auth_user.id != id: + raise R2RException( + "Only superusers can update other users' information", + 403, + ) + + if not auth_user.is_superuser and limits_overrides is not None: + raise R2RException( + "Only superusers can update other users' limits overrides", + 403, + ) + + # Pass `metadata` to our auth or management service so it can do a + # partial (Stripe-like) merge of metadata. + return await self.services.auth.update_user( # type: ignore + user_id=id, + email=email, + is_superuser=is_superuser, + name=name, + bio=bio, + profile_picture=profile_picture, + limits_overrides=limits_overrides, + new_metadata=metadata, + ) + + @self.router.post( + "/users/{id}/api-keys", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Create User API Key", + response_model=WrappedAPIKeyResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + result = client.users.create_api_key( + id="550e8400-e29b-41d4-a716-446655440000", + name="My API Key", + description="API key for accessing the app", + ) + # result["api_key"] contains the newly created API key + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\ + -H "Authorization: Bearer YOUR_API_TOKEN" \\ + -d '{"name": "My API Key", "description": "API key for accessing the app"}' + """), + }, + ] + }, + ) + @self.base_endpoint + async def create_user_api_key( + id: UUID = Path( + ..., description="ID of the user for whom to create an API key" + ), + name: Optional[str] = Body( + None, description="Name of the API key" + ), + description: Optional[str] = Body( + None, description="Description of the API key" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedAPIKeyResponse: + """Create a new API key for the specified user. + + Only superusers or the user themselves may create an API key. + """ + if auth_user.id != id and not auth_user.is_superuser: + raise R2RException( + "Only the user themselves or a superuser can create API keys for this user.", + 403, + ) + + api_key = await self.services.auth.create_user_api_key( + id, name=name, description=description + ) + return api_key # type: ignore + + @self.router.get( + "/users/{id}/api-keys", + dependencies=[Depends(self.rate_limit_dependency)], + summary="List User API Keys", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + keys = client.users.list_api_keys( + id="550e8400-e29b-41d4-a716-446655440000" + ) + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\ + -H "Authorization: Bearer YOUR_API_TOKEN" + """), + }, + ] + }, + ) + @self.base_endpoint + async def list_user_api_keys( + id: UUID = Path( + ..., description="ID of the user whose API keys to list" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedAPIKeysResponse: + """List all API keys for the specified user. + + Only superusers or the user themselves may list the API keys. + """ + if auth_user.id != id and not auth_user.is_superuser: + raise R2RException( + "Only the user themselves or a superuser can list API keys for this user.", + 403, + ) + + keys = ( + await self.providers.database.users_handler.get_user_api_keys( + id + ) + ) + return keys, {"total_entries": len(keys)} # type: ignore + + @self.router.delete( + "/users/{id}/api-keys/{key_id}", + dependencies=[Depends(self.rate_limit_dependency)], + summary="Delete User API Key", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent(""" + from r2r import R2RClient + from uuid import UUID + + client = R2RClient() + # client.login(...) + + response = client.users.delete_api_key( + id="550e8400-e29b-41d4-a716-446655440000", + key_id="d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" + ) + """), + }, + { + "lang": "cURL", + "source": textwrap.dedent(""" + curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys/d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" \\ + -H "Authorization: Bearer YOUR_API_TOKEN" + """), + }, + ] + }, + ) + @self.base_endpoint + async def delete_user_api_key( + id: UUID = Path(..., description="ID of the user"), + key_id: UUID = Path( + ..., description="ID of the API key to delete" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedBooleanResponse: + """Delete a specific API key for the specified user. + + Only superusers or the user themselves may delete the API key. + """ + if auth_user.id != id and not auth_user.is_superuser: + raise R2RException( + "Only the user themselves or a superuser can delete this API key.", + 403, + ) + + success = ( + await self.providers.database.users_handler.delete_api_key( + id, key_id + ) + ) + if not success: + raise R2RException( + "API key not found or could not be deleted", 400 + ) + return {"success": True} # type: ignore + + @self.router.get( + "/users/{id}/limits", + summary="Fetch User Limits", + responses={ + 200: { + "description": "Returns system default limits, user overrides, and final effective settings." + }, + 403: { + "description": "If the requesting user is neither the same user nor a superuser." + }, + 404: {"description": "If the user ID does not exist."}, + }, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + user_limits = client.users.get_limits("550e8400-e29b-41d4-a716-446655440000") + """, + }, + { + "lang": "JavaScript", + "source": """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // await client.users.login(...) + + async function main() { + const userLimits = await client.users.getLimits({ + id: "550e8400-e29b-41d4-a716-446655440000" + }); + console.log(userLimits); + } + + main(); + """, + }, + { + "lang": "cURL", + "source": """ + curl -X GET "https://api.example.com/v3/users/550e8400-e29b-41d4-a716-446655440000/limits" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """, + }, + ] + }, + ) + @self.base_endpoint + async def get_user_limits( + id: UUID = Path( + ..., description="ID of the user to fetch limits for" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> WrappedLimitsResponse: + """Return the system default limits, user-level overrides, and + final "effective" limit settings for the specified user. + + Only superusers or the user themself may fetch these values. + """ + if (auth_user.id != id) and (not auth_user.is_superuser): + raise R2RException( + "Only the user themselves or a superuser can view these limits.", + status_code=403, + ) + + # This calls the new helper you created in ManagementService + limits_info = await self.services.management.get_all_user_limits( + id + ) + return limits_info # type: ignore + + @self.router.get("/users/oauth/google/authorize") + @self.base_endpoint + async def google_authorize() -> WrappedGenericMessageResponse: + """Redirect user to Google's OAuth 2.0 consent screen.""" + state = "some_random_string_or_csrf_token" # Usually you store a random state in session/Redis + scope = "openid email profile" + + # Build the Google OAuth URL + params = { + "client_id": self.google_client_id, + "redirect_uri": self.google_redirect_uri, + "response_type": "code", + "scope": scope, + "state": state, + "access_type": "offline", # to get refresh token if needed + "prompt": "consent", # Force consent each time if you want + } + google_auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?{urllib.parse.urlencode(params)}" + return GenericMessageResponse(message=google_auth_url) # type: ignore + + @self.router.get("/users/oauth/google/callback") + @self.base_endpoint + async def google_callback( + code: str = Query(...), state: str = Query(...) + ) -> WrappedLoginResponse: + """Google's callback that will receive the `code` and `state`. + + We then exchange code for tokens, verify, and log the user in. + """ + # 1. Exchange `code` for tokens + token_data = requests.post( + "https://oauth2.googleapis.com/token", + data={ + "code": code, + "client_id": self.google_client_id, + "client_secret": self.google_client_secret, + "redirect_uri": self.google_redirect_uri, + "grant_type": "authorization_code", + }, + ).json() + if "error" in token_data: + raise HTTPException( + status_code=400, + detail=f"Failed to get token: {token_data}", + ) + + # 2. Verify the ID token + id_token_str = token_data["id_token"] + try: + # google_auth.transport.requests.Request() is a session for verifying + id_info = id_token.verify_oauth2_token( + id_token_str, + google_requests.Request(), + self.google_client_id, + ) + except ValueError as e: + raise HTTPException( + status_code=400, + detail=f"Token verification failed: {str(e)}", + ) from e + + # id_info will contain "sub", "email", etc. + google_id = id_info["sub"] + email = id_info.get("email") + email = email or f"{google_id}@google_oauth.fake" + + # 3. Now call our R2RAuthProvider method that handles "oauth-based" user creation or login + return await self.providers.auth.oauth_callback_handler( # type: ignore + provider="google", + oauth_id=google_id, + email=email, + ) + + @self.router.get("/users/oauth/github/authorize") + @self.base_endpoint + async def github_authorize() -> WrappedGenericMessageResponse: + """Redirect user to GitHub's OAuth consent screen.""" + state = "some_random_string_or_csrf_token" + scope = "read:user user:email" + + params = { + "client_id": self.github_client_id, + "redirect_uri": self.github_redirect_uri, + "scope": scope, + "state": state, + } + github_auth_url = f"https://github.com/login/oauth/authorize?{urllib.parse.urlencode(params)}" + return GenericMessageResponse(message=github_auth_url) # type: ignore + + @self.router.get("/users/oauth/github/callback") + @self.base_endpoint + async def github_callback( + code: str = Query(...), state: str = Query(...) + ) -> WrappedLoginResponse: + """GitHub callback route to exchange code for an access_token, then + fetch user info from GitHub's API, then do the same 'oauth-based' + login or registration.""" + # 1. Exchange code for access_token + token_resp = requests.post( + "https://github.com/login/oauth/access_token", + data={ + "client_id": self.github_client_id, + "client_secret": self.github_client_secret, + "code": code, + "redirect_uri": self.github_redirect_uri, + "state": state, + }, + headers={"Accept": "application/json"}, + ) + token_data = token_resp.json() + if "error" in token_data: + raise HTTPException( + status_code=400, + detail=f"Failed to get token: {token_data}", + ) + access_token = token_data["access_token"] + + # 2. Use the access_token to fetch user info + user_info_resp = requests.get( + "https://api.github.com/user", + headers={"Authorization": f"Bearer {access_token}"}, + ).json() + + github_id = str( + user_info_resp["id"] + ) # GitHub user ID is typically an integer + # fetch email (sometimes you need to call /user/emails endpoint if user sets email private) + email = user_info_resp.get("email") + email = email or f"{github_id}@github_oauth.fake" + # 3. Pass to your auth provider + return await self.providers.auth.oauth_callback_handler( # type: ignore + provider="github", + oauth_id=github_id, + email=email, + ) diff --git a/.venv/lib/python3.12/site-packages/core/main/app.py b/.venv/lib/python3.12/site-packages/core/main/app.py new file mode 100644 index 00000000..ceb13cce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/app.py @@ -0,0 +1,121 @@ +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.utils import get_openapi +from fastapi.responses import JSONResponse + +from core.base import R2RException +from core.providers import ( + HatchetOrchestrationProvider, + SimpleOrchestrationProvider, +) +from core.utils.sentry import init_sentry + +from .abstractions import R2RServices +from .api.v3.chunks_router import ChunksRouter +from .api.v3.collections_router import CollectionsRouter +from .api.v3.conversations_router import ConversationsRouter +from .api.v3.documents_router import DocumentsRouter +from .api.v3.graph_router import GraphRouter +from .api.v3.indices_router import IndicesRouter +from .api.v3.prompts_router import PromptsRouter +from .api.v3.retrieval_router import RetrievalRouter +from .api.v3.system_router import SystemRouter +from .api.v3.users_router import UsersRouter +from .config import R2RConfig + + +class R2RApp: + def __init__( + self, + config: R2RConfig, + orchestration_provider: ( + HatchetOrchestrationProvider | SimpleOrchestrationProvider + ), + services: R2RServices, + chunks_router: ChunksRouter, + collections_router: CollectionsRouter, + conversations_router: ConversationsRouter, + documents_router: DocumentsRouter, + graph_router: GraphRouter, + indices_router: IndicesRouter, + prompts_router: PromptsRouter, + retrieval_router: RetrievalRouter, + system_router: SystemRouter, + users_router: UsersRouter, + ): + init_sentry() + + self.config = config + self.services = services + self.chunks_router = chunks_router + self.collections_router = collections_router + self.conversations_router = conversations_router + self.documents_router = documents_router + self.graph_router = graph_router + self.indices_router = indices_router + self.orchestration_provider = orchestration_provider + self.prompts_router = prompts_router + self.retrieval_router = retrieval_router + self.system_router = system_router + self.users_router = users_router + + self.app = FastAPI() + + @self.app.exception_handler(R2RException) + async def r2r_exception_handler(request: Request, exc: R2RException): + return JSONResponse( + status_code=exc.status_code, + content={ + "message": exc.message, + "error_type": type(exc).__name__, + }, + ) + + self._setup_routes() + self._apply_cors() + + def _setup_routes(self): + self.app.include_router(self.chunks_router, prefix="/v3") + self.app.include_router(self.collections_router, prefix="/v3") + self.app.include_router(self.conversations_router, prefix="/v3") + self.app.include_router(self.documents_router, prefix="/v3") + self.app.include_router(self.graph_router, prefix="/v3") + self.app.include_router(self.indices_router, prefix="/v3") + self.app.include_router(self.prompts_router, prefix="/v3") + self.app.include_router(self.retrieval_router, prefix="/v3") + self.app.include_router(self.system_router, prefix="/v3") + self.app.include_router(self.users_router, prefix="/v3") + + @self.app.get("/openapi_spec", include_in_schema=False) + async def openapi_spec(): + return get_openapi( + title="R2R Application API", + version="1.0.0", + routes=self.app.routes, + ) + + def _apply_cors(self): + origins = ["*", "http://localhost:3000", "http://localhost:7272"] + self.app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + async def serve(self, host: str = "0.0.0.0", port: int = 7272): + import uvicorn + + from core.utils.logging_config import configure_logging + + configure_logging() + + config = uvicorn.Config( + self.app, + host=host, + port=port, + log_config=None, + ) + server = uvicorn.Server(config) + await server.serve() diff --git a/.venv/lib/python3.12/site-packages/core/main/app_entry.py b/.venv/lib/python3.12/site-packages/core/main/app_entry.py new file mode 100644 index 00000000..cd3ea84d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/app_entry.py @@ -0,0 +1,125 @@ +import logging +import os +from contextlib import asynccontextmanager +from typing import Optional + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from core.base import R2RException +from core.utils.logging_config import configure_logging + +from .assembly import R2RBuilder, R2RConfig + +log_file = configure_logging() + +# Global scheduler +scheduler = AsyncIOScheduler() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + r2r_app = await create_r2r_app( + config_name=config_name, + config_path=config_path, + ) + + # Copy all routes from r2r_app to app + app.router.routes = r2r_app.app.routes + + # Copy middleware and exception handlers + app.middleware = r2r_app.app.middleware # type: ignore + app.exception_handlers = r2r_app.app.exception_handlers + + # Start the scheduler + scheduler.start() + + # Start the Hatchet worker + await r2r_app.orchestration_provider.start_worker() + + yield + + # # Shutdown + scheduler.shutdown() + + +async def create_r2r_app( + config_name: Optional[str] = "default", + config_path: Optional[str] = None, +): + config = R2RConfig.load(config_name=config_name, config_path=config_path) + + if ( + config.embedding.provider == "openai" + and "OPENAI_API_KEY" not in os.environ + ): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + + # Build the R2RApp + builder = R2RBuilder(config=config) + return await builder.build() + + +config_name = os.getenv("R2R_CONFIG_NAME", None) +config_path = os.getenv("R2R_CONFIG_PATH", None) + +if not config_path and not config_name: + config_name = "default" +host = os.getenv("R2R_HOST", os.getenv("HOST", "0.0.0.0")) +port = int(os.getenv("R2R_PORT", "7272")) + +logging.info( + f"Environment R2R_IMAGE: {os.getenv('R2R_IMAGE')}", +) +logging.info( + f"Environment R2R_CONFIG_NAME: {'None' if config_name is None else config_name}" +) +logging.info( + f"Environment R2R_CONFIG_PATH: {'None' if config_path is None else config_path}" +) +logging.info(f"Environment R2R_PROJECT_NAME: {os.getenv('R2R_PROJECT_NAME')}") + +logging.info( + f"Environment R2R_POSTGRES_HOST: {os.getenv('R2R_POSTGRES_HOST')}" +) +logging.info( + f"Environment R2R_POSTGRES_DBNAME: {os.getenv('R2R_POSTGRES_DBNAME')}" +) +logging.info( + f"Environment R2R_POSTGRES_PORT: {os.getenv('R2R_POSTGRES_PORT')}" +) +logging.info( + f"Environment R2R_POSTGRES_PASSWORD: {os.getenv('R2R_POSTGRES_PASSWORD')}" +) + +# Create the FastAPI app +app = FastAPI( + lifespan=lifespan, + log_config=None, +) + + +@app.exception_handler(R2RException) +async def r2r_exception_handler(request: Request, exc: R2RException): + return JSONResponse( + status_code=exc.status_code, + content={ + "message": exc.message, + "error_type": type(exc).__name__, + }, + ) + + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py new file mode 100644 index 00000000..3d10f2b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py @@ -0,0 +1,12 @@ +from ..config import R2RConfig +from .builder import R2RBuilder +from .factory import R2RProviderFactory + +__all__ = [ + # Builder + "R2RBuilder", + # Config + "R2RConfig", + # Factory + "R2RProviderFactory", +] diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py new file mode 100644 index 00000000..f72a15c9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py @@ -0,0 +1,127 @@ +import logging +from typing import Any, Type + +from ..abstractions import R2RProviders, R2RServices +from ..api.v3.chunks_router import ChunksRouter +from ..api.v3.collections_router import CollectionsRouter +from ..api.v3.conversations_router import ConversationsRouter +from ..api.v3.documents_router import DocumentsRouter +from ..api.v3.graph_router import GraphRouter +from ..api.v3.indices_router import IndicesRouter +from ..api.v3.prompts_router import PromptsRouter +from ..api.v3.retrieval_router import RetrievalRouter +from ..api.v3.system_router import SystemRouter +from ..api.v3.users_router import UsersRouter +from ..app import R2RApp +from ..config import R2RConfig +from ..services.auth_service import AuthService # noqa: F401 +from ..services.graph_service import GraphService # noqa: F401 +from ..services.ingestion_service import IngestionService # noqa: F401 +from ..services.management_service import ManagementService # noqa: F401 +from ..services.retrieval_service import ( # type: ignore + RetrievalService, # noqa: F401 # type: ignore +) +from .factory import R2RProviderFactory + +logger = logging.getLogger() + + +class R2RBuilder: + _SERVICES = ["auth", "ingestion", "management", "retrieval", "graph"] + + def __init__(self, config: R2RConfig): + self.config = config + + async def build(self, *args, **kwargs) -> R2RApp: + provider_factory = R2RProviderFactory + + try: + providers = await self._create_providers( + provider_factory, *args, **kwargs + ) + except Exception as e: + logger.error(f"Error {e} while creating R2RProviders.") + raise + + service_params = { + "config": self.config, + "providers": providers, + } + + services = self._create_services(service_params) + + routers = { + "chunks_router": ChunksRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "collections_router": CollectionsRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "conversations_router": ConversationsRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "documents_router": DocumentsRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "graph_router": GraphRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "indices_router": IndicesRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "prompts_router": PromptsRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "retrieval_router": RetrievalRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "system_router": SystemRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + "users_router": UsersRouter( + providers=providers, + services=services, + config=self.config, + ).get_router(), + } + + return R2RApp( + config=self.config, + orchestration_provider=providers.orchestration, + services=services, + **routers, + ) + + async def _create_providers( + self, provider_factory: Type[R2RProviderFactory], *args, **kwargs + ) -> R2RProviders: + factory = provider_factory(self.config) + return await factory.create_providers(*args, **kwargs) + + def _create_services(self, service_params: dict[str, Any]) -> R2RServices: + services = R2RBuilder._SERVICES + service_instances = {} + + for service_type in services: + service_class = globals()[f"{service_type.capitalize()}Service"] + service_instances[service_type] = service_class(**service_params) + + return R2RServices(**service_instances) diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py new file mode 100644 index 00000000..b982aa18 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py @@ -0,0 +1,417 @@ +import logging +import math +import os +from typing import Any, Optional + +from core.base import ( + AuthConfig, + CompletionConfig, + CompletionProvider, + CryptoConfig, + DatabaseConfig, + EmailConfig, + EmbeddingConfig, + EmbeddingProvider, + IngestionConfig, + OrchestrationConfig, +) +from core.providers import ( + AnthropicCompletionProvider, + AsyncSMTPEmailProvider, + BcryptCryptoConfig, + BCryptCryptoProvider, + ClerkAuthProvider, + ConsoleMockEmailProvider, + HatchetOrchestrationProvider, + JwtAuthProvider, + LiteLLMCompletionProvider, + LiteLLMEmbeddingProvider, + MailerSendEmailProvider, + NaClCryptoConfig, + NaClCryptoProvider, + OllamaEmbeddingProvider, + OpenAICompletionProvider, + OpenAIEmbeddingProvider, + PostgresDatabaseProvider, + R2RAuthProvider, + R2RCompletionProvider, + R2RIngestionConfig, + R2RIngestionProvider, + SendGridEmailProvider, + SimpleOrchestrationProvider, + SupabaseAuthProvider, + UnstructuredIngestionConfig, + UnstructuredIngestionProvider, +) + +from ..abstractions import R2RProviders +from ..config import R2RConfig + +logger = logging.getLogger() + + +class R2RProviderFactory: + def __init__(self, config: R2RConfig): + self.config = config + + @staticmethod + async def create_auth_provider( + auth_config: AuthConfig, + crypto_provider: BCryptCryptoProvider | NaClCryptoProvider, + database_provider: PostgresDatabaseProvider, + email_provider: ( + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider + | MailerSendEmailProvider + ), + *args, + **kwargs, + ) -> ( + R2RAuthProvider + | SupabaseAuthProvider + | JwtAuthProvider + | ClerkAuthProvider + ): + if auth_config.provider == "r2r": + r2r_auth = R2RAuthProvider( + auth_config, crypto_provider, database_provider, email_provider + ) + await r2r_auth.initialize() + return r2r_auth + elif auth_config.provider == "supabase": + return SupabaseAuthProvider( + auth_config, crypto_provider, database_provider, email_provider + ) + elif auth_config.provider == "jwt": + return JwtAuthProvider( + auth_config, crypto_provider, database_provider, email_provider + ) + elif auth_config.provider == "clerk": + return ClerkAuthProvider( + auth_config, crypto_provider, database_provider, email_provider + ) + else: + raise ValueError( + f"Auth provider {auth_config.provider} not supported." + ) + + @staticmethod + def create_crypto_provider( + crypto_config: CryptoConfig, *args, **kwargs + ) -> BCryptCryptoProvider | NaClCryptoProvider: + if crypto_config.provider == "bcrypt": + return BCryptCryptoProvider( + BcryptCryptoConfig(**crypto_config.model_dump()) + ) + if crypto_config.provider == "nacl": + return NaClCryptoProvider( + NaClCryptoConfig(**crypto_config.model_dump()) + ) + else: + raise ValueError( + f"Crypto provider {crypto_config.provider} not supported." + ) + + @staticmethod + def create_ingestion_provider( + ingestion_config: IngestionConfig, + database_provider: PostgresDatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + *args, + **kwargs, + ) -> R2RIngestionProvider | UnstructuredIngestionProvider: + config_dict = ( + ingestion_config.model_dump() + if isinstance(ingestion_config, IngestionConfig) + else ingestion_config + ) + + extra_fields = config_dict.pop("extra_fields", {}) + + if config_dict["provider"] == "r2r": + r2r_ingestion_config = R2RIngestionConfig( + **config_dict, **extra_fields + ) + return R2RIngestionProvider( + r2r_ingestion_config, database_provider, llm_provider + ) + elif config_dict["provider"] in [ + "unstructured_local", + "unstructured_api", + ]: + unstructured_ingestion_config = UnstructuredIngestionConfig( + **config_dict, **extra_fields + ) + + return UnstructuredIngestionProvider( + unstructured_ingestion_config, database_provider, llm_provider + ) + else: + raise ValueError( + f"Ingestion provider {ingestion_config.provider} not supported" + ) + + @staticmethod + def create_orchestration_provider( + config: OrchestrationConfig, *args, **kwargs + ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider: + if config.provider == "hatchet": + orchestration_provider = HatchetOrchestrationProvider(config) + orchestration_provider.get_worker("r2r-worker") + return orchestration_provider + elif config.provider == "simple": + from core.providers import SimpleOrchestrationProvider + + return SimpleOrchestrationProvider(config) + else: + raise ValueError( + f"Orchestration provider {config.provider} not supported" + ) + + async def create_database_provider( + self, + db_config: DatabaseConfig, + crypto_provider: BCryptCryptoProvider | NaClCryptoProvider, + *args, + **kwargs, + ) -> PostgresDatabaseProvider: + if not self.config.embedding.base_dimension: + raise ValueError( + "Embedding config must have a base dimension to initialize database." + ) + + dimension = self.config.embedding.base_dimension + quantization_type = ( + self.config.embedding.quantization_settings.quantization_type + ) + if db_config.provider == "postgres": + database_provider = PostgresDatabaseProvider( + db_config, + dimension, + crypto_provider=crypto_provider, + quantization_type=quantization_type, + ) + await database_provider.initialize() + return database_provider + else: + raise ValueError( + f"Database provider {db_config.provider} not supported" + ) + + @staticmethod + def create_embedding_provider( + embedding: EmbeddingConfig, *args, **kwargs + ) -> ( + LiteLLMEmbeddingProvider + | OllamaEmbeddingProvider + | OpenAIEmbeddingProvider + ): + embedding_provider: Optional[EmbeddingProvider] = None + + if embedding.provider == "openai": + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + from core.providers import OpenAIEmbeddingProvider + + embedding_provider = OpenAIEmbeddingProvider(embedding) + + elif embedding.provider == "litellm": + from core.providers import LiteLLMEmbeddingProvider + + embedding_provider = LiteLLMEmbeddingProvider(embedding) + + elif embedding.provider == "ollama": + from core.providers import OllamaEmbeddingProvider + + embedding_provider = OllamaEmbeddingProvider(embedding) + + else: + raise ValueError( + f"Embedding provider {embedding.provider} not supported" + ) + + return embedding_provider + + @staticmethod + def create_llm_provider( + llm_config: CompletionConfig, *args, **kwargs + ) -> ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ): + llm_provider: Optional[CompletionProvider] = None + if llm_config.provider == "anthropic": + llm_provider = AnthropicCompletionProvider(llm_config) + elif llm_config.provider == "litellm": + llm_provider = LiteLLMCompletionProvider(llm_config) + elif llm_config.provider == "openai": + llm_provider = OpenAICompletionProvider(llm_config) + elif llm_config.provider == "r2r": + llm_provider = R2RCompletionProvider(llm_config) + else: + raise ValueError( + f"Language model provider {llm_config.provider} not supported" + ) + if not llm_provider: + raise ValueError("Language model provider not found") + return llm_provider + + @staticmethod + async def create_email_provider( + email_config: Optional[EmailConfig] = None, *args, **kwargs + ) -> ( + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider + | MailerSendEmailProvider + ): + """Creates an email provider based on configuration.""" + if not email_config: + raise ValueError( + "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." + ) + + if email_config.provider == "smtp": + return AsyncSMTPEmailProvider(email_config) + elif email_config.provider == "console_mock": + return ConsoleMockEmailProvider(email_config) + elif email_config.provider == "sendgrid": + return SendGridEmailProvider(email_config) + elif email_config.provider == "mailersend": + return MailerSendEmailProvider(email_config) + else: + raise ValueError( + f"Email provider {email_config.provider} not supported." + ) + + async def create_providers( + self, + auth_provider_override: Optional[ + R2RAuthProvider | SupabaseAuthProvider + ] = None, + crypto_provider_override: Optional[ + BCryptCryptoProvider | NaClCryptoProvider + ] = None, + database_provider_override: Optional[PostgresDatabaseProvider] = None, + email_provider_override: Optional[ + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider + | MailerSendEmailProvider + ] = None, + embedding_provider_override: Optional[ + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider + ] = None, + ingestion_provider_override: Optional[ + R2RIngestionProvider | UnstructuredIngestionProvider + ] = None, + llm_provider_override: Optional[ + AnthropicCompletionProvider + | OpenAICompletionProvider + | LiteLLMCompletionProvider + | R2RCompletionProvider + ] = None, + orchestration_provider_override: Optional[Any] = None, + *args, + **kwargs, + ) -> R2RProviders: + if ( + math.isnan(self.config.embedding.base_dimension) + != math.isnan(self.config.completion_embedding.base_dimension) + ) or ( + not math.isnan(self.config.embedding.base_dimension) + and not math.isnan(self.config.completion_embedding.base_dimension) + and self.config.embedding.base_dimension + != self.config.completion_embedding.base_dimension + ): + raise ValueError( + f"Both embedding configurations must use the same dimensions. Got {self.config.embedding.base_dimension} and {self.config.completion_embedding.base_dimension}" + ) + + embedding_provider = ( + embedding_provider_override + or self.create_embedding_provider( + self.config.embedding, *args, **kwargs + ) + ) + + completion_embedding_provider = ( + embedding_provider_override + or self.create_embedding_provider( + self.config.completion_embedding, *args, **kwargs + ) + ) + + llm_provider = llm_provider_override or self.create_llm_provider( + self.config.completion, *args, **kwargs + ) + + crypto_provider = ( + crypto_provider_override + or self.create_crypto_provider(self.config.crypto, *args, **kwargs) + ) + + database_provider = ( + database_provider_override + or await self.create_database_provider( + self.config.database, crypto_provider, *args, **kwargs + ) + ) + + ingestion_provider = ( + ingestion_provider_override + or self.create_ingestion_provider( + self.config.ingestion, + database_provider, + llm_provider, + *args, + **kwargs, + ) + ) + + email_provider = ( + email_provider_override + or await self.create_email_provider( + self.config.email, crypto_provider, *args, **kwargs + ) + ) + + auth_provider = ( + auth_provider_override + or await self.create_auth_provider( + self.config.auth, + crypto_provider, + database_provider, + email_provider, + *args, + **kwargs, + ) + ) + + orchestration_provider = ( + orchestration_provider_override + or self.create_orchestration_provider(self.config.orchestration) + ) + + return R2RProviders( + auth=auth_provider, + database=database_provider, + embedding=embedding_provider, + completion_embedding=completion_embedding_provider, + ingestion=ingestion_provider, + llm=llm_provider, + email=email_provider, + orchestration=orchestration_provider, + ) diff --git a/.venv/lib/python3.12/site-packages/core/main/config.py b/.venv/lib/python3.12/site-packages/core/main/config.py new file mode 100644 index 00000000..f49b4041 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/config.py @@ -0,0 +1,213 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +import logging +import os +from enum import Enum +from typing import Any, Optional + +import toml +from pydantic import BaseModel + +from ..base.abstractions import GenerationConfig +from ..base.agent.agent import RAGAgentConfig # type: ignore +from ..base.providers import AppConfig +from ..base.providers.auth import AuthConfig +from ..base.providers.crypto import CryptoConfig +from ..base.providers.database import DatabaseConfig +from ..base.providers.email import EmailConfig +from ..base.providers.embedding import EmbeddingConfig +from ..base.providers.ingestion import IngestionConfig +from ..base.providers.llm import CompletionConfig +from ..base.providers.orchestration import OrchestrationConfig +from ..base.utils import deep_update + +logger = logging.getLogger() + + +class R2RConfig: + current_file_path = os.path.dirname(__file__) + config_dir_root = os.path.join(current_file_path, "..", "configs") + default_config_path = os.path.join( + current_file_path, "..", "..", "r2r", "r2r.toml" + ) + + CONFIG_OPTIONS: dict[str, Optional[str]] = {} + for file_ in os.listdir(config_dir_root): + if file_.endswith(".toml"): + CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join( + config_dir_root, file_ + ) + CONFIG_OPTIONS["default"] = None + + REQUIRED_KEYS: dict[str, list] = { + "app": [], + "completion": ["provider"], + "crypto": ["provider"], + "email": ["provider"], + "auth": ["provider"], + "embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "add_title_as_prefix", + ], + "completion_embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "add_title_as_prefix", + ], + # TODO - deprecated, remove + "ingestion": ["provider"], + "logging": ["provider", "log_table"], + "database": ["provider"], + "agent": ["generation_config"], + "orchestration": ["provider"], + } + + app: AppConfig + auth: AuthConfig + completion: CompletionConfig + crypto: CryptoConfig + database: DatabaseConfig + embedding: EmbeddingConfig + completion_embedding: EmbeddingConfig + email: EmailConfig + ingestion: IngestionConfig + agent: RAGAgentConfig + orchestration: OrchestrationConfig + + def __init__(self, config_data: dict[str, Any]): + """ + :param config_data: dictionary of configuration parameters + :param base_path: base path when a relative path is specified for the prompts directory + """ + # Load the default configuration + default_config = self.load_default_config() + + # Override the default configuration with the passed configuration + default_config = deep_update(default_config, config_data) + + # Validate and set the configuration + for section, keys in R2RConfig.REQUIRED_KEYS.items(): + # Check the keys when provider is set + # TODO - remove after deprecation + if section in ["graph", "file"] and section not in default_config: + continue + if "provider" in default_config[section] and ( + default_config[section]["provider"] is not None + and default_config[section]["provider"] != "None" + and default_config[section]["provider"] != "null" + ): + self._validate_config_section(default_config, section, keys) + setattr(self, section, default_config[section]) + + self.app = AppConfig.create(**self.app) # type: ignore + self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore + self.completion = CompletionConfig.create( + **self.completion, app=self.app + ) # type: ignore + self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore + self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore + self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore + self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore + self.completion_embedding = EmbeddingConfig.create( + **self.completion_embedding, app=self.app + ) # type: ignore + self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore + self.agent = RAGAgentConfig.create(**self.agent, app=self.app) # type: ignore + self.orchestration = OrchestrationConfig.create( + **self.orchestration, app=self.app + ) # type: ignore + + IngestionConfig.set_default(**self.ingestion.dict()) + + # override GenerationConfig defaults + if self.completion.generation_config: + GenerationConfig.set_default( + **self.completion.generation_config.dict() + ) + + def _validate_config_section( + self, config_data: dict[str, Any], section: str, keys: list + ): + if section not in config_data: + raise ValueError(f"Missing '{section}' section in config") + if missing_keys := [ + key for key in keys if key not in config_data[section] + ]: + raise ValueError( + f"Missing required keys in '{section}' config: {', '.join(missing_keys)}" + ) + + @classmethod + def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig": + if config_path is None: + config_path = R2RConfig.default_config_path + + # Load configuration from TOML file + with open(config_path, encoding="utf-8") as f: + config_data = toml.load(f) + + return cls(config_data) + + def to_toml(self): + config_data = {} + for section in R2RConfig.REQUIRED_KEYS.keys(): + section_data = self._serialize_config(getattr(self, section)) + if isinstance(section_data, dict): + # Remove app from nested configs before serializing + section_data.pop("app", None) + config_data[section] = section_data + return toml.dumps(config_data) + + @classmethod + def load_default_config(cls) -> dict: + with open(R2RConfig.default_config_path, encoding="utf-8") as f: + return toml.load(f) + + @staticmethod + def _serialize_config(config_section: Any): + """Serialize config section while excluding internal state.""" + if isinstance(config_section, dict): + return { + R2RConfig._serialize_key(k): R2RConfig._serialize_config(v) + for k, v in config_section.items() + if k != "app" # Exclude app from serialization + } + elif isinstance(config_section, (list, tuple)): + return [ + R2RConfig._serialize_config(item) for item in config_section + ] + elif isinstance(config_section, Enum): + return config_section.value + elif isinstance(config_section, BaseModel): + data = config_section.model_dump(exclude_none=True) + data.pop("app", None) # Remove app from the serialized data + return R2RConfig._serialize_config(data) + else: + return config_section + + @staticmethod + def _serialize_key(key: Any) -> str: + return key.value if isinstance(key, Enum) else str(key) + + @classmethod + def load( + cls, + config_name: Optional[str] = None, + config_path: Optional[str] = None, + ) -> "R2RConfig": + if config_path and config_name: + raise ValueError( + f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}" + ) + + if config_path := os.getenv("R2R_CONFIG_PATH") or config_path: + return cls.from_toml(config_path) + + config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default" + if config_name not in R2RConfig.CONFIG_OPTIONS: + raise ValueError(f"Invalid config name: {config_name}") + return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name]) diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py new file mode 100644 index 00000000..19cb0428 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py @@ -0,0 +1,16 @@ +# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments +from .hatchet.graph_workflow import ( # type: ignore + hatchet_graph_search_results_factory, +) +from .hatchet.ingestion_workflow import ( # type: ignore + hatchet_ingestion_factory, +) +from .simple.graph_workflow import simple_graph_search_results_factory +from .simple.ingestion_workflow import simple_ingestion_factory + +__all__ = [ + "hatchet_ingestion_factory", + "hatchet_graph_search_results_factory", + "simple_ingestion_factory", + "simple_graph_search_results_factory", +] diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py new file mode 100644 index 00000000..cc128b0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py @@ -0,0 +1,539 @@ +# type: ignore +import asyncio +import contextlib +import json +import logging +import math +import time +import uuid +from typing import TYPE_CHECKING + +from hatchet_sdk import ConcurrencyLimitStrategy, Context + +from core import GenerationConfig +from core.base import OrchestrationProvider, R2RException +from core.base.abstractions import ( + GraphConstructionStatus, + GraphExtractionStatus, +) + +from ...services import GraphService + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + +logger = logging.getLogger() + + +def hatchet_graph_search_results_factory( + orchestration_provider: OrchestrationProvider, service: GraphService +) -> dict[str, "Hatchet.Workflow"]: + def convert_to_dict(input_data): + """Converts input data back to a plain dictionary format, handling + special cases like UUID and GenerationConfig. This is the inverse of + get_input_data_dict. + + Args: + input_data: Dictionary containing the input data with potentially special types + + Returns: + Dictionary with all values converted to basic Python types + """ + output_data = {} + + for key, value in input_data.items(): + if value is None: + output_data[key] = None + continue + + # Convert UUID to string + if isinstance(value, uuid.UUID): + output_data[key] = str(value) + + try: + output_data[key] = value.model_dump() + except Exception: + # Handle nested dictionaries that might contain settings + if isinstance(value, dict): + output_data[key] = convert_to_dict(value) + + # Handle lists that might contain dictionaries + elif isinstance(value, list): + output_data[key] = [ + ( + convert_to_dict(item) + if isinstance(item, dict) + else item + ) + for item in value + ] + + # All other types can be directly assigned + else: + output_data[key] = value + + return output_data + + def get_input_data_dict(input_data): + for key, value in input_data.items(): + if value is None: + continue + + if key == "document_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "collection_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "graph_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key in ["graph_creation_settings", "graph_enrichment_settings"]: + # Ensure we have a dict (if not already) + input_data[key] = ( + json.loads(value) if not isinstance(value, dict) else value + ) + + if "generation_config" in input_data[key]: + gen_cfg = input_data[key]["generation_config"] + # If it's a dict, convert it + if isinstance(gen_cfg, dict): + input_data[key]["generation_config"] = ( + GenerationConfig(**gen_cfg) + ) + # If it's not already a GenerationConfig, default it + elif not isinstance(gen_cfg, GenerationConfig): + input_data[key]["generation_config"] = ( + GenerationConfig() + ) + + input_data[key]["generation_config"].model = ( + input_data[key]["generation_config"].model + or service.config.app.fast_llm + ) + + return input_data + + @orchestration_provider.workflow(name="graph-extraction", timeout="360m") + class GraphExtractionWorkflow: + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + with contextlib.suppress(Exception): + return str( + context.workflow_input()["request"]["collection_id"] + ) + + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.step(retries=1, timeout="360m") + async def graph_search_results_extraction( + self, context: Context + ) -> dict: + request = context.workflow_input()["request"] + + input_data = get_input_data_dict(request) + document_id = input_data.get("document_id", None) + collection_id = input_data.get("collection_id", None) + + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.PROCESSING, + ) + + if collection_id and not document_id: + document_ids = await self.graph_search_results_service.get_document_ids_for_create_graph( + collection_id=collection_id, + **input_data["graph_creation_settings"], + ) + workflows = [] + + for document_id in document_ids: + input_data_copy = input_data.copy() + input_data_copy["collection_id"] = str( + input_data_copy["collection_id"] + ) + input_data_copy["document_id"] = str(document_id) + + workflows.append( + context.aio.spawn_workflow( + "graph-extraction", + { + "request": { + **convert_to_dict(input_data_copy), + } + }, + key=str(document_id), + ) + ) + # Wait for all workflows to complete + results = await asyncio.gather(*workflows) + return { + "result": f"successfully submitted graph_search_results relationships extraction for document {document_id}", + "document_id": str(collection_id), + } + + else: + # Extract relationships and store them + extractions = [] + async for extraction in self.graph_search_results_service.graph_search_results_extraction( + document_id=document_id, + **input_data["graph_creation_settings"], + ): + logger.info( + f"Found extraction with {len(extraction.entities)} entities" + ) + extractions.append(extraction) + + await self.graph_search_results_service.store_graph_search_results_extractions( + extractions + ) + + logger.info( + f"Successfully ran graph_search_results relationships extraction for document {document_id}" + ) + + return { + "result": f"successfully ran graph_search_results relationships extraction for document {document_id}", + "document_id": str(document_id), + } + + @orchestration_provider.step( + retries=1, + timeout="360m", + parents=["graph_search_results_extraction"], + ) + async def graph_search_results_entity_description( + self, context: Context + ) -> dict: + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + document_id = input_data.get("document_id", None) + + # Describe the entities in the graph + await self.graph_search_results_service.graph_search_results_entity_description( + document_id=document_id, + **input_data["graph_creation_settings"], + ) + + logger.info( + f"Successfully ran graph_search_results entity description for document {document_id}" + ) + + if service.providers.database.config.graph_creation_settings.automatic_deduplication: + extract_input = { + "document_id": str(document_id), + } + + extract_result = ( + await context.aio.spawn_workflow( + "graph-deduplication", + {"request": extract_input}, + ) + ).result() + + await asyncio.gather(extract_result) + + return { + "result": f"successfully ran graph_search_results entity description for document {document_id}" + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.info( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=uuid.UUID(document_id), + status_type="extraction_status", + status=GraphExtractionStatus.FAILED, + ) + logger.info( + f"Updated Graph extraction status for {document_id} to FAILED" + ) + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow(name="graph-clustering", timeout="360m") + class GraphClusteringWorkflow: + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.step(retries=1, timeout="360m", parents=[]) + async def graph_search_results_clustering( + self, context: Context + ) -> dict: + logger.info("Running Graph Clustering") + + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + + # Get the collection_id and graph_id + collection_id = input_data.get("collection_id", None) + graph_id = input_data.get("graph_id", None) + + # Check current workflow status + workflow_status = await self.graph_search_results_service.providers.database.documents_handler.get_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + ) + + if workflow_status == GraphConstructionStatus.SUCCESS: + raise R2RException( + "Communities have already been built for this collection. To build communities again, first reset the graph.", + 400, + ) + + # Run clustering + try: + graph_search_results_clustering_results = await self.graph_search_results_service.graph_search_results_clustering( + collection_id=collection_id, + graph_id=graph_id, + **input_data["graph_enrichment_settings"], + ) + + num_communities = graph_search_results_clustering_results[ + "num_communities" + ][0] + + if num_communities == 0: + raise R2RException("No communities found", 400) + + return { + "result": graph_search_results_clustering_results, + } + except Exception as e: + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + raise e + + @orchestration_provider.step( + retries=1, + timeout="360m", + parents=["graph_search_results_clustering"], + ) + async def graph_search_results_community_summary( + self, context: Context + ) -> dict: + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + collection_id = input_data.get("collection_id", None) + graph_id = input_data.get("graph_id", None) + # Get number of communities from previous step + num_communities = context.step_output( + "graph_search_results_clustering" + )["result"]["num_communities"][0] + + # Calculate batching + parallel_communities = min(100, num_communities) + total_workflows = math.ceil(num_communities / parallel_communities) + workflows = [] + + logger.info( + f"Running Graph Community Summary for {num_communities} communities, spawning {total_workflows} workflows" + ) + + # Spawn summary workflows + for i in range(total_workflows): + offset = i * parallel_communities + limit = min(parallel_communities, num_communities - offset) + + workflows.append( + ( + await context.aio.spawn_workflow( + "graph-community-summarization", + { + "request": { + "offset": offset, + "limit": limit, + "graph_id": ( + str(graph_id) if graph_id else None + ), + "collection_id": ( + str(collection_id) + if collection_id + else None + ), + "graph_enrichment_settings": convert_to_dict( + input_data["graph_enrichment_settings"] + ), + } + }, + key=f"{i}/{total_workflows}_community_summary", + ) + ).result() + ) + + results = await asyncio.gather(*workflows) + logger.info( + f"Completed {len(results)} community summary workflows" + ) + + # Update statuses + document_ids = await self.graph_search_results_service.providers.database.documents_handler.get_document_ids_by_status( + status_type="extraction_status", + status=GraphExtractionStatus.SUCCESS, + collection_id=collection_id, + ) + + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=document_ids, + status_type="extraction_status", + status=GraphExtractionStatus.ENRICHED, + ) + + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.SUCCESS, + ) + + return { + "result": f"Successfully completed enrichment with {len(results)} summary workflows" + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + collection_id = context.workflow_input()["request"].get( + "collection_id", None + ) + if collection_id: + await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status( + id=uuid.UUID(collection_id), + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + + @orchestration_provider.workflow( + name="graph-community-summarization", timeout="360m" + ) + class GraphCommunitySummarizerWorkflow: + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + return str( + context.workflow_input()["request"]["collection_id"] + ) + except Exception: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=1, timeout="360m") + async def graph_search_results_community_summary( + self, context: Context + ) -> dict: + start_time = time.time() + + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + + base_args = { + k: v + for k, v in input_data.items() + if k != "graph_enrichment_settings" + } + enrichment_args = input_data.get("graph_enrichment_settings", {}) + + # Merge them together. + # Note: if there is any key overlap, values from enrichment_args will override those from base_args. + merged_args = {**base_args, **enrichment_args} + + # Now call the service method with all arguments at the top level. + # This ensures that keys like "max_summary_input_length" and "generation_config" are present. + community_summary = await self.graph_search_results_service.graph_search_results_community_summary( + **merged_args + ) + logger.info( + f"Successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)} in {time.time() - start_time:.2f} seconds " + ) + return { + "result": f"successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}" + } + + @orchestration_provider.workflow( + name="graph-deduplication", timeout="360m" + ) + class GraphDeduplicationWorkflow: + def __init__(self, graph_search_results_service: GraphService): + self.graph_search_results_service = graph_search_results_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + return str(context.workflow_input()["request"]["document_id"]) + except Exception: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=1, timeout="360m") + async def deduplicate_document_entities( + self, context: Context + ) -> dict: + start_time = time.time() + + input_data = get_input_data_dict( + context.workflow_input()["request"] + ) + + document_id = input_data.get("document_id", None) + + await service.deduplicate_document_entities( + document_id=document_id, + ) + logger.info( + f"Successfully ran deduplication for document {document_id} in {time.time() - start_time:.2f} seconds " + ) + return { + "result": f"Successfully ran deduplication for document {document_id}" + } + + return { + "graph-extraction": GraphExtractionWorkflow(service), + "graph-clustering": GraphClusteringWorkflow(service), + "graph-community-summarization": GraphCommunitySummarizerWorkflow( + service + ), + "graph-deduplication": GraphDeduplicationWorkflow(service), + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py new file mode 100644 index 00000000..96d7aebb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py @@ -0,0 +1,721 @@ +# type: ignore +import asyncio +import logging +import uuid +from typing import TYPE_CHECKING +from uuid import UUID + +import tiktoken +from fastapi import HTTPException +from hatchet_sdk import ConcurrencyLimitStrategy, Context +from litellm import AuthenticationError + +from core.base import ( + DocumentChunk, + GraphConstructionStatus, + IngestionStatus, + OrchestrationProvider, + generate_extraction_id, +) +from core.base.abstractions import DocumentResponse, R2RException +from core.utils import ( + generate_default_user_collection_id, + update_settings_from_dict, +) + +from ...services import IngestionService, IngestionServiceAdapter + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + +logger = logging.getLogger() + + +# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module +def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback to a known encoding if model not recognized + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text, disallowed_special=())) + + +def hatchet_ingestion_factory( + orchestration_provider: OrchestrationProvider, service: IngestionService +) -> dict[str, "Hatchet.Workflow"]: + @orchestration_provider.workflow( + name="ingest-files", + timeout="60m", + ) + class HatchetIngestFilesWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.ingestion_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + return str(parsed_data["user"].id) + except Exception: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=0, timeout="60m") + async def parse(self, context: Context) -> dict: + try: + logger.info("Initiating ingestion workflow, step: parse") + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + # ingestion_result = ( + # await self.ingestion_service.ingest_file_ingress( + # **parsed_data + # ) + # ) + + # document_info = ingestion_result["info"] + document_info = ( + self.ingestion_service.create_document_info_from_file( + parsed_data["document_id"], + parsed_data["user"], + parsed_data["file_data"]["filename"], + parsed_data["metadata"], + parsed_data["version"], + parsed_data["size_in_bytes"], + ) + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.PARSING, + ) + + ingestion_config = parsed_data["ingestion_config"] or {} + extractions_generator = self.ingestion_service.parse_file( + document_info, ingestion_config + ) + + extractions = [] + async for extraction in extractions_generator: + extractions.append(extraction) + + # 2) Sum tokens + total_tokens = 0 + for chunk in extractions: + text_data = chunk.data + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + if not ingestion_config.get("skip_document_summary", False): + await service.update_document_status( + document_info, status=IngestionStatus.AUGMENTING + ) + await service.augment_document_info( + document_info, + [extraction.to_dict() for extraction in extractions], + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.EMBEDDING, + ) + + # extractions = context.step_output("parse")["extractions"] + + embedding_generator = self.ingestion_service.embed_document( + [extraction.to_dict() for extraction in extractions] + ) + + embeddings = [] + async for embedding in embedding_generator: + embeddings.append(embedding) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.STORING, + ) + + storage_generator = self.ingestion_service.store_embeddings( # type: ignore + embeddings + ) + + async for _ in storage_generator: + pass + + await self.ingestion_service.finalize_ingestion(document_info) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + collection_ids = context.workflow_input()["request"].get( + "collection_ids" + ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=GraphConstructionStatus.OUTDATED, + ) + else: + for collection_id_str in collection_ids: + collection_id = UUID(collection_id_str) + try: + name = document_info.title or "N/A" + description = "" + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await ( + self.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + ) + + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=GraphConstructionStatus.OUTDATED, + ) + + # get server chunk enrichment settings and override parts of it if provided in the ingestion config + if server_chunk_enrichment_settings := getattr( + service.providers.ingestion.config, + "chunk_enrichment_settings", + None, + ): + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: + logger.info("Enriching document with contextual chunks") + + document_info: DocumentResponse = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_user_ids=[document_info.owner_id], + filter_document_ids=[document_info.id], + ) + )["results"][0] + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.ENRICHING, + ) + + await self.ingestion_service.chunk_enrichment( + document_id=document_info.id, + document_summary=document_info.summary, + chunk_enrichment_settings=chunk_enrichment_settings, + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + # ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + + if service.providers.ingestion.config.automatic_extraction: + extract_input = { + "document_id": str(document_info.id), + "graph_creation_settings": self.ingestion_service.providers.database.config.graph_creation_settings.model_dump_json(), + "user": input_data["user"], + } + + extract_result = ( + await context.aio.spawn_workflow( + "graph-extraction", + {"request": extract_input}, + ) + ).result() + + await asyncio.gather(extract_result) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + except AuthenticationError: + raise R2RException( + status_code=401, + message="Authentication error: Invalid API key or credentials.", + ) from None + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during ingestion: {str(e)}", + ) from e + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + # Update the document status to FAILED + if document_info.ingestion_status != IngestionStatus.SUCCESS: + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{context.step_run_errors()}"}, + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow( + name="ingest-chunks", + timeout="60m", + ) + class HatchetIngestChunksWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def ingest(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await self.ingestion_service.ingest_chunks_ingress( + **parsed_data + ) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentChunk( + id=generate_extraction_id(document_id, i), + document_id=document_id, + collection_ids=[], + owner_id=document_info.owner_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).to_dict() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + + # 2) Sum tokens + total_tokens = 0 + for chunk in extractions: + text_data = chunk["data"] + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + return { + "status": "Successfully ingested chunks", + "extractions": extractions, + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["ingest"], timeout="60m") + async def embed(self, context: Context) -> dict: + document_info_dict = context.step_output("ingest")["document_info"] + document_info = DocumentResponse(**document_info_dict) + + extractions = context.step_output("ingest")["extractions"] + + embedding_generator = self.ingestion_service.embed_document( + extractions + ) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + + storage_generator = self.ingestion_service.store_embeddings( + embeddings + ) + async for _ in storage_generator: + pass + + return { + "status": "Successfully embedded and stored chunks", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["embed"], timeout="60m") + async def finalize(self, context: Context) -> dict: + document_info_dict = context.step_output("embed")["document_info"] + document_info = DocumentResponse(**document_info_dict) + + await self.ingestion_service.finalize_ingestion(document_info) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + try: + # TODO - Move logic onto the `management service` + collection_ids = context.workflow_input()["request"].get( + "collection_ids" + ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=GraphConstructionStatus.OUTDATED, + ) + else: + for collection_id_str in collection_ids: + collection_id = UUID(collection_id_str) + try: + name = document_info.title or "N/A" + description = "" + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await ( + self.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + ) + + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_document_ids=[document_id], + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + if document_info.ingestion_status != IngestionStatus.SUCCESS: + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.FAILED + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow( + name="update-chunk", + timeout="60m", + ) + class HatchetUpdateChunkWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def update_chunk(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["id"]) + if isinstance(parsed_data["id"], str) + else parsed_data["id"] + ) + + await self.ingestion_service.update_chunk_ingress( + document_id=document_uuid, + chunk_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) + + return { + "message": "Chunk update completed successfully.", + "task_id": context.workflow_run_id(), + "document_ids": [str(document_uuid)], + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during chunk update: {str(e)}", + ) from e + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + + @orchestration_provider.workflow( + name="create-vector-index", timeout="360m" + ) + class HatchetCreateVectorIndexWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def create_vector_index(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = ( + IngestionServiceAdapter.parse_create_vector_index_input( + input_data + ) + ) + + await self.ingestion_service.providers.database.chunks_handler.create_index( + **parsed_data + ) + + return { + "status": "Vector index creation queued successfully.", + } + + @orchestration_provider.workflow(name="delete-vector-index", timeout="30m") + class HatchetDeleteVectorIndexWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="10m") + async def delete_vector_index(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = ( + IngestionServiceAdapter.parse_delete_vector_index_input( + input_data + ) + ) + + await self.ingestion_service.providers.database.chunks_handler.delete_index( + **parsed_data + ) + + return {"status": "Vector index deleted successfully."} + + @orchestration_provider.workflow( + name="update-document-metadata", + timeout="30m", + ) + class HatchetUpdateDocumentMetadataWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="30m") + async def update_document_metadata(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_document_metadata_input( + input_data + ) + + document_id = UUID(parsed_data["document_id"]) + metadata = parsed_data["metadata"] + user = parsed_data["user"] + + await self.ingestion_service.update_document_metadata( + document_id=document_id, + metadata=metadata, + user=user, + ) + + return { + "message": "Document metadata update completed successfully.", + "document_id": str(document_id), + "task_id": context.workflow_run_id(), + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during document metadata update: {str(e)}", + ) from e + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + + # Add this to the workflows dictionary in hatchet_ingestion_factory + ingest_files_workflow = HatchetIngestFilesWorkflow(service) + ingest_chunks_workflow = HatchetIngestChunksWorkflow(service) + update_chunks_workflow = HatchetUpdateChunkWorkflow(service) + update_document_metadata_workflow = HatchetUpdateDocumentMetadataWorkflow( + service + ) + create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service) + delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service) + + return { + "ingest_files": ingest_files_workflow, + "ingest_chunks": ingest_chunks_workflow, + "update_chunk": update_chunks_workflow, + "update_document_metadata": update_document_metadata_workflow, + "create_vector_index": create_vector_index_workflow, + "delete_vector_index": delete_vector_index_workflow, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py new file mode 100644 index 00000000..9e043263 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py @@ -0,0 +1,222 @@ +import json +import logging +import math +import uuid + +from core import GenerationConfig, R2RException +from core.base.abstractions import ( + GraphConstructionStatus, + GraphExtractionStatus, +) + +from ...services import GraphService + +logger = logging.getLogger() + + +def simple_graph_search_results_factory(service: GraphService): + def get_input_data_dict(input_data): + for key, value in input_data.items(): + if value is None: + continue + + if key == "document_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "collection_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key == "graph_id": + input_data[key] = ( + uuid.UUID(value) + if not isinstance(value, uuid.UUID) + else value + ) + + if key in ["graph_creation_settings", "graph_enrichment_settings"]: + # Ensure we have a dict (if not already) + input_data[key] = ( + json.loads(value) if not isinstance(value, dict) else value + ) + + if "generation_config" in input_data[key]: + if isinstance(input_data[key]["generation_config"], dict): + input_data[key]["generation_config"] = ( + GenerationConfig( + **input_data[key]["generation_config"] + ) + ) + elif not isinstance( + input_data[key]["generation_config"], GenerationConfig + ): + input_data[key]["generation_config"] = ( + GenerationConfig() + ) + + input_data[key]["generation_config"].model = ( + input_data[key]["generation_config"].model + or service.config.app.fast_llm + ) + + return input_data + + async def graph_extraction(input_data): + input_data = get_input_data_dict(input_data) + + if input_data.get("document_id"): + document_ids = [input_data.get("document_id")] + else: + documents = [] + collection_id = input_data.get("collection_id") + batch_size = 100 + offset = 0 + while True: + # Fetch current batch + batch = ( + await service.providers.database.collections_handler.documents_in_collection( + collection_id=collection_id, + offset=offset, + limit=batch_size, + ) + )["results"] + + # If no documents returned, we've reached the end + if not batch: + break + + # Add current batch to results + documents.extend(batch) + + # Update offset for next batch + offset += batch_size + + # Optional: If batch is smaller than batch_size, we've reached the end + if len(batch) < batch_size: + break + + document_ids = [document.id for document in documents] + + logger.info( + f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}" + ) + + for _, document_id in enumerate(document_ids): + await service.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.PROCESSING, + ) + + # Extract relationships from the document + try: + extractions = [] + async for ( + extraction + ) in service.graph_search_results_extraction( + document_id=document_id, + **input_data["graph_creation_settings"], + ): + extractions.append(extraction) + await service.store_graph_search_results_extractions( + extractions + ) + + # Describe the entities in the graph + await service.graph_search_results_entity_description( + document_id=document_id, + **input_data["graph_creation_settings"], + ) + + if service.providers.database.config.graph_creation_settings.automatic_deduplication: + logger.warning( + "Automatic deduplication is not yet implemented for `simple` workflows." + ) + + except Exception as e: + logger.error( + f"Error in creating graph for document {document_id}: {e}" + ) + raise e + + async def graph_clustering(input_data): + input_data = get_input_data_dict(input_data) + workflow_status = await service.providers.database.documents_handler.get_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + ) + if workflow_status == GraphConstructionStatus.SUCCESS: + raise R2RException( + "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset` to erase the previously built communities.", + 400, + ) + + try: + num_communities = await service.graph_search_results_clustering( + collection_id=input_data.get("collection_id", None), + # graph_id=input_data.get("graph_id", None), + **input_data["graph_enrichment_settings"], + ) + num_communities = num_communities["num_communities"][0] + # TODO - Do not hardcode the number of parallel communities, + # make it a configurable parameter at runtime & add server-side defaults + + if num_communities == 0: + raise R2RException("No communities found", 400) + + parallel_communities = min(100, num_communities) + + total_workflows = math.ceil(num_communities / parallel_communities) + for i in range(total_workflows): + input_data_copy = input_data.copy() + input_data_copy["offset"] = i * parallel_communities + input_data_copy["limit"] = min( + parallel_communities, + num_communities - i * parallel_communities, + ) + + logger.info( + f"Running graph_search_results community summary for workflow {i + 1} of {total_workflows}" + ) + + await service.graph_search_results_community_summary( + offset=input_data_copy["offset"], + limit=input_data_copy["limit"], + collection_id=input_data_copy.get("collection_id", None), + # graph_id=input_data_copy.get("graph_id", None), + **input_data_copy["graph_enrichment_settings"], + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + status=GraphConstructionStatus.SUCCESS, + ) + + except Exception as e: + await service.providers.database.documents_handler.set_workflow_status( + id=input_data.get("collection_id", None), + status_type="graph_cluster_status", + status=GraphConstructionStatus.FAILED, + ) + + raise e + + async def graph_deduplication(input_data): + input_data = get_input_data_dict(input_data) + await service.deduplicate_document_entities( + document_id=input_data.get("document_id", None), + ) + + return { + "graph-extraction": graph_extraction, + "graph-clustering": graph_clustering, + "graph-deduplication": graph_deduplication, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py new file mode 100644 index 00000000..60a696c1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py @@ -0,0 +1,598 @@ +import asyncio +import logging +from uuid import UUID + +import tiktoken +from fastapi import HTTPException +from litellm import AuthenticationError + +from core.base import ( + DocumentChunk, + GraphConstructionStatus, + R2RException, + increment_version, +) +from core.utils import ( + generate_default_user_collection_id, + generate_extraction_id, + update_settings_from_dict, +) + +from ...services import IngestionService + +logger = logging.getLogger() + + +# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module +def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback to a known encoding if model not recognized + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text, disallowed_special=())) + + +def simple_ingestion_factory(service: IngestionService): + async def ingest_files(input_data): + document_info = None + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + document_info = service.create_document_info_from_file( + parsed_data["document_id"], + parsed_data["user"], + parsed_data["file_data"]["filename"], + parsed_data["metadata"], + parsed_data["version"], + parsed_data["size_in_bytes"], + ) + + await service.update_document_status( + document_info, status=IngestionStatus.PARSING + ) + + ingestion_config = parsed_data["ingestion_config"] + extractions_generator = service.parse_file( + document_info=document_info, + ingestion_config=ingestion_config, + ) + extractions = [ + extraction.model_dump() + async for extraction in extractions_generator + ] + + # 2) Sum tokens + total_tokens = 0 + for chunk_dict in extractions: + text_data = chunk_dict["data"] + if not isinstance(text_data, str): + text_data = text_data.decode("utf-8", errors="ignore") + total_tokens += count_tokens_for_text(text_data) + document_info.total_tokens = total_tokens + + if not ingestion_config.get("skip_document_summary", False): + await service.update_document_status( + document_info=document_info, + status=IngestionStatus.AUGMENTING, + ) + await service.augment_document_info(document_info, extractions) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + embedding_generator = service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + collection_ids = parsed_data.get("collection_ids") + + try: + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + else: + for collection_id in collection_ids: + try: + # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully + name = "My Collection" + description = f"A collection started during {document_info.title} ingestion" + + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await service.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + # Chunk enrichment + if server_chunk_enrichment_settings := getattr( + service.providers.ingestion.config, + "chunk_enrichment_settings", + None, + ): + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: + logger.info("Enriching document with contextual chunks") + + # Get updated document info with collection IDs + document_info = ( + await service.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[document_info.owner_id], + filter_document_ids=[document_info.id], + ) + )["results"][0] + + await service.update_document_status( + document_info, + status=IngestionStatus.ENRICHING, + ) + + await service.chunk_enrichment( + document_id=document_info.id, + document_summary=document_info.summary, + chunk_enrichment_settings=chunk_enrichment_settings, + ) + + await service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + # Automatic extraction + if service.providers.ingestion.config.automatic_extraction: + logger.warning( + "Automatic extraction not yet implemented for `simple` ingestion workflows." + ) + + except AuthenticationError as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + raise R2RException( + status_code=401, + message="Authentication error: Invalid API key or credentials.", + ) from e + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + if isinstance(e, R2RException): + raise + raise HTTPException( + status_code=500, detail=f"Error during ingestion: {str(e)}" + ) from e + + async def update_files(input_data): + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_update_files_input( + input_data + ) + + file_datas = parsed_data["file_datas"] + user = parsed_data["user"] + document_ids = parsed_data["document_ids"] + metadatas = parsed_data["metadatas"] + ingestion_config = parsed_data["ingestion_config"] + file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] + + if not file_datas: + raise R2RException( + status_code=400, message="No files provided for update." + ) from None + if len(document_ids) != len(file_datas): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) from None + + documents_overview = ( + await service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_user_ids=None if user.is_superuser else [user.id], + filter_document_ids=document_ids, + ) + )["results"] + + if len(documents_overview) != len(document_ids): + raise R2RException( + status_code=404, + message="One or more documents not found.", + ) from None + + results = [] + + for idx, ( + file_data, + doc_id, + doc_info, + file_size_in_bytes, + ) in enumerate( + zip( + file_datas, + document_ids, + documents_overview, + file_sizes_in_bytes, + strict=False, + ) + ): + new_version = increment_version(doc_info.version) + + updated_metadata = ( + metadatas[idx] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title") + or file_data["filename"].split("/")[-1] + ) + + ingest_input = { + "file_data": file_data, + "user": user.model_dump(), + "metadata": updated_metadata, + "document_id": str(doc_id), + "version": new_version, + "ingestion_config": ingestion_config, + "size_in_bytes": file_size_in_bytes, + } + + result = ingest_files(ingest_input) + results.append(result) + + await asyncio.gather(*results) + if service.providers.ingestion.config.automatic_extraction: + raise R2RException( + status_code=501, + message="Automatic extraction not yet implemented for `simple` ingestion workflows.", + ) from None + + async def ingest_chunks(input_data): + document_info = None + try: + from core.base import IngestionStatus + from core.main import IngestionServiceAdapter + + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await service.ingest_chunks_ingress(**parsed_data) + + await service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentChunk( + id=( + generate_extraction_id(document_id, i) + if chunk.id is None + else chunk.id + ), + document_id=document_id, + collection_ids=[], + owner_id=document_info.owner_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).model_dump() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + + embedding_generator = service.embed_document(extractions) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + storage_generator = service.store_embeddings(embeddings) + async for _ in storage_generator: + pass + + await service.finalize_ingestion(document_info) + + await service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + collection_ids = parsed_data.get("collection_ids") + + try: + # TODO - Move logic onto management service + if not collection_ids: + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + + else: + for collection_id in collection_ids: + try: + name = document_info.title or "N/A" + description = "" + result = await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await service.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + + if service.providers.ingestion.config.automatic_extraction: + raise R2RException( + status_code=501, + message="Automatic extraction not yet implemented for `simple` ingestion workflows.", + ) from None + + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + except Exception as e: + if document_info is not None: + await service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + metadata={"failure": f"{str(e)}"}, + ) + raise HTTPException( + status_code=500, + detail=f"Error during chunk ingestion: {str(e)}", + ) from e + + async def update_chunk(input_data): + from core.main import IngestionServiceAdapter + + try: + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["id"]) + if isinstance(parsed_data["id"], str) + else parsed_data["id"] + ) + + await service.update_chunk_ingress( + document_id=document_uuid, + chunk_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during chunk update: {str(e)}", + ) from e + + async def create_vector_index(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_create_vector_index_input( + input_data + ) + ) + + await service.providers.database.chunks_handler.create_index( + **parsed_data + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during vector index creation: {str(e)}", + ) from e + + async def delete_vector_index(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_delete_vector_index_input( + input_data + ) + ) + + await service.providers.database.chunks_handler.delete_index( + **parsed_data + ) + + return {"status": "Vector index deleted successfully."} + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during vector index deletion: {str(e)}", + ) from e + + async def update_document_metadata(input_data): + try: + from core.main import IngestionServiceAdapter + + parsed_data = ( + IngestionServiceAdapter.parse_update_document_metadata_input( + input_data + ) + ) + + document_id = parsed_data["document_id"] + metadata = parsed_data["metadata"] + user = parsed_data["user"] + + await service.update_document_metadata( + document_id=document_id, + metadata=metadata, + user=user, + ) + + return { + "message": "Document metadata update completed successfully.", + "document_id": str(document_id), + "task_id": None, + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during document metadata update: {str(e)}", + ) from e + + return { + "ingest-files": ingest_files, + "ingest-chunks": ingest_chunks, + "update-chunk": update_chunk, + "update-document-metadata": update_document_metadata, + "create-vector-index": create_vector_index, + "delete-vector-index": delete_vector_index, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/services/__init__.py b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py new file mode 100644 index 00000000..e6a6dec0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py @@ -0,0 +1,14 @@ +from .auth_service import AuthService +from .graph_service import GraphService +from .ingestion_service import IngestionService, IngestionServiceAdapter +from .management_service import ManagementService +from .retrieval_service import RetrievalService # type: ignore + +__all__ = [ + "AuthService", + "IngestionService", + "IngestionServiceAdapter", + "ManagementService", + "GraphService", + "RetrievalService", +] diff --git a/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py new file mode 100644 index 00000000..c04dd78c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py @@ -0,0 +1,316 @@ +import logging +from datetime import datetime +from typing import Optional +from uuid import UUID + +from core.base import R2RException, Token +from core.base.api.models import User +from core.utils import generate_default_user_collection_id + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + + +class AuthService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def register( + self, + email: str, + password: str, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + ) -> User: + return await self.providers.auth.register( + email=email, + password=password, + name=name, + bio=bio, + profile_picture=profile_picture, + ) + + async def send_verification_email( + self, email: str + ) -> tuple[str, datetime]: + return await self.providers.auth.send_verification_email(email=email) + + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + if not self.config.auth.require_email_verification: + raise R2RException( + status_code=400, message="Email verification is not required" + ) + + user_id = await self.providers.database.users_handler.get_user_id_by_verification_code( + verification_code + ) + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if not user or user.email != email: + raise R2RException( + status_code=400, message="Invalid or expired verification code" + ) + + await self.providers.database.users_handler.mark_user_as_verified( + user_id + ) + await self.providers.database.users_handler.remove_verification_code( + verification_code + ) + return {"message": f"User account {user_id} verified successfully."} + + async def login(self, email: str, password: str) -> dict[str, Token]: + return await self.providers.auth.login(email, password) + + async def user(self, token: str) -> User: + token_data = await self.providers.auth.decode_token(token) + if not token_data.email: + raise R2RException( + status_code=401, message="Invalid authentication credentials" + ) + user = await self.providers.database.users_handler.get_user_by_email( + token_data.email + ) + if user is None: + raise R2RException( + status_code=401, message="Invalid authentication credentials" + ) + return user + + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + return await self.providers.auth.refresh_access_token(refresh_token) + + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + if not user: + raise R2RException(status_code=404, message="User not found") + return await self.providers.auth.change_password( + user, current_password, new_password + ) + + async def request_password_reset(self, email: str) -> dict[str, str]: + return await self.providers.auth.request_password_reset(email) + + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + return await self.providers.auth.confirm_password_reset( + reset_token, new_password + ) + + async def logout(self, token: str) -> dict[str, str]: + return await self.providers.auth.logout(token) + + async def update_user( + self, + user_id: UUID, + email: Optional[str] = None, + is_superuser: Optional[bool] = None, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + limits_overrides: Optional[dict] = None, + merge_limits: bool = False, + new_metadata: Optional[dict] = None, + ) -> User: + user: User = ( + await self.providers.database.users_handler.get_user_by_id(user_id) + ) + if not user: + raise R2RException(status_code=404, message="User not found") + if email is not None: + user.email = email + if is_superuser is not None: + user.is_superuser = is_superuser + if name is not None: + user.name = name + if bio is not None: + user.bio = bio + if profile_picture is not None: + user.profile_picture = profile_picture + if limits_overrides is not None: + user.limits_overrides = limits_overrides + return await self.providers.database.users_handler.update_user( + user, merge_limits=merge_limits, new_metadata=new_metadata + ) + + async def delete_user( + self, + user_id: UUID, + password: Optional[str] = None, + delete_vector_data: bool = False, + is_superuser: bool = False, + ) -> dict[str, str]: + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if not user: + raise R2RException(status_code=404, message="User not found") + if not is_superuser and not password: + raise R2RException( + status_code=422, message="Password is required for deletion" + ) + if not ( + is_superuser + or ( + user.hashed_password is not None + and password is not None + and self.providers.auth.crypto_provider.verify_password( + plain_password=password, + hashed_password=user.hashed_password, + ) + ) + ): + raise R2RException(status_code=400, message="Incorrect password") + await self.providers.database.users_handler.delete_user_relational( + user_id + ) + + # Delete user's default collection + # TODO: We need to better define what happens to the user's data when they are deleted + collection_id = generate_default_user_collection_id(user_id) + await self.providers.database.collections_handler.delete_collection_relational( + collection_id + ) + + try: + await self.providers.database.graphs_handler.delete( + collection_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Error deleting graph for collection {collection_id}: {e}" + ) + + if delete_vector_data: + await self.providers.database.chunks_handler.delete_user_vector( + user_id + ) + await self.providers.database.chunks_handler.delete_collection_vector( + collection_id + ) + + return {"message": f"User account {user_id} deleted successfully."} + + async def clean_expired_blacklisted_tokens( + self, + max_age_hours: int = 7 * 24, + current_time: Optional[datetime] = None, + ): + await self.providers.database.token_handler.clean_expired_blacklisted_tokens( + max_age_hours, current_time + ) + + async def get_user_verification_code( + self, + user_id: UUID, + ) -> dict: + """Get only the verification code data for a specific user. + + This method should be called after superuser authorization has been + verified. + """ + verification_data = await self.providers.database.users_handler.get_user_validation_data( + user_id=user_id + ) + return { + "verification_code": verification_data["verification_data"][ + "verification_code" + ], + "expiry": verification_data["verification_data"][ + "verification_code_expiry" + ], + } + + async def get_user_reset_token( + self, + user_id: UUID, + ) -> dict: + """Get only the verification code data for a specific user. + + This method should be called after superuser authorization has been + verified. + """ + verification_data = await self.providers.database.users_handler.get_user_validation_data( + user_id=user_id + ) + return { + "reset_token": verification_data["verification_data"][ + "reset_token" + ], + "expiry": verification_data["verification_data"][ + "reset_token_expiry" + ], + } + + async def send_reset_email(self, email: str) -> dict: + """Generate a new verification code and send a reset email to the user. + Returns the verification code for testing/sandbox environments. + + Args: + email (str): The email address of the user + + Returns: + dict: Contains verification_code and message + """ + return await self.providers.auth.send_reset_email(email) + + async def create_user_api_key( + self, user_id: UUID, name: Optional[str], description: Optional[str] + ) -> dict: + """Generate a new API key for the user with optional name and + description. + + Args: + user_id (UUID): The ID of the user + name (Optional[str]): Name of the API key + description (Optional[str]): Description of the API key + + Returns: + dict: Contains the API key and message + """ + return await self.providers.auth.create_user_api_key( + user_id=user_id, name=name, description=description + ) + + async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: + """Delete the API key for the user. + + Args: + user_id (UUID): The ID of the user + key_id (str): The ID of the API key + + Returns: + bool: True if the API key was deleted successfully + """ + return await self.providers.auth.delete_user_api_key( + user_id=user_id, key_id=key_id + ) + + async def list_user_api_keys(self, user_id: UUID) -> list[dict]: + """List all API keys for the user. + + Args: + user_id (UUID): The ID of the user + + Returns: + dict: Contains the list of API keys + """ + return await self.providers.auth.list_user_api_keys(user_id) diff --git a/.venv/lib/python3.12/site-packages/core/main/services/base.py b/.venv/lib/python3.12/site-packages/core/main/services/base.py new file mode 100644 index 00000000..dcd98fd5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/base.py @@ -0,0 +1,14 @@ +from abc import ABC + +from ..abstractions import R2RProviders +from ..config import R2RConfig + + +class Service(ABC): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + self.config = config + self.providers = providers diff --git a/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py new file mode 100644 index 00000000..56f32cf8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py @@ -0,0 +1,1358 @@ +import asyncio +import logging +import math +import random +import re +import time +import uuid +import xml.etree.ElementTree as ET +from typing import Any, AsyncGenerator, Coroutine, Optional +from uuid import UUID +from xml.etree.ElementTree import Element + +from core.base import ( + DocumentChunk, + GraphExtraction, + GraphExtractionStatus, + R2RDocumentProcessingError, +) +from core.base.abstractions import ( + Community, + Entity, + GenerationConfig, + GraphConstructionStatus, + R2RException, + Relationship, + StoreType, +) +from core.base.api.models import GraphResponse + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + +MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128 + + +async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]: + """Collects all results from an async generator into a list.""" + results = [] + async for res in result_gen: + results.append(res) + return results + + +class GraphService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def create_entity( + self, + name: str, + description: str, + parent_id: UUID, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return await self.providers.database.graphs_handler.entities.create( + name=name, + parent_id=parent_id, + store_type=StoreType.GRAPHS, + category=category, + description=description, + description_embedding=description_embedding, + metadata=metadata, + ) + + async def update_entity( + self, + entity_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + description_embedding = None + if description is not None: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return await self.providers.database.graphs_handler.entities.update( + entity_id=entity_id, + store_type=StoreType.GRAPHS, + name=name, + description=description, + description_embedding=description_embedding, + category=category, + metadata=metadata, + ) + + async def delete_entity( + self, + parent_id: UUID, + entity_id: UUID, + ): + return await self.providers.database.graphs_handler.entities.delete( + parent_id=parent_id, + entity_ids=[entity_id], + store_type=StoreType.GRAPHS, + ) + + async def get_entities( + self, + parent_id: UUID, + offset: int, + limit: int, + entity_ids: Optional[list[UUID]] = None, + entity_names: Optional[list[str]] = None, + include_embeddings: bool = False, + ): + return await self.providers.database.graphs_handler.get_entities( + parent_id=parent_id, + offset=offset, + limit=limit, + entity_ids=entity_ids, + entity_names=entity_names, + include_embeddings=include_embeddings, + ) + + async def create_relationship( + self, + subject: str, + subject_id: UUID, + predicate: str, + object: str, + object_id: UUID, + parent_id: UUID, + description: str | None = None, + weight: float | None = 1.0, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + description_embedding = None + if description: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return ( + await self.providers.database.graphs_handler.relationships.create( + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + parent_id=parent_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type=StoreType.GRAPHS, + ) + ) + + async def delete_relationship( + self, + parent_id: UUID, + relationship_id: UUID, + ): + return ( + await self.providers.database.graphs_handler.relationships.delete( + parent_id=parent_id, + relationship_ids=[relationship_id], + store_type=StoreType.GRAPHS, + ) + ) + + async def update_relationship( + self, + relationship_id: UUID, + subject: Optional[str] = None, + subject_id: Optional[UUID] = None, + predicate: Optional[str] = None, + object: Optional[str] = None, + object_id: Optional[UUID] = None, + description: Optional[str] = None, + weight: Optional[float] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + description_embedding = None + if description is not None: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + + return ( + await self.providers.database.graphs_handler.relationships.update( + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type=StoreType.GRAPHS, + ) + ) + + async def get_relationships( + self, + parent_id: UUID, + offset: int, + limit: int, + relationship_ids: Optional[list[UUID]] = None, + entity_names: Optional[list[str]] = None, + ): + return await self.providers.database.graphs_handler.relationships.get( + parent_id=parent_id, + store_type=StoreType.GRAPHS, + offset=offset, + limit=limit, + relationship_ids=relationship_ids, + entity_names=entity_names, + ) + + async def create_community( + self, + parent_id: UUID, + name: str, + summary: str, + findings: Optional[list[str]], + rating: Optional[float], + rating_explanation: Optional[str], + ) -> Community: + description_embedding = str( + await self.providers.embedding.async_get_embedding(summary) + ) + return await self.providers.database.graphs_handler.communities.create( + parent_id=parent_id, + store_type=StoreType.GRAPHS, + name=name, + summary=summary, + description_embedding=description_embedding, + findings=findings, + rating=rating, + rating_explanation=rating_explanation, + ) + + async def update_community( + self, + community_id: UUID, + name: Optional[str], + summary: Optional[str], + findings: Optional[list[str]], + rating: Optional[float], + rating_explanation: Optional[str], + ) -> Community: + summary_embedding = None + if summary is not None: + summary_embedding = str( + await self.providers.embedding.async_get_embedding(summary) + ) + + return await self.providers.database.graphs_handler.communities.update( + community_id=community_id, + store_type=StoreType.GRAPHS, + name=name, + summary=summary, + summary_embedding=summary_embedding, + findings=findings, + rating=rating, + rating_explanation=rating_explanation, + ) + + async def delete_community( + self, + parent_id: UUID, + community_id: UUID, + ) -> None: + await self.providers.database.graphs_handler.communities.delete( + parent_id=parent_id, + community_id=community_id, + ) + + async def get_communities( + self, + parent_id: UUID, + offset: int, + limit: int, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, + ): + return await self.providers.database.graphs_handler.get_communities( + parent_id=parent_id, + offset=offset, + limit=limit, + community_ids=community_ids, + include_embeddings=include_embeddings, + ) + + async def list_graphs( + self, + offset: int, + limit: int, + graph_ids: Optional[list[UUID]] = None, + collection_id: Optional[UUID] = None, + ) -> dict[str, list[GraphResponse] | int]: + return await self.providers.database.graphs_handler.list_graphs( + offset=offset, + limit=limit, + filter_graph_ids=graph_ids, + filter_collection_id=collection_id, + ) + + async def update_graph( + self, + collection_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> GraphResponse: + return await self.providers.database.graphs_handler.update( + collection_id=collection_id, + name=name, + description=description, + ) + + async def reset_graph(self, id: UUID) -> bool: + await self.providers.database.graphs_handler.reset( + parent_id=id, + ) + await self.providers.database.documents_handler.set_workflow_status( + id=id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.PENDING, + ) + return True + + async def get_document_ids_for_create_graph( + self, + collection_id: UUID, + **kwargs, + ): + document_status_filter = [ + GraphExtractionStatus.PENDING, + GraphExtractionStatus.FAILED, + ] + + return await self.providers.database.documents_handler.get_document_ids_by_status( + status_type="extraction_status", + status=[str(ele) for ele in document_status_filter], + collection_id=collection_id, + ) + + async def graph_search_results_entity_description( + self, + document_id: UUID, + max_description_input_length: int, + batch_size: int = 256, + **kwargs, + ): + """A new implementation of the old GraphDescriptionPipe logic inline. + No references to pipe objects. + + We: + 1) Count how many entities are in the document + 2) Process them in batches of `batch_size` + 3) For each batch, we retrieve the entity map and possibly call LLM for missing descriptions + """ + start_time = time.time() + logger.info( + f"GraphService: Running graph_search_results_entity_description for doc={document_id}" + ) + + # Count how many doc-entities exist + entity_count = ( + await self.providers.database.graphs_handler.get_entity_count( + document_id=document_id, + distinct=True, + entity_table_name="documents_entities", # or whichever table + ) + ) + logger.info( + f"GraphService: Found {entity_count} doc-entities to describe." + ) + + all_results = [] + num_batches = math.ceil(entity_count / batch_size) + + for i in range(num_batches): + offset = i * batch_size + limit = batch_size + + logger.info( + f"GraphService: describing batch {i + 1}/{num_batches}, offset={offset}, limit={limit}" + ) + + # Actually handle describing the entities in the batch + # We'll collect them into a list via an async generator + gen = self._describe_entities_in_document_batch( + document_id=document_id, + offset=offset, + limit=limit, + max_description_input_length=max_description_input_length, + ) + batch_results = await _collect_async_results(gen) + all_results.append(batch_results) + + # Mark the doc's extraction status as success + await self.providers.database.documents_handler.set_workflow_status( + id=document_id, + status_type="extraction_status", + status=GraphExtractionStatus.SUCCESS, + ) + logger.info( + f"GraphService: Completed graph_search_results_entity_description for doc {document_id} in {time.time() - start_time:.2f}s." + ) + return all_results + + async def _describe_entities_in_document_batch( + self, + document_id: UUID, + offset: int, + limit: int, + max_description_input_length: int, + ) -> AsyncGenerator[str, None]: + """Core logic that replaces GraphDescriptionPipe._run_logic for a + particular document/batch. + + Yields entity-names or some textual result as each entity is updated. + """ + start_time = time.time() + logger.info( + f"Started describing doc={document_id}, offset={offset}, limit={limit}" + ) + + # 1) Get the "entity map" from the DB + entity_map = ( + await self.providers.database.graphs_handler.get_entity_map( + offset=offset, limit=limit, document_id=document_id + ) + ) + total_entities = len(entity_map) + logger.info( + f"_describe_entities_in_document_batch: got {total_entities} items in entity_map for doc={document_id}." + ) + + # 2) For each entity name in the map, we gather sub-entities and relationships + tasks: list[Coroutine[Any, Any, str]] = [] + tasks.extend( + self._process_entity_for_description( + entities=[ + entity if isinstance(entity, Entity) else Entity(**entity) + for entity in entity_info["entities"] + ], + relationships=[ + rel + if isinstance(rel, Relationship) + else Relationship(**rel) + for rel in entity_info["relationships"] + ], + document_id=document_id, + max_description_input_length=max_description_input_length, + ) + for entity_name, entity_info in entity_map.items() + ) + + # 3) Wait for all tasks, yield as they complete + idx = 0 + for coro in asyncio.as_completed(tasks): + result = await coro + idx += 1 + if idx % 100 == 0: + logger.info( + f"_describe_entities_in_document_batch: {idx}/{total_entities} described for doc={document_id}" + ) + yield result + + logger.info( + f"Finished describing doc={document_id} batch offset={offset} in {time.time() - start_time:.2f}s." + ) + + async def _process_entity_for_description( + self, + entities: list[Entity], + relationships: list[Relationship], + document_id: UUID, + max_description_input_length: int, + ) -> str: + """Adapted from the old process_entity function in + GraphDescriptionPipe. + + If entity has no description, call an LLM to create one, then store it. + Returns the name of the top entity (or could store more details). + """ + + def truncate_info(info_list: list[str], max_length: int) -> str: + """Shuffles lines of info to try to keep them distinct, then + accumulates until hitting max_length.""" + random.shuffle(info_list) + truncated_info = "" + current_length = 0 + for info in info_list: + if current_length + len(info) > max_length: + break + truncated_info += info + "\n" + current_length += len(info) + return truncated_info + + # Grab a doc-level summary (optional) to feed into the prompt + response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + document_summary = ( + response["results"][0].summary if response["results"] else None + ) + + # Synthesize a minimal “entity info” string + relationship summary + entity_info = [ + f"{e.name}, {e.description or 'NONE'}" for e in entities + ] + relationships_txt = [ + f"{i + 1}: {r.subject}, {r.object}, {r.predicate} - Summary: {r.description or ''}" + for i, r in enumerate(relationships) + ] + + # We'll describe only the first entity for simplicity + # or you could do them all if needed + main_entity = entities[0] + + if not main_entity.description: + # We only call LLM if the entity is missing a description + messages = await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, + task_inputs={ + "document_summary": document_summary, + "entity_info": truncate_info( + entity_info, max_description_input_length + ), + "relationships_txt": truncate_info( + relationships_txt, max_description_input_length + ), + }, + ) + + # Call the LLM + gen_config = ( + self.providers.database.config.graph_creation_settings.generation_config + or GenerationConfig(model=self.config.app.fast_llm) + ) + llm_resp = await self.providers.llm.aget_completion( + messages=messages, + generation_config=gen_config, + ) + new_description = llm_resp.choices[0].message.content + + if not new_description: + logger.error( + f"No LLM description returned for entity={main_entity.name}" + ) + return main_entity.name + + # create embedding + embed = ( + await self.providers.embedding.async_get_embeddings( + [new_description] + ) + )[0] + + # update DB + main_entity.description = new_description + main_entity.description_embedding = embed + + # Use a method to upsert entity in `documents_entities` or your table + await self.providers.database.graphs_handler.add_entities( + [main_entity], + table_name="documents_entities", + ) + + return main_entity.name + + async def graph_search_results_clustering( + self, + collection_id: UUID, + generation_config: GenerationConfig, + leiden_params: dict, + **kwargs, + ): + """ + Replacement for the old GraphClusteringPipe logic: + 1) call perform_graph_clustering on the DB + 2) return the result + """ + logger.info( + f"Running inline clustering for collection={collection_id} with params={leiden_params}" + ) + return await self._perform_graph_clustering( + collection_id=collection_id, + generation_config=generation_config, + leiden_params=leiden_params, + ) + + async def _perform_graph_clustering( + self, + collection_id: UUID, + generation_config: GenerationConfig, + leiden_params: dict, + ) -> dict: + """The actual clustering logic (previously in + GraphClusteringPipe.cluster_graph_search_results).""" + num_communities = await self.providers.database.graphs_handler.perform_graph_clustering( + collection_id=collection_id, + leiden_params=leiden_params, + ) + return {"num_communities": num_communities} + + async def graph_search_results_community_summary( + self, + offset: int, + limit: int, + max_summary_input_length: int, + generation_config: GenerationConfig, + collection_id: UUID, + leiden_params: Optional[dict] = None, + **kwargs, + ): + """Replacement for the old GraphCommunitySummaryPipe logic. + + Summarizes communities after clustering. Returns an async generator or + you can collect into a list. + """ + logger.info( + f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}" + ) + # We call an internal function that yields summaries + gen = self._summarize_communities( + offset=offset, + limit=limit, + max_summary_input_length=max_summary_input_length, + generation_config=generation_config, + collection_id=collection_id, + leiden_params=leiden_params or {}, + ) + return await _collect_async_results(gen) + + async def _summarize_communities( + self, + offset: int, + limit: int, + max_summary_input_length: int, + generation_config: GenerationConfig, + collection_id: UUID, + leiden_params: dict, + ) -> AsyncGenerator[dict, None]: + """Does the community summary logic from + GraphCommunitySummaryPipe._run_logic. + + Yields each summary dictionary as it completes. + """ + start_time = time.time() + logger.info( + f"Starting community summarization for collection={collection_id}" + ) + + # get all entities & relationships + ( + all_entities, + _, + ) = await self.providers.database.graphs_handler.get_entities( + parent_id=collection_id, + offset=0, + limit=-1, + include_embeddings=False, + ) + ( + all_relationships, + _, + ) = await self.providers.database.graphs_handler.get_relationships( + parent_id=collection_id, + offset=0, + limit=-1, + include_embeddings=False, + ) + + # We can optionally re-run the clustering to produce fresh community assignments + ( + _, + community_clusters, + ) = await self.providers.database.graphs_handler._cluster_and_add_community_info( + relationships=all_relationships, + leiden_params=leiden_params, + collection_id=collection_id, + ) + + # Group clusters + clusters: dict[Any, list[str]] = {} + for item in community_clusters: + cluster_id = item["cluster"] + node_name = item["node"] + clusters.setdefault(cluster_id, []).append(node_name) + + # create an async job for each cluster + tasks: list[Coroutine[Any, Any, dict]] = [] + + tasks.extend( + self._process_community_summary( + community_id=uuid.uuid4(), + nodes=nodes, + all_entities=all_entities, + all_relationships=all_relationships, + max_summary_input_length=max_summary_input_length, + generation_config=generation_config, + collection_id=collection_id, + ) + for nodes in clusters.values() + ) + + total_jobs = len(tasks) + results_returned = 0 + total_errors = 0 + + for coro in asyncio.as_completed(tasks): + summary = await coro + results_returned += 1 + if results_returned % 50 == 0: + logger.info( + f"Community summaries: {results_returned}/{total_jobs} done in {time.time() - start_time:.2f}s" + ) + if "error" in summary: + total_errors += 1 + yield summary + + if total_errors > 0: + logger.warning( + f"{total_errors} communities failed summarization out of {total_jobs}" + ) + + async def _process_community_summary( + self, + community_id: UUID, + nodes: list[str], + all_entities: list[Entity], + all_relationships: list[Relationship], + max_summary_input_length: int, + generation_config: GenerationConfig, + collection_id: UUID, + ) -> dict: + """ + Summarize a single community: gather all relevant entities/relationships, call LLM to generate an XML block, + parse it, store the result as a community in DB. + """ + # (Equivalent to process_community in old code) + # fetch the collection description (optional) + response = await self.providers.database.collections_handler.get_collections_overview( + offset=0, + limit=1, + filter_collection_ids=[collection_id], + ) + collection_description = ( + response["results"][0].description if response["results"] else None # type: ignore + ) + + # filter out relevant entities / relationships + entities = [e for e in all_entities if e.name in nodes] + relationships = [ + r + for r in all_relationships + if r.subject in nodes and r.object in nodes + ] + if not entities and not relationships: + return { + "community_id": community_id, + "error": f"No data in this community (nodes={nodes})", + } + + # Create the big input text for the LLM + input_text = await self._community_summary_prompt( + entities, + relationships, + max_summary_input_length, + ) + + # Attempt up to 3 times to parse + for attempt in range(3): + try: + # Build the prompt + messages = await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=self.providers.database.config.graph_enrichment_settings.graph_communities_prompt, + task_inputs={ + "collection_description": collection_description, + "input_text": input_text, + }, + ) + llm_resp = await self.providers.llm.aget_completion( + messages=messages, + generation_config=generation_config, + ) + llm_text = llm_resp.choices[0].message.content or "" + + # find <community>...</community> XML + match = re.search( + r"<community>.*?</community>", llm_text, re.DOTALL + ) + if not match: + raise ValueError( + "No <community> XML found in LLM response" + ) + + xml_content = match.group(0) + root = ET.fromstring(xml_content) + + # extract fields + name_elem = root.find("name") + summary_elem = root.find("summary") + rating_elem = root.find("rating") + rating_expl_elem = root.find("rating_explanation") + findings_elem = root.find("findings") + + name = name_elem.text if name_elem is not None else "" + summary = summary_elem.text if summary_elem is not None else "" + rating = ( + float(rating_elem.text) + if isinstance(rating_elem, Element) and rating_elem.text + else "" + ) + rating_explanation = ( + rating_expl_elem.text + if rating_expl_elem is not None + else None + ) + findings = ( + [f.text for f in findings_elem.findall("finding")] + if findings_elem is not None + else [] + ) + + # build embedding + embed_text = ( + "Summary:\n" + + (summary or "") + + "\n\nFindings:\n" + + "\n".join( + finding for finding in findings if finding is not None + ) + ) + embedding = await self.providers.embedding.async_get_embedding( + embed_text + ) + + # build Community object + community = Community( + community_id=community_id, + collection_id=collection_id, + name=name, + summary=summary, + rating=rating, + rating_explanation=rating_explanation, + findings=findings, + description_embedding=embedding, + ) + + # store it + await self.providers.database.graphs_handler.add_community( + community + ) + + return { + "community_id": community_id, + "name": name, + } + + except Exception as e: + logger.error( + f"Error summarizing community {community_id}: {e}" + ) + if attempt == 2: + return {"community_id": community_id, "error": str(e)} + await asyncio.sleep(1) + + # fallback + return {"community_id": community_id, "error": "Failed after retries"} + + async def _community_summary_prompt( + self, + entities: list[Entity], + relationships: list[Relationship], + max_summary_input_length: int, + ) -> str: + """Gathers the entity/relationship text, tries not to exceed + `max_summary_input_length`.""" + # Group them by entity.name + entity_map: dict[str, dict] = {} + for e in entities: + entity_map.setdefault( + e.name, {"entities": [], "relationships": []} + ) + entity_map[e.name]["entities"].append(e) + + for r in relationships: + # subject + entity_map.setdefault( + r.subject, {"entities": [], "relationships": []} + ) + entity_map[r.subject]["relationships"].append(r) + + # sort by # of relationships + sorted_entries = sorted( + entity_map.items(), + key=lambda x: len(x[1]["relationships"]), + reverse=True, + ) + + # build up the prompt text + prompt_chunks = [] + cur_len = 0 + for entity_name, data in sorted_entries: + block = f"\nEntity: {entity_name}\nDescriptions:\n" + block += "\n".join( + f"{e.id},{(e.description or '')}" for e in data["entities"] + ) + block += "\nRelationships:\n" + block += "\n".join( + f"{r.id},{r.subject},{r.object},{r.predicate},{r.description or ''}" + for r in data["relationships"] + ) + # check length + if cur_len + len(block) > max_summary_input_length: + prompt_chunks.append( + block[: max_summary_input_length - cur_len] + ) + break + else: + prompt_chunks.append(block) + cur_len += len(block) + + return "".join(prompt_chunks) + + async def delete( + self, + collection_id: UUID, + **kwargs, + ): + return await self.providers.database.graphs_handler.delete( + collection_id=collection_id, + ) + + async def graph_search_results_extraction( + self, + document_id: UUID, + generation_config: GenerationConfig, + entity_types: list[str], + relation_types: list[str], + chunk_merge_count: int, + filter_out_existing_chunks: bool = True, + total_tasks: Optional[int] = None, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]: + """The original “extract Graph from doc” logic, but inlined instead of + referencing a pipe.""" + start_time = time.time() + + logger.info( + f"Graph Extraction: Processing document {document_id} for graph extraction" + ) + + # Retrieve chunks from DB + chunks = [] + limit = 100 + offset = 0 + while True: + chunk_req = await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=offset, + limit=limit, + ) + new_chunk_objs = [ + DocumentChunk( + id=chunk["id"], + document_id=chunk["document_id"], + owner_id=chunk["owner_id"], + collection_ids=chunk["collection_ids"], + data=chunk["text"], + metadata=chunk["metadata"], + ) + for chunk in chunk_req["results"] + ] + chunks.extend(new_chunk_objs) + if len(chunk_req["results"]) < limit: + break + offset += limit + + if not chunks: + logger.info(f"No chunks found for document {document_id}") + raise R2RException( + message="No chunks found for document", + status_code=404, + ) + + # Possibly filter out any chunks that have already been processed + if filter_out_existing_chunks: + existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids( + document_id=document_id + ) + before_count = len(chunks) + chunks = [c for c in chunks if c.id not in existing_chunk_ids] + logger.info( + f"Filtered out {len(existing_chunk_ids)} existing chunk-IDs. {before_count}->{len(chunks)} remain." + ) + if not chunks: + return # nothing left to yield + + # sort by chunk_order if present + chunks = sorted( + chunks, + key=lambda x: x.metadata.get("chunk_order", float("inf")), + ) + + # group them + grouped_chunks = [ + chunks[i : i + chunk_merge_count] + for i in range(0, len(chunks), chunk_merge_count) + ] + + logger.info( + f"Graph Extraction: Created {len(grouped_chunks)} tasks for doc={document_id}" + ) + tasks = [ + asyncio.create_task( + self._extract_graph_search_results_from_chunk_group( + chunk_group, + generation_config, + entity_types, + relation_types, + ) + ) + for chunk_group in grouped_chunks + ] + + completed_tasks = 0 + for t in asyncio.as_completed(tasks): + try: + yield await t + completed_tasks += 1 + if completed_tasks % 100 == 0: + logger.info( + f"Graph Extraction: completed {completed_tasks}/{len(tasks)} tasks" + ) + except Exception as e: + logger.error(f"Error extracting from chunk group: {e}") + yield R2RDocumentProcessingError( + document_id=document_id, + error_message=str(e), + ) + + logger.info( + f"Graph Extraction: done with {document_id}, time={time.time() - start_time:.2f}s" + ) + + async def _extract_graph_search_results_from_chunk_group( + self, + chunks: list[DocumentChunk], + generation_config: GenerationConfig, + entity_types: list[str], + relation_types: list[str], + retries: int = 5, + delay: int = 2, + ) -> GraphExtraction: + """(Equivalent to _extract_graph_search_results in old code.) Merges + chunk data, calls LLM, parses XML, returns GraphExtraction object.""" + combined_extraction: str = " ".join( + [ + c.data.decode("utf-8") if isinstance(c.data, bytes) else c.data + for c in chunks + if c.data + ] + ) + + # Possibly get doc-level summary + doc_id = chunks[0].document_id + response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[doc_id], + ) + document_summary = ( + response["results"][0].summary if response["results"] else None + ) + + # Build messages/prompt + prompt_name = self.providers.database.config.graph_creation_settings.graph_extraction_prompt + messages = ( + await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=prompt_name, + task_inputs={ + "document_summary": document_summary or "", + "input": combined_extraction, + "entity_types": "\n".join(entity_types), + "relation_types": "\n".join(relation_types), + }, + ) + ) + + for attempt in range(retries): + try: + resp = await self.providers.llm.aget_completion( + messages, generation_config=generation_config + ) + graph_search_results_str = resp.choices[0].message.content + + if not graph_search_results_str: + raise R2RException( + "No extraction found in LLM response.", + 400, + ) + + # parse the XML + ( + entities, + relationships, + ) = await self._parse_graph_search_results_extraction_xml( + graph_search_results_str, chunks + ) + return GraphExtraction( + entities=entities, relationships=relationships + ) + + except Exception as e: + if attempt < retries - 1: + await asyncio.sleep(delay) + continue + else: + logger.error( + f"All extraction attempts for doc={doc_id} and chunks{[chunk.id for chunk in chunks]} failed with error:\n{e}" + ) + return GraphExtraction(entities=[], relationships=[]) + + return GraphExtraction(entities=[], relationships=[]) + + async def _parse_graph_search_results_extraction_xml( + self, response_str: str, chunks: list[DocumentChunk] + ) -> tuple[list[Entity], list[Relationship]]: + """Helper to parse the LLM's XML format, handle edge cases/cleanup, + produce Entities/Relationships.""" + + def sanitize_xml(r: str) -> str: + # Remove markdown fences + r = re.sub(r"```xml|```", "", r) + # Remove xml instructions or userStyle + r = re.sub(r"<\?.*?\?>", "", r) + r = re.sub(r"<userStyle>.*?</userStyle>", "", r) + # Replace bare `&` with `&` + r = re.sub(r"&(?!amp;|quot;|apos;|lt;|gt;)", "&", r) + # Also remove <root> if it appears + r = r.replace("<root>", "").replace("</root>", "") + return r.strip() + + cleaned_xml = sanitize_xml(response_str) + wrapped = f"<root>{cleaned_xml}</root>" + try: + root = ET.fromstring(wrapped) + except ET.ParseError: + raise R2RException( + f"Failed to parse XML:\nData: {wrapped[:1000]}...", 400 + ) from None + + entities_elems = root.findall(".//entity") + if ( + len(response_str) > MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH + and len(entities_elems) == 0 + ): + raise R2RException( + f"No <entity> found in LLM XML, possibly malformed. Response excerpt: {response_str[:300]}", + 400, + ) + + # build entity objects + doc_id = chunks[0].document_id + chunk_ids = [c.id for c in chunks] + entities_list: list[Entity] = [] + for element in entities_elems: + name_attr = element.get("name") + type_elem = element.find("type") + desc_elem = element.find("description") + category = type_elem.text if type_elem is not None else None + desc = desc_elem.text if desc_elem is not None else None + desc_embed = await self.providers.embedding.async_get_embedding( + desc or "" + ) + ent = Entity( + category=category, + description=desc, + name=name_attr, + parent_id=doc_id, + chunk_ids=chunk_ids, + description_embedding=desc_embed, + attributes={}, + ) + entities_list.append(ent) + + # build relationship objects + relationships_list: list[Relationship] = [] + rel_elems = root.findall(".//relationship") + for r_elem in rel_elems: + source_elem = r_elem.find("source") + target_elem = r_elem.find("target") + type_elem = r_elem.find("type") + desc_elem = r_elem.find("description") + weight_elem = r_elem.find("weight") + try: + subject = source_elem.text if source_elem is not None else "" + object_ = target_elem.text if target_elem is not None else "" + predicate = type_elem.text if type_elem is not None else "" + desc = desc_elem.text if desc_elem is not None else "" + weight = ( + float(weight_elem.text) + if isinstance(weight_elem, Element) and weight_elem.text + else "" + ) + embed = await self.providers.embedding.async_get_embedding( + desc or "" + ) + + rel = Relationship( + subject=subject, + predicate=predicate, + object=object_, + description=desc, + weight=weight, + parent_id=doc_id, + chunk_ids=chunk_ids, + attributes={}, + description_embedding=embed, + ) + relationships_list.append(rel) + except Exception: + continue + return entities_list, relationships_list + + async def store_graph_search_results_extractions( + self, + graph_search_results_extractions: list[GraphExtraction], + ): + """Stores a batch of knowledge graph extractions in the DB.""" + for extraction in graph_search_results_extractions: + # Map name->id after creation + entities_id_map = {} + for e in extraction.entities: + if e.parent_id is not None: + result = await self.providers.database.graphs_handler.entities.create( + name=e.name, + parent_id=e.parent_id, + store_type=StoreType.DOCUMENTS, + category=e.category, + description=e.description, + description_embedding=e.description_embedding, + chunk_ids=e.chunk_ids, + metadata=e.metadata, + ) + entities_id_map[e.name] = result.id + else: + logger.warning(f"Skipping entity with None parent_id: {e}") + + # Insert relationships + for rel in extraction.relationships: + subject_id = entities_id_map.get(rel.subject) + object_id = entities_id_map.get(rel.object) + parent_id = rel.parent_id + + if any( + id is None for id in (subject_id, object_id, parent_id) + ): + logger.warning(f"Missing ID for relationship: {rel}") + continue + + assert isinstance(subject_id, UUID) + assert isinstance(object_id, UUID) + assert isinstance(parent_id, UUID) + + await self.providers.database.graphs_handler.relationships.create( + subject=rel.subject, + subject_id=subject_id, + predicate=rel.predicate, + object=rel.object, + object_id=object_id, + parent_id=parent_id, + description=rel.description, + description_embedding=rel.description_embedding, + weight=rel.weight, + metadata=rel.metadata, + store_type=StoreType.DOCUMENTS, + ) + + async def deduplicate_document_entities( + self, + document_id: UUID, + ): + """ + Inlined from old code: merges duplicates by name, calls LLM for a new consolidated description, updates the record. + """ + merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks( + parent_id=document_id, + store_type=StoreType.DOCUMENTS, + ) + + # Grab doc summary + response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + document_summary = ( + response["results"][0].summary if response["results"] else None + ) + + # For each merged entity + for original_entities, merged_entity in merged_results: + # Summarize them with LLM + entity_info = "\n".join( + e.description for e in original_entities if e.description + ) + messages = await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, + task_inputs={ + "document_summary": document_summary, + "entity_info": f"{merged_entity.name}\n{entity_info}", + "relationships_txt": "", + }, + ) + gen_config = ( + self.config.database.graph_creation_settings.generation_config + or GenerationConfig(model=self.config.app.fast_llm) + ) + resp = await self.providers.llm.aget_completion( + messages, generation_config=gen_config + ) + new_description = resp.choices[0].message.content + + new_embedding = await self.providers.embedding.async_get_embedding( + new_description or "" + ) + + if merged_entity.id is not None: + await self.providers.database.graphs_handler.entities.update( + entity_id=merged_entity.id, + store_type=StoreType.DOCUMENTS, + description=new_description, + description_embedding=str(new_embedding), + ) + else: + logger.warning("Skipping update for entity with None id") diff --git a/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py new file mode 100644 index 00000000..55b06911 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py @@ -0,0 +1,983 @@ +import asyncio +import json +import logging +from datetime import datetime +from typing import Any, AsyncGenerator, Optional, Sequence +from uuid import UUID + +from fastapi import HTTPException + +from core.base import ( + Document, + DocumentChunk, + DocumentResponse, + DocumentType, + GenerationConfig, + IngestionStatus, + R2RException, + RawChunk, + UnprocessedChunk, + Vector, + VectorEntry, + VectorType, + generate_id, +) +from core.base.abstractions import ( + ChunkEnrichmentSettings, + IndexMeasure, + IndexMethod, + R2RDocumentProcessingError, + VectorTableName, +) +from core.base.api.models import User +from shared.abstractions import PDFParsingError, PopplerNotFoundError + +from ..abstractions import R2RProviders +from ..config import R2RConfig + +logger = logging.getLogger() +STARTING_VERSION = "v0" + + +class IngestionService: + """A refactored IngestionService that inlines all pipe logic for parsing, + embedding, and vector storage directly in its methods.""" + + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ) -> None: + self.config = config + self.providers = providers + + async def ingest_file_ingress( + self, + file_data: dict, + user: User, + document_id: UUID, + size_in_bytes, + metadata: Optional[dict] = None, + version: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> dict: + """Pre-ingests a file by creating or validating the DocumentResponse + entry. + + Does not actually parse/ingest the content. (See parse_file() for that + step.) + """ + try: + if not file_data: + raise R2RException( + status_code=400, message="No files provided for ingestion." + ) + if not file_data.get("filename"): + raise R2RException( + status_code=400, message="File name not provided." + ) + + metadata = metadata or {} + version = version or STARTING_VERSION + + document_info = self.create_document_info_from_file( + document_id, + user, + file_data["filename"], + metadata, + version, + size_in_bytes, + ) + + existing_document_info = ( + await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[user.id], + filter_document_ids=[document_id], + ) + )["results"] + + # Validate ingestion status for re-ingestion + if len(existing_document_info) > 0: + existing_doc = existing_document_info[0] + if existing_doc.ingestion_status == IngestionStatus.SUCCESS: + raise R2RException( + status_code=409, + message=( + f"Document {document_id} already exists. " + "Submit a DELETE request to `/documents/{document_id}` " + "to delete this document and allow for re-ingestion." + ), + ) + elif existing_doc.ingestion_status != IngestionStatus.FAILED: + raise R2RException( + status_code=409, + message=( + f"Document {document_id} is currently ingesting " + f"with status {existing_doc.ingestion_status}." + ), + ) + + # Set to PARSING until we actually parse + document_info.ingestion_status = IngestionStatus.PARSING + await self.providers.database.documents_handler.upsert_documents_overview( + document_info + ) + + return { + "info": document_info, + } + except R2RException as e: + logger.error(f"R2RException in ingest_file_ingress: {str(e)}") + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error during ingestion: {str(e)}" + ) from e + + def create_document_info_from_file( + self, + document_id: UUID, + user: User, + file_name: str, + metadata: dict, + version: str, + size_in_bytes: int, + ) -> DocumentResponse: + file_extension = ( + file_name.split(".")[-1].lower() if file_name != "N/A" else "txt" + ) + if file_extension.upper() not in DocumentType.__members__: + raise R2RException( + status_code=415, + message=f"'{file_extension}' is not a valid DocumentType.", + ) + + metadata = metadata or {} + metadata["version"] = version + + return DocumentResponse( + id=document_id, + owner_id=user.id, + collection_ids=metadata.get("collection_ids", []), + document_type=DocumentType[file_extension.upper()], + title=( + metadata.get("title", file_name.split("/")[-1]) + if file_name != "N/A" + else "N/A" + ), + metadata=metadata, + version=version, + size_in_bytes=size_in_bytes, + ingestion_status=IngestionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + + def _create_document_info_from_chunks( + self, + document_id: UUID, + user: User, + chunks: list[RawChunk], + metadata: dict, + version: str, + ) -> DocumentResponse: + metadata = metadata or {} + metadata["version"] = version + + return DocumentResponse( + id=document_id, + owner_id=user.id, + collection_ids=metadata.get("collection_ids", []), + document_type=DocumentType.TXT, + title=metadata.get("title", f"Ingested Chunks - {document_id}"), + metadata=metadata, + version=version, + size_in_bytes=sum( + len(chunk.text.encode("utf-8")) for chunk in chunks + ), + ingestion_status=IngestionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + + async def parse_file( + self, + document_info: DocumentResponse, + ingestion_config: dict | None, + ) -> AsyncGenerator[DocumentChunk, None]: + """Reads the file content from the DB, calls the ingestion + provider to parse, and yields DocumentChunk objects.""" + version = document_info.version or "v0" + ingestion_config_override = ingestion_config or {} + + # The ingestion config might specify a different provider, etc. + override_provider = ingestion_config_override.pop("provider", None) + if ( + override_provider + and override_provider != self.providers.ingestion.config.provider + ): + raise ValueError( + f"Provider '{override_provider}' does not match ingestion provider " + f"'{self.providers.ingestion.config.provider}'." + ) + + try: + # Pull file from DB + retrieved = ( + await self.providers.database.files_handler.retrieve_file( + document_info.id + ) + ) + if not retrieved: + # No file found in the DB, can't parse + raise R2RDocumentProcessingError( + document_id=document_info.id, + error_message="No file content found in DB for this document.", + ) + + file_name, file_wrapper, file_size = retrieved + + # Read the content + with file_wrapper as file_content_stream: + file_content = file_content_stream.read() + + # Build a barebones Document object + doc = Document( + id=document_info.id, + collection_ids=document_info.collection_ids, + owner_id=document_info.owner_id, + metadata={ + "document_type": document_info.document_type.value, + **document_info.metadata, + }, + document_type=document_info.document_type, + ) + + # Delegate to the ingestion provider to parse + async for extraction in self.providers.ingestion.parse( + file_content, # raw bytes + doc, + ingestion_config_override, + ): + # Adjust chunk ID to incorporate version + # or any other needed transformations + extraction.id = generate_id(f"{extraction.id}_{version}") + extraction.metadata["version"] = version + yield extraction + + except (PopplerNotFoundError, PDFParsingError) as e: + raise R2RDocumentProcessingError( + error_message=e.message, + document_id=document_info.id, + status_code=e.status_code, + ) from None + except Exception as e: + if isinstance(e, R2RException): + raise + raise R2RDocumentProcessingError( + document_id=document_info.id, + error_message=f"Error parsing document: {str(e)}", + ) from e + + async def augment_document_info( + self, + document_info: DocumentResponse, + chunked_documents: list[dict], + ) -> None: + if not self.config.ingestion.skip_document_summary: + document = f"Document Title: {document_info.title}\n" + if document_info.metadata != {}: + document += f"Document Metadata: {json.dumps(document_info.metadata)}\n" + + document += "Document Text:\n" + for chunk in chunked_documents[ + : self.config.ingestion.chunks_for_document_summary + ]: + document += chunk["data"] + + messages = await self.providers.database.prompts_handler.get_message_payload( + system_prompt_name=self.config.ingestion.document_summary_system_prompt, + task_prompt_name=self.config.ingestion.document_summary_task_prompt, + task_inputs={ + "document": document[ + : self.config.ingestion.document_summary_max_length + ] + }, + ) + + response = await self.providers.llm.aget_completion( + messages=messages, + generation_config=GenerationConfig( + model=self.config.ingestion.document_summary_model + or self.config.app.fast_llm + ), + ) + + document_info.summary = response.choices[0].message.content # type: ignore + + if not document_info.summary: + raise ValueError("Expected a generated response.") + + embedding = await self.providers.embedding.async_get_embedding( + text=document_info.summary, + ) + document_info.summary_embedding = embedding + return + + async def embed_document( + self, + chunked_documents: list[dict], + embedding_batch_size: int = 8, + ) -> AsyncGenerator[VectorEntry, None]: + """Inline replacement for the old embedding_pipe.run(...). + + Batches the embedding calls and yields VectorEntry objects. + """ + if not chunked_documents: + return + + concurrency_limit = ( + self.providers.embedding.config.concurrent_request_limit or 5 + ) + extraction_batch: list[DocumentChunk] = [] + tasks: set[asyncio.Task] = set() + + async def process_batch( + batch: list[DocumentChunk], + ) -> list[VectorEntry]: + # All text from the batch + texts = [ + ( + ex.data.decode("utf-8") + if isinstance(ex.data, bytes) + else ex.data + ) + for ex in batch + ] + # Retrieve embeddings in bulk + vectors = await self.providers.embedding.async_get_embeddings( + texts, # list of strings + ) + # Zip them back together + results = [] + for raw_vector, extraction in zip(vectors, batch, strict=False): + results.append( + VectorEntry( + id=extraction.id, + document_id=extraction.document_id, + owner_id=extraction.owner_id, + collection_ids=extraction.collection_ids, + vector=Vector(data=raw_vector, type=VectorType.FIXED), + text=( + extraction.data.decode("utf-8") + if isinstance(extraction.data, bytes) + else str(extraction.data) + ), + metadata={**extraction.metadata}, + ) + ) + return results + + async def run_process_batch(batch: list[DocumentChunk]): + return await process_batch(batch) + + # Convert each chunk dict to a DocumentChunk + for chunk_dict in chunked_documents: + extraction = DocumentChunk.from_dict(chunk_dict) + extraction_batch.append(extraction) + + # If we hit a batch threshold, spawn a task + if len(extraction_batch) >= embedding_batch_size: + tasks.add( + asyncio.create_task(run_process_batch(extraction_batch)) + ) + extraction_batch = [] + + # If tasks are at concurrency limit, wait for the first to finish + while len(tasks) >= concurrency_limit: + done, tasks = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + for t in done: + for vector_entry in await t: + yield vector_entry + + # Handle any leftover items + if extraction_batch: + tasks.add(asyncio.create_task(run_process_batch(extraction_batch))) + + # Gather remaining tasks + for future_task in asyncio.as_completed(tasks): + for vector_entry in await future_task: + yield vector_entry + + async def store_embeddings( + self, + embeddings: Sequence[dict | VectorEntry], + storage_batch_size: int = 128, + ) -> AsyncGenerator[str, None]: + """Inline replacement for the old vector_storage_pipe.run(...). + + Batches up the vector entries, enforces usage limits, stores them, and + yields a success/error string (or you could yield a StorageResult). + """ + if not embeddings: + return + + vector_entries: list[VectorEntry] = [] + for item in embeddings: + if isinstance(item, VectorEntry): + vector_entries.append(item) + else: + vector_entries.append(VectorEntry.from_dict(item)) + + vector_batch: list[VectorEntry] = [] + document_counts: dict[UUID, int] = {} + + # We'll track usage from the first user we see; if your scenario allows + # multiple user owners in a single ingestion, you'd need to refine usage checks. + current_usage = None + user_id_for_usage_check: UUID | None = None + + count = 0 + + for msg in vector_entries: + # If we haven't set usage yet, do so on the first chunk + if current_usage is None: + user_id_for_usage_check = msg.owner_id + usage_data = ( + await self.providers.database.chunks_handler.list_chunks( + limit=1, + offset=0, + filters={"owner_id": msg.owner_id}, + ) + ) + current_usage = usage_data["total_entries"] + + # Figure out the user's limit + user = await self.providers.database.users_handler.get_user_by_id( + msg.owner_id + ) + max_chunks = ( + self.providers.database.config.app.default_max_chunks_per_user + ) + if user.limits_overrides and "max_chunks" in user.limits_overrides: + max_chunks = user.limits_overrides["max_chunks"] + + # Add to our local batch + vector_batch.append(msg) + document_counts[msg.document_id] = ( + document_counts.get(msg.document_id, 0) + 1 + ) + count += 1 + + # Check usage + if ( + current_usage is not None + and (current_usage + len(vector_batch) + count) > max_chunks + ): + error_message = f"User {msg.owner_id} has exceeded the maximum number of allowed chunks: {max_chunks}" + logger.error(error_message) + yield error_message + continue + + # Once we hit our batch size, store them + if len(vector_batch) >= storage_batch_size: + try: + await ( + self.providers.database.chunks_handler.upsert_entries( + vector_batch + ) + ) + except Exception as e: + logger.error(f"Failed to store vector batch: {e}") + yield f"Error: {e}" + vector_batch.clear() + + # Store any leftover items + if vector_batch: + try: + await self.providers.database.chunks_handler.upsert_entries( + vector_batch + ) + except Exception as e: + logger.error(f"Failed to store final vector batch: {e}") + yield f"Error: {e}" + + # Summaries + for doc_id, cnt in document_counts.items(): + info_msg = f"Successful ingestion for document_id: {doc_id}, with vector count: {cnt}" + logger.info(info_msg) + yield info_msg + + async def finalize_ingestion( + self, document_info: DocumentResponse + ) -> None: + """Called at the end of a successful ingestion pipeline to set the + document status to SUCCESS or similar final steps.""" + + async def empty_generator(): + yield document_info + + await self.update_document_status( + document_info, IngestionStatus.SUCCESS + ) + return empty_generator() + + async def update_document_status( + self, + document_info: DocumentResponse, + status: IngestionStatus, + metadata: Optional[dict] = None, + ) -> None: + document_info.ingestion_status = status + if metadata: + document_info.metadata = {**document_info.metadata, **metadata} + await self._update_document_status_in_db(document_info) + + async def _update_document_status_in_db( + self, document_info: DocumentResponse + ): + try: + await self.providers.database.documents_handler.upsert_documents_overview( + document_info + ) + except Exception as e: + logger.error( + f"Failed to update document status: {document_info.id}. Error: {str(e)}" + ) + + async def ingest_chunks_ingress( + self, + document_id: UUID, + metadata: Optional[dict], + chunks: list[RawChunk], + user: User, + *args: Any, + **kwargs: Any, + ) -> DocumentResponse: + """Directly ingest user-provided text chunks (rather than from a + file).""" + if not chunks: + raise R2RException( + status_code=400, message="No chunks provided for ingestion." + ) + metadata = metadata or {} + version = STARTING_VERSION + + document_info = self._create_document_info_from_chunks( + document_id, + user, + chunks, + metadata, + version, + ) + + existing_document_info = ( + await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_user_ids=[user.id], + filter_document_ids=[document_id], + ) + )["results"] + if len(existing_document_info) > 0: + existing_doc = existing_document_info[0] + if existing_doc.ingestion_status != IngestionStatus.FAILED: + raise R2RException( + status_code=409, + message=( + f"Document {document_id} was already ingested " + "and is not in a failed state." + ), + ) + + await self.providers.database.documents_handler.upsert_documents_overview( + document_info + ) + return document_info + + async def update_chunk_ingress( + self, + document_id: UUID, + chunk_id: UUID, + text: str, + user: User, + metadata: Optional[dict] = None, + *args: Any, + **kwargs: Any, + ) -> dict: + """Update an individual chunk's text and metadata, re-embed, and re- + store it.""" + # Verify chunk exists and user has access + existing_chunks = ( + await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=0, + limit=1, + ) + ) + if not existing_chunks["results"]: + raise R2RException( + status_code=404, + message=f"Chunk with chunk_id {chunk_id} not found.", + ) + + existing_chunk = ( + await self.providers.database.chunks_handler.get_chunk(chunk_id) + ) + if not existing_chunk: + raise R2RException( + status_code=404, + message=f"Chunk with id {chunk_id} not found", + ) + + if ( + str(existing_chunk["owner_id"]) != str(user.id) + and not user.is_superuser + ): + raise R2RException( + status_code=403, + message="You don't have permission to modify this chunk.", + ) + + # Merge metadata + merged_metadata = {**existing_chunk["metadata"]} + if metadata is not None: + merged_metadata |= metadata + + # Create updated chunk + extraction_data = { + "id": chunk_id, + "document_id": document_id, + "collection_ids": kwargs.get( + "collection_ids", existing_chunk["collection_ids"] + ), + "owner_id": existing_chunk["owner_id"], + "data": text or existing_chunk["text"], + "metadata": merged_metadata, + } + extraction = DocumentChunk(**extraction_data).model_dump() + + # Re-embed + embeddings_generator = self.embed_document( + [extraction], embedding_batch_size=1 + ) + embeddings = [] + async for embedding in embeddings_generator: + embeddings.append(embedding) + + # Re-store + store_gen = self.store_embeddings(embeddings, storage_batch_size=1) + async for _ in store_gen: + pass + + return extraction + + async def _get_enriched_chunk_text( + self, + chunk_idx: int, + chunk: dict, + document_id: UUID, + document_summary: str | None, + chunk_enrichment_settings: ChunkEnrichmentSettings, + list_document_chunks: list[dict], + ) -> VectorEntry: + """Helper for chunk_enrichment. + + Leverages an LLM to rewrite or expand chunk text, then re-embeds it. + """ + preceding_chunks = [ + list_document_chunks[idx]["text"] + for idx in range( + max(0, chunk_idx - chunk_enrichment_settings.n_chunks), + chunk_idx, + ) + ] + succeeding_chunks = [ + list_document_chunks[idx]["text"] + for idx in range( + chunk_idx + 1, + min( + len(list_document_chunks), + chunk_idx + chunk_enrichment_settings.n_chunks + 1, + ), + ) + ] + try: + # Obtain the updated text from the LLM + updated_chunk_text = ( + ( + await self.providers.llm.aget_completion( + messages=await self.providers.database.prompts_handler.get_message_payload( + task_prompt_name=chunk_enrichment_settings.chunk_enrichment_prompt, + task_inputs={ + "document_summary": document_summary or "None", + "chunk": chunk["text"], + "preceding_chunks": ( + "\n".join(preceding_chunks) + if preceding_chunks + else "None" + ), + "succeeding_chunks": ( + "\n".join(succeeding_chunks) + if succeeding_chunks + else "None" + ), + "chunk_size": self.config.ingestion.chunk_size + or 1024, + }, + ), + generation_config=chunk_enrichment_settings.generation_config + or GenerationConfig(model=self.config.app.fast_llm), + ) + ) + .choices[0] + .message.content + ) + except Exception: + updated_chunk_text = chunk["text"] + chunk["metadata"]["chunk_enrichment_status"] = "failed" + else: + chunk["metadata"]["chunk_enrichment_status"] = ( + "success" if updated_chunk_text else "failed" + ) + + if not updated_chunk_text or not isinstance(updated_chunk_text, str): + updated_chunk_text = str(chunk["text"]) + chunk["metadata"]["chunk_enrichment_status"] = "failed" + + # Re-embed + data = await self.providers.embedding.async_get_embedding( + updated_chunk_text + ) + chunk["metadata"]["original_text"] = chunk["text"] + + return VectorEntry( + id=generate_id(str(chunk["id"])), + vector=Vector(data=data, type=VectorType.FIXED, length=len(data)), + document_id=document_id, + owner_id=chunk["owner_id"], + collection_ids=chunk["collection_ids"], + text=updated_chunk_text, + metadata=chunk["metadata"], + ) + + async def chunk_enrichment( + self, + document_id: UUID, + document_summary: str | None, + chunk_enrichment_settings: ChunkEnrichmentSettings, + ) -> int: + """Example function that modifies chunk text via an LLM then re-embeds + and re-stores all chunks for the given document.""" + list_document_chunks = ( + await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=0, + limit=-1, + ) + )["results"] + + new_vector_entries: list[VectorEntry] = [] + tasks = [] + total_completed = 0 + + for chunk_idx, chunk in enumerate(list_document_chunks): + tasks.append( + self._get_enriched_chunk_text( + chunk_idx=chunk_idx, + chunk=chunk, + document_id=document_id, + document_summary=document_summary, + chunk_enrichment_settings=chunk_enrichment_settings, + list_document_chunks=list_document_chunks, + ) + ) + + # Process in batches of e.g. 128 concurrency + if len(tasks) == 128: + new_vector_entries.extend(await asyncio.gather(*tasks)) + total_completed += 128 + logger.info( + f"Completed {total_completed} out of {len(list_document_chunks)} chunks for document {document_id}" + ) + tasks = [] + + # Finish any remaining tasks + new_vector_entries.extend(await asyncio.gather(*tasks)) + logger.info( + f"Completed enrichment of {len(list_document_chunks)} chunks for document {document_id}" + ) + + # Delete old chunks from vector db + await self.providers.database.chunks_handler.delete( + filters={"document_id": document_id} + ) + + # Insert the newly enriched entries + await self.providers.database.chunks_handler.upsert_entries( + new_vector_entries + ) + return len(new_vector_entries) + + async def list_chunks( + self, + offset: int, + limit: int, + filters: Optional[dict[str, Any]] = None, + include_vectors: bool = False, + *args: Any, + **kwargs: Any, + ) -> dict: + return await self.providers.database.chunks_handler.list_chunks( + offset=offset, + limit=limit, + filters=filters, + include_vectors=include_vectors, + ) + + async def get_chunk( + self, + chunk_id: UUID, + *args: Any, + **kwargs: Any, + ) -> dict: + return await self.providers.database.chunks_handler.get_chunk(chunk_id) + + async def update_document_metadata( + self, + document_id: UUID, + metadata: dict, + user: User, + ) -> None: + # Verify document exists and user has access + existing_document = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=100, + filter_document_ids=[document_id], + filter_user_ids=[user.id], + ) + if not existing_document["results"]: + raise R2RException( + status_code=404, + message=( + f"Document with id {document_id} not found " + "or you don't have access." + ), + ) + + existing_document = existing_document["results"][0] + + # Merge metadata + merged_metadata = {**existing_document.metadata, **metadata} # type: ignore + + # Update document metadata + existing_document.metadata = merged_metadata # type: ignore + await self.providers.database.documents_handler.upsert_documents_overview( + existing_document # type: ignore + ) + + +class IngestionServiceAdapter: + @staticmethod + def _parse_user_data(user_data) -> User: + if isinstance(user_data, str): + try: + user_data = json.loads(user_data) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid user data format: {user_data}" + ) from e + return User.from_dict(user_data) + + @staticmethod + def parse_ingest_file_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "metadata": data["metadata"], + "document_id": ( + UUID(data["document_id"]) if data["document_id"] else None + ), + "version": data.get("version"), + "ingestion_config": data["ingestion_config"] or {}, + "file_data": data["file_data"], + "size_in_bytes": data["size_in_bytes"], + "collection_ids": data.get("collection_ids", []), + } + + @staticmethod + def parse_ingest_chunks_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "metadata": data["metadata"], + "document_id": data["document_id"], + "chunks": [ + UnprocessedChunk.from_dict(chunk) for chunk in data["chunks"] + ], + "id": data.get("id"), + } + + @staticmethod + def parse_update_chunk_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "document_id": UUID(data["document_id"]), + "id": UUID(data["id"]), + "text": data["text"], + "metadata": data.get("metadata"), + "collection_ids": data.get("collection_ids", []), + } + + @staticmethod + def parse_update_files_input(data: dict) -> dict: + return { + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + "document_ids": [UUID(doc_id) for doc_id in data["document_ids"]], + "metadatas": data["metadatas"], + "ingestion_config": data["ingestion_config"], + "file_sizes_in_bytes": data["file_sizes_in_bytes"], + "file_datas": data["file_datas"], + } + + @staticmethod + def parse_create_vector_index_input(data: dict) -> dict: + return { + "table_name": VectorTableName(data["table_name"]), + "index_method": IndexMethod(data["index_method"]), + "index_measure": IndexMeasure(data["index_measure"]), + "index_name": data["index_name"], + "index_column": data["index_column"], + "index_arguments": data["index_arguments"], + "concurrently": data["concurrently"], + } + + @staticmethod + def parse_list_vector_indices_input(input_data: dict) -> dict: + return {"table_name": input_data["table_name"]} + + @staticmethod + def parse_delete_vector_index_input(input_data: dict) -> dict: + return { + "index_name": input_data["index_name"], + "table_name": input_data.get("table_name"), + "concurrently": input_data.get("concurrently", True), + } + + @staticmethod + def parse_select_vector_index_input(input_data: dict) -> dict: + return { + "index_name": input_data["index_name"], + "table_name": input_data.get("table_name"), + } + + @staticmethod + def parse_update_document_metadata_input(data: dict) -> dict: + return { + "document_id": data["document_id"], + "metadata": data["metadata"], + "user": IngestionServiceAdapter._parse_user_data(data["user"]), + } diff --git a/.venv/lib/python3.12/site-packages/core/main/services/management_service.py b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py new file mode 100644 index 00000000..62b4ca0b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py @@ -0,0 +1,1084 @@ +import logging +import os +from collections import defaultdict +from datetime import datetime, timedelta, timezone +from typing import IO, Any, BinaryIO, Optional, Tuple +from uuid import UUID + +import toml + +from core.base import ( + CollectionResponse, + ConversationResponse, + DocumentResponse, + GenerationConfig, + GraphConstructionStatus, + Message, + MessageResponse, + Prompt, + R2RException, + StoreType, + User, +) + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + + +class ManagementService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def app_settings(self): + prompts = ( + await self.providers.database.prompts_handler.get_all_prompts() + ) + config_toml = self.config.to_toml() + config_dict = toml.loads(config_toml) + try: + project_name = os.environ["R2R_PROJECT_NAME"] + except KeyError: + project_name = "" + return { + "config": config_dict, + "prompts": prompts, + "r2r_project_name": project_name, + } + + async def users_overview( + self, + offset: int, + limit: int, + user_ids: Optional[list[UUID]] = None, + ): + return await self.providers.database.users_handler.get_users_overview( + offset=offset, + limit=limit, + user_ids=user_ids, + ) + + async def delete_documents_and_chunks_by_filter( + self, + filters: dict[str, Any], + ): + """Delete chunks matching the given filters. If any documents are now + empty (i.e., have no remaining chunks), delete those documents as well. + + Args: + filters (dict[str, Any]): Filters specifying which chunks to delete. + chunks_handler (PostgresChunksHandler): The handler for chunk operations. + documents_handler (PostgresDocumentsHandler): The handler for document operations. + graphs_handler: Handler for entity and relationship operations in the Graph. + + Returns: + dict: A summary of what was deleted. + """ + + def transform_chunk_id_to_id( + filters: dict[str, Any], + ) -> dict[str, Any]: + """Example transformation function if your filters use `chunk_id` + instead of `id`. + + Recursively transform `chunk_id` to `id`. + """ + if isinstance(filters, dict): + transformed = {} + for key, value in filters.items(): + if key == "chunk_id": + transformed["id"] = value + elif key in ["$and", "$or"]: + transformed[key] = [ + transform_chunk_id_to_id(item) for item in value + ] + else: + transformed[key] = transform_chunk_id_to_id(value) + return transformed + return filters + + # Transform filters if needed. + transformed_filters = transform_chunk_id_to_id(filters) + + # Find chunks that match the filters before deleting + interim_results = ( + await self.providers.database.chunks_handler.list_chunks( + filters=transformed_filters, + offset=0, + limit=1_000, + include_vectors=False, + ) + ) + + results = interim_results["results"] + while interim_results["total_entries"] == 1_000: + # If we hit the limit, we need to paginate to get all results + + interim_results = ( + await self.providers.database.chunks_handler.list_chunks( + filters=transformed_filters, + offset=interim_results["offset"] + 1_000, + limit=1_000, + include_vectors=False, + ) + ) + results.extend(interim_results["results"]) + + document_ids = set() + owner_id = None + + if "$and" in filters: + for condition in filters["$and"]: + if "owner_id" in condition and "$eq" in condition["owner_id"]: + owner_id = condition["owner_id"]["$eq"] + elif ( + "document_id" in condition + and "$eq" in condition["document_id"] + ): + document_ids.add(UUID(condition["document_id"]["$eq"])) + elif "document_id" in filters: + doc_id = filters["document_id"] + if isinstance(doc_id, str): + document_ids.add(UUID(doc_id)) + elif isinstance(doc_id, UUID): + document_ids.add(doc_id) + elif isinstance(doc_id, dict) and "$eq" in doc_id: + value = doc_id["$eq"] + document_ids.add( + UUID(value) if isinstance(value, str) else value + ) + + # Delete matching chunks from the database + delete_results = await self.providers.database.chunks_handler.delete( + transformed_filters + ) + + # Extract the document_ids that were affected. + affected_doc_ids = { + UUID(info["document_id"]) + for info in delete_results.values() + if info.get("document_id") + } + document_ids.update(affected_doc_ids) + + # Check if the document still has any chunks left + docs_to_delete = [] + for doc_id in document_ids: + documents_overview_response = await self.providers.database.documents_handler.get_documents_overview( + offset=0, limit=1, filter_document_ids=[doc_id] + ) + if not documents_overview_response["results"]: + raise R2RException( + status_code=404, message="Document not found" + ) + + document = documents_overview_response["results"][0] + + for collection_id in document.collection_ids: + await self.providers.database.collections_handler.decrement_collection_document_count( + collection_id=collection_id + ) + + if owner_id and str(document.owner_id) != owner_id: + raise R2RException( + status_code=404, + message="Document not found or insufficient permissions", + ) + docs_to_delete.append(doc_id) + + # Delete documents that no longer have associated chunks + for doc_id in docs_to_delete: + # Delete related entities & relationships if needed: + await self.providers.database.graphs_handler.entities.delete( + parent_id=doc_id, + store_type=StoreType.DOCUMENTS, + ) + await self.providers.database.graphs_handler.relationships.delete( + parent_id=doc_id, + store_type=StoreType.DOCUMENTS, + ) + + # Finally, delete the document from documents_overview: + await self.providers.database.documents_handler.delete( + document_id=doc_id + ) + + return { + "success": True, + "deleted_chunks_count": len(delete_results), + "deleted_documents_count": len(docs_to_delete), + "deleted_document_ids": [str(d) for d in docs_to_delete], + } + + async def download_file( + self, document_id: UUID + ) -> Optional[Tuple[str, BinaryIO, int]]: + if result := await self.providers.database.files_handler.retrieve_file( + document_id + ): + return result + return None + + async def export_files( + self, + document_ids: Optional[list[UUID]] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> tuple[str, BinaryIO, int]: + return ( + await self.providers.database.files_handler.retrieve_files_as_zip( + document_ids=document_ids, + start_date=start_date, + end_date=end_date, + ) + ) + + async def export_collections( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.collections_handler.export_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_documents( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.documents_handler.export_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_document_entities( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.entities.export_to_csv( + parent_id=id, + store_type=StoreType.DOCUMENTS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_document_relationships( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.relationships.export_to_csv( + parent_id=id, + store_type=StoreType.DOCUMENTS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_conversations( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.conversations_handler.export_conversations_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_graph_entities( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.entities.export_to_csv( + parent_id=id, + store_type=StoreType.GRAPHS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_graph_relationships( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.relationships.export_to_csv( + parent_id=id, + store_type=StoreType.GRAPHS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_graph_communities( + self, + id: UUID, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.graphs_handler.communities.export_to_csv( + parent_id=id, + store_type=StoreType.GRAPHS, + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_messages( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.conversations_handler.export_messages_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def export_users( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + return await self.providers.database.users_handler.export_to_csv( + columns=columns, + filters=filters, + include_header=include_header, + ) + + async def documents_overview( + self, + offset: int, + limit: int, + user_ids: Optional[list[UUID]] = None, + collection_ids: Optional[list[UUID]] = None, + document_ids: Optional[list[UUID]] = None, + ): + return await self.providers.database.documents_handler.get_documents_overview( + offset=offset, + limit=limit, + filter_document_ids=document_ids, + filter_user_ids=user_ids, + filter_collection_ids=collection_ids, + ) + + async def update_document_metadata( + self, + document_id: UUID, + metadata: list[dict], + overwrite: bool = False, + ): + return await self.providers.database.documents_handler.update_document_metadata( + document_id=document_id, + metadata=metadata, + overwrite=overwrite, + ) + + async def list_document_chunks( + self, + document_id: UUID, + offset: int, + limit: int, + include_vectors: bool = False, + ): + return ( + await self.providers.database.chunks_handler.list_document_chunks( + document_id=document_id, + offset=offset, + limit=limit, + include_vectors=include_vectors, + ) + ) + + async def assign_document_to_collection( + self, document_id: UUID, collection_id: UUID + ): + await self.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id, collection_id + ) + await self.providers.database.collections_handler.assign_document_to_collection_relational( + document_id, collection_id + ) + await self.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=GraphConstructionStatus.OUTDATED, + ) + await self.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=GraphConstructionStatus.OUTDATED, + ) + + return {"message": "Document assigned to collection successfully"} + + async def remove_document_from_collection( + self, document_id: UUID, collection_id: UUID + ): + await self.providers.database.collections_handler.remove_document_from_collection_relational( + document_id, collection_id + ) + await self.providers.database.chunks_handler.remove_document_from_collection_vector( + document_id, collection_id + ) + # await self.providers.database.graphs_handler.delete_node_via_document_id( + # document_id, collection_id + # ) + return None + + def _process_relationships( + self, relationships: list[Tuple[str, str, str]] + ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]: + graph = defaultdict(list) + grouped: dict[str, dict[str, list[str]]] = defaultdict( + lambda: defaultdict(list) + ) + for subject, relation, obj in relationships: + graph[subject].append(obj) + grouped[subject][relation].append(obj) + if obj not in graph: + graph[obj] = [] + return dict(graph), dict(grouped) + + def generate_output( + self, + grouped_relationships: dict[str, dict[str, list[str]]], + graph: dict[str, list[str]], + descriptions_dict: dict[str, str], + print_descriptions: bool = True, + ) -> list[str]: + output = [] + # Print grouped relationships + for subject, relations in grouped_relationships.items(): + output.append(f"\n== {subject} ==") + if print_descriptions and subject in descriptions_dict: + output.append(f"\tDescription: {descriptions_dict[subject]}") + for relation, objects in relations.items(): + output.append(f" {relation}:") + for obj in objects: + output.append(f" - {obj}") + if print_descriptions and obj in descriptions_dict: + output.append( + f" Description: {descriptions_dict[obj]}" + ) + + # Print basic graph statistics + output.extend( + [ + "\n== Graph Statistics ==", + f"Number of nodes: {len(graph)}", + f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}", + f"Number of connected components: {self._count_connected_components(graph)}", + ] + ) + + # Find central nodes + central_nodes = self._get_central_nodes(graph) + output.extend( + [ + "\n== Most Central Nodes ==", + *( + f" {node}: {centrality:.4f}" + for node, centrality in central_nodes + ), + ] + ) + + return output + + def _count_connected_components(self, graph: dict[str, list[str]]) -> int: + visited = set() + components = 0 + + def dfs(node): + visited.add(node) + for neighbor in graph[node]: + if neighbor not in visited: + dfs(neighbor) + + for node in graph: + if node not in visited: + dfs(node) + components += 1 + + return components + + def _get_central_nodes( + self, graph: dict[str, list[str]] + ) -> list[Tuple[str, float]]: + degree = {node: len(neighbors) for node, neighbors in graph.items()} + total_nodes = len(graph) + centrality = { + node: deg / (total_nodes - 1) for node, deg in degree.items() + } + return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5] + + async def create_collection( + self, + owner_id: UUID, + name: Optional[str] = None, + description: str | None = None, + ) -> CollectionResponse: + result = await self.providers.database.collections_handler.create_collection( + owner_id=owner_id, + name=name, + description=description, + ) + await self.providers.database.graphs_handler.create( + collection_id=result.id, + name=name, + description=description, + ) + return result + + async def update_collection( + self, + collection_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + generate_description: bool = False, + ) -> CollectionResponse: + if generate_description: + description = await self.summarize_collection( + id=collection_id, offset=0, limit=100 + ) + return await self.providers.database.collections_handler.update_collection( + collection_id=collection_id, + name=name, + description=description, + ) + + async def delete_collection(self, collection_id: UUID) -> bool: + await self.providers.database.collections_handler.delete_collection_relational( + collection_id + ) + await self.providers.database.chunks_handler.delete_collection_vector( + collection_id + ) + try: + await self.providers.database.graphs_handler.delete( + collection_id=collection_id, + ) + except Exception as e: + logger.warning( + f"Error deleting graph for collection {collection_id}: {e}" + ) + return True + + async def collections_overview( + self, + offset: int, + limit: int, + user_ids: Optional[list[UUID]] = None, + document_ids: Optional[list[UUID]] = None, + collection_ids: Optional[list[UUID]] = None, + ) -> dict[str, list[CollectionResponse] | int]: + return await self.providers.database.collections_handler.get_collections_overview( + offset=offset, + limit=limit, + filter_user_ids=user_ids, + filter_document_ids=document_ids, + filter_collection_ids=collection_ids, + ) + + async def add_user_to_collection( + self, user_id: UUID, collection_id: UUID + ) -> bool: + return ( + await self.providers.database.users_handler.add_user_to_collection( + user_id, collection_id + ) + ) + + async def remove_user_from_collection( + self, user_id: UUID, collection_id: UUID + ) -> bool: + return await self.providers.database.users_handler.remove_user_from_collection( + user_id, collection_id + ) + + async def get_users_in_collection( + self, collection_id: UUID, offset: int = 0, limit: int = 100 + ) -> dict[str, list[User] | int]: + return await self.providers.database.users_handler.get_users_in_collection( + collection_id, offset=offset, limit=limit + ) + + async def documents_in_collection( + self, collection_id: UUID, offset: int = 0, limit: int = 100 + ) -> dict[str, list[DocumentResponse] | int]: + return await self.providers.database.collections_handler.documents_in_collection( + collection_id, offset=offset, limit=limit + ) + + async def summarize_collection( + self, id: UUID, offset: int, limit: int + ) -> str: + documents_in_collection_response = await self.documents_in_collection( + collection_id=id, + offset=offset, + limit=limit, + ) + + document_summaries = [ + document.summary + for document in documents_in_collection_response["results"] # type: ignore + ] + + logger.info( + f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents." + ) + + formatted_summaries = "\n\n".join(document_summaries) # type: ignore + + messages = await self.providers.database.prompts_handler.get_message_payload( + system_prompt_name=self.config.database.collection_summary_system_prompt, + task_prompt_name=self.config.database.collection_summary_prompt, + task_inputs={"document_summaries": formatted_summaries}, + ) + + response = await self.providers.llm.aget_completion( + messages=messages, + generation_config=GenerationConfig( + model=self.config.ingestion.document_summary_model + or self.config.app.fast_llm + ), + ) + + if collection_summary := response.choices[0].message.content: + return collection_summary + else: + raise ValueError("Expected a generated response.") + + async def add_prompt( + self, name: str, template: str, input_types: dict[str, str] + ) -> dict: + try: + await self.providers.database.prompts_handler.add_prompt( + name, template, input_types + ) + return f"Prompt '{name}' added successfully." # type: ignore + except ValueError as e: + raise R2RException(status_code=400, message=str(e)) from e + + async def get_cached_prompt( + self, + prompt_name: str, + inputs: Optional[dict[str, Any]] = None, + prompt_override: Optional[str] = None, + ) -> dict: + try: + return { + "message": ( + await self.providers.database.prompts_handler.get_cached_prompt( + prompt_name=prompt_name, + inputs=inputs, + prompt_override=prompt_override, + ) + ) + } + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def get_prompt( + self, + prompt_name: str, + inputs: Optional[dict[str, Any]] = None, + prompt_override: Optional[str] = None, + ) -> dict: + try: + return await self.providers.database.prompts_handler.get_prompt( # type: ignore + name=prompt_name, + inputs=inputs, + prompt_override=prompt_override, + ) + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def get_all_prompts(self) -> dict[str, Prompt]: + return await self.providers.database.prompts_handler.get_all_prompts() + + async def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = None, + ) -> dict: + try: + await self.providers.database.prompts_handler.update_prompt( + name, template, input_types + ) + return f"Prompt '{name}' updated successfully." # type: ignore + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def delete_prompt(self, name: str) -> dict: + try: + await self.providers.database.prompts_handler.delete_prompt(name) + return {"message": f"Prompt '{name}' deleted successfully."} + except ValueError as e: + raise R2RException(status_code=404, message=str(e)) from e + + async def get_conversation( + self, + conversation_id: UUID, + user_ids: Optional[list[UUID]] = None, + ) -> list[MessageResponse]: + return await self.providers.database.conversations_handler.get_conversation( + conversation_id=conversation_id, + filter_user_ids=user_ids, + ) + + async def create_conversation( + self, + user_id: Optional[UUID] = None, + name: Optional[str] = None, + ) -> ConversationResponse: + return await self.providers.database.conversations_handler.create_conversation( + user_id=user_id, + name=name, + ) + + async def conversations_overview( + self, + offset: int, + limit: int, + conversation_ids: Optional[list[UUID]] = None, + user_ids: Optional[list[UUID]] = None, + ) -> dict[str, list[dict] | int]: + return await self.providers.database.conversations_handler.get_conversations_overview( + offset=offset, + limit=limit, + filter_user_ids=user_ids, + conversation_ids=conversation_ids, + ) + + async def add_message( + self, + conversation_id: UUID, + content: Message, + parent_id: Optional[UUID] = None, + metadata: Optional[dict] = None, + ) -> MessageResponse: + return await self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=content, + parent_id=parent_id, + metadata=metadata, + ) + + async def edit_message( + self, + message_id: UUID, + new_content: Optional[str] = None, + additional_metadata: Optional[dict] = None, + ) -> dict[str, Any]: + return ( + await self.providers.database.conversations_handler.edit_message( + message_id=message_id, + new_content=new_content, + additional_metadata=additional_metadata or {}, + ) + ) + + async def update_conversation( + self, conversation_id: UUID, name: str + ) -> ConversationResponse: + return await self.providers.database.conversations_handler.update_conversation( + conversation_id=conversation_id, name=name + ) + + async def delete_conversation( + self, + conversation_id: UUID, + user_ids: Optional[list[UUID]] = None, + ) -> None: + await ( + self.providers.database.conversations_handler.delete_conversation( + conversation_id=conversation_id, + filter_user_ids=user_ids, + ) + ) + + async def get_user_max_documents(self, user_id: UUID) -> int | None: + # Fetch the user to see if they have any overrides stored + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if user.limits_overrides and "max_documents" in user.limits_overrides: + return user.limits_overrides["max_documents"] + return self.config.app.default_max_documents_per_user + + async def get_user_max_chunks(self, user_id: UUID) -> int | None: + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if user.limits_overrides and "max_chunks" in user.limits_overrides: + return user.limits_overrides["max_chunks"] + return self.config.app.default_max_chunks_per_user + + async def get_user_max_collections(self, user_id: UUID) -> int | None: + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + if ( + user.limits_overrides + and "max_collections" in user.limits_overrides + ): + return user.limits_overrides["max_collections"] + return self.config.app.default_max_collections_per_user + + async def get_max_upload_size_by_type( + self, user_id: UUID, file_type_or_ext: str + ) -> int: + """Return the maximum allowed upload size (in bytes) for the given + user's file type/extension. Respects user-level overrides if present, + falling back to the system config. + + ```json + { + "limits_overrides": { + "max_file_size": 20_000_000, + "max_file_size_by_type": + { + "pdf": 50_000_000, + "docx": 30_000_000 + }, + ... + } + } + ``` + """ + # 1. Normalize extension + ext = file_type_or_ext.lower().lstrip(".") + + # 2. Fetch user from DB to see if we have any overrides + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + user_overrides = user.limits_overrides or {} + + # 3. Check if there's a user-level override for "max_file_size_by_type" + user_file_type_limits = user_overrides.get("max_file_size_by_type", {}) + if ext in user_file_type_limits: + return user_file_type_limits[ext] + + # 4. If not, check if there's a user-level fallback "max_file_size" + if "max_file_size" in user_overrides: + return user_overrides["max_file_size"] + + # 5. If none exist at user level, use system config + # Example config paths: + system_type_limits = self.config.app.max_upload_size_by_type + if ext in system_type_limits: + return system_type_limits[ext] + + # 6. Otherwise, return the global default + return self.config.app.default_max_upload_size + + async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: + """ + Return a dictionary containing: + - The system default limits (from self.config.limits) + - The user's overrides (from user.limits_overrides) + - The final 'effective' set of limits after merging (overall) + - The usage for each relevant limit (per-route usage, etc.) + """ + # 1) Fetch the user + user = await self.providers.database.users_handler.get_user_by_id( + user_id + ) + user_overrides = user.limits_overrides or {} + + # 2) Grab system defaults + system_defaults = { + "global_per_min": self.config.database.limits.global_per_min, + "route_per_min": self.config.database.limits.route_per_min, + "monthly_limit": self.config.database.limits.monthly_limit, + # Add additional fields if your LimitSettings has them + } + + # 3) Build the overall (global) "effective limits" ignoring any specific route + overall_effective = ( + self.providers.database.limits_handler.determine_effective_limits( + user, route="" + ) + ) + + # 4) Build usage data. We'll do top-level usage for global_per_min/monthly, + # then do route-by-route usage in a loop. + usage: dict[str, Any] = {} + now = datetime.now(timezone.utc) + one_min_ago = now - timedelta(minutes=1) + + # (a) Global usage (per-minute) + global_per_min_used = ( + await self.providers.database.limits_handler._count_requests( + user_id, route=None, since=one_min_ago + ) + ) + # (a2) Global usage (monthly) - i.e. usage across ALL routes + global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests( + user_id, route=None + ) + + usage["global_per_min"] = { + "used": global_per_min_used, + "limit": overall_effective.global_per_min, + "remaining": ( + overall_effective.global_per_min - global_per_min_used + if overall_effective.global_per_min is not None + else None + ), + } + usage["monthly_limit"] = { + "used": global_monthly_used, + "limit": overall_effective.monthly_limit, + "remaining": ( + overall_effective.monthly_limit - global_monthly_used + if overall_effective.monthly_limit is not None + else None + ), + } + + # (b) Route-level usage. We'll gather all routes from system + user overrides + system_route_limits = ( + self.config.database.route_limits + ) # dict[str, LimitSettings] + user_route_overrides = user_overrides.get("route_overrides", {}) + route_keys = set(system_route_limits.keys()) | set( + user_route_overrides.keys() + ) + + usage["routes"] = {} + for route in route_keys: + # 1) Get the final merged limits for this specific route + route_effective = self.providers.database.limits_handler.determine_effective_limits( + user, route + ) + + # 2) Count requests for the last minute on this route + route_per_min_used = ( + await self.providers.database.limits_handler._count_requests( + user_id, route, one_min_ago + ) + ) + + # 3) Count route-specific monthly usage + route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests( + user_id, route + ) + + usage["routes"][route] = { + "route_per_min": { + "used": route_per_min_used, + "limit": route_effective.route_per_min, + "remaining": ( + route_effective.route_per_min - route_per_min_used + if route_effective.route_per_min is not None + else None + ), + }, + "monthly_limit": { + "used": route_monthly_used, + "limit": route_effective.monthly_limit, + "remaining": ( + route_effective.monthly_limit - route_monthly_used + if route_effective.monthly_limit is not None + else None + ), + }, + } + + max_documents = await self.get_user_max_documents(user_id) + used_documents = ( + await self.providers.database.documents_handler.get_documents_overview( + limit=1, offset=0, filter_user_ids=[user_id] + ) + )["total_entries"] + max_chunks = await self.get_user_max_chunks(user_id) + used_chunks = ( + await self.providers.database.chunks_handler.list_chunks( + limit=1, offset=0, filters={"owner_id": user_id} + ) + )["total_entries"] + + max_collections = await self.get_user_max_collections(user_id) + used_collections: int = ( # type: ignore + await self.providers.database.collections_handler.get_collections_overview( + limit=1, offset=0, filter_user_ids=[user_id] + ) + )["total_entries"] + + storage_limits = { + "chunks": { + "limit": max_chunks, + "used": used_chunks, + "remaining": ( + max_chunks - used_chunks + if max_chunks is not None + else None + ), + }, + "documents": { + "limit": max_documents, + "used": used_documents, + "remaining": ( + max_documents - used_documents + if max_documents is not None + else None + ), + }, + "collections": { + "limit": max_collections, + "used": used_collections, + "remaining": ( + max_collections - used_collections + if max_collections is not None + else None + ), + }, + } + # 5) Return a structured response + return { + "storage_limits": storage_limits, + "system_defaults": system_defaults, + "user_overrides": user_overrides, + "effective_limits": { + "global_per_min": overall_effective.global_per_min, + "route_per_min": overall_effective.route_per_min, + "monthly_limit": overall_effective.monthly_limit, + }, + "usage": usage, + } diff --git a/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py new file mode 100644 index 00000000..2ae4af31 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py @@ -0,0 +1,2087 @@ +import asyncio +import json +import logging +from copy import deepcopy +from datetime import datetime +from typing import Any, AsyncGenerator, Literal, Optional +from uuid import UUID + +from fastapi import HTTPException + +from core import ( + Citation, + R2RRAGAgent, + R2RStreamingRAGAgent, + R2RStreamingResearchAgent, + R2RXMLToolsRAGAgent, + R2RXMLToolsResearchAgent, + R2RXMLToolsStreamingRAGAgent, + R2RXMLToolsStreamingResearchAgent, +) +from core.agent.research import R2RResearchAgent +from core.base import ( + AggregateSearchResult, + ChunkSearchResult, + DocumentResponse, + GenerationConfig, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, + GraphSearchResult, + GraphSearchResultType, + IngestionStatus, + Message, + R2RException, + SearchSettings, + WebSearchResult, + format_search_results_for_llm, +) +from core.base.api.models import RAGResponse, User +from core.utils import ( + CitationTracker, + SearchResultsCollector, + SSEFormatter, + dump_collector, + dump_obj, + extract_citations, + find_new_citation_spans, + num_tokens_from_messages, +) +from shared.api.models.management.responses import MessageResponse + +from ..abstractions import R2RProviders +from ..config import R2RConfig +from .base import Service + +logger = logging.getLogger() + + +class AgentFactory: + """ + Factory class that creates appropriate agent instances based on mode, + model type, and streaming preferences. + """ + + @staticmethod + def create_agent( + mode: Literal["rag", "research"], + database_provider, + llm_provider, + config, # : AgentConfig + search_settings, # : SearchSettings + generation_config, #: GenerationConfig + app_config, #: AppConfig + knowledge_search_method, + content_method, + file_search_method, + max_tool_context_length: int = 32_768, + rag_tools: Optional[list[str]] = None, + research_tools: Optional[list[str]] = None, + tools: Optional[list[str]] = None, # For backward compatibility + ): + """ + Creates and returns the appropriate agent based on provided parameters. + + Args: + mode: Either "rag" or "research" to determine agent type + database_provider: Provider for database operations + llm_provider: Provider for LLM operations + config: Agent configuration + search_settings: Search settings for retrieval + generation_config: Generation configuration with LLM parameters + app_config: Application configuration + knowledge_search_method: Method for knowledge search + content_method: Method for content retrieval + file_search_method: Method for file search + max_tool_context_length: Maximum context length for tools + rag_tools: Tools specifically for RAG mode + research_tools: Tools specifically for Research mode + tools: Deprecated backward compatibility parameter + + Returns: + An appropriate agent instance + """ + # Create a deep copy of the config to avoid modifying the original + agent_config = deepcopy(config) + + # Handle tool specifications based on mode + if mode == "rag": + # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults + if rag_tools: + agent_config.rag_tools = rag_tools + elif tools: # Backward compatibility + agent_config.rag_tools = tools + # If neither was provided, the config's default rag_tools will be used + elif mode == "research": + # For Research mode, prioritize explicitly passed research_tools, then tools, then config defaults + if research_tools: + agent_config.research_tools = research_tools + elif tools: # Backward compatibility + agent_config.research_tools = tools + # If neither was provided, the config's default research_tools will be used + + # Determine if we need XML-based tools based on model + use_xml_format = False + # if generation_config.model: + # model_str = generation_config.model.lower() + # use_xml_format = "deepseek" in model_str or "gemini" in model_str + + # Set streaming mode based on generation config + is_streaming = generation_config.stream + + # Create the appropriate agent based on all factors + if mode == "rag": + # RAG mode agents + if is_streaming: + if use_xml_format: + return R2RXMLToolsStreamingRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RStreamingRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + if use_xml_format: + return R2RXMLToolsRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + # Research mode agents + if is_streaming: + if use_xml_format: + return R2RXMLToolsStreamingResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RStreamingResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + if use_xml_format: + return R2RXMLToolsResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + else: + return R2RResearchAgent( + app_config=app_config, + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + ) + + +class RetrievalService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + ): + super().__init__( + config, + providers, + ) + + async def search( + self, + query: str, + search_settings: SearchSettings = SearchSettings(), + *args, + **kwargs, + ) -> AggregateSearchResult: + """ + Depending on search_settings.search_strategy, fan out + to basic, hyde, or rag_fusion method. Each returns + an AggregateSearchResult that includes chunk + graph results. + """ + strategy = search_settings.search_strategy.lower() + + if strategy == "hyde": + return await self._hyde_search(query, search_settings) + elif strategy == "rag_fusion": + return await self._rag_fusion_search(query, search_settings) + else: + # 'vanilla', 'basic', or anything else... + return await self._basic_search(query, search_settings) + + async def _basic_search( + self, query: str, search_settings: SearchSettings + ) -> AggregateSearchResult: + """ + 1) Possibly embed the query (if semantic or hybrid). + 2) Chunk search. + 3) Graph search. + 4) Combine into an AggregateSearchResult. + """ + # -- 1) Possibly embed the query + query_vector = None + if ( + search_settings.use_semantic_search + or search_settings.use_hybrid_search + ): + query_vector = ( + await self.providers.completion_embedding.async_get_embedding( + query # , EmbeddingPurpose.QUERY + ) + ) + + # -- 2) Chunk search + chunk_results = [] + if search_settings.chunk_settings.enabled: + chunk_results = await self._vector_search_logic( + query_text=query, + search_settings=search_settings, + precomputed_vector=query_vector, # Pass in the vector we just computed (if any) + ) + + # -- 3) Graph search + graph_results = [] + if search_settings.graph_settings.enabled: + graph_results = await self._graph_search_logic( + query_text=query, + search_settings=search_settings, + precomputed_vector=query_vector, # same idea + ) + + # -- 4) Combine + return AggregateSearchResult( + chunk_search_results=chunk_results, + graph_search_results=graph_results, + ) + + async def _rag_fusion_search( + self, query: str, search_settings: SearchSettings + ) -> AggregateSearchResult: + """ + Implements 'RAG Fusion': + 1) Generate N sub-queries from the user query + 2) For each sub-query => do chunk & graph search + 3) Combine / fuse all retrieved results using Reciprocal Rank Fusion + 4) Return an AggregateSearchResult + """ + + # 1) Generate sub-queries from the user’s original query + # Typically you want the original query to remain in the set as well, + # so that we do not lose the exact user intent. + sub_queries = [query] + if search_settings.num_sub_queries > 1: + # Generate (num_sub_queries - 1) rephrasings + # (Or just generate exactly search_settings.num_sub_queries, + # and remove the first if you prefer.) + extra = await self._generate_similar_queries( + query=query, + num_sub_queries=search_settings.num_sub_queries - 1, + ) + sub_queries.extend(extra) + + # 2) For each sub-query => do chunk + graph search + # We’ll store them in a structure so we can fuse them. + # chunk_results_list is a list of lists of ChunkSearchResult + # graph_results_list is a list of lists of GraphSearchResult + chunk_results_list = [] + graph_results_list = [] + + for sq in sub_queries: + # Recompute or reuse the embedding if desired + # (You could do so, but not mandatory if you have a local approach) + # chunk + graph search + aggr = await self._basic_search(sq, search_settings) + chunk_results_list.append(aggr.chunk_search_results) + graph_results_list.append(aggr.graph_search_results) + + # 3) Fuse the chunk results and fuse the graph results. + # We'll use a simple RRF approach: each sub-query's result list + # is a ranking from best to worst. + fused_chunk_results = self._reciprocal_rank_fusion_chunks( # type: ignore + chunk_results_list # type: ignore + ) + filtered_graph_results = [ + results for results in graph_results_list if results is not None + ] + fused_graph_results = self._reciprocal_rank_fusion_graphs( + filtered_graph_results + ) + + # Optionally, after the RRF, you may want to do a final semantic re-rank + # of the fused results by the user’s original query. + # E.g.: + if fused_chunk_results: + fused_chunk_results = ( + await self.providers.completion_embedding.arerank( + query=query, + results=fused_chunk_results, + limit=search_settings.limit, + ) + ) + + # Sort or slice the graph results if needed: + if fused_graph_results and search_settings.include_scores: + fused_graph_results.sort( + key=lambda g: g.score if g.score is not None else 0.0, + reverse=True, + ) + fused_graph_results = fused_graph_results[: search_settings.limit] + + # 4) Return final AggregateSearchResult + return AggregateSearchResult( + chunk_search_results=fused_chunk_results, + graph_search_results=fused_graph_results, + ) + + async def _generate_similar_queries( + self, query: str, num_sub_queries: int = 2 + ) -> list[str]: + """ + Use your LLM to produce 'similar' queries or rephrasings + that might retrieve different but relevant documents. + + You can prompt your model with something like: + "Given the user query, produce N alternative short queries that + capture possible interpretations or expansions. + Keep them relevant to the user's intent." + """ + if num_sub_queries < 1: + return [] + + # In production, you'd fetch a prompt from your prompts DB: + # Something like: + prompt = f""" + You are a helpful assistant. The user query is: "{query}" + Generate {num_sub_queries} alternative search queries that capture + slightly different phrasings or expansions while preserving the core meaning. + Return each alternative on its own line. + """ + + # For a short generation, we can set minimal tokens + gen_config = GenerationConfig( + model=self.config.app.fast_llm, + max_tokens=128, + temperature=0.8, + stream=False, + ) + response = await self.providers.llm.aget_completion( + messages=[{"role": "system", "content": prompt}], + generation_config=gen_config, + ) + raw_text = ( + response.choices[0].message.content.strip() + if response.choices[0].message.content is not None + else "" + ) + + # Suppose each line is a sub-query + lines = [line.strip() for line in raw_text.split("\n") if line.strip()] + return lines[:num_sub_queries] + + def _reciprocal_rank_fusion_chunks( + self, list_of_rankings: list[list[ChunkSearchResult]], k: float = 60.0 + ) -> list[ChunkSearchResult]: + """ + Simple RRF for chunk results. + list_of_rankings is something like: + [ + [chunkA, chunkB, chunkC], # sub-query #1, in order + [chunkC, chunkD], # sub-query #2, in order + ... + ] + + We'll produce a dictionary mapping chunk.id -> aggregated_score, + then sort descending. + """ + if not list_of_rankings: + return [] + + # Build a map of chunk_id => final_rff_score + score_map: dict[str, float] = {} + + # We also need to store a reference to the chunk object + # (the "first" or "best" instance), so we can reconstruct them later + chunk_map: dict[str, Any] = {} + + for ranking_list in list_of_rankings: + for rank, chunk_result in enumerate(ranking_list, start=1): + if not chunk_result.id: + # fallback if no chunk_id is present + continue + + c_id = chunk_result.id + # RRF scoring + # score = sum(1 / (k + rank)) for each sub-query ranking + # We'll accumulate it. + existing_score = score_map.get(str(c_id), 0.0) + new_score = existing_score + 1.0 / (k + rank) + score_map[str(c_id)] = new_score + + # Keep a reference to chunk + if c_id not in chunk_map: + chunk_map[str(c_id)] = chunk_result + + # Now sort by final score + fused_items = sorted( + score_map.items(), key=lambda x: x[1], reverse=True + ) + + # Rebuild the final list of chunk results with new 'score' + fused_chunks = [] + for c_id, agg_score in fused_items: # type: ignore + # copy the chunk + c = chunk_map[str(c_id)] + # Optionally store the RRF score if you want + c.score = agg_score + fused_chunks.append(c) + + return fused_chunks + + def _reciprocal_rank_fusion_graphs( + self, list_of_rankings: list[list[GraphSearchResult]], k: float = 60.0 + ) -> list[GraphSearchResult]: + """ + Similar RRF logic but for graph results. + """ + if not list_of_rankings: + return [] + + score_map: dict[str, float] = {} + graph_map = {} + + for ranking_list in list_of_rankings: + for rank, g_result in enumerate(ranking_list, start=1): + # We'll do a naive ID approach: + # If your GraphSearchResult has a unique ID in g_result.content.id or so + # we can use that as a key. + # If not, you might have to build a key from the content. + g_id = None + if hasattr(g_result.content, "id"): + g_id = str(g_result.content.id) + else: + # fallback + g_id = f"graph_{hash(g_result.content.json())}" + + existing_score = score_map.get(g_id, 0.0) + new_score = existing_score + 1.0 / (k + rank) + score_map[g_id] = new_score + + if g_id not in graph_map: + graph_map[g_id] = g_result + + # Sort descending by aggregated RRF score + fused_items = sorted( + score_map.items(), key=lambda x: x[1], reverse=True + ) + + fused_graphs = [] + for g_id, agg_score in fused_items: + g = graph_map[g_id] + g.score = agg_score + fused_graphs.append(g) + + return fused_graphs + + async def _hyde_search( + self, query: str, search_settings: SearchSettings + ) -> AggregateSearchResult: + """ + 1) Generate N hypothetical docs via LLM + 2) For each doc => embed => parallel chunk search & graph search + 3) Merge chunk results => optional re-rank => top K + 4) Merge graph results => (optionally re-rank or keep them distinct) + """ + # 1) Generate hypothetical docs + hyde_docs = await self._run_hyde_generation( + query=query, num_sub_queries=search_settings.num_sub_queries + ) + + chunk_all = [] + graph_all = [] + + # We'll gather the per-doc searches in parallel + tasks = [] + for hypothetical_text in hyde_docs: + tasks.append( + asyncio.create_task( + self._fanout_chunk_and_graph_search( + user_text=query, # The user’s original query + alt_text=hypothetical_text, # The hypothetical doc + search_settings=search_settings, + ) + ) + ) + + # 2) Wait for them all + results_list = await asyncio.gather(*tasks) + # each item in results_list is a tuple: (chunks, graphs) + + # Flatten chunk+graph results + for c_results, g_results in results_list: + chunk_all.extend(c_results) + graph_all.extend(g_results) + + # 3) Re-rank chunk results with the original query + if chunk_all: + chunk_all = await self.providers.completion_embedding.arerank( + query=query, # final user query + results=chunk_all, + limit=int( + search_settings.limit * search_settings.num_sub_queries + ), + # no limit on results - limit=search_settings.limit, + ) + + # 4) If needed, re-rank graph results or just slice top-K by score + if search_settings.include_scores and graph_all: + graph_all.sort(key=lambda g: g.score or 0.0, reverse=True) + graph_all = ( + graph_all # no limit on results - [: search_settings.limit] + ) + + return AggregateSearchResult( + chunk_search_results=chunk_all, + graph_search_results=graph_all, + ) + + async def _fanout_chunk_and_graph_search( + self, + user_text: str, + alt_text: str, + search_settings: SearchSettings, + ) -> tuple[list[ChunkSearchResult], list[GraphSearchResult]]: + """ + 1) embed alt_text (HyDE doc or sub-query, etc.) + 2) chunk search + graph search with that embedding + """ + # Precompute the embedding of alt_text + vec = await self.providers.completion_embedding.async_get_embedding( + alt_text # , EmbeddingPurpose.QUERY + ) + + # chunk search + chunk_results = [] + if search_settings.chunk_settings.enabled: + chunk_results = await self._vector_search_logic( + query_text=user_text, # used for text-based stuff & re-ranking + search_settings=search_settings, + precomputed_vector=vec, # use the alt_text vector for semantic/hybrid + ) + + # graph search + graph_results = [] + if search_settings.graph_settings.enabled: + graph_results = await self._graph_search_logic( + query_text=user_text, # or alt_text if you prefer + search_settings=search_settings, + precomputed_vector=vec, + ) + + return (chunk_results, graph_results) + + async def _vector_search_logic( + self, + query_text: str, + search_settings: SearchSettings, + precomputed_vector: Optional[list[float]] = None, + ) -> list[ChunkSearchResult]: + """ + • If precomputed_vector is given, use it for semantic/hybrid search. + Otherwise embed query_text ourselves. + • Then do fulltext, semantic, or hybrid search. + • Optionally re-rank and return results. + """ + if not search_settings.chunk_settings.enabled: + return [] + + # 1) Possibly embed + query_vector = precomputed_vector + if query_vector is None and ( + search_settings.use_semantic_search + or search_settings.use_hybrid_search + ): + query_vector = ( + await self.providers.completion_embedding.async_get_embedding( + query_text # , EmbeddingPurpose.QUERY + ) + ) + + # 2) Choose which search to run + if ( + search_settings.use_fulltext_search + and search_settings.use_semantic_search + ) or search_settings.use_hybrid_search: + if query_vector is None: + raise ValueError("Hybrid search requires a precomputed vector") + raw_results = ( + await self.providers.database.chunks_handler.hybrid_search( + query_vector=query_vector, + query_text=query_text, + search_settings=search_settings, + ) + ) + elif search_settings.use_fulltext_search: + raw_results = ( + await self.providers.database.chunks_handler.full_text_search( + query_text=query_text, + search_settings=search_settings, + ) + ) + elif search_settings.use_semantic_search: + if query_vector is None: + raise ValueError( + "Semantic search requires a precomputed vector" + ) + raw_results = ( + await self.providers.database.chunks_handler.semantic_search( + query_vector=query_vector, + search_settings=search_settings, + ) + ) + else: + raise ValueError( + "At least one of use_fulltext_search or use_semantic_search must be True" + ) + + # 3) Re-rank + reranked = await self.providers.completion_embedding.arerank( + query=query_text, results=raw_results, limit=search_settings.limit + ) + + # 4) Possibly augment text or metadata + final_results = [] + for r in reranked: + if "title" in r.metadata and search_settings.include_metadatas: + title = r.metadata["title"] + r.text = f"Document Title: {title}\n\nText: {r.text}" + r.metadata["associated_query"] = query_text + final_results.append(r) + + return final_results + + async def _graph_search_logic( + self, + query_text: str, + search_settings: SearchSettings, + precomputed_vector: Optional[list[float]] = None, + ) -> list[GraphSearchResult]: + """ + Mirrors your previous GraphSearch approach: + • if precomputed_vector is supplied, use that + • otherwise embed query_text + • search entities, relationships, communities + • return results + """ + results: list[GraphSearchResult] = [] + + if not search_settings.graph_settings.enabled: + return results + + # 1) Possibly embed + query_embedding = precomputed_vector + if query_embedding is None: + query_embedding = ( + await self.providers.completion_embedding.async_get_embedding( + query_text + ) + ) + + base_limit = search_settings.limit + graph_limits = search_settings.graph_settings.limits or {} + + # Entity search + entity_limit = graph_limits.get("entities", base_limit) + entity_cursor = self.providers.database.graphs_handler.graph_search( + query_text, + search_type="entities", + limit=entity_limit, + query_embedding=query_embedding, + property_names=["name", "description", "id"], + filters=search_settings.filters, + ) + async for ent in entity_cursor: + score = ent.get("similarity_score") + metadata = ent.get("metadata", {}) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except Exception as e: + pass + + results.append( + GraphSearchResult( + id=ent.get("id", None), + content=GraphEntityResult( + name=ent.get("name", ""), + description=ent.get("description", ""), + id=ent.get("id", None), + ), + result_type=GraphSearchResultType.ENTITY, + score=score if search_settings.include_scores else None, + metadata=( + { + **(metadata or {}), + "associated_query": query_text, + } + if search_settings.include_metadatas + else {} + ), + ) + ) + + # Relationship search + rel_limit = graph_limits.get("relationships", base_limit) + rel_cursor = self.providers.database.graphs_handler.graph_search( + query_text, + search_type="relationships", + limit=rel_limit, + query_embedding=query_embedding, + property_names=[ + "id", + "subject", + "predicate", + "object", + "description", + "subject_id", + "object_id", + ], + filters=search_settings.filters, + ) + async for rel in rel_cursor: + score = rel.get("similarity_score") + metadata = rel.get("metadata", {}) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except Exception as e: + pass + + results.append( + GraphSearchResult( + id=ent.get("id", None), + content=GraphRelationshipResult( + id=rel.get("id", None), + subject=rel.get("subject", ""), + predicate=rel.get("predicate", ""), + object=rel.get("object", ""), + subject_id=rel.get("subject_id", None), + object_id=rel.get("object_id", None), + description=rel.get("description", ""), + ), + result_type=GraphSearchResultType.RELATIONSHIP, + score=score if search_settings.include_scores else None, + metadata=( + { + **(metadata or {}), + "associated_query": query_text, + } + if search_settings.include_metadatas + else {} + ), + ) + ) + + # Community search + comm_limit = graph_limits.get("communities", base_limit) + comm_cursor = self.providers.database.graphs_handler.graph_search( + query_text, + search_type="communities", + limit=comm_limit, + query_embedding=query_embedding, + property_names=[ + "id", + "name", + "summary", + ], + filters=search_settings.filters, + ) + async for comm in comm_cursor: + score = comm.get("similarity_score") + metadata = comm.get("metadata", {}) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except Exception as e: + pass + + results.append( + GraphSearchResult( + id=ent.get("id", None), + content=GraphCommunityResult( + id=comm.get("id", None), + name=comm.get("name", ""), + summary=comm.get("summary", ""), + ), + result_type=GraphSearchResultType.COMMUNITY, + score=score if search_settings.include_scores else None, + metadata=( + { + **(metadata or {}), + "associated_query": query_text, + } + if search_settings.include_metadatas + else {} + ), + ) + ) + + return results + + async def _run_hyde_generation( + self, + query: str, + num_sub_queries: int = 2, + ) -> list[str]: + """ + Calls the LLM with a 'HyDE' style prompt to produce multiple + hypothetical documents/answers, one per line or separated by blank lines. + """ + # Retrieve the prompt template from your database or config: + # e.g. your "hyde" prompt has placeholders: {message}, {num_outputs} + hyde_template = ( + await self.providers.database.prompts_handler.get_cached_prompt( + prompt_name="hyde", + inputs={"message": query, "num_outputs": num_sub_queries}, + ) + ) + + # Now call the LLM with that as the system or user prompt: + completion_config = GenerationConfig( + model=self.config.app.fast_llm, # or whichever short/cheap model + max_tokens=512, + temperature=0.7, + stream=False, + ) + + response = await self.providers.llm.aget_completion( + messages=[{"role": "system", "content": hyde_template}], + generation_config=completion_config, + ) + + # Suppose the LLM returns something like: + # + # "Doc1. Some made up text.\n\nDoc2. Another made up text.\n\n" + # + # So we split by double-newline or some pattern: + raw_text = response.choices[0].message.content + return [ + chunk.strip() + for chunk in (raw_text or "").split("\n\n") + if chunk.strip() + ] + + async def search_documents( + self, + query: str, + settings: SearchSettings, + query_embedding: Optional[list[float]] = None, + ) -> list[DocumentResponse]: + if query_embedding is None: + query_embedding = ( + await self.providers.completion_embedding.async_get_embedding( + query + ) + ) + result = ( + await self.providers.database.documents_handler.search_documents( + query_text=query, + settings=settings, + query_embedding=query_embedding, + ) + ) + return result + + async def completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + *args, + **kwargs, + ): + return await self.providers.llm.aget_completion( + [message.to_dict() for message in messages], # type: ignore + generation_config, + *args, + **kwargs, + ) + + async def embedding( + self, + text: str, + ): + return await self.providers.completion_embedding.async_get_embedding( + text=text + ) + + async def rag( + self, + query: str, + rag_generation_config: GenerationConfig, + search_settings: SearchSettings = SearchSettings(), + system_prompt_name: str | None = None, + task_prompt_name: str | None = None, + include_web_search: bool = False, + **kwargs, + ) -> Any: + """ + A single RAG method that can do EITHER a one-shot synchronous RAG or + streaming SSE-based RAG, depending on rag_generation_config.stream. + + 1) Perform aggregator search => context + 2) Build system+task prompts => messages + 3) If not streaming => normal LLM call => return RAGResponse + 4) If streaming => return an async generator of SSE lines + """ + # 1) Possibly fix up any UUID filters in search_settings + for f, val in list(search_settings.filters.items()): + if isinstance(val, UUID): + search_settings.filters[f] = str(val) + + try: + # 2) Perform search => aggregated_results + aggregated_results = await self.search(query, search_settings) + # 3) Optionally add web search results if flag is enabled + if include_web_search: + web_results = await self._perform_web_search(query) + # Merge web search results with existing aggregated results + if web_results and web_results.web_search_results: + if not aggregated_results.web_search_results: + aggregated_results.web_search_results = ( + web_results.web_search_results + ) + else: + aggregated_results.web_search_results.extend( + web_results.web_search_results + ) + # 3) Build context from aggregator + collector = SearchResultsCollector() + collector.add_aggregate_result(aggregated_results) + context_str = format_search_results_for_llm( + aggregated_results, collector + ) + + # 4) Prepare system+task messages + system_prompt_name = system_prompt_name or "system" + task_prompt_name = task_prompt_name or "rag" + task_prompt = kwargs.get("task_prompt") + + messages = await self.providers.database.prompts_handler.get_message_payload( + system_prompt_name=system_prompt_name, + task_prompt_name=task_prompt_name, + task_inputs={"query": query, "context": context_str}, + task_prompt=task_prompt, + ) + + # 5) Check streaming vs. non-streaming + if not rag_generation_config.stream: + # ========== Non-Streaming Logic ========== + response = await self.providers.llm.aget_completion( + messages=messages, + generation_config=rag_generation_config, + ) + llm_text = response.choices[0].message.content + + # (a) Extract short-ID references from final text + raw_sids = extract_citations(llm_text or "") + + # (b) Possibly prune large content out of metadata + metadata = response.dict() + if "choices" in metadata and len(metadata["choices"]) > 0: + metadata["choices"][0]["message"].pop("content", None) + + # (c) Build final RAGResponse + rag_resp = RAGResponse( + generated_answer=llm_text or "", + search_results=aggregated_results, + citations=[ + Citation( + id=f"{sid}", + object="citation", + payload=dump_obj( # type: ignore + self._find_item_by_shortid(sid, collector) + ), + ) + for sid in raw_sids + ], + metadata=metadata, + completion=llm_text or "", + ) + return rag_resp + + else: + # ========== Streaming SSE Logic ========== + async def sse_generator() -> AsyncGenerator[str, None]: + # 1) Emit search results via SSEFormatter + async for line in SSEFormatter.yield_search_results_event( + aggregated_results + ): + yield line + + # Initialize citation tracker to manage citation state + citation_tracker = CitationTracker() + + # Store citation payloads by ID for reuse + citation_payloads = {} + + partial_text_buffer = "" + + # Begin streaming from the LLM + msg_stream = self.providers.llm.aget_completion_stream( + messages=messages, + generation_config=rag_generation_config, + ) + + try: + async for chunk in msg_stream: + delta = chunk.choices[0].delta + finish_reason = chunk.choices[0].finish_reason + # if delta.thinking: + # check if delta has `thinking` attribute + + if hasattr(delta, "thinking") and delta.thinking: + # Emit SSE "thinking" event + async for ( + line + ) in SSEFormatter.yield_thinking_event( + delta.thinking + ): + yield line + + if delta.content: + # (b) Emit SSE "message" event for this chunk of text + async for ( + line + ) in SSEFormatter.yield_message_event( + delta.content + ): + yield line + + # Accumulate new text + partial_text_buffer += delta.content + + # (a) Extract citations from updated buffer + # For each *new* short ID, emit an SSE "citation" event + # Find new citation spans in the accumulated text + new_citation_spans = find_new_citation_spans( + partial_text_buffer, citation_tracker + ) + + # Process each new citation span + for cid, spans in new_citation_spans.items(): + for span in spans: + # Check if this is the first time we've seen this citation ID + is_new_citation = ( + citation_tracker.is_new_citation( + cid + ) + ) + + # Get payload if it's a new citation + payload = None + if is_new_citation: + source_obj = ( + self._find_item_by_shortid( + cid, collector + ) + ) + if source_obj: + # Store payload for reuse + payload = dump_obj(source_obj) + citation_payloads[cid] = ( + payload + ) + + # Create citation event payload + citation_data = { + "id": cid, + "object": "citation", + "is_new": is_new_citation, + "span": { + "start": span[0], + "end": span[1], + }, + } + + # Only include full payload for new citations + if is_new_citation and payload: + citation_data["payload"] = payload + + # Emit the citation event + async for ( + line + ) in SSEFormatter.yield_citation_event( + citation_data + ): + yield line + + # If the LLM signals it’s done + if finish_reason == "stop": + # Prepare consolidated citations for final answer event + consolidated_citations = [] + # Group citations by ID with all their spans + for ( + cid, + spans, + ) in citation_tracker.get_all_spans().items(): + if cid in citation_payloads: + consolidated_citations.append( + { + "id": cid, + "object": "citation", + "spans": [ + { + "start": s[0], + "end": s[1], + } + for s in spans + ], + "payload": citation_payloads[ + cid + ], + } + ) + + # (c) Emit final answer + all collected citations + final_answer_evt = { + "id": "msg_final", + "object": "rag.final_answer", + "generated_answer": partial_text_buffer, + "citations": consolidated_citations, + } + async for ( + line + ) in SSEFormatter.yield_final_answer_event( + final_answer_evt + ): + yield line + + # (d) Signal the end of the SSE stream + yield SSEFormatter.yield_done_event() + break + + except Exception as e: + logger.error(f"Error streaming LLM in rag: {e}") + # Optionally yield an SSE "error" event or handle differently + raise + + return sse_generator() + + except Exception as e: + logger.exception(f"Error in RAG pipeline: {e}") + if "NoneType" in str(e): + raise HTTPException( + status_code=502, + detail="Server not reachable or returned an invalid response", + ) from e + raise HTTPException( + status_code=500, + detail=f"Internal RAG Error - {str(e)}", + ) from e + + def _find_item_by_shortid( + self, sid: str, collector: SearchResultsCollector + ) -> Optional[tuple[str, Any, int]]: + """ + Example helper that tries to match aggregator items by short ID, + meaning result_obj.id starts with sid. + """ + for source_type, result_obj in collector.get_all_results(): + # if the aggregator item has an 'id' attribute + if getattr(result_obj, "id", None) is not None: + full_id_str = str(result_obj.id) + if full_id_str.startswith(sid): + if source_type == "chunk": + return ( + result_obj.as_dict() + ) # (source_type, result_obj.as_dict()) + else: + return result_obj # (source_type, result_obj) + return None + + async def agent( + self, + rag_generation_config: GenerationConfig, + rag_tools: Optional[list[str]] = None, + tools: Optional[list[str]] = None, # backward compatibility + search_settings: SearchSettings = SearchSettings(), + task_prompt: Optional[str] = None, + include_title_if_available: Optional[bool] = False, + conversation_id: Optional[UUID] = None, + message: Optional[Message] = None, + messages: Optional[list[Message]] = None, + use_system_context: bool = False, + max_tool_context_length: int = 32_768, + research_tools: Optional[list[str]] = None, + research_generation_config: Optional[GenerationConfig] = None, + needs_initial_conversation_name: Optional[bool] = None, + mode: Optional[Literal["rag", "research"]] = "rag", + ): + """ + Engage with an intelligent agent for information retrieval, analysis, and research. + + Args: + rag_generation_config: Configuration for RAG mode generation + search_settings: Search configuration for retrieving context + task_prompt: Optional custom prompt override + include_title_if_available: Whether to include document titles + conversation_id: Optional conversation ID for continuity + message: Current message to process + messages: List of messages (deprecated) + use_system_context: Whether to use extended prompt + max_tool_context_length: Maximum context length for tools + rag_tools: List of tools for RAG mode + research_tools: List of tools for Research mode + research_generation_config: Configuration for Research mode generation + mode: Either "rag" or "research" + + Returns: + Agent response with messages and conversation ID + """ + try: + # Validate message inputs + if message and messages: + raise R2RException( + status_code=400, + message="Only one of message or messages should be provided", + ) + + if not message and not messages: + raise R2RException( + status_code=400, + message="Either message or messages should be provided", + ) + + # Ensure 'message' is a Message instance + if message and not isinstance(message, Message): + if isinstance(message, dict): + message = Message.from_dict(message) + else: + raise R2RException( + status_code=400, + message=""" + Invalid message format. The expected format contains: + role: MessageType | 'system' | 'user' | 'assistant' | 'function' + content: Optional[str] + name: Optional[str] + function_call: Optional[dict[str, Any]] + tool_calls: Optional[list[dict[str, Any]]] + """, + ) + + # Ensure 'messages' is a list of Message instances + if messages: + processed_messages = [] + for msg in messages: + if isinstance(msg, Message): + processed_messages.append(msg) + elif hasattr(msg, "dict"): + processed_messages.append( + Message.from_dict(msg.dict()) + ) + elif isinstance(msg, dict): + processed_messages.append(Message.from_dict(msg)) + else: + processed_messages.append(Message.from_dict(str(msg))) + messages = processed_messages + else: + messages = [] + + # Validate and process mode-specific configurations + if mode == "rag" and research_tools: + logger.warning( + "research_tools provided but mode is 'rag'. These tools will be ignored." + ) + research_tools = None + + # Determine effective generation config based on mode + effective_generation_config = rag_generation_config + if mode == "research" and research_generation_config: + effective_generation_config = research_generation_config + + # Set appropriate LLM model based on mode if not explicitly specified + if "model" not in effective_generation_config.__fields_set__: + if mode == "rag": + effective_generation_config.model = ( + self.config.app.quality_llm + ) + elif mode == "research": + effective_generation_config.model = ( + self.config.app.planning_llm + ) + + # Transform UUID filters to strings + for filter_key, value in search_settings.filters.items(): + if isinstance(value, UUID): + search_settings.filters[filter_key] = str(value) + + # Process conversation data + ids = [] + if conversation_id: # Fetch the existing conversation + try: + conversation_messages = await self.providers.database.conversations_handler.get_conversation( + conversation_id=conversation_id, + ) + if needs_initial_conversation_name is None: + overview = await self.providers.database.conversations_handler.get_conversations_overview( + offset=0, + limit=1, + conversation_ids=[conversation_id], + ) + if overview.get("total_entries", 0) > 0: + needs_initial_conversation_name = ( + overview.get("results")[0].get("name") is None # type: ignore + ) + except Exception as e: + logger.error(f"Error fetching conversation: {str(e)}") + + if conversation_messages is not None: + messages_from_conversation: list[Message] = [] + for message_response in conversation_messages: + if isinstance(message_response, MessageResponse): + messages_from_conversation.append( + message_response.message + ) + ids.append(message_response.id) + else: + logger.warning( + f"Unexpected type in conversation found: {type(message_response)}\n{message_response}" + ) + messages = messages_from_conversation + messages + else: # Create new conversation + conversation_response = await self.providers.database.conversations_handler.create_conversation() + conversation_id = conversation_response.id + needs_initial_conversation_name = True + + if message: + messages.append(message) + + if not messages: + raise R2RException( + status_code=400, + message="No messages to process", + ) + + current_message = messages[-1] + logger.debug( + f"Running the agent with conversation_id = {conversation_id} and message = {current_message}" + ) + + # Save the new message to the conversation + parent_id = ids[-1] if ids else None + message_response = await self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=current_message, + parent_id=parent_id, + ) + + message_id = ( + message_response.id if message_response is not None else None + ) + + # Extract filter information from search settings + filter_user_id, filter_collection_ids = ( + self._parse_user_and_collection_filters( + search_settings.filters + ) + ) + + # Validate system instruction configuration + if use_system_context and task_prompt: + raise R2RException( + status_code=400, + message="Both use_system_context and task_prompt cannot be True at the same time", + ) + + # Build the system instruction + if task_prompt: + system_instruction = task_prompt + else: + system_instruction = ( + await self._build_aware_system_instruction( + max_tool_context_length=max_tool_context_length, + filter_user_id=filter_user_id, + filter_collection_ids=filter_collection_ids, + model=effective_generation_config.model, + use_system_context=use_system_context, + mode=mode, + ) + ) + + # Configure agent with appropriate tools + agent_config = deepcopy(self.config.agent) + if mode == "rag": + # Use provided RAG tools or default from config + agent_config.rag_tools = ( + rag_tools or tools or self.config.agent.rag_tools + ) + else: # research mode + # Use provided Research tools or default from config + agent_config.research_tools = ( + research_tools or tools or self.config.agent.research_tools + ) + + # Create the agent using our factory + mode = mode or "rag" + + for msg in messages: + if msg.content is None: + msg.content = "" + + agent = AgentFactory.create_agent( + mode=mode, + database_provider=self.providers.database, + llm_provider=self.providers.llm, + config=agent_config, + search_settings=search_settings, + generation_config=effective_generation_config, + app_config=self.config.app, + knowledge_search_method=self.search, + content_method=self.get_context, + file_search_method=self.search_documents, + max_tool_context_length=max_tool_context_length, + rag_tools=rag_tools, + research_tools=research_tools, + tools=tools, # Backward compatibility + ) + + # Handle streaming vs. non-streaming response + if effective_generation_config.stream: + + async def stream_response(): + try: + async for chunk in agent.arun( + messages=messages, + system_instruction=system_instruction, + include_title_if_available=include_title_if_available, + ): + yield chunk + except Exception as e: + logger.error(f"Error streaming agent output: {e}") + raise e + finally: + # Persist conversation data + msgs = [ + msg.to_dict() + for msg in agent.conversation.messages + ] + input_tokens = num_tokens_from_messages(msgs[:-1]) + output_tokens = num_tokens_from_messages([msgs[-1]]) + await self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=agent.conversation.messages[-1], + parent_id=message_id, + metadata={ + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + ) + + # Generate conversation name if needed + if needs_initial_conversation_name: + try: + prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict())}" + conversation_name = ( + ( + await self.providers.llm.aget_completion( + [ + { + "role": "system", + "content": prompt, + } + ], + GenerationConfig( + model=self.config.app.fast_llm + ), + ) + ) + .choices[0] + .message.content + ) + await self.providers.database.conversations_handler.update_conversation( + conversation_id=conversation_id, + name=conversation_name, + ) + except Exception as e: + logger.error( + f"Error generating conversation name: {e}" + ) + + return stream_response() + else: + for idx, msg in enumerate(messages): + if msg.content is None: + if ( + hasattr(msg, "structured_content") + and msg.structured_content + ): + messages[idx].content = "" + else: + messages[idx].content = "" + + # Non-streaming path + results = await agent.arun( + messages=messages, + system_instruction=system_instruction, + include_title_if_available=include_title_if_available, + ) + + # Process the agent results + if isinstance(results[-1], dict): + if results[-1].get("content") is None: + results[-1]["content"] = "" + assistant_message = Message(**results[-1]) + elif isinstance(results[-1], Message): + assistant_message = results[-1] + if assistant_message.content is None: + assistant_message.content = "" + else: + assistant_message = Message( + role="assistant", content=str(results[-1]) + ) + + # Get search results collector for citations + if hasattr(agent, "search_results_collector"): + collector = agent.search_results_collector + else: + collector = SearchResultsCollector() + + # Extract content from the message + structured_content = assistant_message.structured_content + structured_content = ( + structured_content[-1].get("text") + if structured_content + else None + ) + raw_text = ( + assistant_message.content or structured_content or "" + ) + # Process citations + short_ids = extract_citations(raw_text or "") + final_citations = [] + for sid in short_ids: + obj = collector.find_by_short_id(sid) + final_citations.append( + { + "id": sid, + "object": "citation", + "payload": dump_obj(obj) if obj else None, + } + ) + + # Persist in conversation DB + await ( + self.providers.database.conversations_handler.add_message( + conversation_id=conversation_id, + content=assistant_message, + parent_id=message_id, + metadata={ + "citations": final_citations, + "aggregated_search_result": json.dumps( + dump_collector(collector) + ), + }, + ) + ) + + # Generate conversation name if needed + if needs_initial_conversation_name: + conversation_name = None + try: + prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict() if message else {})}" + conversation_name = ( + ( + await self.providers.llm.aget_completion( + [{"role": "system", "content": prompt}], + GenerationConfig( + model=self.config.app.fast_llm + ), + ) + ) + .choices[0] + .message.content + ) + except Exception as e: + pass + finally: + await self.providers.database.conversations_handler.update_conversation( + conversation_id=conversation_id, + name=conversation_name or "", + ) + + tool_calls = [] + if hasattr(agent, "tool_calls"): + if agent.tool_calls is not None: + tool_calls = agent.tool_calls + else: + logger.warning( + "agent.tool_calls is None, using empty list instead" + ) + # Return the final response + return { + "messages": [ + Message( + role="assistant", + content=assistant_message.content + or structured_content + or "", + metadata={ + "citations": final_citations, + "tool_calls": tool_calls, + "aggregated_search_result": json.dumps( + dump_collector(collector) + ), + }, + ) + ], + "conversation_id": str(conversation_id), + } + + except Exception as e: + logger.error(f"Error in agent response: {str(e)}") + if "NoneType" in str(e): + raise HTTPException( + status_code=502, + detail="Server not reachable or returned an invalid response", + ) from e + raise HTTPException( + status_code=500, + detail=f"Internal Server Error - {str(e)}", + ) from e + + async def get_context( + self, + filters: dict[str, Any], + options: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + Return an ordered list of documents (with minimal overview fields), + plus all associated chunks in ascending chunk order. + + Only the filters: owner_id, collection_ids, and document_id + are supported. If any other filter or operator is passed in, + we raise an error. + + Args: + filters: A dictionary describing the allowed filters + (owner_id, collection_ids, document_id). + options: A dictionary with extra options, e.g. include_summary_embedding + or any custom flags for additional logic. + + Returns: + A list of dicts, where each dict has: + { + "document": <DocumentResponse>, + "chunks": [ <chunk0>, <chunk1>, ... ] + } + """ + # 2. Fetch matching documents + matching_docs = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=-1, + filters=filters, + include_summary_embedding=options.get( + "include_summary_embedding", False + ), + ) + + if not matching_docs["results"]: + return [] + + # 3. For each document, fetch associated chunks in ascending chunk order + results = [] + for doc_response in matching_docs["results"]: + doc_id = doc_response.id + chunk_data = await self.providers.database.chunks_handler.list_document_chunks( + document_id=doc_id, + offset=0, + limit=-1, # get all chunks + include_vectors=False, + ) + chunks = chunk_data["results"] # already sorted by chunk_order + doc_response.chunks = chunks + # 4. Build a returned structure that includes doc + chunks + results.append(doc_response.model_dump()) + + return results + + def _parse_user_and_collection_filters( + self, + filters: dict[str, Any], + ): + ### TODO - Come up with smarter way to extract owner / collection ids for non-admin + filter_starts_with_and = filters.get("$and") + filter_starts_with_or = filters.get("$or") + if filter_starts_with_and: + try: + filter_starts_with_and_then_or = filter_starts_with_and[0][ + "$or" + ] + + user_id = filter_starts_with_and_then_or[0]["owner_id"]["$eq"] + collection_ids = [ + UUID(ele) + for ele in filter_starts_with_and_then_or[1][ + "collection_ids" + ]["$overlap"] + ] + return user_id, [str(ele) for ele in collection_ids] + except Exception as e: + logger.error( + f"Error: {e}.\n\n While" + + """ parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored.""" + ) + return None, [] + elif filter_starts_with_or: + try: + user_id = filter_starts_with_or[0]["owner_id"]["$eq"] + collection_ids = [ + UUID(ele) + for ele in filter_starts_with_or[1]["collection_ids"][ + "$overlap" + ] + ] + return user_id, [str(ele) for ele in collection_ids] + except Exception as e: + logger.error( + """Error parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored.""" + ) + return None, [] + else: + # Admin user + return None, [] + + async def _build_documents_context( + self, + filter_user_id: Optional[UUID] = None, + max_summary_length: int = 128, + limit: int = 25, + reverse_order: bool = True, + ) -> str: + """ + Fetches documents matching the given filters and returns a formatted string + enumerating them. + """ + # We only want up to `limit` documents for brevity + docs_data = await self.providers.database.documents_handler.get_documents_overview( + offset=0, + limit=limit, + filter_user_ids=[filter_user_id] if filter_user_id else None, + include_summary_embedding=False, + sort_order="DESC" if reverse_order else "ASC", + ) + + found_max = False + if len(docs_data["results"]) == limit: + found_max = True + + docs = docs_data["results"] + if not docs: + return "No documents found." + + lines = [] + for i, doc in enumerate(docs, start=1): + if ( + not doc.summary + or doc.ingestion_status != IngestionStatus.SUCCESS + ): + lines.append( + f"[{i}] Title: {doc.title}, Summary: (Summary not available), Status:{doc.ingestion_status} ID: {doc.id}" + ) + continue + + # Build a line referencing the doc + title = doc.title or "(Untitled Document)" + lines.append( + f"[{i}] Title: {title}, Summary: {(doc.summary[0:max_summary_length] + ('...' if len(doc.summary) > max_summary_length else ''),)}, Total Tokens: {doc.total_tokens}, ID: {doc.id}" + ) + if found_max: + lines.append( + f"Note: Displaying only the first {limit} documents. Use a filter to narrow down the search if more documents are required." + ) + + return "\n".join(lines) + + async def _build_aware_system_instruction( + self, + max_tool_context_length: int = 10_000, + filter_user_id: Optional[UUID] = None, + filter_collection_ids: Optional[list[UUID]] = None, + model: Optional[str] = None, + use_system_context: bool = False, + mode: Optional[str] = "rag", + ) -> str: + """ + High-level method that: + 1) builds the documents context + 2) builds the collections context + 3) loads the new `dynamic_reasoning_rag_agent` prompt + """ + date_str = str(datetime.now().strftime("%m/%d/%Y")) + + # "dynamic_rag_agent" // "static_rag_agent" + + if mode == "rag": + prompt_name = ( + self.config.agent.rag_agent_dynamic_prompt + if use_system_context + else self.config.agent.rag_rag_agent_static_prompt + ) + else: + prompt_name = "static_research_agent" + return await self.providers.database.prompts_handler.get_cached_prompt( + # We use custom tooling and a custom agent to handle gemini models + prompt_name, + inputs={ + "date": date_str, + }, + ) + + if model is not None and ("deepseek" in model): + prompt_name = f"{prompt_name}_xml_tooling" + + if use_system_context: + doc_context_str = await self._build_documents_context( + filter_user_id=filter_user_id, + ) + logger.debug(f"Loading prompt {prompt_name}") + # Now fetch the prompt from the database prompts handler + # This relies on your "rag_agent_extended" existing with + # placeholders: date, document_context + system_prompt = await self.providers.database.prompts_handler.get_cached_prompt( + # We use custom tooling and a custom agent to handle gemini models + prompt_name, + inputs={ + "date": date_str, + "max_tool_context_length": max_tool_context_length, + "document_context": doc_context_str, + }, + ) + else: + system_prompt = await self.providers.database.prompts_handler.get_cached_prompt( + prompt_name, + inputs={ + "date": date_str, + }, + ) + logger.debug(f"Running agent with system prompt = {system_prompt}") + return system_prompt + + async def _perform_web_search( + self, + query: str, + search_settings: SearchSettings = SearchSettings(), + ) -> AggregateSearchResult: + """ + Perform a web search using an external search engine API (Serper). + + Args: + query: The search query string + search_settings: Optional search settings to customize the search + + Returns: + AggregateSearchResult containing web search results + """ + try: + # Import the Serper client here to avoid circular imports + from core.utils.serper import SerperClient + + # Initialize the Serper client + serper_client = SerperClient() + + # Perform the raw search using Serper API + raw_results = serper_client.get_raw(query) + + # Process the raw results into a WebSearchResult object + web_response = WebSearchResult.from_serper_results(raw_results) + + # Create an AggregateSearchResult with the web search results + agg_result = AggregateSearchResult( + chunk_search_results=None, + graph_search_results=None, + web_search_results=web_response.organic_results, + ) + + # Log the search for monitoring purposes + logger.debug(f"Web search completed for query: {query}") + logger.debug( + f"Found {len(web_response.organic_results)} web results" + ) + + return agg_result + + except Exception as e: + logger.error(f"Error performing web search: {str(e)}") + # Return empty results rather than failing completely + return AggregateSearchResult( + chunk_search_results=None, + graph_search_results=None, + web_search_results=[], + ) + + +class RetrievalServiceAdapter: + @staticmethod + def _parse_user_data(user_data): + if isinstance(user_data, str): + try: + user_data = json.loads(user_data) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid user data format: {user_data}" + ) from e + return User.from_dict(user_data) + + @staticmethod + def prepare_search_input( + query: str, + search_settings: SearchSettings, + user: User, + ) -> dict: + return { + "query": query, + "search_settings": search_settings.to_dict(), + "user": user.to_dict(), + } + + @staticmethod + def parse_search_input(data: dict): + return { + "query": data["query"], + "search_settings": SearchSettings.from_dict( + data["search_settings"] + ), + "user": RetrievalServiceAdapter._parse_user_data(data["user"]), + } + + @staticmethod + def prepare_rag_input( + query: str, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + task_prompt: Optional[str], + include_web_search: bool, + user: User, + ) -> dict: + return { + "query": query, + "search_settings": search_settings.to_dict(), + "rag_generation_config": rag_generation_config.to_dict(), + "task_prompt": task_prompt, + "include_web_search": include_web_search, + "user": user.to_dict(), + } + + @staticmethod + def parse_rag_input(data: dict): + return { + "query": data["query"], + "search_settings": SearchSettings.from_dict( + data["search_settings"] + ), + "rag_generation_config": GenerationConfig.from_dict( + data["rag_generation_config"] + ), + "task_prompt": data["task_prompt"], + "include_web_search": data["include_web_search"], + "user": RetrievalServiceAdapter._parse_user_data(data["user"]), + } + + @staticmethod + def prepare_agent_input( + message: Message, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + task_prompt: Optional[str], + include_title_if_available: bool, + user: User, + conversation_id: Optional[str] = None, + ) -> dict: + return { + "message": message.to_dict(), + "search_settings": search_settings.to_dict(), + "rag_generation_config": rag_generation_config.to_dict(), + "task_prompt": task_prompt, + "include_title_if_available": include_title_if_available, + "user": user.to_dict(), + "conversation_id": conversation_id, + } + + @staticmethod + def parse_agent_input(data: dict): + return { + "message": Message.from_dict(data["message"]), + "search_settings": SearchSettings.from_dict( + data["search_settings"] + ), + "rag_generation_config": GenerationConfig.from_dict( + data["rag_generation_config"] + ), + "task_prompt": data["task_prompt"], + "include_title_if_available": data["include_title_if_available"], + "user": RetrievalServiceAdapter._parse_user_data(data["user"]), + "conversation_id": data.get("conversation_id"), + } |