aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/main
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main')
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/__init__.py24
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/abstractions.py82
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py151
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py422
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py1207
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py737
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py2342
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py1065
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py2051
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py576
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py387
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py639
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py186
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py1721
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/app.py121
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/app_entry.py125
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py12
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/builder.py127
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/factory.py417
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/config.py213
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py16
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py539
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py721
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py222
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py598
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/__init__.py14
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/auth_service.py316
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/base.py14
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/graph_service.py1358
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py983
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/management_service.py1084
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py2087
34 files changed, 20557 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/__init__.py b/.venv/lib/python3.12/site-packages/core/main/__init__.py
new file mode 100644
index 00000000..7043d029
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/__init__.py
@@ -0,0 +1,24 @@
+from .abstractions import R2RProviders
+from .api import *
+from .app import *
+
+# from .app_entry import r2r_app
+from .assembly import *
+from .orchestration import *
+from .services import *
+
+__all__ = [
+ # R2R Primary
+ "R2RProviders",
+ "R2RApp",
+ "R2RBuilder",
+ "R2RConfig",
+ # Factory
+ "R2RProviderFactory",
+ ## R2R SERVICES
+ "AuthService",
+ "IngestionService",
+ "ManagementService",
+ "RetrievalService",
+ "GraphService",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/abstractions.py b/.venv/lib/python3.12/site-packages/core/main/abstractions.py
new file mode 100644
index 00000000..3aaf2dbf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/abstractions.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel
+
+from core.providers import (
+ AnthropicCompletionProvider,
+ AsyncSMTPEmailProvider,
+ ClerkAuthProvider,
+ ConsoleMockEmailProvider,
+ HatchetOrchestrationProvider,
+ JwtAuthProvider,
+ LiteLLMCompletionProvider,
+ LiteLLMEmbeddingProvider,
+ MailerSendEmailProvider,
+ OllamaEmbeddingProvider,
+ OpenAICompletionProvider,
+ OpenAIEmbeddingProvider,
+ PostgresDatabaseProvider,
+ R2RAuthProvider,
+ R2RCompletionProvider,
+ R2RIngestionProvider,
+ SendGridEmailProvider,
+ SimpleOrchestrationProvider,
+ SupabaseAuthProvider,
+ UnstructuredIngestionProvider,
+)
+
+if TYPE_CHECKING:
+ from core.main.services.auth_service import AuthService
+ from core.main.services.graph_service import GraphService
+ from core.main.services.ingestion_service import IngestionService
+ from core.main.services.management_service import ManagementService
+ from core.main.services.retrieval_service import ( # type: ignore
+ RetrievalService, # type: ignore
+ )
+
+
+class R2RProviders(BaseModel):
+ auth: (
+ R2RAuthProvider
+ | SupabaseAuthProvider
+ | JwtAuthProvider
+ | ClerkAuthProvider
+ )
+ database: PostgresDatabaseProvider
+ ingestion: R2RIngestionProvider | UnstructuredIngestionProvider
+ embedding: (
+ LiteLLMEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ | OllamaEmbeddingProvider
+ )
+ completion_embedding: (
+ LiteLLMEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ | OllamaEmbeddingProvider
+ )
+ llm: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ )
+ orchestration: HatchetOrchestrationProvider | SimpleOrchestrationProvider
+ email: (
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ )
+
+ class Config:
+ arbitrary_types_allowed = True
+
+
+@dataclass
+class R2RServices:
+ auth: "AuthService"
+ ingestion: "IngestionService"
+ management: "ManagementService"
+ retrieval: "RetrievalService"
+ graph: "GraphService"
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py
new file mode 100644
index 00000000..ef432420
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py
@@ -0,0 +1,151 @@
+import functools
+import logging
+from abc import abstractmethod
+from typing import Callable
+
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import FileResponse, StreamingResponse
+
+from core.base import R2RException
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+
+logger = logging.getLogger()
+
+
+class BaseRouterV3:
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ """
+ :param providers: Typically includes auth, database, etc.
+ :param services: Additional service references (ingestion, etc).
+ """
+ self.providers = providers
+ self.services = services
+ self.config = config
+ self.router = APIRouter()
+ self.openapi_extras = self._load_openapi_extras()
+
+ # Add the rate-limiting dependency
+ self.set_rate_limiting()
+
+ # Initialize any routes
+ self._setup_routes()
+ self._register_workflows()
+
+ def get_router(self):
+ return self.router
+
+ def base_endpoint(self, func: Callable):
+ """
+ A decorator to wrap endpoints in a standard pattern:
+ - error handling
+ - response shaping
+ """
+
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ try:
+ func_result = await func(*args, **kwargs)
+ if isinstance(func_result, tuple) and len(func_result) == 2:
+ results, outer_kwargs = func_result
+ else:
+ results, outer_kwargs = func_result, {}
+
+ if isinstance(results, (StreamingResponse, FileResponse)):
+ return results
+ return {"results": results, **outer_kwargs}
+
+ except R2RException:
+ raise
+ except Exception as e:
+ logger.error(
+ f"Error in base endpoint {func.__name__}() - {str(e)}",
+ exc_info=True,
+ )
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "message": f"An error '{e}' occurred during {func.__name__}",
+ "error": str(e),
+ "error_type": type(e).__name__,
+ },
+ ) from e
+
+ wrapper._is_base_endpoint = True # type: ignore
+ return wrapper
+
+ @classmethod
+ def build_router(cls, engine):
+ """Class method for building a router instance (if you have a standard
+ pattern)."""
+ return cls(engine).router
+
+ def _register_workflows(self):
+ pass
+
+ def _load_openapi_extras(self):
+ return {}
+
+ @abstractmethod
+ def _setup_routes(self):
+ """Subclasses override this to define actual endpoints."""
+ pass
+
+ def set_rate_limiting(self):
+ """Adds a yield-based dependency for rate limiting each request.
+
+ Checks the limits, then logs the request if the check passes.
+ """
+
+ async def rate_limit_dependency(
+ request: Request,
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ):
+ """1) Fetch the user from the DB (including .limits_overrides).
+
+ 2) Pass it to limits_handler.check_limits. 3) After the endpoint
+ completes, call limits_handler.log_request.
+ """
+ # If the user is superuser, skip checks
+ if auth_user.is_superuser:
+ yield
+ return
+
+ user_id = auth_user.id
+ route = request.scope["path"]
+
+ # 1) Fetch the user from DB
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if not user:
+ raise HTTPException(status_code=404, detail="User not found.")
+
+ # 2) Rate-limit check
+ try:
+ await self.providers.database.limits_handler.check_limits(
+ user=user,
+ route=route, # Pass the User object
+ )
+ except ValueError as e:
+ # If check_limits raises ValueError -> 429 Too Many Requests
+ raise HTTPException(status_code=429, detail=str(e)) from e
+
+ request.state.user_id = user_id
+ request.state.route = route
+
+ # 3) Execute the route
+ try:
+ yield
+ finally:
+ # 4) Log only POST and DELETE requests
+ if request.method in ["POST", "DELETE"]:
+ await self.providers.database.limits_handler.log_request(
+ user_id, route
+ )
+
+ # Attach the dependencies so you can use them in your endpoints
+ self.rate_limit_dependency = rate_limit_dependency
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py
new file mode 100644
index 00000000..ab0a62cb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py
@@ -0,0 +1,422 @@
+import json
+import logging
+import textwrap
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+
+from core.base import (
+ ChunkResponse,
+ GraphSearchSettings,
+ R2RException,
+ SearchSettings,
+ UpdateChunk,
+ select_search_filters,
+)
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedChunkResponse,
+ WrappedChunksResponse,
+ WrappedVectorSearchResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+MAX_CHUNKS_PER_REQUEST = 1024 * 100
+
+
+class ChunksRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing ChunksRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/chunks/search",
+ summary="Search Chunks",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.search(
+ query="search query",
+ search_settings={
+ "limit": 10
+ }
+ )
+ """),
+ }
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def search_chunks(
+ query: str = Body(...),
+ search_settings: SearchSettings = Body(
+ default_factory=SearchSettings,
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedVectorSearchResponse: # type: ignore
+ # TODO - Deduplicate this code by sharing the code on the retrieval router
+ """Perform a semantic search query over all stored chunks.
+
+ This endpoint allows for complex filtering of search results using PostgreSQL-based queries.
+ Filters can be applied to various fields such as document_id, and internal metadata values.
+
+ Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+ """
+
+ search_settings.filters = select_search_filters(
+ auth_user, search_settings
+ )
+
+ search_settings.graph_settings = GraphSearchSettings(enabled=False)
+
+ results = await self.services.retrieval.search(
+ query=query,
+ search_settings=search_settings,
+ )
+ return results.chunk_search_results # type: ignore
+
+ @self.router.get(
+ "/chunks/{id}",
+ summary="Retrieve Chunk",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.retrieve(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.retrieve({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def retrieve_chunk(
+ id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunkResponse:
+ """Get a specific chunk by its ID.
+
+ Returns the chunk's content, metadata, and associated
+ document/collection information. Users can only retrieve chunks
+ they own or have access to through collections.
+ """
+ chunk = await self.services.ingestion.get_chunk(id)
+ if not chunk:
+ raise R2RException("Chunk not found", 404)
+
+ # TODO - Add collection ID check
+ if not auth_user.is_superuser and str(auth_user.id) != str(
+ chunk["owner_id"]
+ ):
+ raise R2RException("Not authorized to access this chunk", 403)
+
+ return ChunkResponse( # type: ignore
+ id=chunk["id"],
+ document_id=chunk["document_id"],
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ text=chunk["text"],
+ metadata=chunk["metadata"],
+ # vector = chunk["vector"] # TODO - Add include vector flag
+ )
+
+ @self.router.post(
+ "/chunks/{id}",
+ summary="Update Chunk",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.update(
+ {
+ "id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ "text": "Updated content",
+ "metadata": {"key": "new value"}
+ }
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.update({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ text: "Updated content",
+ metadata: {key: "new value"}
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_chunk(
+ id: UUID = Path(...),
+ chunk_update: UpdateChunk = Body(...),
+ # TODO: Run with orchestration?
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunkResponse:
+ """Update an existing chunk's content and/or metadata.
+
+ The chunk's vectors will be automatically recomputed based on the
+ new content. Users can only update chunks they own unless they are
+ superusers.
+ """
+ # Get the existing chunk to get its chunk_id
+ existing_chunk = await self.services.ingestion.get_chunk(
+ chunk_update.id
+ )
+ if existing_chunk is None:
+ raise R2RException(f"Chunk {chunk_update.id} not found", 404)
+
+ workflow_input = {
+ "document_id": str(existing_chunk["document_id"]),
+ "id": str(chunk_update.id),
+ "text": chunk_update.text,
+ "metadata": chunk_update.metadata
+ or existing_chunk["metadata"],
+ "user": auth_user.model_dump_json(),
+ }
+
+ logger.info("Running chunk ingestion without orchestration.")
+ from core.main.orchestration import simple_ingestion_factory
+
+ # TODO - CLEAN THIS UP
+
+ simple_ingestor = simple_ingestion_factory(self.services.ingestion)
+ await simple_ingestor["update-chunk"](workflow_input)
+
+ return ChunkResponse( # type: ignore
+ id=chunk_update.id,
+ document_id=existing_chunk["document_id"],
+ owner_id=existing_chunk["owner_id"],
+ collection_ids=existing_chunk["collection_ids"],
+ text=chunk_update.text,
+ metadata=chunk_update.metadata or existing_chunk["metadata"],
+ # vector = existing_chunk.get('vector')
+ )
+
+ @self.router.delete(
+ "/chunks/{id}",
+ summary="Delete Chunk",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.delete(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.delete({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_chunk(
+ id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific chunk by ID.
+
+ This permanently removes the chunk and its associated vector
+ embeddings. The parent document remains unchanged. Users can only
+ delete chunks they own unless they are superusers.
+ """
+ # Get the existing chunk to get its chunk_id
+ existing_chunk = await self.services.ingestion.get_chunk(id)
+
+ if existing_chunk is None:
+ raise R2RException(
+ message=f"Chunk {id} not found", status_code=404
+ )
+
+ filters = {
+ "$and": [
+ {"owner_id": {"$eq": str(auth_user.id)}},
+ {"chunk_id": {"$eq": str(id)}},
+ ]
+ }
+ await (
+ self.services.management.delete_documents_and_chunks_by_filter(
+ filters=filters
+ )
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/chunks",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List Chunks",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.list(
+ metadata_filter={"key": "value"},
+ include_vectors=False,
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.list({
+ metadataFilter: {key: "value"},
+ includeVectors: false,
+ offset: 0,
+ limit: 10,
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_chunks(
+ metadata_filter: Optional[str] = Query(
+ None, description="Filter by metadata"
+ ),
+ include_vectors: bool = Query(
+ False, description="Include vector data in response"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunksResponse:
+ """List chunks with pagination support.
+
+ Returns a paginated list of chunks that the user has access to.
+ Results can be filtered and sorted based on various parameters.
+ Vector embeddings are only included if specifically requested.
+
+ Regular users can only list chunks they own or have access to
+ through collections. Superusers can list all chunks in the system.
+ """ # Build filters
+ filters = {}
+
+ # Add user access control filter
+ if not auth_user.is_superuser:
+ filters["owner_id"] = {"$eq": str(auth_user.id)}
+
+ # Add metadata filters if provided
+ if metadata_filter:
+ metadata_filter = json.loads(metadata_filter)
+
+ # Get chunks using the vector handler's list_chunks method
+ results = await self.services.ingestion.list_chunks(
+ filters=filters,
+ include_vectors=include_vectors,
+ offset=offset,
+ limit=limit,
+ )
+
+ # Convert to response format
+ chunks = [
+ ChunkResponse(
+ id=chunk["id"],
+ document_id=chunk["document_id"],
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ text=chunk["text"],
+ metadata=chunk["metadata"],
+ vector=chunk.get("vector") if include_vectors else None,
+ )
+ for chunk in results["results"]
+ ]
+
+ return (chunks, {"total_entries": results["total_entries"]}) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py
new file mode 100644
index 00000000..462f5ca3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py
@@ -0,0 +1,1207 @@
+import logging
+import textwrap
+from enum import Enum
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+
+from core.base import R2RException
+from core.base.abstractions import GraphCreationSettings
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedCollectionResponse,
+ WrappedCollectionsResponse,
+ WrappedDocumentsResponse,
+ WrappedGenericMessageResponse,
+ WrappedUsersResponse,
+)
+from core.utils import (
+ generate_default_user_collection_id,
+ update_settings_from_dict,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class CollectionAction(str, Enum):
+ VIEW = "view"
+ EDIT = "edit"
+ DELETE = "delete"
+ MANAGE_USERS = "manage_users"
+ ADD_DOCUMENT = "add_document"
+ REMOVE_DOCUMENT = "remove_document"
+
+
+async def authorize_collection_action(
+ auth_user, collection_id: UUID, action: CollectionAction, services
+) -> bool:
+ """Authorize a user's action on a given collection based on:
+
+ - If user is superuser (admin): Full access.
+ - If user is owner of the collection: Full access.
+ - If user is a member of the collection (in `collection_ids`): VIEW only.
+ - Otherwise: No access.
+ """
+
+ # Superusers have complete access
+ if auth_user.is_superuser:
+ return True
+
+ # Fetch collection details: owner_id and members
+ results = (
+ await services.management.collections_overview(
+ 0, 1, collection_ids=[collection_id]
+ )
+ )["results"]
+ if len(results) == 0:
+ raise R2RException("The specified collection does not exist.", 404)
+ details = results[0]
+ owner_id = details.owner_id
+
+ # Check if user is owner
+ if auth_user.id == owner_id:
+ # Owner can do all actions
+ return True
+
+ # Check if user is a member (non-owner)
+ if collection_id in auth_user.collection_ids:
+ # Members can only view
+ if action == CollectionAction.VIEW:
+ return True
+ else:
+ raise R2RException(
+ "Insufficient permissions for this action.", 403
+ )
+
+ # User is neither owner nor member
+ raise R2RException("You do not have access to this collection.", 403)
+
+
+class CollectionsRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing CollectionsRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/collections",
+ summary="Create a new collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.create(
+ name="My New Collection",
+ description="This is a sample collection"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.create({
+ name: "My New Collection",
+ description: "This is a sample collection"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{"name": "My New Collection", "description": "This is a sample collection"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_collection(
+ name: str = Body(..., description="The name of the collection"),
+ description: Optional[str] = Body(
+ None, description="An optional description of the collection"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Create a new collection and automatically add the creating user
+ to it.
+
+ This endpoint allows authenticated users to create a new collection
+ with a specified name and optional description. The user creating
+ the collection is automatically added as a member.
+ """
+ user_collections_count = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id], limit=1, offset=0
+ )
+ )["total_entries"]
+ user_max_collections = (
+ await self.services.management.get_user_max_collections(
+ auth_user.id
+ )
+ )
+ if (user_collections_count + 1) >= user_max_collections: # type: ignore
+ raise R2RException(
+ f"User has reached the maximum number of collections allowed ({user_max_collections}).",
+ 400,
+ )
+ collection = await self.services.management.create_collection(
+ owner_id=auth_user.id,
+ name=name,
+ description=description,
+ )
+ # Add the creating user to the collection
+ await self.services.management.add_user_to_collection(
+ auth_user.id, collection.id
+ )
+ return collection # type: ignore
+
+ @self.router.post(
+ "/collections/export",
+ summary="Export collections to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.collections.export(
+ output_path="export.csv",
+ columns=["id", "name", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.collections.export({
+ outputPath: "export.csv",
+ columns: ["id", "name", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/collections/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_collections(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export collections as a CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_collections(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="collections_export.csv",
+ )
+
+ @self.router.get(
+ "/collections",
+ summary="List collections",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.list(
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections?offset=0&limit=10&name=Sample" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_collections(
+ ids: list[str] = Query(
+ [],
+ description="A list of collection IDs to retrieve. If not provided, all collections will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionsResponse:
+ """Returns a paginated list of collections the authenticated user
+ has access to.
+
+ Results can be filtered by providing specific collection IDs.
+ Regular users will only see collections they own or have access to.
+ Superusers can see all collections.
+
+ The collections are returned in order of last modification, with
+ most recent first.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ collection_uuids = [UUID(collection_id) for collection_id in ids]
+
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=requesting_user_id,
+ collection_ids=collection_uuids,
+ offset=offset,
+ limit=limit,
+ )
+ )
+
+ return ( # type: ignore
+ collections_overview_response["results"],
+ {
+ "total_entries": collections_overview_response[
+ "total_entries"
+ ]
+ },
+ )
+
+ @self.router.get(
+ "/collections/{id}",
+ summary="Get collection details",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.retrieve({id: "123e4567-e89b-12d3-a456-426614174000"});
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Get details of a specific collection.
+
+ This endpoint retrieves detailed information about a single
+ collection identified by its UUID. The user must have access to the
+ collection to view its details.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.VIEW, self.services
+ )
+
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=None,
+ collection_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+ overview = collections_overview_response["results"]
+
+ if len(overview) == 0: # type: ignore
+ raise R2RException(
+ "The specified collection does not exist.",
+ 404,
+ )
+ return overview[0] # type: ignore
+
+ @self.router.post(
+ "/collections/{id}",
+ summary="Update collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.update(
+ "123e4567-e89b-12d3-a456-426614174000",
+ name="Updated Collection Name",
+ description="Updated description"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.update({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ name: "Updated Collection Name",
+ description: "Updated description"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{"name": "Updated Collection Name", "description": "Updated description"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_collection(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the collection to update",
+ ),
+ name: Optional[str] = Body(
+ None, description="The name of the collection"
+ ),
+ description: Optional[str] = Body(
+ None, description="An optional description of the collection"
+ ),
+ generate_description: Optional[bool] = Body(
+ False,
+ description="Whether to generate a new synthetic description for the collection",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Update an existing collection's configuration.
+
+ This endpoint allows updating the name and description of an
+ existing collection. The user must have appropriate permissions to
+ modify the collection.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.EDIT, self.services
+ )
+
+ if generate_description and description is not None:
+ raise R2RException(
+ "Cannot provide both a description and request to synthetically generate a new one.",
+ 400,
+ )
+
+ return await self.services.management.update_collection( # type: ignore
+ id,
+ name=name,
+ description=description,
+ generate_description=generate_description or False,
+ )
+
+ @self.router.delete(
+ "/collections/{id}",
+ summary="Delete collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.delete({id: "123e4567-e89b-12d3-a456-426614174000"});
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_collection(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the collection to delete",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete an existing collection.
+
+ This endpoint allows deletion of a collection identified by its
+ UUID. The user must have appropriate permissions to delete the
+ collection. Deleting a collection removes all associations but does
+ not delete the documents within it.
+ """
+ if id == generate_default_user_collection_id(auth_user.id):
+ raise R2RException(
+ "Cannot delete the default user collection.",
+ 400,
+ )
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.DELETE, self.services
+ )
+
+ await self.services.management.delete_collection(collection_id=id)
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/collections/{id}/documents/{document_id}",
+ summary="Add document to collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.add_document(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "456e789a-b12c-34d5-e678-901234567890"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.addDocument({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ documentId: "456e789a-b12c-34d5-e678-901234567890"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_document_to_collection(
+ id: UUID = Path(...),
+ document_id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Add a document to a collection."""
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.ADD_DOCUMENT, self.services
+ )
+
+ return (
+ await self.services.management.assign_document_to_collection(
+ document_id, id
+ )
+ )
+
+ @self.router.get(
+ "/collections/{id}/documents",
+ summary="List documents in collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.list_documents(
+ "123e4567-e89b-12d3-a456-426614174000",
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.listDocuments({id: "123e4567-e89b-12d3-a456-426614174000"});
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_collection_documents(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentsResponse:
+ """Get all documents in a collection with pagination and sorting
+ options.
+
+ This endpoint retrieves a paginated list of documents associated
+ with a specific collection. It supports sorting options to
+ customize the order of returned documents.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.VIEW, self.services
+ )
+
+ documents_in_collection_response = (
+ await self.services.management.documents_in_collection(
+ id, offset, limit
+ )
+ )
+
+ return documents_in_collection_response["results"], { # type: ignore
+ "total_entries": documents_in_collection_response[
+ "total_entries"
+ ]
+ }
+
+ @self.router.delete(
+ "/collections/{id}/documents/{document_id}",
+ summary="Remove document from collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.remove_document(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "456e789a-b12c-34d5-e678-901234567890"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.removeDocument({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ documentId: "456e789a-b12c-34d5-e678-901234567890"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def remove_document_from_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ document_id: UUID = Path(
+ ...,
+ description="The unique identifier of the document to remove",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Remove a document from a collection.
+
+ This endpoint removes the association between a document and a
+ collection. It does not delete the document itself. The user must
+ have permissions to modify the collection.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.REMOVE_DOCUMENT, self.services
+ )
+ await self.services.management.remove_document_from_collection(
+ document_id, id
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/collections/{id}/users",
+ summary="List users in collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.list_users(
+ "123e4567-e89b-12d3-a456-426614174000",
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.listUsers({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_collection_users(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUsersResponse:
+ """Get all users in a collection with pagination and sorting
+ options.
+
+ This endpoint retrieves a paginated list of users who have access
+ to a specific collection. It supports sorting options to customize
+ the order of returned users.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.VIEW, self.services
+ )
+
+ users_in_collection_response = (
+ await self.services.management.get_users_in_collection(
+ collection_id=id,
+ offset=offset,
+ limit=min(max(limit, 1), 1000),
+ )
+ )
+
+ return users_in_collection_response["results"], { # type: ignore
+ "total_entries": users_in_collection_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/collections/{id}/users/{user_id}",
+ summary="Add user to collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.add_user(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "789a012b-c34d-5e6f-g789-012345678901"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.addUser({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ userId: "789a012b-c34d-5e6f-g789-012345678901"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_user_to_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ user_id: UUID = Path(
+ ..., description="The unique identifier of the user to add"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Add a user to a collection.
+
+ This endpoint grants a user access to a specific collection. The
+ authenticated user must have admin permissions for the collection
+ to add new users.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.MANAGE_USERS, self.services
+ )
+
+ result = await self.services.management.add_user_to_collection(
+ user_id, id
+ )
+ return GenericBooleanResponse(success=result) # type: ignore
+
+ @self.router.delete(
+ "/collections/{id}/users/{user_id}",
+ summary="Remove user from collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.remove_user(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "789a012b-c34d-5e6f-g789-012345678901"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.removeUser({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ userId: "789a012b-c34d-5e6f-g789-012345678901"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def remove_user_from_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ user_id: UUID = Path(
+ ..., description="The unique identifier of the user to remove"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Remove a user from a collection.
+
+ This endpoint revokes a user's access to a specific collection. The
+ authenticated user must have admin permissions for the collection
+ to remove users.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.MANAGE_USERS, self.services
+ )
+
+ result = (
+ await self.services.management.remove_user_from_collection(
+ user_id, id
+ )
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/collections/{id}/extract",
+ summary="Extract entities and relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.documents.extract(
+ id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1"
+ )
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def extract(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to extract entities and relationships from.",
+ ),
+ settings: Optional[GraphCreationSettings] = Body(
+ default=None,
+ description="Settings for the entities and relationships extraction process.",
+ ),
+ run_with_orchestration: Optional[bool] = Query(
+ default=True,
+ description="Whether to run the entities and relationships extraction process with orchestration.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Extracts entities and relationships from a document.
+
+ The entities and relationships extraction process involves:
+ 1. Parsing documents into semantic chunks
+ 2. Extracting entities and relationships using LLMs
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.EDIT, self.services
+ )
+
+ settings = settings.dict() if settings else None # type: ignore
+ if not auth_user.is_superuser:
+ logger.warning("Implement permission checks here.")
+
+ # Apply runtime settings overrides
+ server_graph_creation_settings = (
+ self.providers.database.config.graph_creation_settings
+ )
+
+ if settings:
+ server_graph_creation_settings = update_settings_from_dict(
+ server_settings=server_graph_creation_settings,
+ settings_dict=settings, # type: ignore
+ )
+ if run_with_orchestration:
+ try:
+ workflow_input = {
+ "collection_id": str(id),
+ "graph_creation_settings": server_graph_creation_settings.model_dump_json(),
+ "user": auth_user.json(),
+ }
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-extraction", {"request": workflow_input}, {}
+ )
+ except Exception as e: # TODO: Need to find specific error (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated extraction: {e} \n\nAttempting to run without orchestration."
+ )
+
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info("Running extract-triples without orchestration.")
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-extraction"](
+ workflow_input
+ ) # type: ignore
+ return { # type: ignore
+ "message": "Graph created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.get(
+ "/collections/name/{collection_name}",
+ summary="Get a collection by name",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def get_collection_by_name(
+ collection_name: str = Path(
+ ..., description="The name of the collection"
+ ),
+ owner_id: Optional[UUID] = Query(
+ None,
+ description="(Superuser only) Specify the owner_id to retrieve a collection by name",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Retrieve a collection by its (owner_id, name) combination.
+
+ The authenticated user can only fetch collections they own, or, if
+ superuser, from anyone.
+ """
+ if auth_user.is_superuser:
+ if not owner_id:
+ owner_id = auth_user.id
+ else:
+ owner_id = auth_user.id
+
+ # If not superuser, fetch by (owner_id, name). Otherwise, maybe pass `owner_id=None`.
+ # Decide on the logic for superusers.
+ if not owner_id: # is_superuser
+ # If you want superusers to do /collections/name/<string>?owner_id=...
+ # just parse it from the query. For now, let's say it's not implemented.
+ raise R2RException(
+ "Superuser must specify an owner_id to fetch by name.", 400
+ )
+
+ collection = await self.providers.database.collections_handler.get_collection_by_name(
+ owner_id, collection_name
+ )
+ if not collection:
+ raise R2RException("Collection not found.", 404)
+
+ # Now, authorize the 'view' action just in case:
+ # e.g. await authorize_collection_action(auth_user, collection.id, CollectionAction.VIEW, self.services)
+
+ return collection # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py
new file mode 100644
index 00000000..d1b6d645
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py
@@ -0,0 +1,737 @@
+import logging
+import textwrap
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+
+from core.base import Message, R2RException
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedConversationMessagesResponse,
+ WrappedConversationResponse,
+ WrappedConversationsResponse,
+ WrappedMessageResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class ConversationsRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing ConversationsRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/conversations",
+ summary="Create a new conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.create()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.create();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_conversation(
+ name: Optional[str] = Body(
+ None, description="The name of the conversation", embed=True
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationResponse:
+ """Create a new conversation.
+
+ This endpoint initializes a new conversation for the authenticated
+ user.
+ """
+ user_id = auth_user.id
+
+ return await self.services.management.create_conversation( # type: ignore
+ user_id=user_id,
+ name=name,
+ )
+
+ @self.router.get(
+ "/conversations",
+ summary="List conversations",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.list(
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/conversations?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_conversations(
+ ids: list[str] = Query(
+ [],
+ description="A list of conversation IDs to retrieve. If not provided, all conversations will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationsResponse:
+ """List conversations with pagination and sorting options.
+
+ This endpoint returns a paginated list of conversations for the
+ authenticated user.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ conversation_uuids = [
+ UUID(conversation_id) for conversation_id in ids
+ ]
+
+ conversations_response = (
+ await self.services.management.conversations_overview(
+ offset=offset,
+ limit=limit,
+ conversation_ids=conversation_uuids,
+ user_ids=requesting_user_id,
+ )
+ )
+ return conversations_response["results"], { # type: ignore
+ "total_entries": conversations_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/conversations/export",
+ summary="Export conversations to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.conversations.export(
+ output_path="export.csv",
+ columns=["id", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.conversations.export({
+ outputPath: "export.csv",
+ columns: ["id", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/conversations/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_conversations(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export conversations as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_conversations(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/conversations/export_messages",
+ summary="Export messages to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.conversations.export_messages(
+ output_path="export.csv",
+ columns=["id", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.conversations.exportMessages({
+ outputPath: "export.csv",
+ columns: ["id", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/conversations/export_messages" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_messages(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export conversations as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_messages(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/conversations/{id}",
+ summary="Get conversation details",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.get(
+ "123e4567-e89b-12d3-a456-426614174000"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.retrieve({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_conversation(
+ id: UUID = Path(
+ ..., description="The unique identifier of the conversation"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationMessagesResponse:
+ """Get details of a specific conversation.
+
+ This endpoint retrieves detailed information about a single
+ conversation identified by its UUID.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ conversation = await self.services.management.get_conversation(
+ conversation_id=id,
+ user_ids=requesting_user_id,
+ )
+ return conversation # type: ignore
+
+ @self.router.post(
+ "/conversations/{id}",
+ summary="Update conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.update({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ name: "new_name",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -d '{"name": "new_name"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_conversation(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the conversation to delete",
+ ),
+ name: str = Body(
+ ...,
+ description="The updated name for the conversation",
+ embed=True,
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationResponse:
+ """Update an existing conversation.
+
+ This endpoint updates the name of an existing conversation
+ identified by its UUID.
+ """
+ return await self.services.management.update_conversation( # type: ignore
+ conversation_id=id,
+ name=name,
+ )
+
+ @self.router.delete(
+ "/conversations/{id}",
+ summary="Delete conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.delete({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_conversation(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the conversation to delete",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete an existing conversation.
+
+ This endpoint deletes a conversation identified by its UUID.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ await self.services.management.delete_conversation(
+ conversation_id=id,
+ user_ids=requesting_user_id,
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/conversations/{id}/messages",
+ summary="Add message to conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.add_message(
+ "123e4567-e89b-12d3-a456-426614174000",
+ content="Hello, world!",
+ role="user",
+ parent_id="parent_message_id",
+ metadata={"key": "value"}
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.addMessage({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ content: "Hello, world!",
+ role: "user",
+ parentId: "parent_message_id",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"content": "Hello, world!", "parent_id": "parent_message_id", "metadata": {"key": "value"}}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_message(
+ id: UUID = Path(
+ ..., description="The unique identifier of the conversation"
+ ),
+ content: str = Body(
+ ..., description="The content of the message to add"
+ ),
+ role: str = Body(
+ ..., description="The role of the message to add"
+ ),
+ parent_id: Optional[UUID] = Body(
+ None, description="The ID of the parent message, if any"
+ ),
+ metadata: Optional[dict[str, str]] = Body(
+ None, description="Additional metadata for the message"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedMessageResponse:
+ """Add a new message to a conversation.
+
+ This endpoint adds a new message to an existing conversation.
+ """
+ if content == "":
+ raise R2RException("Content cannot be empty", status_code=400)
+ if role not in ["user", "assistant", "system"]:
+ raise R2RException("Invalid role", status_code=400)
+ message = Message(role=role, content=content)
+ return await self.services.management.add_message( # type: ignore
+ conversation_id=id,
+ content=message,
+ parent_id=parent_id,
+ metadata=metadata,
+ )
+
+ @self.router.post(
+ "/conversations/{id}/messages/{message_id}",
+ summary="Update message in conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.update_message(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "message_id_to_update",
+ content="Updated content"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.updateMessage({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ messageId: "message_id_to_update",
+ content: "Updated content",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages/message_id_to_update" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"content": "Updated content"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_message(
+ id: UUID = Path(
+ ..., description="The unique identifier of the conversation"
+ ),
+ message_id: UUID = Path(
+ ..., description="The ID of the message to update"
+ ),
+ content: Optional[str] = Body(
+ None, description="The new content for the message"
+ ),
+ metadata: Optional[dict[str, str]] = Body(
+ None, description="Additional metadata for the message"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedMessageResponse:
+ """Update an existing message in a conversation.
+
+ This endpoint updates the content of an existing message in a
+ conversation.
+ """
+ return await self.services.management.edit_message( # type: ignore
+ message_id=message_id,
+ new_content=content,
+ additional_metadata=metadata,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py
new file mode 100644
index 00000000..fe152b8b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py
@@ -0,0 +1,2342 @@
+import base64
+import logging
+import mimetypes
+import textwrap
+from datetime import datetime
+from io import BytesIO
+from typing import Any, Optional
+from urllib.parse import quote
+from uuid import UUID
+
+from fastapi import Body, Depends, File, Form, Path, Query, UploadFile
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse, StreamingResponse
+from pydantic import Json
+
+from core.base import (
+ IngestionConfig,
+ IngestionMode,
+ R2RException,
+ SearchMode,
+ SearchSettings,
+ UnprocessedChunk,
+ Workflow,
+ generate_document_id,
+ generate_id,
+ select_search_filters,
+)
+from core.base.abstractions import GraphCreationSettings, StoreType
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedChunksResponse,
+ WrappedCollectionsResponse,
+ WrappedDocumentResponse,
+ WrappedDocumentSearchResponse,
+ WrappedDocumentsResponse,
+ WrappedEntitiesResponse,
+ WrappedGenericMessageResponse,
+ WrappedIngestionResponse,
+ WrappedRelationshipsResponse,
+)
+from core.utils import update_settings_from_dict
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+MAX_CHUNKS_PER_REQUEST = 1024 * 100
+
+
+def merge_search_settings(
+ base: SearchSettings, overrides: SearchSettings
+) -> SearchSettings:
+ # Convert both to dict
+ base_dict = base.model_dump()
+ overrides_dict = overrides.model_dump(exclude_unset=True)
+
+ # Update base_dict with values from overrides_dict
+ # This ensures that any field set in overrides takes precedence
+ for k, v in overrides_dict.items():
+ base_dict[k] = v
+
+ # Construct a new SearchSettings from the merged dict
+ return SearchSettings(**base_dict)
+
+
+def merge_ingestion_config(
+ base: IngestionConfig, overrides: IngestionConfig
+) -> IngestionConfig:
+ base_dict = base.model_dump()
+ overrides_dict = overrides.model_dump(exclude_unset=True)
+
+ for k, v in overrides_dict.items():
+ base_dict[k] = v
+
+ return IngestionConfig(**base_dict)
+
+
+class DocumentsRouter(BaseRouterV3):
+ def __init__(
+ self,
+ providers: R2RProviders,
+ services: R2RServices,
+ config: R2RConfig,
+ ):
+ logging.info("Initializing DocumentsRouter")
+ super().__init__(providers, services, config)
+ self._register_workflows()
+
+ def _prepare_search_settings(
+ self,
+ auth_user: Any,
+ search_mode: SearchMode,
+ search_settings: Optional[SearchSettings],
+ ) -> SearchSettings:
+ """Prepare the effective search settings based on the provided
+ search_mode, optional user-overrides in search_settings, and applied
+ filters."""
+
+ if search_mode != SearchMode.custom:
+ # Start from mode defaults
+ effective_settings = SearchSettings.get_default(search_mode.value)
+ if search_settings:
+ # Merge user-provided overrides
+ effective_settings = merge_search_settings(
+ effective_settings, search_settings
+ )
+ else:
+ # Custom mode: use provided settings or defaults
+ effective_settings = search_settings or SearchSettings()
+
+ # Apply user-specific filters
+ effective_settings.filters = select_search_filters(
+ auth_user, effective_settings
+ )
+
+ return effective_settings
+
+ # TODO - Remove this legacy method
+ def _register_workflows(self):
+ self.providers.orchestration.register_workflows(
+ Workflow.INGESTION,
+ self.services.ingestion,
+ {
+ "ingest-files": (
+ "Ingest files task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Document created and ingested successfully."
+ ),
+ "ingest-chunks": (
+ "Ingest chunks task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Document created and ingested successfully."
+ ),
+ "update-chunk": (
+ "Update chunk task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Chunk update completed successfully."
+ ),
+ "update-document-metadata": (
+ "Update document metadata task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Document metadata update completed successfully."
+ ),
+ "create-vector-index": (
+ "Vector index creation task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Vector index creation task completed successfully."
+ ),
+ "delete-vector-index": (
+ "Vector index deletion task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Vector index deletion task completed successfully."
+ ),
+ "select-vector-index": (
+ "Vector index selection task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Vector index selection task completed successfully."
+ ),
+ },
+ )
+
+ def _prepare_ingestion_config(
+ self,
+ ingestion_mode: IngestionMode,
+ ingestion_config: Optional[IngestionConfig],
+ ) -> IngestionConfig:
+ # If not custom, start from defaults
+ if ingestion_mode != IngestionMode.custom:
+ effective_config = IngestionConfig.get_default(
+ ingestion_mode.value, app=self.providers.auth.config.app
+ )
+ if ingestion_config:
+ effective_config = merge_ingestion_config(
+ effective_config, ingestion_config
+ )
+ else:
+ # custom mode
+ effective_config = ingestion_config or IngestionConfig(
+ app=self.providers.auth.config.app
+ )
+
+ effective_config.validate_config()
+ return effective_config
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/documents",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ status_code=202,
+ summary="Create a new document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.create(
+ file_path="pg_essay_1.html",
+ metadata={"metadata_1":"some random metadata"},
+ id=None
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.create({
+ file: { path: "examples/data/marmeladov.txt", name: "marmeladov.txt" },
+ metadata: { title: "marmeladov.txt" },
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/documents" \\
+ -H "Content-Type: multipart/form-data" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -F "file=@pg_essay_1.html;type=text/html" \\
+ -F 'metadata={}' \\
+ -F 'id=null'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_document(
+ file: Optional[UploadFile] = File(
+ None,
+ description="The file to ingest. Exactly one of file, raw_text, or chunks must be provided.",
+ ),
+ raw_text: Optional[str] = Form(
+ None,
+ description="Raw text content to ingest. Exactly one of file, raw_text, or chunks must be provided.",
+ ),
+ chunks: Optional[Json[list[str]]] = Form(
+ None,
+ description="Pre-processed text chunks to ingest. Exactly one of file, raw_text, or chunks must be provided.",
+ ),
+ id: Optional[UUID] = Form(
+ None,
+ description="The ID of the document. If not provided, a new ID will be generated.",
+ ),
+ collection_ids: Optional[Json[list[UUID]]] = Form(
+ None,
+ description="Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.",
+ ),
+ metadata: Optional[Json[dict]] = Form(
+ None,
+ description="Metadata to associate with the document, such as title, description, or custom fields.",
+ ),
+ ingestion_mode: IngestionMode = Form(
+ default=IngestionMode.custom,
+ description=(
+ "Ingestion modes:\n"
+ "- `hi-res`: Thorough ingestion with full summaries and enrichment.\n"
+ "- `fast`: Quick ingestion with minimal enrichment and no summaries.\n"
+ "- `custom`: Full control via `ingestion_config`.\n\n"
+ "If `filters` or `limit` (in `ingestion_config`) are provided alongside `hi-res` or `fast`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ ingestion_config: Optional[Json[IngestionConfig]] = Form(
+ None,
+ description="An optional dictionary to override the default chunking configuration for the ingestion process. If not provided, the system will use the default server-side chunking configuration.",
+ ),
+ run_with_orchestration: Optional[bool] = Form(
+ True,
+ description="Whether or not ingestion runs with orchestration, default is `True`. When set to `False`, the ingestion process will run synchronous and directly return the result.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedIngestionResponse:
+ """
+ Creates a new Document object from an input file, text content, or chunks. The chosen `ingestion_mode` determines
+ how the ingestion process is configured:
+
+ **Ingestion Modes:**
+ - `hi-res`: Comprehensive parsing and enrichment, including summaries and possibly more thorough parsing.
+ - `fast`: Speed-focused ingestion that skips certain enrichment steps like summaries.
+ - `custom`: Provide a full `ingestion_config` to customize the entire ingestion process.
+
+ Either a file or text content must be provided, but not both. Documents are shared through `Collections` which allow for tightly specified cross-user interactions.
+
+ The ingestion process runs asynchronously and its progress can be tracked using the returned
+ task_id.
+ """
+ if not auth_user.is_superuser:
+ user_document_count = (
+ await self.services.management.documents_overview(
+ user_ids=[auth_user.id],
+ offset=0,
+ limit=1,
+ )
+ )["total_entries"]
+ user_max_documents = (
+ await self.services.management.get_user_max_documents(
+ auth_user.id
+ )
+ )
+
+ if user_document_count >= user_max_documents:
+ raise R2RException(
+ status_code=403,
+ message=f"User has reached the maximum number of documents allowed ({user_max_documents}).",
+ )
+
+ # Get chunks using the vector handler's list_chunks method
+ user_chunk_count = (
+ await self.services.ingestion.list_chunks(
+ filters={"owner_id": {"$eq": str(auth_user.id)}},
+ offset=0,
+ limit=1,
+ )
+ )["total_entries"]
+ user_max_chunks = (
+ await self.services.management.get_user_max_chunks(
+ auth_user.id
+ )
+ )
+ if user_chunk_count >= user_max_chunks:
+ raise R2RException(
+ status_code=403,
+ message=f"User has reached the maximum number of chunks allowed ({user_max_chunks}).",
+ )
+
+ user_collections_count = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id],
+ offset=0,
+ limit=1,
+ )
+ )["total_entries"]
+ user_max_collections = (
+ await self.services.management.get_user_max_collections(
+ auth_user.id
+ )
+ )
+ if user_collections_count >= user_max_collections: # type: ignore
+ raise R2RException(
+ status_code=403,
+ message=f"User has reached the maximum number of collections allowed ({user_max_collections}).",
+ )
+
+ effective_ingestion_config = self._prepare_ingestion_config(
+ ingestion_mode=ingestion_mode,
+ ingestion_config=ingestion_config,
+ )
+ if not file and not raw_text and not chunks:
+ raise R2RException(
+ status_code=422,
+ message="Either a `file`, `raw_text`, or `chunks` must be provided.",
+ )
+ if (
+ (file and raw_text)
+ or (file and chunks)
+ or (raw_text and chunks)
+ ):
+ raise R2RException(
+ status_code=422,
+ message="Only one of `file`, `raw_text`, or `chunks` may be provided.",
+ )
+ # Check if the user is a superuser
+ metadata = metadata or {}
+
+ if chunks:
+ if len(chunks) == 0:
+ raise R2RException("Empty list of chunks provided", 400)
+
+ if len(chunks) > MAX_CHUNKS_PER_REQUEST:
+ raise R2RException(
+ f"Maximum of {MAX_CHUNKS_PER_REQUEST} chunks per request",
+ 400,
+ )
+
+ document_id = id or generate_document_id(
+ "".join(chunks), auth_user.id
+ )
+
+ # FIXME: Metadata doesn't seem to be getting passed through
+ raw_chunks_for_doc = [
+ UnprocessedChunk(
+ text=chunk,
+ metadata=metadata,
+ id=generate_id(),
+ )
+ for chunk in chunks
+ ]
+
+ # Prepare workflow input
+ workflow_input = {
+ "document_id": str(document_id),
+ "chunks": [
+ chunk.model_dump(mode="json")
+ for chunk in raw_chunks_for_doc
+ ],
+ "metadata": metadata, # Base metadata for the document
+ "user": auth_user.model_dump_json(),
+ "ingestion_config": effective_ingestion_config.model_dump(
+ mode="json"
+ ),
+ }
+
+ if run_with_orchestration:
+ try:
+ # Run ingestion with orchestration
+ raw_message = (
+ await self.providers.orchestration.run_workflow(
+ "ingest-chunks",
+ {"request": workflow_input},
+ options={
+ "additional_metadata": {
+ "document_id": str(document_id),
+ }
+ },
+ )
+ )
+ raw_message["document_id"] = str(document_id)
+ return raw_message # type: ignore
+ except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated ingestion: {e} \n\nAttempting to run without orchestration."
+ )
+
+ logger.info("Running chunk ingestion without orchestration.")
+ from core.main.orchestration import simple_ingestion_factory
+
+ simple_ingestor = simple_ingestion_factory(
+ self.services.ingestion
+ )
+ await simple_ingestor["ingest-chunks"](workflow_input)
+
+ return { # type: ignore
+ "message": "Document created and ingested successfully.",
+ "document_id": str(document_id),
+ "task_id": None,
+ }
+
+ else:
+ if file:
+ file_data = await self._process_file(file)
+
+ if not file.filename:
+ raise R2RException(
+ status_code=422,
+ message="Uploaded file must have a filename.",
+ )
+
+ file_ext = file.filename.split(".")[
+ -1
+ ] # e.g. "pdf", "txt"
+ max_allowed_size = await self.services.management.get_max_upload_size_by_type(
+ user_id=auth_user.id, file_type_or_ext=file_ext
+ )
+
+ content_length = file_data["content_length"]
+
+ if content_length > max_allowed_size:
+ raise R2RException(
+ status_code=413, # HTTP 413: Payload Too Large
+ message=(
+ f"File size exceeds maximum of {max_allowed_size} bytes "
+ f"for extension '{file_ext}'."
+ ),
+ )
+
+ file_content = BytesIO(
+ base64.b64decode(file_data["content"])
+ )
+
+ file_data.pop("content", None)
+ document_id = id or generate_document_id(
+ file_data["filename"], auth_user.id
+ )
+ elif raw_text:
+ content_length = len(raw_text)
+ file_content = BytesIO(raw_text.encode("utf-8"))
+ document_id = id or generate_document_id(
+ raw_text, auth_user.id
+ )
+ title = metadata.get("title", None)
+ title = title + ".txt" if title else None
+ file_data = {
+ "filename": title or "N/A",
+ "content_type": "text/plain",
+ }
+ else:
+ raise R2RException(
+ status_code=422,
+ message="Either a file or content must be provided.",
+ )
+
+ workflow_input = {
+ "file_data": file_data,
+ "document_id": str(document_id),
+ "collection_ids": (
+ [str(cid) for cid in collection_ids]
+ if collection_ids
+ else None
+ ),
+ "metadata": metadata,
+ "ingestion_config": effective_ingestion_config.model_dump(
+ mode="json"
+ ),
+ "user": auth_user.model_dump_json(),
+ "size_in_bytes": content_length,
+ "version": "v0",
+ }
+
+ file_name = file_data["filename"]
+ await self.providers.database.files_handler.store_file(
+ document_id,
+ file_name,
+ file_content,
+ file_data["content_type"],
+ )
+
+ await self.services.ingestion.ingest_file_ingress(
+ file_data=workflow_input["file_data"],
+ user=auth_user,
+ document_id=workflow_input["document_id"],
+ size_in_bytes=workflow_input["size_in_bytes"],
+ metadata=workflow_input["metadata"],
+ version=workflow_input["version"],
+ )
+
+ if run_with_orchestration:
+ try:
+ # TODO - Modify create_chunks so that we can add chunks to existing document
+
+ workflow_result: dict[
+ str, str | None
+ ] = await self.providers.orchestration.run_workflow( # type: ignore
+ "ingest-files",
+ {"request": workflow_input},
+ options={
+ "additional_metadata": {
+ "document_id": str(document_id),
+ }
+ },
+ )
+ workflow_result["document_id"] = str(document_id)
+ return workflow_result # type: ignore
+ except Exception as e: # TODO: Need to find specific error (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated ingestion: {e} \n\nAttempting to run without orchestration."
+ )
+ logger.info(
+ f"Running ingestion without orchestration for file {file_name} and document_id {document_id}."
+ )
+ # TODO - Clean up implementation logic here to be more explicitly `synchronous`
+ from core.main.orchestration import simple_ingestion_factory
+
+ simple_ingestor = simple_ingestion_factory(self.services.ingestion)
+ await simple_ingestor["ingest-files"](workflow_input)
+ return { # type: ignore
+ "message": "Document created and ingested successfully.",
+ "document_id": str(document_id),
+ "task_id": None,
+ }
+
+ @self.router.patch(
+ "/documents/{id}/metadata",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Append metadata to a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.append_metadata(
+ id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata=[{"key": "new_key", "value": "new_value"}]
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.appendMetadata({
+ id: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata: [{ key: "new_key", value: "new_value" }],
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def patch_metadata(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to append metadata to.",
+ ),
+ metadata: list[dict] = Body(
+ ...,
+ description="Metadata to append to the document.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentResponse:
+ """Appends metadata to a document. This endpoint allows adding new metadata fields or updating existing ones."""
+ request_user_ids = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=request_user_ids,
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+ results = documents_overview_response["results"]
+ if len(results) == 0:
+ raise R2RException("Document not found.", 404)
+
+ return await self.services.management.update_document_metadata(
+ document_id=id,
+ metadata=metadata,
+ overwrite=False,
+ )
+
+ @self.router.put(
+ "/documents/{id}/metadata",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Replace metadata of a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.replace_metadata(
+ id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata=[{"key": "new_key", "value": "new_value"}]
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.replaceMetadata({
+ id: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata: [{ key: "new_key", value: "new_value" }],
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def put_metadata(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to append metadata to.",
+ ),
+ metadata: list[dict] = Body(
+ ...,
+ description="Metadata to append to the document.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentResponse:
+ """Replaces metadata in a document. This endpoint allows overwriting existing metadata fields."""
+ request_user_ids = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=request_user_ids,
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+ results = documents_overview_response["results"]
+ if len(results) == 0:
+ raise R2RException("Document not found.", 404)
+
+ return await self.services.management.update_document_metadata(
+ document_id=id,
+ metadata=metadata,
+ overwrite=True,
+ )
+
+ @self.router.post(
+ "/documents/export",
+ summary="Export documents to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.documents.export(
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.documents.export({
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/documents/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_documents(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_documents(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/documents/download_zip",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ response_class=StreamingResponse,
+ summary="Export multiple documents as zip",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ client.documents.download_zip(
+ document_ids=["uuid1", "uuid2"],
+ start_date="2024-01-01",
+ end_date="2024-12-31"
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/download_zip?document_ids=uuid1,uuid2&start_date=2024-01-01&end_date=2024-12-31" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_files(
+ document_ids: Optional[list[UUID]] = Query(
+ None,
+ description="List of document IDs to include in the export. If not provided, all accessible documents will be included.",
+ ),
+ start_date: Optional[datetime] = Query(
+ None,
+ description="Filter documents created on or after this date.",
+ ),
+ end_date: Optional[datetime] = Query(
+ None,
+ description="Filter documents created before this date.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> StreamingResponse:
+ """Export multiple documents as a zip file. Documents can be
+ filtered by IDs and/or date range.
+
+ The endpoint allows downloading:
+ - Specific documents by providing their IDs
+ - Documents within a date range
+ - All accessible documents if no filters are provided
+
+ Files are streamed as a zip archive to handle potentially large downloads efficiently.
+ """
+ if not auth_user.is_superuser:
+ # For non-superusers, verify access to requested documents
+ if document_ids:
+ documents_overview = (
+ await self.services.management.documents_overview(
+ user_ids=[auth_user.id],
+ document_ids=document_ids,
+ offset=0,
+ limit=len(document_ids),
+ )
+ )
+ if len(documents_overview["results"]) != len(document_ids):
+ raise R2RException(
+ status_code=403,
+ message="You don't have access to one or more requested documents.",
+ )
+ if not document_ids:
+ raise R2RException(
+ status_code=403,
+ message="Non-superusers must provide document IDs to export.",
+ )
+
+ (
+ zip_name,
+ zip_content,
+ zip_size,
+ ) = await self.services.management.export_files(
+ document_ids=document_ids,
+ start_date=start_date,
+ end_date=end_date,
+ )
+ encoded_filename = quote(zip_name)
+
+ async def stream_file():
+ yield zip_content.getvalue()
+
+ return StreamingResponse(
+ stream_file(),
+ media_type="application/zip",
+ headers={
+ "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
+ "Content-Length": str(zip_size),
+ },
+ )
+
+ @self.router.get(
+ "/documents",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List documents",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list(
+ limit=10,
+ offset=0
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.list({
+ limit: 10,
+ offset: 0,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_documents(
+ ids: list[str] = Query(
+ [],
+ description="A list of document IDs to retrieve. If not provided, all documents will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ include_summary_embeddings: bool = Query(
+ False,
+ description="Specifies whether or not to include embeddings of each document summary.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentsResponse:
+ """Returns a paginated list of documents the authenticated user has
+ access to.
+
+ Results can be filtered by providing specific document IDs. Regular
+ users will only see documents they own or have access to through
+ collections. Superusers can see all documents.
+
+ The documents are returned in order of last modification, with most
+ recent first.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+ filter_collection_ids = (
+ None if auth_user.is_superuser else auth_user.collection_ids
+ )
+
+ document_uuids = [UUID(document_id) for document_id in ids]
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=requesting_user_id,
+ collection_ids=filter_collection_ids,
+ document_ids=document_uuids,
+ offset=offset,
+ limit=limit,
+ )
+ )
+ if not include_summary_embeddings:
+ for document in documents_overview_response["results"]:
+ document.summary_embedding = None
+
+ return ( # type: ignore
+ documents_overview_response["results"],
+ {
+ "total_entries": documents_overview_response[
+ "total_entries"
+ ]
+ },
+ )
+
+ @self.router.get(
+ "/documents/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Retrieve a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.retrieve(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.retrieve({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_document(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentResponse:
+ """Retrieves detailed information about a specific document by its
+ ID.
+
+ This endpoint returns the document's metadata, status, and system information. It does not
+ return the document's content - use the `/documents/{id}/download` endpoint for that.
+
+ Users can only retrieve documents they own or have access to through collections.
+ Superusers can retrieve any document.
+ """
+ request_user_ids = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+ filter_collection_ids = (
+ None if auth_user.is_superuser else auth_user.collection_ids
+ )
+
+ documents_overview_response = await self.services.management.documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+ user_ids=request_user_ids,
+ collection_ids=filter_collection_ids,
+ document_ids=[id],
+ offset=0,
+ limit=100,
+ )
+ results = documents_overview_response["results"]
+ if len(results) == 0:
+ raise R2RException("Document not found.", 404)
+
+ return results[0]
+
+ @self.router.get(
+ "/documents/{id}/chunks",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List document chunks",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list_chunks(
+ id="32b6a70f-a995-5c51-85d2-834f06283a1e"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.listChunks({
+ id: "32b6a70f-a995-5c51-85d2-834f06283a1e",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/chunks" \\
+ -H "Authorization: Bearer YOUR_API_KEY"\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_chunks(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve chunks for.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ include_vectors: Optional[bool] = Query(
+ False,
+ description="Whether to include vector embeddings in the response.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunksResponse:
+ """Retrieves the text chunks that were generated from a document
+ during ingestion. Chunks represent semantic sections of the
+ document and are used for retrieval and analysis.
+
+ Users can only access chunks from documents they own or have access
+ to through collections. Vector embeddings are only included if
+ specifically requested.
+
+ Results are returned in chunk sequence order, representing their
+ position in the original document.
+ """
+ list_document_chunks = (
+ await self.services.management.list_document_chunks(
+ document_id=id,
+ offset=offset,
+ limit=limit,
+ include_vectors=include_vectors or False,
+ )
+ )
+
+ if not list_document_chunks["results"]:
+ raise R2RException(
+ "No chunks found for the given document ID.", 404
+ )
+
+ is_owner = str(
+ list_document_chunks["results"][0].get("owner_id")
+ ) == str(auth_user.id)
+ document_collections = (
+ await self.services.management.collections_overview(
+ offset=0,
+ limit=-1,
+ document_ids=[id],
+ )
+ )
+
+ user_has_access = (
+ is_owner
+ or set(auth_user.collection_ids).intersection(
+ {ele.id for ele in document_collections["results"]} # type: ignore
+ )
+ != set()
+ )
+
+ if not user_has_access and not auth_user.is_superuser:
+ raise R2RException(
+ "Not authorized to access this document's chunks.", 403
+ )
+
+ return ( # type: ignore
+ list_document_chunks["results"],
+ {"total_entries": list_document_chunks["total_entries"]},
+ )
+
+ @self.router.get(
+ "/documents/{id}/download",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ response_class=StreamingResponse,
+ summary="Download document content",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.download(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.download({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/download" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_document_file(
+ id: str = Path(..., description="Document ID"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> StreamingResponse:
+ """Downloads the original file content of a document.
+
+ For uploaded files, returns the original file with its proper MIME
+ type. For text-only documents, returns the content as plain text.
+
+ Users can only download documents they own or have access to
+ through collections.
+ """
+ try:
+ document_uuid = UUID(id)
+ except ValueError:
+ raise R2RException(
+ status_code=422, message="Invalid document ID format."
+ ) from None
+
+ # Retrieve the document's information
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=None,
+ collection_ids=None,
+ document_ids=[document_uuid],
+ offset=0,
+ limit=1,
+ )
+ )
+
+ if not documents_overview_response["results"]:
+ raise R2RException("Document not found.", 404)
+
+ document = documents_overview_response["results"][0]
+
+ is_owner = str(document.owner_id) == str(auth_user.id)
+
+ if not auth_user.is_superuser and not is_owner:
+ document_collections = (
+ await self.services.management.collections_overview(
+ offset=0,
+ limit=-1,
+ document_ids=[document_uuid],
+ )
+ )
+
+ document_collection_ids = {
+ str(ele.id)
+ for ele in document_collections["results"] # type: ignore
+ }
+
+ user_collection_ids = {
+ str(cid) for cid in auth_user.collection_ids
+ }
+
+ has_collection_access = user_collection_ids.intersection(
+ document_collection_ids
+ )
+
+ if not has_collection_access:
+ raise R2RException(
+ "Not authorized to access this document.", 403
+ )
+
+ file_tuple = await self.services.management.download_file(
+ document_uuid
+ )
+ if not file_tuple:
+ raise R2RException(status_code=404, message="File not found.")
+
+ file_name, file_content, file_size = file_tuple
+ encoded_filename = quote(file_name)
+
+ mime_type, _ = mimetypes.guess_type(file_name)
+ if not mime_type:
+ mime_type = "application/octet-stream"
+
+ async def file_stream():
+ chunk_size = 1024 * 1024 # 1MB
+ while True:
+ data = file_content.read(chunk_size)
+ if not data:
+ break
+ yield data
+
+ return StreamingResponse(
+ file_stream(),
+ media_type=mime_type,
+ headers={
+ "Content-Disposition": f"inline; filename*=UTF-8''{encoded_filename}",
+ "Content-Length": str(file_size),
+ },
+ )
+
+ @self.router.delete(
+ "/documents/by-filter",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete documents by filter",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+ client = R2RClient()
+ # when using auth, do client.login(...)
+ response = client.documents.delete_by_filter(
+ filters={"document_type": {"$eq": "txt"}}
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/documents/by-filter?filters=%7B%22document_type%22%3A%7B%22%24eq%22%3A%22text%22%7D%2C%22created_at%22%3A%7B%22%24lt%22%3A%222023-01-01T00%3A00%3A00Z%22%7D%7D" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_document_by_filter(
+ filters: Json[dict] = Body(
+ ..., description="JSON-encoded filters"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete documents based on provided filters.
+
+ Allowed operators
+ include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`,
+ `ilike`, `in`, and `nin`. Deletion requests are limited to a
+ user's own documents.
+ """
+
+ filters_dict = {
+ "$and": [{"owner_id": {"$eq": str(auth_user.id)}}, filters]
+ }
+ await (
+ self.services.management.delete_documents_and_chunks_by_filter(
+ filters=filters_dict
+ )
+ )
+
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.delete(
+ "/documents/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.delete(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.delete({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_document_by_id(
+ id: UUID = Path(..., description="Document ID"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific document. All chunks corresponding to the
+ document are deleted, and all other references to the document are
+ removed.
+
+ NOTE - Deletions do not yet impact the knowledge graph or other derived data. This feature is planned for a future release.
+ """
+
+ filters: dict[str, Any] = {"document_id": {"$eq": str(id)}}
+ if not auth_user.is_superuser:
+ filters = {
+ "$and": [
+ {"owner_id": {"$eq": str(auth_user.id)}},
+ {"document_id": {"$eq": str(id)}},
+ ]
+ }
+
+ await (
+ self.services.management.delete_documents_and_chunks_by_filter(
+ filters=filters
+ )
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/documents/{id}/collections",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List document collections",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list_collections(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", offset=0, limit=10
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.listCollections({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/collections" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_document_collections(
+ id: str = Path(..., description="Document ID"),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionsResponse:
+ """Retrieves all collections that contain the specified document.
+ This endpoint is restricted to superusers only and provides a
+ system-wide view of document organization.
+
+ Collections are used to organize documents and manage access control. A document can belong
+ to multiple collections, and users can access documents through collection membership.
+
+ The results are paginated and ordered by collection creation date, with the most recently
+ created collections appearing first.
+
+ NOTE - This endpoint is only available to superusers, it will be extended to regular users in a future release.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can get the collections belonging to a document.",
+ 403,
+ )
+
+ collections_response = (
+ await self.services.management.collections_overview(
+ offset=offset,
+ limit=limit,
+ document_ids=[UUID(id)], # Convert string ID to UUID
+ )
+ )
+
+ return collections_response["results"], { # type: ignore
+ "total_entries": collections_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/documents/{id}/extract",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Extract entities and relationships",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.extract(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def extract(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to extract entities and relationships from.",
+ ),
+ settings: Optional[GraphCreationSettings] = Body(
+ default=None,
+ description="Settings for the entities and relationships extraction process.",
+ ),
+ run_with_orchestration: Optional[bool] = Body(
+ default=True,
+ description="Whether to run the entities and relationships extraction process with orchestration.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Extracts entities and relationships from a document.
+
+ The entities and relationships extraction process involves:
+
+ 1. Parsing documents into semantic chunks
+
+ 2. Extracting entities and relationships using LLMs
+
+ 3. Storing the created entities and relationships in the knowledge graph
+
+ 4. Preserving the document's metadata and content, and associating the elements with collections the document belongs to
+ """
+
+ settings = settings.dict() if settings else None # type: ignore
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(documents_overview_response) == 0:
+ raise R2RException("Document not found.", 404)
+
+ if (
+ not auth_user.is_superuser
+ and auth_user.id != documents_overview_response[0].owner_id
+ ):
+ raise R2RException(
+ "Only a superuser can extract entities and relationships from a document they do not own.",
+ 403,
+ )
+
+ # Apply runtime settings overrides
+ server_graph_creation_settings = (
+ self.providers.database.config.graph_creation_settings
+ )
+
+ if settings:
+ server_graph_creation_settings = update_settings_from_dict(
+ server_settings=server_graph_creation_settings,
+ settings_dict=settings, # type: ignore
+ )
+
+ if run_with_orchestration:
+ try:
+ workflow_input = {
+ "document_id": str(id),
+ "graph_creation_settings": server_graph_creation_settings.model_dump_json(),
+ "user": auth_user.json(),
+ }
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-extraction", {"request": workflow_input}, {}
+ )
+ except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated extraction: {e} \n\nAttempting to run without orchestration."
+ )
+
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info("Running extract-triples without orchestration.")
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-extraction"](
+ workflow_input
+ )
+ return { # type: ignore
+ "message": "Graph created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.post(
+ "/documents/{id}/deduplicate",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Deduplicate entities",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ response = client.documents.deduplicate(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.deduplicate({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/deduplicate" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def deduplicate(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to extract entities and relationships from.",
+ ),
+ settings: Optional[GraphCreationSettings] = Body(
+ default=None,
+ description="Settings for the entities and relationships extraction process.",
+ ),
+ run_with_orchestration: Optional[bool] = Body(
+ default=True,
+ description="Whether to run the entities and relationships extraction process with orchestration.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Deduplicates entities from a document."""
+
+ settings = settings.model_dump() if settings else None # type: ignore
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(documents_overview_response) == 0:
+ raise R2RException("Document not found.", 404)
+
+ if (
+ not auth_user.is_superuser
+ and auth_user.id != documents_overview_response[0].owner_id
+ ):
+ raise R2RException(
+ "Only a superuser can run deduplication on a document they do not own.",
+ 403,
+ )
+
+ # Apply runtime settings overrides
+ server_graph_creation_settings = (
+ self.providers.database.config.graph_creation_settings
+ )
+
+ if settings:
+ server_graph_creation_settings = update_settings_from_dict(
+ server_settings=server_graph_creation_settings,
+ settings_dict=settings, # type: ignore
+ )
+
+ if run_with_orchestration:
+ try:
+ workflow_input = {
+ "document_id": str(id),
+ }
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-deduplication",
+ {"request": workflow_input},
+ {},
+ )
+ except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated deduplication: {e} \n\nAttempting to run without orchestration."
+ )
+
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info(
+ "Running deduplicate-document-entities without orchestration."
+ )
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-deduplication"](
+ workflow_input
+ )
+ return { # type: ignore
+ "message": "Graph created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.get(
+ "/documents/{id}/entities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Lists the entities from the document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.extract(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_entities(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve entities from.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ include_embeddings: Optional[bool] = Query(
+ False,
+ description="Whether to include vector embeddings in the response.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntitiesResponse:
+ """Retrieves the entities that were extracted from a document.
+ These represent important semantic elements like people, places,
+ organizations, concepts, etc.
+
+ Users can only access entities from documents they own or have
+ access to through collections. Entity embeddings are only included
+ if specifically requested.
+
+ Results are returned in the order they were extracted from the
+ document.
+ """
+ # if (
+ # not auth_user.is_superuser
+ # and id not in auth_user.collection_ids
+ # ):
+ # raise R2RException(
+ # "The currently authenticated user does not have access to the specified collection.",
+ # 403,
+ # )
+
+ # First check if the document exists and user has access
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+
+ if not documents_overview_response["results"]:
+ raise R2RException("Document not found.", 404)
+
+ # Get all entities for this document from the document_entity table
+ (
+ entities,
+ count,
+ ) = await self.providers.database.graphs_handler.entities.get(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ offset=offset,
+ limit=limit,
+ include_embeddings=include_embeddings or False,
+ )
+
+ return entities, {"total_entries": count} # type: ignore
+
+ @self.router.post(
+ "/documents/{id}/entities/export",
+ summary="Export document entities to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.documents.export_entities(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.documents.exportEntities({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_entities(
+ background_tasks: BackgroundTasks,
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_document_entities(
+ id=id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/documents/{id}/relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List document relationships",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list_relationships(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ offset=0,
+ limit=100
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.listRelationships({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ offset: 0,
+ limit: 100,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/relationships" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_relationships(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve relationships for.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ entity_names: Optional[list[str]] = Query(
+ None,
+ description="Filter relationships by specific entity names.",
+ ),
+ relationship_types: Optional[list[str]] = Query(
+ None,
+ description="Filter relationships by specific relationship types.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipsResponse:
+ """Retrieves the relationships between entities that were extracted
+ from a document. These represent connections and interactions
+ between entities found in the text.
+
+ Users can only access relationships from documents they own or have
+ access to through collections. Results can be filtered by entity
+ names and relationship types.
+
+ Results are returned in the order they were extracted from the
+ document.
+ """
+ # if (
+ # not auth_user.is_superuser
+ # and id not in auth_user.collection_ids
+ # ):
+ # raise R2RException(
+ # "The currently authenticated user does not have access to the specified collection.",
+ # 403,
+ # )
+
+ # First check if the document exists and user has access
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+
+ if not documents_overview_response["results"]:
+ raise R2RException("Document not found.", 404)
+
+ # Get relationships for this document
+ (
+ relationships,
+ count,
+ ) = await self.providers.database.graphs_handler.relationships.get(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ entity_names=entity_names,
+ relationship_types=relationship_types,
+ offset=offset,
+ limit=limit,
+ )
+
+ return relationships, {"total_entries": count} # type: ignore
+
+ @self.router.post(
+ "/documents/{id}/relationships/export",
+ summary="Export document relationships to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.documents.export_entities(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.documents.exportEntities({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_relationships(
+ background_tasks: BackgroundTasks,
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_document_relationships(
+ id=id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/documents/search",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Search document summaries",
+ )
+ @self.base_endpoint
+ async def search_documents(
+ query: str = Body(
+ ...,
+ description="The search query to perform.",
+ ),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description=(
+ "Default value of `custom` allows full control over search settings.\n\n"
+ "Pre-configured search modes:\n"
+ "`basic`: A simple semantic-based search.\n"
+ "`advanced`: A more powerful hybrid search combining semantic and full-text.\n"
+ "`custom`: Full control via `search_settings`.\n\n"
+ "If `filters` or `limit` are provided alongside `basic` or `advanced`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ search_settings: SearchSettings = Body(
+ default_factory=SearchSettings,
+ description="Settings for document search",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentSearchResponse:
+ """Perform a search query on the automatically generated document
+ summaries in the system.
+
+ This endpoint allows for complex filtering of search results using PostgreSQL-based queries.
+ Filters can be applied to various fields such as document_id, and internal metadata values.
+
+
+ Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+ """
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+
+ query_embedding = (
+ await self.providers.embedding.async_get_embedding(query)
+ )
+ results = await self.services.retrieval.search_documents(
+ query=query,
+ query_embedding=query_embedding,
+ settings=effective_settings,
+ )
+ return results # type: ignore
+
+ @staticmethod
+ async def _process_file(file):
+ import base64
+
+ content = await file.read()
+
+ return {
+ "filename": file.filename,
+ "content": base64.b64encode(content).decode("utf-8"),
+ "content_type": file.content_type,
+ "content_length": len(content),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py
new file mode 100644
index 00000000..ba588c3b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py
@@ -0,0 +1,1065 @@
+import textwrap
+
+"""
+This file contains updated OpenAPI examples for the RetrievalRouterV3 class.
+These examples are designed to be included in the openapi_extra field for each route.
+"""
+
+# Updated examples for search_app endpoint
+search_app_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # if using auth, do client.login(...)
+
+ # Basic search
+ response = client.retrieval.search(
+ query="What is DeepSeek R1?",
+ )
+
+ # Advanced mode with specific filters
+ response = client.retrieval.search(
+ query="What is DeepSeek R1?",
+ search_mode="advanced",
+ search_settings={
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ "limit": 5
+ }
+ )
+
+ # Using hybrid search
+ response = client.retrieval.search(
+ query="What was Uber's profit in 2020?",
+ search_settings={
+ "use_hybrid_search": True,
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50
+ },
+ "filters": {"title": {"$in": ["DeepSeek_R1.pdf"]}},
+ }
+ )
+
+ # Advanced filtering
+ results = client.retrieval.search(
+ query="What are the effects of climate change?",
+ search_settings={
+ "filters": {
+ "$and":[
+ {"document_type": {"$eq": "pdf"}},
+ {"metadata.year": {"$gt": 2020}}
+ ]
+ },
+ "limit": 10
+ }
+ )
+
+ # Knowledge graph enhanced search
+ results = client.retrieval.search(
+ query="What was DeepSeek R1",
+ graph_search_settings={
+ "use_graph_search": True,
+ "kg_search_type": "local"
+ }
+ )
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // if using auth, do client.login(...)
+
+ // Basic search
+ const response = await client.retrieval.search({
+ query: "What is DeepSeek R1?",
+ });
+
+ // With specific filters
+ const filteredResponse = await client.retrieval.search({
+ query: "What is DeepSeek R1?",
+ searchSettings: {
+ filters: {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ limit: 5
+ }
+ });
+
+ // Using hybrid search
+ const hybridResponse = await client.retrieval.search({
+ query: "What was Uber's profit in 2020?",
+ searchSettings: {
+ indexMeasure: "l2_distance",
+ useHybridSearch: true,
+ hybridSettings: {
+ fullTextWeight: 1.0,
+ semanticWeight: 5.0,
+ fullTextLimit: 200,
+ },
+ filters: {"title": {"$in": ["DeepSeek_R1.pdf"]}},
+ }
+ });
+
+ // Advanced filtering
+ const advancedResults = await client.retrieval.search({
+ query: "What are the effects of climate change?",
+ searchSettings: {
+ filters: {
+ $and: [
+ {document_type: {$eq: "pdf"}},
+ {"metadata.year": {$gt: 2020}}
+ ]
+ },
+ limit: 10
+ }
+ });
+
+ // Knowledge graph enhanced search
+ const kgResults = await client.retrieval.search({
+ query: "who was aristotle?",
+ graphSearchSettings: {
+ useKgSearch: true,
+ kgSearchType: "local"
+ }
+ });
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ # Basic search
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What is DeepSeek R1?"
+ }'
+
+ # With hybrid search and filters
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What was Uber'\''s profit in 2020?",
+ "search_settings": {
+ "use_hybrid_search": true,
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50
+ },
+ "filters": {"title": {"$in": ["DeepSeek_R1.pdf"]}},
+ "limit": 10,
+ "chunk_settings": {
+ "index_measure": "l2_distance",
+ "probes": 25,
+ "ef_search": 100
+ }
+ }
+ }'
+
+ # Knowledge graph enhanced search
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "query": "who was aristotle?",
+ "graph_search_settings": {
+ "use_graph_search": true,
+ "kg_search_type": "local"
+ }
+ }' \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for rag_app endpoint
+rag_app_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ # Basic RAG request
+ response = client.retrieval.rag(
+ query="What is DeepSeek R1?",
+ )
+
+ # Advanced RAG with custom search settings
+ response = client.retrieval.rag(
+ query="What is DeepSeek R1?",
+ search_settings={
+ "use_semantic_search": True,
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ "limit": 10,
+ },
+ rag_generation_config={
+ "stream": False,
+ "temperature": 0.7,
+ "max_tokens": 1500
+ }
+ )
+
+ # Hybrid search in RAG
+ results = client.retrieval.rag(
+ "Who is Jon Snow?",
+ search_settings={"use_hybrid_search": True}
+ )
+
+ # Custom model selection
+ response = client.retrieval.rag(
+ "Who was Aristotle?",
+ rag_generation_config={"model":"anthropic/claude-3-haiku-20240307", "stream": True}
+ )
+ for chunk in response:
+ print(chunk)
+
+ # Streaming RAG
+ from r2r import (
+ CitationEvent,
+ FinalAnswerEvent,
+ MessageEvent,
+ SearchResultsEvent,
+ R2RClient,
+ )
+
+ result_stream = client.retrieval.rag(
+ query="What is DeepSeek R1?",
+ search_settings={"limit": 25},
+ rag_generation_config={"stream": True},
+ )
+
+ # Process different event types
+ for event in result_stream:
+ if isinstance(event, SearchResultsEvent):
+ print("Search results:", event.data)
+ elif isinstance(event, MessageEvent):
+ print("Partial message:", event.data.delta)
+ elif isinstance(event, CitationEvent):
+ print("New citation detected:", event.data.id)
+ elif isinstance(event, FinalAnswerEvent):
+ print("Final answer:", event.data.generated_answer)
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ // Basic RAG request
+ const response = await client.retrieval.rag({
+ query: "What is DeepSeek R1?",
+ });
+
+ // RAG with custom settings
+ const advancedResponse = await client.retrieval.rag({
+ query: "What is DeepSeek R1?",
+ searchSettings: {
+ useSemanticSearch: true,
+ filters: {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ limit: 10,
+ },
+ ragGenerationConfig: {
+ stream: false,
+ temperature: 0.7,
+ maxTokens: 1500
+ }
+ });
+
+ // Hybrid search in RAG
+ const hybridResults = await client.retrieval.rag({
+ query: "Who is Jon Snow?",
+ searchSettings: {
+ useHybridSearch: true
+ },
+ });
+
+ // Custom model
+ const customModelResponse = await client.retrieval.rag({
+ query: "Who was Aristotle?",
+ ragGenerationConfig: {
+ model: 'anthropic/claude-3-haiku-20240307',
+ temperature: 0.7,
+ }
+ });
+
+ // Streaming RAG
+ const resultStream = await client.retrieval.rag({
+ query: "What is DeepSeek R1?",
+ searchSettings: { limit: 25 },
+ ragGenerationConfig: { stream: true },
+ });
+
+ // Process streaming events
+ if (Symbol.asyncIterator in resultStream) {
+ for await (const event of resultStream) {
+ switch (event.event) {
+ case "search_results":
+ console.log("Search results:", event.data);
+ break;
+ case "message":
+ console.log("Partial message delta:", event.data.delta);
+ break;
+ case "citation":
+ console.log("New citation event:", event.data.id);
+ break;
+ case "final_answer":
+ console.log("Final answer:", event.data.generated_answer);
+ break;
+ default:
+ console.log("Unknown or unhandled event:", event);
+ }
+ }
+ }
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ # Basic RAG request
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What is DeepSeek R1?"
+ }'
+
+ # RAG with custom settings
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What is DeepSeek R1?",
+ "search_settings": {
+ "use_semantic_search": true,
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ "limit": 10
+ },
+ "rag_generation_config": {
+ "stream": false,
+ "temperature": 0.7,
+ "max_tokens": 1500
+ }
+ }'
+
+ # Hybrid search in RAG
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "Who is Jon Snow?",
+ "search_settings": {
+ "use_hybrid_search": true,
+ "filters": {},
+ "limit": 10
+ }
+ }'
+
+ # Custom model
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "Who is Jon Snow?",
+ "rag_generation_config": {
+ "model": "anthropic/claude-3-haiku-20240307",
+ "temperature": 0.7
+ }
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for agent_app endpoint
+agent_app_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+from r2r import (
+ R2RClient,
+ ThinkingEvent,
+ ToolCallEvent,
+ ToolResultEvent,
+ CitationEvent,
+ FinalAnswerEvent,
+ MessageEvent,
+)
+
+client = R2RClient()
+# when using auth, do client.login(...)
+
+# Basic synchronous request
+response = client.retrieval.agent(
+ message={
+ "role": "user",
+ "content": "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ rag_tools=["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"],
+)
+
+# Advanced analysis with streaming and extended thinking
+streaming_response = client.retrieval.agent(
+ message={
+ "role": "user",
+ "content": "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ search_settings={"limit": 20},
+ rag_tools=["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"],
+ rag_generation_config={
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": True,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": None,
+ "max_tokens": 16000,
+ "stream": True
+ }
+)
+
+# Process streaming events with emoji only on type change
+current_event_type = None
+for event in streaming_response:
+ # Check if the event type has changed
+ event_type = type(event)
+ if event_type != current_event_type:
+ current_event_type = event_type
+ print() # Add newline before new event type
+
+ # Print emoji based on the new event type
+ if isinstance(event, ThinkingEvent):
+ print(f"\n🧠 Thinking: ", end="", flush=True)
+ elif isinstance(event, ToolCallEvent):
+ print(f"\n🔧 Tool call: ", end="", flush=True)
+ elif isinstance(event, ToolResultEvent):
+ print(f"\n📊 Tool result: ", end="", flush=True)
+ elif isinstance(event, CitationEvent):
+ print(f"\n📑 Citation: ", end="", flush=True)
+ elif isinstance(event, MessageEvent):
+ print(f"\n💬 Message: ", end="", flush=True)
+ elif isinstance(event, FinalAnswerEvent):
+ print(f"\n✅ Final answer: ", end="", flush=True)
+
+ # Print the content without the emoji
+ if isinstance(event, ThinkingEvent):
+ print(f"{event.data.delta.content[0].payload.value}", end="", flush=True)
+ elif isinstance(event, ToolCallEvent):
+ print(f"{event.data.name}({event.data.arguments})")
+ elif isinstance(event, ToolResultEvent):
+ print(f"{event.data.content[:60]}...")
+ elif isinstance(event, CitationEvent):
+ print(f"{event.data.id}")
+ elif isinstance(event, MessageEvent):
+ print(f"{event.data.delta.content[0].payload.value}", end="", flush=True)
+ elif isinstance(event, FinalAnswerEvent):
+ print(f"{event.data.generated_answer[:100]}...")
+ print(f" Citations: {len(event.data.citations)} sources referenced")
+
+# Conversation with multiple turns (synchronous)
+conversation = client.conversations.create()
+
+# First message in conversation
+results_1 = client.retrieval.agent(
+ query="What does DeepSeek R1 imply for the future of AI?",
+ rag_generation_config={
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": True,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": None,
+ "max_tokens": 16000,
+ "stream": True
+ },
+ conversation_id=conversation.results.id
+)
+
+# Follow-up query in the same conversation
+results_2 = client.retrieval.agent(
+ query="How does it compare to other reasoning models?",
+ rag_generation_config={
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": True,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": None,
+ "max_tokens": 16000,
+ "stream": True
+ },
+ conversation_id=conversation.results.id
+)
+
+# Access the final results
+print(f"First response: {results_1.generated_answer[:100]}...")
+print(f"Follow-up response: {results_2.generated_answer[:100]}...")
+"""
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ async function main() {
+ // Basic synchronous request
+ const ragResponse = await client.retrieval.agent({
+ message: {
+ role: "user",
+ content: "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ ragTools: ["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"]
+ });
+
+ // Advanced analysis with streaming and extended thinking
+ const streamingResponse = await client.retrieval.agent({
+ message: {
+ role: "user",
+ content: "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ searchSettings: {limit: 20},
+ ragTools: ["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"],
+ ragGenerationConfig: {
+ model: "anthropic/claude-3-7-sonnet-20250219",
+ extendedThinking: true,
+ thinkingBudget: 4096,
+ temperature: 1,
+ maxTokens: 16000,
+ stream: true
+ }
+ });
+
+ // Process streaming events with emoji only on type change
+ if (Symbol.asyncIterator in streamingResponse) {
+ let currentEventType = null;
+
+ for await (const event of streamingResponse) {
+ // Check if event type has changed
+ const eventType = event.event;
+ if (eventType !== currentEventType) {
+ currentEventType = eventType;
+ console.log(); // Add newline before new event type
+
+ // Print emoji based on the new event type
+ switch(eventType) {
+ case "thinking":
+ process.stdout.write(`🧠 Thinking: `);
+ break;
+ case "tool_call":
+ process.stdout.write(`🔧 Tool call: `);
+ break;
+ case "tool_result":
+ process.stdout.write(`📊 Tool result: `);
+ break;
+ case "citation":
+ process.stdout.write(`📑 Citation: `);
+ break;
+ case "message":
+ process.stdout.write(`💬 Message: `);
+ break;
+ case "final_answer":
+ process.stdout.write(`✅ Final answer: `);
+ break;
+ }
+ }
+
+ // Print content based on event type
+ switch(eventType) {
+ case "thinking":
+ process.stdout.write(`${event.data.delta.content[0].payload.value}`);
+ break;
+ case "tool_call":
+ console.log(`${event.data.name}(${JSON.stringify(event.data.arguments)})`);
+ break;
+ case "tool_result":
+ console.log(`${event.data.content.substring(0, 60)}...`);
+ break;
+ case "citation":
+ console.log(`${event.data.id}`);
+ break;
+ case "message":
+ process.stdout.write(`${event.data.delta.content[0].payload.value}`);
+ break;
+ case "final_answer":
+ console.log(`${event.data.generated_answer.substring(0, 100)}...`);
+ console.log(` Citations: ${event.data.citations.length} sources referenced`);
+ break;
+ }
+ }
+ }
+
+ // Conversation with multiple turns (synchronous)
+ const conversation = await client.conversations.create();
+
+ // First message in conversation
+ const results1 = await client.retrieval.agent({
+ query: "What does DeepSeek R1 imply for the future of AI?",
+ ragGenerationConfig: {
+ model: "anthropic/claude-3-7-sonnet-20250219",
+ extendedThinking: true,
+ thinkingBudget: 4096,
+ temperature: 1,
+ maxTokens: 16000,
+ stream: true
+ },
+ conversationId: conversation.results.id
+ });
+
+ // Follow-up query in the same conversation
+ const results2 = await client.retrieval.agent({
+ query: "How does it compare to other reasoning models?",
+ ragGenerationConfig: {
+ model: "anthropic/claude-3-7-sonnet-20250219",
+ extendedThinking: true,
+ thinkingBudget: 4096,
+ temperature: 1,
+ maxTokens: 16000,
+ stream: true
+ },
+ conversationId: conversation.results.id
+ });
+
+ // Log the results
+ console.log(`First response: ${results1.generated_answer.substring(0, 100)}...`);
+ console.log(`Follow-up response: ${results2.generated_answer.substring(0, 100)}...`);
+ }
+
+ main();
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ # Basic request
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "message": {
+ "role": "user",
+ "content": "What were the key contributions of Aristotle to logic?"
+ },
+ "search_settings": {
+ "use_semantic_search": true,
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}
+ },
+ "rag_tools": ["search_file_knowledge", "content", "web_search"]
+ }'
+
+ # Advanced analysis with extended thinking
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "message": {
+ "role": "user",
+ "content": "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ "search_settings": {"limit": 20},
+ "research_tools": ["rag", "reasoning", "critique", "python_executor"],
+ "rag_generation_config": {
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": true,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": null,
+ "max_tokens": 16000,
+ "stream": true
+ }
+ }'
+
+ # Conversation continuation
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "message": {
+ "role": "user",
+ "content": "How does it compare to other reasoning models?"
+ },
+ "conversation_id": "YOUR_CONVERSATION_ID"
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for completion endpoint
+completion_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.completion(
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Italy?"}
+ ],
+ generation_config={
+ "model": "openai/gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 150,
+ "stream": False
+ }
+ )
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ async function main() {
+ const response = await client.completion({
+ messages: [
+ { role: "system", content: "You are a helpful assistant." },
+ { role: "user", content: "What is the capital of France?" },
+ { role: "assistant", content: "The capital of France is Paris." },
+ { role: "user", content: "What about Italy?" }
+ ],
+ generationConfig: {
+ model: "openai/gpt-4o-mini",
+ temperature: 0.7,
+ maxTokens: 150,
+ stream: false
+ }
+ });
+ }
+
+ main();
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/completion" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Italy?"}
+ ],
+ "generation_config": {
+ "model": "openai/gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 150,
+ "stream": false
+ }
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for embedding endpoint
+embedding_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.retrieval.embedding(
+ text="What is DeepSeek R1?",
+ )
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ async function main() {
+ const response = await client.retrieval.embedding({
+ text: "What is DeepSeek R1?",
+ });
+ }
+
+ main();
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/embedding" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "text": "What is DeepSeek R1?",
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated rag_app docstring
+rag_app_docstring = """
+Execute a RAG (Retrieval-Augmented Generation) query.
+
+This endpoint combines search results with language model generation to produce accurate,
+contextually-relevant responses based on your document corpus.
+
+**Features:**
+- Combines vector search, optional knowledge graph integration, and LLM generation
+- Automatically cites sources with unique citation identifiers
+- Supports both streaming and non-streaming responses
+- Compatible with various LLM providers (OpenAI, Anthropic, etc.)
+- Web search integration for up-to-date information
+
+**Search Configuration:**
+All search parameters from the search endpoint apply here, including filters, hybrid search, and graph-enhanced search.
+
+**Generation Configuration:**
+Fine-tune the language model's behavior with `rag_generation_config`:
+```json
+{
+ "model": "openai/gpt-4o-mini", // Model to use
+ "temperature": 0.7, // Control randomness (0-1)
+ "max_tokens": 1500, // Maximum output length
+ "stream": true // Enable token streaming
+}
+```
+
+**Model Support:**
+- OpenAI models (default)
+- Anthropic Claude models (requires ANTHROPIC_API_KEY)
+- Local models via Ollama
+- Any provider supported by LiteLLM
+
+**Streaming Responses:**
+When `stream: true` is set, the endpoint returns Server-Sent Events with the following types:
+- `search_results`: Initial search results from your documents
+- `message`: Partial tokens as they're generated
+- `citation`: Citation metadata when sources are referenced
+- `final_answer`: Complete answer with structured citations
+
+**Example Response:**
+```json
+{
+ "generated_answer": "DeepSeek-R1 is a model that demonstrates impressive performance...[1]",
+ "search_results": { ... },
+ "citations": [
+ {
+ "id": "cit.123456",
+ "object": "citation",
+ "payload": { ... }
+ }
+ ]
+}
+```
+"""
+
+# Updated agent_app docstring
+agent_app_docstring = """
+Engage with an intelligent agent for information retrieval, analysis, and research.
+
+This endpoint offers two operating modes:
+- **RAG mode**: Standard retrieval-augmented generation for answering questions based on knowledge base
+- **Research mode**: Advanced capabilities for deep analysis, reasoning, and computation
+
+### RAG Mode (Default)
+
+The RAG mode provides fast, knowledge-based responses using:
+- Semantic and hybrid search capabilities
+- Document-level and chunk-level content retrieval
+- Optional web search integration
+- Source citation and evidence-based responses
+
+### Research Mode
+
+The Research mode builds on RAG capabilities and adds:
+- A dedicated reasoning system for complex problem-solving
+- Critique capabilities to identify potential biases or logical fallacies
+- Python execution for computational analysis
+- Multi-step reasoning for deeper exploration of topics
+
+### Available Tools
+
+**RAG Tools:**
+- `search_file_knowledge`: Semantic/hybrid search on your ingested documents
+- `search_file_descriptions`: Search over file-level metadata
+- `content`: Fetch entire documents or chunk structures
+- `web_search`: Query external search APIs for up-to-date information
+- `web_scrape`: Scrape and extract content from specific web pages
+
+**Research Tools:**
+- `rag`: Leverage the underlying RAG agent for information retrieval
+- `reasoning`: Call a dedicated model for complex analytical thinking
+- `critique`: Analyze conversation history to identify flaws and biases
+- `python_executor`: Execute Python code for complex calculations and analysis
+
+### Streaming Output
+
+When streaming is enabled, the agent produces different event types:
+- `thinking`: Shows the model's step-by-step reasoning (when extended_thinking=true)
+- `tool_call`: Shows when the agent invokes a tool
+- `tool_result`: Shows the result of a tool call
+- `citation`: Indicates when a citation is added to the response
+- `message`: Streams partial tokens of the response
+- `final_answer`: Contains the complete generated answer and structured citations
+
+### Conversations
+
+Maintain context across multiple turns by including `conversation_id` in each request.
+After your first call, store the returned `conversation_id` and include it in subsequent calls.
+"""
+
+# Updated completion_docstring
+completion_docstring = """
+Generate completions for a list of messages.
+
+This endpoint uses the language model to generate completions for the provided messages.
+The generation process can be customized using the generation_config parameter.
+
+The messages list should contain alternating user and assistant messages, with an optional
+system message at the start. Each message should have a 'role' and 'content'.
+
+**Generation Configuration:**
+Fine-tune the language model's behavior with `generation_config`:
+```json
+{
+ "model": "openai/gpt-4o-mini", // Model to use
+ "temperature": 0.7, // Control randomness (0-1)
+ "max_tokens": 1500, // Maximum output length
+ "stream": true // Enable token streaming
+}
+```
+
+**Multiple LLM Support:**
+- OpenAI models (default)
+- Anthropic Claude models (requires ANTHROPIC_API_KEY)
+- Local models via Ollama
+- Any provider supported by LiteLLM
+"""
+
+# Updated embedding_docstring
+embedding_docstring = """
+Generate embeddings for the provided text using the specified model.
+
+This endpoint uses the language model to generate embeddings for the provided text.
+The model parameter specifies the model to use for generating embeddings.
+
+Embeddings are numerical representations of text that capture semantic meaning,
+allowing for similarity comparisons and other vector operations.
+
+**Uses:**
+- Semantic search
+- Document clustering
+- Text similarity analysis
+- Content recommendation
+"""
+
+# # Example implementation to update the routers in the RetrievalRouterV3 class
+# def update_retrieval_router(router_class):
+# """
+# Update the RetrievalRouterV3 class with the improved docstrings and examples.
+
+# This function demonstrates how the updated examples and docstrings would be
+# integrated into the actual router class.
+# """
+# # Update search_app endpoint
+# router_class.search_app.__doc__ = search_app_docstring
+# router_class.search_app.openapi_extra = search_app_examples
+
+# # Update rag_app endpoint
+# router_class.rag_app.__doc__ = rag_app_docstring
+# router_class.rag_app.openapi_extra = rag_app_examples
+
+# # Update agent_app endpoint
+# router_class.agent_app.__doc__ = agent_app_docstring
+# router_class.agent_app.openapi_extra = agent_app_examples
+
+# # Update completion endpoint
+# router_class.completion.__doc__ = completion_docstring
+# router_class.completion.openapi_extra = completion_examples
+
+# # Update embedding endpoint
+# router_class.embedding.__doc__ = embedding_docstring
+# router_class.embedding.openapi_extra = embedding_examples
+
+# return router_class
+
+# Example showing how the updated router would be integrated
+"""
+from your_module import RetrievalRouterV3
+
+# Apply the updated docstrings and examples
+router = RetrievalRouterV3(providers, services, config)
+router = update_retrieval_router(router)
+
+# Now the router has the improved docstrings and examples
+"""
+
+EXAMPLES = {
+ "search": search_app_examples,
+ "rag": rag_app_examples,
+ "agent": agent_app_examples,
+ "completion": completion_examples,
+ "embedding": embedding_examples,
+}
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py
new file mode 100644
index 00000000..244d76cf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py
@@ -0,0 +1,2051 @@
+import logging
+import textwrap
+from typing import Optional, cast
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+
+from core.base import GraphConstructionStatus, R2RException, Workflow
+from core.base.abstractions import DocumentResponse, StoreType
+from core.base.api.models import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ WrappedBooleanResponse,
+ WrappedCommunitiesResponse,
+ WrappedCommunityResponse,
+ WrappedEntitiesResponse,
+ WrappedEntityResponse,
+ WrappedGenericMessageResponse,
+ WrappedGraphResponse,
+ WrappedGraphsResponse,
+ WrappedRelationshipResponse,
+ WrappedRelationshipsResponse,
+)
+from core.utils import (
+ generate_default_user_collection_id,
+ update_settings_from_dict,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class GraphRouter(BaseRouterV3):
+ def __init__(
+ self,
+ providers: R2RProviders,
+ services: R2RServices,
+ config: R2RConfig,
+ ):
+ logging.info("Initializing GraphRouter")
+ super().__init__(providers, services, config)
+ self._register_workflows()
+
+ def _register_workflows(self):
+ workflow_messages = {}
+ if self.providers.orchestration.config.provider == "hatchet":
+ workflow_messages["graph-extraction"] = (
+ "Document extraction task queued successfully."
+ )
+ workflow_messages["graph-clustering"] = (
+ "Graph enrichment task queued successfully."
+ )
+ workflow_messages["graph-deduplication"] = (
+ "Entity deduplication task queued successfully."
+ )
+ else:
+ workflow_messages["graph-extraction"] = (
+ "Document entities and relationships extracted successfully."
+ )
+ workflow_messages["graph-clustering"] = (
+ "Graph communities created successfully."
+ )
+ workflow_messages["graph-deduplication"] = (
+ "Entity deduplication completed successfully."
+ )
+
+ self.providers.orchestration.register_workflows(
+ Workflow.GRAPH,
+ self.services.graph,
+ workflow_messages,
+ )
+
+ async def _get_collection_id(
+ self, collection_id: Optional[UUID], auth_user
+ ) -> UUID:
+ """Helper method to get collection ID, using default if none
+ provided."""
+ if collection_id is None:
+ return generate_default_user_collection_id(auth_user.id)
+ return collection_id
+
+ def _setup_routes(self):
+ @self.router.get(
+ "/graphs",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List graphs",
+ openapi_extra={
+ "x-codeSamples": [
+ { # TODO: Verify
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list()
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.list({});
+ }
+
+ main();
+ """
+ ),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_graphs(
+ collection_ids: list[str] = Query(
+ [],
+ description="A list of graph IDs to retrieve. If not provided, all graphs will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGraphsResponse:
+ """Returns a paginated list of graphs the authenticated user has
+ access to.
+
+ Results can be filtered by providing specific graph IDs. Regular
+ users will only see graphs they own or have access to. Superusers
+ can see all graphs.
+
+ The graphs are returned in order of last modification, with most
+ recent first.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ graph_uuids = [UUID(graph_id) for graph_id in collection_ids]
+
+ list_graphs_response = await self.services.graph.list_graphs(
+ # user_ids=requesting_user_id,
+ graph_ids=graph_uuids,
+ offset=offset,
+ limit=limit,
+ )
+
+ return ( # type: ignore
+ list_graphs_response["results"],
+ {"total_entries": list_graphs_response["total_entries"]},
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Retrieve graph details",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.retrieve({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\
+ -H "Authorization: Bearer YOUR_API_KEY" """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_graph(
+ collection_id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGraphResponse:
+ """Retrieves detailed information about a specific graph by ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection associated with the given graph.",
+ 403,
+ )
+
+ list_graphs_response = await self.services.graph.list_graphs(
+ # user_ids=None,
+ graph_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ return list_graphs_response["results"][0] # type: ignore
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities/build",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def build_communities(
+ collection_id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ graph_enrichment_settings: Optional[dict] = Body(
+ default=None,
+ description="Settings for the graph enrichment process.",
+ ),
+ run_with_orchestration: Optional[bool] = Body(True),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Creates communities in the graph by analyzing entity
+ relationships and similarities.
+
+ Communities are created through the following process:
+ 1. Analyzes entity relationships and metadata to build a similarity graph
+ 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups
+ 3. Creates hierarchical community structure with multiple granularity levels
+ 4. Generates natural language summaries and statistical insights for each community
+
+ The resulting communities can be used to:
+ - Understand high-level graph structure and organization
+ - Identify key entity groupings and their relationships
+ - Navigate and explore the graph at different levels of detail
+ - Generate insights about entity clusters and their characteristics
+
+ The community detection process is configurable through settings like:
+ - Community detection algorithm parameters
+ - Summary generation prompt
+ """
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id],
+ collection_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(collections_overview_response) == 0: # type: ignore
+ raise R2RException("Collection not found.", 404)
+
+ # Check user permissions for graph
+ if (
+ not auth_user.is_superuser
+ and collections_overview_response[0].owner_id != auth_user.id # type: ignore
+ ):
+ raise R2RException(
+ "Only superusers can `build communities` for a graph they do not own.",
+ 403,
+ )
+
+ # If no collection ID is provided, use the default user collection
+ # id = generate_default_user_collection_id(auth_user.id)
+
+ # Apply runtime settings overrides
+ server_graph_enrichment_settings = (
+ self.providers.database.config.graph_enrichment_settings
+ )
+ if graph_enrichment_settings:
+ server_graph_enrichment_settings = update_settings_from_dict(
+ server_graph_enrichment_settings, graph_enrichment_settings
+ )
+
+ workflow_input = {
+ "collection_id": str(collection_id),
+ "graph_enrichment_settings": server_graph_enrichment_settings.model_dump_json(),
+ "user": auth_user.json(),
+ }
+
+ if run_with_orchestration:
+ try:
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-clustering", {"request": workflow_input}, {}
+ )
+ return GenericMessageResponse(
+ message="Graph communities created successfully."
+ ) # type: ignore
+
+ except Exception as e: # TODO: Need to find specific error (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated community building: {e} \n\nAttempting to run without orchestration."
+ )
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info("Running build-communities without orchestration.")
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-clustering"](
+ workflow_input
+ )
+ return { # type: ignore
+ "message": "Graph communities created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.post(
+ "/graphs/{collection_id}/reset",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Reset a graph back to the initial state.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.reset(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.reset({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/reset" \\
+ -H "Authorization: Bearer YOUR_API_KEY" """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def reset(
+ collection_id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Deletes a graph and all its associated data.
+
+ This endpoint permanently removes the specified graph along with
+ all entities and relationships that belong to only this graph. The
+ original source entities and relationships extracted from
+ underlying documents are not deleted and are managed through the
+ document lifecycle.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException("Only superusers can reset a graph", 403)
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.reset_graph(id=collection_id)
+ # await _pull(collection_id, auth_user)
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ # update graph
+ @self.router.post(
+ "/graphs/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update graph",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.update(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ graph={
+ "name": "New Name",
+ "description": "New Description"
+ }
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.update({
+ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ name: "New Name",
+ description: "New Description",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_graph(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to update",
+ ),
+ name: Optional[str] = Body(
+ None, description="The name of the graph"
+ ),
+ description: Optional[str] = Body(
+ None, description="An optional description of the graph"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGraphResponse:
+ """Update an existing graphs's configuration.
+
+ This endpoint allows updating the name and description of an
+ existing collection. The user must have appropriate permissions to
+ modify the collection.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update graph details", 403
+ )
+
+ if (
+ not auth_user.is_superuser
+ and id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_graph( # type: ignore
+ collection_id,
+ name=name,
+ description=description,
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}/entities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.listEntities({
+ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_entities(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to list entities from.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntitiesResponse:
+ """Lists all entities in the graph with pagination support."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ entities, count = await self.services.graph.get_entities(
+ parent_id=collection_id,
+ offset=offset,
+ limit=limit,
+ )
+
+ return entities, { # type: ignore
+ "total_entries": count,
+ }
+
+ @self.router.post(
+ "/graphs/{collection_id}/entities/export",
+ summary="Export graph entities to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.graphs.export_entities(
+ collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.graphs.exportEntities({
+ collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/graphs/export_entities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_entities(
+ background_tasks: BackgroundTasks,
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the collection to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_graph_entities(
+ id=collection_id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/entities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def create_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to add the entity to.",
+ ),
+ name: str = Body(
+ ..., description="The name of the entity to create."
+ ),
+ description: str = Body(
+ ..., description="The description of the entity to create."
+ ),
+ category: Optional[str] = Body(
+ None, description="The category of the entity to create."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The metadata of the entity to create."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntityResponse:
+ """Creates a new entity in the graph."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.create_entity( # type: ignore
+ name=name,
+ description=description,
+ parent_id=collection_id,
+ category=category,
+ metadata=metadata,
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def create_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to add the relationship to.",
+ ),
+ subject: str = Body(
+ ..., description="The subject of the relationship to create."
+ ),
+ subject_id: UUID = Body(
+ ...,
+ description="The ID of the subject of the relationship to create.",
+ ),
+ predicate: str = Body(
+ ..., description="The predicate of the relationship to create."
+ ),
+ object: str = Body(
+ ..., description="The object of the relationship to create."
+ ),
+ object_id: UUID = Body(
+ ...,
+ description="The ID of the object of the relationship to create.",
+ ),
+ description: str = Body(
+ ...,
+ description="The description of the relationship to create.",
+ ),
+ weight: float = Body(
+ 1.0, description="The weight of the relationship to create."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The metadata of the relationship to create."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipResponse:
+ """Creates a new relationship in the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can create relationships.", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+ return await self.services.graph.create_relationship( # type: ignore
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ description=description,
+ weight=weight,
+ metadata=metadata,
+ parent_id=collection_id,
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/relationships/export",
+ summary="Export graph relationships to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.graphs.export_entities(
+ collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.graphs.exportEntities({
+ collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/graphs/export_relationships" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_relationships(
+ background_tasks: BackgroundTasks,
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_graph_relationships(
+ id=collection_id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}/entities/{entity_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get_entity(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.get_entity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the entity.",
+ ),
+ entity_id: UUID = Path(
+ ..., description="The ID of the entity to retrieve."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntityResponse:
+ """Retrieves a specific entity by its ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ result = await self.providers.database.graphs_handler.entities.get(
+ parent_id=collection_id,
+ store_type=StoreType.GRAPHS,
+ offset=0,
+ limit=1,
+ entity_ids=[entity_id],
+ )
+ if len(result) == 0 or len(result[0]) == 0:
+ raise R2RException("Entity not found", 404)
+ return result[0][0]
+
+ @self.router.post(
+ "/graphs/{collection_id}/entities/{entity_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def update_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the entity.",
+ ),
+ entity_id: UUID = Path(
+ ..., description="The ID of the entity to update."
+ ),
+ name: Optional[str] = Body(
+ ..., description="The updated name of the entity."
+ ),
+ description: Optional[str] = Body(
+ None, description="The updated description of the entity."
+ ),
+ category: Optional[str] = Body(
+ None, description="The updated category of the entity."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The updated metadata of the entity."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntityResponse:
+ """Updates an existing entity in the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update graph entities.", 403
+ )
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_entity( # type: ignore
+ entity_id=entity_id,
+ name=name,
+ category=category,
+ description=description,
+ metadata=metadata,
+ )
+
+ @self.router.delete(
+ "/graphs/{collection_id}/entities/{entity_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Remove an entity",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.remove_entity(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.removeEntity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to remove the entity from.",
+ ),
+ entity_id: UUID = Path(
+ ...,
+ description="The ID of the entity to remove from the graph.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Removes an entity from the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can delete graph details.", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.delete_entity(
+ parent_id=collection_id,
+ entity_id=entity_id,
+ )
+
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/graphs/{collection_id}/relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ description="Lists all relationships in the graph with pagination support.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.listRelationships({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_relationships(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to list relationships from.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipsResponse:
+ """Lists all relationships in the graph with pagination support."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ relationships, count = await self.services.graph.get_relationships(
+ parent_id=collection_id,
+ offset=offset,
+ limit=limit,
+ )
+
+ return relationships, { # type: ignore
+ "total_entries": count,
+ }
+
+ @self.router.get(
+ "/graphs/{collection_id}/relationships/{relationship_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ description="Retrieves a specific relationship by its ID.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get_relationship(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.getRelationship({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the relationship.",
+ ),
+ relationship_id: UUID = Path(
+ ..., description="The ID of the relationship to retrieve."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipResponse:
+ """Retrieves a specific relationship by its ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ results = (
+ await self.providers.database.graphs_handler.relationships.get(
+ parent_id=collection_id,
+ store_type=StoreType.GRAPHS,
+ offset=0,
+ limit=1,
+ relationship_ids=[relationship_id],
+ )
+ )
+ if len(results) == 0 or len(results[0]) == 0:
+ raise R2RException("Relationship not found", 404)
+ return results[0][0]
+
+ @self.router.post(
+ "/graphs/{collection_id}/relationships/{relationship_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def update_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the relationship.",
+ ),
+ relationship_id: UUID = Path(
+ ..., description="The ID of the relationship to update."
+ ),
+ subject: Optional[str] = Body(
+ ..., description="The updated subject of the relationship."
+ ),
+ subject_id: Optional[UUID] = Body(
+ ..., description="The updated subject ID of the relationship."
+ ),
+ predicate: Optional[str] = Body(
+ ..., description="The updated predicate of the relationship."
+ ),
+ object: Optional[str] = Body(
+ ..., description="The updated object of the relationship."
+ ),
+ object_id: Optional[UUID] = Body(
+ ..., description="The updated object ID of the relationship."
+ ),
+ description: Optional[str] = Body(
+ None,
+ description="The updated description of the relationship.",
+ ),
+ weight: Optional[float] = Body(
+ None, description="The updated weight of the relationship."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The updated metadata of the relationship."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipResponse:
+ """Updates an existing relationship in the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update graph details", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_relationship( # type: ignore
+ relationship_id=relationship_id,
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ description=description,
+ weight=weight,
+ metadata=metadata,
+ )
+
+ @self.router.delete(
+ "/graphs/{collection_id}/relationships/{relationship_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ description="Removes a relationship from the graph.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.delete_relationship(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.deleteRelationship({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def delete_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to remove the relationship from.",
+ ),
+ relationship_id: UUID = Path(
+ ...,
+ description="The ID of the relationship to remove from the graph.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Removes a relationship from the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can delete a relationship.", 403
+ )
+
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.delete_relationship(
+ parent_id=collection_id,
+ relationship_id=relationship_id,
+ )
+
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create a new community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.create_community(
+ collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ name="My Community",
+ summary="A summary of the community",
+ findings=["Finding 1", "Finding 2"],
+ rating=5,
+ rating_explanation="This is a rating explanation",
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.createCommunity({
+ collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ name: "My Community",
+ summary: "A summary of the community",
+ findings: ["Finding 1", "Finding 2"],
+ rating: 5,
+ ratingExplanation: "This is a rating explanation",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_community(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to create the community in.",
+ ),
+ name: str = Body(..., description="The name of the community"),
+ summary: str = Body(..., description="A summary of the community"),
+ findings: Optional[list[str]] = Body(
+ default=[], description="Findings about the community"
+ ),
+ rating: Optional[float] = Body(
+ default=5, ge=1, le=10, description="Rating between 1 and 10"
+ ),
+ rating_explanation: Optional[str] = Body(
+ default="", description="Explanation for the rating"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunityResponse:
+ """Creates a new community in the graph.
+
+ While communities are typically built automatically via the /graphs/{id}/communities/build endpoint,
+ this endpoint allows you to manually create your own communities.
+
+ This can be useful when you want to:
+ - Define custom groupings of entities based on domain knowledge
+ - Add communities that weren't detected by the automatic process
+ - Create hierarchical organization structures
+ - Tag groups of entities with specific metadata
+
+ The created communities will be integrated with any existing automatically detected communities
+ in the graph's community structure.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can create a community.", 403
+ )
+
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.create_community( # type: ignore
+ parent_id=collection_id,
+ name=name,
+ summary=summary,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}/communities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List communities",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.listCommunities({
+ collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_communities(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to get communities for.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunitiesResponse:
+ """Lists all communities in the graph with pagination support."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ communities, count = await self.services.graph.get_communities(
+ parent_id=collection_id,
+ offset=offset,
+ limit=limit,
+ )
+
+ return communities, { # type: ignore
+ "total_entries": count,
+ }
+
+ @self.router.get(
+ "/graphs/{collection_id}/communities/{community_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Retrieve a community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.getCommunity({
+ collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_community(
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the collection to get communities for.",
+ ),
+ community_id: UUID = Path(
+ ...,
+ description="The ID of the community to get.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunityResponse:
+ """Retrieves a specific community by its ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ results = (
+ await self.providers.database.graphs_handler.communities.get(
+ parent_id=collection_id,
+ community_ids=[community_id],
+ store_type=StoreType.GRAPHS,
+ offset=0,
+ limit=1,
+ )
+ )
+ if len(results) == 0 or len(results[0]) == 0:
+ raise R2RException("Community not found", 404)
+ return results[0][0]
+
+ @self.router.delete(
+ "/graphs/{collection_id}/communities/{community_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete a community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.delete_community(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ community_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.deleteCommunity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_community(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to delete the community from.",
+ ),
+ community_id: UUID = Path(
+ ...,
+ description="The ID of the community to delete.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.graph_ids
+ ):
+ raise R2RException(
+ "Only superusers can delete communities", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.delete_community(
+ parent_id=collection_id,
+ community_id=community_id,
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities/export",
+ summary="Export document communities to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.graphs.export_communities(
+ collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.graphs.exportCommunities({
+ collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/graphs/export_communities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_communities(
+ background_tasks: BackgroundTasks,
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_graph_communities(
+ id=collection_id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities/{community_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.update_community(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ community_update={
+ "metadata": {
+ "topic": "Technology",
+ "description": "Tech companies and products"
+ }
+ }
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ async function main() {
+ const response = await client.graphs.updateCommunity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ communityUpdate: {
+ metadata: {
+ topic: "Technology",
+ description: "Tech companies and products"
+ }
+ }
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_community(
+ collection_id: UUID = Path(...),
+ community_id: UUID = Path(...),
+ name: Optional[str] = Body(None),
+ summary: Optional[str] = Body(None),
+ findings: Optional[list[str]] = Body(None),
+ rating: Optional[float] = Body(default=None, ge=1, le=10),
+ rating_explanation: Optional[str] = Body(None),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunityResponse:
+ """Updates an existing community in the graph."""
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.graph_ids
+ ):
+ raise R2RException(
+ "Only superusers can update communities.", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_community( # type: ignore
+ community_id=community_id,
+ name=name,
+ summary=summary,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/pull",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Pull latest entities to the graph",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.pull(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ async function main() {
+ const response = await client.graphs.pull({
+ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def pull(
+ collection_id: UUID = Path(
+ ..., description="The ID of the graph to initialize."
+ ),
+ force: Optional[bool] = Body(
+ False,
+ description="If true, forces a re-pull of all entities and relationships.",
+ ),
+ # document_ids: list[UUID] = Body(
+ # ..., description="List of document IDs to add to the graph."
+ # ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Adds documents to a graph by copying their entities and
+ relationships.
+
+ This endpoint:
+ 1. Copies document entities to the graphs_entities table
+ 2. Copies document relationships to the graphs_relationships table
+ 3. Associates the documents with the graph
+
+ When a document is added:
+ - Its entities and relationships are copied to graph-specific tables
+ - Existing entities/relationships are updated by merging their properties
+ - The document ID is recorded in the graph's document_ids array
+
+ Documents added to a graph will contribute their knowledge to:
+ - Graph analysis and querying
+ - Community detection
+ - Knowledge graph enrichment
+
+ The user must have access to both the graph and the documents being added.
+ """
+
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id],
+ collection_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(collections_overview_response) == 0: # type: ignore
+ raise R2RException("Collection not found.", 404)
+
+ # Check user permissions for graph
+ if (
+ not auth_user.is_superuser
+ and collections_overview_response[0].owner_id != auth_user.id # type: ignore
+ ):
+ raise R2RException("Only superusers can `pull` a graph.", 403)
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ list_graphs_response = await self.services.graph.list_graphs(
+ # user_ids=None,
+ graph_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ if len(list_graphs_response["results"]) == 0: # type: ignore
+ raise R2RException("Graph not found", 404)
+ collection_id = list_graphs_response["results"][0].collection_id # type: ignore
+ documents: list[DocumentResponse] = []
+ document_req = await self.providers.database.collections_handler.documents_in_collection(
+ collection_id, offset=0, limit=100
+ )
+ results = cast(list[DocumentResponse], document_req["results"])
+ documents.extend(results)
+
+ while len(results) == 100:
+ document_req = await self.providers.database.collections_handler.documents_in_collection(
+ collection_id, offset=len(documents), limit=100
+ )
+ results = cast(list[DocumentResponse], document_req["results"])
+ documents.extend(results)
+
+ success = False
+
+ for document in documents:
+ entities = (
+ await self.providers.database.graphs_handler.entities.get(
+ parent_id=document.id,
+ store_type=StoreType.DOCUMENTS,
+ offset=0,
+ limit=100,
+ )
+ )
+ has_document = (
+ await self.providers.database.graphs_handler.has_document(
+ collection_id, document.id
+ )
+ )
+ if has_document:
+ logger.info(
+ f"Document {document.id} is already in graph {collection_id}, skipping."
+ )
+ continue
+ if len(entities[0]) == 0:
+ if not force:
+ logger.warning(
+ f"Document {document.id} has no entities, extraction may not have been called, skipping."
+ )
+ continue
+ else:
+ logger.warning(
+ f"Document {document.id} has no entities, but force=True, continuing."
+ )
+
+ success = (
+ await self.providers.database.graphs_handler.add_documents(
+ id=collection_id,
+ document_ids=[document.id],
+ )
+ )
+ if not success:
+ logger.warning(
+ f"No documents were added to graph {collection_id}, marking as failed."
+ )
+
+ if success:
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.SUCCESS,
+ )
+
+ return GenericBooleanResponse(success=success) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py
new file mode 100644
index 00000000..29b75226
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py
@@ -0,0 +1,576 @@
+import logging
+import textwrap
+from typing import Optional
+
+from fastapi import Body, Depends, Path, Query
+
+from core.base import IndexConfig, R2RException
+from core.base.abstractions import VectorTableName
+from core.base.api.models import (
+ VectorIndexResponse,
+ VectorIndicesResponse,
+ WrappedGenericMessageResponse,
+ WrappedVectorIndexResponse,
+ WrappedVectorIndicesResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class IndicesRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing IndicesRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ ## TODO - Allow developer to pass the index id with the request
+ @self.router.post(
+ "/indices",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create Vector Index",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ # Create an HNSW index for efficient similarity search
+ result = client.indices.create(
+ config={
+ "table_name": "chunks", # The table containing vector embeddings
+ "index_method": "hnsw", # Hierarchical Navigable Small World graph
+ "index_measure": "cosine_distance", # Similarity measure
+ "index_arguments": {
+ "m": 16, # Number of connections per layer
+ "ef_construction": 64,# Size of dynamic candidate list for construction
+ "ef": 40, # Size of dynamic candidate list for search
+ },
+ "index_name": "my_document_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": True # Build index without blocking table writes
+ },
+ run_with_orchestration=True # Run as orchestrated task for large indices
+ )
+
+ # Create an IVF-Flat index for balanced performance
+ result = client.indices.create(
+ config={
+ "table_name": "chunks",
+ "index_method": "ivf_flat", # Inverted File with Flat storage
+ "index_measure": "l2_distance",
+ "index_arguments": {
+ "lists": 100, # Number of cluster centroids
+ "probe": 10, # Number of clusters to search
+ },
+ "index_name": "my_ivf_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": True
+ }
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.create({
+ config: {
+ tableName: "vectors",
+ indexMethod: "hnsw",
+ indexMeasure: "cosine_distance",
+ indexArguments: {
+ m: 16,
+ ef_construction: 64,
+ ef: 40
+ },
+ indexName: "my_document_embeddings_idx",
+ indexColumn: "embedding",
+ concurrently: true
+ },
+ runWithOrchestration: true
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ # Create HNSW Index
+ curl -X POST "https://api.example.com/indices" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "config": {
+ "table_name": "vectors",
+ "index_method": "hnsw",
+ "index_measure": "cosine_distance",
+ "index_arguments": {
+ "m": 16,
+ "ef_construction": 64,
+ "ef": 40
+ },
+ "index_name": "my_document_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": true
+ },
+ "run_with_orchestration": true
+ }'
+
+ # Create IVF-Flat Index
+ curl -X POST "https://api.example.com/indices" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "config": {
+ "table_name": "vectors",
+ "index_method": "ivf_flat",
+ "index_measure": "l2_distance",
+ "index_arguments": {
+ "lists": 100,
+ "probe": 10
+ },
+ "index_name": "my_ivf_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": true
+ }
+ }'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_index(
+ config: IndexConfig,
+ run_with_orchestration: Optional[bool] = Body(
+ True,
+ description="Whether to run index creation as an orchestrated task (recommended for large indices)",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Create a new vector similarity search index in over the target
+ table. Allowed tables include 'vectors', 'entity',
+ 'document_collections'. Vectors correspond to the chunks of text
+ that are indexed for similarity search, whereas entity and
+ document_collections are created during knowledge graph
+ construction.
+
+ This endpoint creates a database index optimized for efficient similarity search over vector embeddings.
+ It supports two main indexing methods:
+
+ 1. HNSW (Hierarchical Navigable Small World):
+ - Best for: High-dimensional vectors requiring fast approximate nearest neighbor search
+ - Pros: Very fast search, good recall, memory-resident for speed
+ - Cons: Slower index construction, more memory usage
+ - Key parameters:
+ * m: Number of connections per layer (higher = better recall but more memory)
+ * ef_construction: Build-time search width (higher = better recall but slower build)
+ * ef: Query-time search width (higher = better recall but slower search)
+
+ 2. IVF-Flat (Inverted File with Flat Storage):
+ - Best for: Balance between build speed, search speed, and recall
+ - Pros: Faster index construction, less memory usage
+ - Cons: Slightly slower search than HNSW
+ - Key parameters:
+ * lists: Number of clusters (usually sqrt(n) where n is number of vectors)
+ * probe: Number of nearest clusters to search
+
+ Supported similarity measures:
+ - cosine_distance: Best for comparing semantic similarity
+ - l2_distance: Best for comparing absolute distances
+ - ip_distance: Best for comparing raw dot products
+
+ Notes:
+ - Index creation can be resource-intensive for large datasets
+ - Use run_with_orchestration=True for large indices to prevent timeouts
+ - The 'concurrently' option allows other operations while building
+ - Index names must be unique per table
+ """
+ # TODO: Implement index creation logic
+ logger.info(
+ f"Creating vector index for {config.table_name} with method {config.index_method}, measure {config.index_measure}, concurrently {config.concurrently}"
+ )
+
+ result = await self.providers.orchestration.run_workflow(
+ "create-vector-index",
+ {
+ "request": {
+ "table_name": config.table_name,
+ "index_method": config.index_method,
+ "index_measure": config.index_measure,
+ "index_name": config.index_name,
+ "index_column": config.index_column,
+ "index_arguments": config.index_arguments,
+ "concurrently": config.concurrently,
+ },
+ },
+ options={
+ "additional_metadata": {},
+ },
+ )
+
+ return result # type: ignore
+
+ @self.router.get(
+ "/indices",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List Vector Indices",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ # List all indices
+ indices = client.indices.list(
+ offset=0,
+ limit=10
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.list({
+ offset: 0,
+ limit: 10,
+ filters: { table_name: "vectors" }
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/indices?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json"
+
+ # With filters
+ curl -X GET "https://api.example.com/indices?offset=0&limit=10&filters={\"table_name\":\"vectors\"}" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_indices(
+ # filters: list[str] = Query([]),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedVectorIndicesResponse:
+ """List existing vector similarity search indices with pagination
+ support.
+
+ Returns details about each index including:
+ - Name and table name
+ - Indexing method and parameters
+ - Size and row count
+ - Creation timestamp and last updated
+ - Performance statistics (if available)
+
+ The response can be filtered using the filter_by parameter to narrow down results
+ based on table name, index method, or other attributes.
+ """
+ # TODO: Implement index listing logic
+ indices_data = (
+ await self.providers.database.chunks_handler.list_indices(
+ offset=offset, limit=limit
+ )
+ )
+
+ formatted_indices = VectorIndicesResponse(
+ indices=[
+ VectorIndexResponse(index=index_data)
+ for index_data in indices_data["indices"]
+ ]
+ )
+
+ return ( # type: ignore
+ formatted_indices,
+ {"total_entries": indices_data["total_entries"]},
+ )
+
+ @self.router.get(
+ "/indices/{table_name}/{index_name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get Vector Index Details",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ # Get detailed information about a specific index
+ index = client.indices.retrieve("index_1")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.retrieve({
+ indexName: "index_1",
+ tableName: "vectors"
+ });
+
+ console.log(response);
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/indices/vectors/index_1" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_index(
+ table_name: VectorTableName = Path(
+ ...,
+ description="The table of vector embeddings to delete (e.g. `vectors`, `entity`, `document_collections`)",
+ ),
+ index_name: str = Path(
+ ..., description="The name of the index to delete"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedVectorIndexResponse:
+ """Get detailed information about a specific vector index.
+
+ Returns comprehensive information about the index including:
+ - Configuration details (method, measure, parameters)
+ - Current size and row count
+ - Build progress (if still under construction)
+ - Performance statistics:
+ * Average query time
+ * Memory usage
+ * Cache hit rates
+ * Recent query patterns
+ - Maintenance information:
+ * Last vacuum
+ * Fragmentation level
+ * Recommended optimizations
+ """
+ # TODO: Implement get index logic
+ indices = (
+ await self.providers.database.chunks_handler.list_indices(
+ filters={
+ "index_name": index_name,
+ "table_name": table_name,
+ },
+ limit=1,
+ offset=0,
+ )
+ )
+ if len(indices["indices"]) != 1:
+ raise R2RException(
+ f"Index '{index_name}' not found", status_code=404
+ )
+ return {"index": indices["indices"][0]} # type: ignore
+
+ # TODO - Implement update index
+ # @self.router.post(
+ # "/indices/{name}",
+ # summary="Update Vector Index",
+ # openapi_extra={
+ # "x-codeSamples": [
+ # {
+ # "lang": "Python",
+ # "source": """
+ # from r2r import R2RClient
+
+ # client = R2RClient()
+
+ # # Update HNSW index parameters
+ # result = client.indices.update(
+ # "550e8400-e29b-41d4-a716-446655440000",
+ # config={
+ # "index_arguments": {
+ # "ef": 80, # Increase search quality
+ # "m": 24 # Increase connections per layer
+ # },
+ # "concurrently": True
+ # },
+ # run_with_orchestration=True
+ # )""",
+ # },
+ # {
+ # "lang": "Shell",
+ # "source": """
+ # curl -X PUT "https://api.example.com/indices/550e8400-e29b-41d4-a716-446655440000" \\
+ # -H "Content-Type: application/json" \\
+ # -H "Authorization: Bearer YOUR_API_KEY" \\
+ # -d '{
+ # "config": {
+ # "index_arguments": {
+ # "ef": 80,
+ # "m": 24
+ # },
+ # "concurrently": true
+ # },
+ # "run_with_orchestration": true
+ # }'""",
+ # },
+ # ]
+ # },
+ # )
+ # @self.base_endpoint
+ # async def update_index(
+ # id: UUID = Path(...),
+ # config: IndexConfig = Body(...),
+ # run_with_orchestration: Optional[bool] = Body(True),
+ # auth_user=Depends(self.providers.auth.auth_wrapper()),
+ # ): # -> WrappedUpdateIndexResponse:
+ # """
+ # Update an existing index's configuration.
+ # """
+ # # TODO: Implement index update logic
+ # pass
+
+ @self.router.delete(
+ "/indices/{table_name}/{index_name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete Vector Index",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ # Delete an index with orchestration for cleanup
+ result = client.indices.delete(
+ index_name="index_1",
+ table_name="vectors",
+ run_with_orchestration=True
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.delete({
+ indexName: "index_1"
+ tableName: "vectors"
+ });
+
+ console.log(response);
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/indices/index_1" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_index(
+ table_name: VectorTableName = Path(
+ default=...,
+ description="The table of vector embeddings to delete (e.g. `vectors`, `entity`, `document_collections`)",
+ ),
+ index_name: str = Path(
+ ..., description="The name of the index to delete"
+ ),
+ # concurrently: bool = Body(
+ # default=True,
+ # description="Whether to delete the index concurrently (recommended for large indices)",
+ # ),
+ # run_with_orchestration: Optional[bool] = Body(True),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Delete an existing vector similarity search index.
+
+ This endpoint removes the specified index from the database. Important considerations:
+
+ - Deletion is permanent and cannot be undone
+ - Underlying vector data remains intact
+ - Queries will fall back to sequential scan
+ - Running queries during deletion may be slower
+ - Use run_with_orchestration=True for large indices to prevent timeouts
+ - Consider index dependencies before deletion
+
+ The operation returns immediately but cleanup may continue in background.
+ """
+ logger.info(
+ f"Deleting vector index {index_name} from table {table_name}"
+ )
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "delete-vector-index",
+ {
+ "request": {
+ "index_name": index_name,
+ "table_name": table_name,
+ "concurrently": True,
+ },
+ },
+ options={
+ "additional_metadata": {},
+ },
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py
new file mode 100644
index 00000000..55512143
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py
@@ -0,0 +1,387 @@
+import logging
+import textwrap
+from typing import Optional
+
+from fastapi import Body, Depends, Path, Query
+
+from core.base import R2RException
+from core.base.api.models import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ WrappedBooleanResponse,
+ WrappedGenericMessageResponse,
+ WrappedPromptResponse,
+ WrappedPromptsResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+
+class PromptsRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing PromptsRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/prompts",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create a new prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.create(
+ name="greeting_prompt",
+ template="Hello, {name}!",
+ input_types={"name": "string"}
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.create({
+ name: "greeting_prompt",
+ template: "Hello, {name}!",
+ inputTypes: { name: "string" },
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/prompts" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"name": "greeting_prompt", "template": "Hello, {name}!", "input_types": {"name": "string"}}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_prompt(
+ name: str = Body(..., description="The name of the prompt"),
+ template: str = Body(
+ ..., description="The template string for the prompt"
+ ),
+ input_types: dict[str, str] = Body(
+ default={},
+ description="A dictionary mapping input names to their types",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Create a new prompt with the given configuration.
+
+ This endpoint allows superusers to create a new prompt with a
+ specified name, template, and input types.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can create prompts.",
+ 403,
+ )
+ result = await self.services.management.add_prompt(
+ name, template, input_types
+ )
+ return GenericMessageResponse(message=result) # type: ignore
+
+ @self.router.get(
+ "/prompts",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List all prompts",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.list()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/prompts" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_prompts(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedPromptsResponse:
+ """List all available prompts.
+
+ This endpoint retrieves a list of all prompts in the system. Only
+ superusers can access this endpoint.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can list prompts.",
+ 403,
+ )
+ get_prompts_response = (
+ await self.services.management.get_all_prompts()
+ )
+
+ return ( # type: ignore
+ get_prompts_response["results"],
+ {
+ "total_entries": get_prompts_response["total_entries"],
+ },
+ )
+
+ @self.router.post(
+ "/prompts/{name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get a specific prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.get(
+ "greeting_prompt",
+ inputs={"name": "John"},
+ prompt_override="Hi, {name}!"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.retrieve({
+ name: "greeting_prompt",
+ inputs: { name: "John" },
+ promptOverride: "Hi, {name}!",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_prompt(
+ name: str = Path(..., description="Prompt name"),
+ inputs: Optional[dict[str, str]] = Body(
+ None, description="Prompt inputs"
+ ),
+ prompt_override: Optional[str] = Query(
+ None, description="Prompt override"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedPromptResponse:
+ """Get a specific prompt by name, optionally with inputs and
+ override.
+
+ This endpoint retrieves a specific prompt and allows for optional
+ inputs and template override. Only superusers can access this
+ endpoint.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can retrieve prompts.",
+ 403,
+ )
+ result = await self.services.management.get_prompt(
+ name, inputs, prompt_override
+ )
+ return result # type: ignore
+
+ @self.router.put(
+ "/prompts/{name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update an existing prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.update(
+ "greeting_prompt",
+ template="Greetings, {name}!",
+ input_types={"name": "string", "age": "integer"}
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.update({
+ name: "greeting_prompt",
+ template: "Greetings, {name}!",
+ inputTypes: { name: "string", age: "integer" },
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X PUT "https://api.example.com/v3/prompts/greeting_prompt" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"template": "Greetings, {name}!", "input_types": {"name": "string", "age": "integer"}}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_prompt(
+ name: str = Path(..., description="Prompt name"),
+ template: Optional[str] = Body(
+ None, description="Updated prompt template"
+ ),
+ input_types: dict[str, str] = Body(
+ default={},
+ description="A dictionary mapping input names to their types",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Update an existing prompt's template and/or input types.
+
+ This endpoint allows superusers to update the template and input
+ types of an existing prompt.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can update prompts.",
+ 403,
+ )
+ result = await self.services.management.update_prompt(
+ name, template, input_types
+ )
+ return GenericMessageResponse(message=result) # type: ignore
+
+ @self.router.delete(
+ "/prompts/{name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete a prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.delete("greeting_prompt")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.delete({
+ name: "greeting_prompt",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/prompts/greeting_prompt" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_prompt(
+ name: str = Path(..., description="Prompt name"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a prompt by name.
+
+ This endpoint allows superusers to delete an existing prompt.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can delete prompts.",
+ 403,
+ )
+ await self.services.management.delete_prompt(name)
+ return GenericBooleanResponse(success=True) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py
new file mode 100644
index 00000000..28749319
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py
@@ -0,0 +1,639 @@
+import logging
+from typing import Any, Literal, Optional
+from uuid import UUID
+
+from fastapi import Body, Depends
+from fastapi.responses import StreamingResponse
+
+from core.base import (
+ GenerationConfig,
+ Message,
+ R2RException,
+ SearchMode,
+ SearchSettings,
+ select_search_filters,
+)
+from core.base.api.models import (
+ WrappedAgentResponse,
+ WrappedCompletionResponse,
+ WrappedEmbeddingResponse,
+ WrappedLLMChatCompletion,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+from .examples import EXAMPLES
+
+logger = logging.getLogger(__name__)
+
+
+def merge_search_settings(
+ base: SearchSettings, overrides: SearchSettings
+) -> SearchSettings:
+ # Convert both to dict
+ base_dict = base.model_dump()
+ overrides_dict = overrides.model_dump(exclude_unset=True)
+
+ # Update base_dict with values from overrides_dict
+ # This ensures that any field set in overrides takes precedence
+ for k, v in overrides_dict.items():
+ base_dict[k] = v
+
+ # Construct a new SearchSettings from the merged dict
+ return SearchSettings(**base_dict)
+
+
+class RetrievalRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing RetrievalRouter")
+ super().__init__(providers, services, config)
+
+ def _register_workflows(self):
+ pass
+
+ def _prepare_search_settings(
+ self,
+ auth_user: Any,
+ search_mode: SearchMode,
+ search_settings: Optional[SearchSettings],
+ ) -> SearchSettings:
+ """Prepare the effective search settings based on the provided
+ search_mode, optional user-overrides in search_settings, and applied
+ filters."""
+ if search_mode != SearchMode.custom:
+ # Start from mode defaults
+ effective_settings = SearchSettings.get_default(search_mode.value)
+ if search_settings:
+ # Merge user-provided overrides
+ effective_settings = merge_search_settings(
+ effective_settings, search_settings
+ )
+ else:
+ # Custom mode: use provided settings or defaults
+ effective_settings = search_settings or SearchSettings()
+
+ # Apply user-specific filters
+ effective_settings.filters = select_search_filters(
+ auth_user, effective_settings
+ )
+ return effective_settings
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/retrieval/search",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Search R2R",
+ openapi_extra=EXAMPLES["search"],
+ )
+ @self.base_endpoint
+ async def search_app(
+ query: str = Body(
+ ...,
+ description="Search query to find relevant documents",
+ ),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description=(
+ "Default value of `custom` allows full control over search settings.\n\n"
+ "Pre-configured search modes:\n"
+ "`basic`: A simple semantic-based search.\n"
+ "`advanced`: A more powerful hybrid search combining semantic and full-text.\n"
+ "`custom`: Full control via `search_settings`.\n\n"
+ "If `filters` or `limit` are provided alongside `basic` or `advanced`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ search_settings: Optional[SearchSettings] = Body(
+ None,
+ description=(
+ "The search configuration object. If `search_mode` is `custom`, "
+ "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n"
+ "Common overrides include `filters` to narrow results and `limit` to control how many results are returned."
+ ),
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedSearchResponse:
+ """Perform a search query against vector and/or graph-based
+ databases.
+
+ **Search Modes:**
+ - `basic`: Defaults to semantic search. Simple and easy to use.
+ - `advanced`: Combines semantic search with full-text search for more comprehensive results.
+ - `custom`: Complete control over how search is performed. Provide a full `SearchSettings` object.
+
+ **Filters:**
+ Apply filters directly inside `search_settings.filters`. For example:
+ ```json
+ {
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}
+ }
+ ```
+ Supported operators: `$eq`, `$neq`, `$gt`, `$gte`, `$lt`, `$lte`, `$like`, `$ilike`, `$in`, `$nin`.
+
+ **Hybrid Search:**
+ Enable hybrid search by setting `use_hybrid_search: true` in search_settings. This combines semantic search with
+ keyword-based search for improved results. Configure with `hybrid_settings`:
+ ```json
+ {
+ "use_hybrid_search": true,
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50
+ }
+ }
+ ```
+
+ **Graph-Enhanced Search:**
+ Knowledge graph integration is enabled by default. Control with `graph_search_settings`:
+ ```json
+ {
+ "graph_search_settings": {
+ "use_graph_search": true,
+ "kg_search_type": "local"
+ }
+ }
+ ```
+
+ **Advanced Filtering:**
+ Use complex filters to narrow down results by metadata fields or document properties:
+ ```json
+ {
+ "filters": {
+ "$and":[
+ {"document_type": {"$eq": "pdf"}},
+ {"metadata.year": {"$gt": 2020}}
+ ]
+ }
+ }
+ ```
+
+ **Results:**
+ The response includes vector search results and optional graph search results.
+ Each result contains the matched text, document ID, and relevance score.
+
+ """
+ if query == "":
+ raise R2RException("Query cannot be empty", 400)
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+ results = await self.services.retrieval.search(
+ query=query,
+ search_settings=effective_settings,
+ )
+ return results # type: ignore
+
+ @self.router.post(
+ "/retrieval/rag",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="RAG Query",
+ response_model=None,
+ openapi_extra=EXAMPLES["rag"],
+ )
+ @self.base_endpoint
+ async def rag_app(
+ query: str = Body(...),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description=(
+ "Default value of `custom` allows full control over search settings.\n\n"
+ "Pre-configured search modes:\n"
+ "`basic`: A simple semantic-based search.\n"
+ "`advanced`: A more powerful hybrid search combining semantic and full-text.\n"
+ "`custom`: Full control via `search_settings`.\n\n"
+ "If `filters` or `limit` are provided alongside `basic` or `advanced`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ search_settings: Optional[SearchSettings] = Body(
+ None,
+ description=(
+ "The search configuration object. If `search_mode` is `custom`, "
+ "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n"
+ "Common overrides include `filters` to narrow results and `limit` to control how many results are returned."
+ ),
+ ),
+ rag_generation_config: GenerationConfig = Body(
+ default_factory=GenerationConfig,
+ description="Configuration for RAG generation",
+ ),
+ task_prompt: Optional[str] = Body(
+ default=None,
+ description="Optional custom prompt to override default",
+ ),
+ include_title_if_available: bool = Body(
+ default=False,
+ description="Include document titles in responses when available",
+ ),
+ include_web_search: bool = Body(
+ default=False,
+ description="Include web search results provided to the LLM.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRAGResponse:
+ """Execute a RAG (Retrieval-Augmented Generation) query.
+
+ This endpoint combines search results with language model generation to produce accurate,
+ contextually-relevant responses based on your document corpus.
+
+ **Features:**
+ - Combines vector search, optional knowledge graph integration, and LLM generation
+ - Automatically cites sources with unique citation identifiers
+ - Supports both streaming and non-streaming responses
+ - Compatible with various LLM providers (OpenAI, Anthropic, etc.)
+ - Web search integration for up-to-date information
+
+ **Search Configuration:**
+ All search parameters from the search endpoint apply here, including filters, hybrid search, and graph-enhanced search.
+
+ **Generation Configuration:**
+ Fine-tune the language model's behavior with `rag_generation_config`:
+ ```json
+ {
+ "model": "openai/gpt-4o-mini", // Model to use
+ "temperature": 0.7, // Control randomness (0-1)
+ "max_tokens": 1500, // Maximum output length
+ "stream": true // Enable token streaming
+ }
+ ```
+
+ **Model Support:**
+ - OpenAI models (default)
+ - Anthropic Claude models (requires ANTHROPIC_API_KEY)
+ - Local models via Ollama
+ - Any provider supported by LiteLLM
+
+ **Streaming Responses:**
+ When `stream: true` is set, the endpoint returns Server-Sent Events with the following types:
+ - `search_results`: Initial search results from your documents
+ - `message`: Partial tokens as they're generated
+ - `citation`: Citation metadata when sources are referenced
+ - `final_answer`: Complete answer with structured citations
+
+ **Example Response:**
+ ```json
+ {
+ "generated_answer": "DeepSeek-R1 is a model that demonstrates impressive performance...[1]",
+ "search_results": { ... },
+ "citations": [
+ {
+ "id": "cit.123456",
+ "object": "citation",
+ "payload": { ... }
+ }
+ ]
+ }
+ ```
+ """
+
+ if "model" not in rag_generation_config.__fields_set__:
+ rag_generation_config.model = self.config.app.quality_llm
+
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+
+ response = await self.services.retrieval.rag(
+ query=query,
+ search_settings=effective_settings,
+ rag_generation_config=rag_generation_config,
+ task_prompt=task_prompt,
+ include_title_if_available=include_title_if_available,
+ include_web_search=include_web_search,
+ )
+
+ if rag_generation_config.stream:
+ # ========== Streaming path ==========
+ async def stream_generator():
+ try:
+ async for chunk in response:
+ if len(chunk) > 1024:
+ for i in range(0, len(chunk), 1024):
+ yield chunk[i : i + 1024]
+ else:
+ yield chunk
+ except GeneratorExit:
+ # Clean up if needed, then return
+ return
+
+ return StreamingResponse(
+ stream_generator(), media_type="text/event-stream"
+ ) # type: ignore
+ else:
+ # ========== Non-streaming path ==========
+ return response
+
+ @self.router.post(
+ "/retrieval/agent",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="RAG-powered Conversational Agent",
+ openapi_extra=EXAMPLES["agent"],
+ )
+ @self.base_endpoint
+ async def agent_app(
+ message: Optional[Message] = Body(
+ None,
+ description="Current message to process",
+ ),
+ messages: Optional[list[Message]] = Body(
+ None,
+ deprecated=True,
+ description="List of messages (deprecated, use message instead)",
+ ),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description="Pre-configured search modes: basic, advanced, or custom.",
+ ),
+ search_settings: Optional[SearchSettings] = Body(
+ None,
+ description="The search configuration object for retrieving context.",
+ ),
+ # Generation configurations
+ rag_generation_config: GenerationConfig = Body(
+ default_factory=GenerationConfig,
+ description="Configuration for RAG generation in 'rag' mode",
+ ),
+ research_generation_config: Optional[GenerationConfig] = Body(
+ None,
+ description="Configuration for generation in 'research' mode. If not provided but mode='research', rag_generation_config will be used with appropriate model overrides.",
+ ),
+ # Tool configurations
+ rag_tools: Optional[
+ list[
+ Literal[
+ "web_search",
+ "web_scrape",
+ "search_file_descriptions",
+ "search_file_knowledge",
+ "get_file_content",
+ ]
+ ]
+ ] = Body(
+ None,
+ description="List of tools to enable for RAG mode. Available tools: search_file_knowledge, get_file_content, web_search, web_scrape, search_file_descriptions",
+ ),
+ research_tools: Optional[
+ list[
+ Literal["rag", "reasoning", "critique", "python_executor"]
+ ]
+ ] = Body(
+ None,
+ description="List of tools to enable for Research mode. Available tools: rag, reasoning, critique, python_executor",
+ ),
+ # Backward compatibility
+ tools: Optional[list[str]] = Body(
+ None,
+ deprecated=True,
+ description="List of tools to execute (deprecated, use rag_tools or research_tools instead)",
+ ),
+ # Other parameters
+ task_prompt: Optional[str] = Body(
+ default=None,
+ description="Optional custom prompt to override default",
+ ),
+ # Backward compatibility
+ task_prompt_override: Optional[str] = Body(
+ default=None,
+ deprecated=True,
+ description="Optional custom prompt to override default",
+ ),
+ include_title_if_available: bool = Body(
+ default=True,
+ description="Pass document titles from search results into the LLM context window.",
+ ),
+ conversation_id: Optional[UUID] = Body(
+ default=None,
+ description="ID of the conversation",
+ ),
+ max_tool_context_length: Optional[int] = Body(
+ default=32_768,
+ description="Maximum length of returned tool context",
+ ),
+ use_system_context: Optional[bool] = Body(
+ default=True,
+ description="Use extended prompt for generation",
+ ),
+ mode: Optional[Literal["rag", "research"]] = Body(
+ default="rag",
+ description="Mode to use for generation: 'rag' for standard retrieval or 'research' for deep analysis with reasoning capabilities",
+ ),
+ needs_initial_conversation_name: Optional[bool] = Body(
+ default=None,
+ description="If true, the system will automatically assign a conversation name if not already specified previously.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedAgentResponse:
+ """
+ Engage with an intelligent agent for information retrieval, analysis, and research.
+
+ This endpoint offers two operating modes:
+ - **RAG mode**: Standard retrieval-augmented generation for answering questions based on knowledge base
+ - **Research mode**: Advanced capabilities for deep analysis, reasoning, and computation
+
+ ### RAG Mode (Default)
+
+ The RAG mode provides fast, knowledge-based responses using:
+ - Semantic and hybrid search capabilities
+ - Document-level and chunk-level content retrieval
+ - Optional web search integration
+ - Source citation and evidence-based responses
+
+ ### Research Mode
+
+ The Research mode builds on RAG capabilities and adds:
+ - A dedicated reasoning system for complex problem-solving
+ - Critique capabilities to identify potential biases or logical fallacies
+ - Python execution for computational analysis
+ - Multi-step reasoning for deeper exploration of topics
+
+ ### Available Tools
+
+ **RAG Tools:**
+ - `search_file_knowledge`: Semantic/hybrid search on your ingested documents
+ - `search_file_descriptions`: Search over file-level metadata
+ - `content`: Fetch entire documents or chunk structures
+ - `web_search`: Query external search APIs for up-to-date information
+ - `web_scrape`: Scrape and extract content from specific web pages
+
+ **Research Tools:**
+ - `rag`: Leverage the underlying RAG agent for information retrieval
+ - `reasoning`: Call a dedicated model for complex analytical thinking
+ - `critique`: Analyze conversation history to identify flaws and biases
+ - `python_executor`: Execute Python code for complex calculations and analysis
+
+ ### Streaming Output
+
+ When streaming is enabled, the agent produces different event types:
+ - `thinking`: Shows the model's step-by-step reasoning (when extended_thinking=true)
+ - `tool_call`: Shows when the agent invokes a tool
+ - `tool_result`: Shows the result of a tool call
+ - `citation`: Indicates when a citation is added to the response
+ - `message`: Streams partial tokens of the response
+ - `final_answer`: Contains the complete generated answer and structured citations
+
+ ### Conversations
+
+ Maintain context across multiple turns by including `conversation_id` in each request.
+ After your first call, store the returned `conversation_id` and include it in subsequent calls.
+ If no conversation name has already been set for the conversation, the system will automatically assign one.
+
+ """
+ # Handle backward compatibility for task_prompt
+ task_prompt = task_prompt or task_prompt_override
+ # Handle model selection based on mode
+ if "model" not in rag_generation_config.__fields_set__:
+ if mode == "rag":
+ rag_generation_config.model = self.config.app.quality_llm
+ elif mode == "research":
+ rag_generation_config.model = self.config.app.planning_llm
+
+ # Prepare search settings
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+
+ # Handle tool configuration and backward compatibility
+ if tools: # Handle deprecated tools parameter
+ logger.warning(
+ "The 'tools' parameter is deprecated. Use 'rag_tools' or 'research_tools' based on mode."
+ )
+ rag_tools = tools # type: ignore
+
+ # Determine effective generation config
+ effective_generation_config = rag_generation_config
+ if mode == "research" and research_generation_config:
+ effective_generation_config = research_generation_config
+
+ try:
+ response = await self.services.retrieval.agent(
+ message=message,
+ messages=messages,
+ search_settings=effective_settings,
+ rag_generation_config=rag_generation_config,
+ research_generation_config=research_generation_config,
+ task_prompt=task_prompt,
+ include_title_if_available=include_title_if_available,
+ max_tool_context_length=max_tool_context_length or 32_768,
+ conversation_id=(
+ str(conversation_id) if conversation_id else None # type: ignore
+ ),
+ use_system_context=use_system_context
+ if use_system_context is not None
+ else True,
+ rag_tools=rag_tools, # type: ignore
+ research_tools=research_tools, # type: ignore
+ mode=mode,
+ needs_initial_conversation_name=needs_initial_conversation_name,
+ )
+
+ if effective_generation_config.stream:
+
+ async def stream_generator():
+ try:
+ async for chunk in response:
+ if len(chunk) > 1024:
+ for i in range(0, len(chunk), 1024):
+ yield chunk[i : i + 1024]
+ else:
+ yield chunk
+ except GeneratorExit:
+ # Clean up if needed, then return
+ return
+
+ return StreamingResponse( # type: ignore
+ stream_generator(), media_type="text/event-stream"
+ )
+ else:
+ return response
+ except Exception as e:
+ logger.error(f"Error in agent_app: {e}")
+ raise R2RException(str(e), 500) from e
+
+ @self.router.post(
+ "/retrieval/completion",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Generate Message Completions",
+ openapi_extra=EXAMPLES["completion"],
+ )
+ @self.base_endpoint
+ async def completion(
+ messages: list[Message] = Body(
+ ...,
+ description="List of messages to generate completion for",
+ example=[
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": "What is the capital of France?",
+ },
+ {
+ "role": "assistant",
+ "content": "The capital of France is Paris.",
+ },
+ {"role": "user", "content": "What about Italy?"},
+ ],
+ ),
+ generation_config: GenerationConfig = Body(
+ default_factory=GenerationConfig,
+ description="Configuration for text generation",
+ example={
+ "model": "openai/gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 150,
+ "stream": False,
+ },
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ response_model=WrappedCompletionResponse,
+ ) -> WrappedLLMChatCompletion:
+ """Generate completions for a list of messages.
+
+ This endpoint uses the language model to generate completions for
+ the provided messages. The generation process can be customized
+ using the generation_config parameter.
+
+ The messages list should contain alternating user and assistant
+ messages, with an optional system message at the start. Each
+ message should have a 'role' and 'content'.
+ """
+
+ return await self.services.retrieval.completion(
+ messages=messages, # type: ignore
+ generation_config=generation_config,
+ )
+
+ @self.router.post(
+ "/retrieval/embedding",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Generate Embeddings",
+ openapi_extra=EXAMPLES["embedding"],
+ )
+ @self.base_endpoint
+ async def embedding(
+ text: str = Body(
+ ...,
+ description="Text to generate embeddings for",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEmbeddingResponse:
+ """Generate embeddings for the provided text using the specified
+ model.
+
+ This endpoint uses the language model to generate embeddings for
+ the provided text. The model parameter specifies the model to use
+ for generating embeddings.
+ """
+
+ return await self.services.retrieval.embedding(
+ text=text,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py
new file mode 100644
index 00000000..682be750
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py
@@ -0,0 +1,186 @@
+import logging
+import textwrap
+from datetime import datetime, timezone
+
+import psutil
+from fastapi import Depends
+
+from core.base import R2RException
+from core.base.api.models import (
+ GenericMessageResponse,
+ WrappedGenericMessageResponse,
+ WrappedServerStatsResponse,
+ WrappedSettingsResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+
+class SystemRouter(BaseRouterV3):
+ def __init__(
+ self,
+ providers: R2RProviders,
+ services: R2RServices,
+ config: R2RConfig,
+ ):
+ logging.info("Initializing SystemRouter")
+ super().__init__(providers, services, config)
+ self.start_time = datetime.now(timezone.utc)
+
+ def _setup_routes(self):
+ @self.router.get(
+ "/health",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.system.health()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.system.health();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/health"\\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def health_check() -> WrappedGenericMessageResponse:
+ return GenericMessageResponse(message="ok") # type: ignore
+
+ @self.router.get(
+ "/system/settings",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.system.settings()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.system.settings();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/system/settings" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def app_settings(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedSettingsResponse:
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can call the `system/settings` endpoint.",
+ 403,
+ )
+ return await self.services.management.app_settings()
+
+ @self.router.get(
+ "/system/status",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.system.status()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.system.status();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/system/status" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def server_stats(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedServerStatsResponse:
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only an authorized user can call the `system/status` endpoint.",
+ 403,
+ )
+ return { # type: ignore
+ "start_time": self.start_time.isoformat(),
+ "uptime_seconds": (
+ datetime.now(timezone.utc) - self.start_time
+ ).total_seconds(),
+ "cpu_usage": psutil.cpu_percent(),
+ "memory_usage": psutil.virtual_memory().percent,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py
new file mode 100644
index 00000000..686f0013
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py
@@ -0,0 +1,1721 @@
+import logging
+import os
+import textwrap
+import urllib.parse
+from typing import Optional
+from uuid import UUID
+
+import requests
+from fastapi import Body, Depends, HTTPException, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+from google.auth.transport import requests as google_requests
+from google.oauth2 import id_token
+from pydantic import EmailStr
+
+from core.base import R2RException
+from core.base.api.models import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ WrappedAPIKeyResponse,
+ WrappedAPIKeysResponse,
+ WrappedBooleanResponse,
+ WrappedCollectionsResponse,
+ WrappedGenericMessageResponse,
+ WrappedLimitsResponse,
+ WrappedLoginResponse,
+ WrappedTokenResponse,
+ WrappedUserResponse,
+ WrappedUsersResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+class UsersRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing UsersRouter")
+ super().__init__(providers, services, config)
+ self.google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
+ self.google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
+ self.google_redirect_uri = os.environ.get("GOOGLE_REDIRECT_URI")
+
+ self.github_client_id = os.environ.get("GITHUB_CLIENT_ID")
+ self.github_client_secret = os.environ.get("GITHUB_CLIENT_SECRET")
+ self.github_redirect_uri = os.environ.get("GITHUB_REDIRECT_URI")
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/users",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedUserResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ new_user = client.users.create(
+ email="jane.doe@example.com",
+ password="secure_password123"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.create({
+ email: "jane.doe@example.com",
+ password: "secure_password123"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "email": "jane.doe@example.com",
+ "password": "secure_password123"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def register(
+ email: EmailStr = Body(..., description="User's email address"),
+ password: str = Body(..., description="User's password"),
+ name: str | None = Body(
+ None, description="The name for the new user"
+ ),
+ bio: str | None = Body(
+ None, description="The bio for the new user"
+ ),
+ profile_picture: str | None = Body(
+ None, description="Updated user profile picture"
+ ),
+ # auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Register a new user with the given email and password."""
+
+ # TODO: Do we really want this validation? The default password for the superuser would not pass...
+ def validate_password(password: str) -> bool:
+ if len(password) < 10:
+ return False
+ if not any(c.isupper() for c in password):
+ return False
+ if not any(c.islower() for c in password):
+ return False
+ if not any(c.isdigit() for c in password):
+ return False
+ if not any(c in "!@#$%^&*" for c in password):
+ return False
+ return True
+
+ # if not validate_password(password):
+ # raise R2RException(
+ # f"Password must be at least 10 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character from '!@#$%^&*'.",
+ # 400,
+ # )
+
+ registration_response = await self.services.auth.register(
+ email=email,
+ password=password,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+
+ return registration_response # type: ignore
+
+ @self.router.post(
+ "/users/export",
+ summary="Export users to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.users.export(
+ output_path="export.csv",
+ columns=["id", "name", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.users.export({
+ outputPath: "export.csv",
+ columns: ["id", "name", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/users/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_users(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export users as a CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ status_code=403,
+ message="Only a superuser can export data.",
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_users(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="users_export.csv",
+ )
+
+ @self.router.post(
+ "/users/verify-email",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ tokens = client.users.verify_email(
+ email="jane.doe@example.com",
+ verification_code="1lklwal!awdclm"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.verifyEmail({
+ email: jane.doe@example.com",
+ verificationCode: "1lklwal!awdclm"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/login" \\
+ -H "Content-Type: application/x-www-form-urlencoded" \\
+ -d "email=jane.doe@example.com&verification_code=1lklwal!awdclm"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def verify_email(
+ email: EmailStr = Body(..., description="User's email address"),
+ verification_code: str = Body(
+ ..., description="Email verification code"
+ ),
+ ) -> WrappedGenericMessageResponse:
+ """Verify a user's email address."""
+ user = (
+ await self.providers.database.users_handler.get_user_by_email(
+ email
+ )
+ )
+ if user and user.is_verified:
+ raise R2RException(
+ status_code=400,
+ message="This email is already verified. Please log in.",
+ )
+
+ result = await self.services.auth.verify_email(
+ email, verification_code
+ )
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/send-verification-email",
+ dependencies=[
+ Depends(self.providers.auth.auth_wrapper(public=True))
+ ],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ tokens = client.users.send_verification_email(
+ email="jane.doe@example.com",
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.sendVerificationEmail({
+ email: jane.doe@example.com",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/send-verification-email" \\
+ -H "Content-Type: application/x-www-form-urlencoded" \\
+ -d "email=jane.doe@example.com"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def send_verification_email(
+ email: EmailStr = Body(..., description="User's email address"),
+ ) -> WrappedGenericMessageResponse:
+ """Send a user's email a verification code."""
+ user = (
+ await self.providers.database.users_handler.get_user_by_email(
+ email
+ )
+ )
+ if user and user.is_verified:
+ raise R2RException(
+ status_code=400,
+ message="This email is already verified. Please log in.",
+ )
+
+ await self.services.auth.send_verification_email(email=email)
+ return GenericMessageResponse(
+ message="A verification email has been sent."
+ ) # type: ignore
+
+ @self.router.post(
+ "/users/login",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedTokenResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ tokens = client.users.login(
+ email="jane.doe@example.com",
+ password="secure_password123"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.login({
+ email: jane.doe@example.com",
+ password: "secure_password123"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/login" \\
+ -H "Content-Type: application/x-www-form-urlencoded" \\
+ -d "username=jane.doe@example.com&password=secure_password123"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def login(
+ form_data: OAuth2PasswordRequestForm = Depends(),
+ ) -> WrappedLoginResponse:
+ """Authenticate a user and provide access tokens."""
+ return await self.services.auth.login( # type: ignore
+ form_data.username, form_data.password
+ )
+
+ @self.router.post(
+ "/users/logout",
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+ result = client.users.logout()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.logout();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/logout" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def logout(
+ token: str = Depends(oauth2_scheme),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Log out the current user."""
+ result = await self.services.auth.logout(token)
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/refresh-token",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ new_tokens = client.users.refresh_token()
+ # New tokens are automatically stored in the client"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.refreshAccessToken();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/refresh-token" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "refresh_token": "YOUR_REFRESH_TOKEN"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def refresh_token(
+ refresh_token: str = Body(..., description="Refresh token"),
+ ) -> WrappedTokenResponse:
+ """Refresh the access token using a refresh token."""
+ result = await self.services.auth.refresh_access_token(
+ refresh_token=refresh_token
+ )
+ return result # type: ignore
+
+ @self.router.post(
+ "/users/change-password",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ result = client.users.change_password(
+ current_password="old_password123",
+ new_password="new_secure_password456"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.changePassword({
+ currentPassword: "old_password123",
+ newPassword: "new_secure_password456"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/change-password" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "current_password": "old_password123",
+ "new_password": "new_secure_password456"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def change_password(
+ current_password: str = Body(..., description="Current password"),
+ new_password: str = Body(..., description="New password"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Change the authenticated user's password."""
+ result = await self.services.auth.change_password(
+ auth_user, current_password, new_password
+ )
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/request-password-reset",
+ dependencies=[
+ Depends(self.providers.auth.auth_wrapper(public=True))
+ ],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ result = client.users.request_password_reset(
+ email="jane.doe@example.com"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.requestPasswordReset({
+ email: jane.doe@example.com",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/request-password-reset" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "email": "jane.doe@example.com"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def request_password_reset(
+ email: EmailStr = Body(..., description="User's email address"),
+ ) -> WrappedGenericMessageResponse:
+ """Request a password reset for a user."""
+ result = await self.services.auth.request_password_reset(email)
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/reset-password",
+ dependencies=[
+ Depends(self.providers.auth.auth_wrapper(public=True))
+ ],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ result = client.users.reset_password(
+ reset_token="reset_token_received_via_email",
+ new_password="new_secure_password789"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.resetPassword({
+ resestToken: "reset_token_received_via_email",
+ newPassword: "new_secure_password789"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/reset-password" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "reset_token": "reset_token_received_via_email",
+ "new_password": "new_secure_password789"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def reset_password(
+ reset_token: str = Body(..., description="Password reset token"),
+ new_password: str = Body(..., description="New password"),
+ ) -> WrappedGenericMessageResponse:
+ """Reset a user's password using a reset token."""
+ result = await self.services.auth.confirm_password_reset(
+ reset_token, new_password
+ )
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.get(
+ "/users",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List Users",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # List users with filters
+ users = client.users.list(
+ offset=0,
+ limit=100,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users?offset=0&limit=100&username=john&email=john@example.com&is_active=true&is_superuser=false" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_users(
+ ids: list[str] = Query(
+ [], description="List of user IDs to filter by"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUsersResponse:
+ """List all users with pagination and filtering options.
+
+ Only accessible by superusers.
+ """
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ status_code=403,
+ message="Only a superuser can call the `users_overview` endpoint.",
+ )
+
+ user_uuids = [UUID(user_id) for user_id in ids]
+
+ users_overview_response = (
+ await self.services.management.users_overview(
+ user_ids=user_uuids, offset=offset, limit=limit
+ )
+ )
+ return users_overview_response["results"], { # type: ignore
+ "total_entries": users_overview_response["total_entries"]
+ }
+
+ @self.router.get(
+ "/users/me",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get the Current User",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Get user details
+ users = client.users.me()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.retrieve();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/me" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_current_user(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Get detailed information about the currently authenticated
+ user."""
+ return auth_user
+
+ @self.router.get(
+ "/users/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get User Details",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Get user details
+ users = client.users.retrieve(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.retrieve({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_user(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Get detailed information about a specific user.
+
+ Users can only access their own information unless they are
+ superusers.
+ """
+ if not auth_user.is_superuser and auth_user.id != id:
+ raise R2RException(
+ "Only a superuser can call the get `user` endpoint for other users.",
+ 403,
+ )
+
+ users_overview_response = (
+ await self.services.management.users_overview(
+ offset=0,
+ limit=1,
+ user_ids=[id],
+ )
+ )
+
+ return users_overview_response["results"][0]
+
+ @self.router.delete(
+ "/users/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete User",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Delete user
+ client.users.delete(id="550e8400-e29b-41d4-a716-446655440000", password="secure_password123")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.delete({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ password: "secure_password123"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_user(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ password: Optional[str] = Body(
+ None, description="User's current password"
+ ),
+ delete_vector_data: Optional[bool] = Body(
+ False,
+ description="Whether to delete the user's vector data",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific user.
+
+ Users can only delete their own account unless they are superusers.
+ """
+ if not auth_user.is_superuser and auth_user.id != id:
+ raise R2RException(
+ "Only a superuser can delete other users.",
+ 403,
+ )
+
+ await self.services.auth.delete_user(
+ user_id=id,
+ password=password,
+ delete_vector_data=delete_vector_data or False,
+ is_superuser=auth_user.is_superuser,
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/users/{id}/collections",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get User Collections",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Get user collections
+ collections = client.user.list_collections(
+ "550e8400-e29b-41d4-a716-446655440000",
+ offset=0,
+ limit=100
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.listCollections({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ offset: 0,
+ limit: 100
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections?offset=0&limit=100" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_user_collections(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionsResponse:
+ """Get all collections associated with a specific user.
+
+ Users can only access their own collections unless they are
+ superusers.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection.",
+ 403,
+ )
+ user_collection_response = (
+ await self.services.management.collections_overview(
+ offset=offset,
+ limit=limit,
+ user_ids=[id],
+ )
+ )
+ return user_collection_response["results"], { # type: ignore
+ "total_entries": user_collection_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/users/{id}/collections/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Add User to Collection",
+ response_model=WrappedBooleanResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Add user to collection
+ client.users.add_to_collection(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ collection_id="750e8400-e29b-41d4-a716-446655440000"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.addToCollection({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ collectionId: "750e8400-e29b-41d4-a716-446655440000"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_user_to_collection(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ collection_id: UUID = Path(
+ ..., example="750e8400-e29b-41d4-a716-446655440000"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection.",
+ 403,
+ )
+
+ # TODO - Do we need a check on user access to the collection?
+ await self.services.management.add_user_to_collection( # type: ignore
+ id, collection_id
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.delete(
+ "/users/{id}/collections/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Remove User from Collection",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Remove user from collection
+ client.users.remove_from_collection(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ collection_id="750e8400-e29b-41d4-a716-446655440000"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.removeFromCollection({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ collectionId: "750e8400-e29b-41d4-a716-446655440000"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def remove_user_from_collection(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ collection_id: UUID = Path(
+ ..., example="750e8400-e29b-41d4-a716-446655440000"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Remove a user from a collection.
+
+ Requires either superuser status or access to the collection.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection.",
+ 403,
+ )
+
+ # TODO - Do we need a check on user access to the collection?
+ await self.services.management.remove_user_from_collection( # type: ignore
+ id, collection_id
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/users/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update User",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Update user
+ updated_user = client.update_user(
+ "550e8400-e29b-41d4-a716-446655440000",
+ name="John Doe"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.update({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ name: "John Doe"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "John Doe",
+ }'
+ """),
+ },
+ ]
+ },
+ )
+ # TODO - Modify update user to have synced params with user object
+ @self.base_endpoint
+ async def update_user(
+ id: UUID = Path(..., description="ID of the user to update"),
+ email: EmailStr | None = Body(
+ None, description="Updated email address"
+ ),
+ is_superuser: bool | None = Body(
+ None, description="Updated superuser status"
+ ),
+ name: str | None = Body(None, description="Updated user name"),
+ bio: str | None = Body(None, description="Updated user bio"),
+ profile_picture: str | None = Body(
+ None, description="Updated profile picture URL"
+ ),
+ limits_overrides: dict = Body(
+ None,
+ description="Updated limits overrides",
+ ),
+ metadata: dict[str, str | None] | None = None,
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Update user information.
+
+ Users can only update their own information unless they are
+ superusers. Superuser status can only be modified by existing
+ superusers.
+ """
+
+ if is_superuser is not None and not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update the superuser status of a user",
+ 403,
+ )
+
+ if not auth_user.is_superuser and auth_user.id != id:
+ raise R2RException(
+ "Only superusers can update other users' information",
+ 403,
+ )
+
+ if not auth_user.is_superuser and limits_overrides is not None:
+ raise R2RException(
+ "Only superusers can update other users' limits overrides",
+ 403,
+ )
+
+ # Pass `metadata` to our auth or management service so it can do a
+ # partial (Stripe-like) merge of metadata.
+ return await self.services.auth.update_user( # type: ignore
+ user_id=id,
+ email=email,
+ is_superuser=is_superuser,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ limits_overrides=limits_overrides,
+ new_metadata=metadata,
+ )
+
+ @self.router.post(
+ "/users/{id}/api-keys",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create User API Key",
+ response_model=WrappedAPIKeyResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ result = client.users.create_api_key(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ name="My API Key",
+ description="API key for accessing the app",
+ )
+ # result["api_key"] contains the newly created API key
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\
+ -H "Authorization: Bearer YOUR_API_TOKEN" \\
+ -d '{"name": "My API Key", "description": "API key for accessing the app"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_user_api_key(
+ id: UUID = Path(
+ ..., description="ID of the user for whom to create an API key"
+ ),
+ name: Optional[str] = Body(
+ None, description="Name of the API key"
+ ),
+ description: Optional[str] = Body(
+ None, description="Description of the API key"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedAPIKeyResponse:
+ """Create a new API key for the specified user.
+
+ Only superusers or the user themselves may create an API key.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "Only the user themselves or a superuser can create API keys for this user.",
+ 403,
+ )
+
+ api_key = await self.services.auth.create_user_api_key(
+ id, name=name, description=description
+ )
+ return api_key # type: ignore
+
+ @self.router.get(
+ "/users/{id}/api-keys",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List User API Keys",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ keys = client.users.list_api_keys(
+ id="550e8400-e29b-41d4-a716-446655440000"
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\
+ -H "Authorization: Bearer YOUR_API_TOKEN"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_user_api_keys(
+ id: UUID = Path(
+ ..., description="ID of the user whose API keys to list"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedAPIKeysResponse:
+ """List all API keys for the specified user.
+
+ Only superusers or the user themselves may list the API keys.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "Only the user themselves or a superuser can list API keys for this user.",
+ 403,
+ )
+
+ keys = (
+ await self.providers.database.users_handler.get_user_api_keys(
+ id
+ )
+ )
+ return keys, {"total_entries": len(keys)} # type: ignore
+
+ @self.router.delete(
+ "/users/{id}/api-keys/{key_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete User API Key",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+ from uuid import UUID
+
+ client = R2RClient()
+ # client.login(...)
+
+ response = client.users.delete_api_key(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ key_id="d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25"
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys/d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" \\
+ -H "Authorization: Bearer YOUR_API_TOKEN"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_user_api_key(
+ id: UUID = Path(..., description="ID of the user"),
+ key_id: UUID = Path(
+ ..., description="ID of the API key to delete"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific API key for the specified user.
+
+ Only superusers or the user themselves may delete the API key.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "Only the user themselves or a superuser can delete this API key.",
+ 403,
+ )
+
+ success = (
+ await self.providers.database.users_handler.delete_api_key(
+ id, key_id
+ )
+ )
+ if not success:
+ raise R2RException(
+ "API key not found or could not be deleted", 400
+ )
+ return {"success": True} # type: ignore
+
+ @self.router.get(
+ "/users/{id}/limits",
+ summary="Fetch User Limits",
+ responses={
+ 200: {
+ "description": "Returns system default limits, user overrides, and final effective settings."
+ },
+ 403: {
+ "description": "If the requesting user is neither the same user nor a superuser."
+ },
+ 404: {"description": "If the user ID does not exist."},
+ },
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ user_limits = client.users.get_limits("550e8400-e29b-41d4-a716-446655440000")
+ """,
+ },
+ {
+ "lang": "JavaScript",
+ "source": """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // await client.users.login(...)
+
+ async function main() {
+ const userLimits = await client.users.getLimits({
+ id: "550e8400-e29b-41d4-a716-446655440000"
+ });
+ console.log(userLimits);
+ }
+
+ main();
+ """,
+ },
+ {
+ "lang": "cURL",
+ "source": """
+ curl -X GET "https://api.example.com/v3/users/550e8400-e29b-41d4-a716-446655440000/limits" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """,
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_user_limits(
+ id: UUID = Path(
+ ..., description="ID of the user to fetch limits for"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedLimitsResponse:
+ """Return the system default limits, user-level overrides, and
+ final "effective" limit settings for the specified user.
+
+ Only superusers or the user themself may fetch these values.
+ """
+ if (auth_user.id != id) and (not auth_user.is_superuser):
+ raise R2RException(
+ "Only the user themselves or a superuser can view these limits.",
+ status_code=403,
+ )
+
+ # This calls the new helper you created in ManagementService
+ limits_info = await self.services.management.get_all_user_limits(
+ id
+ )
+ return limits_info # type: ignore
+
+ @self.router.get("/users/oauth/google/authorize")
+ @self.base_endpoint
+ async def google_authorize() -> WrappedGenericMessageResponse:
+ """Redirect user to Google's OAuth 2.0 consent screen."""
+ state = "some_random_string_or_csrf_token" # Usually you store a random state in session/Redis
+ scope = "openid email profile"
+
+ # Build the Google OAuth URL
+ params = {
+ "client_id": self.google_client_id,
+ "redirect_uri": self.google_redirect_uri,
+ "response_type": "code",
+ "scope": scope,
+ "state": state,
+ "access_type": "offline", # to get refresh token if needed
+ "prompt": "consent", # Force consent each time if you want
+ }
+ google_auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?{urllib.parse.urlencode(params)}"
+ return GenericMessageResponse(message=google_auth_url) # type: ignore
+
+ @self.router.get("/users/oauth/google/callback")
+ @self.base_endpoint
+ async def google_callback(
+ code: str = Query(...), state: str = Query(...)
+ ) -> WrappedLoginResponse:
+ """Google's callback that will receive the `code` and `state`.
+
+ We then exchange code for tokens, verify, and log the user in.
+ """
+ # 1. Exchange `code` for tokens
+ token_data = requests.post(
+ "https://oauth2.googleapis.com/token",
+ data={
+ "code": code,
+ "client_id": self.google_client_id,
+ "client_secret": self.google_client_secret,
+ "redirect_uri": self.google_redirect_uri,
+ "grant_type": "authorization_code",
+ },
+ ).json()
+ if "error" in token_data:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to get token: {token_data}",
+ )
+
+ # 2. Verify the ID token
+ id_token_str = token_data["id_token"]
+ try:
+ # google_auth.transport.requests.Request() is a session for verifying
+ id_info = id_token.verify_oauth2_token(
+ id_token_str,
+ google_requests.Request(),
+ self.google_client_id,
+ )
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Token verification failed: {str(e)}",
+ ) from e
+
+ # id_info will contain "sub", "email", etc.
+ google_id = id_info["sub"]
+ email = id_info.get("email")
+ email = email or f"{google_id}@google_oauth.fake"
+
+ # 3. Now call our R2RAuthProvider method that handles "oauth-based" user creation or login
+ return await self.providers.auth.oauth_callback_handler( # type: ignore
+ provider="google",
+ oauth_id=google_id,
+ email=email,
+ )
+
+ @self.router.get("/users/oauth/github/authorize")
+ @self.base_endpoint
+ async def github_authorize() -> WrappedGenericMessageResponse:
+ """Redirect user to GitHub's OAuth consent screen."""
+ state = "some_random_string_or_csrf_token"
+ scope = "read:user user:email"
+
+ params = {
+ "client_id": self.github_client_id,
+ "redirect_uri": self.github_redirect_uri,
+ "scope": scope,
+ "state": state,
+ }
+ github_auth_url = f"https://github.com/login/oauth/authorize?{urllib.parse.urlencode(params)}"
+ return GenericMessageResponse(message=github_auth_url) # type: ignore
+
+ @self.router.get("/users/oauth/github/callback")
+ @self.base_endpoint
+ async def github_callback(
+ code: str = Query(...), state: str = Query(...)
+ ) -> WrappedLoginResponse:
+ """GitHub callback route to exchange code for an access_token, then
+ fetch user info from GitHub's API, then do the same 'oauth-based'
+ login or registration."""
+ # 1. Exchange code for access_token
+ token_resp = requests.post(
+ "https://github.com/login/oauth/access_token",
+ data={
+ "client_id": self.github_client_id,
+ "client_secret": self.github_client_secret,
+ "code": code,
+ "redirect_uri": self.github_redirect_uri,
+ "state": state,
+ },
+ headers={"Accept": "application/json"},
+ )
+ token_data = token_resp.json()
+ if "error" in token_data:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to get token: {token_data}",
+ )
+ access_token = token_data["access_token"]
+
+ # 2. Use the access_token to fetch user info
+ user_info_resp = requests.get(
+ "https://api.github.com/user",
+ headers={"Authorization": f"Bearer {access_token}"},
+ ).json()
+
+ github_id = str(
+ user_info_resp["id"]
+ ) # GitHub user ID is typically an integer
+ # fetch email (sometimes you need to call /user/emails endpoint if user sets email private)
+ email = user_info_resp.get("email")
+ email = email or f"{github_id}@github_oauth.fake"
+ # 3. Pass to your auth provider
+ return await self.providers.auth.oauth_callback_handler( # type: ignore
+ provider="github",
+ oauth_id=github_id,
+ email=email,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/app.py b/.venv/lib/python3.12/site-packages/core/main/app.py
new file mode 100644
index 00000000..ceb13cce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/app.py
@@ -0,0 +1,121 @@
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.openapi.utils import get_openapi
+from fastapi.responses import JSONResponse
+
+from core.base import R2RException
+from core.providers import (
+ HatchetOrchestrationProvider,
+ SimpleOrchestrationProvider,
+)
+from core.utils.sentry import init_sentry
+
+from .abstractions import R2RServices
+from .api.v3.chunks_router import ChunksRouter
+from .api.v3.collections_router import CollectionsRouter
+from .api.v3.conversations_router import ConversationsRouter
+from .api.v3.documents_router import DocumentsRouter
+from .api.v3.graph_router import GraphRouter
+from .api.v3.indices_router import IndicesRouter
+from .api.v3.prompts_router import PromptsRouter
+from .api.v3.retrieval_router import RetrievalRouter
+from .api.v3.system_router import SystemRouter
+from .api.v3.users_router import UsersRouter
+from .config import R2RConfig
+
+
+class R2RApp:
+ def __init__(
+ self,
+ config: R2RConfig,
+ orchestration_provider: (
+ HatchetOrchestrationProvider | SimpleOrchestrationProvider
+ ),
+ services: R2RServices,
+ chunks_router: ChunksRouter,
+ collections_router: CollectionsRouter,
+ conversations_router: ConversationsRouter,
+ documents_router: DocumentsRouter,
+ graph_router: GraphRouter,
+ indices_router: IndicesRouter,
+ prompts_router: PromptsRouter,
+ retrieval_router: RetrievalRouter,
+ system_router: SystemRouter,
+ users_router: UsersRouter,
+ ):
+ init_sentry()
+
+ self.config = config
+ self.services = services
+ self.chunks_router = chunks_router
+ self.collections_router = collections_router
+ self.conversations_router = conversations_router
+ self.documents_router = documents_router
+ self.graph_router = graph_router
+ self.indices_router = indices_router
+ self.orchestration_provider = orchestration_provider
+ self.prompts_router = prompts_router
+ self.retrieval_router = retrieval_router
+ self.system_router = system_router
+ self.users_router = users_router
+
+ self.app = FastAPI()
+
+ @self.app.exception_handler(R2RException)
+ async def r2r_exception_handler(request: Request, exc: R2RException):
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={
+ "message": exc.message,
+ "error_type": type(exc).__name__,
+ },
+ )
+
+ self._setup_routes()
+ self._apply_cors()
+
+ def _setup_routes(self):
+ self.app.include_router(self.chunks_router, prefix="/v3")
+ self.app.include_router(self.collections_router, prefix="/v3")
+ self.app.include_router(self.conversations_router, prefix="/v3")
+ self.app.include_router(self.documents_router, prefix="/v3")
+ self.app.include_router(self.graph_router, prefix="/v3")
+ self.app.include_router(self.indices_router, prefix="/v3")
+ self.app.include_router(self.prompts_router, prefix="/v3")
+ self.app.include_router(self.retrieval_router, prefix="/v3")
+ self.app.include_router(self.system_router, prefix="/v3")
+ self.app.include_router(self.users_router, prefix="/v3")
+
+ @self.app.get("/openapi_spec", include_in_schema=False)
+ async def openapi_spec():
+ return get_openapi(
+ title="R2R Application API",
+ version="1.0.0",
+ routes=self.app.routes,
+ )
+
+ def _apply_cors(self):
+ origins = ["*", "http://localhost:3000", "http://localhost:7272"]
+ self.app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ async def serve(self, host: str = "0.0.0.0", port: int = 7272):
+ import uvicorn
+
+ from core.utils.logging_config import configure_logging
+
+ configure_logging()
+
+ config = uvicorn.Config(
+ self.app,
+ host=host,
+ port=port,
+ log_config=None,
+ )
+ server = uvicorn.Server(config)
+ await server.serve()
diff --git a/.venv/lib/python3.12/site-packages/core/main/app_entry.py b/.venv/lib/python3.12/site-packages/core/main/app_entry.py
new file mode 100644
index 00000000..cd3ea84d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/app_entry.py
@@ -0,0 +1,125 @@
+import logging
+import os
+from contextlib import asynccontextmanager
+from typing import Optional
+
+from apscheduler.schedulers.asyncio import AsyncIOScheduler
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+
+from core.base import R2RException
+from core.utils.logging_config import configure_logging
+
+from .assembly import R2RBuilder, R2RConfig
+
+log_file = configure_logging()
+
+# Global scheduler
+scheduler = AsyncIOScheduler()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # Startup
+ r2r_app = await create_r2r_app(
+ config_name=config_name,
+ config_path=config_path,
+ )
+
+ # Copy all routes from r2r_app to app
+ app.router.routes = r2r_app.app.routes
+
+ # Copy middleware and exception handlers
+ app.middleware = r2r_app.app.middleware # type: ignore
+ app.exception_handlers = r2r_app.app.exception_handlers
+
+ # Start the scheduler
+ scheduler.start()
+
+ # Start the Hatchet worker
+ await r2r_app.orchestration_provider.start_worker()
+
+ yield
+
+ # # Shutdown
+ scheduler.shutdown()
+
+
+async def create_r2r_app(
+ config_name: Optional[str] = "default",
+ config_path: Optional[str] = None,
+):
+ config = R2RConfig.load(config_name=config_name, config_path=config_path)
+
+ if (
+ config.embedding.provider == "openai"
+ and "OPENAI_API_KEY" not in os.environ
+ ):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ # Build the R2RApp
+ builder = R2RBuilder(config=config)
+ return await builder.build()
+
+
+config_name = os.getenv("R2R_CONFIG_NAME", None)
+config_path = os.getenv("R2R_CONFIG_PATH", None)
+
+if not config_path and not config_name:
+ config_name = "default"
+host = os.getenv("R2R_HOST", os.getenv("HOST", "0.0.0.0"))
+port = int(os.getenv("R2R_PORT", "7272"))
+
+logging.info(
+ f"Environment R2R_IMAGE: {os.getenv('R2R_IMAGE')}",
+)
+logging.info(
+ f"Environment R2R_CONFIG_NAME: {'None' if config_name is None else config_name}"
+)
+logging.info(
+ f"Environment R2R_CONFIG_PATH: {'None' if config_path is None else config_path}"
+)
+logging.info(f"Environment R2R_PROJECT_NAME: {os.getenv('R2R_PROJECT_NAME')}")
+
+logging.info(
+ f"Environment R2R_POSTGRES_HOST: {os.getenv('R2R_POSTGRES_HOST')}"
+)
+logging.info(
+ f"Environment R2R_POSTGRES_DBNAME: {os.getenv('R2R_POSTGRES_DBNAME')}"
+)
+logging.info(
+ f"Environment R2R_POSTGRES_PORT: {os.getenv('R2R_POSTGRES_PORT')}"
+)
+logging.info(
+ f"Environment R2R_POSTGRES_PASSWORD: {os.getenv('R2R_POSTGRES_PASSWORD')}"
+)
+
+# Create the FastAPI app
+app = FastAPI(
+ lifespan=lifespan,
+ log_config=None,
+)
+
+
+@app.exception_handler(R2RException)
+async def r2r_exception_handler(request: Request, exc: R2RException):
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={
+ "message": exc.message,
+ "error_type": type(exc).__name__,
+ },
+ )
+
+
+# Add CORS middleware
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py
new file mode 100644
index 00000000..3d10f2b6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py
@@ -0,0 +1,12 @@
+from ..config import R2RConfig
+from .builder import R2RBuilder
+from .factory import R2RProviderFactory
+
+__all__ = [
+ # Builder
+ "R2RBuilder",
+ # Config
+ "R2RConfig",
+ # Factory
+ "R2RProviderFactory",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py
new file mode 100644
index 00000000..f72a15c9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py
@@ -0,0 +1,127 @@
+import logging
+from typing import Any, Type
+
+from ..abstractions import R2RProviders, R2RServices
+from ..api.v3.chunks_router import ChunksRouter
+from ..api.v3.collections_router import CollectionsRouter
+from ..api.v3.conversations_router import ConversationsRouter
+from ..api.v3.documents_router import DocumentsRouter
+from ..api.v3.graph_router import GraphRouter
+from ..api.v3.indices_router import IndicesRouter
+from ..api.v3.prompts_router import PromptsRouter
+from ..api.v3.retrieval_router import RetrievalRouter
+from ..api.v3.system_router import SystemRouter
+from ..api.v3.users_router import UsersRouter
+from ..app import R2RApp
+from ..config import R2RConfig
+from ..services.auth_service import AuthService # noqa: F401
+from ..services.graph_service import GraphService # noqa: F401
+from ..services.ingestion_service import IngestionService # noqa: F401
+from ..services.management_service import ManagementService # noqa: F401
+from ..services.retrieval_service import ( # type: ignore
+ RetrievalService, # noqa: F401 # type: ignore
+)
+from .factory import R2RProviderFactory
+
+logger = logging.getLogger()
+
+
+class R2RBuilder:
+ _SERVICES = ["auth", "ingestion", "management", "retrieval", "graph"]
+
+ def __init__(self, config: R2RConfig):
+ self.config = config
+
+ async def build(self, *args, **kwargs) -> R2RApp:
+ provider_factory = R2RProviderFactory
+
+ try:
+ providers = await self._create_providers(
+ provider_factory, *args, **kwargs
+ )
+ except Exception as e:
+ logger.error(f"Error {e} while creating R2RProviders.")
+ raise
+
+ service_params = {
+ "config": self.config,
+ "providers": providers,
+ }
+
+ services = self._create_services(service_params)
+
+ routers = {
+ "chunks_router": ChunksRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "collections_router": CollectionsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "conversations_router": ConversationsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "documents_router": DocumentsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "graph_router": GraphRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "indices_router": IndicesRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "prompts_router": PromptsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "retrieval_router": RetrievalRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "system_router": SystemRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "users_router": UsersRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ }
+
+ return R2RApp(
+ config=self.config,
+ orchestration_provider=providers.orchestration,
+ services=services,
+ **routers,
+ )
+
+ async def _create_providers(
+ self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
+ ) -> R2RProviders:
+ factory = provider_factory(self.config)
+ return await factory.create_providers(*args, **kwargs)
+
+ def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
+ services = R2RBuilder._SERVICES
+ service_instances = {}
+
+ for service_type in services:
+ service_class = globals()[f"{service_type.capitalize()}Service"]
+ service_instances[service_type] = service_class(**service_params)
+
+ return R2RServices(**service_instances)
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py
new file mode 100644
index 00000000..b982aa18
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py
@@ -0,0 +1,417 @@
+import logging
+import math
+import os
+from typing import Any, Optional
+
+from core.base import (
+ AuthConfig,
+ CompletionConfig,
+ CompletionProvider,
+ CryptoConfig,
+ DatabaseConfig,
+ EmailConfig,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ IngestionConfig,
+ OrchestrationConfig,
+)
+from core.providers import (
+ AnthropicCompletionProvider,
+ AsyncSMTPEmailProvider,
+ BcryptCryptoConfig,
+ BCryptCryptoProvider,
+ ClerkAuthProvider,
+ ConsoleMockEmailProvider,
+ HatchetOrchestrationProvider,
+ JwtAuthProvider,
+ LiteLLMCompletionProvider,
+ LiteLLMEmbeddingProvider,
+ MailerSendEmailProvider,
+ NaClCryptoConfig,
+ NaClCryptoProvider,
+ OllamaEmbeddingProvider,
+ OpenAICompletionProvider,
+ OpenAIEmbeddingProvider,
+ PostgresDatabaseProvider,
+ R2RAuthProvider,
+ R2RCompletionProvider,
+ R2RIngestionConfig,
+ R2RIngestionProvider,
+ SendGridEmailProvider,
+ SimpleOrchestrationProvider,
+ SupabaseAuthProvider,
+ UnstructuredIngestionConfig,
+ UnstructuredIngestionProvider,
+)
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+logger = logging.getLogger()
+
+
+class R2RProviderFactory:
+ def __init__(self, config: R2RConfig):
+ self.config = config
+
+ @staticmethod
+ async def create_auth_provider(
+ auth_config: AuthConfig,
+ crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: (
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ ),
+ *args,
+ **kwargs,
+ ) -> (
+ R2RAuthProvider
+ | SupabaseAuthProvider
+ | JwtAuthProvider
+ | ClerkAuthProvider
+ ):
+ if auth_config.provider == "r2r":
+ r2r_auth = R2RAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ await r2r_auth.initialize()
+ return r2r_auth
+ elif auth_config.provider == "supabase":
+ return SupabaseAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ elif auth_config.provider == "jwt":
+ return JwtAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ elif auth_config.provider == "clerk":
+ return ClerkAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ else:
+ raise ValueError(
+ f"Auth provider {auth_config.provider} not supported."
+ )
+
+ @staticmethod
+ def create_crypto_provider(
+ crypto_config: CryptoConfig, *args, **kwargs
+ ) -> BCryptCryptoProvider | NaClCryptoProvider:
+ if crypto_config.provider == "bcrypt":
+ return BCryptCryptoProvider(
+ BcryptCryptoConfig(**crypto_config.model_dump())
+ )
+ if crypto_config.provider == "nacl":
+ return NaClCryptoProvider(
+ NaClCryptoConfig(**crypto_config.model_dump())
+ )
+ else:
+ raise ValueError(
+ f"Crypto provider {crypto_config.provider} not supported."
+ )
+
+ @staticmethod
+ def create_ingestion_provider(
+ ingestion_config: IngestionConfig,
+ database_provider: PostgresDatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ *args,
+ **kwargs,
+ ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
+ config_dict = (
+ ingestion_config.model_dump()
+ if isinstance(ingestion_config, IngestionConfig)
+ else ingestion_config
+ )
+
+ extra_fields = config_dict.pop("extra_fields", {})
+
+ if config_dict["provider"] == "r2r":
+ r2r_ingestion_config = R2RIngestionConfig(
+ **config_dict, **extra_fields
+ )
+ return R2RIngestionProvider(
+ r2r_ingestion_config, database_provider, llm_provider
+ )
+ elif config_dict["provider"] in [
+ "unstructured_local",
+ "unstructured_api",
+ ]:
+ unstructured_ingestion_config = UnstructuredIngestionConfig(
+ **config_dict, **extra_fields
+ )
+
+ return UnstructuredIngestionProvider(
+ unstructured_ingestion_config, database_provider, llm_provider
+ )
+ else:
+ raise ValueError(
+ f"Ingestion provider {ingestion_config.provider} not supported"
+ )
+
+ @staticmethod
+ def create_orchestration_provider(
+ config: OrchestrationConfig, *args, **kwargs
+ ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
+ if config.provider == "hatchet":
+ orchestration_provider = HatchetOrchestrationProvider(config)
+ orchestration_provider.get_worker("r2r-worker")
+ return orchestration_provider
+ elif config.provider == "simple":
+ from core.providers import SimpleOrchestrationProvider
+
+ return SimpleOrchestrationProvider(config)
+ else:
+ raise ValueError(
+ f"Orchestration provider {config.provider} not supported"
+ )
+
+ async def create_database_provider(
+ self,
+ db_config: DatabaseConfig,
+ crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
+ *args,
+ **kwargs,
+ ) -> PostgresDatabaseProvider:
+ if not self.config.embedding.base_dimension:
+ raise ValueError(
+ "Embedding config must have a base dimension to initialize database."
+ )
+
+ dimension = self.config.embedding.base_dimension
+ quantization_type = (
+ self.config.embedding.quantization_settings.quantization_type
+ )
+ if db_config.provider == "postgres":
+ database_provider = PostgresDatabaseProvider(
+ db_config,
+ dimension,
+ crypto_provider=crypto_provider,
+ quantization_type=quantization_type,
+ )
+ await database_provider.initialize()
+ return database_provider
+ else:
+ raise ValueError(
+ f"Database provider {db_config.provider} not supported"
+ )
+
+ @staticmethod
+ def create_embedding_provider(
+ embedding: EmbeddingConfig, *args, **kwargs
+ ) -> (
+ LiteLLMEmbeddingProvider
+ | OllamaEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ ):
+ embedding_provider: Optional[EmbeddingProvider] = None
+
+ if embedding.provider == "openai":
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ from core.providers import OpenAIEmbeddingProvider
+
+ embedding_provider = OpenAIEmbeddingProvider(embedding)
+
+ elif embedding.provider == "litellm":
+ from core.providers import LiteLLMEmbeddingProvider
+
+ embedding_provider = LiteLLMEmbeddingProvider(embedding)
+
+ elif embedding.provider == "ollama":
+ from core.providers import OllamaEmbeddingProvider
+
+ embedding_provider = OllamaEmbeddingProvider(embedding)
+
+ else:
+ raise ValueError(
+ f"Embedding provider {embedding.provider} not supported"
+ )
+
+ return embedding_provider
+
+ @staticmethod
+ def create_llm_provider(
+ llm_config: CompletionConfig, *args, **kwargs
+ ) -> (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ):
+ llm_provider: Optional[CompletionProvider] = None
+ if llm_config.provider == "anthropic":
+ llm_provider = AnthropicCompletionProvider(llm_config)
+ elif llm_config.provider == "litellm":
+ llm_provider = LiteLLMCompletionProvider(llm_config)
+ elif llm_config.provider == "openai":
+ llm_provider = OpenAICompletionProvider(llm_config)
+ elif llm_config.provider == "r2r":
+ llm_provider = R2RCompletionProvider(llm_config)
+ else:
+ raise ValueError(
+ f"Language model provider {llm_config.provider} not supported"
+ )
+ if not llm_provider:
+ raise ValueError("Language model provider not found")
+ return llm_provider
+
+ @staticmethod
+ async def create_email_provider(
+ email_config: Optional[EmailConfig] = None, *args, **kwargs
+ ) -> (
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ ):
+ """Creates an email provider based on configuration."""
+ if not email_config:
+ raise ValueError(
+ "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
+ )
+
+ if email_config.provider == "smtp":
+ return AsyncSMTPEmailProvider(email_config)
+ elif email_config.provider == "console_mock":
+ return ConsoleMockEmailProvider(email_config)
+ elif email_config.provider == "sendgrid":
+ return SendGridEmailProvider(email_config)
+ elif email_config.provider == "mailersend":
+ return MailerSendEmailProvider(email_config)
+ else:
+ raise ValueError(
+ f"Email provider {email_config.provider} not supported."
+ )
+
+ async def create_providers(
+ self,
+ auth_provider_override: Optional[
+ R2RAuthProvider | SupabaseAuthProvider
+ ] = None,
+ crypto_provider_override: Optional[
+ BCryptCryptoProvider | NaClCryptoProvider
+ ] = None,
+ database_provider_override: Optional[PostgresDatabaseProvider] = None,
+ email_provider_override: Optional[
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ ] = None,
+ embedding_provider_override: Optional[
+ LiteLLMEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ | OllamaEmbeddingProvider
+ ] = None,
+ ingestion_provider_override: Optional[
+ R2RIngestionProvider | UnstructuredIngestionProvider
+ ] = None,
+ llm_provider_override: Optional[
+ AnthropicCompletionProvider
+ | OpenAICompletionProvider
+ | LiteLLMCompletionProvider
+ | R2RCompletionProvider
+ ] = None,
+ orchestration_provider_override: Optional[Any] = None,
+ *args,
+ **kwargs,
+ ) -> R2RProviders:
+ if (
+ math.isnan(self.config.embedding.base_dimension)
+ != math.isnan(self.config.completion_embedding.base_dimension)
+ ) or (
+ not math.isnan(self.config.embedding.base_dimension)
+ and not math.isnan(self.config.completion_embedding.base_dimension)
+ and self.config.embedding.base_dimension
+ != self.config.completion_embedding.base_dimension
+ ):
+ raise ValueError(
+ f"Both embedding configurations must use the same dimensions. Got {self.config.embedding.base_dimension} and {self.config.completion_embedding.base_dimension}"
+ )
+
+ embedding_provider = (
+ embedding_provider_override
+ or self.create_embedding_provider(
+ self.config.embedding, *args, **kwargs
+ )
+ )
+
+ completion_embedding_provider = (
+ embedding_provider_override
+ or self.create_embedding_provider(
+ self.config.completion_embedding, *args, **kwargs
+ )
+ )
+
+ llm_provider = llm_provider_override or self.create_llm_provider(
+ self.config.completion, *args, **kwargs
+ )
+
+ crypto_provider = (
+ crypto_provider_override
+ or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
+ )
+
+ database_provider = (
+ database_provider_override
+ or await self.create_database_provider(
+ self.config.database, crypto_provider, *args, **kwargs
+ )
+ )
+
+ ingestion_provider = (
+ ingestion_provider_override
+ or self.create_ingestion_provider(
+ self.config.ingestion,
+ database_provider,
+ llm_provider,
+ *args,
+ **kwargs,
+ )
+ )
+
+ email_provider = (
+ email_provider_override
+ or await self.create_email_provider(
+ self.config.email, crypto_provider, *args, **kwargs
+ )
+ )
+
+ auth_provider = (
+ auth_provider_override
+ or await self.create_auth_provider(
+ self.config.auth,
+ crypto_provider,
+ database_provider,
+ email_provider,
+ *args,
+ **kwargs,
+ )
+ )
+
+ orchestration_provider = (
+ orchestration_provider_override
+ or self.create_orchestration_provider(self.config.orchestration)
+ )
+
+ return R2RProviders(
+ auth=auth_provider,
+ database=database_provider,
+ embedding=embedding_provider,
+ completion_embedding=completion_embedding_provider,
+ ingestion=ingestion_provider,
+ llm=llm_provider,
+ email=email_provider,
+ orchestration=orchestration_provider,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/config.py b/.venv/lib/python3.12/site-packages/core/main/config.py
new file mode 100644
index 00000000..f49b4041
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/config.py
@@ -0,0 +1,213 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+import logging
+import os
+from enum import Enum
+from typing import Any, Optional
+
+import toml
+from pydantic import BaseModel
+
+from ..base.abstractions import GenerationConfig
+from ..base.agent.agent import RAGAgentConfig # type: ignore
+from ..base.providers import AppConfig
+from ..base.providers.auth import AuthConfig
+from ..base.providers.crypto import CryptoConfig
+from ..base.providers.database import DatabaseConfig
+from ..base.providers.email import EmailConfig
+from ..base.providers.embedding import EmbeddingConfig
+from ..base.providers.ingestion import IngestionConfig
+from ..base.providers.llm import CompletionConfig
+from ..base.providers.orchestration import OrchestrationConfig
+from ..base.utils import deep_update
+
+logger = logging.getLogger()
+
+
+class R2RConfig:
+ current_file_path = os.path.dirname(__file__)
+ config_dir_root = os.path.join(current_file_path, "..", "configs")
+ default_config_path = os.path.join(
+ current_file_path, "..", "..", "r2r", "r2r.toml"
+ )
+
+ CONFIG_OPTIONS: dict[str, Optional[str]] = {}
+ for file_ in os.listdir(config_dir_root):
+ if file_.endswith(".toml"):
+ CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
+ config_dir_root, file_
+ )
+ CONFIG_OPTIONS["default"] = None
+
+ REQUIRED_KEYS: dict[str, list] = {
+ "app": [],
+ "completion": ["provider"],
+ "crypto": ["provider"],
+ "email": ["provider"],
+ "auth": ["provider"],
+ "embedding": [
+ "provider",
+ "base_model",
+ "base_dimension",
+ "batch_size",
+ "add_title_as_prefix",
+ ],
+ "completion_embedding": [
+ "provider",
+ "base_model",
+ "base_dimension",
+ "batch_size",
+ "add_title_as_prefix",
+ ],
+ # TODO - deprecated, remove
+ "ingestion": ["provider"],
+ "logging": ["provider", "log_table"],
+ "database": ["provider"],
+ "agent": ["generation_config"],
+ "orchestration": ["provider"],
+ }
+
+ app: AppConfig
+ auth: AuthConfig
+ completion: CompletionConfig
+ crypto: CryptoConfig
+ database: DatabaseConfig
+ embedding: EmbeddingConfig
+ completion_embedding: EmbeddingConfig
+ email: EmailConfig
+ ingestion: IngestionConfig
+ agent: RAGAgentConfig
+ orchestration: OrchestrationConfig
+
+ def __init__(self, config_data: dict[str, Any]):
+ """
+ :param config_data: dictionary of configuration parameters
+ :param base_path: base path when a relative path is specified for the prompts directory
+ """
+ # Load the default configuration
+ default_config = self.load_default_config()
+
+ # Override the default configuration with the passed configuration
+ default_config = deep_update(default_config, config_data)
+
+ # Validate and set the configuration
+ for section, keys in R2RConfig.REQUIRED_KEYS.items():
+ # Check the keys when provider is set
+ # TODO - remove after deprecation
+ if section in ["graph", "file"] and section not in default_config:
+ continue
+ if "provider" in default_config[section] and (
+ default_config[section]["provider"] is not None
+ and default_config[section]["provider"] != "None"
+ and default_config[section]["provider"] != "null"
+ ):
+ self._validate_config_section(default_config, section, keys)
+ setattr(self, section, default_config[section])
+
+ self.app = AppConfig.create(**self.app) # type: ignore
+ self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
+ self.completion = CompletionConfig.create(
+ **self.completion, app=self.app
+ ) # type: ignore
+ self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
+ self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore
+ self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
+ self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
+ self.completion_embedding = EmbeddingConfig.create(
+ **self.completion_embedding, app=self.app
+ ) # type: ignore
+ self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
+ self.agent = RAGAgentConfig.create(**self.agent, app=self.app) # type: ignore
+ self.orchestration = OrchestrationConfig.create(
+ **self.orchestration, app=self.app
+ ) # type: ignore
+
+ IngestionConfig.set_default(**self.ingestion.dict())
+
+ # override GenerationConfig defaults
+ if self.completion.generation_config:
+ GenerationConfig.set_default(
+ **self.completion.generation_config.dict()
+ )
+
+ def _validate_config_section(
+ self, config_data: dict[str, Any], section: str, keys: list
+ ):
+ if section not in config_data:
+ raise ValueError(f"Missing '{section}' section in config")
+ if missing_keys := [
+ key for key in keys if key not in config_data[section]
+ ]:
+ raise ValueError(
+ f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
+ )
+
+ @classmethod
+ def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
+ if config_path is None:
+ config_path = R2RConfig.default_config_path
+
+ # Load configuration from TOML file
+ with open(config_path, encoding="utf-8") as f:
+ config_data = toml.load(f)
+
+ return cls(config_data)
+
+ def to_toml(self):
+ config_data = {}
+ for section in R2RConfig.REQUIRED_KEYS.keys():
+ section_data = self._serialize_config(getattr(self, section))
+ if isinstance(section_data, dict):
+ # Remove app from nested configs before serializing
+ section_data.pop("app", None)
+ config_data[section] = section_data
+ return toml.dumps(config_data)
+
+ @classmethod
+ def load_default_config(cls) -> dict:
+ with open(R2RConfig.default_config_path, encoding="utf-8") as f:
+ return toml.load(f)
+
+ @staticmethod
+ def _serialize_config(config_section: Any):
+ """Serialize config section while excluding internal state."""
+ if isinstance(config_section, dict):
+ return {
+ R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
+ for k, v in config_section.items()
+ if k != "app" # Exclude app from serialization
+ }
+ elif isinstance(config_section, (list, tuple)):
+ return [
+ R2RConfig._serialize_config(item) for item in config_section
+ ]
+ elif isinstance(config_section, Enum):
+ return config_section.value
+ elif isinstance(config_section, BaseModel):
+ data = config_section.model_dump(exclude_none=True)
+ data.pop("app", None) # Remove app from the serialized data
+ return R2RConfig._serialize_config(data)
+ else:
+ return config_section
+
+ @staticmethod
+ def _serialize_key(key: Any) -> str:
+ return key.value if isinstance(key, Enum) else str(key)
+
+ @classmethod
+ def load(
+ cls,
+ config_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ ) -> "R2RConfig":
+ if config_path and config_name:
+ raise ValueError(
+ f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
+ )
+
+ if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
+ return cls.from_toml(config_path)
+
+ config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
+ if config_name not in R2RConfig.CONFIG_OPTIONS:
+ raise ValueError(f"Invalid config name: {config_name}")
+ return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py
new file mode 100644
index 00000000..19cb0428
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py
@@ -0,0 +1,16 @@
+# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
+from .hatchet.graph_workflow import ( # type: ignore
+ hatchet_graph_search_results_factory,
+)
+from .hatchet.ingestion_workflow import ( # type: ignore
+ hatchet_ingestion_factory,
+)
+from .simple.graph_workflow import simple_graph_search_results_factory
+from .simple.ingestion_workflow import simple_ingestion_factory
+
+__all__ = [
+ "hatchet_ingestion_factory",
+ "hatchet_graph_search_results_factory",
+ "simple_ingestion_factory",
+ "simple_graph_search_results_factory",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py
new file mode 100644
index 00000000..cc128b0f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py
@@ -0,0 +1,539 @@
+# type: ignore
+import asyncio
+import contextlib
+import json
+import logging
+import math
+import time
+import uuid
+from typing import TYPE_CHECKING
+
+from hatchet_sdk import ConcurrencyLimitStrategy, Context
+
+from core import GenerationConfig
+from core.base import OrchestrationProvider, R2RException
+from core.base.abstractions import (
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+)
+
+from ...services import GraphService
+
+if TYPE_CHECKING:
+ from hatchet_sdk import Hatchet
+
+logger = logging.getLogger()
+
+
+def hatchet_graph_search_results_factory(
+ orchestration_provider: OrchestrationProvider, service: GraphService
+) -> dict[str, "Hatchet.Workflow"]:
+ def convert_to_dict(input_data):
+ """Converts input data back to a plain dictionary format, handling
+ special cases like UUID and GenerationConfig. This is the inverse of
+ get_input_data_dict.
+
+ Args:
+ input_data: Dictionary containing the input data with potentially special types
+
+ Returns:
+ Dictionary with all values converted to basic Python types
+ """
+ output_data = {}
+
+ for key, value in input_data.items():
+ if value is None:
+ output_data[key] = None
+ continue
+
+ # Convert UUID to string
+ if isinstance(value, uuid.UUID):
+ output_data[key] = str(value)
+
+ try:
+ output_data[key] = value.model_dump()
+ except Exception:
+ # Handle nested dictionaries that might contain settings
+ if isinstance(value, dict):
+ output_data[key] = convert_to_dict(value)
+
+ # Handle lists that might contain dictionaries
+ elif isinstance(value, list):
+ output_data[key] = [
+ (
+ convert_to_dict(item)
+ if isinstance(item, dict)
+ else item
+ )
+ for item in value
+ ]
+
+ # All other types can be directly assigned
+ else:
+ output_data[key] = value
+
+ return output_data
+
+ def get_input_data_dict(input_data):
+ for key, value in input_data.items():
+ if value is None:
+ continue
+
+ if key == "document_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "collection_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "graph_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key in ["graph_creation_settings", "graph_enrichment_settings"]:
+ # Ensure we have a dict (if not already)
+ input_data[key] = (
+ json.loads(value) if not isinstance(value, dict) else value
+ )
+
+ if "generation_config" in input_data[key]:
+ gen_cfg = input_data[key]["generation_config"]
+ # If it's a dict, convert it
+ if isinstance(gen_cfg, dict):
+ input_data[key]["generation_config"] = (
+ GenerationConfig(**gen_cfg)
+ )
+ # If it's not already a GenerationConfig, default it
+ elif not isinstance(gen_cfg, GenerationConfig):
+ input_data[key]["generation_config"] = (
+ GenerationConfig()
+ )
+
+ input_data[key]["generation_config"].model = (
+ input_data[key]["generation_config"].model
+ or service.config.app.fast_llm
+ )
+
+ return input_data
+
+ @orchestration_provider.workflow(name="graph-extraction", timeout="360m")
+ class GraphExtractionWorkflow:
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ with contextlib.suppress(Exception):
+ return str(
+ context.workflow_input()["request"]["collection_id"]
+ )
+
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.step(retries=1, timeout="360m")
+ async def graph_search_results_extraction(
+ self, context: Context
+ ) -> dict:
+ request = context.workflow_input()["request"]
+
+ input_data = get_input_data_dict(request)
+ document_id = input_data.get("document_id", None)
+ collection_id = input_data.get("collection_id", None)
+
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=document_id,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.PROCESSING,
+ )
+
+ if collection_id and not document_id:
+ document_ids = await self.graph_search_results_service.get_document_ids_for_create_graph(
+ collection_id=collection_id,
+ **input_data["graph_creation_settings"],
+ )
+ workflows = []
+
+ for document_id in document_ids:
+ input_data_copy = input_data.copy()
+ input_data_copy["collection_id"] = str(
+ input_data_copy["collection_id"]
+ )
+ input_data_copy["document_id"] = str(document_id)
+
+ workflows.append(
+ context.aio.spawn_workflow(
+ "graph-extraction",
+ {
+ "request": {
+ **convert_to_dict(input_data_copy),
+ }
+ },
+ key=str(document_id),
+ )
+ )
+ # Wait for all workflows to complete
+ results = await asyncio.gather(*workflows)
+ return {
+ "result": f"successfully submitted graph_search_results relationships extraction for document {document_id}",
+ "document_id": str(collection_id),
+ }
+
+ else:
+ # Extract relationships and store them
+ extractions = []
+ async for extraction in self.graph_search_results_service.graph_search_results_extraction(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ ):
+ logger.info(
+ f"Found extraction with {len(extraction.entities)} entities"
+ )
+ extractions.append(extraction)
+
+ await self.graph_search_results_service.store_graph_search_results_extractions(
+ extractions
+ )
+
+ logger.info(
+ f"Successfully ran graph_search_results relationships extraction for document {document_id}"
+ )
+
+ return {
+ "result": f"successfully ran graph_search_results relationships extraction for document {document_id}",
+ "document_id": str(document_id),
+ }
+
+ @orchestration_provider.step(
+ retries=1,
+ timeout="360m",
+ parents=["graph_search_results_extraction"],
+ )
+ async def graph_search_results_entity_description(
+ self, context: Context
+ ) -> dict:
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+ document_id = input_data.get("document_id", None)
+
+ # Describe the entities in the graph
+ await self.graph_search_results_service.graph_search_results_entity_description(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ )
+
+ logger.info(
+ f"Successfully ran graph_search_results entity description for document {document_id}"
+ )
+
+ if service.providers.database.config.graph_creation_settings.automatic_deduplication:
+ extract_input = {
+ "document_id": str(document_id),
+ }
+
+ extract_result = (
+ await context.aio.spawn_workflow(
+ "graph-deduplication",
+ {"request": extract_input},
+ )
+ ).result()
+
+ await asyncio.gather(extract_result)
+
+ return {
+ "result": f"successfully ran graph_search_results entity description for document {document_id}"
+ }
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ request = context.workflow_input().get("request", {})
+ document_id = request.get("document_id")
+
+ if not document_id:
+ logger.info(
+ "No document id was found in workflow input to mark a failure."
+ )
+ return
+
+ try:
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=uuid.UUID(document_id),
+ status_type="extraction_status",
+ status=GraphExtractionStatus.FAILED,
+ )
+ logger.info(
+ f"Updated Graph extraction status for {document_id} to FAILED"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status for {document_id}: {e}"
+ )
+
+ @orchestration_provider.workflow(name="graph-clustering", timeout="360m")
+ class GraphClusteringWorkflow:
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.step(retries=1, timeout="360m", parents=[])
+ async def graph_search_results_clustering(
+ self, context: Context
+ ) -> dict:
+ logger.info("Running Graph Clustering")
+
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+
+ # Get the collection_id and graph_id
+ collection_id = input_data.get("collection_id", None)
+ graph_id = input_data.get("graph_id", None)
+
+ # Check current workflow status
+ workflow_status = await self.graph_search_results_service.providers.database.documents_handler.get_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ )
+
+ if workflow_status == GraphConstructionStatus.SUCCESS:
+ raise R2RException(
+ "Communities have already been built for this collection. To build communities again, first reset the graph.",
+ 400,
+ )
+
+ # Run clustering
+ try:
+ graph_search_results_clustering_results = await self.graph_search_results_service.graph_search_results_clustering(
+ collection_id=collection_id,
+ graph_id=graph_id,
+ **input_data["graph_enrichment_settings"],
+ )
+
+ num_communities = graph_search_results_clustering_results[
+ "num_communities"
+ ][0]
+
+ if num_communities == 0:
+ raise R2RException("No communities found", 400)
+
+ return {
+ "result": graph_search_results_clustering_results,
+ }
+ except Exception as e:
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.FAILED,
+ )
+ raise e
+
+ @orchestration_provider.step(
+ retries=1,
+ timeout="360m",
+ parents=["graph_search_results_clustering"],
+ )
+ async def graph_search_results_community_summary(
+ self, context: Context
+ ) -> dict:
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+ collection_id = input_data.get("collection_id", None)
+ graph_id = input_data.get("graph_id", None)
+ # Get number of communities from previous step
+ num_communities = context.step_output(
+ "graph_search_results_clustering"
+ )["result"]["num_communities"][0]
+
+ # Calculate batching
+ parallel_communities = min(100, num_communities)
+ total_workflows = math.ceil(num_communities / parallel_communities)
+ workflows = []
+
+ logger.info(
+ f"Running Graph Community Summary for {num_communities} communities, spawning {total_workflows} workflows"
+ )
+
+ # Spawn summary workflows
+ for i in range(total_workflows):
+ offset = i * parallel_communities
+ limit = min(parallel_communities, num_communities - offset)
+
+ workflows.append(
+ (
+ await context.aio.spawn_workflow(
+ "graph-community-summarization",
+ {
+ "request": {
+ "offset": offset,
+ "limit": limit,
+ "graph_id": (
+ str(graph_id) if graph_id else None
+ ),
+ "collection_id": (
+ str(collection_id)
+ if collection_id
+ else None
+ ),
+ "graph_enrichment_settings": convert_to_dict(
+ input_data["graph_enrichment_settings"]
+ ),
+ }
+ },
+ key=f"{i}/{total_workflows}_community_summary",
+ )
+ ).result()
+ )
+
+ results = await asyncio.gather(*workflows)
+ logger.info(
+ f"Completed {len(results)} community summary workflows"
+ )
+
+ # Update statuses
+ document_ids = await self.graph_search_results_service.providers.database.documents_handler.get_document_ids_by_status(
+ status_type="extraction_status",
+ status=GraphExtractionStatus.SUCCESS,
+ collection_id=collection_id,
+ )
+
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=document_ids,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.ENRICHED,
+ )
+
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.SUCCESS,
+ )
+
+ return {
+ "result": f"Successfully completed enrichment with {len(results)} summary workflows"
+ }
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ collection_id = context.workflow_input()["request"].get(
+ "collection_id", None
+ )
+ if collection_id:
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=uuid.UUID(collection_id),
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.FAILED,
+ )
+
+ @orchestration_provider.workflow(
+ name="graph-community-summarization", timeout="360m"
+ )
+ class GraphCommunitySummarizerWorkflow:
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ try:
+ return str(
+ context.workflow_input()["request"]["collection_id"]
+ )
+ except Exception:
+ return str(uuid.uuid4())
+
+ @orchestration_provider.step(retries=1, timeout="360m")
+ async def graph_search_results_community_summary(
+ self, context: Context
+ ) -> dict:
+ start_time = time.time()
+
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+
+ base_args = {
+ k: v
+ for k, v in input_data.items()
+ if k != "graph_enrichment_settings"
+ }
+ enrichment_args = input_data.get("graph_enrichment_settings", {})
+
+ # Merge them together.
+ # Note: if there is any key overlap, values from enrichment_args will override those from base_args.
+ merged_args = {**base_args, **enrichment_args}
+
+ # Now call the service method with all arguments at the top level.
+ # This ensures that keys like "max_summary_input_length" and "generation_config" are present.
+ community_summary = await self.graph_search_results_service.graph_search_results_community_summary(
+ **merged_args
+ )
+ logger.info(
+ f"Successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)} in {time.time() - start_time:.2f} seconds "
+ )
+ return {
+ "result": f"successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}"
+ }
+
+ @orchestration_provider.workflow(
+ name="graph-deduplication", timeout="360m"
+ )
+ class GraphDeduplicationWorkflow:
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ try:
+ return str(context.workflow_input()["request"]["document_id"])
+ except Exception:
+ return str(uuid.uuid4())
+
+ @orchestration_provider.step(retries=1, timeout="360m")
+ async def deduplicate_document_entities(
+ self, context: Context
+ ) -> dict:
+ start_time = time.time()
+
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+
+ document_id = input_data.get("document_id", None)
+
+ await service.deduplicate_document_entities(
+ document_id=document_id,
+ )
+ logger.info(
+ f"Successfully ran deduplication for document {document_id} in {time.time() - start_time:.2f} seconds "
+ )
+ return {
+ "result": f"Successfully ran deduplication for document {document_id}"
+ }
+
+ return {
+ "graph-extraction": GraphExtractionWorkflow(service),
+ "graph-clustering": GraphClusteringWorkflow(service),
+ "graph-community-summarization": GraphCommunitySummarizerWorkflow(
+ service
+ ),
+ "graph-deduplication": GraphDeduplicationWorkflow(service),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py
new file mode 100644
index 00000000..96d7aebb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py
@@ -0,0 +1,721 @@
+# type: ignore
+import asyncio
+import logging
+import uuid
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+import tiktoken
+from fastapi import HTTPException
+from hatchet_sdk import ConcurrencyLimitStrategy, Context
+from litellm import AuthenticationError
+
+from core.base import (
+ DocumentChunk,
+ GraphConstructionStatus,
+ IngestionStatus,
+ OrchestrationProvider,
+ generate_extraction_id,
+)
+from core.base.abstractions import DocumentResponse, R2RException
+from core.utils import (
+ generate_default_user_collection_id,
+ update_settings_from_dict,
+)
+
+from ...services import IngestionService, IngestionServiceAdapter
+
+if TYPE_CHECKING:
+ from hatchet_sdk import Hatchet
+
+logger = logging.getLogger()
+
+
+# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module
+def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int:
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ # Fallback to a known encoding if model not recognized
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+def hatchet_ingestion_factory(
+ orchestration_provider: OrchestrationProvider, service: IngestionService
+) -> dict[str, "Hatchet.Workflow"]:
+ @orchestration_provider.workflow(
+ name="ingest-files",
+ timeout="60m",
+ )
+ class HatchetIngestFilesWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.ingestion_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ try:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_ingest_file_input(
+ input_data
+ )
+ return str(parsed_data["user"].id)
+ except Exception:
+ return str(uuid.uuid4())
+
+ @orchestration_provider.step(retries=0, timeout="60m")
+ async def parse(self, context: Context) -> dict:
+ try:
+ logger.info("Initiating ingestion workflow, step: parse")
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_ingest_file_input(
+ input_data
+ )
+
+ # ingestion_result = (
+ # await self.ingestion_service.ingest_file_ingress(
+ # **parsed_data
+ # )
+ # )
+
+ # document_info = ingestion_result["info"]
+ document_info = (
+ self.ingestion_service.create_document_info_from_file(
+ parsed_data["document_id"],
+ parsed_data["user"],
+ parsed_data["file_data"]["filename"],
+ parsed_data["metadata"],
+ parsed_data["version"],
+ parsed_data["size_in_bytes"],
+ )
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.PARSING,
+ )
+
+ ingestion_config = parsed_data["ingestion_config"] or {}
+ extractions_generator = self.ingestion_service.parse_file(
+ document_info, ingestion_config
+ )
+
+ extractions = []
+ async for extraction in extractions_generator:
+ extractions.append(extraction)
+
+ # 2) Sum tokens
+ total_tokens = 0
+ for chunk in extractions:
+ text_data = chunk.data
+ if not isinstance(text_data, str):
+ text_data = text_data.decode("utf-8", errors="ignore")
+ total_tokens += count_tokens_for_text(text_data)
+ document_info.total_tokens = total_tokens
+
+ if not ingestion_config.get("skip_document_summary", False):
+ await service.update_document_status(
+ document_info, status=IngestionStatus.AUGMENTING
+ )
+ await service.augment_document_info(
+ document_info,
+ [extraction.to_dict() for extraction in extractions],
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.EMBEDDING,
+ )
+
+ # extractions = context.step_output("parse")["extractions"]
+
+ embedding_generator = self.ingestion_service.embed_document(
+ [extraction.to_dict() for extraction in extractions]
+ )
+
+ embeddings = []
+ async for embedding in embedding_generator:
+ embeddings.append(embedding)
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.STORING,
+ )
+
+ storage_generator = self.ingestion_service.store_embeddings( # type: ignore
+ embeddings
+ )
+
+ async for _ in storage_generator:
+ pass
+
+ await self.ingestion_service.finalize_ingestion(document_info)
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.SUCCESS,
+ )
+
+ collection_ids = context.workflow_input()["request"].get(
+ "collection_ids"
+ )
+ if not collection_ids:
+ # TODO: Move logic onto the `management service`
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ else:
+ for collection_id_str in collection_ids:
+ collection_id = UUID(collection_id_str)
+ try:
+ name = document_info.title or "N/A"
+ description = ""
+ await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await (
+ self.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ )
+
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ status=GraphConstructionStatus.OUTDATED,
+ )
+
+ # get server chunk enrichment settings and override parts of it if provided in the ingestion config
+ if server_chunk_enrichment_settings := getattr(
+ service.providers.ingestion.config,
+ "chunk_enrichment_settings",
+ None,
+ ):
+ chunk_enrichment_settings = update_settings_from_dict(
+ server_chunk_enrichment_settings,
+ ingestion_config.get("chunk_enrichment_settings", {})
+ or {},
+ )
+
+ if chunk_enrichment_settings.enable_chunk_enrichment:
+ logger.info("Enriching document with contextual chunks")
+
+ document_info: DocumentResponse = (
+ await self.ingestion_service.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_user_ids=[document_info.owner_id],
+ filter_document_ids=[document_info.id],
+ )
+ )["results"][0]
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.ENRICHING,
+ )
+
+ await self.ingestion_service.chunk_enrichment(
+ document_id=document_info.id,
+ document_summary=document_info.summary,
+ chunk_enrichment_settings=chunk_enrichment_settings,
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.SUCCESS,
+ )
+ # ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+ if service.providers.ingestion.config.automatic_extraction:
+ extract_input = {
+ "document_id": str(document_info.id),
+ "graph_creation_settings": self.ingestion_service.providers.database.config.graph_creation_settings.model_dump_json(),
+ "user": input_data["user"],
+ }
+
+ extract_result = (
+ await context.aio.spawn_workflow(
+ "graph-extraction",
+ {"request": extract_input},
+ )
+ ).result()
+
+ await asyncio.gather(extract_result)
+
+ return {
+ "status": "Successfully finalized ingestion",
+ "document_info": document_info.to_dict(),
+ }
+
+ except AuthenticationError:
+ raise R2RException(
+ status_code=401,
+ message="Authentication error: Invalid API key or credentials.",
+ ) from None
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during ingestion: {str(e)}",
+ ) from e
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ request = context.workflow_input().get("request", {})
+ document_id = request.get("document_id")
+
+ if not document_id:
+ logger.error(
+ "No document id was found in workflow input to mark a failure."
+ )
+ return
+
+ try:
+ documents_overview = (
+ await self.ingestion_service.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+
+ if not documents_overview:
+ logger.error(
+ f"Document with id {document_id} not found in database to mark failure."
+ )
+ return
+
+ document_info = documents_overview[0]
+
+ # Update the document status to FAILED
+ if document_info.ingestion_status != IngestionStatus.SUCCESS:
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{context.step_run_errors()}"},
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status for {document_id}: {e}"
+ )
+
+ @orchestration_provider.workflow(
+ name="ingest-chunks",
+ timeout="60m",
+ )
+ class HatchetIngestChunksWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="60m")
+ async def ingest(self, context: Context) -> dict:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input(
+ input_data
+ )
+
+ document_info = await self.ingestion_service.ingest_chunks_ingress(
+ **parsed_data
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.EMBEDDING
+ )
+ document_id = document_info.id
+
+ extractions = [
+ DocumentChunk(
+ id=generate_extraction_id(document_id, i),
+ document_id=document_id,
+ collection_ids=[],
+ owner_id=document_info.owner_id,
+ data=chunk.text,
+ metadata=parsed_data["metadata"],
+ ).to_dict()
+ for i, chunk in enumerate(parsed_data["chunks"])
+ ]
+
+ # 2) Sum tokens
+ total_tokens = 0
+ for chunk in extractions:
+ text_data = chunk["data"]
+ if not isinstance(text_data, str):
+ text_data = text_data.decode("utf-8", errors="ignore")
+ total_tokens += count_tokens_for_text(text_data)
+ document_info.total_tokens = total_tokens
+
+ return {
+ "status": "Successfully ingested chunks",
+ "extractions": extractions,
+ "document_info": document_info.to_dict(),
+ }
+
+ @orchestration_provider.step(parents=["ingest"], timeout="60m")
+ async def embed(self, context: Context) -> dict:
+ document_info_dict = context.step_output("ingest")["document_info"]
+ document_info = DocumentResponse(**document_info_dict)
+
+ extractions = context.step_output("ingest")["extractions"]
+
+ embedding_generator = self.ingestion_service.embed_document(
+ extractions
+ )
+ embeddings = [
+ embedding.model_dump()
+ async for embedding in embedding_generator
+ ]
+
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.STORING
+ )
+
+ storage_generator = self.ingestion_service.store_embeddings(
+ embeddings
+ )
+ async for _ in storage_generator:
+ pass
+
+ return {
+ "status": "Successfully embedded and stored chunks",
+ "document_info": document_info.to_dict(),
+ }
+
+ @orchestration_provider.step(parents=["embed"], timeout="60m")
+ async def finalize(self, context: Context) -> dict:
+ document_info_dict = context.step_output("embed")["document_info"]
+ document_info = DocumentResponse(**document_info_dict)
+
+ await self.ingestion_service.finalize_ingestion(document_info)
+
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.SUCCESS
+ )
+
+ try:
+ # TODO - Move logic onto the `management service`
+ collection_ids = context.workflow_input()["request"].get(
+ "collection_ids"
+ )
+ if not collection_ids:
+ # TODO: Move logic onto the `management service`
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ else:
+ for collection_id_str in collection_ids:
+ collection_id = UUID(collection_id_str)
+ try:
+ name = document_info.title or "N/A"
+ description = ""
+ await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await (
+ self.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ )
+
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+ except Exception as e:
+ logger.error(
+ f"Error during assigning document to collection: {str(e)}"
+ )
+
+ return {
+ "status": "Successfully finalized ingestion",
+ "document_info": document_info.to_dict(),
+ }
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ request = context.workflow_input().get("request", {})
+ document_id = request.get("document_id")
+
+ if not document_id:
+ logger.error(
+ "No document id was found in workflow input to mark a failure."
+ )
+ return
+
+ try:
+ documents_overview = (
+ await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+ offset=0,
+ limit=100,
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+
+ if not documents_overview:
+ logger.error(
+ f"Document with id {document_id} not found in database to mark failure."
+ )
+ return
+
+ document_info = documents_overview[0]
+
+ if document_info.ingestion_status != IngestionStatus.SUCCESS:
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.FAILED
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status for {document_id}: {e}"
+ )
+
+ @orchestration_provider.workflow(
+ name="update-chunk",
+ timeout="60m",
+ )
+ class HatchetUpdateChunkWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="60m")
+ async def update_chunk(self, context: Context) -> dict:
+ try:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_update_chunk_input(
+ input_data
+ )
+
+ document_uuid = (
+ UUID(parsed_data["document_id"])
+ if isinstance(parsed_data["document_id"], str)
+ else parsed_data["document_id"]
+ )
+ extraction_uuid = (
+ UUID(parsed_data["id"])
+ if isinstance(parsed_data["id"], str)
+ else parsed_data["id"]
+ )
+
+ await self.ingestion_service.update_chunk_ingress(
+ document_id=document_uuid,
+ chunk_id=extraction_uuid,
+ text=parsed_data.get("text"),
+ user=parsed_data["user"],
+ metadata=parsed_data.get("metadata"),
+ collection_ids=parsed_data.get("collection_ids"),
+ )
+
+ return {
+ "message": "Chunk update completed successfully.",
+ "task_id": context.workflow_run_id(),
+ "document_ids": [str(document_uuid)],
+ }
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during chunk update: {str(e)}",
+ ) from e
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ # Handle failure case if necessary
+ pass
+
+ @orchestration_provider.workflow(
+ name="create-vector-index", timeout="360m"
+ )
+ class HatchetCreateVectorIndexWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="60m")
+ async def create_vector_index(self, context: Context) -> dict:
+ input_data = context.workflow_input()["request"]
+ parsed_data = (
+ IngestionServiceAdapter.parse_create_vector_index_input(
+ input_data
+ )
+ )
+
+ await self.ingestion_service.providers.database.chunks_handler.create_index(
+ **parsed_data
+ )
+
+ return {
+ "status": "Vector index creation queued successfully.",
+ }
+
+ @orchestration_provider.workflow(name="delete-vector-index", timeout="30m")
+ class HatchetDeleteVectorIndexWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="10m")
+ async def delete_vector_index(self, context: Context) -> dict:
+ input_data = context.workflow_input()["request"]
+ parsed_data = (
+ IngestionServiceAdapter.parse_delete_vector_index_input(
+ input_data
+ )
+ )
+
+ await self.ingestion_service.providers.database.chunks_handler.delete_index(
+ **parsed_data
+ )
+
+ return {"status": "Vector index deleted successfully."}
+
+ @orchestration_provider.workflow(
+ name="update-document-metadata",
+ timeout="30m",
+ )
+ class HatchetUpdateDocumentMetadataWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="30m")
+ async def update_document_metadata(self, context: Context) -> dict:
+ try:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_update_document_metadata_input(
+ input_data
+ )
+
+ document_id = UUID(parsed_data["document_id"])
+ metadata = parsed_data["metadata"]
+ user = parsed_data["user"]
+
+ await self.ingestion_service.update_document_metadata(
+ document_id=document_id,
+ metadata=metadata,
+ user=user,
+ )
+
+ return {
+ "message": "Document metadata update completed successfully.",
+ "document_id": str(document_id),
+ "task_id": context.workflow_run_id(),
+ }
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during document metadata update: {str(e)}",
+ ) from e
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ # Handle failure case if necessary
+ pass
+
+ # Add this to the workflows dictionary in hatchet_ingestion_factory
+ ingest_files_workflow = HatchetIngestFilesWorkflow(service)
+ ingest_chunks_workflow = HatchetIngestChunksWorkflow(service)
+ update_chunks_workflow = HatchetUpdateChunkWorkflow(service)
+ update_document_metadata_workflow = HatchetUpdateDocumentMetadataWorkflow(
+ service
+ )
+ create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service)
+ delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service)
+
+ return {
+ "ingest_files": ingest_files_workflow,
+ "ingest_chunks": ingest_chunks_workflow,
+ "update_chunk": update_chunks_workflow,
+ "update_document_metadata": update_document_metadata_workflow,
+ "create_vector_index": create_vector_index_workflow,
+ "delete_vector_index": delete_vector_index_workflow,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py
new file mode 100644
index 00000000..9e043263
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py
@@ -0,0 +1,222 @@
+import json
+import logging
+import math
+import uuid
+
+from core import GenerationConfig, R2RException
+from core.base.abstractions import (
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+)
+
+from ...services import GraphService
+
+logger = logging.getLogger()
+
+
+def simple_graph_search_results_factory(service: GraphService):
+ def get_input_data_dict(input_data):
+ for key, value in input_data.items():
+ if value is None:
+ continue
+
+ if key == "document_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "collection_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "graph_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key in ["graph_creation_settings", "graph_enrichment_settings"]:
+ # Ensure we have a dict (if not already)
+ input_data[key] = (
+ json.loads(value) if not isinstance(value, dict) else value
+ )
+
+ if "generation_config" in input_data[key]:
+ if isinstance(input_data[key]["generation_config"], dict):
+ input_data[key]["generation_config"] = (
+ GenerationConfig(
+ **input_data[key]["generation_config"]
+ )
+ )
+ elif not isinstance(
+ input_data[key]["generation_config"], GenerationConfig
+ ):
+ input_data[key]["generation_config"] = (
+ GenerationConfig()
+ )
+
+ input_data[key]["generation_config"].model = (
+ input_data[key]["generation_config"].model
+ or service.config.app.fast_llm
+ )
+
+ return input_data
+
+ async def graph_extraction(input_data):
+ input_data = get_input_data_dict(input_data)
+
+ if input_data.get("document_id"):
+ document_ids = [input_data.get("document_id")]
+ else:
+ documents = []
+ collection_id = input_data.get("collection_id")
+ batch_size = 100
+ offset = 0
+ while True:
+ # Fetch current batch
+ batch = (
+ await service.providers.database.collections_handler.documents_in_collection(
+ collection_id=collection_id,
+ offset=offset,
+ limit=batch_size,
+ )
+ )["results"]
+
+ # If no documents returned, we've reached the end
+ if not batch:
+ break
+
+ # Add current batch to results
+ documents.extend(batch)
+
+ # Update offset for next batch
+ offset += batch_size
+
+ # Optional: If batch is smaller than batch_size, we've reached the end
+ if len(batch) < batch_size:
+ break
+
+ document_ids = [document.id for document in documents]
+
+ logger.info(
+ f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}"
+ )
+
+ for _, document_id in enumerate(document_ids):
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=document_id,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.PROCESSING,
+ )
+
+ # Extract relationships from the document
+ try:
+ extractions = []
+ async for (
+ extraction
+ ) in service.graph_search_results_extraction(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ ):
+ extractions.append(extraction)
+ await service.store_graph_search_results_extractions(
+ extractions
+ )
+
+ # Describe the entities in the graph
+ await service.graph_search_results_entity_description(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ )
+
+ if service.providers.database.config.graph_creation_settings.automatic_deduplication:
+ logger.warning(
+ "Automatic deduplication is not yet implemented for `simple` workflows."
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Error in creating graph for document {document_id}: {e}"
+ )
+ raise e
+
+ async def graph_clustering(input_data):
+ input_data = get_input_data_dict(input_data)
+ workflow_status = await service.providers.database.documents_handler.get_workflow_status(
+ id=input_data.get("collection_id", None),
+ status_type="graph_cluster_status",
+ )
+ if workflow_status == GraphConstructionStatus.SUCCESS:
+ raise R2RException(
+ "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset` to erase the previously built communities.",
+ 400,
+ )
+
+ try:
+ num_communities = await service.graph_search_results_clustering(
+ collection_id=input_data.get("collection_id", None),
+ # graph_id=input_data.get("graph_id", None),
+ **input_data["graph_enrichment_settings"],
+ )
+ num_communities = num_communities["num_communities"][0]
+ # TODO - Do not hardcode the number of parallel communities,
+ # make it a configurable parameter at runtime & add server-side defaults
+
+ if num_communities == 0:
+ raise R2RException("No communities found", 400)
+
+ parallel_communities = min(100, num_communities)
+
+ total_workflows = math.ceil(num_communities / parallel_communities)
+ for i in range(total_workflows):
+ input_data_copy = input_data.copy()
+ input_data_copy["offset"] = i * parallel_communities
+ input_data_copy["limit"] = min(
+ parallel_communities,
+ num_communities - i * parallel_communities,
+ )
+
+ logger.info(
+ f"Running graph_search_results community summary for workflow {i + 1} of {total_workflows}"
+ )
+
+ await service.graph_search_results_community_summary(
+ offset=input_data_copy["offset"],
+ limit=input_data_copy["limit"],
+ collection_id=input_data_copy.get("collection_id", None),
+ # graph_id=input_data_copy.get("graph_id", None),
+ **input_data_copy["graph_enrichment_settings"],
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=input_data.get("collection_id", None),
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.SUCCESS,
+ )
+
+ except Exception as e:
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=input_data.get("collection_id", None),
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.FAILED,
+ )
+
+ raise e
+
+ async def graph_deduplication(input_data):
+ input_data = get_input_data_dict(input_data)
+ await service.deduplicate_document_entities(
+ document_id=input_data.get("document_id", None),
+ )
+
+ return {
+ "graph-extraction": graph_extraction,
+ "graph-clustering": graph_clustering,
+ "graph-deduplication": graph_deduplication,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py
new file mode 100644
index 00000000..60a696c1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py
@@ -0,0 +1,598 @@
+import asyncio
+import logging
+from uuid import UUID
+
+import tiktoken
+from fastapi import HTTPException
+from litellm import AuthenticationError
+
+from core.base import (
+ DocumentChunk,
+ GraphConstructionStatus,
+ R2RException,
+ increment_version,
+)
+from core.utils import (
+ generate_default_user_collection_id,
+ generate_extraction_id,
+ update_settings_from_dict,
+)
+
+from ...services import IngestionService
+
+logger = logging.getLogger()
+
+
+# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module
+def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int:
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ # Fallback to a known encoding if model not recognized
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+def simple_ingestion_factory(service: IngestionService):
+ async def ingest_files(input_data):
+ document_info = None
+ try:
+ from core.base import IngestionStatus
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = IngestionServiceAdapter.parse_ingest_file_input(
+ input_data
+ )
+
+ document_info = service.create_document_info_from_file(
+ parsed_data["document_id"],
+ parsed_data["user"],
+ parsed_data["file_data"]["filename"],
+ parsed_data["metadata"],
+ parsed_data["version"],
+ parsed_data["size_in_bytes"],
+ )
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.PARSING
+ )
+
+ ingestion_config = parsed_data["ingestion_config"]
+ extractions_generator = service.parse_file(
+ document_info=document_info,
+ ingestion_config=ingestion_config,
+ )
+ extractions = [
+ extraction.model_dump()
+ async for extraction in extractions_generator
+ ]
+
+ # 2) Sum tokens
+ total_tokens = 0
+ for chunk_dict in extractions:
+ text_data = chunk_dict["data"]
+ if not isinstance(text_data, str):
+ text_data = text_data.decode("utf-8", errors="ignore")
+ total_tokens += count_tokens_for_text(text_data)
+ document_info.total_tokens = total_tokens
+
+ if not ingestion_config.get("skip_document_summary", False):
+ await service.update_document_status(
+ document_info=document_info,
+ status=IngestionStatus.AUGMENTING,
+ )
+ await service.augment_document_info(document_info, extractions)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.EMBEDDING
+ )
+ embedding_generator = service.embed_document(extractions)
+ embeddings = [
+ embedding.model_dump()
+ async for embedding in embedding_generator
+ ]
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.STORING
+ )
+ storage_generator = service.store_embeddings(embeddings)
+ async for _ in storage_generator:
+ pass
+
+ await service.finalize_ingestion(document_info)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.SUCCESS
+ )
+
+ collection_ids = parsed_data.get("collection_ids")
+
+ try:
+ if not collection_ids:
+ # TODO: Move logic onto the `management service`
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+ else:
+ for collection_id in collection_ids:
+ try:
+ # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
+ name = "My Collection"
+ description = f"A collection started during {document_info.title} ingestion"
+
+ await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await service.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+ except Exception as e:
+ logger.error(
+ f"Error during assigning document to collection: {str(e)}"
+ )
+
+ # Chunk enrichment
+ if server_chunk_enrichment_settings := getattr(
+ service.providers.ingestion.config,
+ "chunk_enrichment_settings",
+ None,
+ ):
+ chunk_enrichment_settings = update_settings_from_dict(
+ server_chunk_enrichment_settings,
+ ingestion_config.get("chunk_enrichment_settings", {})
+ or {},
+ )
+
+ if chunk_enrichment_settings.enable_chunk_enrichment:
+ logger.info("Enriching document with contextual chunks")
+
+ # Get updated document info with collection IDs
+ document_info = (
+ await service.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_user_ids=[document_info.owner_id],
+ filter_document_ids=[document_info.id],
+ )
+ )["results"][0]
+
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.ENRICHING,
+ )
+
+ await service.chunk_enrichment(
+ document_id=document_info.id,
+ document_summary=document_info.summary,
+ chunk_enrichment_settings=chunk_enrichment_settings,
+ )
+
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.SUCCESS,
+ )
+
+ # Automatic extraction
+ if service.providers.ingestion.config.automatic_extraction:
+ logger.warning(
+ "Automatic extraction not yet implemented for `simple` ingestion workflows."
+ )
+
+ except AuthenticationError as e:
+ if document_info is not None:
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{str(e)}"},
+ )
+ raise R2RException(
+ status_code=401,
+ message="Authentication error: Invalid API key or credentials.",
+ ) from e
+ except Exception as e:
+ if document_info is not None:
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{str(e)}"},
+ )
+ if isinstance(e, R2RException):
+ raise
+ raise HTTPException(
+ status_code=500, detail=f"Error during ingestion: {str(e)}"
+ ) from e
+
+ async def update_files(input_data):
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = IngestionServiceAdapter.parse_update_files_input(
+ input_data
+ )
+
+ file_datas = parsed_data["file_datas"]
+ user = parsed_data["user"]
+ document_ids = parsed_data["document_ids"]
+ metadatas = parsed_data["metadatas"]
+ ingestion_config = parsed_data["ingestion_config"]
+ file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"]
+
+ if not file_datas:
+ raise R2RException(
+ status_code=400, message="No files provided for update."
+ ) from None
+ if len(document_ids) != len(file_datas):
+ raise R2RException(
+ status_code=400,
+ message="Number of ids does not match number of files.",
+ ) from None
+
+ documents_overview = (
+ await service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+ offset=0,
+ limit=100,
+ filter_user_ids=None if user.is_superuser else [user.id],
+ filter_document_ids=document_ids,
+ )
+ )["results"]
+
+ if len(documents_overview) != len(document_ids):
+ raise R2RException(
+ status_code=404,
+ message="One or more documents not found.",
+ ) from None
+
+ results = []
+
+ for idx, (
+ file_data,
+ doc_id,
+ doc_info,
+ file_size_in_bytes,
+ ) in enumerate(
+ zip(
+ file_datas,
+ document_ids,
+ documents_overview,
+ file_sizes_in_bytes,
+ strict=False,
+ )
+ ):
+ new_version = increment_version(doc_info.version)
+
+ updated_metadata = (
+ metadatas[idx] if metadatas else doc_info.metadata
+ )
+ updated_metadata["title"] = (
+ updated_metadata.get("title")
+ or file_data["filename"].split("/")[-1]
+ )
+
+ ingest_input = {
+ "file_data": file_data,
+ "user": user.model_dump(),
+ "metadata": updated_metadata,
+ "document_id": str(doc_id),
+ "version": new_version,
+ "ingestion_config": ingestion_config,
+ "size_in_bytes": file_size_in_bytes,
+ }
+
+ result = ingest_files(ingest_input)
+ results.append(result)
+
+ await asyncio.gather(*results)
+ if service.providers.ingestion.config.automatic_extraction:
+ raise R2RException(
+ status_code=501,
+ message="Automatic extraction not yet implemented for `simple` ingestion workflows.",
+ ) from None
+
+ async def ingest_chunks(input_data):
+ document_info = None
+ try:
+ from core.base import IngestionStatus
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input(
+ input_data
+ )
+
+ document_info = await service.ingest_chunks_ingress(**parsed_data)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.EMBEDDING
+ )
+ document_id = document_info.id
+
+ extractions = [
+ DocumentChunk(
+ id=(
+ generate_extraction_id(document_id, i)
+ if chunk.id is None
+ else chunk.id
+ ),
+ document_id=document_id,
+ collection_ids=[],
+ owner_id=document_info.owner_id,
+ data=chunk.text,
+ metadata=parsed_data["metadata"],
+ ).model_dump()
+ for i, chunk in enumerate(parsed_data["chunks"])
+ ]
+
+ embedding_generator = service.embed_document(extractions)
+ embeddings = [
+ embedding.model_dump()
+ async for embedding in embedding_generator
+ ]
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.STORING
+ )
+ storage_generator = service.store_embeddings(embeddings)
+ async for _ in storage_generator:
+ pass
+
+ await service.finalize_ingestion(document_info)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.SUCCESS
+ )
+
+ collection_ids = parsed_data.get("collection_ids")
+
+ try:
+ # TODO - Move logic onto management service
+ if not collection_ids:
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+
+ else:
+ for collection_id in collection_ids:
+ try:
+ name = document_info.title or "N/A"
+ description = ""
+ result = await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await service.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+
+ if service.providers.ingestion.config.automatic_extraction:
+ raise R2RException(
+ status_code=501,
+ message="Automatic extraction not yet implemented for `simple` ingestion workflows.",
+ ) from None
+
+ except Exception as e:
+ logger.error(
+ f"Error during assigning document to collection: {str(e)}"
+ )
+
+ except Exception as e:
+ if document_info is not None:
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{str(e)}"},
+ )
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during chunk ingestion: {str(e)}",
+ ) from e
+
+ async def update_chunk(input_data):
+ from core.main import IngestionServiceAdapter
+
+ try:
+ parsed_data = IngestionServiceAdapter.parse_update_chunk_input(
+ input_data
+ )
+ document_uuid = (
+ UUID(parsed_data["document_id"])
+ if isinstance(parsed_data["document_id"], str)
+ else parsed_data["document_id"]
+ )
+ extraction_uuid = (
+ UUID(parsed_data["id"])
+ if isinstance(parsed_data["id"], str)
+ else parsed_data["id"]
+ )
+
+ await service.update_chunk_ingress(
+ document_id=document_uuid,
+ chunk_id=extraction_uuid,
+ text=parsed_data.get("text"),
+ user=parsed_data["user"],
+ metadata=parsed_data.get("metadata"),
+ collection_ids=parsed_data.get("collection_ids"),
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during chunk update: {str(e)}",
+ ) from e
+
+ async def create_vector_index(input_data):
+ try:
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = (
+ IngestionServiceAdapter.parse_create_vector_index_input(
+ input_data
+ )
+ )
+
+ await service.providers.database.chunks_handler.create_index(
+ **parsed_data
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during vector index creation: {str(e)}",
+ ) from e
+
+ async def delete_vector_index(input_data):
+ try:
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = (
+ IngestionServiceAdapter.parse_delete_vector_index_input(
+ input_data
+ )
+ )
+
+ await service.providers.database.chunks_handler.delete_index(
+ **parsed_data
+ )
+
+ return {"status": "Vector index deleted successfully."}
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during vector index deletion: {str(e)}",
+ ) from e
+
+ async def update_document_metadata(input_data):
+ try:
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = (
+ IngestionServiceAdapter.parse_update_document_metadata_input(
+ input_data
+ )
+ )
+
+ document_id = parsed_data["document_id"]
+ metadata = parsed_data["metadata"]
+ user = parsed_data["user"]
+
+ await service.update_document_metadata(
+ document_id=document_id,
+ metadata=metadata,
+ user=user,
+ )
+
+ return {
+ "message": "Document metadata update completed successfully.",
+ "document_id": str(document_id),
+ "task_id": None,
+ }
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during document metadata update: {str(e)}",
+ ) from e
+
+ return {
+ "ingest-files": ingest_files,
+ "ingest-chunks": ingest_chunks,
+ "update-chunk": update_chunk,
+ "update-document-metadata": update_document_metadata,
+ "create-vector-index": create_vector_index,
+ "delete-vector-index": delete_vector_index,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/__init__.py b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py
new file mode 100644
index 00000000..e6a6dec0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py
@@ -0,0 +1,14 @@
+from .auth_service import AuthService
+from .graph_service import GraphService
+from .ingestion_service import IngestionService, IngestionServiceAdapter
+from .management_service import ManagementService
+from .retrieval_service import RetrievalService # type: ignore
+
+__all__ = [
+ "AuthService",
+ "IngestionService",
+ "IngestionServiceAdapter",
+ "ManagementService",
+ "GraphService",
+ "RetrievalService",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py
new file mode 100644
index 00000000..c04dd78c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py
@@ -0,0 +1,316 @@
+import logging
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from core.base import R2RException, Token
+from core.base.api.models import User
+from core.utils import generate_default_user_collection_id
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+class AuthService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ return await self.providers.auth.register(
+ email=email,
+ password=password,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+
+ async def send_verification_email(
+ self, email: str
+ ) -> tuple[str, datetime]:
+ return await self.providers.auth.send_verification_email(email=email)
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ if not self.config.auth.require_email_verification:
+ raise R2RException(
+ status_code=400, message="Email verification is not required"
+ )
+
+ user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
+ verification_code
+ )
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if not user or user.email != email:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ await self.providers.database.users_handler.mark_user_as_verified(
+ user_id
+ )
+ await self.providers.database.users_handler.remove_verification_code(
+ verification_code
+ )
+ return {"message": f"User account {user_id} verified successfully."}
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ return await self.providers.auth.login(email, password)
+
+ async def user(self, token: str) -> User:
+ token_data = await self.providers.auth.decode_token(token)
+ if not token_data.email:
+ raise R2RException(
+ status_code=401, message="Invalid authentication credentials"
+ )
+ user = await self.providers.database.users_handler.get_user_by_email(
+ token_data.email
+ )
+ if user is None:
+ raise R2RException(
+ status_code=401, message="Invalid authentication credentials"
+ )
+ return user
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ return await self.providers.auth.refresh_access_token(refresh_token)
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+ return await self.providers.auth.change_password(
+ user, current_password, new_password
+ )
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ return await self.providers.auth.request_password_reset(email)
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ return await self.providers.auth.confirm_password_reset(
+ reset_token, new_password
+ )
+
+ async def logout(self, token: str) -> dict[str, str]:
+ return await self.providers.auth.logout(token)
+
+ async def update_user(
+ self,
+ user_id: UUID,
+ email: Optional[str] = None,
+ is_superuser: Optional[bool] = None,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ limits_overrides: Optional[dict] = None,
+ merge_limits: bool = False,
+ new_metadata: Optional[dict] = None,
+ ) -> User:
+ user: User = (
+ await self.providers.database.users_handler.get_user_by_id(user_id)
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+ if email is not None:
+ user.email = email
+ if is_superuser is not None:
+ user.is_superuser = is_superuser
+ if name is not None:
+ user.name = name
+ if bio is not None:
+ user.bio = bio
+ if profile_picture is not None:
+ user.profile_picture = profile_picture
+ if limits_overrides is not None:
+ user.limits_overrides = limits_overrides
+ return await self.providers.database.users_handler.update_user(
+ user, merge_limits=merge_limits, new_metadata=new_metadata
+ )
+
+ async def delete_user(
+ self,
+ user_id: UUID,
+ password: Optional[str] = None,
+ delete_vector_data: bool = False,
+ is_superuser: bool = False,
+ ) -> dict[str, str]:
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+ if not is_superuser and not password:
+ raise R2RException(
+ status_code=422, message="Password is required for deletion"
+ )
+ if not (
+ is_superuser
+ or (
+ user.hashed_password is not None
+ and password is not None
+ and self.providers.auth.crypto_provider.verify_password(
+ plain_password=password,
+ hashed_password=user.hashed_password,
+ )
+ )
+ ):
+ raise R2RException(status_code=400, message="Incorrect password")
+ await self.providers.database.users_handler.delete_user_relational(
+ user_id
+ )
+
+ # Delete user's default collection
+ # TODO: We need to better define what happens to the user's data when they are deleted
+ collection_id = generate_default_user_collection_id(user_id)
+ await self.providers.database.collections_handler.delete_collection_relational(
+ collection_id
+ )
+
+ try:
+ await self.providers.database.graphs_handler.delete(
+ collection_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Error deleting graph for collection {collection_id}: {e}"
+ )
+
+ if delete_vector_data:
+ await self.providers.database.chunks_handler.delete_user_vector(
+ user_id
+ )
+ await self.providers.database.chunks_handler.delete_collection_vector(
+ collection_id
+ )
+
+ return {"message": f"User account {user_id} deleted successfully."}
+
+ async def clean_expired_blacklisted_tokens(
+ self,
+ max_age_hours: int = 7 * 24,
+ current_time: Optional[datetime] = None,
+ ):
+ await self.providers.database.token_handler.clean_expired_blacklisted_tokens(
+ max_age_hours, current_time
+ )
+
+ async def get_user_verification_code(
+ self,
+ user_id: UUID,
+ ) -> dict:
+ """Get only the verification code data for a specific user.
+
+ This method should be called after superuser authorization has been
+ verified.
+ """
+ verification_data = await self.providers.database.users_handler.get_user_validation_data(
+ user_id=user_id
+ )
+ return {
+ "verification_code": verification_data["verification_data"][
+ "verification_code"
+ ],
+ "expiry": verification_data["verification_data"][
+ "verification_code_expiry"
+ ],
+ }
+
+ async def get_user_reset_token(
+ self,
+ user_id: UUID,
+ ) -> dict:
+ """Get only the verification code data for a specific user.
+
+ This method should be called after superuser authorization has been
+ verified.
+ """
+ verification_data = await self.providers.database.users_handler.get_user_validation_data(
+ user_id=user_id
+ )
+ return {
+ "reset_token": verification_data["verification_data"][
+ "reset_token"
+ ],
+ "expiry": verification_data["verification_data"][
+ "reset_token_expiry"
+ ],
+ }
+
+ async def send_reset_email(self, email: str) -> dict:
+ """Generate a new verification code and send a reset email to the user.
+ Returns the verification code for testing/sandbox environments.
+
+ Args:
+ email (str): The email address of the user
+
+ Returns:
+ dict: Contains verification_code and message
+ """
+ return await self.providers.auth.send_reset_email(email)
+
+ async def create_user_api_key(
+ self, user_id: UUID, name: Optional[str], description: Optional[str]
+ ) -> dict:
+ """Generate a new API key for the user with optional name and
+ description.
+
+ Args:
+ user_id (UUID): The ID of the user
+ name (Optional[str]): Name of the API key
+ description (Optional[str]): Description of the API key
+
+ Returns:
+ dict: Contains the API key and message
+ """
+ return await self.providers.auth.create_user_api_key(
+ user_id=user_id, name=name, description=description
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ """Delete the API key for the user.
+
+ Args:
+ user_id (UUID): The ID of the user
+ key_id (str): The ID of the API key
+
+ Returns:
+ bool: True if the API key was deleted successfully
+ """
+ return await self.providers.auth.delete_user_api_key(
+ user_id=user_id, key_id=key_id
+ )
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ """List all API keys for the user.
+
+ Args:
+ user_id (UUID): The ID of the user
+
+ Returns:
+ dict: Contains the list of API keys
+ """
+ return await self.providers.auth.list_user_api_keys(user_id)
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/base.py b/.venv/lib/python3.12/site-packages/core/main/services/base.py
new file mode 100644
index 00000000..dcd98fd5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/base.py
@@ -0,0 +1,14 @@
+from abc import ABC
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+
+class Service(ABC):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ self.config = config
+ self.providers = providers
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py
new file mode 100644
index 00000000..56f32cf8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py
@@ -0,0 +1,1358 @@
+import asyncio
+import logging
+import math
+import random
+import re
+import time
+import uuid
+import xml.etree.ElementTree as ET
+from typing import Any, AsyncGenerator, Coroutine, Optional
+from uuid import UUID
+from xml.etree.ElementTree import Element
+
+from core.base import (
+ DocumentChunk,
+ GraphExtraction,
+ GraphExtractionStatus,
+ R2RDocumentProcessingError,
+)
+from core.base.abstractions import (
+ Community,
+ Entity,
+ GenerationConfig,
+ GraphConstructionStatus,
+ R2RException,
+ Relationship,
+ StoreType,
+)
+from core.base.api.models import GraphResponse
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128
+
+
+async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]:
+ """Collects all results from an async generator into a list."""
+ results = []
+ async for res in result_gen:
+ results.append(res)
+ return results
+
+
+class GraphService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def create_entity(
+ self,
+ name: str,
+ description: str,
+ parent_id: UUID,
+ category: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Entity:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return await self.providers.database.graphs_handler.entities.create(
+ name=name,
+ parent_id=parent_id,
+ store_type=StoreType.GRAPHS,
+ category=category,
+ description=description,
+ description_embedding=description_embedding,
+ metadata=metadata,
+ )
+
+ async def update_entity(
+ self,
+ entity_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ category: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Entity:
+ description_embedding = None
+ if description is not None:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return await self.providers.database.graphs_handler.entities.update(
+ entity_id=entity_id,
+ store_type=StoreType.GRAPHS,
+ name=name,
+ description=description,
+ description_embedding=description_embedding,
+ category=category,
+ metadata=metadata,
+ )
+
+ async def delete_entity(
+ self,
+ parent_id: UUID,
+ entity_id: UUID,
+ ):
+ return await self.providers.database.graphs_handler.entities.delete(
+ parent_id=parent_id,
+ entity_ids=[entity_id],
+ store_type=StoreType.GRAPHS,
+ )
+
+ async def get_entities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ entity_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ return await self.providers.database.graphs_handler.get_entities(
+ parent_id=parent_id,
+ offset=offset,
+ limit=limit,
+ entity_ids=entity_ids,
+ entity_names=entity_names,
+ include_embeddings=include_embeddings,
+ )
+
+ async def create_relationship(
+ self,
+ subject: str,
+ subject_id: UUID,
+ predicate: str,
+ object: str,
+ object_id: UUID,
+ parent_id: UUID,
+ description: str | None = None,
+ weight: float | None = 1.0,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Relationship:
+ description_embedding = None
+ if description:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return (
+ await self.providers.database.graphs_handler.relationships.create(
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ parent_id=parent_id,
+ description=description,
+ description_embedding=description_embedding,
+ weight=weight,
+ metadata=metadata,
+ store_type=StoreType.GRAPHS,
+ )
+ )
+
+ async def delete_relationship(
+ self,
+ parent_id: UUID,
+ relationship_id: UUID,
+ ):
+ return (
+ await self.providers.database.graphs_handler.relationships.delete(
+ parent_id=parent_id,
+ relationship_ids=[relationship_id],
+ store_type=StoreType.GRAPHS,
+ )
+ )
+
+ async def update_relationship(
+ self,
+ relationship_id: UUID,
+ subject: Optional[str] = None,
+ subject_id: Optional[UUID] = None,
+ predicate: Optional[str] = None,
+ object: Optional[str] = None,
+ object_id: Optional[UUID] = None,
+ description: Optional[str] = None,
+ weight: Optional[float] = None,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Relationship:
+ description_embedding = None
+ if description is not None:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return (
+ await self.providers.database.graphs_handler.relationships.update(
+ relationship_id=relationship_id,
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ description=description,
+ description_embedding=description_embedding,
+ weight=weight,
+ metadata=metadata,
+ store_type=StoreType.GRAPHS,
+ )
+ )
+
+ async def get_relationships(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ relationship_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ ):
+ return await self.providers.database.graphs_handler.relationships.get(
+ parent_id=parent_id,
+ store_type=StoreType.GRAPHS,
+ offset=offset,
+ limit=limit,
+ relationship_ids=relationship_ids,
+ entity_names=entity_names,
+ )
+
+ async def create_community(
+ self,
+ parent_id: UUID,
+ name: str,
+ summary: str,
+ findings: Optional[list[str]],
+ rating: Optional[float],
+ rating_explanation: Optional[str],
+ ) -> Community:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(summary)
+ )
+ return await self.providers.database.graphs_handler.communities.create(
+ parent_id=parent_id,
+ store_type=StoreType.GRAPHS,
+ name=name,
+ summary=summary,
+ description_embedding=description_embedding,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ async def update_community(
+ self,
+ community_id: UUID,
+ name: Optional[str],
+ summary: Optional[str],
+ findings: Optional[list[str]],
+ rating: Optional[float],
+ rating_explanation: Optional[str],
+ ) -> Community:
+ summary_embedding = None
+ if summary is not None:
+ summary_embedding = str(
+ await self.providers.embedding.async_get_embedding(summary)
+ )
+
+ return await self.providers.database.graphs_handler.communities.update(
+ community_id=community_id,
+ store_type=StoreType.GRAPHS,
+ name=name,
+ summary=summary,
+ summary_embedding=summary_embedding,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ async def delete_community(
+ self,
+ parent_id: UUID,
+ community_id: UUID,
+ ) -> None:
+ await self.providers.database.graphs_handler.communities.delete(
+ parent_id=parent_id,
+ community_id=community_id,
+ )
+
+ async def get_communities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ community_ids: Optional[list[UUID]] = None,
+ community_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ return await self.providers.database.graphs_handler.get_communities(
+ parent_id=parent_id,
+ offset=offset,
+ limit=limit,
+ community_ids=community_ids,
+ include_embeddings=include_embeddings,
+ )
+
+ async def list_graphs(
+ self,
+ offset: int,
+ limit: int,
+ graph_ids: Optional[list[UUID]] = None,
+ collection_id: Optional[UUID] = None,
+ ) -> dict[str, list[GraphResponse] | int]:
+ return await self.providers.database.graphs_handler.list_graphs(
+ offset=offset,
+ limit=limit,
+ filter_graph_ids=graph_ids,
+ filter_collection_id=collection_id,
+ )
+
+ async def update_graph(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> GraphResponse:
+ return await self.providers.database.graphs_handler.update(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ )
+
+ async def reset_graph(self, id: UUID) -> bool:
+ await self.providers.database.graphs_handler.reset(
+ parent_id=id,
+ )
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.PENDING,
+ )
+ return True
+
+ async def get_document_ids_for_create_graph(
+ self,
+ collection_id: UUID,
+ **kwargs,
+ ):
+ document_status_filter = [
+ GraphExtractionStatus.PENDING,
+ GraphExtractionStatus.FAILED,
+ ]
+
+ return await self.providers.database.documents_handler.get_document_ids_by_status(
+ status_type="extraction_status",
+ status=[str(ele) for ele in document_status_filter],
+ collection_id=collection_id,
+ )
+
+ async def graph_search_results_entity_description(
+ self,
+ document_id: UUID,
+ max_description_input_length: int,
+ batch_size: int = 256,
+ **kwargs,
+ ):
+ """A new implementation of the old GraphDescriptionPipe logic inline.
+ No references to pipe objects.
+
+ We:
+ 1) Count how many entities are in the document
+ 2) Process them in batches of `batch_size`
+ 3) For each batch, we retrieve the entity map and possibly call LLM for missing descriptions
+ """
+ start_time = time.time()
+ logger.info(
+ f"GraphService: Running graph_search_results_entity_description for doc={document_id}"
+ )
+
+ # Count how many doc-entities exist
+ entity_count = (
+ await self.providers.database.graphs_handler.get_entity_count(
+ document_id=document_id,
+ distinct=True,
+ entity_table_name="documents_entities", # or whichever table
+ )
+ )
+ logger.info(
+ f"GraphService: Found {entity_count} doc-entities to describe."
+ )
+
+ all_results = []
+ num_batches = math.ceil(entity_count / batch_size)
+
+ for i in range(num_batches):
+ offset = i * batch_size
+ limit = batch_size
+
+ logger.info(
+ f"GraphService: describing batch {i + 1}/{num_batches}, offset={offset}, limit={limit}"
+ )
+
+ # Actually handle describing the entities in the batch
+ # We'll collect them into a list via an async generator
+ gen = self._describe_entities_in_document_batch(
+ document_id=document_id,
+ offset=offset,
+ limit=limit,
+ max_description_input_length=max_description_input_length,
+ )
+ batch_results = await _collect_async_results(gen)
+ all_results.append(batch_results)
+
+ # Mark the doc's extraction status as success
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=document_id,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.SUCCESS,
+ )
+ logger.info(
+ f"GraphService: Completed graph_search_results_entity_description for doc {document_id} in {time.time() - start_time:.2f}s."
+ )
+ return all_results
+
+ async def _describe_entities_in_document_batch(
+ self,
+ document_id: UUID,
+ offset: int,
+ limit: int,
+ max_description_input_length: int,
+ ) -> AsyncGenerator[str, None]:
+ """Core logic that replaces GraphDescriptionPipe._run_logic for a
+ particular document/batch.
+
+ Yields entity-names or some textual result as each entity is updated.
+ """
+ start_time = time.time()
+ logger.info(
+ f"Started describing doc={document_id}, offset={offset}, limit={limit}"
+ )
+
+ # 1) Get the "entity map" from the DB
+ entity_map = (
+ await self.providers.database.graphs_handler.get_entity_map(
+ offset=offset, limit=limit, document_id=document_id
+ )
+ )
+ total_entities = len(entity_map)
+ logger.info(
+ f"_describe_entities_in_document_batch: got {total_entities} items in entity_map for doc={document_id}."
+ )
+
+ # 2) For each entity name in the map, we gather sub-entities and relationships
+ tasks: list[Coroutine[Any, Any, str]] = []
+ tasks.extend(
+ self._process_entity_for_description(
+ entities=[
+ entity if isinstance(entity, Entity) else Entity(**entity)
+ for entity in entity_info["entities"]
+ ],
+ relationships=[
+ rel
+ if isinstance(rel, Relationship)
+ else Relationship(**rel)
+ for rel in entity_info["relationships"]
+ ],
+ document_id=document_id,
+ max_description_input_length=max_description_input_length,
+ )
+ for entity_name, entity_info in entity_map.items()
+ )
+
+ # 3) Wait for all tasks, yield as they complete
+ idx = 0
+ for coro in asyncio.as_completed(tasks):
+ result = await coro
+ idx += 1
+ if idx % 100 == 0:
+ logger.info(
+ f"_describe_entities_in_document_batch: {idx}/{total_entities} described for doc={document_id}"
+ )
+ yield result
+
+ logger.info(
+ f"Finished describing doc={document_id} batch offset={offset} in {time.time() - start_time:.2f}s."
+ )
+
+ async def _process_entity_for_description(
+ self,
+ entities: list[Entity],
+ relationships: list[Relationship],
+ document_id: UUID,
+ max_description_input_length: int,
+ ) -> str:
+ """Adapted from the old process_entity function in
+ GraphDescriptionPipe.
+
+ If entity has no description, call an LLM to create one, then store it.
+ Returns the name of the top entity (or could store more details).
+ """
+
+ def truncate_info(info_list: list[str], max_length: int) -> str:
+ """Shuffles lines of info to try to keep them distinct, then
+ accumulates until hitting max_length."""
+ random.shuffle(info_list)
+ truncated_info = ""
+ current_length = 0
+ for info in info_list:
+ if current_length + len(info) > max_length:
+ break
+ truncated_info += info + "\n"
+ current_length += len(info)
+ return truncated_info
+
+ # Grab a doc-level summary (optional) to feed into the prompt
+ response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+ document_summary = (
+ response["results"][0].summary if response["results"] else None
+ )
+
+ # Synthesize a minimal “entity info” string + relationship summary
+ entity_info = [
+ f"{e.name}, {e.description or 'NONE'}" for e in entities
+ ]
+ relationships_txt = [
+ f"{i + 1}: {r.subject}, {r.object}, {r.predicate} - Summary: {r.description or ''}"
+ for i, r in enumerate(relationships)
+ ]
+
+ # We'll describe only the first entity for simplicity
+ # or you could do them all if needed
+ main_entity = entities[0]
+
+ if not main_entity.description:
+ # We only call LLM if the entity is missing a description
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt,
+ task_inputs={
+ "document_summary": document_summary,
+ "entity_info": truncate_info(
+ entity_info, max_description_input_length
+ ),
+ "relationships_txt": truncate_info(
+ relationships_txt, max_description_input_length
+ ),
+ },
+ )
+
+ # Call the LLM
+ gen_config = (
+ self.providers.database.config.graph_creation_settings.generation_config
+ or GenerationConfig(model=self.config.app.fast_llm)
+ )
+ llm_resp = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=gen_config,
+ )
+ new_description = llm_resp.choices[0].message.content
+
+ if not new_description:
+ logger.error(
+ f"No LLM description returned for entity={main_entity.name}"
+ )
+ return main_entity.name
+
+ # create embedding
+ embed = (
+ await self.providers.embedding.async_get_embeddings(
+ [new_description]
+ )
+ )[0]
+
+ # update DB
+ main_entity.description = new_description
+ main_entity.description_embedding = embed
+
+ # Use a method to upsert entity in `documents_entities` or your table
+ await self.providers.database.graphs_handler.add_entities(
+ [main_entity],
+ table_name="documents_entities",
+ )
+
+ return main_entity.name
+
+ async def graph_search_results_clustering(
+ self,
+ collection_id: UUID,
+ generation_config: GenerationConfig,
+ leiden_params: dict,
+ **kwargs,
+ ):
+ """
+ Replacement for the old GraphClusteringPipe logic:
+ 1) call perform_graph_clustering on the DB
+ 2) return the result
+ """
+ logger.info(
+ f"Running inline clustering for collection={collection_id} with params={leiden_params}"
+ )
+ return await self._perform_graph_clustering(
+ collection_id=collection_id,
+ generation_config=generation_config,
+ leiden_params=leiden_params,
+ )
+
+ async def _perform_graph_clustering(
+ self,
+ collection_id: UUID,
+ generation_config: GenerationConfig,
+ leiden_params: dict,
+ ) -> dict:
+ """The actual clustering logic (previously in
+ GraphClusteringPipe.cluster_graph_search_results)."""
+ num_communities = await self.providers.database.graphs_handler.perform_graph_clustering(
+ collection_id=collection_id,
+ leiden_params=leiden_params,
+ )
+ return {"num_communities": num_communities}
+
+ async def graph_search_results_community_summary(
+ self,
+ offset: int,
+ limit: int,
+ max_summary_input_length: int,
+ generation_config: GenerationConfig,
+ collection_id: UUID,
+ leiden_params: Optional[dict] = None,
+ **kwargs,
+ ):
+ """Replacement for the old GraphCommunitySummaryPipe logic.
+
+ Summarizes communities after clustering. Returns an async generator or
+ you can collect into a list.
+ """
+ logger.info(
+ f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}"
+ )
+ # We call an internal function that yields summaries
+ gen = self._summarize_communities(
+ offset=offset,
+ limit=limit,
+ max_summary_input_length=max_summary_input_length,
+ generation_config=generation_config,
+ collection_id=collection_id,
+ leiden_params=leiden_params or {},
+ )
+ return await _collect_async_results(gen)
+
+ async def _summarize_communities(
+ self,
+ offset: int,
+ limit: int,
+ max_summary_input_length: int,
+ generation_config: GenerationConfig,
+ collection_id: UUID,
+ leiden_params: dict,
+ ) -> AsyncGenerator[dict, None]:
+ """Does the community summary logic from
+ GraphCommunitySummaryPipe._run_logic.
+
+ Yields each summary dictionary as it completes.
+ """
+ start_time = time.time()
+ logger.info(
+ f"Starting community summarization for collection={collection_id}"
+ )
+
+ # get all entities & relationships
+ (
+ all_entities,
+ _,
+ ) = await self.providers.database.graphs_handler.get_entities(
+ parent_id=collection_id,
+ offset=0,
+ limit=-1,
+ include_embeddings=False,
+ )
+ (
+ all_relationships,
+ _,
+ ) = await self.providers.database.graphs_handler.get_relationships(
+ parent_id=collection_id,
+ offset=0,
+ limit=-1,
+ include_embeddings=False,
+ )
+
+ # We can optionally re-run the clustering to produce fresh community assignments
+ (
+ _,
+ community_clusters,
+ ) = await self.providers.database.graphs_handler._cluster_and_add_community_info(
+ relationships=all_relationships,
+ leiden_params=leiden_params,
+ collection_id=collection_id,
+ )
+
+ # Group clusters
+ clusters: dict[Any, list[str]] = {}
+ for item in community_clusters:
+ cluster_id = item["cluster"]
+ node_name = item["node"]
+ clusters.setdefault(cluster_id, []).append(node_name)
+
+ # create an async job for each cluster
+ tasks: list[Coroutine[Any, Any, dict]] = []
+
+ tasks.extend(
+ self._process_community_summary(
+ community_id=uuid.uuid4(),
+ nodes=nodes,
+ all_entities=all_entities,
+ all_relationships=all_relationships,
+ max_summary_input_length=max_summary_input_length,
+ generation_config=generation_config,
+ collection_id=collection_id,
+ )
+ for nodes in clusters.values()
+ )
+
+ total_jobs = len(tasks)
+ results_returned = 0
+ total_errors = 0
+
+ for coro in asyncio.as_completed(tasks):
+ summary = await coro
+ results_returned += 1
+ if results_returned % 50 == 0:
+ logger.info(
+ f"Community summaries: {results_returned}/{total_jobs} done in {time.time() - start_time:.2f}s"
+ )
+ if "error" in summary:
+ total_errors += 1
+ yield summary
+
+ if total_errors > 0:
+ logger.warning(
+ f"{total_errors} communities failed summarization out of {total_jobs}"
+ )
+
+ async def _process_community_summary(
+ self,
+ community_id: UUID,
+ nodes: list[str],
+ all_entities: list[Entity],
+ all_relationships: list[Relationship],
+ max_summary_input_length: int,
+ generation_config: GenerationConfig,
+ collection_id: UUID,
+ ) -> dict:
+ """
+ Summarize a single community: gather all relevant entities/relationships, call LLM to generate an XML block,
+ parse it, store the result as a community in DB.
+ """
+ # (Equivalent to process_community in old code)
+ # fetch the collection description (optional)
+ response = await self.providers.database.collections_handler.get_collections_overview(
+ offset=0,
+ limit=1,
+ filter_collection_ids=[collection_id],
+ )
+ collection_description = (
+ response["results"][0].description if response["results"] else None # type: ignore
+ )
+
+ # filter out relevant entities / relationships
+ entities = [e for e in all_entities if e.name in nodes]
+ relationships = [
+ r
+ for r in all_relationships
+ if r.subject in nodes and r.object in nodes
+ ]
+ if not entities and not relationships:
+ return {
+ "community_id": community_id,
+ "error": f"No data in this community (nodes={nodes})",
+ }
+
+ # Create the big input text for the LLM
+ input_text = await self._community_summary_prompt(
+ entities,
+ relationships,
+ max_summary_input_length,
+ )
+
+ # Attempt up to 3 times to parse
+ for attempt in range(3):
+ try:
+ # Build the prompt
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=self.providers.database.config.graph_enrichment_settings.graph_communities_prompt,
+ task_inputs={
+ "collection_description": collection_description,
+ "input_text": input_text,
+ },
+ )
+ llm_resp = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=generation_config,
+ )
+ llm_text = llm_resp.choices[0].message.content or ""
+
+ # find <community>...</community> XML
+ match = re.search(
+ r"<community>.*?</community>", llm_text, re.DOTALL
+ )
+ if not match:
+ raise ValueError(
+ "No <community> XML found in LLM response"
+ )
+
+ xml_content = match.group(0)
+ root = ET.fromstring(xml_content)
+
+ # extract fields
+ name_elem = root.find("name")
+ summary_elem = root.find("summary")
+ rating_elem = root.find("rating")
+ rating_expl_elem = root.find("rating_explanation")
+ findings_elem = root.find("findings")
+
+ name = name_elem.text if name_elem is not None else ""
+ summary = summary_elem.text if summary_elem is not None else ""
+ rating = (
+ float(rating_elem.text)
+ if isinstance(rating_elem, Element) and rating_elem.text
+ else ""
+ )
+ rating_explanation = (
+ rating_expl_elem.text
+ if rating_expl_elem is not None
+ else None
+ )
+ findings = (
+ [f.text for f in findings_elem.findall("finding")]
+ if findings_elem is not None
+ else []
+ )
+
+ # build embedding
+ embed_text = (
+ "Summary:\n"
+ + (summary or "")
+ + "\n\nFindings:\n"
+ + "\n".join(
+ finding for finding in findings if finding is not None
+ )
+ )
+ embedding = await self.providers.embedding.async_get_embedding(
+ embed_text
+ )
+
+ # build Community object
+ community = Community(
+ community_id=community_id,
+ collection_id=collection_id,
+ name=name,
+ summary=summary,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ findings=findings,
+ description_embedding=embedding,
+ )
+
+ # store it
+ await self.providers.database.graphs_handler.add_community(
+ community
+ )
+
+ return {
+ "community_id": community_id,
+ "name": name,
+ }
+
+ except Exception as e:
+ logger.error(
+ f"Error summarizing community {community_id}: {e}"
+ )
+ if attempt == 2:
+ return {"community_id": community_id, "error": str(e)}
+ await asyncio.sleep(1)
+
+ # fallback
+ return {"community_id": community_id, "error": "Failed after retries"}
+
+ async def _community_summary_prompt(
+ self,
+ entities: list[Entity],
+ relationships: list[Relationship],
+ max_summary_input_length: int,
+ ) -> str:
+ """Gathers the entity/relationship text, tries not to exceed
+ `max_summary_input_length`."""
+ # Group them by entity.name
+ entity_map: dict[str, dict] = {}
+ for e in entities:
+ entity_map.setdefault(
+ e.name, {"entities": [], "relationships": []}
+ )
+ entity_map[e.name]["entities"].append(e)
+
+ for r in relationships:
+ # subject
+ entity_map.setdefault(
+ r.subject, {"entities": [], "relationships": []}
+ )
+ entity_map[r.subject]["relationships"].append(r)
+
+ # sort by # of relationships
+ sorted_entries = sorted(
+ entity_map.items(),
+ key=lambda x: len(x[1]["relationships"]),
+ reverse=True,
+ )
+
+ # build up the prompt text
+ prompt_chunks = []
+ cur_len = 0
+ for entity_name, data in sorted_entries:
+ block = f"\nEntity: {entity_name}\nDescriptions:\n"
+ block += "\n".join(
+ f"{e.id},{(e.description or '')}" for e in data["entities"]
+ )
+ block += "\nRelationships:\n"
+ block += "\n".join(
+ f"{r.id},{r.subject},{r.object},{r.predicate},{r.description or ''}"
+ for r in data["relationships"]
+ )
+ # check length
+ if cur_len + len(block) > max_summary_input_length:
+ prompt_chunks.append(
+ block[: max_summary_input_length - cur_len]
+ )
+ break
+ else:
+ prompt_chunks.append(block)
+ cur_len += len(block)
+
+ return "".join(prompt_chunks)
+
+ async def delete(
+ self,
+ collection_id: UUID,
+ **kwargs,
+ ):
+ return await self.providers.database.graphs_handler.delete(
+ collection_id=collection_id,
+ )
+
+ async def graph_search_results_extraction(
+ self,
+ document_id: UUID,
+ generation_config: GenerationConfig,
+ entity_types: list[str],
+ relation_types: list[str],
+ chunk_merge_count: int,
+ filter_out_existing_chunks: bool = True,
+ total_tasks: Optional[int] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]:
+ """The original “extract Graph from doc” logic, but inlined instead of
+ referencing a pipe."""
+ start_time = time.time()
+
+ logger.info(
+ f"Graph Extraction: Processing document {document_id} for graph extraction"
+ )
+
+ # Retrieve chunks from DB
+ chunks = []
+ limit = 100
+ offset = 0
+ while True:
+ chunk_req = await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=offset,
+ limit=limit,
+ )
+ new_chunk_objs = [
+ DocumentChunk(
+ id=chunk["id"],
+ document_id=chunk["document_id"],
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ data=chunk["text"],
+ metadata=chunk["metadata"],
+ )
+ for chunk in chunk_req["results"]
+ ]
+ chunks.extend(new_chunk_objs)
+ if len(chunk_req["results"]) < limit:
+ break
+ offset += limit
+
+ if not chunks:
+ logger.info(f"No chunks found for document {document_id}")
+ raise R2RException(
+ message="No chunks found for document",
+ status_code=404,
+ )
+
+ # Possibly filter out any chunks that have already been processed
+ if filter_out_existing_chunks:
+ existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
+ document_id=document_id
+ )
+ before_count = len(chunks)
+ chunks = [c for c in chunks if c.id not in existing_chunk_ids]
+ logger.info(
+ f"Filtered out {len(existing_chunk_ids)} existing chunk-IDs. {before_count}->{len(chunks)} remain."
+ )
+ if not chunks:
+ return # nothing left to yield
+
+ # sort by chunk_order if present
+ chunks = sorted(
+ chunks,
+ key=lambda x: x.metadata.get("chunk_order", float("inf")),
+ )
+
+ # group them
+ grouped_chunks = [
+ chunks[i : i + chunk_merge_count]
+ for i in range(0, len(chunks), chunk_merge_count)
+ ]
+
+ logger.info(
+ f"Graph Extraction: Created {len(grouped_chunks)} tasks for doc={document_id}"
+ )
+ tasks = [
+ asyncio.create_task(
+ self._extract_graph_search_results_from_chunk_group(
+ chunk_group,
+ generation_config,
+ entity_types,
+ relation_types,
+ )
+ )
+ for chunk_group in grouped_chunks
+ ]
+
+ completed_tasks = 0
+ for t in asyncio.as_completed(tasks):
+ try:
+ yield await t
+ completed_tasks += 1
+ if completed_tasks % 100 == 0:
+ logger.info(
+ f"Graph Extraction: completed {completed_tasks}/{len(tasks)} tasks"
+ )
+ except Exception as e:
+ logger.error(f"Error extracting from chunk group: {e}")
+ yield R2RDocumentProcessingError(
+ document_id=document_id,
+ error_message=str(e),
+ )
+
+ logger.info(
+ f"Graph Extraction: done with {document_id}, time={time.time() - start_time:.2f}s"
+ )
+
+ async def _extract_graph_search_results_from_chunk_group(
+ self,
+ chunks: list[DocumentChunk],
+ generation_config: GenerationConfig,
+ entity_types: list[str],
+ relation_types: list[str],
+ retries: int = 5,
+ delay: int = 2,
+ ) -> GraphExtraction:
+ """(Equivalent to _extract_graph_search_results in old code.) Merges
+ chunk data, calls LLM, parses XML, returns GraphExtraction object."""
+ combined_extraction: str = " ".join(
+ [
+ c.data.decode("utf-8") if isinstance(c.data, bytes) else c.data
+ for c in chunks
+ if c.data
+ ]
+ )
+
+ # Possibly get doc-level summary
+ doc_id = chunks[0].document_id
+ response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[doc_id],
+ )
+ document_summary = (
+ response["results"][0].summary if response["results"] else None
+ )
+
+ # Build messages/prompt
+ prompt_name = self.providers.database.config.graph_creation_settings.graph_extraction_prompt
+ messages = (
+ await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=prompt_name,
+ task_inputs={
+ "document_summary": document_summary or "",
+ "input": combined_extraction,
+ "entity_types": "\n".join(entity_types),
+ "relation_types": "\n".join(relation_types),
+ },
+ )
+ )
+
+ for attempt in range(retries):
+ try:
+ resp = await self.providers.llm.aget_completion(
+ messages, generation_config=generation_config
+ )
+ graph_search_results_str = resp.choices[0].message.content
+
+ if not graph_search_results_str:
+ raise R2RException(
+ "No extraction found in LLM response.",
+ 400,
+ )
+
+ # parse the XML
+ (
+ entities,
+ relationships,
+ ) = await self._parse_graph_search_results_extraction_xml(
+ graph_search_results_str, chunks
+ )
+ return GraphExtraction(
+ entities=entities, relationships=relationships
+ )
+
+ except Exception as e:
+ if attempt < retries - 1:
+ await asyncio.sleep(delay)
+ continue
+ else:
+ logger.error(
+ f"All extraction attempts for doc={doc_id} and chunks{[chunk.id for chunk in chunks]} failed with error:\n{e}"
+ )
+ return GraphExtraction(entities=[], relationships=[])
+
+ return GraphExtraction(entities=[], relationships=[])
+
+ async def _parse_graph_search_results_extraction_xml(
+ self, response_str: str, chunks: list[DocumentChunk]
+ ) -> tuple[list[Entity], list[Relationship]]:
+ """Helper to parse the LLM's XML format, handle edge cases/cleanup,
+ produce Entities/Relationships."""
+
+ def sanitize_xml(r: str) -> str:
+ # Remove markdown fences
+ r = re.sub(r"```xml|```", "", r)
+ # Remove xml instructions or userStyle
+ r = re.sub(r"<\?.*?\?>", "", r)
+ r = re.sub(r"<userStyle>.*?</userStyle>", "", r)
+ # Replace bare `&` with `&amp;`
+ r = re.sub(r"&(?!amp;|quot;|apos;|lt;|gt;)", "&amp;", r)
+ # Also remove <root> if it appears
+ r = r.replace("<root>", "").replace("</root>", "")
+ return r.strip()
+
+ cleaned_xml = sanitize_xml(response_str)
+ wrapped = f"<root>{cleaned_xml}</root>"
+ try:
+ root = ET.fromstring(wrapped)
+ except ET.ParseError:
+ raise R2RException(
+ f"Failed to parse XML:\nData: {wrapped[:1000]}...", 400
+ ) from None
+
+ entities_elems = root.findall(".//entity")
+ if (
+ len(response_str) > MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH
+ and len(entities_elems) == 0
+ ):
+ raise R2RException(
+ f"No <entity> found in LLM XML, possibly malformed. Response excerpt: {response_str[:300]}",
+ 400,
+ )
+
+ # build entity objects
+ doc_id = chunks[0].document_id
+ chunk_ids = [c.id for c in chunks]
+ entities_list: list[Entity] = []
+ for element in entities_elems:
+ name_attr = element.get("name")
+ type_elem = element.find("type")
+ desc_elem = element.find("description")
+ category = type_elem.text if type_elem is not None else None
+ desc = desc_elem.text if desc_elem is not None else None
+ desc_embed = await self.providers.embedding.async_get_embedding(
+ desc or ""
+ )
+ ent = Entity(
+ category=category,
+ description=desc,
+ name=name_attr,
+ parent_id=doc_id,
+ chunk_ids=chunk_ids,
+ description_embedding=desc_embed,
+ attributes={},
+ )
+ entities_list.append(ent)
+
+ # build relationship objects
+ relationships_list: list[Relationship] = []
+ rel_elems = root.findall(".//relationship")
+ for r_elem in rel_elems:
+ source_elem = r_elem.find("source")
+ target_elem = r_elem.find("target")
+ type_elem = r_elem.find("type")
+ desc_elem = r_elem.find("description")
+ weight_elem = r_elem.find("weight")
+ try:
+ subject = source_elem.text if source_elem is not None else ""
+ object_ = target_elem.text if target_elem is not None else ""
+ predicate = type_elem.text if type_elem is not None else ""
+ desc = desc_elem.text if desc_elem is not None else ""
+ weight = (
+ float(weight_elem.text)
+ if isinstance(weight_elem, Element) and weight_elem.text
+ else ""
+ )
+ embed = await self.providers.embedding.async_get_embedding(
+ desc or ""
+ )
+
+ rel = Relationship(
+ subject=subject,
+ predicate=predicate,
+ object=object_,
+ description=desc,
+ weight=weight,
+ parent_id=doc_id,
+ chunk_ids=chunk_ids,
+ attributes={},
+ description_embedding=embed,
+ )
+ relationships_list.append(rel)
+ except Exception:
+ continue
+ return entities_list, relationships_list
+
+ async def store_graph_search_results_extractions(
+ self,
+ graph_search_results_extractions: list[GraphExtraction],
+ ):
+ """Stores a batch of knowledge graph extractions in the DB."""
+ for extraction in graph_search_results_extractions:
+ # Map name->id after creation
+ entities_id_map = {}
+ for e in extraction.entities:
+ if e.parent_id is not None:
+ result = await self.providers.database.graphs_handler.entities.create(
+ name=e.name,
+ parent_id=e.parent_id,
+ store_type=StoreType.DOCUMENTS,
+ category=e.category,
+ description=e.description,
+ description_embedding=e.description_embedding,
+ chunk_ids=e.chunk_ids,
+ metadata=e.metadata,
+ )
+ entities_id_map[e.name] = result.id
+ else:
+ logger.warning(f"Skipping entity with None parent_id: {e}")
+
+ # Insert relationships
+ for rel in extraction.relationships:
+ subject_id = entities_id_map.get(rel.subject)
+ object_id = entities_id_map.get(rel.object)
+ parent_id = rel.parent_id
+
+ if any(
+ id is None for id in (subject_id, object_id, parent_id)
+ ):
+ logger.warning(f"Missing ID for relationship: {rel}")
+ continue
+
+ assert isinstance(subject_id, UUID)
+ assert isinstance(object_id, UUID)
+ assert isinstance(parent_id, UUID)
+
+ await self.providers.database.graphs_handler.relationships.create(
+ subject=rel.subject,
+ subject_id=subject_id,
+ predicate=rel.predicate,
+ object=rel.object,
+ object_id=object_id,
+ parent_id=parent_id,
+ description=rel.description,
+ description_embedding=rel.description_embedding,
+ weight=rel.weight,
+ metadata=rel.metadata,
+ store_type=StoreType.DOCUMENTS,
+ )
+
+ async def deduplicate_document_entities(
+ self,
+ document_id: UUID,
+ ):
+ """
+ Inlined from old code: merges duplicates by name, calls LLM for a new consolidated description, updates the record.
+ """
+ merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks(
+ parent_id=document_id,
+ store_type=StoreType.DOCUMENTS,
+ )
+
+ # Grab doc summary
+ response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+ document_summary = (
+ response["results"][0].summary if response["results"] else None
+ )
+
+ # For each merged entity
+ for original_entities, merged_entity in merged_results:
+ # Summarize them with LLM
+ entity_info = "\n".join(
+ e.description for e in original_entities if e.description
+ )
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt,
+ task_inputs={
+ "document_summary": document_summary,
+ "entity_info": f"{merged_entity.name}\n{entity_info}",
+ "relationships_txt": "",
+ },
+ )
+ gen_config = (
+ self.config.database.graph_creation_settings.generation_config
+ or GenerationConfig(model=self.config.app.fast_llm)
+ )
+ resp = await self.providers.llm.aget_completion(
+ messages, generation_config=gen_config
+ )
+ new_description = resp.choices[0].message.content
+
+ new_embedding = await self.providers.embedding.async_get_embedding(
+ new_description or ""
+ )
+
+ if merged_entity.id is not None:
+ await self.providers.database.graphs_handler.entities.update(
+ entity_id=merged_entity.id,
+ store_type=StoreType.DOCUMENTS,
+ description=new_description,
+ description_embedding=str(new_embedding),
+ )
+ else:
+ logger.warning("Skipping update for entity with None id")
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py
new file mode 100644
index 00000000..55b06911
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py
@@ -0,0 +1,983 @@
+import asyncio
+import json
+import logging
+from datetime import datetime
+from typing import Any, AsyncGenerator, Optional, Sequence
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core.base import (
+ Document,
+ DocumentChunk,
+ DocumentResponse,
+ DocumentType,
+ GenerationConfig,
+ IngestionStatus,
+ R2RException,
+ RawChunk,
+ UnprocessedChunk,
+ Vector,
+ VectorEntry,
+ VectorType,
+ generate_id,
+)
+from core.base.abstractions import (
+ ChunkEnrichmentSettings,
+ IndexMeasure,
+ IndexMethod,
+ R2RDocumentProcessingError,
+ VectorTableName,
+)
+from core.base.api.models import User
+from shared.abstractions import PDFParsingError, PopplerNotFoundError
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+logger = logging.getLogger()
+STARTING_VERSION = "v0"
+
+
+class IngestionService:
+ """A refactored IngestionService that inlines all pipe logic for parsing,
+ embedding, and vector storage directly in its methods."""
+
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ) -> None:
+ self.config = config
+ self.providers = providers
+
+ async def ingest_file_ingress(
+ self,
+ file_data: dict,
+ user: User,
+ document_id: UUID,
+ size_in_bytes,
+ metadata: Optional[dict] = None,
+ version: Optional[str] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ """Pre-ingests a file by creating or validating the DocumentResponse
+ entry.
+
+ Does not actually parse/ingest the content. (See parse_file() for that
+ step.)
+ """
+ try:
+ if not file_data:
+ raise R2RException(
+ status_code=400, message="No files provided for ingestion."
+ )
+ if not file_data.get("filename"):
+ raise R2RException(
+ status_code=400, message="File name not provided."
+ )
+
+ metadata = metadata or {}
+ version = version or STARTING_VERSION
+
+ document_info = self.create_document_info_from_file(
+ document_id,
+ user,
+ file_data["filename"],
+ metadata,
+ version,
+ size_in_bytes,
+ )
+
+ existing_document_info = (
+ await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_user_ids=[user.id],
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+
+ # Validate ingestion status for re-ingestion
+ if len(existing_document_info) > 0:
+ existing_doc = existing_document_info[0]
+ if existing_doc.ingestion_status == IngestionStatus.SUCCESS:
+ raise R2RException(
+ status_code=409,
+ message=(
+ f"Document {document_id} already exists. "
+ "Submit a DELETE request to `/documents/{document_id}` "
+ "to delete this document and allow for re-ingestion."
+ ),
+ )
+ elif existing_doc.ingestion_status != IngestionStatus.FAILED:
+ raise R2RException(
+ status_code=409,
+ message=(
+ f"Document {document_id} is currently ingesting "
+ f"with status {existing_doc.ingestion_status}."
+ ),
+ )
+
+ # Set to PARSING until we actually parse
+ document_info.ingestion_status = IngestionStatus.PARSING
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ document_info
+ )
+
+ return {
+ "info": document_info,
+ }
+ except R2RException as e:
+ logger.error(f"R2RException in ingest_file_ingress: {str(e)}")
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error during ingestion: {str(e)}"
+ ) from e
+
+ def create_document_info_from_file(
+ self,
+ document_id: UUID,
+ user: User,
+ file_name: str,
+ metadata: dict,
+ version: str,
+ size_in_bytes: int,
+ ) -> DocumentResponse:
+ file_extension = (
+ file_name.split(".")[-1].lower() if file_name != "N/A" else "txt"
+ )
+ if file_extension.upper() not in DocumentType.__members__:
+ raise R2RException(
+ status_code=415,
+ message=f"'{file_extension}' is not a valid DocumentType.",
+ )
+
+ metadata = metadata or {}
+ metadata["version"] = version
+
+ return DocumentResponse(
+ id=document_id,
+ owner_id=user.id,
+ collection_ids=metadata.get("collection_ids", []),
+ document_type=DocumentType[file_extension.upper()],
+ title=(
+ metadata.get("title", file_name.split("/")[-1])
+ if file_name != "N/A"
+ else "N/A"
+ ),
+ metadata=metadata,
+ version=version,
+ size_in_bytes=size_in_bytes,
+ ingestion_status=IngestionStatus.PENDING,
+ created_at=datetime.now(),
+ updated_at=datetime.now(),
+ )
+
+ def _create_document_info_from_chunks(
+ self,
+ document_id: UUID,
+ user: User,
+ chunks: list[RawChunk],
+ metadata: dict,
+ version: str,
+ ) -> DocumentResponse:
+ metadata = metadata or {}
+ metadata["version"] = version
+
+ return DocumentResponse(
+ id=document_id,
+ owner_id=user.id,
+ collection_ids=metadata.get("collection_ids", []),
+ document_type=DocumentType.TXT,
+ title=metadata.get("title", f"Ingested Chunks - {document_id}"),
+ metadata=metadata,
+ version=version,
+ size_in_bytes=sum(
+ len(chunk.text.encode("utf-8")) for chunk in chunks
+ ),
+ ingestion_status=IngestionStatus.PENDING,
+ created_at=datetime.now(),
+ updated_at=datetime.now(),
+ )
+
+ async def parse_file(
+ self,
+ document_info: DocumentResponse,
+ ingestion_config: dict | None,
+ ) -> AsyncGenerator[DocumentChunk, None]:
+ """Reads the file content from the DB, calls the ingestion
+ provider to parse, and yields DocumentChunk objects."""
+ version = document_info.version or "v0"
+ ingestion_config_override = ingestion_config or {}
+
+ # The ingestion config might specify a different provider, etc.
+ override_provider = ingestion_config_override.pop("provider", None)
+ if (
+ override_provider
+ and override_provider != self.providers.ingestion.config.provider
+ ):
+ raise ValueError(
+ f"Provider '{override_provider}' does not match ingestion provider "
+ f"'{self.providers.ingestion.config.provider}'."
+ )
+
+ try:
+ # Pull file from DB
+ retrieved = (
+ await self.providers.database.files_handler.retrieve_file(
+ document_info.id
+ )
+ )
+ if not retrieved:
+ # No file found in the DB, can't parse
+ raise R2RDocumentProcessingError(
+ document_id=document_info.id,
+ error_message="No file content found in DB for this document.",
+ )
+
+ file_name, file_wrapper, file_size = retrieved
+
+ # Read the content
+ with file_wrapper as file_content_stream:
+ file_content = file_content_stream.read()
+
+ # Build a barebones Document object
+ doc = Document(
+ id=document_info.id,
+ collection_ids=document_info.collection_ids,
+ owner_id=document_info.owner_id,
+ metadata={
+ "document_type": document_info.document_type.value,
+ **document_info.metadata,
+ },
+ document_type=document_info.document_type,
+ )
+
+ # Delegate to the ingestion provider to parse
+ async for extraction in self.providers.ingestion.parse(
+ file_content, # raw bytes
+ doc,
+ ingestion_config_override,
+ ):
+ # Adjust chunk ID to incorporate version
+ # or any other needed transformations
+ extraction.id = generate_id(f"{extraction.id}_{version}")
+ extraction.metadata["version"] = version
+ yield extraction
+
+ except (PopplerNotFoundError, PDFParsingError) as e:
+ raise R2RDocumentProcessingError(
+ error_message=e.message,
+ document_id=document_info.id,
+ status_code=e.status_code,
+ ) from None
+ except Exception as e:
+ if isinstance(e, R2RException):
+ raise
+ raise R2RDocumentProcessingError(
+ document_id=document_info.id,
+ error_message=f"Error parsing document: {str(e)}",
+ ) from e
+
+ async def augment_document_info(
+ self,
+ document_info: DocumentResponse,
+ chunked_documents: list[dict],
+ ) -> None:
+ if not self.config.ingestion.skip_document_summary:
+ document = f"Document Title: {document_info.title}\n"
+ if document_info.metadata != {}:
+ document += f"Document Metadata: {json.dumps(document_info.metadata)}\n"
+
+ document += "Document Text:\n"
+ for chunk in chunked_documents[
+ : self.config.ingestion.chunks_for_document_summary
+ ]:
+ document += chunk["data"]
+
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ system_prompt_name=self.config.ingestion.document_summary_system_prompt,
+ task_prompt_name=self.config.ingestion.document_summary_task_prompt,
+ task_inputs={
+ "document": document[
+ : self.config.ingestion.document_summary_max_length
+ ]
+ },
+ )
+
+ response = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=GenerationConfig(
+ model=self.config.ingestion.document_summary_model
+ or self.config.app.fast_llm
+ ),
+ )
+
+ document_info.summary = response.choices[0].message.content # type: ignore
+
+ if not document_info.summary:
+ raise ValueError("Expected a generated response.")
+
+ embedding = await self.providers.embedding.async_get_embedding(
+ text=document_info.summary,
+ )
+ document_info.summary_embedding = embedding
+ return
+
+ async def embed_document(
+ self,
+ chunked_documents: list[dict],
+ embedding_batch_size: int = 8,
+ ) -> AsyncGenerator[VectorEntry, None]:
+ """Inline replacement for the old embedding_pipe.run(...).
+
+ Batches the embedding calls and yields VectorEntry objects.
+ """
+ if not chunked_documents:
+ return
+
+ concurrency_limit = (
+ self.providers.embedding.config.concurrent_request_limit or 5
+ )
+ extraction_batch: list[DocumentChunk] = []
+ tasks: set[asyncio.Task] = set()
+
+ async def process_batch(
+ batch: list[DocumentChunk],
+ ) -> list[VectorEntry]:
+ # All text from the batch
+ texts = [
+ (
+ ex.data.decode("utf-8")
+ if isinstance(ex.data, bytes)
+ else ex.data
+ )
+ for ex in batch
+ ]
+ # Retrieve embeddings in bulk
+ vectors = await self.providers.embedding.async_get_embeddings(
+ texts, # list of strings
+ )
+ # Zip them back together
+ results = []
+ for raw_vector, extraction in zip(vectors, batch, strict=False):
+ results.append(
+ VectorEntry(
+ id=extraction.id,
+ document_id=extraction.document_id,
+ owner_id=extraction.owner_id,
+ collection_ids=extraction.collection_ids,
+ vector=Vector(data=raw_vector, type=VectorType.FIXED),
+ text=(
+ extraction.data.decode("utf-8")
+ if isinstance(extraction.data, bytes)
+ else str(extraction.data)
+ ),
+ metadata={**extraction.metadata},
+ )
+ )
+ return results
+
+ async def run_process_batch(batch: list[DocumentChunk]):
+ return await process_batch(batch)
+
+ # Convert each chunk dict to a DocumentChunk
+ for chunk_dict in chunked_documents:
+ extraction = DocumentChunk.from_dict(chunk_dict)
+ extraction_batch.append(extraction)
+
+ # If we hit a batch threshold, spawn a task
+ if len(extraction_batch) >= embedding_batch_size:
+ tasks.add(
+ asyncio.create_task(run_process_batch(extraction_batch))
+ )
+ extraction_batch = []
+
+ # If tasks are at concurrency limit, wait for the first to finish
+ while len(tasks) >= concurrency_limit:
+ done, tasks = await asyncio.wait(
+ tasks, return_when=asyncio.FIRST_COMPLETED
+ )
+ for t in done:
+ for vector_entry in await t:
+ yield vector_entry
+
+ # Handle any leftover items
+ if extraction_batch:
+ tasks.add(asyncio.create_task(run_process_batch(extraction_batch)))
+
+ # Gather remaining tasks
+ for future_task in asyncio.as_completed(tasks):
+ for vector_entry in await future_task:
+ yield vector_entry
+
+ async def store_embeddings(
+ self,
+ embeddings: Sequence[dict | VectorEntry],
+ storage_batch_size: int = 128,
+ ) -> AsyncGenerator[str, None]:
+ """Inline replacement for the old vector_storage_pipe.run(...).
+
+ Batches up the vector entries, enforces usage limits, stores them, and
+ yields a success/error string (or you could yield a StorageResult).
+ """
+ if not embeddings:
+ return
+
+ vector_entries: list[VectorEntry] = []
+ for item in embeddings:
+ if isinstance(item, VectorEntry):
+ vector_entries.append(item)
+ else:
+ vector_entries.append(VectorEntry.from_dict(item))
+
+ vector_batch: list[VectorEntry] = []
+ document_counts: dict[UUID, int] = {}
+
+ # We'll track usage from the first user we see; if your scenario allows
+ # multiple user owners in a single ingestion, you'd need to refine usage checks.
+ current_usage = None
+ user_id_for_usage_check: UUID | None = None
+
+ count = 0
+
+ for msg in vector_entries:
+ # If we haven't set usage yet, do so on the first chunk
+ if current_usage is None:
+ user_id_for_usage_check = msg.owner_id
+ usage_data = (
+ await self.providers.database.chunks_handler.list_chunks(
+ limit=1,
+ offset=0,
+ filters={"owner_id": msg.owner_id},
+ )
+ )
+ current_usage = usage_data["total_entries"]
+
+ # Figure out the user's limit
+ user = await self.providers.database.users_handler.get_user_by_id(
+ msg.owner_id
+ )
+ max_chunks = (
+ self.providers.database.config.app.default_max_chunks_per_user
+ )
+ if user.limits_overrides and "max_chunks" in user.limits_overrides:
+ max_chunks = user.limits_overrides["max_chunks"]
+
+ # Add to our local batch
+ vector_batch.append(msg)
+ document_counts[msg.document_id] = (
+ document_counts.get(msg.document_id, 0) + 1
+ )
+ count += 1
+
+ # Check usage
+ if (
+ current_usage is not None
+ and (current_usage + len(vector_batch) + count) > max_chunks
+ ):
+ error_message = f"User {msg.owner_id} has exceeded the maximum number of allowed chunks: {max_chunks}"
+ logger.error(error_message)
+ yield error_message
+ continue
+
+ # Once we hit our batch size, store them
+ if len(vector_batch) >= storage_batch_size:
+ try:
+ await (
+ self.providers.database.chunks_handler.upsert_entries(
+ vector_batch
+ )
+ )
+ except Exception as e:
+ logger.error(f"Failed to store vector batch: {e}")
+ yield f"Error: {e}"
+ vector_batch.clear()
+
+ # Store any leftover items
+ if vector_batch:
+ try:
+ await self.providers.database.chunks_handler.upsert_entries(
+ vector_batch
+ )
+ except Exception as e:
+ logger.error(f"Failed to store final vector batch: {e}")
+ yield f"Error: {e}"
+
+ # Summaries
+ for doc_id, cnt in document_counts.items():
+ info_msg = f"Successful ingestion for document_id: {doc_id}, with vector count: {cnt}"
+ logger.info(info_msg)
+ yield info_msg
+
+ async def finalize_ingestion(
+ self, document_info: DocumentResponse
+ ) -> None:
+ """Called at the end of a successful ingestion pipeline to set the
+ document status to SUCCESS or similar final steps."""
+
+ async def empty_generator():
+ yield document_info
+
+ await self.update_document_status(
+ document_info, IngestionStatus.SUCCESS
+ )
+ return empty_generator()
+
+ async def update_document_status(
+ self,
+ document_info: DocumentResponse,
+ status: IngestionStatus,
+ metadata: Optional[dict] = None,
+ ) -> None:
+ document_info.ingestion_status = status
+ if metadata:
+ document_info.metadata = {**document_info.metadata, **metadata}
+ await self._update_document_status_in_db(document_info)
+
+ async def _update_document_status_in_db(
+ self, document_info: DocumentResponse
+ ):
+ try:
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ document_info
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status: {document_info.id}. Error: {str(e)}"
+ )
+
+ async def ingest_chunks_ingress(
+ self,
+ document_id: UUID,
+ metadata: Optional[dict],
+ chunks: list[RawChunk],
+ user: User,
+ *args: Any,
+ **kwargs: Any,
+ ) -> DocumentResponse:
+ """Directly ingest user-provided text chunks (rather than from a
+ file)."""
+ if not chunks:
+ raise R2RException(
+ status_code=400, message="No chunks provided for ingestion."
+ )
+ metadata = metadata or {}
+ version = STARTING_VERSION
+
+ document_info = self._create_document_info_from_chunks(
+ document_id,
+ user,
+ chunks,
+ metadata,
+ version,
+ )
+
+ existing_document_info = (
+ await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_user_ids=[user.id],
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+ if len(existing_document_info) > 0:
+ existing_doc = existing_document_info[0]
+ if existing_doc.ingestion_status != IngestionStatus.FAILED:
+ raise R2RException(
+ status_code=409,
+ message=(
+ f"Document {document_id} was already ingested "
+ "and is not in a failed state."
+ ),
+ )
+
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ document_info
+ )
+ return document_info
+
+ async def update_chunk_ingress(
+ self,
+ document_id: UUID,
+ chunk_id: UUID,
+ text: str,
+ user: User,
+ metadata: Optional[dict] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ """Update an individual chunk's text and metadata, re-embed, and re-
+ store it."""
+ # Verify chunk exists and user has access
+ existing_chunks = (
+ await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=0,
+ limit=1,
+ )
+ )
+ if not existing_chunks["results"]:
+ raise R2RException(
+ status_code=404,
+ message=f"Chunk with chunk_id {chunk_id} not found.",
+ )
+
+ existing_chunk = (
+ await self.providers.database.chunks_handler.get_chunk(chunk_id)
+ )
+ if not existing_chunk:
+ raise R2RException(
+ status_code=404,
+ message=f"Chunk with id {chunk_id} not found",
+ )
+
+ if (
+ str(existing_chunk["owner_id"]) != str(user.id)
+ and not user.is_superuser
+ ):
+ raise R2RException(
+ status_code=403,
+ message="You don't have permission to modify this chunk.",
+ )
+
+ # Merge metadata
+ merged_metadata = {**existing_chunk["metadata"]}
+ if metadata is not None:
+ merged_metadata |= metadata
+
+ # Create updated chunk
+ extraction_data = {
+ "id": chunk_id,
+ "document_id": document_id,
+ "collection_ids": kwargs.get(
+ "collection_ids", existing_chunk["collection_ids"]
+ ),
+ "owner_id": existing_chunk["owner_id"],
+ "data": text or existing_chunk["text"],
+ "metadata": merged_metadata,
+ }
+ extraction = DocumentChunk(**extraction_data).model_dump()
+
+ # Re-embed
+ embeddings_generator = self.embed_document(
+ [extraction], embedding_batch_size=1
+ )
+ embeddings = []
+ async for embedding in embeddings_generator:
+ embeddings.append(embedding)
+
+ # Re-store
+ store_gen = self.store_embeddings(embeddings, storage_batch_size=1)
+ async for _ in store_gen:
+ pass
+
+ return extraction
+
+ async def _get_enriched_chunk_text(
+ self,
+ chunk_idx: int,
+ chunk: dict,
+ document_id: UUID,
+ document_summary: str | None,
+ chunk_enrichment_settings: ChunkEnrichmentSettings,
+ list_document_chunks: list[dict],
+ ) -> VectorEntry:
+ """Helper for chunk_enrichment.
+
+ Leverages an LLM to rewrite or expand chunk text, then re-embeds it.
+ """
+ preceding_chunks = [
+ list_document_chunks[idx]["text"]
+ for idx in range(
+ max(0, chunk_idx - chunk_enrichment_settings.n_chunks),
+ chunk_idx,
+ )
+ ]
+ succeeding_chunks = [
+ list_document_chunks[idx]["text"]
+ for idx in range(
+ chunk_idx + 1,
+ min(
+ len(list_document_chunks),
+ chunk_idx + chunk_enrichment_settings.n_chunks + 1,
+ ),
+ )
+ ]
+ try:
+ # Obtain the updated text from the LLM
+ updated_chunk_text = (
+ (
+ await self.providers.llm.aget_completion(
+ messages=await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=chunk_enrichment_settings.chunk_enrichment_prompt,
+ task_inputs={
+ "document_summary": document_summary or "None",
+ "chunk": chunk["text"],
+ "preceding_chunks": (
+ "\n".join(preceding_chunks)
+ if preceding_chunks
+ else "None"
+ ),
+ "succeeding_chunks": (
+ "\n".join(succeeding_chunks)
+ if succeeding_chunks
+ else "None"
+ ),
+ "chunk_size": self.config.ingestion.chunk_size
+ or 1024,
+ },
+ ),
+ generation_config=chunk_enrichment_settings.generation_config
+ or GenerationConfig(model=self.config.app.fast_llm),
+ )
+ )
+ .choices[0]
+ .message.content
+ )
+ except Exception:
+ updated_chunk_text = chunk["text"]
+ chunk["metadata"]["chunk_enrichment_status"] = "failed"
+ else:
+ chunk["metadata"]["chunk_enrichment_status"] = (
+ "success" if updated_chunk_text else "failed"
+ )
+
+ if not updated_chunk_text or not isinstance(updated_chunk_text, str):
+ updated_chunk_text = str(chunk["text"])
+ chunk["metadata"]["chunk_enrichment_status"] = "failed"
+
+ # Re-embed
+ data = await self.providers.embedding.async_get_embedding(
+ updated_chunk_text
+ )
+ chunk["metadata"]["original_text"] = chunk["text"]
+
+ return VectorEntry(
+ id=generate_id(str(chunk["id"])),
+ vector=Vector(data=data, type=VectorType.FIXED, length=len(data)),
+ document_id=document_id,
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ text=updated_chunk_text,
+ metadata=chunk["metadata"],
+ )
+
+ async def chunk_enrichment(
+ self,
+ document_id: UUID,
+ document_summary: str | None,
+ chunk_enrichment_settings: ChunkEnrichmentSettings,
+ ) -> int:
+ """Example function that modifies chunk text via an LLM then re-embeds
+ and re-stores all chunks for the given document."""
+ list_document_chunks = (
+ await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=0,
+ limit=-1,
+ )
+ )["results"]
+
+ new_vector_entries: list[VectorEntry] = []
+ tasks = []
+ total_completed = 0
+
+ for chunk_idx, chunk in enumerate(list_document_chunks):
+ tasks.append(
+ self._get_enriched_chunk_text(
+ chunk_idx=chunk_idx,
+ chunk=chunk,
+ document_id=document_id,
+ document_summary=document_summary,
+ chunk_enrichment_settings=chunk_enrichment_settings,
+ list_document_chunks=list_document_chunks,
+ )
+ )
+
+ # Process in batches of e.g. 128 concurrency
+ if len(tasks) == 128:
+ new_vector_entries.extend(await asyncio.gather(*tasks))
+ total_completed += 128
+ logger.info(
+ f"Completed {total_completed} out of {len(list_document_chunks)} chunks for document {document_id}"
+ )
+ tasks = []
+
+ # Finish any remaining tasks
+ new_vector_entries.extend(await asyncio.gather(*tasks))
+ logger.info(
+ f"Completed enrichment of {len(list_document_chunks)} chunks for document {document_id}"
+ )
+
+ # Delete old chunks from vector db
+ await self.providers.database.chunks_handler.delete(
+ filters={"document_id": document_id}
+ )
+
+ # Insert the newly enriched entries
+ await self.providers.database.chunks_handler.upsert_entries(
+ new_vector_entries
+ )
+ return len(new_vector_entries)
+
+ async def list_chunks(
+ self,
+ offset: int,
+ limit: int,
+ filters: Optional[dict[str, Any]] = None,
+ include_vectors: bool = False,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ return await self.providers.database.chunks_handler.list_chunks(
+ offset=offset,
+ limit=limit,
+ filters=filters,
+ include_vectors=include_vectors,
+ )
+
+ async def get_chunk(
+ self,
+ chunk_id: UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ return await self.providers.database.chunks_handler.get_chunk(chunk_id)
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: dict,
+ user: User,
+ ) -> None:
+ # Verify document exists and user has access
+ existing_document = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_document_ids=[document_id],
+ filter_user_ids=[user.id],
+ )
+ if not existing_document["results"]:
+ raise R2RException(
+ status_code=404,
+ message=(
+ f"Document with id {document_id} not found "
+ "or you don't have access."
+ ),
+ )
+
+ existing_document = existing_document["results"][0]
+
+ # Merge metadata
+ merged_metadata = {**existing_document.metadata, **metadata} # type: ignore
+
+ # Update document metadata
+ existing_document.metadata = merged_metadata # type: ignore
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ existing_document # type: ignore
+ )
+
+
+class IngestionServiceAdapter:
+ @staticmethod
+ def _parse_user_data(user_data) -> User:
+ if isinstance(user_data, str):
+ try:
+ user_data = json.loads(user_data)
+ except json.JSONDecodeError as e:
+ raise ValueError(
+ f"Invalid user data format: {user_data}"
+ ) from e
+ return User.from_dict(user_data)
+
+ @staticmethod
+ def parse_ingest_file_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "metadata": data["metadata"],
+ "document_id": (
+ UUID(data["document_id"]) if data["document_id"] else None
+ ),
+ "version": data.get("version"),
+ "ingestion_config": data["ingestion_config"] or {},
+ "file_data": data["file_data"],
+ "size_in_bytes": data["size_in_bytes"],
+ "collection_ids": data.get("collection_ids", []),
+ }
+
+ @staticmethod
+ def parse_ingest_chunks_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "metadata": data["metadata"],
+ "document_id": data["document_id"],
+ "chunks": [
+ UnprocessedChunk.from_dict(chunk) for chunk in data["chunks"]
+ ],
+ "id": data.get("id"),
+ }
+
+ @staticmethod
+ def parse_update_chunk_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "document_id": UUID(data["document_id"]),
+ "id": UUID(data["id"]),
+ "text": data["text"],
+ "metadata": data.get("metadata"),
+ "collection_ids": data.get("collection_ids", []),
+ }
+
+ @staticmethod
+ def parse_update_files_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "document_ids": [UUID(doc_id) for doc_id in data["document_ids"]],
+ "metadatas": data["metadatas"],
+ "ingestion_config": data["ingestion_config"],
+ "file_sizes_in_bytes": data["file_sizes_in_bytes"],
+ "file_datas": data["file_datas"],
+ }
+
+ @staticmethod
+ def parse_create_vector_index_input(data: dict) -> dict:
+ return {
+ "table_name": VectorTableName(data["table_name"]),
+ "index_method": IndexMethod(data["index_method"]),
+ "index_measure": IndexMeasure(data["index_measure"]),
+ "index_name": data["index_name"],
+ "index_column": data["index_column"],
+ "index_arguments": data["index_arguments"],
+ "concurrently": data["concurrently"],
+ }
+
+ @staticmethod
+ def parse_list_vector_indices_input(input_data: dict) -> dict:
+ return {"table_name": input_data["table_name"]}
+
+ @staticmethod
+ def parse_delete_vector_index_input(input_data: dict) -> dict:
+ return {
+ "index_name": input_data["index_name"],
+ "table_name": input_data.get("table_name"),
+ "concurrently": input_data.get("concurrently", True),
+ }
+
+ @staticmethod
+ def parse_select_vector_index_input(input_data: dict) -> dict:
+ return {
+ "index_name": input_data["index_name"],
+ "table_name": input_data.get("table_name"),
+ }
+
+ @staticmethod
+ def parse_update_document_metadata_input(data: dict) -> dict:
+ return {
+ "document_id": data["document_id"],
+ "metadata": data["metadata"],
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/management_service.py b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py
new file mode 100644
index 00000000..62b4ca0b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py
@@ -0,0 +1,1084 @@
+import logging
+import os
+from collections import defaultdict
+from datetime import datetime, timedelta, timezone
+from typing import IO, Any, BinaryIO, Optional, Tuple
+from uuid import UUID
+
+import toml
+
+from core.base import (
+ CollectionResponse,
+ ConversationResponse,
+ DocumentResponse,
+ GenerationConfig,
+ GraphConstructionStatus,
+ Message,
+ MessageResponse,
+ Prompt,
+ R2RException,
+ StoreType,
+ User,
+)
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+class ManagementService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def app_settings(self):
+ prompts = (
+ await self.providers.database.prompts_handler.get_all_prompts()
+ )
+ config_toml = self.config.to_toml()
+ config_dict = toml.loads(config_toml)
+ try:
+ project_name = os.environ["R2R_PROJECT_NAME"]
+ except KeyError:
+ project_name = ""
+ return {
+ "config": config_dict,
+ "prompts": prompts,
+ "r2r_project_name": project_name,
+ }
+
+ async def users_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ ):
+ return await self.providers.database.users_handler.get_users_overview(
+ offset=offset,
+ limit=limit,
+ user_ids=user_ids,
+ )
+
+ async def delete_documents_and_chunks_by_filter(
+ self,
+ filters: dict[str, Any],
+ ):
+ """Delete chunks matching the given filters. If any documents are now
+ empty (i.e., have no remaining chunks), delete those documents as well.
+
+ Args:
+ filters (dict[str, Any]): Filters specifying which chunks to delete.
+ chunks_handler (PostgresChunksHandler): The handler for chunk operations.
+ documents_handler (PostgresDocumentsHandler): The handler for document operations.
+ graphs_handler: Handler for entity and relationship operations in the Graph.
+
+ Returns:
+ dict: A summary of what was deleted.
+ """
+
+ def transform_chunk_id_to_id(
+ filters: dict[str, Any],
+ ) -> dict[str, Any]:
+ """Example transformation function if your filters use `chunk_id`
+ instead of `id`.
+
+ Recursively transform `chunk_id` to `id`.
+ """
+ if isinstance(filters, dict):
+ transformed = {}
+ for key, value in filters.items():
+ if key == "chunk_id":
+ transformed["id"] = value
+ elif key in ["$and", "$or"]:
+ transformed[key] = [
+ transform_chunk_id_to_id(item) for item in value
+ ]
+ else:
+ transformed[key] = transform_chunk_id_to_id(value)
+ return transformed
+ return filters
+
+ # Transform filters if needed.
+ transformed_filters = transform_chunk_id_to_id(filters)
+
+ # Find chunks that match the filters before deleting
+ interim_results = (
+ await self.providers.database.chunks_handler.list_chunks(
+ filters=transformed_filters,
+ offset=0,
+ limit=1_000,
+ include_vectors=False,
+ )
+ )
+
+ results = interim_results["results"]
+ while interim_results["total_entries"] == 1_000:
+ # If we hit the limit, we need to paginate to get all results
+
+ interim_results = (
+ await self.providers.database.chunks_handler.list_chunks(
+ filters=transformed_filters,
+ offset=interim_results["offset"] + 1_000,
+ limit=1_000,
+ include_vectors=False,
+ )
+ )
+ results.extend(interim_results["results"])
+
+ document_ids = set()
+ owner_id = None
+
+ if "$and" in filters:
+ for condition in filters["$and"]:
+ if "owner_id" in condition and "$eq" in condition["owner_id"]:
+ owner_id = condition["owner_id"]["$eq"]
+ elif (
+ "document_id" in condition
+ and "$eq" in condition["document_id"]
+ ):
+ document_ids.add(UUID(condition["document_id"]["$eq"]))
+ elif "document_id" in filters:
+ doc_id = filters["document_id"]
+ if isinstance(doc_id, str):
+ document_ids.add(UUID(doc_id))
+ elif isinstance(doc_id, UUID):
+ document_ids.add(doc_id)
+ elif isinstance(doc_id, dict) and "$eq" in doc_id:
+ value = doc_id["$eq"]
+ document_ids.add(
+ UUID(value) if isinstance(value, str) else value
+ )
+
+ # Delete matching chunks from the database
+ delete_results = await self.providers.database.chunks_handler.delete(
+ transformed_filters
+ )
+
+ # Extract the document_ids that were affected.
+ affected_doc_ids = {
+ UUID(info["document_id"])
+ for info in delete_results.values()
+ if info.get("document_id")
+ }
+ document_ids.update(affected_doc_ids)
+
+ # Check if the document still has any chunks left
+ docs_to_delete = []
+ for doc_id in document_ids:
+ documents_overview_response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0, limit=1, filter_document_ids=[doc_id]
+ )
+ if not documents_overview_response["results"]:
+ raise R2RException(
+ status_code=404, message="Document not found"
+ )
+
+ document = documents_overview_response["results"][0]
+
+ for collection_id in document.collection_ids:
+ await self.providers.database.collections_handler.decrement_collection_document_count(
+ collection_id=collection_id
+ )
+
+ if owner_id and str(document.owner_id) != owner_id:
+ raise R2RException(
+ status_code=404,
+ message="Document not found or insufficient permissions",
+ )
+ docs_to_delete.append(doc_id)
+
+ # Delete documents that no longer have associated chunks
+ for doc_id in docs_to_delete:
+ # Delete related entities & relationships if needed:
+ await self.providers.database.graphs_handler.entities.delete(
+ parent_id=doc_id,
+ store_type=StoreType.DOCUMENTS,
+ )
+ await self.providers.database.graphs_handler.relationships.delete(
+ parent_id=doc_id,
+ store_type=StoreType.DOCUMENTS,
+ )
+
+ # Finally, delete the document from documents_overview:
+ await self.providers.database.documents_handler.delete(
+ document_id=doc_id
+ )
+
+ return {
+ "success": True,
+ "deleted_chunks_count": len(delete_results),
+ "deleted_documents_count": len(docs_to_delete),
+ "deleted_document_ids": [str(d) for d in docs_to_delete],
+ }
+
+ async def download_file(
+ self, document_id: UUID
+ ) -> Optional[Tuple[str, BinaryIO, int]]:
+ if result := await self.providers.database.files_handler.retrieve_file(
+ document_id
+ ):
+ return result
+ return None
+
+ async def export_files(
+ self,
+ document_ids: Optional[list[UUID]] = None,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ ) -> tuple[str, BinaryIO, int]:
+ return (
+ await self.providers.database.files_handler.retrieve_files_as_zip(
+ document_ids=document_ids,
+ start_date=start_date,
+ end_date=end_date,
+ )
+ )
+
+ async def export_collections(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.collections_handler.export_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_documents(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.documents_handler.export_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_document_entities(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.entities.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_document_relationships(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.relationships.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_conversations(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.conversations_handler.export_conversations_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_graph_entities(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.entities.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.GRAPHS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_graph_relationships(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.relationships.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.GRAPHS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_graph_communities(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.communities.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.GRAPHS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_messages(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.conversations_handler.export_messages_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_users(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.users_handler.export_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def documents_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ collection_ids: Optional[list[UUID]] = None,
+ document_ids: Optional[list[UUID]] = None,
+ ):
+ return await self.providers.database.documents_handler.get_documents_overview(
+ offset=offset,
+ limit=limit,
+ filter_document_ids=document_ids,
+ filter_user_ids=user_ids,
+ filter_collection_ids=collection_ids,
+ )
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: list[dict],
+ overwrite: bool = False,
+ ):
+ return await self.providers.database.documents_handler.update_document_metadata(
+ document_id=document_id,
+ metadata=metadata,
+ overwrite=overwrite,
+ )
+
+ async def list_document_chunks(
+ self,
+ document_id: UUID,
+ offset: int,
+ limit: int,
+ include_vectors: bool = False,
+ ):
+ return (
+ await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=offset,
+ limit=limit,
+ include_vectors=include_vectors,
+ )
+ )
+
+ async def assign_document_to_collection(
+ self, document_id: UUID, collection_id: UUID
+ ):
+ await self.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id, collection_id
+ )
+ await self.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id, collection_id
+ )
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+
+ return {"message": "Document assigned to collection successfully"}
+
+ async def remove_document_from_collection(
+ self, document_id: UUID, collection_id: UUID
+ ):
+ await self.providers.database.collections_handler.remove_document_from_collection_relational(
+ document_id, collection_id
+ )
+ await self.providers.database.chunks_handler.remove_document_from_collection_vector(
+ document_id, collection_id
+ )
+ # await self.providers.database.graphs_handler.delete_node_via_document_id(
+ # document_id, collection_id
+ # )
+ return None
+
+ def _process_relationships(
+ self, relationships: list[Tuple[str, str, str]]
+ ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]:
+ graph = defaultdict(list)
+ grouped: dict[str, dict[str, list[str]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ for subject, relation, obj in relationships:
+ graph[subject].append(obj)
+ grouped[subject][relation].append(obj)
+ if obj not in graph:
+ graph[obj] = []
+ return dict(graph), dict(grouped)
+
+ def generate_output(
+ self,
+ grouped_relationships: dict[str, dict[str, list[str]]],
+ graph: dict[str, list[str]],
+ descriptions_dict: dict[str, str],
+ print_descriptions: bool = True,
+ ) -> list[str]:
+ output = []
+ # Print grouped relationships
+ for subject, relations in grouped_relationships.items():
+ output.append(f"\n== {subject} ==")
+ if print_descriptions and subject in descriptions_dict:
+ output.append(f"\tDescription: {descriptions_dict[subject]}")
+ for relation, objects in relations.items():
+ output.append(f" {relation}:")
+ for obj in objects:
+ output.append(f" - {obj}")
+ if print_descriptions and obj in descriptions_dict:
+ output.append(
+ f" Description: {descriptions_dict[obj]}"
+ )
+
+ # Print basic graph statistics
+ output.extend(
+ [
+ "\n== Graph Statistics ==",
+ f"Number of nodes: {len(graph)}",
+ f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}",
+ f"Number of connected components: {self._count_connected_components(graph)}",
+ ]
+ )
+
+ # Find central nodes
+ central_nodes = self._get_central_nodes(graph)
+ output.extend(
+ [
+ "\n== Most Central Nodes ==",
+ *(
+ f" {node}: {centrality:.4f}"
+ for node, centrality in central_nodes
+ ),
+ ]
+ )
+
+ return output
+
+ def _count_connected_components(self, graph: dict[str, list[str]]) -> int:
+ visited = set()
+ components = 0
+
+ def dfs(node):
+ visited.add(node)
+ for neighbor in graph[node]:
+ if neighbor not in visited:
+ dfs(neighbor)
+
+ for node in graph:
+ if node not in visited:
+ dfs(node)
+ components += 1
+
+ return components
+
+ def _get_central_nodes(
+ self, graph: dict[str, list[str]]
+ ) -> list[Tuple[str, float]]:
+ degree = {node: len(neighbors) for node, neighbors in graph.items()}
+ total_nodes = len(graph)
+ centrality = {
+ node: deg / (total_nodes - 1) for node, deg in degree.items()
+ }
+ return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
+
+ async def create_collection(
+ self,
+ owner_id: UUID,
+ name: Optional[str] = None,
+ description: str | None = None,
+ ) -> CollectionResponse:
+ result = await self.providers.database.collections_handler.create_collection(
+ owner_id=owner_id,
+ name=name,
+ description=description,
+ )
+ await self.providers.database.graphs_handler.create(
+ collection_id=result.id,
+ name=name,
+ description=description,
+ )
+ return result
+
+ async def update_collection(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ generate_description: bool = False,
+ ) -> CollectionResponse:
+ if generate_description:
+ description = await self.summarize_collection(
+ id=collection_id, offset=0, limit=100
+ )
+ return await self.providers.database.collections_handler.update_collection(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ )
+
+ async def delete_collection(self, collection_id: UUID) -> bool:
+ await self.providers.database.collections_handler.delete_collection_relational(
+ collection_id
+ )
+ await self.providers.database.chunks_handler.delete_collection_vector(
+ collection_id
+ )
+ try:
+ await self.providers.database.graphs_handler.delete(
+ collection_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Error deleting graph for collection {collection_id}: {e}"
+ )
+ return True
+
+ async def collections_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ document_ids: Optional[list[UUID]] = None,
+ collection_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[CollectionResponse] | int]:
+ return await self.providers.database.collections_handler.get_collections_overview(
+ offset=offset,
+ limit=limit,
+ filter_user_ids=user_ids,
+ filter_document_ids=document_ids,
+ filter_collection_ids=collection_ids,
+ )
+
+ async def add_user_to_collection(
+ self, user_id: UUID, collection_id: UUID
+ ) -> bool:
+ return (
+ await self.providers.database.users_handler.add_user_to_collection(
+ user_id, collection_id
+ )
+ )
+
+ async def remove_user_from_collection(
+ self, user_id: UUID, collection_id: UUID
+ ) -> bool:
+ return await self.providers.database.users_handler.remove_user_from_collection(
+ user_id, collection_id
+ )
+
+ async def get_users_in_collection(
+ self, collection_id: UUID, offset: int = 0, limit: int = 100
+ ) -> dict[str, list[User] | int]:
+ return await self.providers.database.users_handler.get_users_in_collection(
+ collection_id, offset=offset, limit=limit
+ )
+
+ async def documents_in_collection(
+ self, collection_id: UUID, offset: int = 0, limit: int = 100
+ ) -> dict[str, list[DocumentResponse] | int]:
+ return await self.providers.database.collections_handler.documents_in_collection(
+ collection_id, offset=offset, limit=limit
+ )
+
+ async def summarize_collection(
+ self, id: UUID, offset: int, limit: int
+ ) -> str:
+ documents_in_collection_response = await self.documents_in_collection(
+ collection_id=id,
+ offset=offset,
+ limit=limit,
+ )
+
+ document_summaries = [
+ document.summary
+ for document in documents_in_collection_response["results"] # type: ignore
+ ]
+
+ logger.info(
+ f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents."
+ )
+
+ formatted_summaries = "\n\n".join(document_summaries) # type: ignore
+
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ system_prompt_name=self.config.database.collection_summary_system_prompt,
+ task_prompt_name=self.config.database.collection_summary_prompt,
+ task_inputs={"document_summaries": formatted_summaries},
+ )
+
+ response = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=GenerationConfig(
+ model=self.config.ingestion.document_summary_model
+ or self.config.app.fast_llm
+ ),
+ )
+
+ if collection_summary := response.choices[0].message.content:
+ return collection_summary
+ else:
+ raise ValueError("Expected a generated response.")
+
+ async def add_prompt(
+ self, name: str, template: str, input_types: dict[str, str]
+ ) -> dict:
+ try:
+ await self.providers.database.prompts_handler.add_prompt(
+ name, template, input_types
+ )
+ return f"Prompt '{name}' added successfully." # type: ignore
+ except ValueError as e:
+ raise R2RException(status_code=400, message=str(e)) from e
+
+ async def get_cached_prompt(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ prompt_override: Optional[str] = None,
+ ) -> dict:
+ try:
+ return {
+ "message": (
+ await self.providers.database.prompts_handler.get_cached_prompt(
+ prompt_name=prompt_name,
+ inputs=inputs,
+ prompt_override=prompt_override,
+ )
+ )
+ }
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def get_prompt(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ prompt_override: Optional[str] = None,
+ ) -> dict:
+ try:
+ return await self.providers.database.prompts_handler.get_prompt( # type: ignore
+ name=prompt_name,
+ inputs=inputs,
+ prompt_override=prompt_override,
+ )
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def get_all_prompts(self) -> dict[str, Prompt]:
+ return await self.providers.database.prompts_handler.get_all_prompts()
+
+ async def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> dict:
+ try:
+ await self.providers.database.prompts_handler.update_prompt(
+ name, template, input_types
+ )
+ return f"Prompt '{name}' updated successfully." # type: ignore
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def delete_prompt(self, name: str) -> dict:
+ try:
+ await self.providers.database.prompts_handler.delete_prompt(name)
+ return {"message": f"Prompt '{name}' deleted successfully."}
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def get_conversation(
+ self,
+ conversation_id: UUID,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> list[MessageResponse]:
+ return await self.providers.database.conversations_handler.get_conversation(
+ conversation_id=conversation_id,
+ filter_user_ids=user_ids,
+ )
+
+ async def create_conversation(
+ self,
+ user_id: Optional[UUID] = None,
+ name: Optional[str] = None,
+ ) -> ConversationResponse:
+ return await self.providers.database.conversations_handler.create_conversation(
+ user_id=user_id,
+ name=name,
+ )
+
+ async def conversations_overview(
+ self,
+ offset: int,
+ limit: int,
+ conversation_ids: Optional[list[UUID]] = None,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[dict] | int]:
+ return await self.providers.database.conversations_handler.get_conversations_overview(
+ offset=offset,
+ limit=limit,
+ filter_user_ids=user_ids,
+ conversation_ids=conversation_ids,
+ )
+
+ async def add_message(
+ self,
+ conversation_id: UUID,
+ content: Message,
+ parent_id: Optional[UUID] = None,
+ metadata: Optional[dict] = None,
+ ) -> MessageResponse:
+ return await self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=content,
+ parent_id=parent_id,
+ metadata=metadata,
+ )
+
+ async def edit_message(
+ self,
+ message_id: UUID,
+ new_content: Optional[str] = None,
+ additional_metadata: Optional[dict] = None,
+ ) -> dict[str, Any]:
+ return (
+ await self.providers.database.conversations_handler.edit_message(
+ message_id=message_id,
+ new_content=new_content,
+ additional_metadata=additional_metadata or {},
+ )
+ )
+
+ async def update_conversation(
+ self, conversation_id: UUID, name: str
+ ) -> ConversationResponse:
+ return await self.providers.database.conversations_handler.update_conversation(
+ conversation_id=conversation_id, name=name
+ )
+
+ async def delete_conversation(
+ self,
+ conversation_id: UUID,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> None:
+ await (
+ self.providers.database.conversations_handler.delete_conversation(
+ conversation_id=conversation_id,
+ filter_user_ids=user_ids,
+ )
+ )
+
+ async def get_user_max_documents(self, user_id: UUID) -> int | None:
+ # Fetch the user to see if they have any overrides stored
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if user.limits_overrides and "max_documents" in user.limits_overrides:
+ return user.limits_overrides["max_documents"]
+ return self.config.app.default_max_documents_per_user
+
+ async def get_user_max_chunks(self, user_id: UUID) -> int | None:
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if user.limits_overrides and "max_chunks" in user.limits_overrides:
+ return user.limits_overrides["max_chunks"]
+ return self.config.app.default_max_chunks_per_user
+
+ async def get_user_max_collections(self, user_id: UUID) -> int | None:
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if (
+ user.limits_overrides
+ and "max_collections" in user.limits_overrides
+ ):
+ return user.limits_overrides["max_collections"]
+ return self.config.app.default_max_collections_per_user
+
+ async def get_max_upload_size_by_type(
+ self, user_id: UUID, file_type_or_ext: str
+ ) -> int:
+ """Return the maximum allowed upload size (in bytes) for the given
+ user's file type/extension. Respects user-level overrides if present,
+ falling back to the system config.
+
+ ```json
+ {
+ "limits_overrides": {
+ "max_file_size": 20_000_000,
+ "max_file_size_by_type":
+ {
+ "pdf": 50_000_000,
+ "docx": 30_000_000
+ },
+ ...
+ }
+ }
+ ```
+ """
+ # 1. Normalize extension
+ ext = file_type_or_ext.lower().lstrip(".")
+
+ # 2. Fetch user from DB to see if we have any overrides
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ user_overrides = user.limits_overrides or {}
+
+ # 3. Check if there's a user-level override for "max_file_size_by_type"
+ user_file_type_limits = user_overrides.get("max_file_size_by_type", {})
+ if ext in user_file_type_limits:
+ return user_file_type_limits[ext]
+
+ # 4. If not, check if there's a user-level fallback "max_file_size"
+ if "max_file_size" in user_overrides:
+ return user_overrides["max_file_size"]
+
+ # 5. If none exist at user level, use system config
+ # Example config paths:
+ system_type_limits = self.config.app.max_upload_size_by_type
+ if ext in system_type_limits:
+ return system_type_limits[ext]
+
+ # 6. Otherwise, return the global default
+ return self.config.app.default_max_upload_size
+
+ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
+ """
+ Return a dictionary containing:
+ - The system default limits (from self.config.limits)
+ - The user's overrides (from user.limits_overrides)
+ - The final 'effective' set of limits after merging (overall)
+ - The usage for each relevant limit (per-route usage, etc.)
+ """
+ # 1) Fetch the user
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ user_overrides = user.limits_overrides or {}
+
+ # 2) Grab system defaults
+ system_defaults = {
+ "global_per_min": self.config.database.limits.global_per_min,
+ "route_per_min": self.config.database.limits.route_per_min,
+ "monthly_limit": self.config.database.limits.monthly_limit,
+ # Add additional fields if your LimitSettings has them
+ }
+
+ # 3) Build the overall (global) "effective limits" ignoring any specific route
+ overall_effective = (
+ self.providers.database.limits_handler.determine_effective_limits(
+ user, route=""
+ )
+ )
+
+ # 4) Build usage data. We'll do top-level usage for global_per_min/monthly,
+ # then do route-by-route usage in a loop.
+ usage: dict[str, Any] = {}
+ now = datetime.now(timezone.utc)
+ one_min_ago = now - timedelta(minutes=1)
+
+ # (a) Global usage (per-minute)
+ global_per_min_used = (
+ await self.providers.database.limits_handler._count_requests(
+ user_id, route=None, since=one_min_ago
+ )
+ )
+ # (a2) Global usage (monthly) - i.e. usage across ALL routes
+ global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
+ user_id, route=None
+ )
+
+ usage["global_per_min"] = {
+ "used": global_per_min_used,
+ "limit": overall_effective.global_per_min,
+ "remaining": (
+ overall_effective.global_per_min - global_per_min_used
+ if overall_effective.global_per_min is not None
+ else None
+ ),
+ }
+ usage["monthly_limit"] = {
+ "used": global_monthly_used,
+ "limit": overall_effective.monthly_limit,
+ "remaining": (
+ overall_effective.monthly_limit - global_monthly_used
+ if overall_effective.monthly_limit is not None
+ else None
+ ),
+ }
+
+ # (b) Route-level usage. We'll gather all routes from system + user overrides
+ system_route_limits = (
+ self.config.database.route_limits
+ ) # dict[str, LimitSettings]
+ user_route_overrides = user_overrides.get("route_overrides", {})
+ route_keys = set(system_route_limits.keys()) | set(
+ user_route_overrides.keys()
+ )
+
+ usage["routes"] = {}
+ for route in route_keys:
+ # 1) Get the final merged limits for this specific route
+ route_effective = self.providers.database.limits_handler.determine_effective_limits(
+ user, route
+ )
+
+ # 2) Count requests for the last minute on this route
+ route_per_min_used = (
+ await self.providers.database.limits_handler._count_requests(
+ user_id, route, one_min_ago
+ )
+ )
+
+ # 3) Count route-specific monthly usage
+ route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
+ user_id, route
+ )
+
+ usage["routes"][route] = {
+ "route_per_min": {
+ "used": route_per_min_used,
+ "limit": route_effective.route_per_min,
+ "remaining": (
+ route_effective.route_per_min - route_per_min_used
+ if route_effective.route_per_min is not None
+ else None
+ ),
+ },
+ "monthly_limit": {
+ "used": route_monthly_used,
+ "limit": route_effective.monthly_limit,
+ "remaining": (
+ route_effective.monthly_limit - route_monthly_used
+ if route_effective.monthly_limit is not None
+ else None
+ ),
+ },
+ }
+
+ max_documents = await self.get_user_max_documents(user_id)
+ used_documents = (
+ await self.providers.database.documents_handler.get_documents_overview(
+ limit=1, offset=0, filter_user_ids=[user_id]
+ )
+ )["total_entries"]
+ max_chunks = await self.get_user_max_chunks(user_id)
+ used_chunks = (
+ await self.providers.database.chunks_handler.list_chunks(
+ limit=1, offset=0, filters={"owner_id": user_id}
+ )
+ )["total_entries"]
+
+ max_collections = await self.get_user_max_collections(user_id)
+ used_collections: int = ( # type: ignore
+ await self.providers.database.collections_handler.get_collections_overview(
+ limit=1, offset=0, filter_user_ids=[user_id]
+ )
+ )["total_entries"]
+
+ storage_limits = {
+ "chunks": {
+ "limit": max_chunks,
+ "used": used_chunks,
+ "remaining": (
+ max_chunks - used_chunks
+ if max_chunks is not None
+ else None
+ ),
+ },
+ "documents": {
+ "limit": max_documents,
+ "used": used_documents,
+ "remaining": (
+ max_documents - used_documents
+ if max_documents is not None
+ else None
+ ),
+ },
+ "collections": {
+ "limit": max_collections,
+ "used": used_collections,
+ "remaining": (
+ max_collections - used_collections
+ if max_collections is not None
+ else None
+ ),
+ },
+ }
+ # 5) Return a structured response
+ return {
+ "storage_limits": storage_limits,
+ "system_defaults": system_defaults,
+ "user_overrides": user_overrides,
+ "effective_limits": {
+ "global_per_min": overall_effective.global_per_min,
+ "route_per_min": overall_effective.route_per_min,
+ "monthly_limit": overall_effective.monthly_limit,
+ },
+ "usage": usage,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py
new file mode 100644
index 00000000..2ae4af31
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py
@@ -0,0 +1,2087 @@
+import asyncio
+import json
+import logging
+from copy import deepcopy
+from datetime import datetime
+from typing import Any, AsyncGenerator, Literal, Optional
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core import (
+ Citation,
+ R2RRAGAgent,
+ R2RStreamingRAGAgent,
+ R2RStreamingResearchAgent,
+ R2RXMLToolsRAGAgent,
+ R2RXMLToolsResearchAgent,
+ R2RXMLToolsStreamingRAGAgent,
+ R2RXMLToolsStreamingResearchAgent,
+)
+from core.agent.research import R2RResearchAgent
+from core.base import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ DocumentResponse,
+ GenerationConfig,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+ GraphSearchResult,
+ GraphSearchResultType,
+ IngestionStatus,
+ Message,
+ R2RException,
+ SearchSettings,
+ WebSearchResult,
+ format_search_results_for_llm,
+)
+from core.base.api.models import RAGResponse, User
+from core.utils import (
+ CitationTracker,
+ SearchResultsCollector,
+ SSEFormatter,
+ dump_collector,
+ dump_obj,
+ extract_citations,
+ find_new_citation_spans,
+ num_tokens_from_messages,
+)
+from shared.api.models.management.responses import MessageResponse
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+class AgentFactory:
+ """
+ Factory class that creates appropriate agent instances based on mode,
+ model type, and streaming preferences.
+ """
+
+ @staticmethod
+ def create_agent(
+ mode: Literal["rag", "research"],
+ database_provider,
+ llm_provider,
+ config, # : AgentConfig
+ search_settings, # : SearchSettings
+ generation_config, #: GenerationConfig
+ app_config, #: AppConfig
+ knowledge_search_method,
+ content_method,
+ file_search_method,
+ max_tool_context_length: int = 32_768,
+ rag_tools: Optional[list[str]] = None,
+ research_tools: Optional[list[str]] = None,
+ tools: Optional[list[str]] = None, # For backward compatibility
+ ):
+ """
+ Creates and returns the appropriate agent based on provided parameters.
+
+ Args:
+ mode: Either "rag" or "research" to determine agent type
+ database_provider: Provider for database operations
+ llm_provider: Provider for LLM operations
+ config: Agent configuration
+ search_settings: Search settings for retrieval
+ generation_config: Generation configuration with LLM parameters
+ app_config: Application configuration
+ knowledge_search_method: Method for knowledge search
+ content_method: Method for content retrieval
+ file_search_method: Method for file search
+ max_tool_context_length: Maximum context length for tools
+ rag_tools: Tools specifically for RAG mode
+ research_tools: Tools specifically for Research mode
+ tools: Deprecated backward compatibility parameter
+
+ Returns:
+ An appropriate agent instance
+ """
+ # Create a deep copy of the config to avoid modifying the original
+ agent_config = deepcopy(config)
+
+ # Handle tool specifications based on mode
+ if mode == "rag":
+ # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults
+ if rag_tools:
+ agent_config.rag_tools = rag_tools
+ elif tools: # Backward compatibility
+ agent_config.rag_tools = tools
+ # If neither was provided, the config's default rag_tools will be used
+ elif mode == "research":
+ # For Research mode, prioritize explicitly passed research_tools, then tools, then config defaults
+ if research_tools:
+ agent_config.research_tools = research_tools
+ elif tools: # Backward compatibility
+ agent_config.research_tools = tools
+ # If neither was provided, the config's default research_tools will be used
+
+ # Determine if we need XML-based tools based on model
+ use_xml_format = False
+ # if generation_config.model:
+ # model_str = generation_config.model.lower()
+ # use_xml_format = "deepseek" in model_str or "gemini" in model_str
+
+ # Set streaming mode based on generation config
+ is_streaming = generation_config.stream
+
+ # Create the appropriate agent based on all factors
+ if mode == "rag":
+ # RAG mode agents
+ if is_streaming:
+ if use_xml_format:
+ return R2RXMLToolsStreamingRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RStreamingRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ if use_xml_format:
+ return R2RXMLToolsRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ # Research mode agents
+ if is_streaming:
+ if use_xml_format:
+ return R2RXMLToolsStreamingResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RStreamingResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ if use_xml_format:
+ return R2RXMLToolsResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class RetrievalService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def search(
+ self,
+ query: str,
+ search_settings: SearchSettings = SearchSettings(),
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Depending on search_settings.search_strategy, fan out
+ to basic, hyde, or rag_fusion method. Each returns
+ an AggregateSearchResult that includes chunk + graph results.
+ """
+ strategy = search_settings.search_strategy.lower()
+
+ if strategy == "hyde":
+ return await self._hyde_search(query, search_settings)
+ elif strategy == "rag_fusion":
+ return await self._rag_fusion_search(query, search_settings)
+ else:
+ # 'vanilla', 'basic', or anything else...
+ return await self._basic_search(query, search_settings)
+
+ async def _basic_search(
+ self, query: str, search_settings: SearchSettings
+ ) -> AggregateSearchResult:
+ """
+ 1) Possibly embed the query (if semantic or hybrid).
+ 2) Chunk search.
+ 3) Graph search.
+ 4) Combine into an AggregateSearchResult.
+ """
+ # -- 1) Possibly embed the query
+ query_vector = None
+ if (
+ search_settings.use_semantic_search
+ or search_settings.use_hybrid_search
+ ):
+ query_vector = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query # , EmbeddingPurpose.QUERY
+ )
+ )
+
+ # -- 2) Chunk search
+ chunk_results = []
+ if search_settings.chunk_settings.enabled:
+ chunk_results = await self._vector_search_logic(
+ query_text=query,
+ search_settings=search_settings,
+ precomputed_vector=query_vector, # Pass in the vector we just computed (if any)
+ )
+
+ # -- 3) Graph search
+ graph_results = []
+ if search_settings.graph_settings.enabled:
+ graph_results = await self._graph_search_logic(
+ query_text=query,
+ search_settings=search_settings,
+ precomputed_vector=query_vector, # same idea
+ )
+
+ # -- 4) Combine
+ return AggregateSearchResult(
+ chunk_search_results=chunk_results,
+ graph_search_results=graph_results,
+ )
+
+ async def _rag_fusion_search(
+ self, query: str, search_settings: SearchSettings
+ ) -> AggregateSearchResult:
+ """
+ Implements 'RAG Fusion':
+ 1) Generate N sub-queries from the user query
+ 2) For each sub-query => do chunk & graph search
+ 3) Combine / fuse all retrieved results using Reciprocal Rank Fusion
+ 4) Return an AggregateSearchResult
+ """
+
+ # 1) Generate sub-queries from the user’s original query
+ # Typically you want the original query to remain in the set as well,
+ # so that we do not lose the exact user intent.
+ sub_queries = [query]
+ if search_settings.num_sub_queries > 1:
+ # Generate (num_sub_queries - 1) rephrasings
+ # (Or just generate exactly search_settings.num_sub_queries,
+ # and remove the first if you prefer.)
+ extra = await self._generate_similar_queries(
+ query=query,
+ num_sub_queries=search_settings.num_sub_queries - 1,
+ )
+ sub_queries.extend(extra)
+
+ # 2) For each sub-query => do chunk + graph search
+ # We’ll store them in a structure so we can fuse them.
+ # chunk_results_list is a list of lists of ChunkSearchResult
+ # graph_results_list is a list of lists of GraphSearchResult
+ chunk_results_list = []
+ graph_results_list = []
+
+ for sq in sub_queries:
+ # Recompute or reuse the embedding if desired
+ # (You could do so, but not mandatory if you have a local approach)
+ # chunk + graph search
+ aggr = await self._basic_search(sq, search_settings)
+ chunk_results_list.append(aggr.chunk_search_results)
+ graph_results_list.append(aggr.graph_search_results)
+
+ # 3) Fuse the chunk results and fuse the graph results.
+ # We'll use a simple RRF approach: each sub-query's result list
+ # is a ranking from best to worst.
+ fused_chunk_results = self._reciprocal_rank_fusion_chunks( # type: ignore
+ chunk_results_list # type: ignore
+ )
+ filtered_graph_results = [
+ results for results in graph_results_list if results is not None
+ ]
+ fused_graph_results = self._reciprocal_rank_fusion_graphs(
+ filtered_graph_results
+ )
+
+ # Optionally, after the RRF, you may want to do a final semantic re-rank
+ # of the fused results by the user’s original query.
+ # E.g.:
+ if fused_chunk_results:
+ fused_chunk_results = (
+ await self.providers.completion_embedding.arerank(
+ query=query,
+ results=fused_chunk_results,
+ limit=search_settings.limit,
+ )
+ )
+
+ # Sort or slice the graph results if needed:
+ if fused_graph_results and search_settings.include_scores:
+ fused_graph_results.sort(
+ key=lambda g: g.score if g.score is not None else 0.0,
+ reverse=True,
+ )
+ fused_graph_results = fused_graph_results[: search_settings.limit]
+
+ # 4) Return final AggregateSearchResult
+ return AggregateSearchResult(
+ chunk_search_results=fused_chunk_results,
+ graph_search_results=fused_graph_results,
+ )
+
+ async def _generate_similar_queries(
+ self, query: str, num_sub_queries: int = 2
+ ) -> list[str]:
+ """
+ Use your LLM to produce 'similar' queries or rephrasings
+ that might retrieve different but relevant documents.
+
+ You can prompt your model with something like:
+ "Given the user query, produce N alternative short queries that
+ capture possible interpretations or expansions.
+ Keep them relevant to the user's intent."
+ """
+ if num_sub_queries < 1:
+ return []
+
+ # In production, you'd fetch a prompt from your prompts DB:
+ # Something like:
+ prompt = f"""
+ You are a helpful assistant. The user query is: "{query}"
+ Generate {num_sub_queries} alternative search queries that capture
+ slightly different phrasings or expansions while preserving the core meaning.
+ Return each alternative on its own line.
+ """
+
+ # For a short generation, we can set minimal tokens
+ gen_config = GenerationConfig(
+ model=self.config.app.fast_llm,
+ max_tokens=128,
+ temperature=0.8,
+ stream=False,
+ )
+ response = await self.providers.llm.aget_completion(
+ messages=[{"role": "system", "content": prompt}],
+ generation_config=gen_config,
+ )
+ raw_text = (
+ response.choices[0].message.content.strip()
+ if response.choices[0].message.content is not None
+ else ""
+ )
+
+ # Suppose each line is a sub-query
+ lines = [line.strip() for line in raw_text.split("\n") if line.strip()]
+ return lines[:num_sub_queries]
+
+ def _reciprocal_rank_fusion_chunks(
+ self, list_of_rankings: list[list[ChunkSearchResult]], k: float = 60.0
+ ) -> list[ChunkSearchResult]:
+ """
+ Simple RRF for chunk results.
+ list_of_rankings is something like:
+ [
+ [chunkA, chunkB, chunkC], # sub-query #1, in order
+ [chunkC, chunkD], # sub-query #2, in order
+ ...
+ ]
+
+ We'll produce a dictionary mapping chunk.id -> aggregated_score,
+ then sort descending.
+ """
+ if not list_of_rankings:
+ return []
+
+ # Build a map of chunk_id => final_rff_score
+ score_map: dict[str, float] = {}
+
+ # We also need to store a reference to the chunk object
+ # (the "first" or "best" instance), so we can reconstruct them later
+ chunk_map: dict[str, Any] = {}
+
+ for ranking_list in list_of_rankings:
+ for rank, chunk_result in enumerate(ranking_list, start=1):
+ if not chunk_result.id:
+ # fallback if no chunk_id is present
+ continue
+
+ c_id = chunk_result.id
+ # RRF scoring
+ # score = sum(1 / (k + rank)) for each sub-query ranking
+ # We'll accumulate it.
+ existing_score = score_map.get(str(c_id), 0.0)
+ new_score = existing_score + 1.0 / (k + rank)
+ score_map[str(c_id)] = new_score
+
+ # Keep a reference to chunk
+ if c_id not in chunk_map:
+ chunk_map[str(c_id)] = chunk_result
+
+ # Now sort by final score
+ fused_items = sorted(
+ score_map.items(), key=lambda x: x[1], reverse=True
+ )
+
+ # Rebuild the final list of chunk results with new 'score'
+ fused_chunks = []
+ for c_id, agg_score in fused_items: # type: ignore
+ # copy the chunk
+ c = chunk_map[str(c_id)]
+ # Optionally store the RRF score if you want
+ c.score = agg_score
+ fused_chunks.append(c)
+
+ return fused_chunks
+
+ def _reciprocal_rank_fusion_graphs(
+ self, list_of_rankings: list[list[GraphSearchResult]], k: float = 60.0
+ ) -> list[GraphSearchResult]:
+ """
+ Similar RRF logic but for graph results.
+ """
+ if not list_of_rankings:
+ return []
+
+ score_map: dict[str, float] = {}
+ graph_map = {}
+
+ for ranking_list in list_of_rankings:
+ for rank, g_result in enumerate(ranking_list, start=1):
+ # We'll do a naive ID approach:
+ # If your GraphSearchResult has a unique ID in g_result.content.id or so
+ # we can use that as a key.
+ # If not, you might have to build a key from the content.
+ g_id = None
+ if hasattr(g_result.content, "id"):
+ g_id = str(g_result.content.id)
+ else:
+ # fallback
+ g_id = f"graph_{hash(g_result.content.json())}"
+
+ existing_score = score_map.get(g_id, 0.0)
+ new_score = existing_score + 1.0 / (k + rank)
+ score_map[g_id] = new_score
+
+ if g_id not in graph_map:
+ graph_map[g_id] = g_result
+
+ # Sort descending by aggregated RRF score
+ fused_items = sorted(
+ score_map.items(), key=lambda x: x[1], reverse=True
+ )
+
+ fused_graphs = []
+ for g_id, agg_score in fused_items:
+ g = graph_map[g_id]
+ g.score = agg_score
+ fused_graphs.append(g)
+
+ return fused_graphs
+
+ async def _hyde_search(
+ self, query: str, search_settings: SearchSettings
+ ) -> AggregateSearchResult:
+ """
+ 1) Generate N hypothetical docs via LLM
+ 2) For each doc => embed => parallel chunk search & graph search
+ 3) Merge chunk results => optional re-rank => top K
+ 4) Merge graph results => (optionally re-rank or keep them distinct)
+ """
+ # 1) Generate hypothetical docs
+ hyde_docs = await self._run_hyde_generation(
+ query=query, num_sub_queries=search_settings.num_sub_queries
+ )
+
+ chunk_all = []
+ graph_all = []
+
+ # We'll gather the per-doc searches in parallel
+ tasks = []
+ for hypothetical_text in hyde_docs:
+ tasks.append(
+ asyncio.create_task(
+ self._fanout_chunk_and_graph_search(
+ user_text=query, # The user’s original query
+ alt_text=hypothetical_text, # The hypothetical doc
+ search_settings=search_settings,
+ )
+ )
+ )
+
+ # 2) Wait for them all
+ results_list = await asyncio.gather(*tasks)
+ # each item in results_list is a tuple: (chunks, graphs)
+
+ # Flatten chunk+graph results
+ for c_results, g_results in results_list:
+ chunk_all.extend(c_results)
+ graph_all.extend(g_results)
+
+ # 3) Re-rank chunk results with the original query
+ if chunk_all:
+ chunk_all = await self.providers.completion_embedding.arerank(
+ query=query, # final user query
+ results=chunk_all,
+ limit=int(
+ search_settings.limit * search_settings.num_sub_queries
+ ),
+ # no limit on results - limit=search_settings.limit,
+ )
+
+ # 4) If needed, re-rank graph results or just slice top-K by score
+ if search_settings.include_scores and graph_all:
+ graph_all.sort(key=lambda g: g.score or 0.0, reverse=True)
+ graph_all = (
+ graph_all # no limit on results - [: search_settings.limit]
+ )
+
+ return AggregateSearchResult(
+ chunk_search_results=chunk_all,
+ graph_search_results=graph_all,
+ )
+
+ async def _fanout_chunk_and_graph_search(
+ self,
+ user_text: str,
+ alt_text: str,
+ search_settings: SearchSettings,
+ ) -> tuple[list[ChunkSearchResult], list[GraphSearchResult]]:
+ """
+ 1) embed alt_text (HyDE doc or sub-query, etc.)
+ 2) chunk search + graph search with that embedding
+ """
+ # Precompute the embedding of alt_text
+ vec = await self.providers.completion_embedding.async_get_embedding(
+ alt_text # , EmbeddingPurpose.QUERY
+ )
+
+ # chunk search
+ chunk_results = []
+ if search_settings.chunk_settings.enabled:
+ chunk_results = await self._vector_search_logic(
+ query_text=user_text, # used for text-based stuff & re-ranking
+ search_settings=search_settings,
+ precomputed_vector=vec, # use the alt_text vector for semantic/hybrid
+ )
+
+ # graph search
+ graph_results = []
+ if search_settings.graph_settings.enabled:
+ graph_results = await self._graph_search_logic(
+ query_text=user_text, # or alt_text if you prefer
+ search_settings=search_settings,
+ precomputed_vector=vec,
+ )
+
+ return (chunk_results, graph_results)
+
+ async def _vector_search_logic(
+ self,
+ query_text: str,
+ search_settings: SearchSettings,
+ precomputed_vector: Optional[list[float]] = None,
+ ) -> list[ChunkSearchResult]:
+ """
+ • If precomputed_vector is given, use it for semantic/hybrid search.
+ Otherwise embed query_text ourselves.
+ • Then do fulltext, semantic, or hybrid search.
+ • Optionally re-rank and return results.
+ """
+ if not search_settings.chunk_settings.enabled:
+ return []
+
+ # 1) Possibly embed
+ query_vector = precomputed_vector
+ if query_vector is None and (
+ search_settings.use_semantic_search
+ or search_settings.use_hybrid_search
+ ):
+ query_vector = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query_text # , EmbeddingPurpose.QUERY
+ )
+ )
+
+ # 2) Choose which search to run
+ if (
+ search_settings.use_fulltext_search
+ and search_settings.use_semantic_search
+ ) or search_settings.use_hybrid_search:
+ if query_vector is None:
+ raise ValueError("Hybrid search requires a precomputed vector")
+ raw_results = (
+ await self.providers.database.chunks_handler.hybrid_search(
+ query_vector=query_vector,
+ query_text=query_text,
+ search_settings=search_settings,
+ )
+ )
+ elif search_settings.use_fulltext_search:
+ raw_results = (
+ await self.providers.database.chunks_handler.full_text_search(
+ query_text=query_text,
+ search_settings=search_settings,
+ )
+ )
+ elif search_settings.use_semantic_search:
+ if query_vector is None:
+ raise ValueError(
+ "Semantic search requires a precomputed vector"
+ )
+ raw_results = (
+ await self.providers.database.chunks_handler.semantic_search(
+ query_vector=query_vector,
+ search_settings=search_settings,
+ )
+ )
+ else:
+ raise ValueError(
+ "At least one of use_fulltext_search or use_semantic_search must be True"
+ )
+
+ # 3) Re-rank
+ reranked = await self.providers.completion_embedding.arerank(
+ query=query_text, results=raw_results, limit=search_settings.limit
+ )
+
+ # 4) Possibly augment text or metadata
+ final_results = []
+ for r in reranked:
+ if "title" in r.metadata and search_settings.include_metadatas:
+ title = r.metadata["title"]
+ r.text = f"Document Title: {title}\n\nText: {r.text}"
+ r.metadata["associated_query"] = query_text
+ final_results.append(r)
+
+ return final_results
+
+ async def _graph_search_logic(
+ self,
+ query_text: str,
+ search_settings: SearchSettings,
+ precomputed_vector: Optional[list[float]] = None,
+ ) -> list[GraphSearchResult]:
+ """
+ Mirrors your previous GraphSearch approach:
+ • if precomputed_vector is supplied, use that
+ • otherwise embed query_text
+ • search entities, relationships, communities
+ • return results
+ """
+ results: list[GraphSearchResult] = []
+
+ if not search_settings.graph_settings.enabled:
+ return results
+
+ # 1) Possibly embed
+ query_embedding = precomputed_vector
+ if query_embedding is None:
+ query_embedding = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query_text
+ )
+ )
+
+ base_limit = search_settings.limit
+ graph_limits = search_settings.graph_settings.limits or {}
+
+ # Entity search
+ entity_limit = graph_limits.get("entities", base_limit)
+ entity_cursor = self.providers.database.graphs_handler.graph_search(
+ query_text,
+ search_type="entities",
+ limit=entity_limit,
+ query_embedding=query_embedding,
+ property_names=["name", "description", "id"],
+ filters=search_settings.filters,
+ )
+ async for ent in entity_cursor:
+ score = ent.get("similarity_score")
+ metadata = ent.get("metadata", {})
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except Exception as e:
+ pass
+
+ results.append(
+ GraphSearchResult(
+ id=ent.get("id", None),
+ content=GraphEntityResult(
+ name=ent.get("name", ""),
+ description=ent.get("description", ""),
+ id=ent.get("id", None),
+ ),
+ result_type=GraphSearchResultType.ENTITY,
+ score=score if search_settings.include_scores else None,
+ metadata=(
+ {
+ **(metadata or {}),
+ "associated_query": query_text,
+ }
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ )
+
+ # Relationship search
+ rel_limit = graph_limits.get("relationships", base_limit)
+ rel_cursor = self.providers.database.graphs_handler.graph_search(
+ query_text,
+ search_type="relationships",
+ limit=rel_limit,
+ query_embedding=query_embedding,
+ property_names=[
+ "id",
+ "subject",
+ "predicate",
+ "object",
+ "description",
+ "subject_id",
+ "object_id",
+ ],
+ filters=search_settings.filters,
+ )
+ async for rel in rel_cursor:
+ score = rel.get("similarity_score")
+ metadata = rel.get("metadata", {})
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except Exception as e:
+ pass
+
+ results.append(
+ GraphSearchResult(
+ id=ent.get("id", None),
+ content=GraphRelationshipResult(
+ id=rel.get("id", None),
+ subject=rel.get("subject", ""),
+ predicate=rel.get("predicate", ""),
+ object=rel.get("object", ""),
+ subject_id=rel.get("subject_id", None),
+ object_id=rel.get("object_id", None),
+ description=rel.get("description", ""),
+ ),
+ result_type=GraphSearchResultType.RELATIONSHIP,
+ score=score if search_settings.include_scores else None,
+ metadata=(
+ {
+ **(metadata or {}),
+ "associated_query": query_text,
+ }
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ )
+
+ # Community search
+ comm_limit = graph_limits.get("communities", base_limit)
+ comm_cursor = self.providers.database.graphs_handler.graph_search(
+ query_text,
+ search_type="communities",
+ limit=comm_limit,
+ query_embedding=query_embedding,
+ property_names=[
+ "id",
+ "name",
+ "summary",
+ ],
+ filters=search_settings.filters,
+ )
+ async for comm in comm_cursor:
+ score = comm.get("similarity_score")
+ metadata = comm.get("metadata", {})
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except Exception as e:
+ pass
+
+ results.append(
+ GraphSearchResult(
+ id=ent.get("id", None),
+ content=GraphCommunityResult(
+ id=comm.get("id", None),
+ name=comm.get("name", ""),
+ summary=comm.get("summary", ""),
+ ),
+ result_type=GraphSearchResultType.COMMUNITY,
+ score=score if search_settings.include_scores else None,
+ metadata=(
+ {
+ **(metadata or {}),
+ "associated_query": query_text,
+ }
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ )
+
+ return results
+
+ async def _run_hyde_generation(
+ self,
+ query: str,
+ num_sub_queries: int = 2,
+ ) -> list[str]:
+ """
+ Calls the LLM with a 'HyDE' style prompt to produce multiple
+ hypothetical documents/answers, one per line or separated by blank lines.
+ """
+ # Retrieve the prompt template from your database or config:
+ # e.g. your "hyde" prompt has placeholders: {message}, {num_outputs}
+ hyde_template = (
+ await self.providers.database.prompts_handler.get_cached_prompt(
+ prompt_name="hyde",
+ inputs={"message": query, "num_outputs": num_sub_queries},
+ )
+ )
+
+ # Now call the LLM with that as the system or user prompt:
+ completion_config = GenerationConfig(
+ model=self.config.app.fast_llm, # or whichever short/cheap model
+ max_tokens=512,
+ temperature=0.7,
+ stream=False,
+ )
+
+ response = await self.providers.llm.aget_completion(
+ messages=[{"role": "system", "content": hyde_template}],
+ generation_config=completion_config,
+ )
+
+ # Suppose the LLM returns something like:
+ #
+ # "Doc1. Some made up text.\n\nDoc2. Another made up text.\n\n"
+ #
+ # So we split by double-newline or some pattern:
+ raw_text = response.choices[0].message.content
+ return [
+ chunk.strip()
+ for chunk in (raw_text or "").split("\n\n")
+ if chunk.strip()
+ ]
+
+ async def search_documents(
+ self,
+ query: str,
+ settings: SearchSettings,
+ query_embedding: Optional[list[float]] = None,
+ ) -> list[DocumentResponse]:
+ if query_embedding is None:
+ query_embedding = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query
+ )
+ )
+ result = (
+ await self.providers.database.documents_handler.search_documents(
+ query_text=query,
+ settings=settings,
+ query_embedding=query_embedding,
+ )
+ )
+ return result
+
+ async def completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ *args,
+ **kwargs,
+ ):
+ return await self.providers.llm.aget_completion(
+ [message.to_dict() for message in messages], # type: ignore
+ generation_config,
+ *args,
+ **kwargs,
+ )
+
+ async def embedding(
+ self,
+ text: str,
+ ):
+ return await self.providers.completion_embedding.async_get_embedding(
+ text=text
+ )
+
+ async def rag(
+ self,
+ query: str,
+ rag_generation_config: GenerationConfig,
+ search_settings: SearchSettings = SearchSettings(),
+ system_prompt_name: str | None = None,
+ task_prompt_name: str | None = None,
+ include_web_search: bool = False,
+ **kwargs,
+ ) -> Any:
+ """
+ A single RAG method that can do EITHER a one-shot synchronous RAG or
+ streaming SSE-based RAG, depending on rag_generation_config.stream.
+
+ 1) Perform aggregator search => context
+ 2) Build system+task prompts => messages
+ 3) If not streaming => normal LLM call => return RAGResponse
+ 4) If streaming => return an async generator of SSE lines
+ """
+ # 1) Possibly fix up any UUID filters in search_settings
+ for f, val in list(search_settings.filters.items()):
+ if isinstance(val, UUID):
+ search_settings.filters[f] = str(val)
+
+ try:
+ # 2) Perform search => aggregated_results
+ aggregated_results = await self.search(query, search_settings)
+ # 3) Optionally add web search results if flag is enabled
+ if include_web_search:
+ web_results = await self._perform_web_search(query)
+ # Merge web search results with existing aggregated results
+ if web_results and web_results.web_search_results:
+ if not aggregated_results.web_search_results:
+ aggregated_results.web_search_results = (
+ web_results.web_search_results
+ )
+ else:
+ aggregated_results.web_search_results.extend(
+ web_results.web_search_results
+ )
+ # 3) Build context from aggregator
+ collector = SearchResultsCollector()
+ collector.add_aggregate_result(aggregated_results)
+ context_str = format_search_results_for_llm(
+ aggregated_results, collector
+ )
+
+ # 4) Prepare system+task messages
+ system_prompt_name = system_prompt_name or "system"
+ task_prompt_name = task_prompt_name or "rag"
+ task_prompt = kwargs.get("task_prompt")
+
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ system_prompt_name=system_prompt_name,
+ task_prompt_name=task_prompt_name,
+ task_inputs={"query": query, "context": context_str},
+ task_prompt=task_prompt,
+ )
+
+ # 5) Check streaming vs. non-streaming
+ if not rag_generation_config.stream:
+ # ========== Non-Streaming Logic ==========
+ response = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=rag_generation_config,
+ )
+ llm_text = response.choices[0].message.content
+
+ # (a) Extract short-ID references from final text
+ raw_sids = extract_citations(llm_text or "")
+
+ # (b) Possibly prune large content out of metadata
+ metadata = response.dict()
+ if "choices" in metadata and len(metadata["choices"]) > 0:
+ metadata["choices"][0]["message"].pop("content", None)
+
+ # (c) Build final RAGResponse
+ rag_resp = RAGResponse(
+ generated_answer=llm_text or "",
+ search_results=aggregated_results,
+ citations=[
+ Citation(
+ id=f"{sid}",
+ object="citation",
+ payload=dump_obj( # type: ignore
+ self._find_item_by_shortid(sid, collector)
+ ),
+ )
+ for sid in raw_sids
+ ],
+ metadata=metadata,
+ completion=llm_text or "",
+ )
+ return rag_resp
+
+ else:
+ # ========== Streaming SSE Logic ==========
+ async def sse_generator() -> AsyncGenerator[str, None]:
+ # 1) Emit search results via SSEFormatter
+ async for line in SSEFormatter.yield_search_results_event(
+ aggregated_results
+ ):
+ yield line
+
+ # Initialize citation tracker to manage citation state
+ citation_tracker = CitationTracker()
+
+ # Store citation payloads by ID for reuse
+ citation_payloads = {}
+
+ partial_text_buffer = ""
+
+ # Begin streaming from the LLM
+ msg_stream = self.providers.llm.aget_completion_stream(
+ messages=messages,
+ generation_config=rag_generation_config,
+ )
+
+ try:
+ async for chunk in msg_stream:
+ delta = chunk.choices[0].delta
+ finish_reason = chunk.choices[0].finish_reason
+ # if delta.thinking:
+ # check if delta has `thinking` attribute
+
+ if hasattr(delta, "thinking") and delta.thinking:
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.thinking
+ ):
+ yield line
+
+ if delta.content:
+ # (b) Emit SSE "message" event for this chunk of text
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ delta.content
+ ):
+ yield line
+
+ # Accumulate new text
+ partial_text_buffer += delta.content
+
+ # (a) Extract citations from updated buffer
+ # For each *new* short ID, emit an SSE "citation" event
+ # Find new citation spans in the accumulated text
+ new_citation_spans = find_new_citation_spans(
+ partial_text_buffer, citation_tracker
+ )
+
+ # Process each new citation span
+ for cid, spans in new_citation_spans.items():
+ for span in spans:
+ # Check if this is the first time we've seen this citation ID
+ is_new_citation = (
+ citation_tracker.is_new_citation(
+ cid
+ )
+ )
+
+ # Get payload if it's a new citation
+ payload = None
+ if is_new_citation:
+ source_obj = (
+ self._find_item_by_shortid(
+ cid, collector
+ )
+ )
+ if source_obj:
+ # Store payload for reuse
+ payload = dump_obj(source_obj)
+ citation_payloads[cid] = (
+ payload
+ )
+
+ # Create citation event payload
+ citation_data = {
+ "id": cid,
+ "object": "citation",
+ "is_new": is_new_citation,
+ "span": {
+ "start": span[0],
+ "end": span[1],
+ },
+ }
+
+ # Only include full payload for new citations
+ if is_new_citation and payload:
+ citation_data["payload"] = payload
+
+ # Emit the citation event
+ async for (
+ line
+ ) in SSEFormatter.yield_citation_event(
+ citation_data
+ ):
+ yield line
+
+ # If the LLM signals it’s done
+ if finish_reason == "stop":
+ # Prepare consolidated citations for final answer event
+ consolidated_citations = []
+ # Group citations by ID with all their spans
+ for (
+ cid,
+ spans,
+ ) in citation_tracker.get_all_spans().items():
+ if cid in citation_payloads:
+ consolidated_citations.append(
+ {
+ "id": cid,
+ "object": "citation",
+ "spans": [
+ {
+ "start": s[0],
+ "end": s[1],
+ }
+ for s in spans
+ ],
+ "payload": citation_payloads[
+ cid
+ ],
+ }
+ )
+
+ # (c) Emit final answer + all collected citations
+ final_answer_evt = {
+ "id": "msg_final",
+ "object": "rag.final_answer",
+ "generated_answer": partial_text_buffer,
+ "citations": consolidated_citations,
+ }
+ async for (
+ line
+ ) in SSEFormatter.yield_final_answer_event(
+ final_answer_evt
+ ):
+ yield line
+
+ # (d) Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ break
+
+ except Exception as e:
+ logger.error(f"Error streaming LLM in rag: {e}")
+ # Optionally yield an SSE "error" event or handle differently
+ raise
+
+ return sse_generator()
+
+ except Exception as e:
+ logger.exception(f"Error in RAG pipeline: {e}")
+ if "NoneType" in str(e):
+ raise HTTPException(
+ status_code=502,
+ detail="Server not reachable or returned an invalid response",
+ ) from e
+ raise HTTPException(
+ status_code=500,
+ detail=f"Internal RAG Error - {str(e)}",
+ ) from e
+
+ def _find_item_by_shortid(
+ self, sid: str, collector: SearchResultsCollector
+ ) -> Optional[tuple[str, Any, int]]:
+ """
+ Example helper that tries to match aggregator items by short ID,
+ meaning result_obj.id starts with sid.
+ """
+ for source_type, result_obj in collector.get_all_results():
+ # if the aggregator item has an 'id' attribute
+ if getattr(result_obj, "id", None) is not None:
+ full_id_str = str(result_obj.id)
+ if full_id_str.startswith(sid):
+ if source_type == "chunk":
+ return (
+ result_obj.as_dict()
+ ) # (source_type, result_obj.as_dict())
+ else:
+ return result_obj # (source_type, result_obj)
+ return None
+
+ async def agent(
+ self,
+ rag_generation_config: GenerationConfig,
+ rag_tools: Optional[list[str]] = None,
+ tools: Optional[list[str]] = None, # backward compatibility
+ search_settings: SearchSettings = SearchSettings(),
+ task_prompt: Optional[str] = None,
+ include_title_if_available: Optional[bool] = False,
+ conversation_id: Optional[UUID] = None,
+ message: Optional[Message] = None,
+ messages: Optional[list[Message]] = None,
+ use_system_context: bool = False,
+ max_tool_context_length: int = 32_768,
+ research_tools: Optional[list[str]] = None,
+ research_generation_config: Optional[GenerationConfig] = None,
+ needs_initial_conversation_name: Optional[bool] = None,
+ mode: Optional[Literal["rag", "research"]] = "rag",
+ ):
+ """
+ Engage with an intelligent agent for information retrieval, analysis, and research.
+
+ Args:
+ rag_generation_config: Configuration for RAG mode generation
+ search_settings: Search configuration for retrieving context
+ task_prompt: Optional custom prompt override
+ include_title_if_available: Whether to include document titles
+ conversation_id: Optional conversation ID for continuity
+ message: Current message to process
+ messages: List of messages (deprecated)
+ use_system_context: Whether to use extended prompt
+ max_tool_context_length: Maximum context length for tools
+ rag_tools: List of tools for RAG mode
+ research_tools: List of tools for Research mode
+ research_generation_config: Configuration for Research mode generation
+ mode: Either "rag" or "research"
+
+ Returns:
+ Agent response with messages and conversation ID
+ """
+ try:
+ # Validate message inputs
+ if message and messages:
+ raise R2RException(
+ status_code=400,
+ message="Only one of message or messages should be provided",
+ )
+
+ if not message and not messages:
+ raise R2RException(
+ status_code=400,
+ message="Either message or messages should be provided",
+ )
+
+ # Ensure 'message' is a Message instance
+ if message and not isinstance(message, Message):
+ if isinstance(message, dict):
+ message = Message.from_dict(message)
+ else:
+ raise R2RException(
+ status_code=400,
+ message="""
+ Invalid message format. The expected format contains:
+ role: MessageType | 'system' | 'user' | 'assistant' | 'function'
+ content: Optional[str]
+ name: Optional[str]
+ function_call: Optional[dict[str, Any]]
+ tool_calls: Optional[list[dict[str, Any]]]
+ """,
+ )
+
+ # Ensure 'messages' is a list of Message instances
+ if messages:
+ processed_messages = []
+ for msg in messages:
+ if isinstance(msg, Message):
+ processed_messages.append(msg)
+ elif hasattr(msg, "dict"):
+ processed_messages.append(
+ Message.from_dict(msg.dict())
+ )
+ elif isinstance(msg, dict):
+ processed_messages.append(Message.from_dict(msg))
+ else:
+ processed_messages.append(Message.from_dict(str(msg)))
+ messages = processed_messages
+ else:
+ messages = []
+
+ # Validate and process mode-specific configurations
+ if mode == "rag" and research_tools:
+ logger.warning(
+ "research_tools provided but mode is 'rag'. These tools will be ignored."
+ )
+ research_tools = None
+
+ # Determine effective generation config based on mode
+ effective_generation_config = rag_generation_config
+ if mode == "research" and research_generation_config:
+ effective_generation_config = research_generation_config
+
+ # Set appropriate LLM model based on mode if not explicitly specified
+ if "model" not in effective_generation_config.__fields_set__:
+ if mode == "rag":
+ effective_generation_config.model = (
+ self.config.app.quality_llm
+ )
+ elif mode == "research":
+ effective_generation_config.model = (
+ self.config.app.planning_llm
+ )
+
+ # Transform UUID filters to strings
+ for filter_key, value in search_settings.filters.items():
+ if isinstance(value, UUID):
+ search_settings.filters[filter_key] = str(value)
+
+ # Process conversation data
+ ids = []
+ if conversation_id: # Fetch the existing conversation
+ try:
+ conversation_messages = await self.providers.database.conversations_handler.get_conversation(
+ conversation_id=conversation_id,
+ )
+ if needs_initial_conversation_name is None:
+ overview = await self.providers.database.conversations_handler.get_conversations_overview(
+ offset=0,
+ limit=1,
+ conversation_ids=[conversation_id],
+ )
+ if overview.get("total_entries", 0) > 0:
+ needs_initial_conversation_name = (
+ overview.get("results")[0].get("name") is None # type: ignore
+ )
+ except Exception as e:
+ logger.error(f"Error fetching conversation: {str(e)}")
+
+ if conversation_messages is not None:
+ messages_from_conversation: list[Message] = []
+ for message_response in conversation_messages:
+ if isinstance(message_response, MessageResponse):
+ messages_from_conversation.append(
+ message_response.message
+ )
+ ids.append(message_response.id)
+ else:
+ logger.warning(
+ f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
+ )
+ messages = messages_from_conversation + messages
+ else: # Create new conversation
+ conversation_response = await self.providers.database.conversations_handler.create_conversation()
+ conversation_id = conversation_response.id
+ needs_initial_conversation_name = True
+
+ if message:
+ messages.append(message)
+
+ if not messages:
+ raise R2RException(
+ status_code=400,
+ message="No messages to process",
+ )
+
+ current_message = messages[-1]
+ logger.debug(
+ f"Running the agent with conversation_id = {conversation_id} and message = {current_message}"
+ )
+
+ # Save the new message to the conversation
+ parent_id = ids[-1] if ids else None
+ message_response = await self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=current_message,
+ parent_id=parent_id,
+ )
+
+ message_id = (
+ message_response.id if message_response is not None else None
+ )
+
+ # Extract filter information from search settings
+ filter_user_id, filter_collection_ids = (
+ self._parse_user_and_collection_filters(
+ search_settings.filters
+ )
+ )
+
+ # Validate system instruction configuration
+ if use_system_context and task_prompt:
+ raise R2RException(
+ status_code=400,
+ message="Both use_system_context and task_prompt cannot be True at the same time",
+ )
+
+ # Build the system instruction
+ if task_prompt:
+ system_instruction = task_prompt
+ else:
+ system_instruction = (
+ await self._build_aware_system_instruction(
+ max_tool_context_length=max_tool_context_length,
+ filter_user_id=filter_user_id,
+ filter_collection_ids=filter_collection_ids,
+ model=effective_generation_config.model,
+ use_system_context=use_system_context,
+ mode=mode,
+ )
+ )
+
+ # Configure agent with appropriate tools
+ agent_config = deepcopy(self.config.agent)
+ if mode == "rag":
+ # Use provided RAG tools or default from config
+ agent_config.rag_tools = (
+ rag_tools or tools or self.config.agent.rag_tools
+ )
+ else: # research mode
+ # Use provided Research tools or default from config
+ agent_config.research_tools = (
+ research_tools or tools or self.config.agent.research_tools
+ )
+
+ # Create the agent using our factory
+ mode = mode or "rag"
+
+ for msg in messages:
+ if msg.content is None:
+ msg.content = ""
+
+ agent = AgentFactory.create_agent(
+ mode=mode,
+ database_provider=self.providers.database,
+ llm_provider=self.providers.llm,
+ config=agent_config,
+ search_settings=search_settings,
+ generation_config=effective_generation_config,
+ app_config=self.config.app,
+ knowledge_search_method=self.search,
+ content_method=self.get_context,
+ file_search_method=self.search_documents,
+ max_tool_context_length=max_tool_context_length,
+ rag_tools=rag_tools,
+ research_tools=research_tools,
+ tools=tools, # Backward compatibility
+ )
+
+ # Handle streaming vs. non-streaming response
+ if effective_generation_config.stream:
+
+ async def stream_response():
+ try:
+ async for chunk in agent.arun(
+ messages=messages,
+ system_instruction=system_instruction,
+ include_title_if_available=include_title_if_available,
+ ):
+ yield chunk
+ except Exception as e:
+ logger.error(f"Error streaming agent output: {e}")
+ raise e
+ finally:
+ # Persist conversation data
+ msgs = [
+ msg.to_dict()
+ for msg in agent.conversation.messages
+ ]
+ input_tokens = num_tokens_from_messages(msgs[:-1])
+ output_tokens = num_tokens_from_messages([msgs[-1]])
+ await self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=agent.conversation.messages[-1],
+ parent_id=message_id,
+ metadata={
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ },
+ )
+
+ # Generate conversation name if needed
+ if needs_initial_conversation_name:
+ try:
+ prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict())}"
+ conversation_name = (
+ (
+ await self.providers.llm.aget_completion(
+ [
+ {
+ "role": "system",
+ "content": prompt,
+ }
+ ],
+ GenerationConfig(
+ model=self.config.app.fast_llm
+ ),
+ )
+ )
+ .choices[0]
+ .message.content
+ )
+ await self.providers.database.conversations_handler.update_conversation(
+ conversation_id=conversation_id,
+ name=conversation_name,
+ )
+ except Exception as e:
+ logger.error(
+ f"Error generating conversation name: {e}"
+ )
+
+ return stream_response()
+ else:
+ for idx, msg in enumerate(messages):
+ if msg.content is None:
+ if (
+ hasattr(msg, "structured_content")
+ and msg.structured_content
+ ):
+ messages[idx].content = ""
+ else:
+ messages[idx].content = ""
+
+ # Non-streaming path
+ results = await agent.arun(
+ messages=messages,
+ system_instruction=system_instruction,
+ include_title_if_available=include_title_if_available,
+ )
+
+ # Process the agent results
+ if isinstance(results[-1], dict):
+ if results[-1].get("content") is None:
+ results[-1]["content"] = ""
+ assistant_message = Message(**results[-1])
+ elif isinstance(results[-1], Message):
+ assistant_message = results[-1]
+ if assistant_message.content is None:
+ assistant_message.content = ""
+ else:
+ assistant_message = Message(
+ role="assistant", content=str(results[-1])
+ )
+
+ # Get search results collector for citations
+ if hasattr(agent, "search_results_collector"):
+ collector = agent.search_results_collector
+ else:
+ collector = SearchResultsCollector()
+
+ # Extract content from the message
+ structured_content = assistant_message.structured_content
+ structured_content = (
+ structured_content[-1].get("text")
+ if structured_content
+ else None
+ )
+ raw_text = (
+ assistant_message.content or structured_content or ""
+ )
+ # Process citations
+ short_ids = extract_citations(raw_text or "")
+ final_citations = []
+ for sid in short_ids:
+ obj = collector.find_by_short_id(sid)
+ final_citations.append(
+ {
+ "id": sid,
+ "object": "citation",
+ "payload": dump_obj(obj) if obj else None,
+ }
+ )
+
+ # Persist in conversation DB
+ await (
+ self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=assistant_message,
+ parent_id=message_id,
+ metadata={
+ "citations": final_citations,
+ "aggregated_search_result": json.dumps(
+ dump_collector(collector)
+ ),
+ },
+ )
+ )
+
+ # Generate conversation name if needed
+ if needs_initial_conversation_name:
+ conversation_name = None
+ try:
+ prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict() if message else {})}"
+ conversation_name = (
+ (
+ await self.providers.llm.aget_completion(
+ [{"role": "system", "content": prompt}],
+ GenerationConfig(
+ model=self.config.app.fast_llm
+ ),
+ )
+ )
+ .choices[0]
+ .message.content
+ )
+ except Exception as e:
+ pass
+ finally:
+ await self.providers.database.conversations_handler.update_conversation(
+ conversation_id=conversation_id,
+ name=conversation_name or "",
+ )
+
+ tool_calls = []
+ if hasattr(agent, "tool_calls"):
+ if agent.tool_calls is not None:
+ tool_calls = agent.tool_calls
+ else:
+ logger.warning(
+ "agent.tool_calls is None, using empty list instead"
+ )
+ # Return the final response
+ return {
+ "messages": [
+ Message(
+ role="assistant",
+ content=assistant_message.content
+ or structured_content
+ or "",
+ metadata={
+ "citations": final_citations,
+ "tool_calls": tool_calls,
+ "aggregated_search_result": json.dumps(
+ dump_collector(collector)
+ ),
+ },
+ )
+ ],
+ "conversation_id": str(conversation_id),
+ }
+
+ except Exception as e:
+ logger.error(f"Error in agent response: {str(e)}")
+ if "NoneType" in str(e):
+ raise HTTPException(
+ status_code=502,
+ detail="Server not reachable or returned an invalid response",
+ ) from e
+ raise HTTPException(
+ status_code=500,
+ detail=f"Internal Server Error - {str(e)}",
+ ) from e
+
+ async def get_context(
+ self,
+ filters: dict[str, Any],
+ options: dict[str, Any],
+ ) -> list[dict[str, Any]]:
+ """
+ Return an ordered list of documents (with minimal overview fields),
+ plus all associated chunks in ascending chunk order.
+
+ Only the filters: owner_id, collection_ids, and document_id
+ are supported. If any other filter or operator is passed in,
+ we raise an error.
+
+ Args:
+ filters: A dictionary describing the allowed filters
+ (owner_id, collection_ids, document_id).
+ options: A dictionary with extra options, e.g. include_summary_embedding
+ or any custom flags for additional logic.
+
+ Returns:
+ A list of dicts, where each dict has:
+ {
+ "document": <DocumentResponse>,
+ "chunks": [ <chunk0>, <chunk1>, ... ]
+ }
+ """
+ # 2. Fetch matching documents
+ matching_docs = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=-1,
+ filters=filters,
+ include_summary_embedding=options.get(
+ "include_summary_embedding", False
+ ),
+ )
+
+ if not matching_docs["results"]:
+ return []
+
+ # 3. For each document, fetch associated chunks in ascending chunk order
+ results = []
+ for doc_response in matching_docs["results"]:
+ doc_id = doc_response.id
+ chunk_data = await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=doc_id,
+ offset=0,
+ limit=-1, # get all chunks
+ include_vectors=False,
+ )
+ chunks = chunk_data["results"] # already sorted by chunk_order
+ doc_response.chunks = chunks
+ # 4. Build a returned structure that includes doc + chunks
+ results.append(doc_response.model_dump())
+
+ return results
+
+ def _parse_user_and_collection_filters(
+ self,
+ filters: dict[str, Any],
+ ):
+ ### TODO - Come up with smarter way to extract owner / collection ids for non-admin
+ filter_starts_with_and = filters.get("$and")
+ filter_starts_with_or = filters.get("$or")
+ if filter_starts_with_and:
+ try:
+ filter_starts_with_and_then_or = filter_starts_with_and[0][
+ "$or"
+ ]
+
+ user_id = filter_starts_with_and_then_or[0]["owner_id"]["$eq"]
+ collection_ids = [
+ UUID(ele)
+ for ele in filter_starts_with_and_then_or[1][
+ "collection_ids"
+ ]["$overlap"]
+ ]
+ return user_id, [str(ele) for ele in collection_ids]
+ except Exception as e:
+ logger.error(
+ f"Error: {e}.\n\n While"
+ + """ parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored."""
+ )
+ return None, []
+ elif filter_starts_with_or:
+ try:
+ user_id = filter_starts_with_or[0]["owner_id"]["$eq"]
+ collection_ids = [
+ UUID(ele)
+ for ele in filter_starts_with_or[1]["collection_ids"][
+ "$overlap"
+ ]
+ ]
+ return user_id, [str(ele) for ele in collection_ids]
+ except Exception as e:
+ logger.error(
+ """Error parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored."""
+ )
+ return None, []
+ else:
+ # Admin user
+ return None, []
+
+ async def _build_documents_context(
+ self,
+ filter_user_id: Optional[UUID] = None,
+ max_summary_length: int = 128,
+ limit: int = 25,
+ reverse_order: bool = True,
+ ) -> str:
+ """
+ Fetches documents matching the given filters and returns a formatted string
+ enumerating them.
+ """
+ # We only want up to `limit` documents for brevity
+ docs_data = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=limit,
+ filter_user_ids=[filter_user_id] if filter_user_id else None,
+ include_summary_embedding=False,
+ sort_order="DESC" if reverse_order else "ASC",
+ )
+
+ found_max = False
+ if len(docs_data["results"]) == limit:
+ found_max = True
+
+ docs = docs_data["results"]
+ if not docs:
+ return "No documents found."
+
+ lines = []
+ for i, doc in enumerate(docs, start=1):
+ if (
+ not doc.summary
+ or doc.ingestion_status != IngestionStatus.SUCCESS
+ ):
+ lines.append(
+ f"[{i}] Title: {doc.title}, Summary: (Summary not available), Status:{doc.ingestion_status} ID: {doc.id}"
+ )
+ continue
+
+ # Build a line referencing the doc
+ title = doc.title or "(Untitled Document)"
+ lines.append(
+ f"[{i}] Title: {title}, Summary: {(doc.summary[0:max_summary_length] + ('...' if len(doc.summary) > max_summary_length else ''),)}, Total Tokens: {doc.total_tokens}, ID: {doc.id}"
+ )
+ if found_max:
+ lines.append(
+ f"Note: Displaying only the first {limit} documents. Use a filter to narrow down the search if more documents are required."
+ )
+
+ return "\n".join(lines)
+
+ async def _build_aware_system_instruction(
+ self,
+ max_tool_context_length: int = 10_000,
+ filter_user_id: Optional[UUID] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ model: Optional[str] = None,
+ use_system_context: bool = False,
+ mode: Optional[str] = "rag",
+ ) -> str:
+ """
+ High-level method that:
+ 1) builds the documents context
+ 2) builds the collections context
+ 3) loads the new `dynamic_reasoning_rag_agent` prompt
+ """
+ date_str = str(datetime.now().strftime("%m/%d/%Y"))
+
+ # "dynamic_rag_agent" // "static_rag_agent"
+
+ if mode == "rag":
+ prompt_name = (
+ self.config.agent.rag_agent_dynamic_prompt
+ if use_system_context
+ else self.config.agent.rag_rag_agent_static_prompt
+ )
+ else:
+ prompt_name = "static_research_agent"
+ return await self.providers.database.prompts_handler.get_cached_prompt(
+ # We use custom tooling and a custom agent to handle gemini models
+ prompt_name,
+ inputs={
+ "date": date_str,
+ },
+ )
+
+ if model is not None and ("deepseek" in model):
+ prompt_name = f"{prompt_name}_xml_tooling"
+
+ if use_system_context:
+ doc_context_str = await self._build_documents_context(
+ filter_user_id=filter_user_id,
+ )
+ logger.debug(f"Loading prompt {prompt_name}")
+ # Now fetch the prompt from the database prompts handler
+ # This relies on your "rag_agent_extended" existing with
+ # placeholders: date, document_context
+ system_prompt = await self.providers.database.prompts_handler.get_cached_prompt(
+ # We use custom tooling and a custom agent to handle gemini models
+ prompt_name,
+ inputs={
+ "date": date_str,
+ "max_tool_context_length": max_tool_context_length,
+ "document_context": doc_context_str,
+ },
+ )
+ else:
+ system_prompt = await self.providers.database.prompts_handler.get_cached_prompt(
+ prompt_name,
+ inputs={
+ "date": date_str,
+ },
+ )
+ logger.debug(f"Running agent with system prompt = {system_prompt}")
+ return system_prompt
+
+ async def _perform_web_search(
+ self,
+ query: str,
+ search_settings: SearchSettings = SearchSettings(),
+ ) -> AggregateSearchResult:
+ """
+ Perform a web search using an external search engine API (Serper).
+
+ Args:
+ query: The search query string
+ search_settings: Optional search settings to customize the search
+
+ Returns:
+ AggregateSearchResult containing web search results
+ """
+ try:
+ # Import the Serper client here to avoid circular imports
+ from core.utils.serper import SerperClient
+
+ # Initialize the Serper client
+ serper_client = SerperClient()
+
+ # Perform the raw search using Serper API
+ raw_results = serper_client.get_raw(query)
+
+ # Process the raw results into a WebSearchResult object
+ web_response = WebSearchResult.from_serper_results(raw_results)
+
+ # Create an AggregateSearchResult with the web search results
+ agg_result = AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=web_response.organic_results,
+ )
+
+ # Log the search for monitoring purposes
+ logger.debug(f"Web search completed for query: {query}")
+ logger.debug(
+ f"Found {len(web_response.organic_results)} web results"
+ )
+
+ return agg_result
+
+ except Exception as e:
+ logger.error(f"Error performing web search: {str(e)}")
+ # Return empty results rather than failing completely
+ return AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=[],
+ )
+
+
+class RetrievalServiceAdapter:
+ @staticmethod
+ def _parse_user_data(user_data):
+ if isinstance(user_data, str):
+ try:
+ user_data = json.loads(user_data)
+ except json.JSONDecodeError as e:
+ raise ValueError(
+ f"Invalid user data format: {user_data}"
+ ) from e
+ return User.from_dict(user_data)
+
+ @staticmethod
+ def prepare_search_input(
+ query: str,
+ search_settings: SearchSettings,
+ user: User,
+ ) -> dict:
+ return {
+ "query": query,
+ "search_settings": search_settings.to_dict(),
+ "user": user.to_dict(),
+ }
+
+ @staticmethod
+ def parse_search_input(data: dict):
+ return {
+ "query": data["query"],
+ "search_settings": SearchSettings.from_dict(
+ data["search_settings"]
+ ),
+ "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
+ }
+
+ @staticmethod
+ def prepare_rag_input(
+ query: str,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ task_prompt: Optional[str],
+ include_web_search: bool,
+ user: User,
+ ) -> dict:
+ return {
+ "query": query,
+ "search_settings": search_settings.to_dict(),
+ "rag_generation_config": rag_generation_config.to_dict(),
+ "task_prompt": task_prompt,
+ "include_web_search": include_web_search,
+ "user": user.to_dict(),
+ }
+
+ @staticmethod
+ def parse_rag_input(data: dict):
+ return {
+ "query": data["query"],
+ "search_settings": SearchSettings.from_dict(
+ data["search_settings"]
+ ),
+ "rag_generation_config": GenerationConfig.from_dict(
+ data["rag_generation_config"]
+ ),
+ "task_prompt": data["task_prompt"],
+ "include_web_search": data["include_web_search"],
+ "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
+ }
+
+ @staticmethod
+ def prepare_agent_input(
+ message: Message,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ task_prompt: Optional[str],
+ include_title_if_available: bool,
+ user: User,
+ conversation_id: Optional[str] = None,
+ ) -> dict:
+ return {
+ "message": message.to_dict(),
+ "search_settings": search_settings.to_dict(),
+ "rag_generation_config": rag_generation_config.to_dict(),
+ "task_prompt": task_prompt,
+ "include_title_if_available": include_title_if_available,
+ "user": user.to_dict(),
+ "conversation_id": conversation_id,
+ }
+
+ @staticmethod
+ def parse_agent_input(data: dict):
+ return {
+ "message": Message.from_dict(data["message"]),
+ "search_settings": SearchSettings.from_dict(
+ data["search_settings"]
+ ),
+ "rag_generation_config": GenerationConfig.from_dict(
+ data["rag_generation_config"]
+ ),
+ "task_prompt": data["task_prompt"],
+ "include_title_if_available": data["include_title_if_available"],
+ "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
+ "conversation_id": data.get("conversation_id"),
+ }