aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core')
-rw-r--r--.venv/lib/python3.12/site-packages/core/__init__.py175
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/__init__.py36
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/base.py1484
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/rag.py662
-rw-r--r--.venv/lib/python3.12/site-packages/core/agent/research.py697
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/__init__.py130
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py154
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/agent.py291
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/base.py22
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py208
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py12
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/__init__.py59
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/auth.py231
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/base.py135
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/crypto.py120
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/database.py197
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/email.py96
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/embedding.py197
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py172
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/llm.py200
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py70
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/utils/__init__.py43
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/full.toml21
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/full_azure.toml46
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/full_lm_studio.toml57
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/full_ollama.toml63
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/gemini.toml21
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/lm_studio.toml42
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/ollama.toml48
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/r2r_azure.toml23
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/r2r_azure_with_test_limits.toml37
-rw-r--r--.venv/lib/python3.12/site-packages/core/configs/r2r_with_auth.toml8
-rw-r--r--.venv/lib/python3.12/site-packages/core/examples/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/examples/hello_r2r.py23
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/__init__.py24
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/abstractions.py82
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py151
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py422
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py1207
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py737
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py2342
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py1065
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py2051
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py576
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py387
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py639
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py186
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py1721
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/app.py121
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/app_entry.py125
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py12
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/builder.py127
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/assembly/factory.py417
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/config.py213
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py16
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py539
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py721
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py222
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py598
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/__init__.py14
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/auth_service.py316
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/base.py14
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/graph_service.py1358
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py983
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/management_service.py1084
-rw-r--r--.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py2087
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/__init__.py35
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/__init__.py26
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/audio_parser.py74
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/bmp_parser.py78
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/doc_parser.py108
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/docx_parser.py38
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/img_parser.py281
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/odt_parser.py60
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/pdf_parser.py363
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/ppt_parser.py88
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/pptx_parser.py40
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/media/rtf_parser.py45
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/__init__.py28
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/csv_parser.py108
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/eml_parser.py63
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/epub_parser.py121
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/json_parser.py94
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/msg_parser.py65
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/org_parser.py72
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/p7s_parser.py178
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/rst_parser.py58
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/tsv_parser.py109
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/xls_parser.py140
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/structured/xlsx_parser.py100
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/text/__init__.py10
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/text/html_parser.py32
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/text/md_parser.py39
-rw-r--r--.venv/lib/python3.12/site-packages/core/parsers/text/text_parser.py30
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/__init__.py77
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py133
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py166
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py701
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py249
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py195
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py181
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/base.py247
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/chunks.py1316
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/collections.py701
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/conversations.py858
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/documents.py1172
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/files.py334
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/filters.py478
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/graphs.py2884
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/limits.py434
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/postgres.py286
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml56
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml41
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml28
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml99
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml74
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml40
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml100
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml29
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml29
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml27
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml16
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml61
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml18
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml3
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml4
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml42
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py748
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/tokens.py67
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/users.py1325
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py67
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py281
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py257
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/smtp.py176
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py305
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py194
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py243
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py13
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py355
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py396
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py925
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py110
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py80
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/openai.py522
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py96
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/utils.py106
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py4
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py105
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py61
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/__init__.py182
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/logging_config.py164
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/sentry.py22
-rw-r--r--.venv/lib/python3.12/site-packages/core/utils/serper.py107
164 files changed, 46888 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/__init__.py b/.venv/lib/python3.12/site-packages/core/__init__.py
new file mode 100644
index 00000000..e80319c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/__init__.py
@@ -0,0 +1,175 @@
+import logging
+
+# Keep '*' imports for enhanced development velocity
+from .agent import *
+from .base import *
+from .main import *
+from .parsers import *
+from .providers import *
+
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+# Create a console handler and set the level to info
+ch = logging.StreamHandler()
+ch.setLevel(logging.INFO)
+
+# Create a formatter and set it for the handler
+formatter = logging.Formatter(
+ "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
+)
+ch.setFormatter(formatter)
+
+# Add the handler to the logger
+logger.addHandler(ch)
+
+# Optional: Prevent propagation to the root logger
+logger.propagate = False
+
+logging.getLogger("httpx").setLevel(logging.WARNING)
+logging.getLogger("LiteLLM").setLevel(logging.WARNING)
+
+__all__ = [
+ "ThinkingEvent",
+ "ToolCallEvent",
+ "ToolResultEvent",
+ "CitationEvent",
+ "Citation",
+ "R2RAgent",
+ "SearchResultsCollector",
+ "R2RRAGAgent",
+ "R2RXMLToolsRAGAgent",
+ "R2RStreamingRAGAgent",
+ "R2RXMLToolsStreamingRAGAgent",
+ "AsyncSyncMeta",
+ "syncable",
+ "MessageType",
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "DocumentType",
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ "R2RDocumentProcessingError",
+ "R2RException",
+ "Entity",
+ "GraphExtraction",
+ "Relationship",
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "RAGCompletion",
+ "Prompt",
+ "AggregateSearchResult",
+ "WebSearchResult",
+ "GraphSearchResult",
+ "ChunkSearchSettings",
+ "GraphSearchSettings",
+ "ChunkSearchResult",
+ "WebPageSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "SearchMode",
+ "HybridSearchSettings",
+ "Token",
+ "TokenData",
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "IndexConfig",
+ "Agent",
+ "AgentConfig",
+ "Conversation",
+ "Message",
+ "Tool",
+ "ToolResult",
+ "TokenResponse",
+ "User",
+ "AppConfig",
+ "Provider",
+ "ProviderConfig",
+ "AuthConfig",
+ "AuthProvider",
+ "CryptoConfig",
+ "CryptoProvider",
+ "EmailConfig",
+ "EmailProvider",
+ "LimitSettings",
+ "DatabaseConfig",
+ "DatabaseProvider",
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ "CompletionConfig",
+ "CompletionProvider",
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "generate_id",
+ "increment_version",
+ "validate_uuid",
+ "yield_sse_event",
+ "convert_nonserializable_objects",
+ "num_tokens",
+ "num_tokens_from_messages",
+ "SearchResultsCollector",
+ "R2RProviders",
+ "R2RApp",
+ "R2RBuilder",
+ "R2RConfig",
+ "R2RProviderFactory",
+ "AuthService",
+ "IngestionService",
+ "ManagementService",
+ "RetrievalService",
+ "GraphService",
+ "AudioParser",
+ "BMPParser",
+ "DOCParser",
+ "DOCXParser",
+ "ImageParser",
+ "ODTParser",
+ "VLMPDFParser",
+ "BasicPDFParser",
+ "PDFParserUnstructured",
+ "PPTParser",
+ "PPTXParser",
+ "RTFParser",
+ "CSVParser",
+ "CSVParserAdvanced",
+ "EMLParser",
+ "EPUBParser",
+ "JSONParser",
+ "MSGParser",
+ "ORGParser",
+ "P7SParser",
+ "RSTParser",
+ "TSVParser",
+ "XLSParser",
+ "XLSXParser",
+ "XLSXParserAdvanced",
+ "MDParser",
+ "HTMLParser",
+ "TextParser",
+ "SupabaseAuthProvider",
+ "R2RAuthProvider",
+ "JwtAuthProvider",
+ "ClerkAuthProvider",
+ # Email
+ # Crypto
+ "BCryptCryptoProvider",
+ "BcryptCryptoConfig",
+ "NaClCryptoConfig",
+ "NaClCryptoProvider",
+ "PostgresDatabaseProvider",
+ "LiteLLMEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ "OllamaEmbeddingProvider",
+ "OpenAICompletionProvider",
+ "R2RCompletionProvider",
+ "LiteLLMCompletionProvider",
+ "UnstructuredIngestionProvider",
+ "R2RIngestionProvider",
+ "ChunkingStrategy",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/agent/__init__.py
new file mode 100644
index 00000000..bd6dda79
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/__init__.py
@@ -0,0 +1,36 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+from .base import ( # type: ignore
+ R2RAgent,
+ R2RStreamingAgent,
+ R2RXMLStreamingAgent,
+)
+from .rag import ( # type: ignore
+ R2RRAGAgent,
+ R2RStreamingRAGAgent,
+ R2RXMLToolsRAGAgent,
+ R2RXMLToolsStreamingRAGAgent,
+)
+
+# Import the concrete implementations
+from .research import (
+ R2RResearchAgent,
+ R2RStreamingResearchAgent,
+ R2RXMLToolsResearchAgent,
+ R2RXMLToolsStreamingResearchAgent,
+)
+
+__all__ = [
+ # Base
+ "R2RAgent",
+ "R2RStreamingAgent",
+ "R2RXMLStreamingAgent",
+ # RAG Agents
+ "R2RRAGAgent",
+ "R2RXMLToolsRAGAgent",
+ "R2RStreamingRAGAgent",
+ "R2RXMLToolsStreamingRAGAgent",
+ "R2RResearchAgent",
+ "R2RStreamingResearchAgent",
+ "R2RXMLToolsResearchAgent",
+ "R2RXMLToolsStreamingResearchAgent",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/agent/base.py b/.venv/lib/python3.12/site-packages/core/agent/base.py
new file mode 100644
index 00000000..84aae3f2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/base.py
@@ -0,0 +1,1484 @@
+import asyncio
+import json
+import logging
+import re
+from abc import ABCMeta
+from typing import AsyncGenerator, Optional, Tuple
+
+from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable
+from core.base.agent import Agent, Conversation
+from core.utils import (
+ CitationTracker,
+ SearchResultsCollector,
+ SSEFormatter,
+ convert_nonserializable_objects,
+ dump_obj,
+ find_new_citation_spans,
+)
+
+logger = logging.getLogger()
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+ pass
+
+
+def sync_wrapper(async_gen):
+ loop = asyncio.get_event_loop()
+
+ def wrapper():
+ try:
+ while True:
+ try:
+ yield loop.run_until_complete(async_gen.__anext__())
+ except StopAsyncIteration:
+ break
+ finally:
+ loop.run_until_complete(async_gen.aclose())
+
+ return wrapper()
+
+
+class R2RAgent(Agent, metaclass=CombinedMeta):
+ def __init__(self, *args, **kwargs):
+ self.search_results_collector = SearchResultsCollector()
+ super().__init__(*args, **kwargs)
+ self._reset()
+
+ async def _generate_llm_summary(self, iterations_count: int) -> str:
+ """
+ Generate a summary of the conversation using the LLM when max iterations are exceeded.
+
+ Args:
+ iterations_count: The number of iterations that were completed
+
+ Returns:
+ A string containing the LLM-generated summary
+ """
+ try:
+ # Get all messages in the conversation
+ all_messages = await self.conversation.get_messages()
+
+ # Create a prompt for the LLM to summarize
+ summary_prompt = {
+ "role": "user",
+ "content": (
+ f"The conversation has reached the maximum limit of {iterations_count} iterations "
+ f"without completing the task. Please provide a concise summary of: "
+ f"1) The key information you've gathered that's relevant to the original query, "
+ f"2) What you've attempted so far and why it's incomplete, and "
+ f"3) A specific recommendation for how to proceed. "
+ f"Keep your summary brief (3-4 sentences total) and focused on the most valuable insights. If it is possible to answer the original user query, then do so now instead."
+ f"Start with '⚠️ **Maximum iterations exceeded**'"
+ ),
+ }
+
+ # Create a new message list with just the conversation history and summary request
+ summary_messages = all_messages + [summary_prompt]
+
+ # Get a completion for the summary
+ generation_config = self.get_generation_config(summary_prompt)
+ response = await self.llm_provider.aget_completion(
+ summary_messages,
+ generation_config,
+ )
+
+ return response.choices[0].message.content
+ except Exception as e:
+ logger.error(f"Error generating LLM summary: {str(e)}")
+ # Fall back to basic summary if LLM generation fails
+ return (
+ "⚠️ **Maximum iterations exceeded**\n\n"
+ "The agent reached the maximum iteration limit without completing the task. "
+ "Consider breaking your request into smaller steps or refining your query."
+ )
+
+ def _reset(self):
+ self._completed = False
+ self.conversation = Conversation()
+
+ @syncable
+ async def arun(
+ self,
+ messages: list[Message],
+ system_instruction: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> list[dict]:
+ self._reset()
+ await self._setup(system_instruction)
+
+ if messages:
+ for message in messages:
+ await self.conversation.add_message(message)
+ iterations_count = 0
+ while (
+ not self._completed
+ and iterations_count < self.config.max_iterations
+ ):
+ iterations_count += 1
+ messages_list = await self.conversation.get_messages()
+ generation_config = self.get_generation_config(messages_list[-1])
+ response = await self.llm_provider.aget_completion(
+ messages_list,
+ generation_config,
+ )
+ logger.debug(f"R2RAgent response: {response}")
+ await self.process_llm_response(response, *args, **kwargs)
+
+ if not self._completed:
+ # Generate a summary of the conversation using the LLM
+ summary = await self._generate_llm_summary(iterations_count)
+ await self.conversation.add_message(
+ Message(role="assistant", content=summary)
+ )
+
+ # Return final content
+ all_messages: list[dict] = await self.conversation.get_messages()
+ all_messages.reverse()
+
+ output_messages = []
+ for message_2 in all_messages:
+ if (
+ # message_2.get("content")
+ message_2.get("content") != messages[-1].content
+ ):
+ output_messages.append(message_2)
+ else:
+ break
+ output_messages.reverse()
+
+ return output_messages
+
+ async def process_llm_response(
+ self, response: LLMChatCompletion, *args, **kwargs
+ ) -> None:
+ if not self._completed:
+ message = response.choices[0].message
+ finish_reason = response.choices[0].finish_reason
+
+ if finish_reason == "stop":
+ self._completed = True
+
+ # Determine which provider we're using
+ using_anthropic = (
+ "anthropic" in self.rag_generation_config.model.lower()
+ )
+
+ # OPENAI HANDLING
+ if not using_anthropic:
+ if message.tool_calls:
+ assistant_msg = Message(
+ role="assistant",
+ content="",
+ tool_calls=[msg.dict() for msg in message.tool_calls],
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ # If there are multiple tool_calls, call them sequentially here
+ for tool_call in message.tool_calls:
+ await self.handle_function_or_tool_call(
+ tool_call.function.name,
+ tool_call.function.arguments,
+ tool_id=tool_call.id,
+ *args,
+ **kwargs,
+ )
+ else:
+ await self.conversation.add_message(
+ Message(role="assistant", content=message.content)
+ )
+ self._completed = True
+
+ else:
+ # First handle thinking blocks if present
+ if (
+ hasattr(message, "structured_content")
+ and message.structured_content
+ ):
+ # Check if structured_content contains any tool_use blocks
+ has_tool_use = any(
+ block.get("type") == "tool_use"
+ for block in message.structured_content
+ )
+
+ if not has_tool_use and message.tool_calls:
+ # If it has thinking but no tool_use, add a separate message with structured_content
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=message.structured_content, # Use structured_content field
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ # Add explicit tool_use blocks in a separate message
+ tool_uses = []
+ for tool_call in message.tool_calls:
+ # Safely parse arguments if they're a string
+ try:
+ if isinstance(
+ tool_call.function.arguments, str
+ ):
+ input_args = json.loads(
+ tool_call.function.arguments
+ )
+ else:
+ input_args = tool_call.function.arguments
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse tool arguments: {tool_call.function.arguments}"
+ )
+ input_args = {
+ "_raw": tool_call.function.arguments
+ }
+
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.id,
+ "name": tool_call.function.name,
+ "input": input_args,
+ }
+ )
+
+ # Add tool_use blocks as a separate assistant message with structured content
+ if tool_uses:
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ structured_content=tool_uses,
+ content="",
+ )
+ )
+ else:
+ # If it already has tool_use or no tool_calls, preserve original structure
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=message.structured_content,
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ elif message.content:
+ # For regular text content
+ await self.conversation.add_message(
+ Message(role="assistant", content=message.content)
+ )
+
+ # If there are tool calls, add them as structured content
+ if message.tool_calls:
+ tool_uses = []
+ for tool_call in message.tool_calls:
+ # Same safe parsing as above
+ try:
+ if isinstance(
+ tool_call.function.arguments, str
+ ):
+ input_args = json.loads(
+ tool_call.function.arguments
+ )
+ else:
+ input_args = tool_call.function.arguments
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse tool arguments: {tool_call.function.arguments}"
+ )
+ input_args = {
+ "_raw": tool_call.function.arguments
+ }
+
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.id,
+ "name": tool_call.function.name,
+ "input": input_args,
+ }
+ )
+
+ await self.conversation.add_message(
+ Message(
+ role="assistant", structured_content=tool_uses
+ )
+ )
+
+ # NEW CASE: Handle tool_calls with no content or structured_content
+ elif message.tool_calls:
+ # Create tool_uses for the message with only tool_calls
+ tool_uses = []
+ for tool_call in message.tool_calls:
+ try:
+ if isinstance(tool_call.function.arguments, str):
+ input_args = json.loads(
+ tool_call.function.arguments
+ )
+ else:
+ input_args = tool_call.function.arguments
+ except json.JSONDecodeError:
+ logger.error(
+ f"Failed to parse tool arguments: {tool_call.function.arguments}"
+ )
+ input_args = {"_raw": tool_call.function.arguments}
+
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.id,
+ "name": tool_call.function.name,
+ "input": input_args,
+ }
+ )
+
+ # Add tool_use blocks as a message before processing tools
+ if tool_uses:
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ structured_content=tool_uses,
+ )
+ )
+
+ # Process the tool calls
+ if message.tool_calls:
+ for tool_call in message.tool_calls:
+ await self.handle_function_or_tool_call(
+ tool_call.function.name,
+ tool_call.function.arguments,
+ tool_id=tool_call.id,
+ *args,
+ **kwargs,
+ )
+
+
+class R2RStreamingAgent(R2RAgent):
+ """
+ Base class for all streaming agents with core streaming functionality.
+ Supports emitting messages, tool calls, and results as SSE events.
+ """
+
+ # These two regexes will detect bracket references and then find short IDs.
+ BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
+ SHORT_ID_PATTERN = re.compile(
+ r"[A-Za-z0-9]{7,8}"
+ ) # 7-8 chars, for example
+
+ def __init__(self, *args, **kwargs):
+ # Force streaming on
+ if hasattr(kwargs.get("config", {}), "stream"):
+ kwargs["config"].stream = True
+ super().__init__(*args, **kwargs)
+
+ async def arun(
+ self,
+ system_instruction: str | None = None,
+ messages: list[Message] | None = None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Main streaming entrypoint: returns an async generator of SSE lines.
+ """
+ self._reset()
+ await self._setup(system_instruction)
+
+ if messages:
+ for m in messages:
+ await self.conversation.add_message(m)
+
+ # Initialize citation tracker for this run
+ citation_tracker = CitationTracker()
+
+ # Dictionary to store citation payloads by ID
+ citation_payloads = {}
+
+ # Track all citations emitted during streaming for final persistence
+ self.streaming_citations: list[dict] = []
+
+ async def sse_generator() -> AsyncGenerator[str, None]:
+ pending_tool_calls = {}
+ partial_text_buffer = ""
+ iterations_count = 0
+
+ try:
+ # Keep streaming until we complete
+ while (
+ not self._completed
+ and iterations_count < self.config.max_iterations
+ ):
+ iterations_count += 1
+ # 1) Get current messages
+ msg_list = await self.conversation.get_messages()
+ gen_cfg = self.get_generation_config(
+ msg_list[-1], stream=True
+ )
+
+ accumulated_thinking = ""
+ thinking_signatures = {} # Map thinking content to signatures
+
+ # 2) Start streaming from LLM
+ llm_stream = self.llm_provider.aget_completion_stream(
+ msg_list, gen_cfg
+ )
+ async for chunk in llm_stream:
+ delta = chunk.choices[0].delta
+ finish_reason = chunk.choices[0].finish_reason
+
+ if hasattr(delta, "thinking") and delta.thinking:
+ # Accumulate thinking for later use in messages
+ accumulated_thinking += delta.thinking
+
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.thinking
+ ):
+ yield line
+
+ # Add this new handler for thinking signatures
+ if hasattr(delta, "thinking_signature"):
+ thinking_signatures[accumulated_thinking] = (
+ delta.thinking_signature
+ )
+ accumulated_thinking = ""
+
+ # 3) If new text, accumulate it
+ if delta.content:
+ partial_text_buffer += delta.content
+
+ # (a) Now emit the newly streamed text as a "message" event
+ async for line in SSEFormatter.yield_message_event(
+ delta.content
+ ):
+ yield line
+
+ # (b) Find new citation spans in the accumulated text
+ new_citation_spans = find_new_citation_spans(
+ partial_text_buffer, citation_tracker
+ )
+
+ # Process each new citation span
+ for cid, spans in new_citation_spans.items():
+ for span in spans:
+ # Check if this is the first time we've seen this citation ID
+ is_new_citation = (
+ citation_tracker.is_new_citation(cid)
+ )
+
+ # Get payload if it's a new citation
+ payload = None
+ if is_new_citation:
+ source_obj = self.search_results_collector.find_by_short_id(
+ cid
+ )
+ if source_obj:
+ # Store payload for reuse
+ payload = dump_obj(source_obj)
+ citation_payloads[cid] = payload
+
+ # Create citation event payload
+ citation_data = {
+ "id": cid,
+ "object": "citation",
+ "is_new": is_new_citation,
+ "span": {
+ "start": span[0],
+ "end": span[1],
+ },
+ }
+
+ # Only include full payload for new citations
+ if is_new_citation and payload:
+ citation_data["payload"] = payload
+
+ # Add to streaming citations for final answer
+ self.streaming_citations.append(
+ citation_data
+ )
+
+ # Emit the citation event
+ async for (
+ line
+ ) in SSEFormatter.yield_citation_event(
+ citation_data
+ ):
+ yield line
+
+ if delta.tool_calls:
+ for tc in delta.tool_calls:
+ idx = tc.index
+ if idx not in pending_tool_calls:
+ pending_tool_calls[idx] = {
+ "id": tc.id,
+ "name": tc.function.name or "",
+ "arguments": tc.function.arguments
+ or "",
+ }
+ else:
+ # Accumulate partial name/arguments
+ if tc.function.name:
+ pending_tool_calls[idx]["name"] = (
+ tc.function.name
+ )
+ if tc.function.arguments:
+ pending_tool_calls[idx][
+ "arguments"
+ ] += tc.function.arguments
+
+ # 5) If the stream signals we should handle "tool_calls"
+ if finish_reason == "tool_calls":
+ # Handle thinking if present
+ await self._handle_thinking(
+ thinking_signatures, accumulated_thinking
+ )
+
+ calls_list = []
+ for idx in sorted(pending_tool_calls.keys()):
+ cinfo = pending_tool_calls[idx]
+ calls_list.append(
+ {
+ "tool_call_id": cinfo["id"]
+ or f"call_{idx}",
+ "name": cinfo["name"],
+ "arguments": cinfo["arguments"],
+ }
+ )
+
+ # (a) Emit SSE "tool_call" events
+ for c in calls_list:
+ tc_data = self._create_tool_call_data(c)
+ async for (
+ line
+ ) in SSEFormatter.yield_tool_call_event(
+ tc_data
+ ):
+ yield line
+
+ # (b) Add an assistant message capturing these calls
+ await self._add_tool_calls_message(
+ calls_list, partial_text_buffer
+ )
+
+ # (c) Execute each tool call in parallel
+ await asyncio.gather(
+ *[
+ self.handle_function_or_tool_call(
+ c["name"],
+ c["arguments"],
+ tool_id=c["tool_call_id"],
+ )
+ for c in calls_list
+ ]
+ )
+
+ # Reset buffer & calls
+ pending_tool_calls.clear()
+ partial_text_buffer = ""
+
+ elif finish_reason == "stop":
+ # Handle thinking if present
+ await self._handle_thinking(
+ thinking_signatures, accumulated_thinking
+ )
+
+ # 6) The LLM is done. If we have any leftover partial text,
+ # finalize it in the conversation
+ if partial_text_buffer:
+ # Create the final message with metadata including citations
+ final_message = Message(
+ role="assistant",
+ content=partial_text_buffer,
+ metadata={
+ "citations": self.streaming_citations
+ },
+ )
+
+ # Add it to the conversation
+ await self.conversation.add_message(
+ final_message
+ )
+
+ # (a) Prepare final answer with optimized citations
+ consolidated_citations = []
+ # Group citations by ID with all their spans
+ for (
+ cid,
+ spans,
+ ) in citation_tracker.get_all_spans().items():
+ if cid in citation_payloads:
+ consolidated_citations.append(
+ {
+ "id": cid,
+ "object": "citation",
+ "spans": [
+ {"start": s[0], "end": s[1]}
+ for s in spans
+ ],
+ "payload": citation_payloads[cid],
+ }
+ )
+
+ # Create final answer payload
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": partial_text_buffer,
+ "citations": consolidated_citations,
+ }
+
+ # Emit final answer event
+ async for (
+ line
+ ) in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # (b) Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+ break
+
+ # If we exit the while loop due to hitting max iterations
+ if not self._completed:
+ # Generate a summary using the LLM
+ summary = await self._generate_llm_summary(
+ iterations_count
+ )
+
+ # Send the summary as a message event
+ async for line in SSEFormatter.yield_message_event(
+ summary
+ ):
+ yield line
+
+ # Add summary to conversation with citations metadata
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ content=summary,
+ metadata={"citations": self.streaming_citations},
+ )
+ )
+
+ # Create and emit a final answer payload with the summary
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": summary,
+ "citations": consolidated_citations,
+ }
+
+ async for line in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+
+ except Exception as e:
+ logger.error(f"Error in streaming agent: {str(e)}")
+ # Emit error event for client
+ async for line in SSEFormatter.yield_error_event(
+ f"Agent error: {str(e)}"
+ ):
+ yield line
+ # Send done event to close the stream
+ yield SSEFormatter.yield_done_event()
+
+ # Finally, we return the async generator
+ async for line in sse_generator():
+ yield line
+
+ async def _handle_thinking(
+ self, thinking_signatures, accumulated_thinking
+ ):
+ """Process any accumulated thinking content"""
+ if accumulated_thinking:
+ structured_content = [
+ {
+ "type": "thinking",
+ "thinking": accumulated_thinking,
+ # Anthropic will validate this in their API
+ "signature": "placeholder_signature",
+ }
+ ]
+
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=structured_content,
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ elif thinking_signatures:
+ for (
+ accumulated_thinking,
+ thinking_signature,
+ ) in thinking_signatures.items():
+ structured_content = [
+ {
+ "type": "thinking",
+ "thinking": accumulated_thinking,
+ # Anthropic will validate this in their API
+ "signature": thinking_signature,
+ }
+ ]
+
+ assistant_msg = Message(
+ role="assistant",
+ structured_content=structured_content,
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ async def _add_tool_calls_message(self, calls_list, partial_text_buffer):
+ """Add a message with tool calls to the conversation"""
+ assistant_msg = Message(
+ role="assistant",
+ content=partial_text_buffer or "",
+ tool_calls=[
+ {
+ "id": c["tool_call_id"],
+ "type": "function",
+ "function": {
+ "name": c["name"],
+ "arguments": c["arguments"],
+ },
+ }
+ for c in calls_list
+ ],
+ )
+ await self.conversation.add_message(assistant_msg)
+
+ def _create_tool_call_data(self, call_info):
+ """Create tool call data structure from call info"""
+ return {
+ "tool_call_id": call_info["tool_call_id"],
+ "name": call_info["name"],
+ "arguments": call_info["arguments"],
+ }
+
+ def _create_citation_payload(self, short_id, payload):
+ """Create citation payload for a short ID"""
+ # This will be overridden in RAG subclasses
+ # check if as_dict is on payload
+ if hasattr(payload, "as_dict"):
+ payload = payload.as_dict()
+ if hasattr(payload, "dict"):
+ payload = payload.dict
+ if hasattr(payload, "to_dict"):
+ payload = payload.to_dict()
+
+ return {
+ "id": f"{short_id}",
+ "object": "citation",
+ "payload": dump_obj(payload), # Will be populated in RAG agents
+ }
+
+ def _create_final_answer_payload(self, answer_text, citations):
+ """Create the final answer payload"""
+ # This will be extended in RAG subclasses
+ return {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": answer_text,
+ "citations": citations,
+ }
+
+
+class R2RXMLStreamingAgent(R2RStreamingAgent):
+ """
+ A streaming agent that parses XML-formatted responses with special handling for:
+ - <think> or <Thought> blocks for chain-of-thought reasoning
+ - <Action>, <ToolCalls>, <ToolCall> blocks for tool execution
+ """
+
+ # We treat <think> or <Thought> as the same token boundaries
+ THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
+ THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
+
+ # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
+ ACTION_PATTERN = re.compile(
+ r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALLS_PATTERN = re.compile(
+ r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALL_PATTERN = re.compile(
+ r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
+ )
+ NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
+ PARAMS_PATTERN = re.compile(
+ r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
+ )
+ RESPONSE_PATTERN = re.compile(
+ r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
+ )
+
+ async def arun(
+ self,
+ system_instruction: str | None = None,
+ messages: list[Message] | None = None,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Main streaming entrypoint: returns an async generator of SSE lines.
+ """
+ self._reset()
+ await self._setup(system_instruction)
+
+ if messages:
+ for m in messages:
+ await self.conversation.add_message(m)
+
+ # Initialize citation tracker for this run
+ citation_tracker = CitationTracker()
+
+ # Dictionary to store citation payloads by ID
+ citation_payloads = {}
+
+ # Track all citations emitted during streaming for final persistence
+ self.streaming_citations: list[dict] = []
+
+ async def sse_generator() -> AsyncGenerator[str, None]:
+ iterations_count = 0
+
+ try:
+ # Keep streaming until we complete
+ while (
+ not self._completed
+ and iterations_count < self.config.max_iterations
+ ):
+ iterations_count += 1
+ # 1) Get current messages
+ msg_list = await self.conversation.get_messages()
+ gen_cfg = self.get_generation_config(
+ msg_list[-1], stream=True
+ )
+
+ # 2) Start streaming from LLM
+ llm_stream = self.llm_provider.aget_completion_stream(
+ msg_list, gen_cfg
+ )
+
+ # Create state variables for each iteration
+ iteration_buffer = ""
+ yielded_first_event = False
+ in_action_block = False
+ is_thinking = False
+ accumulated_thinking = ""
+ thinking_signatures = {}
+
+ async for chunk in llm_stream:
+ delta = chunk.choices[0].delta
+ finish_reason = chunk.choices[0].finish_reason
+
+ # Handle thinking if present
+ if hasattr(delta, "thinking") and delta.thinking:
+ # Accumulate thinking for later use in messages
+ accumulated_thinking += delta.thinking
+
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.thinking
+ ):
+ yield line
+
+ # Add this new handler for thinking signatures
+ if hasattr(delta, "thinking_signature"):
+ thinking_signatures[accumulated_thinking] = (
+ delta.thinking_signature
+ )
+ accumulated_thinking = ""
+
+ # 3) If new text, accumulate it
+ if delta.content:
+ iteration_buffer += delta.content
+
+ # Check if we have accumulated enough text for a `<Thought>` block
+ if len(iteration_buffer) < len("<Thought>"):
+ continue
+
+ # Check if we have yielded the first event
+ if not yielded_first_event:
+ # Emit the first chunk
+ if self.THOUGHT_OPEN.findall(iteration_buffer):
+ is_thinking = True
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ iteration_buffer
+ ):
+ yield line
+ else:
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ iteration_buffer
+ ):
+ yield line
+
+ # Mark as yielded
+ yielded_first_event = True
+ continue
+
+ # Check if we are in a thinking block
+ if is_thinking:
+ # Still thinking, so keep yielding thinking events
+ if not self.THOUGHT_CLOSE.findall(
+ iteration_buffer
+ ):
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.content
+ ):
+ yield line
+
+ continue
+ # Done thinking, so emit the last thinking event
+ else:
+ is_thinking = False
+ thought_text = delta.content.split(
+ "</Thought>"
+ )[0].split("</think>")[0]
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ thought_text
+ ):
+ yield line
+ post_thought_text = delta.content.split(
+ "</Thought>"
+ )[-1].split("</think>")[-1]
+ delta.content = post_thought_text
+
+ # (b) Find new citation spans in the accumulated text
+ new_citation_spans = find_new_citation_spans(
+ iteration_buffer, citation_tracker
+ )
+
+ # Process each new citation span
+ for cid, spans in new_citation_spans.items():
+ for span in spans:
+ # Check if this is the first time we've seen this citation ID
+ is_new_citation = (
+ citation_tracker.is_new_citation(cid)
+ )
+
+ # Get payload if it's a new citation
+ payload = None
+ if is_new_citation:
+ source_obj = self.search_results_collector.find_by_short_id(
+ cid
+ )
+ if source_obj:
+ # Store payload for reuse
+ payload = dump_obj(source_obj)
+ citation_payloads[cid] = payload
+
+ # Create citation event payload
+ citation_data = {
+ "id": cid,
+ "object": "citation",
+ "is_new": is_new_citation,
+ "span": {
+ "start": span[0],
+ "end": span[1],
+ },
+ }
+
+ # Only include full payload for new citations
+ if is_new_citation and payload:
+ citation_data["payload"] = payload
+
+ # Add to streaming citations for final answer
+ self.streaming_citations.append(
+ citation_data
+ )
+
+ # Emit the citation event
+ async for (
+ line
+ ) in SSEFormatter.yield_citation_event(
+ citation_data
+ ):
+ yield line
+
+ # Now prepare to emit the newly streamed text as a "message" event
+ if (
+ iteration_buffer.count("<")
+ and not in_action_block
+ ):
+ in_action_block = True
+
+ if (
+ in_action_block
+ and len(
+ self.ACTION_PATTERN.findall(
+ iteration_buffer
+ )
+ )
+ < 2
+ ):
+ continue
+
+ elif in_action_block:
+ in_action_block = False
+ # Emit the post action block text, if it is there
+ post_action_text = iteration_buffer.split(
+ "</Action>"
+ )[-1]
+ if post_action_text:
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ post_action_text
+ ):
+ yield line
+
+ else:
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ delta.content
+ ):
+ yield line
+
+ elif finish_reason == "stop":
+ break
+
+ # Process any accumulated thinking
+ await self._handle_thinking(
+ thinking_signatures, accumulated_thinking
+ )
+
+ # 6) The LLM is done. If we have any leftover partial text,
+ # finalize it in the conversation
+ if iteration_buffer:
+ # Create the final message with metadata including citations
+ final_message = Message(
+ role="assistant",
+ content=iteration_buffer,
+ metadata={"citations": self.streaming_citations},
+ )
+
+ # Add it to the conversation
+ await self.conversation.add_message(final_message)
+
+ # --- 4) Process any <Action>/<ToolCalls> blocks, or mark completed
+ action_matches = self.ACTION_PATTERN.findall(
+ iteration_buffer
+ )
+
+ if len(action_matches) > 0:
+ # Process each ToolCall
+ xml_toolcalls = "<ToolCalls>"
+
+ for action_block in action_matches:
+ tool_calls_text = []
+ # Look for ToolCalls wrapper, or use the raw action block
+ calls_wrapper = self.TOOLCALLS_PATTERN.findall(
+ action_block
+ )
+ if calls_wrapper:
+ for tw in calls_wrapper:
+ tool_calls_text.append(tw)
+ else:
+ tool_calls_text.append(action_block)
+
+ for calls_region in tool_calls_text:
+ calls_found = self.TOOLCALL_PATTERN.findall(
+ calls_region
+ )
+ for tc_block in calls_found:
+ tool_name, tool_params = (
+ self._parse_single_tool_call(tc_block)
+ )
+ if tool_name:
+ # Emit SSE event for tool call
+ tool_call_id = (
+ f"call_{abs(hash(tc_block))}"
+ )
+ call_evt_data = {
+ "tool_call_id": tool_call_id,
+ "name": tool_name,
+ "arguments": json.dumps(
+ tool_params
+ ),
+ }
+ async for line in (
+ SSEFormatter.yield_tool_call_event(
+ call_evt_data
+ )
+ ):
+ yield line
+
+ try:
+ tool_result = await self.handle_function_or_tool_call(
+ tool_name,
+ json.dumps(tool_params),
+ tool_id=tool_call_id,
+ save_messages=False,
+ )
+ result_content = tool_result.llm_formatted_result
+ except Exception as e:
+ result_content = f"Error in tool '{tool_name}': {str(e)}"
+
+ xml_toolcalls += (
+ f"<ToolCall>"
+ f"<Name>{tool_name}</Name>"
+ f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+ f"<Result>{result_content}</Result>"
+ f"</ToolCall>"
+ )
+
+ # Emit SSE tool result for non-result tools
+ result_data = {
+ "tool_call_id": tool_call_id,
+ "role": "tool",
+ "content": json.dumps(
+ convert_nonserializable_objects(
+ result_content
+ )
+ ),
+ }
+ async for line in SSEFormatter.yield_tool_result_event(
+ result_data
+ ):
+ yield line
+
+ xml_toolcalls += "</ToolCalls>"
+ pre_action_text = iteration_buffer[
+ : iteration_buffer.find(action_block)
+ ]
+ post_action_text = iteration_buffer[
+ iteration_buffer.find(action_block)
+ + len(action_block) :
+ ]
+ iteration_text = (
+ pre_action_text + xml_toolcalls + post_action_text
+ )
+
+ # Update the conversation with tool results
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ content=iteration_text,
+ metadata={
+ "citations": self.streaming_citations
+ },
+ )
+ )
+ else:
+ # (a) Prepare final answer with optimized citations
+ consolidated_citations = []
+ # Group citations by ID with all their spans
+ for (
+ cid,
+ spans,
+ ) in citation_tracker.get_all_spans().items():
+ if cid in citation_payloads:
+ consolidated_citations.append(
+ {
+ "id": cid,
+ "object": "citation",
+ "spans": [
+ {"start": s[0], "end": s[1]}
+ for s in spans
+ ],
+ "payload": citation_payloads[cid],
+ }
+ )
+
+ # Create final answer payload
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": iteration_buffer,
+ "citations": consolidated_citations,
+ }
+
+ # Emit final answer event
+ async for (
+ line
+ ) in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # (b) Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+
+ # If we exit the while loop due to hitting max iterations
+ if not self._completed:
+ # Generate a summary using the LLM
+ summary = await self._generate_llm_summary(
+ iterations_count
+ )
+
+ # Send the summary as a message event
+ async for line in SSEFormatter.yield_message_event(
+ summary
+ ):
+ yield line
+
+ # Add summary to conversation with citations metadata
+ await self.conversation.add_message(
+ Message(
+ role="assistant",
+ content=summary,
+ metadata={"citations": self.streaming_citations},
+ )
+ )
+
+ # Create and emit a final answer payload with the summary
+ final_evt_payload = {
+ "id": "msg_final",
+ "object": "agent.final_answer",
+ "generated_answer": summary,
+ "citations": consolidated_citations,
+ }
+
+ async for line in SSEFormatter.yield_final_answer_event(
+ final_evt_payload
+ ):
+ yield line
+
+ # Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ self._completed = True
+
+ except Exception as e:
+ logger.error(f"Error in streaming agent: {str(e)}")
+ # Emit error event for client
+ async for line in SSEFormatter.yield_error_event(
+ f"Agent error: {str(e)}"
+ ):
+ yield line
+ # Send done event to close the stream
+ yield SSEFormatter.yield_done_event()
+
+ # Finally, we return the async generator
+ async for line in sse_generator():
+ yield line
+
+ def _parse_single_tool_call(
+ self, toolcall_text: str
+ ) -> Tuple[Optional[str], dict]:
+ """
+ Parse a ToolCall block to extract the name and parameters.
+
+ Args:
+ toolcall_text: The text content of a ToolCall block
+
+ Returns:
+ Tuple of (tool_name, tool_parameters)
+ """
+ name_match = self.NAME_PATTERN.search(toolcall_text)
+ if not name_match:
+ return None, {}
+ tool_name = name_match.group(1).strip()
+
+ params_match = self.PARAMS_PATTERN.search(toolcall_text)
+ if not params_match:
+ return tool_name, {}
+
+ raw_params = params_match.group(1).strip()
+ try:
+ # Handle potential JSON parsing issues
+ # First try direct parsing
+ tool_params = json.loads(raw_params)
+ except json.JSONDecodeError:
+ # If that fails, try to clean up the JSON string
+ try:
+ # Replace escaped quotes that might cause issues
+ cleaned_params = raw_params.replace('\\"', '"')
+ # Try again with the cleaned string
+ tool_params = json.loads(cleaned_params)
+ except json.JSONDecodeError:
+ # If all else fails, treat as a plain string value
+ tool_params = {"value": raw_params}
+
+ return tool_name, tool_params
+
+
+class R2RXMLToolsAgent(R2RAgent):
+ """
+ A non-streaming agent that:
+ - parses <think> or <Thought> blocks as chain-of-thought
+ - filters out XML tags related to tool calls and actions
+ - processes <Action><ToolCalls><ToolCall> blocks
+ - properly extracts citations when they appear in the text
+ """
+
+ # We treat <think> or <Thought> as the same token boundaries
+ THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
+ THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
+
+ # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
+ ACTION_PATTERN = re.compile(
+ r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALLS_PATTERN = re.compile(
+ r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
+ )
+ TOOLCALL_PATTERN = re.compile(
+ r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
+ )
+ NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
+ PARAMS_PATTERN = re.compile(
+ r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
+ )
+ RESPONSE_PATTERN = re.compile(
+ r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
+ )
+
+ async def process_llm_response(self, response, *args, **kwargs):
+ """
+ Override the base process_llm_response to handle XML structured responses
+ including thoughts and tool calls.
+ """
+ if self._completed:
+ return
+
+ message = response.choices[0].message
+ finish_reason = response.choices[0].finish_reason
+
+ if not message.content:
+ # If there's no content, let the parent class handle the normal tool_calls flow
+ return await super().process_llm_response(
+ response, *args, **kwargs
+ )
+
+ # Get the response content
+ content = message.content
+
+ # HACK for gemini
+ content = content.replace("```action", "")
+ content = content.replace("```tool_code", "")
+ content = content.replace("```", "")
+
+ if (
+ not content.startswith("<")
+ and "deepseek" in self.rag_generation_config.model
+ ): # HACK - fix issues with adding `<think>` to the beginning
+ content = "<think>" + content
+
+ # Process any tool calls in the content
+ action_matches = self.ACTION_PATTERN.findall(content)
+ if action_matches:
+ xml_toolcalls = "<ToolCalls>"
+ for action_block in action_matches:
+ tool_calls_text = []
+ # Look for ToolCalls wrapper, or use the raw action block
+ calls_wrapper = self.TOOLCALLS_PATTERN.findall(action_block)
+ if calls_wrapper:
+ for tw in calls_wrapper:
+ tool_calls_text.append(tw)
+ else:
+ tool_calls_text.append(action_block)
+
+ # Process each ToolCall
+ for calls_region in tool_calls_text:
+ calls_found = self.TOOLCALL_PATTERN.findall(calls_region)
+ for tc_block in calls_found:
+ tool_name, tool_params = self._parse_single_tool_call(
+ tc_block
+ )
+ if tool_name:
+ tool_call_id = f"call_{abs(hash(tc_block))}"
+ try:
+ tool_result = (
+ await self.handle_function_or_tool_call(
+ tool_name,
+ json.dumps(tool_params),
+ tool_id=tool_call_id,
+ save_messages=False,
+ )
+ )
+
+ # Add tool result to XML
+ xml_toolcalls += (
+ f"<ToolCall>"
+ f"<Name>{tool_name}</Name>"
+ f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+ f"<Result>{tool_result.llm_formatted_result}</Result>"
+ f"</ToolCall>"
+ )
+
+ except Exception as e:
+ logger.error(f"Error in tool call: {str(e)}")
+ # Add error to XML
+ xml_toolcalls += (
+ f"<ToolCall>"
+ f"<Name>{tool_name}</Name>"
+ f"<Parameters>{json.dumps(tool_params)}</Parameters>"
+ f"<Result>Error: {str(e)}</Result>"
+ f"</ToolCall>"
+ )
+
+ xml_toolcalls += "</ToolCalls>"
+ pre_action_text = content[: content.find(action_block)]
+ post_action_text = content[
+ content.find(action_block) + len(action_block) :
+ ]
+ iteration_text = pre_action_text + xml_toolcalls + post_action_text
+
+ # Create the assistant message
+ await self.conversation.add_message(
+ Message(role="assistant", content=iteration_text)
+ )
+ else:
+ # Create an assistant message with the content as-is
+ await self.conversation.add_message(
+ Message(role="assistant", content=content)
+ )
+
+ # Only mark as completed if the finish_reason is "stop" or there are no action calls
+ # This allows the agent to continue the conversation when tool calls are processed
+ if finish_reason == "stop":
+ self._completed = True
+
+ def _parse_single_tool_call(
+ self, toolcall_text: str
+ ) -> Tuple[Optional[str], dict]:
+ """
+ Parse a ToolCall block to extract the name and parameters.
+
+ Args:
+ toolcall_text: The text content of a ToolCall block
+
+ Returns:
+ Tuple of (tool_name, tool_parameters)
+ """
+ name_match = self.NAME_PATTERN.search(toolcall_text)
+ if not name_match:
+ return None, {}
+ tool_name = name_match.group(1).strip()
+
+ params_match = self.PARAMS_PATTERN.search(toolcall_text)
+ if not params_match:
+ return tool_name, {}
+
+ raw_params = params_match.group(1).strip()
+ try:
+ # Handle potential JSON parsing issues
+ # First try direct parsing
+ tool_params = json.loads(raw_params)
+ except json.JSONDecodeError:
+ # If that fails, try to clean up the JSON string
+ try:
+ # Replace escaped quotes that might cause issues
+ cleaned_params = raw_params.replace('\\"', '"')
+ # Try again with the cleaned string
+ tool_params = json.loads(cleaned_params)
+ except json.JSONDecodeError:
+ # If all else fails, treat as a plain string value
+ tool_params = {"value": raw_params}
+
+ return tool_name, tool_params
diff --git a/.venv/lib/python3.12/site-packages/core/agent/rag.py b/.venv/lib/python3.12/site-packages/core/agent/rag.py
new file mode 100644
index 00000000..6f3ab630
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/rag.py
@@ -0,0 +1,662 @@
+# type: ignore
+import logging
+from typing import Any, Callable, Optional
+
+from core.base import (
+ format_search_results_for_llm,
+)
+from core.base.abstractions import (
+ AggregateSearchResult,
+ GenerationConfig,
+ SearchSettings,
+ WebPageSearchResult,
+ WebSearchResult,
+)
+from core.base.agent import Tool
+from core.base.providers import DatabaseProvider
+from core.providers import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from core.utils import (
+ SearchResultsCollector,
+ generate_id,
+ num_tokens,
+)
+
+from ..base.agent.agent import RAGAgentConfig
+
+# Import the base classes from the refactored base file
+from .base import (
+ R2RAgent,
+ R2RStreamingAgent,
+ R2RXMLStreamingAgent,
+ R2RXMLToolsAgent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class RAGAgentMixin:
+ """
+ A Mixin for adding search_file_knowledge, web_search, and content tools
+ to your R2R Agents. This allows your agent to:
+ - call knowledge_search_method (semantic/hybrid search)
+ - call content_method (fetch entire doc/chunk structures)
+ - call an external web search API
+ """
+
+ def __init__(
+ self,
+ *args,
+ search_settings: SearchSettings,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length=10_000,
+ max_context_window_tokens=512_000,
+ **kwargs,
+ ):
+ # Save references to the retrieval logic
+ self.search_settings = search_settings
+ self.knowledge_search_method = knowledge_search_method
+ self.content_method = content_method
+ self.file_search_method = file_search_method
+ self.max_tool_context_length = max_tool_context_length
+ self.max_context_window_tokens = max_context_window_tokens
+ self.search_results_collector = SearchResultsCollector()
+ super().__init__(*args, **kwargs)
+
+ def _register_tools(self):
+ """
+ Called by the base R2RAgent to register all requested tools from self.config.rag_tools.
+ """
+ if not self.config.rag_tools:
+ return
+
+ for tool_name in set(self.config.rag_tools):
+ if tool_name == "get_file_content":
+ self._tools.append(self.content())
+ elif tool_name == "web_scrape":
+ self._tools.append(self.web_scrape())
+ elif tool_name == "search_file_knowledge":
+ self._tools.append(self.search_file_knowledge())
+ elif tool_name == "search_file_descriptions":
+ self._tools.append(self.search_files())
+ elif tool_name == "web_search":
+ self._tools.append(self.web_search())
+ else:
+ raise ValueError(f"Unknown tool requested: {tool_name}")
+ logger.debug(f"Registered {len(self._tools)} RAG tools.")
+
+ # Local Search Tool
+ def search_file_knowledge(self) -> Tool:
+ """
+ Tool to do a semantic/hybrid search on the local knowledge base
+ using self.knowledge_search_method.
+ """
+ return Tool(
+ name="search_file_knowledge",
+ description=(
+ "Search your local knowledge base using the R2R system. "
+ "Use this when you want relevant text chunks or knowledge graph data."
+ ),
+ results_function=self._file_knowledge_search_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "User query to search in the local DB.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _file_knowledge_search_function(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Calls the passed-in `knowledge_search_method(query, search_settings)`.
+ Expects either an AggregateSearchResult or a dict with chunk_search_results, etc.
+ """
+ if not self.knowledge_search_method:
+ raise ValueError(
+ "No knowledge_search_method provided to RAGAgentMixin."
+ )
+
+ raw_response = await self.knowledge_search_method(
+ query=query, search_settings=self.search_settings
+ )
+
+ if isinstance(raw_response, AggregateSearchResult):
+ agg = raw_response
+ else:
+ agg = AggregateSearchResult(
+ chunk_search_results=raw_response.get(
+ "chunk_search_results", []
+ ),
+ graph_search_results=raw_response.get(
+ "graph_search_results", []
+ ),
+ )
+
+ # 1) Store them so that we can do final citations later
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ # 2) Local Context
+ def content(self) -> Tool:
+ """Tool to fetch entire documents from the local database.
+
+ Typically used if the agent needs deeper or more structured context
+ from documents, not just chunk-level hits.
+ """
+ if "gemini" in self.rag_generation_config.model:
+ tool = Tool(
+ name="get_file_content",
+ description=(
+ "Fetches the complete contents of all user documents from the local database. "
+ "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query."
+ "For instance, a single document can be returned with a filter like so:"
+ "{'document_id': {'$eq': '...'}}."
+ "Be sure to use the full 32 character hexidecimal document ID, and not the shortened 8 character ID."
+ ),
+ results_function=self._content_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "filters": {
+ "type": "string",
+ "description": (
+ "Dictionary with filter criteria, such as "
+ '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}'
+ ),
+ },
+ },
+ "required": ["filters"],
+ },
+ )
+
+ else:
+ tool = Tool(
+ name="get_file_content",
+ description=(
+ "Fetches the complete contents of all user documents from the local database. "
+ "Can be used alongside filter criteria (e.g. doc IDs, collection IDs, etc.) to restrict the query."
+ "For instance, a single document can be returned with a filter like so:"
+ "{'document_id': {'$eq': '...'}}."
+ ),
+ results_function=self._content_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "filters": {
+ "type": "object",
+ "description": (
+ "Dictionary with filter criteria, such as "
+ '{"$and": [{"document_id": {"$eq": "6c9d1c39..."}, {"collection_ids": {"$overlap": [...]}]}'
+ ),
+ },
+ },
+ "required": ["filters"],
+ },
+ )
+ return tool
+
+ async def _content_function(
+ self,
+ filters: Optional[dict[str, Any]] = None,
+ options: Optional[dict[str, Any]] = None,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """Calls the passed-in `content_method(filters, options)` to fetch
+ doc+chunk structures.
+
+ Typically returns a list of dicts:
+ [
+ { 'document': {...}, 'chunks': [ {...}, {...}, ... ] },
+ ...
+ ]
+ We'll store these in a new field `document_search_results` of
+ AggregateSearchResult so we don't collide with chunk_search_results.
+ """
+ if not self.content_method:
+ raise ValueError("No content_method provided to RAGAgentMixin.")
+
+ if filters:
+ if "document_id" in filters:
+ filters["id"] = filters.pop("document_id")
+ if self.search_settings.filters != {}:
+ filters = {"$and": [filters, self.search_settings.filters]}
+ else:
+ filters = self.search_settings.filters
+
+ options = options or {}
+
+ # Actually call your data retrieval
+ content = await self.content_method(filters, options)
+ # raw_context presumably is a list[dict], each with 'document' + 'chunks'.
+
+ # Return them in the new aggregator field
+ agg = AggregateSearchResult(
+ # We won't put them in chunk_search_results:
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=None,
+ document_search_results=content,
+ )
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ # Web Search Tool
+ def web_search(self) -> Tool:
+ return Tool(
+ name="web_search",
+ description=(
+ "Search for information on the web - use this tool when the user "
+ "query needs LIVE or recent data from the internet."
+ ),
+ results_function=self._web_search_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The query to search with an external web API.",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _web_search_function(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Calls an external search engine (Serper, Google, etc.) asynchronously
+ and returns results in an AggregateSearchResult.
+ """
+ import asyncio
+
+ from ..utils.serper import SerperClient # adjust your import
+
+ serper_client = SerperClient()
+
+ # If SerperClient.get_raw is not already async, wrap it in run_in_executor
+ raw_results = await asyncio.get_event_loop().run_in_executor(
+ None, # Uses the default executor
+ lambda: serper_client.get_raw(query),
+ )
+
+ # If from_serper_results is not already async, wrap it in run_in_executor too
+ web_response = await asyncio.get_event_loop().run_in_executor(
+ None, lambda: WebSearchResult.from_serper_results(raw_results)
+ )
+
+ agg = AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=web_response.organic_results,
+ )
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ def search_files(self) -> Tool:
+ """
+ A tool to search over file-level metadata (titles, doc-level descriptions, etc.)
+ returning a list of DocumentResponse objects.
+ """
+ return Tool(
+ name="search_file_descriptions",
+ description=(
+ "Semantic search over the stored documents over AI generated summaries of input documents. "
+ "This does NOT retrieve chunk-level contents or knowledge-graph relationships. "
+ "Use this when you need a broad overview of which documents (files) might be relevant."
+ ),
+ results_function=self._search_files_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Query string to semantic search over available files 'list documents about XYZ'.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ async def _search_files_function(
+ self, query: str, *args, **kwargs
+ ) -> AggregateSearchResult:
+ if not self.file_search_method:
+ raise ValueError(
+ "No file_search_method provided to RAGAgentMixin."
+ )
+
+ # call the doc-level search
+ """
+ FIXME: This is going to fail, as it requires an embedding NOT a query.
+ I've moved 'search_settings' to 'settings' which had been causing a silent failure
+ causing null content in the Message object.
+ """
+ doc_results = await self.file_search_method(
+ query=query,
+ settings=self.search_settings,
+ )
+
+ # Wrap them in an AggregateSearchResult
+ agg = AggregateSearchResult(document_search_results=doc_results)
+
+ # Add them to the collector
+ self.search_results_collector.add_aggregate_result(agg)
+ return agg
+
+ def format_search_results_for_llm(
+ self, results: AggregateSearchResult
+ ) -> str:
+ context = format_search_results_for_llm(
+ results, self.search_results_collector
+ )
+ context_tokens = num_tokens(context) + 1
+ frac_to_return = self.max_tool_context_length / (context_tokens)
+
+ if frac_to_return > 1:
+ return context
+ else:
+ return context[: int(frac_to_return * len(context))]
+
+ def web_scrape(self) -> Tool:
+ """
+ A new Tool that uses Firecrawl to scrape a single URL and return
+ its contents in an LLM-friendly format (e.g. markdown).
+ """
+ return Tool(
+ name="web_scrape",
+ description=(
+ "Use Firecrawl to scrape a single webpage and retrieve its contents "
+ "as clean markdown. Useful when you need the entire body of a page, "
+ "not just a quick snippet or standard web search result."
+ ),
+ results_function=self._web_scrape_function,
+ llm_format_function=self.format_search_results_for_llm,
+ parameters={
+ "type": "object",
+ "properties": {
+ "url": {
+ "type": "string",
+ "description": (
+ "The absolute URL of the webpage you want to scrape. "
+ "Example: 'https://docs.firecrawl.dev/getting-started'"
+ ),
+ }
+ },
+ "required": ["url"],
+ },
+ )
+
+ async def _web_scrape_function(
+ self,
+ url: str,
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Performs the Firecrawl scrape asynchronously, returning results
+ as an `AggregateSearchResult` with a single WebPageSearchResult.
+ """
+ import asyncio
+
+ from firecrawl import FirecrawlApp
+
+ app = FirecrawlApp()
+ logger.debug(f"[Firecrawl] Scraping URL={url}")
+
+ # Create a proper async wrapper for the synchronous scrape_url method
+ # This offloads the blocking operation to a thread pool
+ response = await asyncio.get_event_loop().run_in_executor(
+ None, # Uses the default executor
+ lambda: app.scrape_url(
+ url=url,
+ params={"formats": ["markdown"]},
+ ),
+ )
+
+ markdown_text = response.get("markdown", "")
+ metadata = response.get("metadata", {})
+ page_title = metadata.get("title", "Untitled page")
+
+ if len(markdown_text) > 100_000:
+ markdown_text = (
+ markdown_text[:100_000] + "...FURTHER CONTENT TRUNCATED..."
+ )
+
+ # Create a single WebPageSearchResult HACK - TODO FIX
+ web_result = WebPageSearchResult(
+ title=page_title,
+ link=url,
+ snippet=markdown_text,
+ position=0,
+ id=generate_id(markdown_text),
+ type="firecrawl",
+ )
+
+ agg = AggregateSearchResult(web_search_results=[web_result])
+
+ # Add results to the collector
+ if self.search_results_collector:
+ self.search_results_collector.add_aggregate_result(agg)
+
+ return agg
+
+
+class R2RRAGAgent(RAGAgentMixin, R2RAgent):
+ """
+ Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Initialize base R2RAgent
+ R2RAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RXMLToolsRAGAgent(RAGAgentMixin, R2RXMLToolsAgent):
+ """
+ Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Initialize base R2RAgent
+ R2RXMLToolsAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
+ """
+ Streaming-capable RAG Agent that supports search_file_knowledge, content, web_search,
+ and emits citations as [abc1234] short IDs if the LLM includes them in brackets.
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+
+ # Initialize base R2RStreamingAgent
+ R2RStreamingAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class R2RXMLToolsStreamingRAGAgent(RAGAgentMixin, R2RXMLStreamingAgent):
+ """
+ A streaming agent that:
+ - treats <think> or <Thought> blocks as chain-of-thought
+ and emits them incrementally as SSE "thinking" events.
+ - accumulates user-visible text outside those tags as SSE "message" events.
+ - filters out all XML tags related to tool calls and actions.
+ - upon finishing each iteration, it parses <Action><ToolCalls><ToolCall> blocks,
+ calls the appropriate tool, and emits SSE "tool_call" / "tool_result".
+ - properly emits citations when they appear in the text
+ """
+
+ def __init__(
+ self,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+
+ # Initialize base R2RXMLStreamingAgent
+ R2RXMLStreamingAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ rag_generation_config=rag_generation_config,
+ )
+
+ # Initialize the RAGAgentMixin
+ RAGAgentMixin.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/agent/research.py b/.venv/lib/python3.12/site-packages/core/agent/research.py
new file mode 100644
index 00000000..6ea35783
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/agent/research.py
@@ -0,0 +1,697 @@
+import logging
+import os
+import subprocess
+import sys
+import tempfile
+from copy import copy
+from typing import Any, Callable, Optional
+
+from core.base import AppConfig
+from core.base.abstractions import GenerationConfig, Message, SearchSettings
+from core.base.agent import Tool
+from core.base.providers import DatabaseProvider
+from core.providers import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from core.utils import extract_citations
+
+from ..base.agent.agent import RAGAgentConfig # type: ignore
+
+# Import the RAG agents we'll leverage
+from .rag import ( # type: ignore
+ R2RRAGAgent,
+ R2RStreamingRAGAgent,
+ R2RXMLToolsRAGAgent,
+ R2RXMLToolsStreamingRAGAgent,
+ RAGAgentMixin,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ResearchAgentMixin(RAGAgentMixin):
+ """
+ A mixin that extends RAGAgentMixin to add research capabilities to any R2R agent.
+
+ This mixin provides all RAG capabilities plus additional research tools:
+ - A RAG tool for knowledge retrieval (which leverages the underlying RAG capabilities)
+ - A Python execution tool for code execution and computation
+ - A reasoning tool for complex problem solving
+ - A critique tool for analyzing conversation history
+ """
+
+ def __init__(
+ self,
+ *args,
+ app_config: AppConfig,
+ search_settings: SearchSettings,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length=10_000,
+ **kwargs,
+ ):
+ # Store the app configuration needed for research tools
+ self.app_config = app_config
+
+ # Call the parent RAGAgentMixin's __init__ with explicitly passed parameters
+ super().__init__(
+ *args,
+ search_settings=search_settings,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ **kwargs,
+ )
+
+ # Register our research-specific tools
+ self._register_research_tools()
+
+ def _register_research_tools(self):
+ """
+ Register research-specific tools to the agent.
+ This is called by the mixin's __init__ after the parent class initialization.
+ """
+ # Add our research tools to whatever tools are already registered
+ research_tools = []
+ for tool_name in set(self.config.research_tools):
+ if tool_name == "rag":
+ research_tools.append(self.rag_tool())
+ elif tool_name == "reasoning":
+ research_tools.append(self.reasoning_tool())
+ elif tool_name == "critique":
+ research_tools.append(self.critique_tool())
+ elif tool_name == "python_executor":
+ research_tools.append(self.python_execution_tool())
+ else:
+ logger.warning(f"Unknown research tool: {tool_name}")
+ raise ValueError(f"Unknown research tool: {tool_name}")
+
+ logger.debug(f"Registered research tools: {research_tools}")
+ self.tools = research_tools
+
+ def rag_tool(self) -> Tool:
+ """Tool that provides access to the RAG agent's search capabilities."""
+ return Tool(
+ name="rag",
+ description=(
+ "Search for information using RAG (Retrieval-Augmented Generation). "
+ "This tool searches across relevant sources and returns comprehensive information. "
+ "Use this tool when you need to find specific information on any topic. Be sure to pose your query as a comprehensive query."
+ ),
+ results_function=self._rag,
+ llm_format_function=self._format_search_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The search query to find information.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ def reasoning_tool(self) -> Tool:
+ """Tool that provides access to a strong reasoning model."""
+ return Tool(
+ name="reasoning",
+ description=(
+ "A dedicated reasoning system that excels at solving complex problems through step-by-step analysis. "
+ "This tool connects to a separate AI system optimized for deep analytical thinking.\n\n"
+ "USAGE GUIDELINES:\n"
+ "1. Formulate your request as a complete, standalone question to a reasoning expert.\n"
+ "2. Clearly state the problem/question at the beginning.\n"
+ "3. Provide all relevant context, data, and constraints.\n\n"
+ "IMPORTANT: This system has no memory of previous interactions or context from your conversation.\n\n"
+ "STRENGTHS: Mathematical reasoning, logical analysis, evaluating complex scenarios, "
+ "solving multi-step problems, and identifying potential errors in reasoning."
+ ),
+ results_function=self._reason,
+ llm_format_function=self._format_search_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "A complete, standalone question with all necessary context, appropriate for a dedicated reasoning system.",
+ }
+ },
+ "required": ["query"],
+ },
+ )
+
+ def critique_tool(self) -> Tool:
+ """Tool that provides critical analysis of the reasoning done so far in the conversation."""
+ return Tool(
+ name="critique",
+ description=(
+ "Analyzes the conversation history to identify potential flaws, biases, and alternative "
+ "approaches to the reasoning presented so far.\n\n"
+ "Use this tool to get a second opinion on your reasoning, find overlooked considerations, "
+ "identify biases or fallacies, explore alternative hypotheses, and improve the robustness "
+ "of your conclusions."
+ ),
+ results_function=self._critique,
+ llm_format_function=self._format_search_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "A specific aspect of the reasoning you want critiqued, or leave empty for a general critique.",
+ },
+ "focus_areas": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Optional specific areas to focus the critique (e.g., ['logical fallacies', 'methodology'])",
+ },
+ },
+ "required": ["query"],
+ },
+ )
+
+ def python_execution_tool(self) -> Tool:
+ """Tool that provides Python code execution capabilities."""
+ return Tool(
+ name="python_executor",
+ description=(
+ "Executes Python code and returns the results, output, and any errors. "
+ "Use this tool for complex calculations, statistical operations, or algorithmic implementations.\n\n"
+ "The execution environment includes common libraries such as numpy, pandas, sympy, scipy, statsmodels, biopython, etc.\n\n"
+ "USAGE:\n"
+ "1. Send complete, executable Python code as a string.\n"
+ "2. Use print statements for output you want to see.\n"
+ "3. Assign to the 'result' variable for values you want to return.\n"
+ "4. Do not use input() or plotting (matplotlib). Output is text-based."
+ ),
+ results_function=self._execute_python_with_process_timeout,
+ llm_format_function=self._format_python_results,
+ parameters={
+ "type": "object",
+ "properties": {
+ "code": {
+ "type": "string",
+ "description": "Python code to execute.",
+ }
+ },
+ "required": ["code"],
+ },
+ )
+
+ async def _rag(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Execute a search using an internal RAG agent."""
+ # Create a copy of the current configuration for the RAG agent
+ config_copy = copy(self.config)
+ config_copy.max_iterations = 10 # Could be configurable
+ config_copy.rag_tools = [
+ "web_search",
+ "web_scrape",
+ ] # HACK HACK TODO - Fix.
+
+ # Create a generation config for the RAG agent
+ generation_config = GenerationConfig(
+ model=self.app_config.quality_llm,
+ max_tokens_to_sample=16000,
+ )
+
+ # Create a new RAG agent - we'll use the non-streaming variant for consistent results
+ rag_agent = R2RRAGAgent(
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ config=config_copy,
+ search_settings=self.search_settings,
+ rag_generation_config=generation_config,
+ knowledge_search_method=self.knowledge_search_method,
+ content_method=self.content_method,
+ file_search_method=self.file_search_method,
+ max_tool_context_length=self.max_tool_context_length,
+ )
+
+ # Run the RAG agent with the query
+ user_message = Message(role="user", content=query)
+ response = await rag_agent.arun(messages=[user_message])
+
+ # Get the content from the response
+ structured_content = response[-1].get("structured_content")
+ if structured_content:
+ possible_text = structured_content[-1].get("text")
+ content = response[-1].get("content") or possible_text
+ else:
+ content = response[-1].get("content")
+
+ # Extract citations and transfer search results from RAG agent to research agent
+ short_ids = extract_citations(content)
+ if short_ids:
+ logger.info(f"Found citations in RAG response: {short_ids}")
+
+ for short_id in short_ids:
+ result = rag_agent.search_results_collector.find_by_short_id(
+ short_id
+ )
+ if result:
+ self.search_results_collector.add_result(result)
+
+ # Log confirmation for successful transfer
+ logger.info(
+ "Transferred search results from RAG agent to research agent for citations"
+ )
+ return content
+
+ async def _reason(
+ self,
+ query: str,
+ *args,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Execute a reasoning query using a specialized reasoning LLM."""
+ msg_list = await self.conversation.get_messages()
+
+ # Create a specialized generation config for reasoning
+ gen_cfg = self.get_generation_config(msg_list[-1], stream=False)
+ gen_cfg.model = self.app_config.reasoning_llm
+ gen_cfg.top_p = None
+ gen_cfg.temperature = 0.1
+ gen_cfg.max_tokens_to_sample = 64000
+ gen_cfg.stream = False
+ gen_cfg.tools = None
+ gen_cfg.functions = None
+ gen_cfg.reasoning_effort = "high"
+ gen_cfg.add_generation_kwargs = None
+
+ # Call the LLM with the reasoning request
+ response = await self.llm_provider.aget_completion(
+ [{"role": "user", "content": query}], gen_cfg
+ )
+ return response.choices[0].message.content
+
+ async def _critique(
+ self,
+ query: str,
+ focus_areas: Optional[list] = None,
+ *args,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Critique the conversation history."""
+ msg_list = await self.conversation.get_messages()
+ if not focus_areas:
+ focus_areas = []
+ # Build the critique prompt
+ critique_prompt = (
+ "You are a critical reasoning expert. Your task is to analyze the following conversation "
+ "and critique the reasoning. Look for:\n"
+ "1. Logical fallacies or inconsistencies\n"
+ "2. Cognitive biases\n"
+ "3. Overlooked questions or considerations\n"
+ "4. Alternative approaches\n"
+ "5. Improvements in rigor\n\n"
+ )
+
+ if focus_areas:
+ critique_prompt += f"Focus areas: {', '.join(focus_areas)}\n\n"
+
+ if query.strip():
+ critique_prompt += f"Specific question: {query}\n\n"
+
+ critique_prompt += (
+ "Structure your critique:\n"
+ "1. Summary\n"
+ "2. Key strengths\n"
+ "3. Potential issues\n"
+ "4. Alternatives\n"
+ "5. Recommendations\n\n"
+ )
+
+ # Add the conversation history to the prompt
+ conversation_text = "\n--- CONVERSATION HISTORY ---\n\n"
+ for msg in msg_list:
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+ if content and role in ["user", "assistant", "system"]:
+ conversation_text += f"{role.upper()}: {content}\n\n"
+
+ final_prompt = critique_prompt + conversation_text
+
+ # Use the reasoning tool to process the critique
+ return await self._reason(final_prompt, *args, **kwargs)
+
+ async def _execute_python_with_process_timeout(
+ self, code: str, timeout: int = 10, *args, **kwargs
+ ) -> dict[str, Any]:
+ """
+ Executes Python code in a separate subprocess with a timeout.
+ This provides isolation and prevents re-importing the current agent module.
+
+ Parameters:
+ code (str): Python code to execute.
+ timeout (int): Timeout in seconds (default: 10).
+
+ Returns:
+ dict[str, Any]: Dictionary containing stdout, stderr, return code, etc.
+ """
+ # Write user code to a temporary file
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".py", delete=False
+ ) as tmp_file:
+ tmp_file.write(code)
+ script_path = tmp_file.name
+
+ try:
+ # Run the script in a fresh subprocess
+ result = subprocess.run(
+ [sys.executable, script_path],
+ capture_output=True,
+ text=True,
+ timeout=timeout,
+ )
+
+ return {
+ "result": None, # We'll parse from stdout if needed
+ "stdout": result.stdout,
+ "stderr": result.stderr,
+ "error": (
+ None
+ if result.returncode == 0
+ else {
+ "type": "SubprocessError",
+ "message": f"Process exited with code {result.returncode}",
+ "traceback": "",
+ }
+ ),
+ "locals": {}, # No direct local var capture in a separate process
+ "success": (result.returncode == 0),
+ "timed_out": False,
+ "timeout": timeout,
+ }
+ except subprocess.TimeoutExpired as e:
+ return {
+ "result": None,
+ "stdout": e.output or "",
+ "stderr": e.stderr or "",
+ "error": {
+ "type": "TimeoutError",
+ "message": f"Execution exceeded {timeout} second limit.",
+ "traceback": "",
+ },
+ "locals": {},
+ "success": False,
+ "timed_out": True,
+ "timeout": timeout,
+ }
+ finally:
+ # Clean up the temp file
+ if os.path.exists(script_path):
+ os.remove(script_path)
+
+ def _format_python_results(self, results: dict[str, Any]) -> str:
+ """Format Python execution results for display."""
+ output = []
+
+ # Timeout notification
+ if results.get("timed_out", False):
+ output.append(
+ f"⚠️ **Execution Timeout**: Code exceeded the {results.get('timeout', 10)} second limit."
+ )
+ output.append("")
+
+ # Stdout
+ if results.get("stdout"):
+ output.append("## Output:")
+ output.append("```")
+ output.append(results["stdout"].rstrip())
+ output.append("```")
+ output.append("")
+
+ # If there's a 'result' variable to display
+ if results.get("result") is not None:
+ output.append("## Result:")
+ output.append("```")
+ output.append(str(results["result"]))
+ output.append("```")
+ output.append("")
+
+ # Error info
+ if not results.get("success", True):
+ output.append("## Error:")
+ output.append("```")
+ stderr_out = results.get("stderr", "").rstrip()
+ if stderr_out:
+ output.append(stderr_out)
+
+ err_obj = results.get("error")
+ if err_obj and err_obj.get("message"):
+ output.append(err_obj["message"])
+ output.append("```")
+
+ # Return formatted output
+ return (
+ "\n".join(output)
+ if output
+ else "Code executed with no output or result."
+ )
+
+ def _format_search_results(self, results) -> str:
+ """Simple pass-through formatting for RAG search results."""
+ return results
+
+
+class R2RResearchAgent(ResearchAgentMixin, R2RRAGAgent):
+ """
+ A non-streaming research agent that uses the standard R2R agent as its base.
+
+ This agent combines research capabilities with the non-streaming RAG agent,
+ providing tools for deep research through tool-based interaction.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Set a higher max iterations for research tasks
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the RAG agent first
+ R2RRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ file_search_method=file_search_method,
+ content_method=content_method,
+ )
+
+
+class R2RStreamingResearchAgent(ResearchAgentMixin, R2RStreamingRAGAgent):
+ """
+ A streaming research agent that uses the streaming RAG agent as its base.
+
+ This agent combines research capabilities with streaming text generation,
+ providing real-time responses while still offering research tools.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the streaming RAG agent first
+ R2RStreamingRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class R2RXMLToolsResearchAgent(ResearchAgentMixin, R2RXMLToolsRAGAgent):
+ """
+ A non-streaming research agent that uses XML tool formatting.
+
+ This agent combines research capabilities with the XML-based tool calling format,
+ which might be more appropriate for certain LLM providers.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 20_000,
+ ):
+ # Set higher max iterations
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the XML Tools RAG agent first
+ R2RXMLToolsRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ search_settings=search_settings,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+
+class R2RXMLToolsStreamingResearchAgent(
+ ResearchAgentMixin, R2RXMLToolsStreamingRAGAgent
+):
+ """
+ A streaming research agent that uses XML tool formatting.
+
+ This agent combines research capabilities with streaming and XML-based tool calling,
+ providing real-time responses in a format suitable for certain LLM providers.
+ """
+
+ def __init__(
+ self,
+ app_config: AppConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ config: RAGAgentConfig,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ knowledge_search_method: Callable,
+ content_method: Callable,
+ file_search_method: Callable,
+ max_tool_context_length: int = 10_000,
+ ):
+ # Force streaming on
+ config.stream = True
+ config.max_iterations = config.max_iterations or 15
+
+ # Initialize the XML Tools Streaming RAG agent first
+ R2RXMLToolsStreamingRAGAgent.__init__(
+ self,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=config,
+ search_settings=search_settings,
+ rag_generation_config=rag_generation_config,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
+
+ # Then initialize the ResearchAgentMixin
+ ResearchAgentMixin.__init__(
+ self,
+ app_config=app_config,
+ search_settings=search_settings,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ max_tool_context_length=max_tool_context_length,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/base/__init__.py b/.venv/lib/python3.12/site-packages/core/base/__init__.py
new file mode 100644
index 00000000..1e872799
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/__init__.py
@@ -0,0 +1,130 @@
+from .abstractions import *
+from .agent import *
+from .api.models import *
+from .parsers import *
+from .providers import *
+from .utils import *
+
+__all__ = [
+ "ThinkingEvent",
+ "ToolCallEvent",
+ "ToolResultEvent",
+ "CitationEvent",
+ "Citation",
+ ## ABSTRACTIONS
+ # Base abstractions
+ "AsyncSyncMeta",
+ "syncable",
+ # Completion abstractions
+ "MessageType",
+ # Document abstractions
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "DocumentType",
+ # Embedding abstractions
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ # Exception abstractions
+ "R2RDocumentProcessingError",
+ "R2RException",
+ # Graph abstractions
+ "Entity",
+ "GraphExtraction",
+ "Relationship",
+ "Community",
+ "GraphCreationSettings",
+ "GraphEnrichmentSettings",
+ # LLM abstractions
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "RAGCompletion",
+ # Prompt abstractions
+ "Prompt",
+ # Search abstractions
+ "AggregateSearchResult",
+ "WebSearchResult",
+ "GraphSearchResult",
+ "GraphSearchSettings",
+ "ChunkSearchSettings",
+ "ChunkSearchResult",
+ "WebPageSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "SearchMode",
+ "HybridSearchSettings",
+ # User abstractions
+ "Token",
+ "TokenData",
+ # Vector abstractions
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "StorageResult",
+ "IndexConfig",
+ ## AGENT
+ # Agent abstractions
+ "Agent",
+ "AgentConfig",
+ "Conversation",
+ "Message",
+ "Tool",
+ "ToolResult",
+ ## API
+ # Auth Responses
+ "TokenResponse",
+ "User",
+ ## PARSERS
+ # Base parser
+ "AsyncParser",
+ ## PROVIDERS
+ # Base provider classes
+ "AppConfig",
+ "Provider",
+ "ProviderConfig",
+ # Auth provider
+ "AuthConfig",
+ "AuthProvider",
+ # Crypto provider
+ "CryptoConfig",
+ "CryptoProvider",
+ # Email provider
+ "EmailConfig",
+ "EmailProvider",
+ # Database providers
+ "LimitSettings",
+ "DatabaseConfig",
+ "DatabaseProvider",
+ "Handler",
+ "PostgresConfigurationSettings",
+ # Embedding provider
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ # Ingestion provider
+ "IngestionMode",
+ "IngestionConfig",
+ "IngestionProvider",
+ "ChunkingStrategy",
+ # LLM provider
+ "CompletionConfig",
+ "CompletionProvider",
+ ## UTILS
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "format_search_results_for_llm",
+ "validate_uuid",
+ # ID generation
+ "generate_id",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_default_user_collection_id",
+ "generate_user_id",
+ "increment_version",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py
new file mode 100644
index 00000000..bb1363fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py
@@ -0,0 +1,154 @@
+from shared.abstractions.base import AsyncSyncMeta, R2RSerializable, syncable
+from shared.abstractions.document import (
+ ChunkEnrichmentSettings,
+ Document,
+ DocumentChunk,
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ IngestionStatus,
+ RawChunk,
+ UnprocessedChunk,
+ UpdateChunk,
+)
+from shared.abstractions.embedding import (
+ EmbeddingPurpose,
+ default_embedding_prefixes,
+)
+from shared.abstractions.exception import (
+ R2RDocumentProcessingError,
+ R2RException,
+)
+from shared.abstractions.graph import (
+ Community,
+ Entity,
+ Graph,
+ GraphCommunitySettings,
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphExtraction,
+ Relationship,
+ StoreType,
+)
+from shared.abstractions.llm import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ Message,
+ MessageType,
+ RAGCompletion,
+)
+from shared.abstractions.prompt import Prompt
+from shared.abstractions.search import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ ChunkSearchSettings,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+ GraphSearchResult,
+ GraphSearchResultType,
+ GraphSearchSettings,
+ HybridSearchSettings,
+ SearchMode,
+ SearchSettings,
+ WebPageSearchResult,
+ WebSearchResult,
+ select_search_filters,
+)
+from shared.abstractions.user import Token, TokenData, User
+from shared.abstractions.vector import (
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexConfig,
+ IndexMeasure,
+ IndexMethod,
+ StorageResult,
+ Vector,
+ VectorEntry,
+ VectorQuantizationSettings,
+ VectorQuantizationType,
+ VectorTableName,
+ VectorType,
+)
+
+__all__ = [
+ # Base abstractions
+ "R2RSerializable",
+ "AsyncSyncMeta",
+ "syncable",
+ # Completion abstractions
+ "MessageType",
+ # Document abstractions
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "DocumentType",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "RawChunk",
+ "UnprocessedChunk",
+ "UpdateChunk",
+ # Embedding abstractions
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ # Exception abstractions
+ "R2RDocumentProcessingError",
+ "R2RException",
+ # Graph abstractions
+ "Entity",
+ "Graph",
+ "Community",
+ "StoreType",
+ "GraphExtraction",
+ "Relationship",
+ # Index abstractions
+ "IndexConfig",
+ # LLM abstractions
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "Message",
+ "RAGCompletion",
+ # Prompt abstractions
+ "Prompt",
+ # Search abstractions
+ "WebSearchResult",
+ "AggregateSearchResult",
+ "GraphSearchResult",
+ "GraphSearchResultType",
+ "GraphEntityResult",
+ "GraphRelationshipResult",
+ "GraphCommunityResult",
+ "GraphSearchSettings",
+ "ChunkSearchSettings",
+ "ChunkSearchResult",
+ "WebPageSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "SearchMode",
+ "HybridSearchSettings",
+ # Graph abstractions
+ "GraphCreationSettings",
+ "GraphEnrichmentSettings",
+ "GraphCommunitySettings",
+ # User abstractions
+ "Token",
+ "TokenData",
+ "User",
+ # Vector abstractions
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "IndexMeasure",
+ "IndexMethod",
+ "VectorTableName",
+ "IndexArgsHNSW",
+ "IndexArgsIVFFlat",
+ "VectorQuantizationSettings",
+ "VectorQuantizationType",
+ "StorageResult",
+ "ChunkEnrichmentSettings",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py
new file mode 100644
index 00000000..815b9ae7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py
@@ -0,0 +1,17 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+from .agent import ( # type: ignore
+ Agent,
+ AgentConfig,
+ Conversation,
+ Tool,
+ ToolResult,
+)
+
+__all__ = [
+ # Agent abstractions
+ "Agent",
+ "AgentConfig",
+ "Conversation",
+ "Tool",
+ "ToolResult",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/agent.py b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
new file mode 100644
index 00000000..6813dd21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
@@ -0,0 +1,291 @@
+# type: ignore
+import asyncio
+import json
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from json import JSONDecodeError
+from typing import Any, AsyncGenerator, Optional, Type
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ Message,
+)
+from core.base.providers import CompletionProvider, DatabaseProvider
+
+from .base import Tool, ToolResult
+
+logger = logging.getLogger()
+
+
+class Conversation:
+ def __init__(self):
+ self.messages: list[Message] = []
+ self._lock = asyncio.Lock()
+
+ async def add_message(self, message):
+ async with self._lock:
+ self.messages.append(message)
+
+ async def get_messages(self) -> list[dict[str, Any]]:
+ async with self._lock:
+ return [
+ {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
+ for msg in self.messages
+ ]
+
+
+# TODO - Move agents to provider pattern
+class AgentConfig(BaseModel):
+ rag_rag_agent_static_prompt: str = "static_rag_agent"
+ rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+ stream: bool = False
+ include_tools: bool = True
+ max_iterations: int = 10
+
+ @classmethod
+ def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ return cls(**filtered_kwargs) # type: ignore
+
+
+class Agent(ABC):
+ def __init__(
+ self,
+ llm_provider: CompletionProvider,
+ database_provider: DatabaseProvider,
+ config: AgentConfig,
+ rag_generation_config: GenerationConfig,
+ ):
+ self.llm_provider = llm_provider
+ self.database_provider: DatabaseProvider = database_provider
+ self.config = config
+ self.conversation = Conversation()
+ self._completed = False
+ self._tools: list[Tool] = []
+ self.tool_calls: list[dict] = []
+ self.rag_generation_config = rag_generation_config
+ self._register_tools()
+
+ @abstractmethod
+ def _register_tools(self):
+ pass
+
+ async def _setup(
+ self, system_instruction: Optional[str] = None, *args, **kwargs
+ ):
+ await self.conversation.add_message(
+ Message(
+ role="system",
+ content=system_instruction
+ or (
+ await self.database_provider.prompts_handler.get_cached_prompt(
+ self.config.rag_rag_agent_static_prompt,
+ inputs={
+ "date": str(datetime.now().strftime("%m/%d/%Y"))
+ },
+ )
+ + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates."
+ ),
+ )
+ )
+
+ @property
+ def tools(self) -> list[Tool]:
+ return self._tools
+
+ @tools.setter
+ def tools(self, tools: list[Tool]):
+ self._tools = tools
+
+ @abstractmethod
+ async def arun(
+ self,
+ system_instruction: Optional[str] = None,
+ messages: Optional[list[Message]] = None,
+ *args,
+ **kwargs,
+ ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
+ pass
+
+ @abstractmethod
+ async def process_llm_response(
+ self,
+ response: Any,
+ *args,
+ **kwargs,
+ ) -> None | AsyncGenerator[str, None]:
+ pass
+
+ async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
+ if tool := next((t for t in self.tools if t.name == tool_name), None):
+ return await tool.results_function(*args, **kwargs)
+ else:
+ return f"Error: Tool {tool_name} not found."
+
+ def get_generation_config(
+ self, last_message: dict, stream: bool = False
+ ) -> GenerationConfig:
+ if (
+ last_message["role"] in ["tool", "function"]
+ and last_message["content"] != ""
+ and "ollama" in self.rag_generation_config.model
+ or not self.config.include_tools
+ ):
+ return GenerationConfig(
+ **self.rag_generation_config.model_dump(
+ exclude={"functions", "tools", "stream"}
+ ),
+ stream=stream,
+ )
+
+ return GenerationConfig(
+ **self.rag_generation_config.model_dump(
+ exclude={"functions", "tools", "stream"}
+ ),
+ # FIXME: Use tools instead of functions
+ # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
+ tools=(
+ [
+ {
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.parameters,
+ },
+ "type": "function",
+ "name": tool.name,
+ }
+ for tool in self.tools
+ ]
+ if self.tools
+ else None
+ ),
+ stream=stream,
+ )
+
+ async def handle_function_or_tool_call(
+ self,
+ function_name: str,
+ function_arguments: str,
+ tool_id: Optional[str] = None,
+ save_messages: bool = True,
+ *args,
+ **kwargs,
+ ) -> ToolResult:
+ logger.debug(
+ f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}"
+ )
+ if tool := next(
+ (t for t in self.tools if t.name == function_name), None
+ ):
+ try:
+ function_args = json.loads(function_arguments)
+
+ except JSONDecodeError as e:
+ error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`."
+ if save_messages:
+ await self.conversation.add_message(
+ Message(
+ role="tool" if tool_id else "function",
+ content=error_message,
+ name=function_name,
+ tool_call_id=tool_id,
+ )
+ )
+
+ # raise R2RException(
+ # message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.",
+ # status_code=400,
+ # )
+
+ merged_kwargs = {**kwargs, **function_args}
+ try:
+ raw_result = await tool.results_function(
+ *args, **merged_kwargs
+ )
+ llm_formatted_result = tool.llm_format_function(raw_result)
+ except Exception as e:
+ raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}."
+ logger.error(raw_result)
+ llm_formatted_result = raw_result
+
+ tool_result = ToolResult(
+ raw_result=raw_result,
+ llm_formatted_result=llm_formatted_result,
+ )
+ if tool.stream_function:
+ tool_result.stream_result = tool.stream_function(raw_result)
+
+ if save_messages:
+ await self.conversation.add_message(
+ Message(
+ role="tool" if tool_id else "function",
+ content=str(tool_result.llm_formatted_result),
+ name=function_name,
+ tool_call_id=tool_id,
+ )
+ )
+ # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
+ if self.rag_generation_config.extended_thinking:
+ await self.conversation.add_message(
+ Message(
+ role="user",
+ content="Continue...",
+ )
+ )
+
+ self.tool_calls.append(
+ {
+ "name": function_name,
+ "args": function_arguments,
+ }
+ )
+ return tool_result
+
+
+# TODO - Move agents to provider pattern
+class RAGAgentConfig(AgentConfig):
+ rag_rag_agent_static_prompt: str = "static_rag_agent"
+ rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+ stream: bool = False
+ include_tools: bool = True
+ max_iterations: int = 10
+ # tools: list[str] = [] # HACK - unused variable.
+
+ # Default RAG tools
+ rag_tools: list[str] = [
+ "search_file_descriptions",
+ "search_file_knowledge",
+ "get_file_content",
+ ]
+
+ # Default Research tools
+ research_tools: list[str] = [
+ "rag",
+ "reasoning",
+ # DISABLED by default
+ "critique",
+ "python_executor",
+ ]
+
+ @classmethod
+ def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get(
+ "tool_names", None
+ )
+ return cls(**filtered_kwargs) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/base.py b/.venv/lib/python3.12/site-packages/core/base/agent/base.py
new file mode 100644
index 00000000..0d8f15ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/base.py
@@ -0,0 +1,22 @@
+from typing import Any, Callable, Optional
+
+from ..abstractions import R2RSerializable
+
+
+class Tool(R2RSerializable):
+ name: str
+ description: str
+ results_function: Callable
+ llm_format_function: Callable
+ stream_function: Optional[Callable] = None
+ parameters: Optional[dict[str, Any]] = None
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+
+
+class ToolResult(R2RSerializable):
+ raw_result: Any
+ llm_formatted_result: str
+ stream_result: Optional[str] = None
diff --git a/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py b/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py
new file mode 100644
index 00000000..dc0b041f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py
@@ -0,0 +1,208 @@
+from shared.api.models.auth.responses import (
+ TokenResponse,
+ WrappedTokenResponse,
+)
+from shared.api.models.base import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ PaginatedR2RResult,
+ R2RResults,
+ WrappedBooleanResponse,
+ WrappedGenericMessageResponse,
+)
+from shared.api.models.graph.responses import ( # TODO: Need to review anything above this
+ Community,
+ Entity,
+ GraphResponse,
+ Relationship,
+ WrappedCommunitiesResponse,
+ WrappedCommunityResponse,
+ WrappedEntitiesResponse,
+ WrappedEntityResponse,
+ WrappedGraphResponse,
+ WrappedGraphsResponse,
+ WrappedRelationshipResponse,
+ WrappedRelationshipsResponse,
+)
+from shared.api.models.ingestion.responses import (
+ IngestionResponse,
+ UpdateResponse,
+ VectorIndexResponse,
+ VectorIndicesResponse,
+ WrappedIngestionResponse,
+ WrappedMetadataUpdateResponse,
+ WrappedUpdateResponse,
+ WrappedVectorIndexResponse,
+ WrappedVectorIndicesResponse,
+)
+from shared.api.models.management.responses import ( # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
+ ChunkResponse,
+ CollectionResponse,
+ ConversationResponse,
+ MessageResponse,
+ PromptResponse,
+ ServerStats,
+ SettingsResponse,
+ User,
+ WrappedAPIKeyResponse,
+ WrappedAPIKeysResponse,
+ WrappedChunkResponse,
+ WrappedChunksResponse,
+ WrappedCollectionResponse,
+ WrappedCollectionsResponse,
+ WrappedConversationMessagesResponse,
+ WrappedConversationResponse,
+ WrappedConversationsResponse,
+ WrappedDocumentResponse,
+ WrappedDocumentsResponse,
+ WrappedLimitsResponse,
+ WrappedLoginResponse,
+ WrappedMessageResponse,
+ WrappedMessagesResponse,
+ WrappedPromptResponse,
+ WrappedPromptsResponse,
+ WrappedServerStatsResponse,
+ WrappedSettingsResponse,
+ WrappedUserResponse,
+ WrappedUsersResponse,
+)
+from shared.api.models.retrieval.responses import (
+ AgentEvent,
+ AgentResponse,
+ Citation,
+ CitationData,
+ CitationEvent,
+ Delta,
+ DeltaPayload,
+ FinalAnswerData,
+ FinalAnswerEvent,
+ MessageData,
+ MessageDelta,
+ MessageEvent,
+ RAGEvent,
+ RAGResponse,
+ SearchResultsData,
+ SearchResultsEvent,
+ SSEEventBase,
+ ThinkingData,
+ ThinkingEvent,
+ ToolCallData,
+ ToolCallEvent,
+ ToolResultData,
+ ToolResultEvent,
+ UnknownEvent,
+ WrappedAgentResponse,
+ WrappedCompletionResponse,
+ WrappedDocumentSearchResponse,
+ WrappedEmbeddingResponse,
+ WrappedLLMChatCompletion,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+ WrappedVectorSearchResponse,
+)
+
+__all__ = [
+ # Auth Responses
+ "TokenResponse",
+ "WrappedTokenResponse",
+ "WrappedGenericMessageResponse",
+ # Ingestion Responses
+ "IngestionResponse",
+ "WrappedIngestionResponse",
+ "WrappedUpdateResponse",
+ "WrappedMetadataUpdateResponse",
+ "WrappedVectorIndexResponse",
+ "WrappedVectorIndicesResponse",
+ "UpdateResponse",
+ "VectorIndexResponse",
+ "VectorIndicesResponse",
+ # Knowledge Graph Responses
+ "Entity",
+ "Relationship",
+ "Community",
+ "WrappedEntityResponse",
+ "WrappedEntitiesResponse",
+ "WrappedRelationshipResponse",
+ "WrappedRelationshipsResponse",
+ "WrappedCommunityResponse",
+ "WrappedCommunitiesResponse",
+ # TODO: Need to review anything above this
+ "GraphResponse",
+ "WrappedGraphResponse",
+ "WrappedGraphsResponse",
+ # Management Responses
+ "PromptResponse",
+ "ServerStats",
+ "SettingsResponse",
+ "ChunkResponse",
+ "CollectionResponse",
+ "WrappedServerStatsResponse",
+ "WrappedSettingsResponse",
+ "WrappedDocumentResponse",
+ "WrappedDocumentsResponse",
+ "WrappedCollectionResponse",
+ "WrappedCollectionsResponse",
+ # Conversation Responses
+ "ConversationResponse",
+ "WrappedConversationMessagesResponse",
+ "WrappedConversationResponse",
+ "WrappedConversationsResponse",
+ # Prompt Responses
+ "WrappedPromptResponse",
+ "WrappedPromptsResponse",
+ # Conversation Responses
+ "MessageResponse",
+ "WrappedMessageResponse",
+ "WrappedMessagesResponse",
+ # Chunk Responses
+ "WrappedChunkResponse",
+ "WrappedChunksResponse",
+ # User Responses
+ "User",
+ "WrappedUserResponse",
+ "WrappedUsersResponse",
+ "WrappedAPIKeyResponse",
+ "WrappedLimitsResponse",
+ "WrappedAPIKeysResponse",
+ "WrappedLoginResponse",
+ # Base Responses
+ "PaginatedR2RResult",
+ "R2RResults",
+ "GenericBooleanResponse",
+ "GenericMessageResponse",
+ "WrappedBooleanResponse",
+ "WrappedGenericMessageResponse",
+ # Retrieval Responses
+ "SSEEventBase",
+ "SearchResultsData",
+ "SearchResultsEvent",
+ "MessageDelta",
+ "MessageData",
+ "MessageEvent",
+ "DeltaPayload",
+ "Delta",
+ "CitationData",
+ "CitationEvent",
+ "FinalAnswerData",
+ "FinalAnswerEvent",
+ "ToolCallData",
+ "ToolCallEvent",
+ "ToolResultData",
+ "ToolResultEvent",
+ "ThinkingData",
+ "ThinkingEvent",
+ "RAGEvent",
+ "AgentEvent",
+ "UnknownEvent",
+ "RAGResponse",
+ "Citation",
+ "AgentResponse",
+ "WrappedDocumentSearchResponse",
+ "WrappedSearchResponse",
+ "WrappedVectorSearchResponse",
+ "WrappedCompletionResponse",
+ "WrappedRAGResponse",
+ "WrappedAgentResponse",
+ "WrappedLLMChatCompletion",
+ "WrappedEmbeddingResponse",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py
new file mode 100644
index 00000000..d7696202
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py
@@ -0,0 +1,5 @@
+from .base_parser import AsyncParser
+
+__all__ = [
+ "AsyncParser",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py b/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py
new file mode 100644
index 00000000..fb40d767
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py
@@ -0,0 +1,12 @@
+"""Abstract base class for parsers."""
+
+from abc import ABC, abstractmethod
+from typing import AsyncGenerator, Generic, TypeVar
+
+T = TypeVar("T")
+
+
+class AsyncParser(ABC, Generic[T]):
+ @abstractmethod
+ async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py
new file mode 100644
index 00000000..b902944d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py
@@ -0,0 +1,59 @@
+from .auth import AuthConfig, AuthProvider
+from .base import AppConfig, Provider, ProviderConfig
+from .crypto import CryptoConfig, CryptoProvider
+from .database import (
+ DatabaseConfig,
+ DatabaseConnectionManager,
+ DatabaseProvider,
+ Handler,
+ LimitSettings,
+ PostgresConfigurationSettings,
+)
+from .email import EmailConfig, EmailProvider
+from .embedding import EmbeddingConfig, EmbeddingProvider
+from .ingestion import (
+ ChunkingStrategy,
+ IngestionConfig,
+ IngestionMode,
+ IngestionProvider,
+)
+from .llm import CompletionConfig, CompletionProvider
+from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
+
+__all__ = [
+ # Auth provider
+ "AuthConfig",
+ "AuthProvider",
+ # Base provider classes
+ "AppConfig",
+ "Provider",
+ "ProviderConfig",
+ # Ingestion provider
+ "IngestionMode",
+ "IngestionConfig",
+ "IngestionProvider",
+ "ChunkingStrategy",
+ # Crypto provider
+ "CryptoConfig",
+ "CryptoProvider",
+ # Email provider
+ "EmailConfig",
+ "EmailProvider",
+ # Database providers
+ "DatabaseConnectionManager",
+ "DatabaseConfig",
+ "LimitSettings",
+ "PostgresConfigurationSettings",
+ "DatabaseProvider",
+ "Handler",
+ # Embedding provider
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ # LLM provider
+ "CompletionConfig",
+ "CompletionProvider",
+ # Orchestration provider
+ "OrchestrationConfig",
+ "OrchestrationProvider",
+ "Workflow",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
new file mode 100644
index 00000000..352c3331
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
@@ -0,0 +1,231 @@
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import Security
+from fastapi.security import (
+ APIKeyHeader,
+ HTTPAuthorizationCredentials,
+ HTTPBearer,
+)
+
+from ..abstractions import R2RException, Token, TokenData
+from ..api.models import User
+from .base import Provider, ProviderConfig
+from .crypto import CryptoProvider
+from .email import EmailProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
+
+
+class AuthConfig(ProviderConfig):
+ secret_key: Optional[str] = None
+ require_authentication: bool = False
+ require_email_verification: bool = False
+ default_admin_email: str = "admin@example.com"
+ default_admin_password: str = "change_me_immediately"
+ access_token_lifetime_in_minutes: Optional[int] = None
+ refresh_token_lifetime_in_days: Optional[int] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r"]
+
+ def validate_config(self) -> None:
+ pass
+
+
+class AuthProvider(Provider, ABC):
+ security = HTTPBearer(auto_error=False)
+ crypto_provider: CryptoProvider
+ email_provider: EmailProvider
+ database_provider: "PostgresDatabaseProvider"
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: "PostgresDatabaseProvider",
+ email_provider: EmailProvider,
+ ):
+ if not isinstance(config, AuthConfig):
+ raise ValueError(
+ "AuthProvider must be initialized with an AuthConfig"
+ )
+ self.config = config
+ self.admin_email = config.default_admin_email
+ self.admin_password = config.default_admin_password
+ self.crypto_provider = crypto_provider
+ self.database_provider = database_provider
+ self.email_provider = email_provider
+ super().__init__(config)
+ self.config: AuthConfig = config
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
+
+ async def _get_default_admin_user(self) -> User:
+ return await self.database_provider.users_handler.get_user_by_email(
+ self.admin_email
+ )
+
+ @abstractmethod
+ def create_access_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ def create_refresh_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ async def decode_token(self, token: str) -> TokenData:
+ pass
+
+ @abstractmethod
+ async def user(self, token: str) -> User:
+ pass
+
+ @abstractmethod
+ def get_current_active_user(self, current_user: User) -> User:
+ pass
+
+ @abstractmethod
+ async def register(self, email: str, password: str) -> User:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ pass
+
+ @abstractmethod
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ pass
+
+ @abstractmethod
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ pass
+
+ def auth_wrapper(
+ self,
+ public: bool = False,
+ ):
+ async def _auth_wrapper(
+ auth: Optional[HTTPAuthorizationCredentials] = Security(
+ self.security
+ ),
+ api_key: Optional[str] = Security(api_key_header),
+ ) -> User:
+ # If authentication is not required and no credentials are provided, return the default admin user
+ if (
+ ((not self.config.require_authentication) or public)
+ and auth is None
+ and api_key is None
+ ):
+ return await self._get_default_admin_user()
+ if not auth and not api_key:
+ raise R2RException(
+ message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
+ status_code=401,
+ )
+ if auth and api_key:
+ raise R2RException(
+ message="Cannot have both Bearer token and API key",
+ status_code=400,
+ )
+ # 1. Try JWT if `auth` is present (Bearer token)
+ if auth is not None:
+ credentials = auth.credentials
+ try:
+ token_data = await self.decode_token(credentials)
+ user = await self.database_provider.users_handler.get_user_by_email(
+ token_data.email
+ )
+ if user is not None:
+ return user
+ except R2RException:
+ # JWT decoding failed for logical reasons (invalid token)
+ pass
+ except Exception as e:
+ # JWT decoding failed unexpectedly, log and continue
+ logger.debug(f"JWT verification failed: {e}")
+
+ # 2. If JWT failed, try API key from Bearer token
+ # Expected format: key_id.raw_api_key
+ if "." in credentials:
+ key_id, raw_api_key = credentials.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # 3. If no Bearer token worked, try the X-API-Key header
+ if api_key is not None and "." in api_key:
+ key_id, raw_api_key = api_key.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # If we reach here, both JWT and API key auth failed
+ raise R2RException(
+ message="Invalid token or API key",
+ status_code=401,
+ )
+
+ return _auth_wrapper
+
+ @abstractmethod
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def logout(self, token: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/base.py b/.venv/lib/python3.12/site-packages/core/base/providers/base.py
new file mode 100644
index 00000000..3f72a5ea
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/base.py
@@ -0,0 +1,135 @@
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Type
+
+from pydantic import BaseModel
+
+
+class InnerConfig(BaseModel, ABC):
+ """A base provider configuration class."""
+
+ extra_fields: dict[str, Any] = {}
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @classmethod
+ def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs) # type: ignore
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+
+class AppConfig(InnerConfig):
+ project_name: Optional[str] = None
+ default_max_documents_per_user: Optional[int] = 100
+ default_max_chunks_per_user: Optional[int] = 10_000
+ default_max_collections_per_user: Optional[int] = 5
+ default_max_upload_size: int = 2_000_000 # e.g. ~2 MB
+ quality_llm: Optional[str] = None
+ fast_llm: Optional[str] = None
+ vlm: Optional[str] = None
+ audio_lm: Optional[str] = None
+ reasoning_llm: Optional[str] = None
+ planning_llm: Optional[str] = None
+
+ # File extension to max-size mapping
+ # These are examples; adjust sizes as needed.
+ max_upload_size_by_type: dict[str, int] = {
+ # Common text-based formats
+ "txt": 2_000_000,
+ "md": 2_000_000,
+ "tsv": 2_000_000,
+ "csv": 5_000_000,
+ "xml": 2_000_000,
+ "html": 5_000_000,
+ # Office docs
+ "doc": 10_000_000,
+ "docx": 10_000_000,
+ "ppt": 20_000_000,
+ "pptx": 20_000_000,
+ "xls": 10_000_000,
+ "xlsx": 10_000_000,
+ "odt": 5_000_000,
+ # PDFs can expand quite a bit when converted to text
+ "pdf": 30_000_000,
+ # E-mail
+ "eml": 5_000_000,
+ "msg": 5_000_000,
+ "p7s": 5_000_000,
+ # Images
+ "bmp": 5_000_000,
+ "heic": 5_000_000,
+ "jpeg": 5_000_000,
+ "jpg": 5_000_000,
+ "png": 5_000_000,
+ "tiff": 5_000_000,
+ # Others
+ "epub": 10_000_000,
+ "rtf": 5_000_000,
+ "rst": 5_000_000,
+ "org": 5_000_000,
+ }
+
+
+class ProviderConfig(BaseModel, ABC):
+ """A base provider configuration class."""
+
+ app: AppConfig # Add an app_config field
+ extra_fields: dict[str, Any] = {}
+ provider: Optional[str] = None
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @abstractmethod
+ def validate_config(self) -> None:
+ pass
+
+ @classmethod
+ def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs) # type: ignore
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+ @property
+ @abstractmethod
+ def supported_providers(self) -> list[str]:
+ """Define a list of supported providers."""
+ pass
+
+ @classmethod
+ def from_dict(
+ cls: Type["ProviderConfig"], data: dict[str, Any]
+ ) -> "ProviderConfig":
+ """Create a new instance of the config from a dictionary."""
+ return cls.create(**data)
+
+
+class Provider(ABC):
+ """A base provider class to provide a common interface for all
+ providers."""
+
+ def __init__(self, config: ProviderConfig, *args, **kwargs):
+ if config:
+ config.validate_config()
+ self.config = config
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py
new file mode 100644
index 00000000..bdf794b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py
@@ -0,0 +1,120 @@
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import Optional, Tuple
+
+from .base import Provider, ProviderConfig
+
+
+class CryptoConfig(ProviderConfig):
+ provider: Optional[str] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["bcrypt", "nacl"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Unsupported crypto provider: {self.provider}")
+
+
+class CryptoProvider(Provider, ABC):
+ def __init__(self, config: CryptoConfig):
+ if not isinstance(config, CryptoConfig):
+ raise ValueError(
+ "CryptoProvider must be initialized with a CryptoConfig"
+ )
+ super().__init__(config)
+
+ @abstractmethod
+ def get_password_hash(self, password: str) -> str:
+ """Hash a plaintext password using a secure password hashing algorithm
+ (e.g., Argon2i)."""
+ pass
+
+ @abstractmethod
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ """Verify that a plaintext password matches the given hashed
+ password."""
+ pass
+
+ @abstractmethod
+ def generate_verification_code(self, length: int = 32) -> str:
+ """Generate a random code for email verification or reset tokens."""
+ pass
+
+ @abstractmethod
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ """Generate a new Ed25519 signing keypair for request signing.
+
+ Returns:
+ A tuple of (key_id, private_key, public_key).
+ - key_id: A unique identifier for this keypair.
+ - private_key: Base64 encoded Ed25519 private key.
+ - public_key: Base64 encoded Ed25519 public key.
+ """
+ pass
+
+ @abstractmethod
+ def sign_request(self, private_key: str, data: str) -> str:
+ """Sign request data with an Ed25519 private key, returning the
+ signature."""
+ pass
+
+ @abstractmethod
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ """Verify a request signature using the corresponding Ed25519 public
+ key."""
+ pass
+
+ @abstractmethod
+ def generate_api_key(self) -> Tuple[str, str]:
+ """Generate a new API key for a user.
+
+ Returns:
+ A tuple (key_id, raw_api_key):
+ - key_id: A unique identifier for the API key.
+ - raw_api_key: The plaintext API key to provide to the user.
+ """
+ pass
+
+ @abstractmethod
+ def hash_api_key(self, raw_api_key: str) -> str:
+ """Hash a raw API key for secure storage in the database.
+
+ Use strong parameters suitable for long-term secrets.
+ """
+ pass
+
+ @abstractmethod
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ """Verify that a provided API key matches the stored hashed version."""
+ pass
+
+ @abstractmethod
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ """Generate a secure, signed token (e.g., JWT) embedding claims.
+
+ Args:
+ data: The claims to include in the token.
+ expiry: A datetime at which the token expires.
+
+ Returns:
+ A JWT string signed with a secret key.
+ """
+ pass
+
+ @abstractmethod
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ """Verify a secure token (e.g., JWT).
+
+ Args:
+ token: The token string to verify.
+
+ Returns:
+ The token payload if valid, otherwise None.
+ """
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
new file mode 100644
index 00000000..845a8109
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
@@ -0,0 +1,197 @@
+"""Base classes for database providers."""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, cast
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphSearchSettings,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class DatabaseConnectionManager(ABC):
+ @abstractmethod
+ def execute_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ isolation_level: Optional[str] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def execute_many(self, query, params=None, batch_size=1000):
+ pass
+
+ @abstractmethod
+ def fetch_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ def fetchrow_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def initialize(self, pool: Any):
+ pass
+
+
+class Handler(ABC):
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: DatabaseConnectionManager,
+ ):
+ self.project_name = project_name
+ self.connection_manager = connection_manager
+
+ def _get_table_name(self, base_name: str) -> str:
+ return f"{self.project_name}.{base_name}"
+
+ @abstractmethod
+ def create_tables(self):
+ pass
+
+
+class PostgresConfigurationSettings(BaseModel):
+ """Configuration settings with defaults defined by the PGVector docker
+ image.
+
+ These settings are helpful in managing the connections to the database. To
+ tune these settings for a specific deployment, see
+ https://pgtune.leopard.in.ua/
+ """
+
+ checkpoint_completion_target: Optional[float] = 0.9
+ default_statistics_target: Optional[int] = 100
+ effective_io_concurrency: Optional[int] = 1
+ effective_cache_size: Optional[int] = 524288
+ huge_pages: Optional[str] = "try"
+ maintenance_work_mem: Optional[int] = 65536
+ max_connections: Optional[int] = 256
+ max_parallel_workers_per_gather: Optional[int] = 2
+ max_parallel_workers: Optional[int] = 8
+ max_parallel_maintenance_workers: Optional[int] = 2
+ max_wal_size: Optional[int] = 1024
+ max_worker_processes: Optional[int] = 8
+ min_wal_size: Optional[int] = 80
+ shared_buffers: Optional[int] = 16384
+ statement_cache_size: Optional[int] = 100
+ random_page_cost: Optional[float] = 4
+ wal_buffers: Optional[int] = 512
+ work_mem: Optional[int] = 4096
+
+
+class LimitSettings(BaseModel):
+ global_per_min: Optional[int] = None
+ route_per_min: Optional[int] = None
+ monthly_limit: Optional[int] = None
+
+ def merge_with_defaults(
+ self, defaults: "LimitSettings"
+ ) -> "LimitSettings":
+ return LimitSettings(
+ global_per_min=self.global_per_min or defaults.global_per_min,
+ route_per_min=self.route_per_min or defaults.route_per_min,
+ monthly_limit=self.monthly_limit or defaults.monthly_limit,
+ )
+
+
+class DatabaseConfig(ProviderConfig):
+ """A base database configuration class."""
+
+ provider: str = "postgres"
+ user: Optional[str] = None
+ password: Optional[str] = None
+ host: Optional[str] = None
+ port: Optional[int] = None
+ db_name: Optional[str] = None
+ project_name: Optional[str] = None
+ postgres_configuration_settings: Optional[
+ PostgresConfigurationSettings
+ ] = None
+ default_collection_name: str = "Default"
+ default_collection_description: str = "Your default collection."
+ collection_summary_system_prompt: str = "system"
+ collection_summary_prompt: str = "collection_summary"
+ enable_fts: bool = False
+
+ # Graph settings
+ batch_size: Optional[int] = 1
+ graph_search_results_store_path: Optional[str] = None
+ graph_enrichment_settings: GraphEnrichmentSettings = (
+ GraphEnrichmentSettings()
+ )
+ graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
+ graph_search_settings: GraphSearchSettings = GraphSearchSettings()
+
+ # Rate limits
+ limits: LimitSettings = LimitSettings(
+ global_per_min=60, route_per_min=20, monthly_limit=10000
+ )
+ route_limits: dict[str, LimitSettings] = {}
+ user_limits: dict[UUID, LimitSettings] = {}
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["postgres"]
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
+ instance = cls.create(**data)
+
+ instance = cast(DatabaseConfig, instance)
+
+ limits_data = data.get("limits", {})
+ default_limits = LimitSettings(
+ global_per_min=limits_data.get("global_per_min", 60),
+ route_per_min=limits_data.get("route_per_min", 20),
+ monthly_limit=limits_data.get("monthly_limit", 10000),
+ )
+
+ instance.limits = default_limits
+
+ route_limits_data = limits_data.get("routes", {})
+ for route_str, route_cfg in route_limits_data.items():
+ instance.route_limits[route_str] = LimitSettings(**route_cfg)
+
+ return instance
+
+
+class DatabaseProvider(Provider):
+ connection_manager: DatabaseConnectionManager
+ config: DatabaseConfig
+ project_name: str
+
+ def __init__(self, config: DatabaseConfig):
+ logger.info(f"Initializing DatabaseProvider with config {config}.")
+ super().__init__(config)
+
+ @abstractmethod
+ async def __aenter__(self):
+ pass
+
+ @abstractmethod
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/email.py b/.venv/lib/python3.12/site-packages/core/base/providers/email.py
new file mode 100644
index 00000000..73f88162
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/email.py
@@ -0,0 +1,96 @@
+import logging
+import os
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from .base import Provider, ProviderConfig
+
+
+class EmailConfig(ProviderConfig):
+ smtp_server: Optional[str] = None
+ smtp_port: Optional[int] = None
+ smtp_username: Optional[str] = None
+ smtp_password: Optional[str] = None
+ from_email: Optional[str] = None
+ use_tls: Optional[bool] = True
+ sendgrid_api_key: Optional[str] = None
+ mailersend_api_key: Optional[str] = None
+ verify_email_template_id: Optional[str] = None
+ reset_password_template_id: Optional[str] = None
+ password_changed_template_id: Optional[str] = None
+ frontend_url: Optional[str] = None
+ sender_name: Optional[str] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return [
+ "smtp",
+ "console",
+ "sendgrid",
+ "mailersend",
+ ] # Could add more providers like AWS SES, SendGrid etc.
+
+ def validate_config(self) -> None:
+ if (
+ self.provider == "sendgrid"
+ and not self.sendgrid_api_key
+ and not os.getenv("SENDGRID_API_KEY")
+ ):
+ raise ValueError(
+ "SendGrid API key is required when using SendGrid provider"
+ )
+
+ if (
+ self.provider == "mailersend"
+ and not self.mailersend_api_key
+ and not os.getenv("MAILERSEND_API_KEY")
+ ):
+ raise ValueError(
+ "MailerSend API key is required when using MailerSend provider"
+ )
+
+
+logger = logging.getLogger(__name__)
+
+
+class EmailProvider(Provider, ABC):
+ def __init__(self, config: EmailConfig):
+ if not isinstance(config, EmailConfig):
+ raise ValueError(
+ "EmailProvider must be initialized with an EmailConfig"
+ )
+ super().__init__(config)
+ self.config: EmailConfig = config
+
+ @abstractmethod
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
new file mode 100644
index 00000000..d1f9f9d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
@@ -0,0 +1,197 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import VectorQuantizationSettings
+
+from ..abstractions import (
+ ChunkSearchResult,
+ EmbeddingPurpose,
+ default_embedding_prefixes,
+)
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class EmbeddingConfig(ProviderConfig):
+ provider: str
+ base_model: str
+ base_dimension: int | float
+ rerank_model: Optional[str] = None
+ rerank_url: Optional[str] = None
+ batch_size: int = 1
+ prefixes: Optional[dict[str, str]] = None
+ add_title_as_prefix: bool = True
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1
+ max_backoff: float = 64.0
+ quantization_settings: VectorQuantizationSettings = (
+ VectorQuantizationSettings()
+ )
+
+ ## deprecated
+ rerank_dimension: Optional[int] = None
+ rerank_transformer_type: Optional[str] = None
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["litellm", "openai", "ollama"]
+
+
+class EmbeddingProvider(Provider):
+ class Step(Enum):
+ BASE = 1
+ RERANK = 2
+
+ def __init__(self, config: EmbeddingConfig):
+ if not isinstance(config, EmbeddingConfig):
+ raise ValueError(
+ "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+ )
+ logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+ super().__init__(config)
+ self.config: EmbeddingConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.current_requests = 0
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ) -> list[list[float]]:
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ @abstractmethod
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ @abstractmethod
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
+ self.prefixes = {}
+
+ for t, p in config_prefixes.items():
+ purpose = EmbeddingPurpose(t.lower())
+ self.prefixes[purpose] = p
+
+ if base_model in default_embedding_prefixes:
+ for t, p in default_embedding_prefixes[base_model].items():
+ if t not in self.prefixes:
+ self.prefixes[t] = p
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py
new file mode 100644
index 00000000..70d0d3a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py
@@ -0,0 +1,172 @@
+import logging
+from abc import ABC
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from pydantic import Field
+
+from core.base.abstractions import ChunkEnrichmentSettings
+
+from .base import AppConfig, Provider, ProviderConfig
+from .llm import CompletionProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+
+class ChunkingStrategy(str, Enum):
+ RECURSIVE = "recursive"
+ CHARACTER = "character"
+ BASIC = "basic"
+ BY_TITLE = "by_title"
+
+
+class IngestionMode(str, Enum):
+ hi_res = "hi-res"
+ fast = "fast"
+ custom = "custom"
+
+
+class IngestionConfig(ProviderConfig):
+ _defaults: ClassVar[dict] = {
+ "app": AppConfig(),
+ "provider": "r2r",
+ "excluded_parsers": ["mp4"],
+ "chunking_strategy": "recursive",
+ "chunk_size": 1024,
+ "chunk_enrichment_settings": ChunkEnrichmentSettings(),
+ "extra_parsers": {},
+ "audio_transcription_model": None,
+ "vision_img_prompt_name": "vision_img",
+ "vision_pdf_prompt_name": "vision_pdf",
+ "skip_document_summary": False,
+ "document_summary_system_prompt": "system",
+ "document_summary_task_prompt": "summary",
+ "document_summary_max_length": 100_000,
+ "chunks_for_document_summary": 128,
+ "document_summary_model": None,
+ "parser_overrides": {},
+ "extra_fields": {},
+ "automatic_extraction": False,
+ }
+
+ provider: str = Field(
+ default_factory=lambda: IngestionConfig._defaults["provider"]
+ )
+ excluded_parsers: list[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
+ )
+ chunking_strategy: str | ChunkingStrategy = Field(
+ default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
+ )
+ chunk_size: int = Field(
+ default_factory=lambda: IngestionConfig._defaults["chunk_size"]
+ )
+ chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "chunk_enrichment_settings"
+ ]
+ )
+ extra_parsers: dict[str, Any] = Field(
+ default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
+ )
+ audio_transcription_model: Optional[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "audio_transcription_model"
+ ]
+ )
+ vision_img_prompt_name: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "vision_img_prompt_name"
+ ]
+ )
+ vision_pdf_prompt_name: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "vision_pdf_prompt_name"
+ ]
+ )
+ skip_document_summary: bool = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "skip_document_summary"
+ ]
+ )
+ document_summary_system_prompt: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_system_prompt"
+ ]
+ )
+ document_summary_task_prompt: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_task_prompt"
+ ]
+ )
+ chunks_for_document_summary: int = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "chunks_for_document_summary"
+ ]
+ )
+ document_summary_model: Optional[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_model"
+ ]
+ )
+ parser_overrides: dict[str, str] = Field(
+ default_factory=lambda: IngestionConfig._defaults["parser_overrides"]
+ )
+ automatic_extraction: bool = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "automatic_extraction"
+ ]
+ )
+ document_summary_max_length: int = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_max_length"
+ ]
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in IngestionConfig"
+ )
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r", "unstructured_local", "unstructured_api"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @classmethod
+ def get_default(cls, mode: str, app) -> "IngestionConfig":
+ """Return default ingestion configuration for a given mode."""
+ if mode == "hi-res":
+ return cls(app=app, parser_overrides={"pdf": "zerox"})
+ if mode == "fast":
+ return cls(app=app, skip_document_summary=True)
+ else:
+ return cls(app=app)
+
+
+class IngestionProvider(Provider, ABC):
+ config: IngestionConfig
+ database_provider: "PostgresDatabaseProvider"
+ llm_provider: CompletionProvider
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: "PostgresDatabaseProvider",
+ llm_provider: CompletionProvider,
+ ):
+ super().__init__(config)
+ self.config: IngestionConfig = config
+ self.llm_provider = llm_provider
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
new file mode 100644
index 00000000..669dfc4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
@@ -0,0 +1,200 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class CompletionConfig(ProviderConfig):
+ provider: Optional[str] = None
+ generation_config: Optional[GenerationConfig] = None
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1.0
+ max_backoff: float = 64.0
+
+ def validate_config(self) -> None:
+ if not self.provider:
+ raise ValueError("Provider must be set.")
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["anthropic", "litellm", "openai", "r2r"]
+
+
+class CompletionProvider(Provider):
+ def __init__(self, config: CompletionConfig) -> None:
+ if not isinstance(config, CompletionConfig):
+ raise ValueError(
+ "CompletionProvider must be initialized with a `CompletionConfig`."
+ )
+ logger.info(f"Initializing CompletionProvider with config: {config}")
+ super().__init__(config)
+ self.config: CompletionConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.thread_pool = ThreadPoolExecutor(
+ max_workers=config.concurrent_request_limit
+ )
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ async def _execute_with_backoff_async_stream(
+ self, task: dict[str, Any]
+ ) -> AsyncGenerator[Any, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ async for chunk in await self._execute_task(task):
+ yield chunk
+ return # Successful completion of the stream
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync_stream(
+ self, task: dict[str, Any]
+ ) -> Generator[Any, None, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ yield from self._execute_task_sync(task)
+ return # Successful completion of the stream
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ response = await self._execute_with_backoff_async(task)
+ return LLMChatCompletion(**response.dict())
+
+ async def aget_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ async for chunk in self._execute_with_backoff_async_stream(task):
+ if isinstance(chunk, dict):
+ yield LLMChatCompletionChunk(**chunk)
+ continue
+
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != ""
+ else None
+ ) # handle error output conventions
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != "eos"
+ else "stop"
+ ) # hardcode `eos` to `stop` for consistency
+ try:
+ yield LLMChatCompletionChunk(**(chunk.dict()))
+ except Exception as e:
+ logger.error(f"Error parsing chunk: {e}")
+ yield LLMChatCompletionChunk(**(chunk.as_dict()))
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ for chunk in self._execute_with_backoff_sync_stream(task):
+ yield LLMChatCompletionChunk(**chunk.dict())
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py
new file mode 100644
index 00000000..c3105f30
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py
@@ -0,0 +1,70 @@
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+from .base import Provider, ProviderConfig
+
+
+class Workflow(Enum):
+ INGESTION = "ingestion"
+ GRAPH = "graph"
+
+
+class OrchestrationConfig(ProviderConfig):
+ provider: str
+ max_runs: int = 2_048
+ graph_search_results_creation_concurrency_limit: int = 32
+ ingestion_concurrency_limit: int = 16
+ graph_search_results_concurrency_limit: int = 8
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["hatchet", "simple"]
+
+
+class OrchestrationProvider(Provider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ self.config = config
+ self.worker = None
+
+ @abstractmethod
+ async def start_worker(self):
+ pass
+
+ @abstractmethod
+ def get_worker(self, name: str, max_runs: int) -> Any:
+ pass
+
+ @abstractmethod
+ def step(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def workflow(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def failure(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def run_workflow(
+ self,
+ workflow_name: str,
+ parameters: dict,
+ options: dict,
+ *args,
+ **kwargs,
+ ) -> dict[str, str]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py b/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py
new file mode 100644
index 00000000..948a1069
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py
@@ -0,0 +1,43 @@
+from shared.utils import (
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+ _decorate_vector_type,
+ _get_vector_column_str,
+ decrement_version,
+ deep_update,
+ dump_collector,
+ dump_obj,
+ format_search_results_for_llm,
+ generate_default_prompt_id,
+ generate_default_user_collection_id,
+ generate_document_id,
+ generate_entity_document_id,
+ generate_extraction_id,
+ generate_id,
+ generate_user_id,
+ increment_version,
+ validate_uuid,
+ yield_sse_event,
+)
+
+__all__ = [
+ "format_search_results_for_llm",
+ "generate_id",
+ "generate_default_user_collection_id",
+ "increment_version",
+ "decrement_version",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_user_id",
+ "generate_entity_document_id",
+ "generate_default_prompt_id",
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "validate_uuid",
+ "deep_update",
+ "_decorate_vector_type",
+ "_get_vector_column_str",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/configs/full.toml b/.venv/lib/python3.12/site-packages/core/configs/full.toml
new file mode 100644
index 00000000..0bf70631
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/full.toml
@@ -0,0 +1,21 @@
+[completion]
+provider = "r2r"
+concurrent_request_limit = 128
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 2_048
+max_characters = 4_096
+combine_under_n_chars = 1_024
+overlap = 1_024
+
+ [ingestion.extra_parsers]
+ pdf = "zerox"
+
+[orchestration]
+provider = "hatchet"
+kg_creation_concurrency_limit = 32
+ingestion_concurrency_limit = 16
+kg_concurrency_limit = 8
diff --git a/.venv/lib/python3.12/site-packages/core/configs/full_azure.toml b/.venv/lib/python3.12/site-packages/core/configs/full_azure.toml
new file mode 100644
index 00000000..c6ebb199
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/full_azure.toml
@@ -0,0 +1,46 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "azure/gpt-4o-mini"
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "azure/gpt-4o"
+
+# LLM used for ingesting visual inputs
+vlm = "azure/gpt-4o"
+
+# LLM used for transcription
+audio_lm = "azure/whisper-1"
+
+# Reasoning model, used for `research` agent
+reasoning_llm = "azure/o3-mini"
+# Planning model, used for `research` agent
+planning_llm = "azure/o3-mini"
+
+[embedding]
+base_model = "azure/text-embedding-3-small"
+
+[completion_embedding]
+base_model = "azure/text-embedding-3-small"
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 2_048
+max_characters = 4_096
+combine_under_n_chars = 1_024
+overlap = 1_024
+document_summary_model = "azure/gpt-4o-mini"
+automatic_extraction = true # enable automatic extraction of entities and relations
+
+ [ingestion.extra_parsers]
+ pdf = "zerox"
+
+ [ingestion.chunk_enrichment_settings]
+ generation_config = { model = "azure/gpt-4o-mini" }
+
+[orchestration]
+provider = "hatchet"
+kg_creation_concurrency_limit = 32
+ingestion_concurrency_limit = 4
+kg_concurrency_limit = 8
diff --git a/.venv/lib/python3.12/site-packages/core/configs/full_lm_studio.toml b/.venv/lib/python3.12/site-packages/core/configs/full_lm_studio.toml
new file mode 100644
index 00000000..daae73a1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/full_lm_studio.toml
@@ -0,0 +1,57 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "lm_studio/llama-3.2-3b-instruct"
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "lm_studio/llama-3.2-3b-instruct"
+
+# LLM used for ingesting visual inputs
+vlm = "lm_studio/llama3.2-vision" # TODO - Replace with viable candidate
+
+# LLM used for transcription
+audio_lm = "lm_studio/llama-3.2-3b-instruct" # TODO - Replace with viable candidate
+
+[embedding]
+provider = "litellm"
+base_model = "lm_studio/text-embedding-nomic-embed-text-v1.5"
+base_dimension = nan
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[completion_embedding]
+# Generally this should be the same as the embedding config, but advanced users may want to run with a different provider to reduce latency
+provider = "litellm"
+base_model = "lm_studio/text-embedding-nomic-embed-text-v1.5"
+base_dimension = nan
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[agent]
+tools = ["search_file_knowledge"]
+
+[completion]
+provider = "litellm"
+concurrent_request_limit = 1
+
+ [completion.generation_config]
+ temperature = 0.1
+ top_p = 1
+ max_tokens_to_sample = 1_024
+ stream = false
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 512
+max_characters = 1_024
+combine_under_n_chars = 128
+overlap = 20
+chunks_for_document_summary = 16
+document_summary_model = "lm_studio/llama-3.2-3b-instruct"
+automatic_extraction = false
+
+[orchestration]
+provider = "hatchet"
diff --git a/.venv/lib/python3.12/site-packages/core/configs/full_ollama.toml b/.venv/lib/python3.12/site-packages/core/configs/full_ollama.toml
new file mode 100644
index 00000000..8ec2fc77
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/full_ollama.toml
@@ -0,0 +1,63 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "openai/llama3.1"
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "openai/llama3.1"
+
+# LLM used for ingesting visual inputs
+vlm = "openai/llama3.1" # TODO - Replace with viable candidate
+
+# LLM used for transcription
+audio_lm = "openai/llama3.1" # TODO - Replace with viable candidate
+
+
+# Reasoning model, used for `research` agent
+reasoning_llm = "openai/llama3.1"
+# Planning model, used for `research` agent
+planning_llm = "openai/llama3.1"
+
+[embedding]
+provider = "ollama"
+base_model = "mxbai-embed-large"
+base_dimension = 1_024
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[completion_embedding]
+provider = "ollama"
+base_model = "mxbai-embed-large"
+base_dimension = 1_024
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[agent]
+tools = ["search_file_knowledge"]
+
+[completion]
+provider = "litellm"
+concurrent_request_limit = 1
+
+ [completion.generation_config]
+ temperature = 0.1
+ top_p = 1
+ max_tokens_to_sample = 1_024
+ stream = false
+ api_base = "http://localhost:11434/v1"
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 512
+max_characters = 1_024
+combine_under_n_chars = 128
+overlap = 20
+chunks_for_document_summary = 16
+document_summary_model = "ollama/llama3.1"
+automatic_extraction = false
+
+[orchestration]
+provider = "hatchet"
diff --git a/.venv/lib/python3.12/site-packages/core/configs/gemini.toml b/.venv/lib/python3.12/site-packages/core/configs/gemini.toml
new file mode 100644
index 00000000..50739a6c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/gemini.toml
@@ -0,0 +1,21 @@
+[app]
+fast_llm = "gemini/gemini-2.0-flash-lite"
+quality_llm = "gemini/gemini-2.0-flash"
+vlm = "gemini/gemini-2.0-flash"
+audio_lm = "gemini/gemini-2.0-flash-lite"
+
+[embedding]
+provider = "litellm"
+base_model = "gemini/text-embedding-004"
+base_dimension = nan
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[completion_embedding]
+provider = "litellm"
+base_model = "gemini/text-embedding-004"
+base_dimension = nan
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
diff --git a/.venv/lib/python3.12/site-packages/core/configs/lm_studio.toml b/.venv/lib/python3.12/site-packages/core/configs/lm_studio.toml
new file mode 100644
index 00000000..1b8acb8f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/lm_studio.toml
@@ -0,0 +1,42 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "lm_studio/llama-3.2-3b-instruct"
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "lm_studio/llama-3.2-3b-instruct"
+
+# LLM used for ingesting visual inputs
+vlm = "lm_studio/llama3.2-vision" # TODO - Replace with viable candidate
+
+# LLM used for transcription
+audio_lm = "lm_studio/llama-3.2-3b-instruct" # TODO - Replace with viable candidate
+
+[embedding]
+provider = "litellm"
+base_model = "lm_studio/text-embedding-nomic-embed-text-v1.5"
+base_dimension = nan
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[completion_embedding]
+# Generally this should be the same as the embedding config, but advanced users may want to run with a different provider to reduce latency
+provider = "litellm"
+base_model = "lm_studio/text-embedding-nomic-embed-text-v1.5"
+base_dimension = nan
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[agent]
+tools = ["search_file_knowledge"]
+
+[completion]
+provider = "litellm"
+concurrent_request_limit = 1
+
+ [completion.generation_config]
+ temperature = 0.1
+ top_p = 1
+ max_tokens_to_sample = 1_024
+ stream = false
diff --git a/.venv/lib/python3.12/site-packages/core/configs/ollama.toml b/.venv/lib/python3.12/site-packages/core/configs/ollama.toml
new file mode 100644
index 00000000..5226eebf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/ollama.toml
@@ -0,0 +1,48 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "openai/llama3.1" ### NOTE - RECOMMENDED TO USE `openai` with `api_base = "http://localhost:11434/v1"` for best results, otherwise `ollama` with `litellm` is acceptable
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "openai/llama3.1"
+
+# LLM used for ingesting visual inputs
+vlm = "openai/llama3.1" # TODO - Replace with viable candidate
+
+# LLM used for transcription
+audio_lm = "openai/llama3.1" # TODO - Replace with viable candidate
+
+
+# Reasoning model, used for `research` agent
+reasoning_llm = "openai/llama3.1"
+# Planning model, used for `research` agent
+planning_llm = "openai/llama3.1"
+
+[embedding]
+provider = "ollama"
+base_model = "mxbai-embed-large"
+base_dimension = 1_024
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[completion_embedding]
+provider = "ollama"
+base_model = "mxbai-embed-large"
+base_dimension = 1_024
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[agent]
+tools = ["search_file_knowledge"]
+
+[completion]
+provider = "litellm"
+concurrent_request_limit = 1
+
+ [completion.generation_config]
+ temperature = 0.1
+ top_p = 1
+ max_tokens_to_sample = 1_024
+ stream = false
+ api_base = "http://localhost:11434/v1"
diff --git a/.venv/lib/python3.12/site-packages/core/configs/r2r_azure.toml b/.venv/lib/python3.12/site-packages/core/configs/r2r_azure.toml
new file mode 100644
index 00000000..fec2b026
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/r2r_azure.toml
@@ -0,0 +1,23 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "azure/gpt-4o-mini"
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "azure/gpt-4o"
+
+# LLM used for ingesting visual inputs
+vlm = "azure/gpt-4o"
+
+# LLM used for transcription
+audio_lm = "azure/whisper-1"
+
+# Reasoning model, used for `research` agent
+reasoning_llm = "azure/o3-mini"
+# Planning model, used for `research` agent
+planning_llm = "azure/o3-mini"
+
+[embedding]
+base_model = "azure/text-embedding-3-small"
+
+[completion_embedding]
+base_model = "azure/text-embedding-3-small"
diff --git a/.venv/lib/python3.12/site-packages/core/configs/r2r_azure_with_test_limits.toml b/.venv/lib/python3.12/site-packages/core/configs/r2r_azure_with_test_limits.toml
new file mode 100644
index 00000000..d26e7683
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/r2r_azure_with_test_limits.toml
@@ -0,0 +1,37 @@
+[app]
+# LLM used for internal operations, like deriving conversation names
+fast_llm = "azure/gpt-4o-mini"
+
+# LLM used for user-facing output, like RAG replies
+quality_llm = "azure/gpt-4o"
+
+# LLM used for ingesting visual inputs
+vlm = "azure/gpt-4o"
+
+# LLM used for transcription
+audio_lm = "azure/whisper-1"
+
+
+# Reasoning model, used for `research` agent
+reasoning_llm = "azure/o3-mini"
+# Planning model, used for `research` agent
+planning_llm = "azure/o3-mini"
+
+[embedding]
+base_model = "azure/text-embedding-3-small"
+base_dimension = 512
+
+[completion_embedding]
+base_model = "azure/text-embedding-3-small"
+
+[database]
+ [database.limits]
+ global_per_min = 10 # Small enough to test quickly
+ monthly_limit = 20 # Small enough to test in one run
+
+ [database.route_limits]
+ "/v3/retrieval/search" = { route_per_min = 5, monthly_limit = 10 }
+
+ [database.user_limits."47e53676-b478-5b3f-a409-234ca2164de5"]
+ global_per_min = 2
+ route_per_min = 1
diff --git a/.venv/lib/python3.12/site-packages/core/configs/r2r_with_auth.toml b/.venv/lib/python3.12/site-packages/core/configs/r2r_with_auth.toml
new file mode 100644
index 00000000..f36e8bb3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/configs/r2r_with_auth.toml
@@ -0,0 +1,8 @@
+[auth]
+provider = "r2r"
+access_token_lifetime_in_minutes = 60
+refresh_token_lifetime_in_days = 7
+require_authentication = true
+require_email_verification = false
+default_admin_email = "admin@example.com"
+default_admin_password = "change_me_immediately"
diff --git a/.venv/lib/python3.12/site-packages/core/examples/__init__.py b/.venv/lib/python3.12/site-packages/core/examples/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/examples/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/examples/hello_r2r.py b/.venv/lib/python3.12/site-packages/core/examples/hello_r2r.py
new file mode 100644
index 00000000..e7b21497
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/examples/hello_r2r.py
@@ -0,0 +1,23 @@
+from r2r import R2RClient
+
+client = R2RClient()
+
+with open("test.txt", "w") as file:
+ file.write("John is a person that works at Google.")
+
+client.ingest_files(file_paths=["test.txt"])
+
+# Call RAG directly on an R2R object
+rag_response = client.rag(
+ query="Who is john",
+ rag_generation_config={"model": "gpt-4o-mini", "temperature": 0.0},
+)
+results = rag_response["results"]
+print(f"Search Results:\n{results['search_results']}")
+print(f"Completion:\n{results['completion']}")
+
+# RAG Results:
+# Search Results:
+# AggregateSearchResult(chunk_search_results=[ChunkSearchResult(id=2d71e689-0a0e-5491-a50b-4ecb9494c832, score=0.6848798582029441, metadata={'text': 'John is a person that works at Google.', 'version': 'v0', 'chunk_order': 0, 'document_id': 'ed76b6ee-dd80-5172-9263-919d493b439a', 'id': '1ba494d7-cb2f-5f0e-9f64-76c31da11381', 'associatedQuery': 'Who is john'})], graph_search_results=None)
+# Completion:
+# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='assistant', function_call=None, tool_calls=None))], created=1719797903, model='gpt-4o-mini', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
diff --git a/.venv/lib/python3.12/site-packages/core/main/__init__.py b/.venv/lib/python3.12/site-packages/core/main/__init__.py
new file mode 100644
index 00000000..7043d029
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/__init__.py
@@ -0,0 +1,24 @@
+from .abstractions import R2RProviders
+from .api import *
+from .app import *
+
+# from .app_entry import r2r_app
+from .assembly import *
+from .orchestration import *
+from .services import *
+
+__all__ = [
+ # R2R Primary
+ "R2RProviders",
+ "R2RApp",
+ "R2RBuilder",
+ "R2RConfig",
+ # Factory
+ "R2RProviderFactory",
+ ## R2R SERVICES
+ "AuthService",
+ "IngestionService",
+ "ManagementService",
+ "RetrievalService",
+ "GraphService",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/abstractions.py b/.venv/lib/python3.12/site-packages/core/main/abstractions.py
new file mode 100644
index 00000000..3aaf2dbf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/abstractions.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel
+
+from core.providers import (
+ AnthropicCompletionProvider,
+ AsyncSMTPEmailProvider,
+ ClerkAuthProvider,
+ ConsoleMockEmailProvider,
+ HatchetOrchestrationProvider,
+ JwtAuthProvider,
+ LiteLLMCompletionProvider,
+ LiteLLMEmbeddingProvider,
+ MailerSendEmailProvider,
+ OllamaEmbeddingProvider,
+ OpenAICompletionProvider,
+ OpenAIEmbeddingProvider,
+ PostgresDatabaseProvider,
+ R2RAuthProvider,
+ R2RCompletionProvider,
+ R2RIngestionProvider,
+ SendGridEmailProvider,
+ SimpleOrchestrationProvider,
+ SupabaseAuthProvider,
+ UnstructuredIngestionProvider,
+)
+
+if TYPE_CHECKING:
+ from core.main.services.auth_service import AuthService
+ from core.main.services.graph_service import GraphService
+ from core.main.services.ingestion_service import IngestionService
+ from core.main.services.management_service import ManagementService
+ from core.main.services.retrieval_service import ( # type: ignore
+ RetrievalService, # type: ignore
+ )
+
+
+class R2RProviders(BaseModel):
+ auth: (
+ R2RAuthProvider
+ | SupabaseAuthProvider
+ | JwtAuthProvider
+ | ClerkAuthProvider
+ )
+ database: PostgresDatabaseProvider
+ ingestion: R2RIngestionProvider | UnstructuredIngestionProvider
+ embedding: (
+ LiteLLMEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ | OllamaEmbeddingProvider
+ )
+ completion_embedding: (
+ LiteLLMEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ | OllamaEmbeddingProvider
+ )
+ llm: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ )
+ orchestration: HatchetOrchestrationProvider | SimpleOrchestrationProvider
+ email: (
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ )
+
+ class Config:
+ arbitrary_types_allowed = True
+
+
+@dataclass
+class R2RServices:
+ auth: "AuthService"
+ ingestion: "IngestionService"
+ management: "ManagementService"
+ retrieval: "RetrievalService"
+ graph: "GraphService"
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py
new file mode 100644
index 00000000..ef432420
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/base_router.py
@@ -0,0 +1,151 @@
+import functools
+import logging
+from abc import abstractmethod
+from typing import Callable
+
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import FileResponse, StreamingResponse
+
+from core.base import R2RException
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+
+logger = logging.getLogger()
+
+
+class BaseRouterV3:
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ """
+ :param providers: Typically includes auth, database, etc.
+ :param services: Additional service references (ingestion, etc).
+ """
+ self.providers = providers
+ self.services = services
+ self.config = config
+ self.router = APIRouter()
+ self.openapi_extras = self._load_openapi_extras()
+
+ # Add the rate-limiting dependency
+ self.set_rate_limiting()
+
+ # Initialize any routes
+ self._setup_routes()
+ self._register_workflows()
+
+ def get_router(self):
+ return self.router
+
+ def base_endpoint(self, func: Callable):
+ """
+ A decorator to wrap endpoints in a standard pattern:
+ - error handling
+ - response shaping
+ """
+
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ try:
+ func_result = await func(*args, **kwargs)
+ if isinstance(func_result, tuple) and len(func_result) == 2:
+ results, outer_kwargs = func_result
+ else:
+ results, outer_kwargs = func_result, {}
+
+ if isinstance(results, (StreamingResponse, FileResponse)):
+ return results
+ return {"results": results, **outer_kwargs}
+
+ except R2RException:
+ raise
+ except Exception as e:
+ logger.error(
+ f"Error in base endpoint {func.__name__}() - {str(e)}",
+ exc_info=True,
+ )
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "message": f"An error '{e}' occurred during {func.__name__}",
+ "error": str(e),
+ "error_type": type(e).__name__,
+ },
+ ) from e
+
+ wrapper._is_base_endpoint = True # type: ignore
+ return wrapper
+
+ @classmethod
+ def build_router(cls, engine):
+ """Class method for building a router instance (if you have a standard
+ pattern)."""
+ return cls(engine).router
+
+ def _register_workflows(self):
+ pass
+
+ def _load_openapi_extras(self):
+ return {}
+
+ @abstractmethod
+ def _setup_routes(self):
+ """Subclasses override this to define actual endpoints."""
+ pass
+
+ def set_rate_limiting(self):
+ """Adds a yield-based dependency for rate limiting each request.
+
+ Checks the limits, then logs the request if the check passes.
+ """
+
+ async def rate_limit_dependency(
+ request: Request,
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ):
+ """1) Fetch the user from the DB (including .limits_overrides).
+
+ 2) Pass it to limits_handler.check_limits. 3) After the endpoint
+ completes, call limits_handler.log_request.
+ """
+ # If the user is superuser, skip checks
+ if auth_user.is_superuser:
+ yield
+ return
+
+ user_id = auth_user.id
+ route = request.scope["path"]
+
+ # 1) Fetch the user from DB
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if not user:
+ raise HTTPException(status_code=404, detail="User not found.")
+
+ # 2) Rate-limit check
+ try:
+ await self.providers.database.limits_handler.check_limits(
+ user=user,
+ route=route, # Pass the User object
+ )
+ except ValueError as e:
+ # If check_limits raises ValueError -> 429 Too Many Requests
+ raise HTTPException(status_code=429, detail=str(e)) from e
+
+ request.state.user_id = user_id
+ request.state.route = route
+
+ # 3) Execute the route
+ try:
+ yield
+ finally:
+ # 4) Log only POST and DELETE requests
+ if request.method in ["POST", "DELETE"]:
+ await self.providers.database.limits_handler.log_request(
+ user_id, route
+ )
+
+ # Attach the dependencies so you can use them in your endpoints
+ self.rate_limit_dependency = rate_limit_dependency
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py
new file mode 100644
index 00000000..ab0a62cb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/chunks_router.py
@@ -0,0 +1,422 @@
+import json
+import logging
+import textwrap
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+
+from core.base import (
+ ChunkResponse,
+ GraphSearchSettings,
+ R2RException,
+ SearchSettings,
+ UpdateChunk,
+ select_search_filters,
+)
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedChunkResponse,
+ WrappedChunksResponse,
+ WrappedVectorSearchResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+MAX_CHUNKS_PER_REQUEST = 1024 * 100
+
+
+class ChunksRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing ChunksRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/chunks/search",
+ summary="Search Chunks",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.search(
+ query="search query",
+ search_settings={
+ "limit": 10
+ }
+ )
+ """),
+ }
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def search_chunks(
+ query: str = Body(...),
+ search_settings: SearchSettings = Body(
+ default_factory=SearchSettings,
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedVectorSearchResponse: # type: ignore
+ # TODO - Deduplicate this code by sharing the code on the retrieval router
+ """Perform a semantic search query over all stored chunks.
+
+ This endpoint allows for complex filtering of search results using PostgreSQL-based queries.
+ Filters can be applied to various fields such as document_id, and internal metadata values.
+
+ Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+ """
+
+ search_settings.filters = select_search_filters(
+ auth_user, search_settings
+ )
+
+ search_settings.graph_settings = GraphSearchSettings(enabled=False)
+
+ results = await self.services.retrieval.search(
+ query=query,
+ search_settings=search_settings,
+ )
+ return results.chunk_search_results # type: ignore
+
+ @self.router.get(
+ "/chunks/{id}",
+ summary="Retrieve Chunk",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.retrieve(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.retrieve({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def retrieve_chunk(
+ id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunkResponse:
+ """Get a specific chunk by its ID.
+
+ Returns the chunk's content, metadata, and associated
+ document/collection information. Users can only retrieve chunks
+ they own or have access to through collections.
+ """
+ chunk = await self.services.ingestion.get_chunk(id)
+ if not chunk:
+ raise R2RException("Chunk not found", 404)
+
+ # TODO - Add collection ID check
+ if not auth_user.is_superuser and str(auth_user.id) != str(
+ chunk["owner_id"]
+ ):
+ raise R2RException("Not authorized to access this chunk", 403)
+
+ return ChunkResponse( # type: ignore
+ id=chunk["id"],
+ document_id=chunk["document_id"],
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ text=chunk["text"],
+ metadata=chunk["metadata"],
+ # vector = chunk["vector"] # TODO - Add include vector flag
+ )
+
+ @self.router.post(
+ "/chunks/{id}",
+ summary="Update Chunk",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.update(
+ {
+ "id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ "text": "Updated content",
+ "metadata": {"key": "new value"}
+ }
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.update({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ text: "Updated content",
+ metadata: {key: "new value"}
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_chunk(
+ id: UUID = Path(...),
+ chunk_update: UpdateChunk = Body(...),
+ # TODO: Run with orchestration?
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunkResponse:
+ """Update an existing chunk's content and/or metadata.
+
+ The chunk's vectors will be automatically recomputed based on the
+ new content. Users can only update chunks they own unless they are
+ superusers.
+ """
+ # Get the existing chunk to get its chunk_id
+ existing_chunk = await self.services.ingestion.get_chunk(
+ chunk_update.id
+ )
+ if existing_chunk is None:
+ raise R2RException(f"Chunk {chunk_update.id} not found", 404)
+
+ workflow_input = {
+ "document_id": str(existing_chunk["document_id"]),
+ "id": str(chunk_update.id),
+ "text": chunk_update.text,
+ "metadata": chunk_update.metadata
+ or existing_chunk["metadata"],
+ "user": auth_user.model_dump_json(),
+ }
+
+ logger.info("Running chunk ingestion without orchestration.")
+ from core.main.orchestration import simple_ingestion_factory
+
+ # TODO - CLEAN THIS UP
+
+ simple_ingestor = simple_ingestion_factory(self.services.ingestion)
+ await simple_ingestor["update-chunk"](workflow_input)
+
+ return ChunkResponse( # type: ignore
+ id=chunk_update.id,
+ document_id=existing_chunk["document_id"],
+ owner_id=existing_chunk["owner_id"],
+ collection_ids=existing_chunk["collection_ids"],
+ text=chunk_update.text,
+ metadata=chunk_update.metadata or existing_chunk["metadata"],
+ # vector = existing_chunk.get('vector')
+ )
+
+ @self.router.delete(
+ "/chunks/{id}",
+ summary="Delete Chunk",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.delete(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.delete({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_chunk(
+ id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific chunk by ID.
+
+ This permanently removes the chunk and its associated vector
+ embeddings. The parent document remains unchanged. Users can only
+ delete chunks they own unless they are superusers.
+ """
+ # Get the existing chunk to get its chunk_id
+ existing_chunk = await self.services.ingestion.get_chunk(id)
+
+ if existing_chunk is None:
+ raise R2RException(
+ message=f"Chunk {id} not found", status_code=404
+ )
+
+ filters = {
+ "$and": [
+ {"owner_id": {"$eq": str(auth_user.id)}},
+ {"chunk_id": {"$eq": str(id)}},
+ ]
+ }
+ await (
+ self.services.management.delete_documents_and_chunks_by_filter(
+ filters=filters
+ )
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/chunks",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List Chunks",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ response = client.chunks.list(
+ metadata_filter={"key": "value"},
+ include_vectors=False,
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.chunks.list({
+ metadataFilter: {key: "value"},
+ includeVectors: false,
+ offset: 0,
+ limit: 10,
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_chunks(
+ metadata_filter: Optional[str] = Query(
+ None, description="Filter by metadata"
+ ),
+ include_vectors: bool = Query(
+ False, description="Include vector data in response"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunksResponse:
+ """List chunks with pagination support.
+
+ Returns a paginated list of chunks that the user has access to.
+ Results can be filtered and sorted based on various parameters.
+ Vector embeddings are only included if specifically requested.
+
+ Regular users can only list chunks they own or have access to
+ through collections. Superusers can list all chunks in the system.
+ """ # Build filters
+ filters = {}
+
+ # Add user access control filter
+ if not auth_user.is_superuser:
+ filters["owner_id"] = {"$eq": str(auth_user.id)}
+
+ # Add metadata filters if provided
+ if metadata_filter:
+ metadata_filter = json.loads(metadata_filter)
+
+ # Get chunks using the vector handler's list_chunks method
+ results = await self.services.ingestion.list_chunks(
+ filters=filters,
+ include_vectors=include_vectors,
+ offset=offset,
+ limit=limit,
+ )
+
+ # Convert to response format
+ chunks = [
+ ChunkResponse(
+ id=chunk["id"],
+ document_id=chunk["document_id"],
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ text=chunk["text"],
+ metadata=chunk["metadata"],
+ vector=chunk.get("vector") if include_vectors else None,
+ )
+ for chunk in results["results"]
+ ]
+
+ return (chunks, {"total_entries": results["total_entries"]}) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py
new file mode 100644
index 00000000..462f5ca3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/collections_router.py
@@ -0,0 +1,1207 @@
+import logging
+import textwrap
+from enum import Enum
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+
+from core.base import R2RException
+from core.base.abstractions import GraphCreationSettings
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedCollectionResponse,
+ WrappedCollectionsResponse,
+ WrappedDocumentsResponse,
+ WrappedGenericMessageResponse,
+ WrappedUsersResponse,
+)
+from core.utils import (
+ generate_default_user_collection_id,
+ update_settings_from_dict,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class CollectionAction(str, Enum):
+ VIEW = "view"
+ EDIT = "edit"
+ DELETE = "delete"
+ MANAGE_USERS = "manage_users"
+ ADD_DOCUMENT = "add_document"
+ REMOVE_DOCUMENT = "remove_document"
+
+
+async def authorize_collection_action(
+ auth_user, collection_id: UUID, action: CollectionAction, services
+) -> bool:
+ """Authorize a user's action on a given collection based on:
+
+ - If user is superuser (admin): Full access.
+ - If user is owner of the collection: Full access.
+ - If user is a member of the collection (in `collection_ids`): VIEW only.
+ - Otherwise: No access.
+ """
+
+ # Superusers have complete access
+ if auth_user.is_superuser:
+ return True
+
+ # Fetch collection details: owner_id and members
+ results = (
+ await services.management.collections_overview(
+ 0, 1, collection_ids=[collection_id]
+ )
+ )["results"]
+ if len(results) == 0:
+ raise R2RException("The specified collection does not exist.", 404)
+ details = results[0]
+ owner_id = details.owner_id
+
+ # Check if user is owner
+ if auth_user.id == owner_id:
+ # Owner can do all actions
+ return True
+
+ # Check if user is a member (non-owner)
+ if collection_id in auth_user.collection_ids:
+ # Members can only view
+ if action == CollectionAction.VIEW:
+ return True
+ else:
+ raise R2RException(
+ "Insufficient permissions for this action.", 403
+ )
+
+ # User is neither owner nor member
+ raise R2RException("You do not have access to this collection.", 403)
+
+
+class CollectionsRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing CollectionsRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/collections",
+ summary="Create a new collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.create(
+ name="My New Collection",
+ description="This is a sample collection"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.create({
+ name: "My New Collection",
+ description: "This is a sample collection"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{"name": "My New Collection", "description": "This is a sample collection"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_collection(
+ name: str = Body(..., description="The name of the collection"),
+ description: Optional[str] = Body(
+ None, description="An optional description of the collection"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Create a new collection and automatically add the creating user
+ to it.
+
+ This endpoint allows authenticated users to create a new collection
+ with a specified name and optional description. The user creating
+ the collection is automatically added as a member.
+ """
+ user_collections_count = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id], limit=1, offset=0
+ )
+ )["total_entries"]
+ user_max_collections = (
+ await self.services.management.get_user_max_collections(
+ auth_user.id
+ )
+ )
+ if (user_collections_count + 1) >= user_max_collections: # type: ignore
+ raise R2RException(
+ f"User has reached the maximum number of collections allowed ({user_max_collections}).",
+ 400,
+ )
+ collection = await self.services.management.create_collection(
+ owner_id=auth_user.id,
+ name=name,
+ description=description,
+ )
+ # Add the creating user to the collection
+ await self.services.management.add_user_to_collection(
+ auth_user.id, collection.id
+ )
+ return collection # type: ignore
+
+ @self.router.post(
+ "/collections/export",
+ summary="Export collections to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.collections.export(
+ output_path="export.csv",
+ columns=["id", "name", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.collections.export({
+ outputPath: "export.csv",
+ columns: ["id", "name", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/collections/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_collections(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export collections as a CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_collections(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="collections_export.csv",
+ )
+
+ @self.router.get(
+ "/collections",
+ summary="List collections",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.list(
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections?offset=0&limit=10&name=Sample" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_collections(
+ ids: list[str] = Query(
+ [],
+ description="A list of collection IDs to retrieve. If not provided, all collections will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionsResponse:
+ """Returns a paginated list of collections the authenticated user
+ has access to.
+
+ Results can be filtered by providing specific collection IDs.
+ Regular users will only see collections they own or have access to.
+ Superusers can see all collections.
+
+ The collections are returned in order of last modification, with
+ most recent first.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ collection_uuids = [UUID(collection_id) for collection_id in ids]
+
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=requesting_user_id,
+ collection_ids=collection_uuids,
+ offset=offset,
+ limit=limit,
+ )
+ )
+
+ return ( # type: ignore
+ collections_overview_response["results"],
+ {
+ "total_entries": collections_overview_response[
+ "total_entries"
+ ]
+ },
+ )
+
+ @self.router.get(
+ "/collections/{id}",
+ summary="Get collection details",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.retrieve({id: "123e4567-e89b-12d3-a456-426614174000"});
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Get details of a specific collection.
+
+ This endpoint retrieves detailed information about a single
+ collection identified by its UUID. The user must have access to the
+ collection to view its details.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.VIEW, self.services
+ )
+
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=None,
+ collection_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+ overview = collections_overview_response["results"]
+
+ if len(overview) == 0: # type: ignore
+ raise R2RException(
+ "The specified collection does not exist.",
+ 404,
+ )
+ return overview[0] # type: ignore
+
+ @self.router.post(
+ "/collections/{id}",
+ summary="Update collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.update(
+ "123e4567-e89b-12d3-a456-426614174000",
+ name="Updated Collection Name",
+ description="Updated description"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.update({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ name: "Updated Collection Name",
+ description: "Updated description"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{"name": "Updated Collection Name", "description": "Updated description"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_collection(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the collection to update",
+ ),
+ name: Optional[str] = Body(
+ None, description="The name of the collection"
+ ),
+ description: Optional[str] = Body(
+ None, description="An optional description of the collection"
+ ),
+ generate_description: Optional[bool] = Body(
+ False,
+ description="Whether to generate a new synthetic description for the collection",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Update an existing collection's configuration.
+
+ This endpoint allows updating the name and description of an
+ existing collection. The user must have appropriate permissions to
+ modify the collection.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.EDIT, self.services
+ )
+
+ if generate_description and description is not None:
+ raise R2RException(
+ "Cannot provide both a description and request to synthetically generate a new one.",
+ 400,
+ )
+
+ return await self.services.management.update_collection( # type: ignore
+ id,
+ name=name,
+ description=description,
+ generate_description=generate_description or False,
+ )
+
+ @self.router.delete(
+ "/collections/{id}",
+ summary="Delete collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.delete({id: "123e4567-e89b-12d3-a456-426614174000"});
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_collection(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the collection to delete",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete an existing collection.
+
+ This endpoint allows deletion of a collection identified by its
+ UUID. The user must have appropriate permissions to delete the
+ collection. Deleting a collection removes all associations but does
+ not delete the documents within it.
+ """
+ if id == generate_default_user_collection_id(auth_user.id):
+ raise R2RException(
+ "Cannot delete the default user collection.",
+ 400,
+ )
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.DELETE, self.services
+ )
+
+ await self.services.management.delete_collection(collection_id=id)
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/collections/{id}/documents/{document_id}",
+ summary="Add document to collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.add_document(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "456e789a-b12c-34d5-e678-901234567890"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.addDocument({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ documentId: "456e789a-b12c-34d5-e678-901234567890"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_document_to_collection(
+ id: UUID = Path(...),
+ document_id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Add a document to a collection."""
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.ADD_DOCUMENT, self.services
+ )
+
+ return (
+ await self.services.management.assign_document_to_collection(
+ document_id, id
+ )
+ )
+
+ @self.router.get(
+ "/collections/{id}/documents",
+ summary="List documents in collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.list_documents(
+ "123e4567-e89b-12d3-a456-426614174000",
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.listDocuments({id: "123e4567-e89b-12d3-a456-426614174000"});
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_collection_documents(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentsResponse:
+ """Get all documents in a collection with pagination and sorting
+ options.
+
+ This endpoint retrieves a paginated list of documents associated
+ with a specific collection. It supports sorting options to
+ customize the order of returned documents.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.VIEW, self.services
+ )
+
+ documents_in_collection_response = (
+ await self.services.management.documents_in_collection(
+ id, offset, limit
+ )
+ )
+
+ return documents_in_collection_response["results"], { # type: ignore
+ "total_entries": documents_in_collection_response[
+ "total_entries"
+ ]
+ }
+
+ @self.router.delete(
+ "/collections/{id}/documents/{document_id}",
+ summary="Remove document from collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.remove_document(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "456e789a-b12c-34d5-e678-901234567890"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.removeDocument({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ documentId: "456e789a-b12c-34d5-e678-901234567890"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def remove_document_from_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ document_id: UUID = Path(
+ ...,
+ description="The unique identifier of the document to remove",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Remove a document from a collection.
+
+ This endpoint removes the association between a document and a
+ collection. It does not delete the document itself. The user must
+ have permissions to modify the collection.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.REMOVE_DOCUMENT, self.services
+ )
+ await self.services.management.remove_document_from_collection(
+ document_id, id
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/collections/{id}/users",
+ summary="List users in collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.list_users(
+ "123e4567-e89b-12d3-a456-426614174000",
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.listUsers({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_collection_users(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUsersResponse:
+ """Get all users in a collection with pagination and sorting
+ options.
+
+ This endpoint retrieves a paginated list of users who have access
+ to a specific collection. It supports sorting options to customize
+ the order of returned users.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.VIEW, self.services
+ )
+
+ users_in_collection_response = (
+ await self.services.management.get_users_in_collection(
+ collection_id=id,
+ offset=offset,
+ limit=min(max(limit, 1), 1000),
+ )
+ )
+
+ return users_in_collection_response["results"], { # type: ignore
+ "total_entries": users_in_collection_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/collections/{id}/users/{user_id}",
+ summary="Add user to collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.add_user(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "789a012b-c34d-5e6f-g789-012345678901"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.addUser({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ userId: "789a012b-c34d-5e6f-g789-012345678901"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_user_to_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ user_id: UUID = Path(
+ ..., description="The unique identifier of the user to add"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Add a user to a collection.
+
+ This endpoint grants a user access to a specific collection. The
+ authenticated user must have admin permissions for the collection
+ to add new users.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.MANAGE_USERS, self.services
+ )
+
+ result = await self.services.management.add_user_to_collection(
+ user_id, id
+ )
+ return GenericBooleanResponse(success=result) # type: ignore
+
+ @self.router.delete(
+ "/collections/{id}/users/{user_id}",
+ summary="Remove user from collection",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.collections.remove_user(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "789a012b-c34d-5e6f-g789-012345678901"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.collections.removeUser({
+ id: "123e4567-e89b-12d3-a456-426614174000"
+ userId: "789a012b-c34d-5e6f-g789-012345678901"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def remove_user_from_collection(
+ id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ user_id: UUID = Path(
+ ..., description="The unique identifier of the user to remove"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Remove a user from a collection.
+
+ This endpoint revokes a user's access to a specific collection. The
+ authenticated user must have admin permissions for the collection
+ to remove users.
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.MANAGE_USERS, self.services
+ )
+
+ result = (
+ await self.services.management.remove_user_from_collection(
+ user_id, id
+ )
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/collections/{id}/extract",
+ summary="Extract entities and relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.documents.extract(
+ id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1"
+ )
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def extract(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to extract entities and relationships from.",
+ ),
+ settings: Optional[GraphCreationSettings] = Body(
+ default=None,
+ description="Settings for the entities and relationships extraction process.",
+ ),
+ run_with_orchestration: Optional[bool] = Query(
+ default=True,
+ description="Whether to run the entities and relationships extraction process with orchestration.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Extracts entities and relationships from a document.
+
+ The entities and relationships extraction process involves:
+ 1. Parsing documents into semantic chunks
+ 2. Extracting entities and relationships using LLMs
+ """
+ await authorize_collection_action(
+ auth_user, id, CollectionAction.EDIT, self.services
+ )
+
+ settings = settings.dict() if settings else None # type: ignore
+ if not auth_user.is_superuser:
+ logger.warning("Implement permission checks here.")
+
+ # Apply runtime settings overrides
+ server_graph_creation_settings = (
+ self.providers.database.config.graph_creation_settings
+ )
+
+ if settings:
+ server_graph_creation_settings = update_settings_from_dict(
+ server_settings=server_graph_creation_settings,
+ settings_dict=settings, # type: ignore
+ )
+ if run_with_orchestration:
+ try:
+ workflow_input = {
+ "collection_id": str(id),
+ "graph_creation_settings": server_graph_creation_settings.model_dump_json(),
+ "user": auth_user.json(),
+ }
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-extraction", {"request": workflow_input}, {}
+ )
+ except Exception as e: # TODO: Need to find specific error (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated extraction: {e} \n\nAttempting to run without orchestration."
+ )
+
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info("Running extract-triples without orchestration.")
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-extraction"](
+ workflow_input
+ ) # type: ignore
+ return { # type: ignore
+ "message": "Graph created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.get(
+ "/collections/name/{collection_name}",
+ summary="Get a collection by name",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def get_collection_by_name(
+ collection_name: str = Path(
+ ..., description="The name of the collection"
+ ),
+ owner_id: Optional[UUID] = Query(
+ None,
+ description="(Superuser only) Specify the owner_id to retrieve a collection by name",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionResponse:
+ """Retrieve a collection by its (owner_id, name) combination.
+
+ The authenticated user can only fetch collections they own, or, if
+ superuser, from anyone.
+ """
+ if auth_user.is_superuser:
+ if not owner_id:
+ owner_id = auth_user.id
+ else:
+ owner_id = auth_user.id
+
+ # If not superuser, fetch by (owner_id, name). Otherwise, maybe pass `owner_id=None`.
+ # Decide on the logic for superusers.
+ if not owner_id: # is_superuser
+ # If you want superusers to do /collections/name/<string>?owner_id=...
+ # just parse it from the query. For now, let's say it's not implemented.
+ raise R2RException(
+ "Superuser must specify an owner_id to fetch by name.", 400
+ )
+
+ collection = await self.providers.database.collections_handler.get_collection_by_name(
+ owner_id, collection_name
+ )
+ if not collection:
+ raise R2RException("Collection not found.", 404)
+
+ # Now, authorize the 'view' action just in case:
+ # e.g. await authorize_collection_action(auth_user, collection.id, CollectionAction.VIEW, self.services)
+
+ return collection # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py
new file mode 100644
index 00000000..d1b6d645
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/conversations_router.py
@@ -0,0 +1,737 @@
+import logging
+import textwrap
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+
+from core.base import Message, R2RException
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedConversationMessagesResponse,
+ WrappedConversationResponse,
+ WrappedConversationsResponse,
+ WrappedMessageResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class ConversationsRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing ConversationsRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/conversations",
+ summary="Create a new conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.create()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.create();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_conversation(
+ name: Optional[str] = Body(
+ None, description="The name of the conversation", embed=True
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationResponse:
+ """Create a new conversation.
+
+ This endpoint initializes a new conversation for the authenticated
+ user.
+ """
+ user_id = auth_user.id
+
+ return await self.services.management.create_conversation( # type: ignore
+ user_id=user_id,
+ name=name,
+ )
+
+ @self.router.get(
+ "/conversations",
+ summary="List conversations",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.list(
+ offset=0,
+ limit=10,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/conversations?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_conversations(
+ ids: list[str] = Query(
+ [],
+ description="A list of conversation IDs to retrieve. If not provided, all conversations will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationsResponse:
+ """List conversations with pagination and sorting options.
+
+ This endpoint returns a paginated list of conversations for the
+ authenticated user.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ conversation_uuids = [
+ UUID(conversation_id) for conversation_id in ids
+ ]
+
+ conversations_response = (
+ await self.services.management.conversations_overview(
+ offset=offset,
+ limit=limit,
+ conversation_ids=conversation_uuids,
+ user_ids=requesting_user_id,
+ )
+ )
+ return conversations_response["results"], { # type: ignore
+ "total_entries": conversations_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/conversations/export",
+ summary="Export conversations to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.conversations.export(
+ output_path="export.csv",
+ columns=["id", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.conversations.export({
+ outputPath: "export.csv",
+ columns: ["id", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/conversations/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_conversations(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export conversations as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_conversations(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/conversations/export_messages",
+ summary="Export messages to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.conversations.export_messages(
+ output_path="export.csv",
+ columns=["id", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.conversations.exportMessages({
+ outputPath: "export.csv",
+ columns: ["id", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/conversations/export_messages" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_messages(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export conversations as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_messages(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/conversations/{id}",
+ summary="Get conversation details",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.get(
+ "123e4567-e89b-12d3-a456-426614174000"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.retrieve({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_conversation(
+ id: UUID = Path(
+ ..., description="The unique identifier of the conversation"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationMessagesResponse:
+ """Get details of a specific conversation.
+
+ This endpoint retrieves detailed information about a single
+ conversation identified by its UUID.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ conversation = await self.services.management.get_conversation(
+ conversation_id=id,
+ user_ids=requesting_user_id,
+ )
+ return conversation # type: ignore
+
+ @self.router.post(
+ "/conversations/{id}",
+ summary="Update conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.update({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ name: "new_name",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -d '{"name": "new_name"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_conversation(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the conversation to delete",
+ ),
+ name: str = Body(
+ ...,
+ description="The updated name for the conversation",
+ embed=True,
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedConversationResponse:
+ """Update an existing conversation.
+
+ This endpoint updates the name of an existing conversation
+ identified by its UUID.
+ """
+ return await self.services.management.update_conversation( # type: ignore
+ conversation_id=id,
+ name=name,
+ )
+
+ @self.router.delete(
+ "/conversations/{id}",
+ summary="Delete conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.delete({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_conversation(
+ id: UUID = Path(
+ ...,
+ description="The unique identifier of the conversation to delete",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete an existing conversation.
+
+ This endpoint deletes a conversation identified by its UUID.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ await self.services.management.delete_conversation(
+ conversation_id=id,
+ user_ids=requesting_user_id,
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/conversations/{id}/messages",
+ summary="Add message to conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.add_message(
+ "123e4567-e89b-12d3-a456-426614174000",
+ content="Hello, world!",
+ role="user",
+ parent_id="parent_message_id",
+ metadata={"key": "value"}
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.addMessage({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ content: "Hello, world!",
+ role: "user",
+ parentId: "parent_message_id",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"content": "Hello, world!", "parent_id": "parent_message_id", "metadata": {"key": "value"}}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_message(
+ id: UUID = Path(
+ ..., description="The unique identifier of the conversation"
+ ),
+ content: str = Body(
+ ..., description="The content of the message to add"
+ ),
+ role: str = Body(
+ ..., description="The role of the message to add"
+ ),
+ parent_id: Optional[UUID] = Body(
+ None, description="The ID of the parent message, if any"
+ ),
+ metadata: Optional[dict[str, str]] = Body(
+ None, description="Additional metadata for the message"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedMessageResponse:
+ """Add a new message to a conversation.
+
+ This endpoint adds a new message to an existing conversation.
+ """
+ if content == "":
+ raise R2RException("Content cannot be empty", status_code=400)
+ if role not in ["user", "assistant", "system"]:
+ raise R2RException("Invalid role", status_code=400)
+ message = Message(role=role, content=content)
+ return await self.services.management.add_message( # type: ignore
+ conversation_id=id,
+ content=message,
+ parent_id=parent_id,
+ metadata=metadata,
+ )
+
+ @self.router.post(
+ "/conversations/{id}/messages/{message_id}",
+ summary="Update message in conversation",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.conversations.update_message(
+ "123e4567-e89b-12d3-a456-426614174000",
+ "message_id_to_update",
+ content="Updated content"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.conversations.updateMessage({
+ id: "123e4567-e89b-12d3-a456-426614174000",
+ messageId: "message_id_to_update",
+ content: "Updated content",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages/message_id_to_update" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"content": "Updated content"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_message(
+ id: UUID = Path(
+ ..., description="The unique identifier of the conversation"
+ ),
+ message_id: UUID = Path(
+ ..., description="The ID of the message to update"
+ ),
+ content: Optional[str] = Body(
+ None, description="The new content for the message"
+ ),
+ metadata: Optional[dict[str, str]] = Body(
+ None, description="Additional metadata for the message"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedMessageResponse:
+ """Update an existing message in a conversation.
+
+ This endpoint updates the content of an existing message in a
+ conversation.
+ """
+ return await self.services.management.edit_message( # type: ignore
+ message_id=message_id,
+ new_content=content,
+ additional_metadata=metadata,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py
new file mode 100644
index 00000000..fe152b8b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/documents_router.py
@@ -0,0 +1,2342 @@
+import base64
+import logging
+import mimetypes
+import textwrap
+from datetime import datetime
+from io import BytesIO
+from typing import Any, Optional
+from urllib.parse import quote
+from uuid import UUID
+
+from fastapi import Body, Depends, File, Form, Path, Query, UploadFile
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse, StreamingResponse
+from pydantic import Json
+
+from core.base import (
+ IngestionConfig,
+ IngestionMode,
+ R2RException,
+ SearchMode,
+ SearchSettings,
+ UnprocessedChunk,
+ Workflow,
+ generate_document_id,
+ generate_id,
+ select_search_filters,
+)
+from core.base.abstractions import GraphCreationSettings, StoreType
+from core.base.api.models import (
+ GenericBooleanResponse,
+ WrappedBooleanResponse,
+ WrappedChunksResponse,
+ WrappedCollectionsResponse,
+ WrappedDocumentResponse,
+ WrappedDocumentSearchResponse,
+ WrappedDocumentsResponse,
+ WrappedEntitiesResponse,
+ WrappedGenericMessageResponse,
+ WrappedIngestionResponse,
+ WrappedRelationshipsResponse,
+)
+from core.utils import update_settings_from_dict
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+MAX_CHUNKS_PER_REQUEST = 1024 * 100
+
+
+def merge_search_settings(
+ base: SearchSettings, overrides: SearchSettings
+) -> SearchSettings:
+ # Convert both to dict
+ base_dict = base.model_dump()
+ overrides_dict = overrides.model_dump(exclude_unset=True)
+
+ # Update base_dict with values from overrides_dict
+ # This ensures that any field set in overrides takes precedence
+ for k, v in overrides_dict.items():
+ base_dict[k] = v
+
+ # Construct a new SearchSettings from the merged dict
+ return SearchSettings(**base_dict)
+
+
+def merge_ingestion_config(
+ base: IngestionConfig, overrides: IngestionConfig
+) -> IngestionConfig:
+ base_dict = base.model_dump()
+ overrides_dict = overrides.model_dump(exclude_unset=True)
+
+ for k, v in overrides_dict.items():
+ base_dict[k] = v
+
+ return IngestionConfig(**base_dict)
+
+
+class DocumentsRouter(BaseRouterV3):
+ def __init__(
+ self,
+ providers: R2RProviders,
+ services: R2RServices,
+ config: R2RConfig,
+ ):
+ logging.info("Initializing DocumentsRouter")
+ super().__init__(providers, services, config)
+ self._register_workflows()
+
+ def _prepare_search_settings(
+ self,
+ auth_user: Any,
+ search_mode: SearchMode,
+ search_settings: Optional[SearchSettings],
+ ) -> SearchSettings:
+ """Prepare the effective search settings based on the provided
+ search_mode, optional user-overrides in search_settings, and applied
+ filters."""
+
+ if search_mode != SearchMode.custom:
+ # Start from mode defaults
+ effective_settings = SearchSettings.get_default(search_mode.value)
+ if search_settings:
+ # Merge user-provided overrides
+ effective_settings = merge_search_settings(
+ effective_settings, search_settings
+ )
+ else:
+ # Custom mode: use provided settings or defaults
+ effective_settings = search_settings or SearchSettings()
+
+ # Apply user-specific filters
+ effective_settings.filters = select_search_filters(
+ auth_user, effective_settings
+ )
+
+ return effective_settings
+
+ # TODO - Remove this legacy method
+ def _register_workflows(self):
+ self.providers.orchestration.register_workflows(
+ Workflow.INGESTION,
+ self.services.ingestion,
+ {
+ "ingest-files": (
+ "Ingest files task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Document created and ingested successfully."
+ ),
+ "ingest-chunks": (
+ "Ingest chunks task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Document created and ingested successfully."
+ ),
+ "update-chunk": (
+ "Update chunk task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Chunk update completed successfully."
+ ),
+ "update-document-metadata": (
+ "Update document metadata task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Document metadata update completed successfully."
+ ),
+ "create-vector-index": (
+ "Vector index creation task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Vector index creation task completed successfully."
+ ),
+ "delete-vector-index": (
+ "Vector index deletion task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Vector index deletion task completed successfully."
+ ),
+ "select-vector-index": (
+ "Vector index selection task queued successfully."
+ if self.providers.orchestration.config.provider != "simple"
+ else "Vector index selection task completed successfully."
+ ),
+ },
+ )
+
+ def _prepare_ingestion_config(
+ self,
+ ingestion_mode: IngestionMode,
+ ingestion_config: Optional[IngestionConfig],
+ ) -> IngestionConfig:
+ # If not custom, start from defaults
+ if ingestion_mode != IngestionMode.custom:
+ effective_config = IngestionConfig.get_default(
+ ingestion_mode.value, app=self.providers.auth.config.app
+ )
+ if ingestion_config:
+ effective_config = merge_ingestion_config(
+ effective_config, ingestion_config
+ )
+ else:
+ # custom mode
+ effective_config = ingestion_config or IngestionConfig(
+ app=self.providers.auth.config.app
+ )
+
+ effective_config.validate_config()
+ return effective_config
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/documents",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ status_code=202,
+ summary="Create a new document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.create(
+ file_path="pg_essay_1.html",
+ metadata={"metadata_1":"some random metadata"},
+ id=None
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.create({
+ file: { path: "examples/data/marmeladov.txt", name: "marmeladov.txt" },
+ metadata: { title: "marmeladov.txt" },
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/documents" \\
+ -H "Content-Type: multipart/form-data" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -F "file=@pg_essay_1.html;type=text/html" \\
+ -F 'metadata={}' \\
+ -F 'id=null'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_document(
+ file: Optional[UploadFile] = File(
+ None,
+ description="The file to ingest. Exactly one of file, raw_text, or chunks must be provided.",
+ ),
+ raw_text: Optional[str] = Form(
+ None,
+ description="Raw text content to ingest. Exactly one of file, raw_text, or chunks must be provided.",
+ ),
+ chunks: Optional[Json[list[str]]] = Form(
+ None,
+ description="Pre-processed text chunks to ingest. Exactly one of file, raw_text, or chunks must be provided.",
+ ),
+ id: Optional[UUID] = Form(
+ None,
+ description="The ID of the document. If not provided, a new ID will be generated.",
+ ),
+ collection_ids: Optional[Json[list[UUID]]] = Form(
+ None,
+ description="Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.",
+ ),
+ metadata: Optional[Json[dict]] = Form(
+ None,
+ description="Metadata to associate with the document, such as title, description, or custom fields.",
+ ),
+ ingestion_mode: IngestionMode = Form(
+ default=IngestionMode.custom,
+ description=(
+ "Ingestion modes:\n"
+ "- `hi-res`: Thorough ingestion with full summaries and enrichment.\n"
+ "- `fast`: Quick ingestion with minimal enrichment and no summaries.\n"
+ "- `custom`: Full control via `ingestion_config`.\n\n"
+ "If `filters` or `limit` (in `ingestion_config`) are provided alongside `hi-res` or `fast`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ ingestion_config: Optional[Json[IngestionConfig]] = Form(
+ None,
+ description="An optional dictionary to override the default chunking configuration for the ingestion process. If not provided, the system will use the default server-side chunking configuration.",
+ ),
+ run_with_orchestration: Optional[bool] = Form(
+ True,
+ description="Whether or not ingestion runs with orchestration, default is `True`. When set to `False`, the ingestion process will run synchronous and directly return the result.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedIngestionResponse:
+ """
+ Creates a new Document object from an input file, text content, or chunks. The chosen `ingestion_mode` determines
+ how the ingestion process is configured:
+
+ **Ingestion Modes:**
+ - `hi-res`: Comprehensive parsing and enrichment, including summaries and possibly more thorough parsing.
+ - `fast`: Speed-focused ingestion that skips certain enrichment steps like summaries.
+ - `custom`: Provide a full `ingestion_config` to customize the entire ingestion process.
+
+ Either a file or text content must be provided, but not both. Documents are shared through `Collections` which allow for tightly specified cross-user interactions.
+
+ The ingestion process runs asynchronously and its progress can be tracked using the returned
+ task_id.
+ """
+ if not auth_user.is_superuser:
+ user_document_count = (
+ await self.services.management.documents_overview(
+ user_ids=[auth_user.id],
+ offset=0,
+ limit=1,
+ )
+ )["total_entries"]
+ user_max_documents = (
+ await self.services.management.get_user_max_documents(
+ auth_user.id
+ )
+ )
+
+ if user_document_count >= user_max_documents:
+ raise R2RException(
+ status_code=403,
+ message=f"User has reached the maximum number of documents allowed ({user_max_documents}).",
+ )
+
+ # Get chunks using the vector handler's list_chunks method
+ user_chunk_count = (
+ await self.services.ingestion.list_chunks(
+ filters={"owner_id": {"$eq": str(auth_user.id)}},
+ offset=0,
+ limit=1,
+ )
+ )["total_entries"]
+ user_max_chunks = (
+ await self.services.management.get_user_max_chunks(
+ auth_user.id
+ )
+ )
+ if user_chunk_count >= user_max_chunks:
+ raise R2RException(
+ status_code=403,
+ message=f"User has reached the maximum number of chunks allowed ({user_max_chunks}).",
+ )
+
+ user_collections_count = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id],
+ offset=0,
+ limit=1,
+ )
+ )["total_entries"]
+ user_max_collections = (
+ await self.services.management.get_user_max_collections(
+ auth_user.id
+ )
+ )
+ if user_collections_count >= user_max_collections: # type: ignore
+ raise R2RException(
+ status_code=403,
+ message=f"User has reached the maximum number of collections allowed ({user_max_collections}).",
+ )
+
+ effective_ingestion_config = self._prepare_ingestion_config(
+ ingestion_mode=ingestion_mode,
+ ingestion_config=ingestion_config,
+ )
+ if not file and not raw_text and not chunks:
+ raise R2RException(
+ status_code=422,
+ message="Either a `file`, `raw_text`, or `chunks` must be provided.",
+ )
+ if (
+ (file and raw_text)
+ or (file and chunks)
+ or (raw_text and chunks)
+ ):
+ raise R2RException(
+ status_code=422,
+ message="Only one of `file`, `raw_text`, or `chunks` may be provided.",
+ )
+ # Check if the user is a superuser
+ metadata = metadata or {}
+
+ if chunks:
+ if len(chunks) == 0:
+ raise R2RException("Empty list of chunks provided", 400)
+
+ if len(chunks) > MAX_CHUNKS_PER_REQUEST:
+ raise R2RException(
+ f"Maximum of {MAX_CHUNKS_PER_REQUEST} chunks per request",
+ 400,
+ )
+
+ document_id = id or generate_document_id(
+ "".join(chunks), auth_user.id
+ )
+
+ # FIXME: Metadata doesn't seem to be getting passed through
+ raw_chunks_for_doc = [
+ UnprocessedChunk(
+ text=chunk,
+ metadata=metadata,
+ id=generate_id(),
+ )
+ for chunk in chunks
+ ]
+
+ # Prepare workflow input
+ workflow_input = {
+ "document_id": str(document_id),
+ "chunks": [
+ chunk.model_dump(mode="json")
+ for chunk in raw_chunks_for_doc
+ ],
+ "metadata": metadata, # Base metadata for the document
+ "user": auth_user.model_dump_json(),
+ "ingestion_config": effective_ingestion_config.model_dump(
+ mode="json"
+ ),
+ }
+
+ if run_with_orchestration:
+ try:
+ # Run ingestion with orchestration
+ raw_message = (
+ await self.providers.orchestration.run_workflow(
+ "ingest-chunks",
+ {"request": workflow_input},
+ options={
+ "additional_metadata": {
+ "document_id": str(document_id),
+ }
+ },
+ )
+ )
+ raw_message["document_id"] = str(document_id)
+ return raw_message # type: ignore
+ except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated ingestion: {e} \n\nAttempting to run without orchestration."
+ )
+
+ logger.info("Running chunk ingestion without orchestration.")
+ from core.main.orchestration import simple_ingestion_factory
+
+ simple_ingestor = simple_ingestion_factory(
+ self.services.ingestion
+ )
+ await simple_ingestor["ingest-chunks"](workflow_input)
+
+ return { # type: ignore
+ "message": "Document created and ingested successfully.",
+ "document_id": str(document_id),
+ "task_id": None,
+ }
+
+ else:
+ if file:
+ file_data = await self._process_file(file)
+
+ if not file.filename:
+ raise R2RException(
+ status_code=422,
+ message="Uploaded file must have a filename.",
+ )
+
+ file_ext = file.filename.split(".")[
+ -1
+ ] # e.g. "pdf", "txt"
+ max_allowed_size = await self.services.management.get_max_upload_size_by_type(
+ user_id=auth_user.id, file_type_or_ext=file_ext
+ )
+
+ content_length = file_data["content_length"]
+
+ if content_length > max_allowed_size:
+ raise R2RException(
+ status_code=413, # HTTP 413: Payload Too Large
+ message=(
+ f"File size exceeds maximum of {max_allowed_size} bytes "
+ f"for extension '{file_ext}'."
+ ),
+ )
+
+ file_content = BytesIO(
+ base64.b64decode(file_data["content"])
+ )
+
+ file_data.pop("content", None)
+ document_id = id or generate_document_id(
+ file_data["filename"], auth_user.id
+ )
+ elif raw_text:
+ content_length = len(raw_text)
+ file_content = BytesIO(raw_text.encode("utf-8"))
+ document_id = id or generate_document_id(
+ raw_text, auth_user.id
+ )
+ title = metadata.get("title", None)
+ title = title + ".txt" if title else None
+ file_data = {
+ "filename": title or "N/A",
+ "content_type": "text/plain",
+ }
+ else:
+ raise R2RException(
+ status_code=422,
+ message="Either a file or content must be provided.",
+ )
+
+ workflow_input = {
+ "file_data": file_data,
+ "document_id": str(document_id),
+ "collection_ids": (
+ [str(cid) for cid in collection_ids]
+ if collection_ids
+ else None
+ ),
+ "metadata": metadata,
+ "ingestion_config": effective_ingestion_config.model_dump(
+ mode="json"
+ ),
+ "user": auth_user.model_dump_json(),
+ "size_in_bytes": content_length,
+ "version": "v0",
+ }
+
+ file_name = file_data["filename"]
+ await self.providers.database.files_handler.store_file(
+ document_id,
+ file_name,
+ file_content,
+ file_data["content_type"],
+ )
+
+ await self.services.ingestion.ingest_file_ingress(
+ file_data=workflow_input["file_data"],
+ user=auth_user,
+ document_id=workflow_input["document_id"],
+ size_in_bytes=workflow_input["size_in_bytes"],
+ metadata=workflow_input["metadata"],
+ version=workflow_input["version"],
+ )
+
+ if run_with_orchestration:
+ try:
+ # TODO - Modify create_chunks so that we can add chunks to existing document
+
+ workflow_result: dict[
+ str, str | None
+ ] = await self.providers.orchestration.run_workflow( # type: ignore
+ "ingest-files",
+ {"request": workflow_input},
+ options={
+ "additional_metadata": {
+ "document_id": str(document_id),
+ }
+ },
+ )
+ workflow_result["document_id"] = str(document_id)
+ return workflow_result # type: ignore
+ except Exception as e: # TODO: Need to find specific error (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated ingestion: {e} \n\nAttempting to run without orchestration."
+ )
+ logger.info(
+ f"Running ingestion without orchestration for file {file_name} and document_id {document_id}."
+ )
+ # TODO - Clean up implementation logic here to be more explicitly `synchronous`
+ from core.main.orchestration import simple_ingestion_factory
+
+ simple_ingestor = simple_ingestion_factory(self.services.ingestion)
+ await simple_ingestor["ingest-files"](workflow_input)
+ return { # type: ignore
+ "message": "Document created and ingested successfully.",
+ "document_id": str(document_id),
+ "task_id": None,
+ }
+
+ @self.router.patch(
+ "/documents/{id}/metadata",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Append metadata to a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.append_metadata(
+ id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata=[{"key": "new_key", "value": "new_value"}]
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.appendMetadata({
+ id: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata: [{ key: "new_key", value: "new_value" }],
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def patch_metadata(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to append metadata to.",
+ ),
+ metadata: list[dict] = Body(
+ ...,
+ description="Metadata to append to the document.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentResponse:
+ """Appends metadata to a document. This endpoint allows adding new metadata fields or updating existing ones."""
+ request_user_ids = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=request_user_ids,
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+ results = documents_overview_response["results"]
+ if len(results) == 0:
+ raise R2RException("Document not found.", 404)
+
+ return await self.services.management.update_document_metadata(
+ document_id=id,
+ metadata=metadata,
+ overwrite=False,
+ )
+
+ @self.router.put(
+ "/documents/{id}/metadata",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Replace metadata of a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.replace_metadata(
+ id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata=[{"key": "new_key", "value": "new_value"}]
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.replaceMetadata({
+ id: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ metadata: [{ key: "new_key", value: "new_value" }],
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def put_metadata(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to append metadata to.",
+ ),
+ metadata: list[dict] = Body(
+ ...,
+ description="Metadata to append to the document.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentResponse:
+ """Replaces metadata in a document. This endpoint allows overwriting existing metadata fields."""
+ request_user_ids = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=request_user_ids,
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+ results = documents_overview_response["results"]
+ if len(results) == 0:
+ raise R2RException("Document not found.", 404)
+
+ return await self.services.management.update_document_metadata(
+ document_id=id,
+ metadata=metadata,
+ overwrite=True,
+ )
+
+ @self.router.post(
+ "/documents/export",
+ summary="Export documents to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.documents.export(
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.documents.export({
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/documents/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_documents(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_documents(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/documents/download_zip",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ response_class=StreamingResponse,
+ summary="Export multiple documents as zip",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ client.documents.download_zip(
+ document_ids=["uuid1", "uuid2"],
+ start_date="2024-01-01",
+ end_date="2024-12-31"
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/download_zip?document_ids=uuid1,uuid2&start_date=2024-01-01&end_date=2024-12-31" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_files(
+ document_ids: Optional[list[UUID]] = Query(
+ None,
+ description="List of document IDs to include in the export. If not provided, all accessible documents will be included.",
+ ),
+ start_date: Optional[datetime] = Query(
+ None,
+ description="Filter documents created on or after this date.",
+ ),
+ end_date: Optional[datetime] = Query(
+ None,
+ description="Filter documents created before this date.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> StreamingResponse:
+ """Export multiple documents as a zip file. Documents can be
+ filtered by IDs and/or date range.
+
+ The endpoint allows downloading:
+ - Specific documents by providing their IDs
+ - Documents within a date range
+ - All accessible documents if no filters are provided
+
+ Files are streamed as a zip archive to handle potentially large downloads efficiently.
+ """
+ if not auth_user.is_superuser:
+ # For non-superusers, verify access to requested documents
+ if document_ids:
+ documents_overview = (
+ await self.services.management.documents_overview(
+ user_ids=[auth_user.id],
+ document_ids=document_ids,
+ offset=0,
+ limit=len(document_ids),
+ )
+ )
+ if len(documents_overview["results"]) != len(document_ids):
+ raise R2RException(
+ status_code=403,
+ message="You don't have access to one or more requested documents.",
+ )
+ if not document_ids:
+ raise R2RException(
+ status_code=403,
+ message="Non-superusers must provide document IDs to export.",
+ )
+
+ (
+ zip_name,
+ zip_content,
+ zip_size,
+ ) = await self.services.management.export_files(
+ document_ids=document_ids,
+ start_date=start_date,
+ end_date=end_date,
+ )
+ encoded_filename = quote(zip_name)
+
+ async def stream_file():
+ yield zip_content.getvalue()
+
+ return StreamingResponse(
+ stream_file(),
+ media_type="application/zip",
+ headers={
+ "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
+ "Content-Length": str(zip_size),
+ },
+ )
+
+ @self.router.get(
+ "/documents",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List documents",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list(
+ limit=10,
+ offset=0
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.list({
+ limit: 10,
+ offset: 0,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_documents(
+ ids: list[str] = Query(
+ [],
+ description="A list of document IDs to retrieve. If not provided, all documents will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ include_summary_embeddings: bool = Query(
+ False,
+ description="Specifies whether or not to include embeddings of each document summary.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentsResponse:
+ """Returns a paginated list of documents the authenticated user has
+ access to.
+
+ Results can be filtered by providing specific document IDs. Regular
+ users will only see documents they own or have access to through
+ collections. Superusers can see all documents.
+
+ The documents are returned in order of last modification, with most
+ recent first.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+ filter_collection_ids = (
+ None if auth_user.is_superuser else auth_user.collection_ids
+ )
+
+ document_uuids = [UUID(document_id) for document_id in ids]
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=requesting_user_id,
+ collection_ids=filter_collection_ids,
+ document_ids=document_uuids,
+ offset=offset,
+ limit=limit,
+ )
+ )
+ if not include_summary_embeddings:
+ for document in documents_overview_response["results"]:
+ document.summary_embedding = None
+
+ return ( # type: ignore
+ documents_overview_response["results"],
+ {
+ "total_entries": documents_overview_response[
+ "total_entries"
+ ]
+ },
+ )
+
+ @self.router.get(
+ "/documents/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Retrieve a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.retrieve(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.retrieve({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_document(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentResponse:
+ """Retrieves detailed information about a specific document by its
+ ID.
+
+ This endpoint returns the document's metadata, status, and system information. It does not
+ return the document's content - use the `/documents/{id}/download` endpoint for that.
+
+ Users can only retrieve documents they own or have access to through collections.
+ Superusers can retrieve any document.
+ """
+ request_user_ids = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+ filter_collection_ids = (
+ None if auth_user.is_superuser else auth_user.collection_ids
+ )
+
+ documents_overview_response = await self.services.management.documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+ user_ids=request_user_ids,
+ collection_ids=filter_collection_ids,
+ document_ids=[id],
+ offset=0,
+ limit=100,
+ )
+ results = documents_overview_response["results"]
+ if len(results) == 0:
+ raise R2RException("Document not found.", 404)
+
+ return results[0]
+
+ @self.router.get(
+ "/documents/{id}/chunks",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List document chunks",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list_chunks(
+ id="32b6a70f-a995-5c51-85d2-834f06283a1e"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.listChunks({
+ id: "32b6a70f-a995-5c51-85d2-834f06283a1e",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/chunks" \\
+ -H "Authorization: Bearer YOUR_API_KEY"\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_chunks(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve chunks for.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ include_vectors: Optional[bool] = Query(
+ False,
+ description="Whether to include vector embeddings in the response.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedChunksResponse:
+ """Retrieves the text chunks that were generated from a document
+ during ingestion. Chunks represent semantic sections of the
+ document and are used for retrieval and analysis.
+
+ Users can only access chunks from documents they own or have access
+ to through collections. Vector embeddings are only included if
+ specifically requested.
+
+ Results are returned in chunk sequence order, representing their
+ position in the original document.
+ """
+ list_document_chunks = (
+ await self.services.management.list_document_chunks(
+ document_id=id,
+ offset=offset,
+ limit=limit,
+ include_vectors=include_vectors or False,
+ )
+ )
+
+ if not list_document_chunks["results"]:
+ raise R2RException(
+ "No chunks found for the given document ID.", 404
+ )
+
+ is_owner = str(
+ list_document_chunks["results"][0].get("owner_id")
+ ) == str(auth_user.id)
+ document_collections = (
+ await self.services.management.collections_overview(
+ offset=0,
+ limit=-1,
+ document_ids=[id],
+ )
+ )
+
+ user_has_access = (
+ is_owner
+ or set(auth_user.collection_ids).intersection(
+ {ele.id for ele in document_collections["results"]} # type: ignore
+ )
+ != set()
+ )
+
+ if not user_has_access and not auth_user.is_superuser:
+ raise R2RException(
+ "Not authorized to access this document's chunks.", 403
+ )
+
+ return ( # type: ignore
+ list_document_chunks["results"],
+ {"total_entries": list_document_chunks["total_entries"]},
+ )
+
+ @self.router.get(
+ "/documents/{id}/download",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ response_class=StreamingResponse,
+ summary="Download document content",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.download(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.download({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/download" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_document_file(
+ id: str = Path(..., description="Document ID"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> StreamingResponse:
+ """Downloads the original file content of a document.
+
+ For uploaded files, returns the original file with its proper MIME
+ type. For text-only documents, returns the content as plain text.
+
+ Users can only download documents they own or have access to
+ through collections.
+ """
+ try:
+ document_uuid = UUID(id)
+ except ValueError:
+ raise R2RException(
+ status_code=422, message="Invalid document ID format."
+ ) from None
+
+ # Retrieve the document's information
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=None,
+ collection_ids=None,
+ document_ids=[document_uuid],
+ offset=0,
+ limit=1,
+ )
+ )
+
+ if not documents_overview_response["results"]:
+ raise R2RException("Document not found.", 404)
+
+ document = documents_overview_response["results"][0]
+
+ is_owner = str(document.owner_id) == str(auth_user.id)
+
+ if not auth_user.is_superuser and not is_owner:
+ document_collections = (
+ await self.services.management.collections_overview(
+ offset=0,
+ limit=-1,
+ document_ids=[document_uuid],
+ )
+ )
+
+ document_collection_ids = {
+ str(ele.id)
+ for ele in document_collections["results"] # type: ignore
+ }
+
+ user_collection_ids = {
+ str(cid) for cid in auth_user.collection_ids
+ }
+
+ has_collection_access = user_collection_ids.intersection(
+ document_collection_ids
+ )
+
+ if not has_collection_access:
+ raise R2RException(
+ "Not authorized to access this document.", 403
+ )
+
+ file_tuple = await self.services.management.download_file(
+ document_uuid
+ )
+ if not file_tuple:
+ raise R2RException(status_code=404, message="File not found.")
+
+ file_name, file_content, file_size = file_tuple
+ encoded_filename = quote(file_name)
+
+ mime_type, _ = mimetypes.guess_type(file_name)
+ if not mime_type:
+ mime_type = "application/octet-stream"
+
+ async def file_stream():
+ chunk_size = 1024 * 1024 # 1MB
+ while True:
+ data = file_content.read(chunk_size)
+ if not data:
+ break
+ yield data
+
+ return StreamingResponse(
+ file_stream(),
+ media_type=mime_type,
+ headers={
+ "Content-Disposition": f"inline; filename*=UTF-8''{encoded_filename}",
+ "Content-Length": str(file_size),
+ },
+ )
+
+ @self.router.delete(
+ "/documents/by-filter",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete documents by filter",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+ client = R2RClient()
+ # when using auth, do client.login(...)
+ response = client.documents.delete_by_filter(
+ filters={"document_type": {"$eq": "txt"}}
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/documents/by-filter?filters=%7B%22document_type%22%3A%7B%22%24eq%22%3A%22text%22%7D%2C%22created_at%22%3A%7B%22%24lt%22%3A%222023-01-01T00%3A00%3A00Z%22%7D%7D" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_document_by_filter(
+ filters: Json[dict] = Body(
+ ..., description="JSON-encoded filters"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete documents based on provided filters.
+
+ Allowed operators
+ include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`,
+ `ilike`, `in`, and `nin`. Deletion requests are limited to a
+ user's own documents.
+ """
+
+ filters_dict = {
+ "$and": [{"owner_id": {"$eq": str(auth_user.id)}}, filters]
+ }
+ await (
+ self.services.management.delete_documents_and_chunks_by_filter(
+ filters=filters_dict
+ )
+ )
+
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.delete(
+ "/documents/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete a document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.delete(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.delete({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_document_by_id(
+ id: UUID = Path(..., description="Document ID"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific document. All chunks corresponding to the
+ document are deleted, and all other references to the document are
+ removed.
+
+ NOTE - Deletions do not yet impact the knowledge graph or other derived data. This feature is planned for a future release.
+ """
+
+ filters: dict[str, Any] = {"document_id": {"$eq": str(id)}}
+ if not auth_user.is_superuser:
+ filters = {
+ "$and": [
+ {"owner_id": {"$eq": str(auth_user.id)}},
+ {"document_id": {"$eq": str(id)}},
+ ]
+ }
+
+ await (
+ self.services.management.delete_documents_and_chunks_by_filter(
+ filters=filters
+ )
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/documents/{id}/collections",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List document collections",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list_collections(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", offset=0, limit=10
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.listCollections({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/collections" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_document_collections(
+ id: str = Path(..., description="Document ID"),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionsResponse:
+ """Retrieves all collections that contain the specified document.
+ This endpoint is restricted to superusers only and provides a
+ system-wide view of document organization.
+
+ Collections are used to organize documents and manage access control. A document can belong
+ to multiple collections, and users can access documents through collection membership.
+
+ The results are paginated and ordered by collection creation date, with the most recently
+ created collections appearing first.
+
+ NOTE - This endpoint is only available to superusers, it will be extended to regular users in a future release.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can get the collections belonging to a document.",
+ 403,
+ )
+
+ collections_response = (
+ await self.services.management.collections_overview(
+ offset=offset,
+ limit=limit,
+ document_ids=[UUID(id)], # Convert string ID to UUID
+ )
+ )
+
+ return collections_response["results"], { # type: ignore
+ "total_entries": collections_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/documents/{id}/extract",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Extract entities and relationships",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.extract(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def extract(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to extract entities and relationships from.",
+ ),
+ settings: Optional[GraphCreationSettings] = Body(
+ default=None,
+ description="Settings for the entities and relationships extraction process.",
+ ),
+ run_with_orchestration: Optional[bool] = Body(
+ default=True,
+ description="Whether to run the entities and relationships extraction process with orchestration.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Extracts entities and relationships from a document.
+
+ The entities and relationships extraction process involves:
+
+ 1. Parsing documents into semantic chunks
+
+ 2. Extracting entities and relationships using LLMs
+
+ 3. Storing the created entities and relationships in the knowledge graph
+
+ 4. Preserving the document's metadata and content, and associating the elements with collections the document belongs to
+ """
+
+ settings = settings.dict() if settings else None # type: ignore
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(documents_overview_response) == 0:
+ raise R2RException("Document not found.", 404)
+
+ if (
+ not auth_user.is_superuser
+ and auth_user.id != documents_overview_response[0].owner_id
+ ):
+ raise R2RException(
+ "Only a superuser can extract entities and relationships from a document they do not own.",
+ 403,
+ )
+
+ # Apply runtime settings overrides
+ server_graph_creation_settings = (
+ self.providers.database.config.graph_creation_settings
+ )
+
+ if settings:
+ server_graph_creation_settings = update_settings_from_dict(
+ server_settings=server_graph_creation_settings,
+ settings_dict=settings, # type: ignore
+ )
+
+ if run_with_orchestration:
+ try:
+ workflow_input = {
+ "document_id": str(id),
+ "graph_creation_settings": server_graph_creation_settings.model_dump_json(),
+ "user": auth_user.json(),
+ }
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-extraction", {"request": workflow_input}, {}
+ )
+ except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated extraction: {e} \n\nAttempting to run without orchestration."
+ )
+
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info("Running extract-triples without orchestration.")
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-extraction"](
+ workflow_input
+ )
+ return { # type: ignore
+ "message": "Graph created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.post(
+ "/documents/{id}/deduplicate",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Deduplicate entities",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ response = client.documents.deduplicate(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.deduplicate({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/deduplicate" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def deduplicate(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to extract entities and relationships from.",
+ ),
+ settings: Optional[GraphCreationSettings] = Body(
+ default=None,
+ description="Settings for the entities and relationships extraction process.",
+ ),
+ run_with_orchestration: Optional[bool] = Body(
+ default=True,
+ description="Whether to run the entities and relationships extraction process with orchestration.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Deduplicates entities from a document."""
+
+ settings = settings.model_dump() if settings else None # type: ignore
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(documents_overview_response) == 0:
+ raise R2RException("Document not found.", 404)
+
+ if (
+ not auth_user.is_superuser
+ and auth_user.id != documents_overview_response[0].owner_id
+ ):
+ raise R2RException(
+ "Only a superuser can run deduplication on a document they do not own.",
+ 403,
+ )
+
+ # Apply runtime settings overrides
+ server_graph_creation_settings = (
+ self.providers.database.config.graph_creation_settings
+ )
+
+ if settings:
+ server_graph_creation_settings = update_settings_from_dict(
+ server_settings=server_graph_creation_settings,
+ settings_dict=settings, # type: ignore
+ )
+
+ if run_with_orchestration:
+ try:
+ workflow_input = {
+ "document_id": str(id),
+ }
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-deduplication",
+ {"request": workflow_input},
+ {},
+ )
+ except Exception as e: # TODO: Need to find specific errors that we should be excepting (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated deduplication: {e} \n\nAttempting to run without orchestration."
+ )
+
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info(
+ "Running deduplicate-document-entities without orchestration."
+ )
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-deduplication"](
+ workflow_input
+ )
+ return { # type: ignore
+ "message": "Graph created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.get(
+ "/documents/{id}/entities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Lists the entities from the document",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.extract(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_entities(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve entities from.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ include_embeddings: Optional[bool] = Query(
+ False,
+ description="Whether to include vector embeddings in the response.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntitiesResponse:
+ """Retrieves the entities that were extracted from a document.
+ These represent important semantic elements like people, places,
+ organizations, concepts, etc.
+
+ Users can only access entities from documents they own or have
+ access to through collections. Entity embeddings are only included
+ if specifically requested.
+
+ Results are returned in the order they were extracted from the
+ document.
+ """
+ # if (
+ # not auth_user.is_superuser
+ # and id not in auth_user.collection_ids
+ # ):
+ # raise R2RException(
+ # "The currently authenticated user does not have access to the specified collection.",
+ # 403,
+ # )
+
+ # First check if the document exists and user has access
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+
+ if not documents_overview_response["results"]:
+ raise R2RException("Document not found.", 404)
+
+ # Get all entities for this document from the document_entity table
+ (
+ entities,
+ count,
+ ) = await self.providers.database.graphs_handler.entities.get(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ offset=offset,
+ limit=limit,
+ include_embeddings=include_embeddings or False,
+ )
+
+ return entities, {"total_entries": count} # type: ignore
+
+ @self.router.post(
+ "/documents/{id}/entities/export",
+ summary="Export document entities to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.documents.export_entities(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.documents.exportEntities({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_entities(
+ background_tasks: BackgroundTasks,
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_document_entities(
+ id=id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/documents/{id}/relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List document relationships",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.documents.list_relationships(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ offset=0,
+ limit=100
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.documents.listRelationships({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ offset: 0,
+ limit: 100,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/relationships" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_relationships(
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to retrieve relationships for.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ entity_names: Optional[list[str]] = Query(
+ None,
+ description="Filter relationships by specific entity names.",
+ ),
+ relationship_types: Optional[list[str]] = Query(
+ None,
+ description="Filter relationships by specific relationship types.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipsResponse:
+ """Retrieves the relationships between entities that were extracted
+ from a document. These represent connections and interactions
+ between entities found in the text.
+
+ Users can only access relationships from documents they own or have
+ access to through collections. Results can be filtered by entity
+ names and relationship types.
+
+ Results are returned in the order they were extracted from the
+ document.
+ """
+ # if (
+ # not auth_user.is_superuser
+ # and id not in auth_user.collection_ids
+ # ):
+ # raise R2RException(
+ # "The currently authenticated user does not have access to the specified collection.",
+ # 403,
+ # )
+
+ # First check if the document exists and user has access
+ documents_overview_response = (
+ await self.services.management.documents_overview(
+ user_ids=(
+ None if auth_user.is_superuser else [auth_user.id]
+ ),
+ collection_ids=(
+ None
+ if auth_user.is_superuser
+ else auth_user.collection_ids
+ ),
+ document_ids=[id],
+ offset=0,
+ limit=1,
+ )
+ )
+
+ if not documents_overview_response["results"]:
+ raise R2RException("Document not found.", 404)
+
+ # Get relationships for this document
+ (
+ relationships,
+ count,
+ ) = await self.providers.database.graphs_handler.relationships.get(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ entity_names=entity_names,
+ relationship_types=relationship_types,
+ offset=offset,
+ limit=limit,
+ )
+
+ return relationships, {"total_entries": count} # type: ignore
+
+ @self.router.post(
+ "/documents/{id}/relationships/export",
+ summary="Export document relationships to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.documents.export_entities(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.documents.exportEntities({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_relationships(
+ background_tasks: BackgroundTasks,
+ id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_document_relationships(
+ id=id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/documents/search",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Search document summaries",
+ )
+ @self.base_endpoint
+ async def search_documents(
+ query: str = Body(
+ ...,
+ description="The search query to perform.",
+ ),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description=(
+ "Default value of `custom` allows full control over search settings.\n\n"
+ "Pre-configured search modes:\n"
+ "`basic`: A simple semantic-based search.\n"
+ "`advanced`: A more powerful hybrid search combining semantic and full-text.\n"
+ "`custom`: Full control via `search_settings`.\n\n"
+ "If `filters` or `limit` are provided alongside `basic` or `advanced`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ search_settings: SearchSettings = Body(
+ default_factory=SearchSettings,
+ description="Settings for document search",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedDocumentSearchResponse:
+ """Perform a search query on the automatically generated document
+ summaries in the system.
+
+ This endpoint allows for complex filtering of search results using PostgreSQL-based queries.
+ Filters can be applied to various fields such as document_id, and internal metadata values.
+
+
+ Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+ """
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+
+ query_embedding = (
+ await self.providers.embedding.async_get_embedding(query)
+ )
+ results = await self.services.retrieval.search_documents(
+ query=query,
+ query_embedding=query_embedding,
+ settings=effective_settings,
+ )
+ return results # type: ignore
+
+ @staticmethod
+ async def _process_file(file):
+ import base64
+
+ content = await file.read()
+
+ return {
+ "filename": file.filename,
+ "content": base64.b64encode(content).decode("utf-8"),
+ "content_type": file.content_type,
+ "content_length": len(content),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py
new file mode 100644
index 00000000..ba588c3b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/examples.py
@@ -0,0 +1,1065 @@
+import textwrap
+
+"""
+This file contains updated OpenAPI examples for the RetrievalRouterV3 class.
+These examples are designed to be included in the openapi_extra field for each route.
+"""
+
+# Updated examples for search_app endpoint
+search_app_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # if using auth, do client.login(...)
+
+ # Basic search
+ response = client.retrieval.search(
+ query="What is DeepSeek R1?",
+ )
+
+ # Advanced mode with specific filters
+ response = client.retrieval.search(
+ query="What is DeepSeek R1?",
+ search_mode="advanced",
+ search_settings={
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ "limit": 5
+ }
+ )
+
+ # Using hybrid search
+ response = client.retrieval.search(
+ query="What was Uber's profit in 2020?",
+ search_settings={
+ "use_hybrid_search": True,
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50
+ },
+ "filters": {"title": {"$in": ["DeepSeek_R1.pdf"]}},
+ }
+ )
+
+ # Advanced filtering
+ results = client.retrieval.search(
+ query="What are the effects of climate change?",
+ search_settings={
+ "filters": {
+ "$and":[
+ {"document_type": {"$eq": "pdf"}},
+ {"metadata.year": {"$gt": 2020}}
+ ]
+ },
+ "limit": 10
+ }
+ )
+
+ # Knowledge graph enhanced search
+ results = client.retrieval.search(
+ query="What was DeepSeek R1",
+ graph_search_settings={
+ "use_graph_search": True,
+ "kg_search_type": "local"
+ }
+ )
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // if using auth, do client.login(...)
+
+ // Basic search
+ const response = await client.retrieval.search({
+ query: "What is DeepSeek R1?",
+ });
+
+ // With specific filters
+ const filteredResponse = await client.retrieval.search({
+ query: "What is DeepSeek R1?",
+ searchSettings: {
+ filters: {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ limit: 5
+ }
+ });
+
+ // Using hybrid search
+ const hybridResponse = await client.retrieval.search({
+ query: "What was Uber's profit in 2020?",
+ searchSettings: {
+ indexMeasure: "l2_distance",
+ useHybridSearch: true,
+ hybridSettings: {
+ fullTextWeight: 1.0,
+ semanticWeight: 5.0,
+ fullTextLimit: 200,
+ },
+ filters: {"title": {"$in": ["DeepSeek_R1.pdf"]}},
+ }
+ });
+
+ // Advanced filtering
+ const advancedResults = await client.retrieval.search({
+ query: "What are the effects of climate change?",
+ searchSettings: {
+ filters: {
+ $and: [
+ {document_type: {$eq: "pdf"}},
+ {"metadata.year": {$gt: 2020}}
+ ]
+ },
+ limit: 10
+ }
+ });
+
+ // Knowledge graph enhanced search
+ const kgResults = await client.retrieval.search({
+ query: "who was aristotle?",
+ graphSearchSettings: {
+ useKgSearch: true,
+ kgSearchType: "local"
+ }
+ });
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ # Basic search
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What is DeepSeek R1?"
+ }'
+
+ # With hybrid search and filters
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What was Uber'\''s profit in 2020?",
+ "search_settings": {
+ "use_hybrid_search": true,
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50
+ },
+ "filters": {"title": {"$in": ["DeepSeek_R1.pdf"]}},
+ "limit": 10,
+ "chunk_settings": {
+ "index_measure": "l2_distance",
+ "probes": 25,
+ "ef_search": 100
+ }
+ }
+ }'
+
+ # Knowledge graph enhanced search
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/search" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "query": "who was aristotle?",
+ "graph_search_settings": {
+ "use_graph_search": true,
+ "kg_search_type": "local"
+ }
+ }' \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for rag_app endpoint
+rag_app_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ # Basic RAG request
+ response = client.retrieval.rag(
+ query="What is DeepSeek R1?",
+ )
+
+ # Advanced RAG with custom search settings
+ response = client.retrieval.rag(
+ query="What is DeepSeek R1?",
+ search_settings={
+ "use_semantic_search": True,
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ "limit": 10,
+ },
+ rag_generation_config={
+ "stream": False,
+ "temperature": 0.7,
+ "max_tokens": 1500
+ }
+ )
+
+ # Hybrid search in RAG
+ results = client.retrieval.rag(
+ "Who is Jon Snow?",
+ search_settings={"use_hybrid_search": True}
+ )
+
+ # Custom model selection
+ response = client.retrieval.rag(
+ "Who was Aristotle?",
+ rag_generation_config={"model":"anthropic/claude-3-haiku-20240307", "stream": True}
+ )
+ for chunk in response:
+ print(chunk)
+
+ # Streaming RAG
+ from r2r import (
+ CitationEvent,
+ FinalAnswerEvent,
+ MessageEvent,
+ SearchResultsEvent,
+ R2RClient,
+ )
+
+ result_stream = client.retrieval.rag(
+ query="What is DeepSeek R1?",
+ search_settings={"limit": 25},
+ rag_generation_config={"stream": True},
+ )
+
+ # Process different event types
+ for event in result_stream:
+ if isinstance(event, SearchResultsEvent):
+ print("Search results:", event.data)
+ elif isinstance(event, MessageEvent):
+ print("Partial message:", event.data.delta)
+ elif isinstance(event, CitationEvent):
+ print("New citation detected:", event.data.id)
+ elif isinstance(event, FinalAnswerEvent):
+ print("Final answer:", event.data.generated_answer)
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ // Basic RAG request
+ const response = await client.retrieval.rag({
+ query: "What is DeepSeek R1?",
+ });
+
+ // RAG with custom settings
+ const advancedResponse = await client.retrieval.rag({
+ query: "What is DeepSeek R1?",
+ searchSettings: {
+ useSemanticSearch: true,
+ filters: {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ limit: 10,
+ },
+ ragGenerationConfig: {
+ stream: false,
+ temperature: 0.7,
+ maxTokens: 1500
+ }
+ });
+
+ // Hybrid search in RAG
+ const hybridResults = await client.retrieval.rag({
+ query: "Who is Jon Snow?",
+ searchSettings: {
+ useHybridSearch: true
+ },
+ });
+
+ // Custom model
+ const customModelResponse = await client.retrieval.rag({
+ query: "Who was Aristotle?",
+ ragGenerationConfig: {
+ model: 'anthropic/claude-3-haiku-20240307',
+ temperature: 0.7,
+ }
+ });
+
+ // Streaming RAG
+ const resultStream = await client.retrieval.rag({
+ query: "What is DeepSeek R1?",
+ searchSettings: { limit: 25 },
+ ragGenerationConfig: { stream: true },
+ });
+
+ // Process streaming events
+ if (Symbol.asyncIterator in resultStream) {
+ for await (const event of resultStream) {
+ switch (event.event) {
+ case "search_results":
+ console.log("Search results:", event.data);
+ break;
+ case "message":
+ console.log("Partial message delta:", event.data.delta);
+ break;
+ case "citation":
+ console.log("New citation event:", event.data.id);
+ break;
+ case "final_answer":
+ console.log("Final answer:", event.data.generated_answer);
+ break;
+ default:
+ console.log("Unknown or unhandled event:", event);
+ }
+ }
+ }
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ # Basic RAG request
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What is DeepSeek R1?"
+ }'
+
+ # RAG with custom settings
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "What is DeepSeek R1?",
+ "search_settings": {
+ "use_semantic_search": true,
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}},
+ "limit": 10
+ },
+ "rag_generation_config": {
+ "stream": false,
+ "temperature": 0.7,
+ "max_tokens": 1500
+ }
+ }'
+
+ # Hybrid search in RAG
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "Who is Jon Snow?",
+ "search_settings": {
+ "use_hybrid_search": true,
+ "filters": {},
+ "limit": 10
+ }
+ }'
+
+ # Custom model
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/rag" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "query": "Who is Jon Snow?",
+ "rag_generation_config": {
+ "model": "anthropic/claude-3-haiku-20240307",
+ "temperature": 0.7
+ }
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for agent_app endpoint
+agent_app_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+from r2r import (
+ R2RClient,
+ ThinkingEvent,
+ ToolCallEvent,
+ ToolResultEvent,
+ CitationEvent,
+ FinalAnswerEvent,
+ MessageEvent,
+)
+
+client = R2RClient()
+# when using auth, do client.login(...)
+
+# Basic synchronous request
+response = client.retrieval.agent(
+ message={
+ "role": "user",
+ "content": "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ rag_tools=["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"],
+)
+
+# Advanced analysis with streaming and extended thinking
+streaming_response = client.retrieval.agent(
+ message={
+ "role": "user",
+ "content": "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ search_settings={"limit": 20},
+ rag_tools=["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"],
+ rag_generation_config={
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": True,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": None,
+ "max_tokens": 16000,
+ "stream": True
+ }
+)
+
+# Process streaming events with emoji only on type change
+current_event_type = None
+for event in streaming_response:
+ # Check if the event type has changed
+ event_type = type(event)
+ if event_type != current_event_type:
+ current_event_type = event_type
+ print() # Add newline before new event type
+
+ # Print emoji based on the new event type
+ if isinstance(event, ThinkingEvent):
+ print(f"\n🧠 Thinking: ", end="", flush=True)
+ elif isinstance(event, ToolCallEvent):
+ print(f"\n🔧 Tool call: ", end="", flush=True)
+ elif isinstance(event, ToolResultEvent):
+ print(f"\n📊 Tool result: ", end="", flush=True)
+ elif isinstance(event, CitationEvent):
+ print(f"\n📑 Citation: ", end="", flush=True)
+ elif isinstance(event, MessageEvent):
+ print(f"\n💬 Message: ", end="", flush=True)
+ elif isinstance(event, FinalAnswerEvent):
+ print(f"\n✅ Final answer: ", end="", flush=True)
+
+ # Print the content without the emoji
+ if isinstance(event, ThinkingEvent):
+ print(f"{event.data.delta.content[0].payload.value}", end="", flush=True)
+ elif isinstance(event, ToolCallEvent):
+ print(f"{event.data.name}({event.data.arguments})")
+ elif isinstance(event, ToolResultEvent):
+ print(f"{event.data.content[:60]}...")
+ elif isinstance(event, CitationEvent):
+ print(f"{event.data.id}")
+ elif isinstance(event, MessageEvent):
+ print(f"{event.data.delta.content[0].payload.value}", end="", flush=True)
+ elif isinstance(event, FinalAnswerEvent):
+ print(f"{event.data.generated_answer[:100]}...")
+ print(f" Citations: {len(event.data.citations)} sources referenced")
+
+# Conversation with multiple turns (synchronous)
+conversation = client.conversations.create()
+
+# First message in conversation
+results_1 = client.retrieval.agent(
+ query="What does DeepSeek R1 imply for the future of AI?",
+ rag_generation_config={
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": True,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": None,
+ "max_tokens": 16000,
+ "stream": True
+ },
+ conversation_id=conversation.results.id
+)
+
+# Follow-up query in the same conversation
+results_2 = client.retrieval.agent(
+ query="How does it compare to other reasoning models?",
+ rag_generation_config={
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": True,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": None,
+ "max_tokens": 16000,
+ "stream": True
+ },
+ conversation_id=conversation.results.id
+)
+
+# Access the final results
+print(f"First response: {results_1.generated_answer[:100]}...")
+print(f"Follow-up response: {results_2.generated_answer[:100]}...")
+"""
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ async function main() {
+ // Basic synchronous request
+ const ragResponse = await client.retrieval.agent({
+ message: {
+ role: "user",
+ content: "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ ragTools: ["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"]
+ });
+
+ // Advanced analysis with streaming and extended thinking
+ const streamingResponse = await client.retrieval.agent({
+ message: {
+ role: "user",
+ content: "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ searchSettings: {limit: 20},
+ ragTools: ["web_search", "web_scrape", "search_file_descriptions", "search_file_knowledge", "get_file_content"],
+ ragGenerationConfig: {
+ model: "anthropic/claude-3-7-sonnet-20250219",
+ extendedThinking: true,
+ thinkingBudget: 4096,
+ temperature: 1,
+ maxTokens: 16000,
+ stream: true
+ }
+ });
+
+ // Process streaming events with emoji only on type change
+ if (Symbol.asyncIterator in streamingResponse) {
+ let currentEventType = null;
+
+ for await (const event of streamingResponse) {
+ // Check if event type has changed
+ const eventType = event.event;
+ if (eventType !== currentEventType) {
+ currentEventType = eventType;
+ console.log(); // Add newline before new event type
+
+ // Print emoji based on the new event type
+ switch(eventType) {
+ case "thinking":
+ process.stdout.write(`🧠 Thinking: `);
+ break;
+ case "tool_call":
+ process.stdout.write(`🔧 Tool call: `);
+ break;
+ case "tool_result":
+ process.stdout.write(`📊 Tool result: `);
+ break;
+ case "citation":
+ process.stdout.write(`📑 Citation: `);
+ break;
+ case "message":
+ process.stdout.write(`💬 Message: `);
+ break;
+ case "final_answer":
+ process.stdout.write(`✅ Final answer: `);
+ break;
+ }
+ }
+
+ // Print content based on event type
+ switch(eventType) {
+ case "thinking":
+ process.stdout.write(`${event.data.delta.content[0].payload.value}`);
+ break;
+ case "tool_call":
+ console.log(`${event.data.name}(${JSON.stringify(event.data.arguments)})`);
+ break;
+ case "tool_result":
+ console.log(`${event.data.content.substring(0, 60)}...`);
+ break;
+ case "citation":
+ console.log(`${event.data.id}`);
+ break;
+ case "message":
+ process.stdout.write(`${event.data.delta.content[0].payload.value}`);
+ break;
+ case "final_answer":
+ console.log(`${event.data.generated_answer.substring(0, 100)}...`);
+ console.log(` Citations: ${event.data.citations.length} sources referenced`);
+ break;
+ }
+ }
+ }
+
+ // Conversation with multiple turns (synchronous)
+ const conversation = await client.conversations.create();
+
+ // First message in conversation
+ const results1 = await client.retrieval.agent({
+ query: "What does DeepSeek R1 imply for the future of AI?",
+ ragGenerationConfig: {
+ model: "anthropic/claude-3-7-sonnet-20250219",
+ extendedThinking: true,
+ thinkingBudget: 4096,
+ temperature: 1,
+ maxTokens: 16000,
+ stream: true
+ },
+ conversationId: conversation.results.id
+ });
+
+ // Follow-up query in the same conversation
+ const results2 = await client.retrieval.agent({
+ query: "How does it compare to other reasoning models?",
+ ragGenerationConfig: {
+ model: "anthropic/claude-3-7-sonnet-20250219",
+ extendedThinking: true,
+ thinkingBudget: 4096,
+ temperature: 1,
+ maxTokens: 16000,
+ stream: true
+ },
+ conversationId: conversation.results.id
+ });
+
+ // Log the results
+ console.log(`First response: ${results1.generated_answer.substring(0, 100)}...`);
+ console.log(`Follow-up response: ${results2.generated_answer.substring(0, 100)}...`);
+ }
+
+ main();
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ # Basic request
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "message": {
+ "role": "user",
+ "content": "What were the key contributions of Aristotle to logic?"
+ },
+ "search_settings": {
+ "use_semantic_search": true,
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}
+ },
+ "rag_tools": ["search_file_knowledge", "content", "web_search"]
+ }'
+
+ # Advanced analysis with extended thinking
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "message": {
+ "role": "user",
+ "content": "Do a deep analysis of the philosophical implications of DeepSeek R1"
+ },
+ "search_settings": {"limit": 20},
+ "research_tools": ["rag", "reasoning", "critique", "python_executor"],
+ "rag_generation_config": {
+ "model": "anthropic/claude-3-7-sonnet-20250219",
+ "extended_thinking": true,
+ "thinking_budget": 4096,
+ "temperature": 1,
+ "top_p": null,
+ "max_tokens": 16000,
+ "stream": true
+ }
+ }'
+
+ # Conversation continuation
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/agent" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "message": {
+ "role": "user",
+ "content": "How does it compare to other reasoning models?"
+ },
+ "conversation_id": "YOUR_CONVERSATION_ID"
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for completion endpoint
+completion_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.completion(
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Italy?"}
+ ],
+ generation_config={
+ "model": "openai/gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 150,
+ "stream": False
+ }
+ )
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ async function main() {
+ const response = await client.completion({
+ messages: [
+ { role: "system", content: "You are a helpful assistant." },
+ { role: "user", content: "What is the capital of France?" },
+ { role: "assistant", content: "The capital of France is Paris." },
+ { role: "user", content: "What about Italy?" }
+ ],
+ generationConfig: {
+ model: "openai/gpt-4o-mini",
+ temperature: 0.7,
+ maxTokens: 150,
+ stream: false
+ }
+ });
+ }
+
+ main();
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/completion" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Italy?"}
+ ],
+ "generation_config": {
+ "model": "openai/gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 150,
+ "stream": false
+ }
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated examples for embedding endpoint
+embedding_examples = {
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.retrieval.embedding(
+ text="What is DeepSeek R1?",
+ )
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // when using auth, do client.login(...)
+
+ async function main() {
+ const response = await client.retrieval.embedding({
+ text: "What is DeepSeek R1?",
+ });
+ }
+
+ main();
+ """
+ ),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent(
+ """
+ curl -X POST "https://api.sciphi.ai/v3/retrieval/embedding" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "text": "What is DeepSeek R1?",
+ }'
+ """
+ ),
+ },
+ ]
+}
+
+# Updated rag_app docstring
+rag_app_docstring = """
+Execute a RAG (Retrieval-Augmented Generation) query.
+
+This endpoint combines search results with language model generation to produce accurate,
+contextually-relevant responses based on your document corpus.
+
+**Features:**
+- Combines vector search, optional knowledge graph integration, and LLM generation
+- Automatically cites sources with unique citation identifiers
+- Supports both streaming and non-streaming responses
+- Compatible with various LLM providers (OpenAI, Anthropic, etc.)
+- Web search integration for up-to-date information
+
+**Search Configuration:**
+All search parameters from the search endpoint apply here, including filters, hybrid search, and graph-enhanced search.
+
+**Generation Configuration:**
+Fine-tune the language model's behavior with `rag_generation_config`:
+```json
+{
+ "model": "openai/gpt-4o-mini", // Model to use
+ "temperature": 0.7, // Control randomness (0-1)
+ "max_tokens": 1500, // Maximum output length
+ "stream": true // Enable token streaming
+}
+```
+
+**Model Support:**
+- OpenAI models (default)
+- Anthropic Claude models (requires ANTHROPIC_API_KEY)
+- Local models via Ollama
+- Any provider supported by LiteLLM
+
+**Streaming Responses:**
+When `stream: true` is set, the endpoint returns Server-Sent Events with the following types:
+- `search_results`: Initial search results from your documents
+- `message`: Partial tokens as they're generated
+- `citation`: Citation metadata when sources are referenced
+- `final_answer`: Complete answer with structured citations
+
+**Example Response:**
+```json
+{
+ "generated_answer": "DeepSeek-R1 is a model that demonstrates impressive performance...[1]",
+ "search_results": { ... },
+ "citations": [
+ {
+ "id": "cit.123456",
+ "object": "citation",
+ "payload": { ... }
+ }
+ ]
+}
+```
+"""
+
+# Updated agent_app docstring
+agent_app_docstring = """
+Engage with an intelligent agent for information retrieval, analysis, and research.
+
+This endpoint offers two operating modes:
+- **RAG mode**: Standard retrieval-augmented generation for answering questions based on knowledge base
+- **Research mode**: Advanced capabilities for deep analysis, reasoning, and computation
+
+### RAG Mode (Default)
+
+The RAG mode provides fast, knowledge-based responses using:
+- Semantic and hybrid search capabilities
+- Document-level and chunk-level content retrieval
+- Optional web search integration
+- Source citation and evidence-based responses
+
+### Research Mode
+
+The Research mode builds on RAG capabilities and adds:
+- A dedicated reasoning system for complex problem-solving
+- Critique capabilities to identify potential biases or logical fallacies
+- Python execution for computational analysis
+- Multi-step reasoning for deeper exploration of topics
+
+### Available Tools
+
+**RAG Tools:**
+- `search_file_knowledge`: Semantic/hybrid search on your ingested documents
+- `search_file_descriptions`: Search over file-level metadata
+- `content`: Fetch entire documents or chunk structures
+- `web_search`: Query external search APIs for up-to-date information
+- `web_scrape`: Scrape and extract content from specific web pages
+
+**Research Tools:**
+- `rag`: Leverage the underlying RAG agent for information retrieval
+- `reasoning`: Call a dedicated model for complex analytical thinking
+- `critique`: Analyze conversation history to identify flaws and biases
+- `python_executor`: Execute Python code for complex calculations and analysis
+
+### Streaming Output
+
+When streaming is enabled, the agent produces different event types:
+- `thinking`: Shows the model's step-by-step reasoning (when extended_thinking=true)
+- `tool_call`: Shows when the agent invokes a tool
+- `tool_result`: Shows the result of a tool call
+- `citation`: Indicates when a citation is added to the response
+- `message`: Streams partial tokens of the response
+- `final_answer`: Contains the complete generated answer and structured citations
+
+### Conversations
+
+Maintain context across multiple turns by including `conversation_id` in each request.
+After your first call, store the returned `conversation_id` and include it in subsequent calls.
+"""
+
+# Updated completion_docstring
+completion_docstring = """
+Generate completions for a list of messages.
+
+This endpoint uses the language model to generate completions for the provided messages.
+The generation process can be customized using the generation_config parameter.
+
+The messages list should contain alternating user and assistant messages, with an optional
+system message at the start. Each message should have a 'role' and 'content'.
+
+**Generation Configuration:**
+Fine-tune the language model's behavior with `generation_config`:
+```json
+{
+ "model": "openai/gpt-4o-mini", // Model to use
+ "temperature": 0.7, // Control randomness (0-1)
+ "max_tokens": 1500, // Maximum output length
+ "stream": true // Enable token streaming
+}
+```
+
+**Multiple LLM Support:**
+- OpenAI models (default)
+- Anthropic Claude models (requires ANTHROPIC_API_KEY)
+- Local models via Ollama
+- Any provider supported by LiteLLM
+"""
+
+# Updated embedding_docstring
+embedding_docstring = """
+Generate embeddings for the provided text using the specified model.
+
+This endpoint uses the language model to generate embeddings for the provided text.
+The model parameter specifies the model to use for generating embeddings.
+
+Embeddings are numerical representations of text that capture semantic meaning,
+allowing for similarity comparisons and other vector operations.
+
+**Uses:**
+- Semantic search
+- Document clustering
+- Text similarity analysis
+- Content recommendation
+"""
+
+# # Example implementation to update the routers in the RetrievalRouterV3 class
+# def update_retrieval_router(router_class):
+# """
+# Update the RetrievalRouterV3 class with the improved docstrings and examples.
+
+# This function demonstrates how the updated examples and docstrings would be
+# integrated into the actual router class.
+# """
+# # Update search_app endpoint
+# router_class.search_app.__doc__ = search_app_docstring
+# router_class.search_app.openapi_extra = search_app_examples
+
+# # Update rag_app endpoint
+# router_class.rag_app.__doc__ = rag_app_docstring
+# router_class.rag_app.openapi_extra = rag_app_examples
+
+# # Update agent_app endpoint
+# router_class.agent_app.__doc__ = agent_app_docstring
+# router_class.agent_app.openapi_extra = agent_app_examples
+
+# # Update completion endpoint
+# router_class.completion.__doc__ = completion_docstring
+# router_class.completion.openapi_extra = completion_examples
+
+# # Update embedding endpoint
+# router_class.embedding.__doc__ = embedding_docstring
+# router_class.embedding.openapi_extra = embedding_examples
+
+# return router_class
+
+# Example showing how the updated router would be integrated
+"""
+from your_module import RetrievalRouterV3
+
+# Apply the updated docstrings and examples
+router = RetrievalRouterV3(providers, services, config)
+router = update_retrieval_router(router)
+
+# Now the router has the improved docstrings and examples
+"""
+
+EXAMPLES = {
+ "search": search_app_examples,
+ "rag": rag_app_examples,
+ "agent": agent_app_examples,
+ "completion": completion_examples,
+ "embedding": embedding_examples,
+}
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py
new file mode 100644
index 00000000..244d76cf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/graph_router.py
@@ -0,0 +1,2051 @@
+import logging
+import textwrap
+from typing import Optional, cast
+from uuid import UUID
+
+from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+
+from core.base import GraphConstructionStatus, R2RException, Workflow
+from core.base.abstractions import DocumentResponse, StoreType
+from core.base.api.models import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ WrappedBooleanResponse,
+ WrappedCommunitiesResponse,
+ WrappedCommunityResponse,
+ WrappedEntitiesResponse,
+ WrappedEntityResponse,
+ WrappedGenericMessageResponse,
+ WrappedGraphResponse,
+ WrappedGraphsResponse,
+ WrappedRelationshipResponse,
+ WrappedRelationshipsResponse,
+)
+from core.utils import (
+ generate_default_user_collection_id,
+ update_settings_from_dict,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class GraphRouter(BaseRouterV3):
+ def __init__(
+ self,
+ providers: R2RProviders,
+ services: R2RServices,
+ config: R2RConfig,
+ ):
+ logging.info("Initializing GraphRouter")
+ super().__init__(providers, services, config)
+ self._register_workflows()
+
+ def _register_workflows(self):
+ workflow_messages = {}
+ if self.providers.orchestration.config.provider == "hatchet":
+ workflow_messages["graph-extraction"] = (
+ "Document extraction task queued successfully."
+ )
+ workflow_messages["graph-clustering"] = (
+ "Graph enrichment task queued successfully."
+ )
+ workflow_messages["graph-deduplication"] = (
+ "Entity deduplication task queued successfully."
+ )
+ else:
+ workflow_messages["graph-extraction"] = (
+ "Document entities and relationships extracted successfully."
+ )
+ workflow_messages["graph-clustering"] = (
+ "Graph communities created successfully."
+ )
+ workflow_messages["graph-deduplication"] = (
+ "Entity deduplication completed successfully."
+ )
+
+ self.providers.orchestration.register_workflows(
+ Workflow.GRAPH,
+ self.services.graph,
+ workflow_messages,
+ )
+
+ async def _get_collection_id(
+ self, collection_id: Optional[UUID], auth_user
+ ) -> UUID:
+ """Helper method to get collection ID, using default if none
+ provided."""
+ if collection_id is None:
+ return generate_default_user_collection_id(auth_user.id)
+ return collection_id
+
+ def _setup_routes(self):
+ @self.router.get(
+ "/graphs",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List graphs",
+ openapi_extra={
+ "x-codeSamples": [
+ { # TODO: Verify
+ "lang": "Python",
+ "source": textwrap.dedent(
+ """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list()
+ """
+ ),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent(
+ """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.list({});
+ }
+
+ main();
+ """
+ ),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_graphs(
+ collection_ids: list[str] = Query(
+ [],
+ description="A list of graph IDs to retrieve. If not provided, all graphs will be returned.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGraphsResponse:
+ """Returns a paginated list of graphs the authenticated user has
+ access to.
+
+ Results can be filtered by providing specific graph IDs. Regular
+ users will only see graphs they own or have access to. Superusers
+ can see all graphs.
+
+ The graphs are returned in order of last modification, with most
+ recent first.
+ """
+ requesting_user_id = (
+ None if auth_user.is_superuser else [auth_user.id]
+ )
+
+ graph_uuids = [UUID(graph_id) for graph_id in collection_ids]
+
+ list_graphs_response = await self.services.graph.list_graphs(
+ # user_ids=requesting_user_id,
+ graph_ids=graph_uuids,
+ offset=offset,
+ limit=limit,
+ )
+
+ return ( # type: ignore
+ list_graphs_response["results"],
+ {"total_entries": list_graphs_response["total_entries"]},
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Retrieve graph details",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.retrieve({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\
+ -H "Authorization: Bearer YOUR_API_KEY" """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_graph(
+ collection_id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGraphResponse:
+ """Retrieves detailed information about a specific graph by ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection associated with the given graph.",
+ 403,
+ )
+
+ list_graphs_response = await self.services.graph.list_graphs(
+ # user_ids=None,
+ graph_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ return list_graphs_response["results"][0] # type: ignore
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities/build",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def build_communities(
+ collection_id: UUID = Path(
+ ..., description="The unique identifier of the collection"
+ ),
+ graph_enrichment_settings: Optional[dict] = Body(
+ default=None,
+ description="Settings for the graph enrichment process.",
+ ),
+ run_with_orchestration: Optional[bool] = Body(True),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Creates communities in the graph by analyzing entity
+ relationships and similarities.
+
+ Communities are created through the following process:
+ 1. Analyzes entity relationships and metadata to build a similarity graph
+ 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups
+ 3. Creates hierarchical community structure with multiple granularity levels
+ 4. Generates natural language summaries and statistical insights for each community
+
+ The resulting communities can be used to:
+ - Understand high-level graph structure and organization
+ - Identify key entity groupings and their relationships
+ - Navigate and explore the graph at different levels of detail
+ - Generate insights about entity clusters and their characteristics
+
+ The community detection process is configurable through settings like:
+ - Community detection algorithm parameters
+ - Summary generation prompt
+ """
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id],
+ collection_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(collections_overview_response) == 0: # type: ignore
+ raise R2RException("Collection not found.", 404)
+
+ # Check user permissions for graph
+ if (
+ not auth_user.is_superuser
+ and collections_overview_response[0].owner_id != auth_user.id # type: ignore
+ ):
+ raise R2RException(
+ "Only superusers can `build communities` for a graph they do not own.",
+ 403,
+ )
+
+ # If no collection ID is provided, use the default user collection
+ # id = generate_default_user_collection_id(auth_user.id)
+
+ # Apply runtime settings overrides
+ server_graph_enrichment_settings = (
+ self.providers.database.config.graph_enrichment_settings
+ )
+ if graph_enrichment_settings:
+ server_graph_enrichment_settings = update_settings_from_dict(
+ server_graph_enrichment_settings, graph_enrichment_settings
+ )
+
+ workflow_input = {
+ "collection_id": str(collection_id),
+ "graph_enrichment_settings": server_graph_enrichment_settings.model_dump_json(),
+ "user": auth_user.json(),
+ }
+
+ if run_with_orchestration:
+ try:
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "graph-clustering", {"request": workflow_input}, {}
+ )
+ return GenericMessageResponse(
+ message="Graph communities created successfully."
+ ) # type: ignore
+
+ except Exception as e: # TODO: Need to find specific error (gRPC most likely?)
+ logger.error(
+ f"Error running orchestrated community building: {e} \n\nAttempting to run without orchestration."
+ )
+ from core.main.orchestration import (
+ simple_graph_search_results_factory,
+ )
+
+ logger.info("Running build-communities without orchestration.")
+ simple_graph_search_results = simple_graph_search_results_factory(
+ self.services.graph
+ )
+ await simple_graph_search_results["graph-clustering"](
+ workflow_input
+ )
+ return { # type: ignore
+ "message": "Graph communities created successfully.",
+ "task_id": None,
+ }
+
+ @self.router.post(
+ "/graphs/{collection_id}/reset",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Reset a graph back to the initial state.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.reset(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.reset({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/reset" \\
+ -H "Authorization: Bearer YOUR_API_KEY" """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def reset(
+ collection_id: UUID = Path(...),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Deletes a graph and all its associated data.
+
+ This endpoint permanently removes the specified graph along with
+ all entities and relationships that belong to only this graph. The
+ original source entities and relationships extracted from
+ underlying documents are not deleted and are managed through the
+ document lifecycle.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException("Only superusers can reset a graph", 403)
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.reset_graph(id=collection_id)
+ # await _pull(collection_id, auth_user)
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ # update graph
+ @self.router.post(
+ "/graphs/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update graph",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.update(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ graph={
+ "name": "New Name",
+ "description": "New Description"
+ }
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.update({
+ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ name: "New Name",
+ description: "New Description",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_graph(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to update",
+ ),
+ name: Optional[str] = Body(
+ None, description="The name of the graph"
+ ),
+ description: Optional[str] = Body(
+ None, description="An optional description of the graph"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGraphResponse:
+ """Update an existing graphs's configuration.
+
+ This endpoint allows updating the name and description of an
+ existing collection. The user must have appropriate permissions to
+ modify the collection.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update graph details", 403
+ )
+
+ if (
+ not auth_user.is_superuser
+ and id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_graph( # type: ignore
+ collection_id,
+ name=name,
+ description=description,
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}/entities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.listEntities({
+ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_entities(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to list entities from.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntitiesResponse:
+ """Lists all entities in the graph with pagination support."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ entities, count = await self.services.graph.get_entities(
+ parent_id=collection_id,
+ offset=offset,
+ limit=limit,
+ )
+
+ return entities, { # type: ignore
+ "total_entries": count,
+ }
+
+ @self.router.post(
+ "/graphs/{collection_id}/entities/export",
+ summary="Export graph entities to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.graphs.export_entities(
+ collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.graphs.exportEntities({
+ collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/graphs/export_entities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_entities(
+ background_tasks: BackgroundTasks,
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the collection to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_graph_entities(
+ id=collection_id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/entities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def create_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to add the entity to.",
+ ),
+ name: str = Body(
+ ..., description="The name of the entity to create."
+ ),
+ description: str = Body(
+ ..., description="The description of the entity to create."
+ ),
+ category: Optional[str] = Body(
+ None, description="The category of the entity to create."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The metadata of the entity to create."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntityResponse:
+ """Creates a new entity in the graph."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.create_entity( # type: ignore
+ name=name,
+ description=description,
+ parent_id=collection_id,
+ category=category,
+ metadata=metadata,
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def create_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to add the relationship to.",
+ ),
+ subject: str = Body(
+ ..., description="The subject of the relationship to create."
+ ),
+ subject_id: UUID = Body(
+ ...,
+ description="The ID of the subject of the relationship to create.",
+ ),
+ predicate: str = Body(
+ ..., description="The predicate of the relationship to create."
+ ),
+ object: str = Body(
+ ..., description="The object of the relationship to create."
+ ),
+ object_id: UUID = Body(
+ ...,
+ description="The ID of the object of the relationship to create.",
+ ),
+ description: str = Body(
+ ...,
+ description="The description of the relationship to create.",
+ ),
+ weight: float = Body(
+ 1.0, description="The weight of the relationship to create."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The metadata of the relationship to create."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipResponse:
+ """Creates a new relationship in the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can create relationships.", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+ return await self.services.graph.create_relationship( # type: ignore
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ description=description,
+ weight=weight,
+ metadata=metadata,
+ parent_id=collection_id,
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/relationships/export",
+ summary="Export graph relationships to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.graphs.export_entities(
+ collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.graphs.exportEntities({
+ collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/graphs/export_relationships" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_relationships(
+ background_tasks: BackgroundTasks,
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_graph_relationships(
+ id=collection_id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}/entities/{entity_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get_entity(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.get_entity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the entity.",
+ ),
+ entity_id: UUID = Path(
+ ..., description="The ID of the entity to retrieve."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntityResponse:
+ """Retrieves a specific entity by its ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ result = await self.providers.database.graphs_handler.entities.get(
+ parent_id=collection_id,
+ store_type=StoreType.GRAPHS,
+ offset=0,
+ limit=1,
+ entity_ids=[entity_id],
+ )
+ if len(result) == 0 or len(result[0]) == 0:
+ raise R2RException("Entity not found", 404)
+ return result[0][0]
+
+ @self.router.post(
+ "/graphs/{collection_id}/entities/{entity_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def update_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the entity.",
+ ),
+ entity_id: UUID = Path(
+ ..., description="The ID of the entity to update."
+ ),
+ name: Optional[str] = Body(
+ ..., description="The updated name of the entity."
+ ),
+ description: Optional[str] = Body(
+ None, description="The updated description of the entity."
+ ),
+ category: Optional[str] = Body(
+ None, description="The updated category of the entity."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The updated metadata of the entity."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEntityResponse:
+ """Updates an existing entity in the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update graph entities.", 403
+ )
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_entity( # type: ignore
+ entity_id=entity_id,
+ name=name,
+ category=category,
+ description=description,
+ metadata=metadata,
+ )
+
+ @self.router.delete(
+ "/graphs/{collection_id}/entities/{entity_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Remove an entity",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.remove_entity(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.removeEntity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_entity(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to remove the entity from.",
+ ),
+ entity_id: UUID = Path(
+ ...,
+ description="The ID of the entity to remove from the graph.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Removes an entity from the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can delete graph details.", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.delete_entity(
+ parent_id=collection_id,
+ entity_id=entity_id,
+ )
+
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/graphs/{collection_id}/relationships",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ description="Lists all relationships in the graph with pagination support.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.listRelationships({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_relationships(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to list relationships from.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipsResponse:
+ """Lists all relationships in the graph with pagination support."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ relationships, count = await self.services.graph.get_relationships(
+ parent_id=collection_id,
+ offset=offset,
+ limit=limit,
+ )
+
+ return relationships, { # type: ignore
+ "total_entries": count,
+ }
+
+ @self.router.get(
+ "/graphs/{collection_id}/relationships/{relationship_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ description="Retrieves a specific relationship by its ID.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get_relationship(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.getRelationship({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def get_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the relationship.",
+ ),
+ relationship_id: UUID = Path(
+ ..., description="The ID of the relationship to retrieve."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipResponse:
+ """Retrieves a specific relationship by its ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ results = (
+ await self.providers.database.graphs_handler.relationships.get(
+ parent_id=collection_id,
+ store_type=StoreType.GRAPHS,
+ offset=0,
+ limit=1,
+ relationship_ids=[relationship_id],
+ )
+ )
+ if len(results) == 0 or len(results[0]) == 0:
+ raise R2RException("Relationship not found", 404)
+ return results[0][0]
+
+ @self.router.post(
+ "/graphs/{collection_id}/relationships/{relationship_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ )
+ @self.base_endpoint
+ async def update_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph containing the relationship.",
+ ),
+ relationship_id: UUID = Path(
+ ..., description="The ID of the relationship to update."
+ ),
+ subject: Optional[str] = Body(
+ ..., description="The updated subject of the relationship."
+ ),
+ subject_id: Optional[UUID] = Body(
+ ..., description="The updated subject ID of the relationship."
+ ),
+ predicate: Optional[str] = Body(
+ ..., description="The updated predicate of the relationship."
+ ),
+ object: Optional[str] = Body(
+ ..., description="The updated object of the relationship."
+ ),
+ object_id: Optional[UUID] = Body(
+ ..., description="The updated object ID of the relationship."
+ ),
+ description: Optional[str] = Body(
+ None,
+ description="The updated description of the relationship.",
+ ),
+ weight: Optional[float] = Body(
+ None, description="The updated weight of the relationship."
+ ),
+ metadata: Optional[dict] = Body(
+ None, description="The updated metadata of the relationship."
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRelationshipResponse:
+ """Updates an existing relationship in the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update graph details", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_relationship( # type: ignore
+ relationship_id=relationship_id,
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ description=description,
+ weight=weight,
+ metadata=metadata,
+ )
+
+ @self.router.delete(
+ "/graphs/{collection_id}/relationships/{relationship_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ description="Removes a relationship from the graph.",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.delete_relationship(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.deleteRelationship({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ],
+ },
+ )
+ @self.base_endpoint
+ async def delete_relationship(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to remove the relationship from.",
+ ),
+ relationship_id: UUID = Path(
+ ...,
+ description="The ID of the relationship to remove from the graph.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Removes a relationship from the graph."""
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can delete a relationship.", 403
+ )
+
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.delete_relationship(
+ parent_id=collection_id,
+ relationship_id=relationship_id,
+ )
+
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create a new community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.create_community(
+ collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ name="My Community",
+ summary="A summary of the community",
+ findings=["Finding 1", "Finding 2"],
+ rating=5,
+ rating_explanation="This is a rating explanation",
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.createCommunity({
+ collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ name: "My Community",
+ summary: "A summary of the community",
+ findings: ["Finding 1", "Finding 2"],
+ rating: 5,
+ ratingExplanation: "This is a rating explanation",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_community(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to create the community in.",
+ ),
+ name: str = Body(..., description="The name of the community"),
+ summary: str = Body(..., description="A summary of the community"),
+ findings: Optional[list[str]] = Body(
+ default=[], description="Findings about the community"
+ ),
+ rating: Optional[float] = Body(
+ default=5, ge=1, le=10, description="Rating between 1 and 10"
+ ),
+ rating_explanation: Optional[str] = Body(
+ default="", description="Explanation for the rating"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunityResponse:
+ """Creates a new community in the graph.
+
+ While communities are typically built automatically via the /graphs/{id}/communities/build endpoint,
+ this endpoint allows you to manually create your own communities.
+
+ This can be useful when you want to:
+ - Define custom groupings of entities based on domain knowledge
+ - Add communities that weren't detected by the automatic process
+ - Create hierarchical organization structures
+ - Tag groups of entities with specific metadata
+
+ The created communities will be integrated with any existing automatically detected communities
+ in the graph's community structure.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can create a community.", 403
+ )
+
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.create_community( # type: ignore
+ parent_id=collection_id,
+ name=name,
+ summary=summary,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ @self.router.get(
+ "/graphs/{collection_id}/communities",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List communities",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.listCommunities({
+ collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_communities(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to get communities for.",
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunitiesResponse:
+ """Lists all communities in the graph with pagination support."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ communities, count = await self.services.graph.get_communities(
+ parent_id=collection_id,
+ offset=offset,
+ limit=limit,
+ )
+
+ return communities, { # type: ignore
+ "total_entries": count,
+ }
+
+ @self.router.get(
+ "/graphs/{collection_id}/communities/{community_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Retrieve a community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.getCommunity({
+ collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_community(
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the collection to get communities for.",
+ ),
+ community_id: UUID = Path(
+ ...,
+ description="The ID of the community to get.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunityResponse:
+ """Retrieves a specific community by its ID."""
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ results = (
+ await self.providers.database.graphs_handler.communities.get(
+ parent_id=collection_id,
+ community_ids=[community_id],
+ store_type=StoreType.GRAPHS,
+ offset=0,
+ limit=1,
+ )
+ )
+ if len(results) == 0 or len(results[0]) == 0:
+ raise R2RException("Community not found", 404)
+ return results[0][0]
+
+ @self.router.delete(
+ "/graphs/{collection_id}/communities/{community_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete a community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.delete_community(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ community_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.graphs.deleteCommunity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_community(
+ collection_id: UUID = Path(
+ ...,
+ description="The collection ID corresponding to the graph to delete the community from.",
+ ),
+ community_id: UUID = Path(
+ ...,
+ description="The ID of the community to delete.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.graph_ids
+ ):
+ raise R2RException(
+ "Only superusers can delete communities", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ await self.services.graph.delete_community(
+ parent_id=collection_id,
+ community_id=community_id,
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities/export",
+ summary="Export document communities to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.graphs.export_communities(
+ collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ output_path="export.csv",
+ columns=["id", "title", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.graphs.exportCommunities({
+ collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+ outputPath: "export.csv",
+ columns: ["id", "title", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/graphs/export_communities" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_communities(
+ background_tasks: BackgroundTasks,
+ collection_id: UUID = Path(
+ ...,
+ description="The ID of the document to export entities from.",
+ ),
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export documents as a downloadable CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can export data.",
+ 403,
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_graph_communities(
+ id=collection_id,
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="documents_export.csv",
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/communities/{community_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update community",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.update_community(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ community_update={
+ "metadata": {
+ "topic": "Technology",
+ "description": "Tech companies and products"
+ }
+ }
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ async function main() {
+ const response = await client.graphs.updateCommunity({
+ collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
+ communityUpdate: {
+ metadata: {
+ topic: "Technology",
+ description: "Tech companies and products"
+ }
+ }
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_community(
+ collection_id: UUID = Path(...),
+ community_id: UUID = Path(...),
+ name: Optional[str] = Body(None),
+ summary: Optional[str] = Body(None),
+ findings: Optional[list[str]] = Body(None),
+ rating: Optional[float] = Body(default=None, ge=1, le=10),
+ rating_explanation: Optional[str] = Body(None),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCommunityResponse:
+ """Updates an existing community in the graph."""
+ if (
+ not auth_user.is_superuser
+ and collection_id not in auth_user.graph_ids
+ ):
+ raise R2RException(
+ "Only superusers can update communities.", 403
+ )
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ return await self.services.graph.update_community( # type: ignore
+ community_id=community_id,
+ name=name,
+ summary=summary,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ @self.router.post(
+ "/graphs/{collection_id}/pull",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Pull latest entities to the graph",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ response = client.graphs.pull(
+ collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ async function main() {
+ const response = await client.graphs.pull({
+ collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def pull(
+ collection_id: UUID = Path(
+ ..., description="The ID of the graph to initialize."
+ ),
+ force: Optional[bool] = Body(
+ False,
+ description="If true, forces a re-pull of all entities and relationships.",
+ ),
+ # document_ids: list[UUID] = Body(
+ # ..., description="List of document IDs to add to the graph."
+ # ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Adds documents to a graph by copying their entities and
+ relationships.
+
+ This endpoint:
+ 1. Copies document entities to the graphs_entities table
+ 2. Copies document relationships to the graphs_relationships table
+ 3. Associates the documents with the graph
+
+ When a document is added:
+ - Its entities and relationships are copied to graph-specific tables
+ - Existing entities/relationships are updated by merging their properties
+ - The document ID is recorded in the graph's document_ids array
+
+ Documents added to a graph will contribute their knowledge to:
+ - Graph analysis and querying
+ - Community detection
+ - Knowledge graph enrichment
+
+ The user must have access to both the graph and the documents being added.
+ """
+
+ collections_overview_response = (
+ await self.services.management.collections_overview(
+ user_ids=[auth_user.id],
+ collection_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ )["results"]
+ if len(collections_overview_response) == 0: # type: ignore
+ raise R2RException("Collection not found.", 404)
+
+ # Check user permissions for graph
+ if (
+ not auth_user.is_superuser
+ and collections_overview_response[0].owner_id != auth_user.id # type: ignore
+ ):
+ raise R2RException("Only superusers can `pull` a graph.", 403)
+
+ if (
+ # not auth_user.is_superuser
+ collection_id not in auth_user.collection_ids
+ ):
+ raise R2RException(
+ "The currently authenticated user does not have access to the collection associated with the given graph.",
+ 403,
+ )
+
+ list_graphs_response = await self.services.graph.list_graphs(
+ # user_ids=None,
+ graph_ids=[collection_id],
+ offset=0,
+ limit=1,
+ )
+ if len(list_graphs_response["results"]) == 0: # type: ignore
+ raise R2RException("Graph not found", 404)
+ collection_id = list_graphs_response["results"][0].collection_id # type: ignore
+ documents: list[DocumentResponse] = []
+ document_req = await self.providers.database.collections_handler.documents_in_collection(
+ collection_id, offset=0, limit=100
+ )
+ results = cast(list[DocumentResponse], document_req["results"])
+ documents.extend(results)
+
+ while len(results) == 100:
+ document_req = await self.providers.database.collections_handler.documents_in_collection(
+ collection_id, offset=len(documents), limit=100
+ )
+ results = cast(list[DocumentResponse], document_req["results"])
+ documents.extend(results)
+
+ success = False
+
+ for document in documents:
+ entities = (
+ await self.providers.database.graphs_handler.entities.get(
+ parent_id=document.id,
+ store_type=StoreType.DOCUMENTS,
+ offset=0,
+ limit=100,
+ )
+ )
+ has_document = (
+ await self.providers.database.graphs_handler.has_document(
+ collection_id, document.id
+ )
+ )
+ if has_document:
+ logger.info(
+ f"Document {document.id} is already in graph {collection_id}, skipping."
+ )
+ continue
+ if len(entities[0]) == 0:
+ if not force:
+ logger.warning(
+ f"Document {document.id} has no entities, extraction may not have been called, skipping."
+ )
+ continue
+ else:
+ logger.warning(
+ f"Document {document.id} has no entities, but force=True, continuing."
+ )
+
+ success = (
+ await self.providers.database.graphs_handler.add_documents(
+ id=collection_id,
+ document_ids=[document.id],
+ )
+ )
+ if not success:
+ logger.warning(
+ f"No documents were added to graph {collection_id}, marking as failed."
+ )
+
+ if success:
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.SUCCESS,
+ )
+
+ return GenericBooleanResponse(success=success) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py
new file mode 100644
index 00000000..29b75226
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/indices_router.py
@@ -0,0 +1,576 @@
+import logging
+import textwrap
+from typing import Optional
+
+from fastapi import Body, Depends, Path, Query
+
+from core.base import IndexConfig, R2RException
+from core.base.abstractions import VectorTableName
+from core.base.api.models import (
+ VectorIndexResponse,
+ VectorIndicesResponse,
+ WrappedGenericMessageResponse,
+ WrappedVectorIndexResponse,
+ WrappedVectorIndicesResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+logger = logging.getLogger()
+
+
+class IndicesRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing IndicesRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ ## TODO - Allow developer to pass the index id with the request
+ @self.router.post(
+ "/indices",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create Vector Index",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ # Create an HNSW index for efficient similarity search
+ result = client.indices.create(
+ config={
+ "table_name": "chunks", # The table containing vector embeddings
+ "index_method": "hnsw", # Hierarchical Navigable Small World graph
+ "index_measure": "cosine_distance", # Similarity measure
+ "index_arguments": {
+ "m": 16, # Number of connections per layer
+ "ef_construction": 64,# Size of dynamic candidate list for construction
+ "ef": 40, # Size of dynamic candidate list for search
+ },
+ "index_name": "my_document_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": True # Build index without blocking table writes
+ },
+ run_with_orchestration=True # Run as orchestrated task for large indices
+ )
+
+ # Create an IVF-Flat index for balanced performance
+ result = client.indices.create(
+ config={
+ "table_name": "chunks",
+ "index_method": "ivf_flat", # Inverted File with Flat storage
+ "index_measure": "l2_distance",
+ "index_arguments": {
+ "lists": 100, # Number of cluster centroids
+ "probe": 10, # Number of clusters to search
+ },
+ "index_name": "my_ivf_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": True
+ }
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.create({
+ config: {
+ tableName: "vectors",
+ indexMethod: "hnsw",
+ indexMeasure: "cosine_distance",
+ indexArguments: {
+ m: 16,
+ ef_construction: 64,
+ ef: 40
+ },
+ indexName: "my_document_embeddings_idx",
+ indexColumn: "embedding",
+ concurrently: true
+ },
+ runWithOrchestration: true
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ # Create HNSW Index
+ curl -X POST "https://api.example.com/indices" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "config": {
+ "table_name": "vectors",
+ "index_method": "hnsw",
+ "index_measure": "cosine_distance",
+ "index_arguments": {
+ "m": 16,
+ "ef_construction": 64,
+ "ef": 40
+ },
+ "index_name": "my_document_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": true
+ },
+ "run_with_orchestration": true
+ }'
+
+ # Create IVF-Flat Index
+ curl -X POST "https://api.example.com/indices" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -d '{
+ "config": {
+ "table_name": "vectors",
+ "index_method": "ivf_flat",
+ "index_measure": "l2_distance",
+ "index_arguments": {
+ "lists": 100,
+ "probe": 10
+ },
+ "index_name": "my_ivf_embeddings_idx",
+ "index_column": "embedding",
+ "concurrently": true
+ }
+ }'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_index(
+ config: IndexConfig,
+ run_with_orchestration: Optional[bool] = Body(
+ True,
+ description="Whether to run index creation as an orchestrated task (recommended for large indices)",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Create a new vector similarity search index in over the target
+ table. Allowed tables include 'vectors', 'entity',
+ 'document_collections'. Vectors correspond to the chunks of text
+ that are indexed for similarity search, whereas entity and
+ document_collections are created during knowledge graph
+ construction.
+
+ This endpoint creates a database index optimized for efficient similarity search over vector embeddings.
+ It supports two main indexing methods:
+
+ 1. HNSW (Hierarchical Navigable Small World):
+ - Best for: High-dimensional vectors requiring fast approximate nearest neighbor search
+ - Pros: Very fast search, good recall, memory-resident for speed
+ - Cons: Slower index construction, more memory usage
+ - Key parameters:
+ * m: Number of connections per layer (higher = better recall but more memory)
+ * ef_construction: Build-time search width (higher = better recall but slower build)
+ * ef: Query-time search width (higher = better recall but slower search)
+
+ 2. IVF-Flat (Inverted File with Flat Storage):
+ - Best for: Balance between build speed, search speed, and recall
+ - Pros: Faster index construction, less memory usage
+ - Cons: Slightly slower search than HNSW
+ - Key parameters:
+ * lists: Number of clusters (usually sqrt(n) where n is number of vectors)
+ * probe: Number of nearest clusters to search
+
+ Supported similarity measures:
+ - cosine_distance: Best for comparing semantic similarity
+ - l2_distance: Best for comparing absolute distances
+ - ip_distance: Best for comparing raw dot products
+
+ Notes:
+ - Index creation can be resource-intensive for large datasets
+ - Use run_with_orchestration=True for large indices to prevent timeouts
+ - The 'concurrently' option allows other operations while building
+ - Index names must be unique per table
+ """
+ # TODO: Implement index creation logic
+ logger.info(
+ f"Creating vector index for {config.table_name} with method {config.index_method}, measure {config.index_measure}, concurrently {config.concurrently}"
+ )
+
+ result = await self.providers.orchestration.run_workflow(
+ "create-vector-index",
+ {
+ "request": {
+ "table_name": config.table_name,
+ "index_method": config.index_method,
+ "index_measure": config.index_measure,
+ "index_name": config.index_name,
+ "index_column": config.index_column,
+ "index_arguments": config.index_arguments,
+ "concurrently": config.concurrently,
+ },
+ },
+ options={
+ "additional_metadata": {},
+ },
+ )
+
+ return result # type: ignore
+
+ @self.router.get(
+ "/indices",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List Vector Indices",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ # List all indices
+ indices = client.indices.list(
+ offset=0,
+ limit=10
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.list({
+ offset: 0,
+ limit: 10,
+ filters: { table_name: "vectors" }
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/indices?offset=0&limit=10" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json"
+
+ # With filters
+ curl -X GET "https://api.example.com/indices?offset=0&limit=10&filters={\"table_name\":\"vectors\"}" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_indices(
+ # filters: list[str] = Query([]),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedVectorIndicesResponse:
+ """List existing vector similarity search indices with pagination
+ support.
+
+ Returns details about each index including:
+ - Name and table name
+ - Indexing method and parameters
+ - Size and row count
+ - Creation timestamp and last updated
+ - Performance statistics (if available)
+
+ The response can be filtered using the filter_by parameter to narrow down results
+ based on table name, index method, or other attributes.
+ """
+ # TODO: Implement index listing logic
+ indices_data = (
+ await self.providers.database.chunks_handler.list_indices(
+ offset=offset, limit=limit
+ )
+ )
+
+ formatted_indices = VectorIndicesResponse(
+ indices=[
+ VectorIndexResponse(index=index_data)
+ for index_data in indices_data["indices"]
+ ]
+ )
+
+ return ( # type: ignore
+ formatted_indices,
+ {"total_entries": indices_data["total_entries"]},
+ )
+
+ @self.router.get(
+ "/indices/{table_name}/{index_name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get Vector Index Details",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ # Get detailed information about a specific index
+ index = client.indices.retrieve("index_1")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.retrieve({
+ indexName: "index_1",
+ tableName: "vectors"
+ });
+
+ console.log(response);
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/indices/vectors/index_1" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_index(
+ table_name: VectorTableName = Path(
+ ...,
+ description="The table of vector embeddings to delete (e.g. `vectors`, `entity`, `document_collections`)",
+ ),
+ index_name: str = Path(
+ ..., description="The name of the index to delete"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedVectorIndexResponse:
+ """Get detailed information about a specific vector index.
+
+ Returns comprehensive information about the index including:
+ - Configuration details (method, measure, parameters)
+ - Current size and row count
+ - Build progress (if still under construction)
+ - Performance statistics:
+ * Average query time
+ * Memory usage
+ * Cache hit rates
+ * Recent query patterns
+ - Maintenance information:
+ * Last vacuum
+ * Fragmentation level
+ * Recommended optimizations
+ """
+ # TODO: Implement get index logic
+ indices = (
+ await self.providers.database.chunks_handler.list_indices(
+ filters={
+ "index_name": index_name,
+ "table_name": table_name,
+ },
+ limit=1,
+ offset=0,
+ )
+ )
+ if len(indices["indices"]) != 1:
+ raise R2RException(
+ f"Index '{index_name}' not found", status_code=404
+ )
+ return {"index": indices["indices"][0]} # type: ignore
+
+ # TODO - Implement update index
+ # @self.router.post(
+ # "/indices/{name}",
+ # summary="Update Vector Index",
+ # openapi_extra={
+ # "x-codeSamples": [
+ # {
+ # "lang": "Python",
+ # "source": """
+ # from r2r import R2RClient
+
+ # client = R2RClient()
+
+ # # Update HNSW index parameters
+ # result = client.indices.update(
+ # "550e8400-e29b-41d4-a716-446655440000",
+ # config={
+ # "index_arguments": {
+ # "ef": 80, # Increase search quality
+ # "m": 24 # Increase connections per layer
+ # },
+ # "concurrently": True
+ # },
+ # run_with_orchestration=True
+ # )""",
+ # },
+ # {
+ # "lang": "Shell",
+ # "source": """
+ # curl -X PUT "https://api.example.com/indices/550e8400-e29b-41d4-a716-446655440000" \\
+ # -H "Content-Type: application/json" \\
+ # -H "Authorization: Bearer YOUR_API_KEY" \\
+ # -d '{
+ # "config": {
+ # "index_arguments": {
+ # "ef": 80,
+ # "m": 24
+ # },
+ # "concurrently": true
+ # },
+ # "run_with_orchestration": true
+ # }'""",
+ # },
+ # ]
+ # },
+ # )
+ # @self.base_endpoint
+ # async def update_index(
+ # id: UUID = Path(...),
+ # config: IndexConfig = Body(...),
+ # run_with_orchestration: Optional[bool] = Body(True),
+ # auth_user=Depends(self.providers.auth.auth_wrapper()),
+ # ): # -> WrappedUpdateIndexResponse:
+ # """
+ # Update an existing index's configuration.
+ # """
+ # # TODO: Implement index update logic
+ # pass
+
+ @self.router.delete(
+ "/indices/{table_name}/{index_name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete Vector Index",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+
+ # Delete an index with orchestration for cleanup
+ result = client.indices.delete(
+ index_name="index_1",
+ table_name="vectors",
+ run_with_orchestration=True
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.indicies.delete({
+ indexName: "index_1"
+ tableName: "vectors"
+ });
+
+ console.log(response);
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/indices/index_1" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_index(
+ table_name: VectorTableName = Path(
+ default=...,
+ description="The table of vector embeddings to delete (e.g. `vectors`, `entity`, `document_collections`)",
+ ),
+ index_name: str = Path(
+ ..., description="The name of the index to delete"
+ ),
+ # concurrently: bool = Body(
+ # default=True,
+ # description="Whether to delete the index concurrently (recommended for large indices)",
+ # ),
+ # run_with_orchestration: Optional[bool] = Body(True),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Delete an existing vector similarity search index.
+
+ This endpoint removes the specified index from the database. Important considerations:
+
+ - Deletion is permanent and cannot be undone
+ - Underlying vector data remains intact
+ - Queries will fall back to sequential scan
+ - Running queries during deletion may be slower
+ - Use run_with_orchestration=True for large indices to prevent timeouts
+ - Consider index dependencies before deletion
+
+ The operation returns immediately but cleanup may continue in background.
+ """
+ logger.info(
+ f"Deleting vector index {index_name} from table {table_name}"
+ )
+
+ return await self.providers.orchestration.run_workflow( # type: ignore
+ "delete-vector-index",
+ {
+ "request": {
+ "index_name": index_name,
+ "table_name": table_name,
+ "concurrently": True,
+ },
+ },
+ options={
+ "additional_metadata": {},
+ },
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py
new file mode 100644
index 00000000..55512143
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/prompts_router.py
@@ -0,0 +1,387 @@
+import logging
+import textwrap
+from typing import Optional
+
+from fastapi import Body, Depends, Path, Query
+
+from core.base import R2RException
+from core.base.api.models import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ WrappedBooleanResponse,
+ WrappedGenericMessageResponse,
+ WrappedPromptResponse,
+ WrappedPromptsResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+
+class PromptsRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing PromptsRouter")
+ super().__init__(providers, services, config)
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/prompts",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create a new prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.create(
+ name="greeting_prompt",
+ template="Hello, {name}!",
+ input_types={"name": "string"}
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.create({
+ name: "greeting_prompt",
+ template: "Hello, {name}!",
+ inputTypes: { name: "string" },
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/prompts" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"name": "greeting_prompt", "template": "Hello, {name}!", "input_types": {"name": "string"}}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_prompt(
+ name: str = Body(..., description="The name of the prompt"),
+ template: str = Body(
+ ..., description="The template string for the prompt"
+ ),
+ input_types: dict[str, str] = Body(
+ default={},
+ description="A dictionary mapping input names to their types",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Create a new prompt with the given configuration.
+
+ This endpoint allows superusers to create a new prompt with a
+ specified name, template, and input types.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can create prompts.",
+ 403,
+ )
+ result = await self.services.management.add_prompt(
+ name, template, input_types
+ )
+ return GenericMessageResponse(message=result) # type: ignore
+
+ @self.router.get(
+ "/prompts",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List all prompts",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.list()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/v3/prompts" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_prompts(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedPromptsResponse:
+ """List all available prompts.
+
+ This endpoint retrieves a list of all prompts in the system. Only
+ superusers can access this endpoint.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can list prompts.",
+ 403,
+ )
+ get_prompts_response = (
+ await self.services.management.get_all_prompts()
+ )
+
+ return ( # type: ignore
+ get_prompts_response["results"],
+ {
+ "total_entries": get_prompts_response["total_entries"],
+ },
+ )
+
+ @self.router.post(
+ "/prompts/{name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get a specific prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.get(
+ "greeting_prompt",
+ inputs={"name": "John"},
+ prompt_override="Hi, {name}!"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.retrieve({
+ name: "greeting_prompt",
+ inputs: { name: "John" },
+ promptOverride: "Hi, {name}!",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_prompt(
+ name: str = Path(..., description="Prompt name"),
+ inputs: Optional[dict[str, str]] = Body(
+ None, description="Prompt inputs"
+ ),
+ prompt_override: Optional[str] = Query(
+ None, description="Prompt override"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedPromptResponse:
+ """Get a specific prompt by name, optionally with inputs and
+ override.
+
+ This endpoint retrieves a specific prompt and allows for optional
+ inputs and template override. Only superusers can access this
+ endpoint.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can retrieve prompts.",
+ 403,
+ )
+ result = await self.services.management.get_prompt(
+ name, inputs, prompt_override
+ )
+ return result # type: ignore
+
+ @self.router.put(
+ "/prompts/{name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update an existing prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.update(
+ "greeting_prompt",
+ template="Greetings, {name}!",
+ input_types={"name": "string", "age": "integer"}
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.update({
+ name: "greeting_prompt",
+ template: "Greetings, {name}!",
+ inputTypes: { name: "string", age: "integer" },
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X PUT "https://api.example.com/v3/prompts/greeting_prompt" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{"template": "Greetings, {name}!", "input_types": {"name": "string", "age": "integer"}}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def update_prompt(
+ name: str = Path(..., description="Prompt name"),
+ template: Optional[str] = Body(
+ None, description="Updated prompt template"
+ ),
+ input_types: dict[str, str] = Body(
+ default={},
+ description="A dictionary mapping input names to their types",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Update an existing prompt's template and/or input types.
+
+ This endpoint allows superusers to update the template and input
+ types of an existing prompt.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can update prompts.",
+ 403,
+ )
+ result = await self.services.management.update_prompt(
+ name, template, input_types
+ )
+ return GenericMessageResponse(message=result) # type: ignore
+
+ @self.router.delete(
+ "/prompts/{name}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete a prompt",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.prompts.delete("greeting_prompt")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.prompts.delete({
+ name: "greeting_prompt",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/v3/prompts/greeting_prompt" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_prompt(
+ name: str = Path(..., description="Prompt name"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a prompt by name.
+
+ This endpoint allows superusers to delete an existing prompt.
+ """
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can delete prompts.",
+ 403,
+ )
+ await self.services.management.delete_prompt(name)
+ return GenericBooleanResponse(success=True) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py
new file mode 100644
index 00000000..28749319
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/retrieval_router.py
@@ -0,0 +1,639 @@
+import logging
+from typing import Any, Literal, Optional
+from uuid import UUID
+
+from fastapi import Body, Depends
+from fastapi.responses import StreamingResponse
+
+from core.base import (
+ GenerationConfig,
+ Message,
+ R2RException,
+ SearchMode,
+ SearchSettings,
+ select_search_filters,
+)
+from core.base.api.models import (
+ WrappedAgentResponse,
+ WrappedCompletionResponse,
+ WrappedEmbeddingResponse,
+ WrappedLLMChatCompletion,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+from .examples import EXAMPLES
+
+logger = logging.getLogger(__name__)
+
+
+def merge_search_settings(
+ base: SearchSettings, overrides: SearchSettings
+) -> SearchSettings:
+ # Convert both to dict
+ base_dict = base.model_dump()
+ overrides_dict = overrides.model_dump(exclude_unset=True)
+
+ # Update base_dict with values from overrides_dict
+ # This ensures that any field set in overrides takes precedence
+ for k, v in overrides_dict.items():
+ base_dict[k] = v
+
+ # Construct a new SearchSettings from the merged dict
+ return SearchSettings(**base_dict)
+
+
+class RetrievalRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing RetrievalRouter")
+ super().__init__(providers, services, config)
+
+ def _register_workflows(self):
+ pass
+
+ def _prepare_search_settings(
+ self,
+ auth_user: Any,
+ search_mode: SearchMode,
+ search_settings: Optional[SearchSettings],
+ ) -> SearchSettings:
+ """Prepare the effective search settings based on the provided
+ search_mode, optional user-overrides in search_settings, and applied
+ filters."""
+ if search_mode != SearchMode.custom:
+ # Start from mode defaults
+ effective_settings = SearchSettings.get_default(search_mode.value)
+ if search_settings:
+ # Merge user-provided overrides
+ effective_settings = merge_search_settings(
+ effective_settings, search_settings
+ )
+ else:
+ # Custom mode: use provided settings or defaults
+ effective_settings = search_settings or SearchSettings()
+
+ # Apply user-specific filters
+ effective_settings.filters = select_search_filters(
+ auth_user, effective_settings
+ )
+ return effective_settings
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/retrieval/search",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Search R2R",
+ openapi_extra=EXAMPLES["search"],
+ )
+ @self.base_endpoint
+ async def search_app(
+ query: str = Body(
+ ...,
+ description="Search query to find relevant documents",
+ ),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description=(
+ "Default value of `custom` allows full control over search settings.\n\n"
+ "Pre-configured search modes:\n"
+ "`basic`: A simple semantic-based search.\n"
+ "`advanced`: A more powerful hybrid search combining semantic and full-text.\n"
+ "`custom`: Full control via `search_settings`.\n\n"
+ "If `filters` or `limit` are provided alongside `basic` or `advanced`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ search_settings: Optional[SearchSettings] = Body(
+ None,
+ description=(
+ "The search configuration object. If `search_mode` is `custom`, "
+ "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n"
+ "Common overrides include `filters` to narrow results and `limit` to control how many results are returned."
+ ),
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedSearchResponse:
+ """Perform a search query against vector and/or graph-based
+ databases.
+
+ **Search Modes:**
+ - `basic`: Defaults to semantic search. Simple and easy to use.
+ - `advanced`: Combines semantic search with full-text search for more comprehensive results.
+ - `custom`: Complete control over how search is performed. Provide a full `SearchSettings` object.
+
+ **Filters:**
+ Apply filters directly inside `search_settings.filters`. For example:
+ ```json
+ {
+ "filters": {"document_id": {"$eq": "e43864f5-a36f-548e-aacd-6f8d48b30c7f"}}
+ }
+ ```
+ Supported operators: `$eq`, `$neq`, `$gt`, `$gte`, `$lt`, `$lte`, `$like`, `$ilike`, `$in`, `$nin`.
+
+ **Hybrid Search:**
+ Enable hybrid search by setting `use_hybrid_search: true` in search_settings. This combines semantic search with
+ keyword-based search for improved results. Configure with `hybrid_settings`:
+ ```json
+ {
+ "use_hybrid_search": true,
+ "hybrid_settings": {
+ "full_text_weight": 1.0,
+ "semantic_weight": 5.0,
+ "full_text_limit": 200,
+ "rrf_k": 50
+ }
+ }
+ ```
+
+ **Graph-Enhanced Search:**
+ Knowledge graph integration is enabled by default. Control with `graph_search_settings`:
+ ```json
+ {
+ "graph_search_settings": {
+ "use_graph_search": true,
+ "kg_search_type": "local"
+ }
+ }
+ ```
+
+ **Advanced Filtering:**
+ Use complex filters to narrow down results by metadata fields or document properties:
+ ```json
+ {
+ "filters": {
+ "$and":[
+ {"document_type": {"$eq": "pdf"}},
+ {"metadata.year": {"$gt": 2020}}
+ ]
+ }
+ }
+ ```
+
+ **Results:**
+ The response includes vector search results and optional graph search results.
+ Each result contains the matched text, document ID, and relevance score.
+
+ """
+ if query == "":
+ raise R2RException("Query cannot be empty", 400)
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+ results = await self.services.retrieval.search(
+ query=query,
+ search_settings=effective_settings,
+ )
+ return results # type: ignore
+
+ @self.router.post(
+ "/retrieval/rag",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="RAG Query",
+ response_model=None,
+ openapi_extra=EXAMPLES["rag"],
+ )
+ @self.base_endpoint
+ async def rag_app(
+ query: str = Body(...),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description=(
+ "Default value of `custom` allows full control over search settings.\n\n"
+ "Pre-configured search modes:\n"
+ "`basic`: A simple semantic-based search.\n"
+ "`advanced`: A more powerful hybrid search combining semantic and full-text.\n"
+ "`custom`: Full control via `search_settings`.\n\n"
+ "If `filters` or `limit` are provided alongside `basic` or `advanced`, "
+ "they will override the default settings for that mode."
+ ),
+ ),
+ search_settings: Optional[SearchSettings] = Body(
+ None,
+ description=(
+ "The search configuration object. If `search_mode` is `custom`, "
+ "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n"
+ "Common overrides include `filters` to narrow results and `limit` to control how many results are returned."
+ ),
+ ),
+ rag_generation_config: GenerationConfig = Body(
+ default_factory=GenerationConfig,
+ description="Configuration for RAG generation",
+ ),
+ task_prompt: Optional[str] = Body(
+ default=None,
+ description="Optional custom prompt to override default",
+ ),
+ include_title_if_available: bool = Body(
+ default=False,
+ description="Include document titles in responses when available",
+ ),
+ include_web_search: bool = Body(
+ default=False,
+ description="Include web search results provided to the LLM.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedRAGResponse:
+ """Execute a RAG (Retrieval-Augmented Generation) query.
+
+ This endpoint combines search results with language model generation to produce accurate,
+ contextually-relevant responses based on your document corpus.
+
+ **Features:**
+ - Combines vector search, optional knowledge graph integration, and LLM generation
+ - Automatically cites sources with unique citation identifiers
+ - Supports both streaming and non-streaming responses
+ - Compatible with various LLM providers (OpenAI, Anthropic, etc.)
+ - Web search integration for up-to-date information
+
+ **Search Configuration:**
+ All search parameters from the search endpoint apply here, including filters, hybrid search, and graph-enhanced search.
+
+ **Generation Configuration:**
+ Fine-tune the language model's behavior with `rag_generation_config`:
+ ```json
+ {
+ "model": "openai/gpt-4o-mini", // Model to use
+ "temperature": 0.7, // Control randomness (0-1)
+ "max_tokens": 1500, // Maximum output length
+ "stream": true // Enable token streaming
+ }
+ ```
+
+ **Model Support:**
+ - OpenAI models (default)
+ - Anthropic Claude models (requires ANTHROPIC_API_KEY)
+ - Local models via Ollama
+ - Any provider supported by LiteLLM
+
+ **Streaming Responses:**
+ When `stream: true` is set, the endpoint returns Server-Sent Events with the following types:
+ - `search_results`: Initial search results from your documents
+ - `message`: Partial tokens as they're generated
+ - `citation`: Citation metadata when sources are referenced
+ - `final_answer`: Complete answer with structured citations
+
+ **Example Response:**
+ ```json
+ {
+ "generated_answer": "DeepSeek-R1 is a model that demonstrates impressive performance...[1]",
+ "search_results": { ... },
+ "citations": [
+ {
+ "id": "cit.123456",
+ "object": "citation",
+ "payload": { ... }
+ }
+ ]
+ }
+ ```
+ """
+
+ if "model" not in rag_generation_config.__fields_set__:
+ rag_generation_config.model = self.config.app.quality_llm
+
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+
+ response = await self.services.retrieval.rag(
+ query=query,
+ search_settings=effective_settings,
+ rag_generation_config=rag_generation_config,
+ task_prompt=task_prompt,
+ include_title_if_available=include_title_if_available,
+ include_web_search=include_web_search,
+ )
+
+ if rag_generation_config.stream:
+ # ========== Streaming path ==========
+ async def stream_generator():
+ try:
+ async for chunk in response:
+ if len(chunk) > 1024:
+ for i in range(0, len(chunk), 1024):
+ yield chunk[i : i + 1024]
+ else:
+ yield chunk
+ except GeneratorExit:
+ # Clean up if needed, then return
+ return
+
+ return StreamingResponse(
+ stream_generator(), media_type="text/event-stream"
+ ) # type: ignore
+ else:
+ # ========== Non-streaming path ==========
+ return response
+
+ @self.router.post(
+ "/retrieval/agent",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="RAG-powered Conversational Agent",
+ openapi_extra=EXAMPLES["agent"],
+ )
+ @self.base_endpoint
+ async def agent_app(
+ message: Optional[Message] = Body(
+ None,
+ description="Current message to process",
+ ),
+ messages: Optional[list[Message]] = Body(
+ None,
+ deprecated=True,
+ description="List of messages (deprecated, use message instead)",
+ ),
+ search_mode: SearchMode = Body(
+ default=SearchMode.custom,
+ description="Pre-configured search modes: basic, advanced, or custom.",
+ ),
+ search_settings: Optional[SearchSettings] = Body(
+ None,
+ description="The search configuration object for retrieving context.",
+ ),
+ # Generation configurations
+ rag_generation_config: GenerationConfig = Body(
+ default_factory=GenerationConfig,
+ description="Configuration for RAG generation in 'rag' mode",
+ ),
+ research_generation_config: Optional[GenerationConfig] = Body(
+ None,
+ description="Configuration for generation in 'research' mode. If not provided but mode='research', rag_generation_config will be used with appropriate model overrides.",
+ ),
+ # Tool configurations
+ rag_tools: Optional[
+ list[
+ Literal[
+ "web_search",
+ "web_scrape",
+ "search_file_descriptions",
+ "search_file_knowledge",
+ "get_file_content",
+ ]
+ ]
+ ] = Body(
+ None,
+ description="List of tools to enable for RAG mode. Available tools: search_file_knowledge, get_file_content, web_search, web_scrape, search_file_descriptions",
+ ),
+ research_tools: Optional[
+ list[
+ Literal["rag", "reasoning", "critique", "python_executor"]
+ ]
+ ] = Body(
+ None,
+ description="List of tools to enable for Research mode. Available tools: rag, reasoning, critique, python_executor",
+ ),
+ # Backward compatibility
+ tools: Optional[list[str]] = Body(
+ None,
+ deprecated=True,
+ description="List of tools to execute (deprecated, use rag_tools or research_tools instead)",
+ ),
+ # Other parameters
+ task_prompt: Optional[str] = Body(
+ default=None,
+ description="Optional custom prompt to override default",
+ ),
+ # Backward compatibility
+ task_prompt_override: Optional[str] = Body(
+ default=None,
+ deprecated=True,
+ description="Optional custom prompt to override default",
+ ),
+ include_title_if_available: bool = Body(
+ default=True,
+ description="Pass document titles from search results into the LLM context window.",
+ ),
+ conversation_id: Optional[UUID] = Body(
+ default=None,
+ description="ID of the conversation",
+ ),
+ max_tool_context_length: Optional[int] = Body(
+ default=32_768,
+ description="Maximum length of returned tool context",
+ ),
+ use_system_context: Optional[bool] = Body(
+ default=True,
+ description="Use extended prompt for generation",
+ ),
+ mode: Optional[Literal["rag", "research"]] = Body(
+ default="rag",
+ description="Mode to use for generation: 'rag' for standard retrieval or 'research' for deep analysis with reasoning capabilities",
+ ),
+ needs_initial_conversation_name: Optional[bool] = Body(
+ default=None,
+ description="If true, the system will automatically assign a conversation name if not already specified previously.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedAgentResponse:
+ """
+ Engage with an intelligent agent for information retrieval, analysis, and research.
+
+ This endpoint offers two operating modes:
+ - **RAG mode**: Standard retrieval-augmented generation for answering questions based on knowledge base
+ - **Research mode**: Advanced capabilities for deep analysis, reasoning, and computation
+
+ ### RAG Mode (Default)
+
+ The RAG mode provides fast, knowledge-based responses using:
+ - Semantic and hybrid search capabilities
+ - Document-level and chunk-level content retrieval
+ - Optional web search integration
+ - Source citation and evidence-based responses
+
+ ### Research Mode
+
+ The Research mode builds on RAG capabilities and adds:
+ - A dedicated reasoning system for complex problem-solving
+ - Critique capabilities to identify potential biases or logical fallacies
+ - Python execution for computational analysis
+ - Multi-step reasoning for deeper exploration of topics
+
+ ### Available Tools
+
+ **RAG Tools:**
+ - `search_file_knowledge`: Semantic/hybrid search on your ingested documents
+ - `search_file_descriptions`: Search over file-level metadata
+ - `content`: Fetch entire documents or chunk structures
+ - `web_search`: Query external search APIs for up-to-date information
+ - `web_scrape`: Scrape and extract content from specific web pages
+
+ **Research Tools:**
+ - `rag`: Leverage the underlying RAG agent for information retrieval
+ - `reasoning`: Call a dedicated model for complex analytical thinking
+ - `critique`: Analyze conversation history to identify flaws and biases
+ - `python_executor`: Execute Python code for complex calculations and analysis
+
+ ### Streaming Output
+
+ When streaming is enabled, the agent produces different event types:
+ - `thinking`: Shows the model's step-by-step reasoning (when extended_thinking=true)
+ - `tool_call`: Shows when the agent invokes a tool
+ - `tool_result`: Shows the result of a tool call
+ - `citation`: Indicates when a citation is added to the response
+ - `message`: Streams partial tokens of the response
+ - `final_answer`: Contains the complete generated answer and structured citations
+
+ ### Conversations
+
+ Maintain context across multiple turns by including `conversation_id` in each request.
+ After your first call, store the returned `conversation_id` and include it in subsequent calls.
+ If no conversation name has already been set for the conversation, the system will automatically assign one.
+
+ """
+ # Handle backward compatibility for task_prompt
+ task_prompt = task_prompt or task_prompt_override
+ # Handle model selection based on mode
+ if "model" not in rag_generation_config.__fields_set__:
+ if mode == "rag":
+ rag_generation_config.model = self.config.app.quality_llm
+ elif mode == "research":
+ rag_generation_config.model = self.config.app.planning_llm
+
+ # Prepare search settings
+ effective_settings = self._prepare_search_settings(
+ auth_user, search_mode, search_settings
+ )
+
+ # Handle tool configuration and backward compatibility
+ if tools: # Handle deprecated tools parameter
+ logger.warning(
+ "The 'tools' parameter is deprecated. Use 'rag_tools' or 'research_tools' based on mode."
+ )
+ rag_tools = tools # type: ignore
+
+ # Determine effective generation config
+ effective_generation_config = rag_generation_config
+ if mode == "research" and research_generation_config:
+ effective_generation_config = research_generation_config
+
+ try:
+ response = await self.services.retrieval.agent(
+ message=message,
+ messages=messages,
+ search_settings=effective_settings,
+ rag_generation_config=rag_generation_config,
+ research_generation_config=research_generation_config,
+ task_prompt=task_prompt,
+ include_title_if_available=include_title_if_available,
+ max_tool_context_length=max_tool_context_length or 32_768,
+ conversation_id=(
+ str(conversation_id) if conversation_id else None # type: ignore
+ ),
+ use_system_context=use_system_context
+ if use_system_context is not None
+ else True,
+ rag_tools=rag_tools, # type: ignore
+ research_tools=research_tools, # type: ignore
+ mode=mode,
+ needs_initial_conversation_name=needs_initial_conversation_name,
+ )
+
+ if effective_generation_config.stream:
+
+ async def stream_generator():
+ try:
+ async for chunk in response:
+ if len(chunk) > 1024:
+ for i in range(0, len(chunk), 1024):
+ yield chunk[i : i + 1024]
+ else:
+ yield chunk
+ except GeneratorExit:
+ # Clean up if needed, then return
+ return
+
+ return StreamingResponse( # type: ignore
+ stream_generator(), media_type="text/event-stream"
+ )
+ else:
+ return response
+ except Exception as e:
+ logger.error(f"Error in agent_app: {e}")
+ raise R2RException(str(e), 500) from e
+
+ @self.router.post(
+ "/retrieval/completion",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Generate Message Completions",
+ openapi_extra=EXAMPLES["completion"],
+ )
+ @self.base_endpoint
+ async def completion(
+ messages: list[Message] = Body(
+ ...,
+ description="List of messages to generate completion for",
+ example=[
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": "What is the capital of France?",
+ },
+ {
+ "role": "assistant",
+ "content": "The capital of France is Paris.",
+ },
+ {"role": "user", "content": "What about Italy?"},
+ ],
+ ),
+ generation_config: GenerationConfig = Body(
+ default_factory=GenerationConfig,
+ description="Configuration for text generation",
+ example={
+ "model": "openai/gpt-4o-mini",
+ "temperature": 0.7,
+ "max_tokens": 150,
+ "stream": False,
+ },
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ response_model=WrappedCompletionResponse,
+ ) -> WrappedLLMChatCompletion:
+ """Generate completions for a list of messages.
+
+ This endpoint uses the language model to generate completions for
+ the provided messages. The generation process can be customized
+ using the generation_config parameter.
+
+ The messages list should contain alternating user and assistant
+ messages, with an optional system message at the start. Each
+ message should have a 'role' and 'content'.
+ """
+
+ return await self.services.retrieval.completion(
+ messages=messages, # type: ignore
+ generation_config=generation_config,
+ )
+
+ @self.router.post(
+ "/retrieval/embedding",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Generate Embeddings",
+ openapi_extra=EXAMPLES["embedding"],
+ )
+ @self.base_endpoint
+ async def embedding(
+ text: str = Body(
+ ...,
+ description="Text to generate embeddings for",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedEmbeddingResponse:
+ """Generate embeddings for the provided text using the specified
+ model.
+
+ This endpoint uses the language model to generate embeddings for
+ the provided text. The model parameter specifies the model to use
+ for generating embeddings.
+ """
+
+ return await self.services.retrieval.embedding(
+ text=text,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py
new file mode 100644
index 00000000..682be750
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/system_router.py
@@ -0,0 +1,186 @@
+import logging
+import textwrap
+from datetime import datetime, timezone
+
+import psutil
+from fastapi import Depends
+
+from core.base import R2RException
+from core.base.api.models import (
+ GenericMessageResponse,
+ WrappedGenericMessageResponse,
+ WrappedServerStatsResponse,
+ WrappedSettingsResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+
+class SystemRouter(BaseRouterV3):
+ def __init__(
+ self,
+ providers: R2RProviders,
+ services: R2RServices,
+ config: R2RConfig,
+ ):
+ logging.info("Initializing SystemRouter")
+ super().__init__(providers, services, config)
+ self.start_time = datetime.now(timezone.utc)
+
+ def _setup_routes(self):
+ @self.router.get(
+ "/health",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.system.health()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.system.health();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/health"\\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def health_check() -> WrappedGenericMessageResponse:
+ return GenericMessageResponse(message="ok") # type: ignore
+
+ @self.router.get(
+ "/system/settings",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.system.settings()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.system.settings();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/system/settings" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def app_settings(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedSettingsResponse:
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only a superuser can call the `system/settings` endpoint.",
+ 403,
+ )
+ return await self.services.management.app_settings()
+
+ @self.router.get(
+ "/system/status",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # when using auth, do client.login(...)
+
+ result = client.system.status()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.system.status();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/system/status" \\
+ -H "Content-Type: application/json" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def server_stats(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedServerStatsResponse:
+ if not auth_user.is_superuser:
+ raise R2RException(
+ "Only an authorized user can call the `system/status` endpoint.",
+ 403,
+ )
+ return { # type: ignore
+ "start_time": self.start_time.isoformat(),
+ "uptime_seconds": (
+ datetime.now(timezone.utc) - self.start_time
+ ).total_seconds(),
+ "cpu_usage": psutil.cpu_percent(),
+ "memory_usage": psutil.virtual_memory().percent,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py b/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py
new file mode 100644
index 00000000..686f0013
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/api/v3/users_router.py
@@ -0,0 +1,1721 @@
+import logging
+import os
+import textwrap
+import urllib.parse
+from typing import Optional
+from uuid import UUID
+
+import requests
+from fastapi import Body, Depends, HTTPException, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+from google.auth.transport import requests as google_requests
+from google.oauth2 import id_token
+from pydantic import EmailStr
+
+from core.base import R2RException
+from core.base.api.models import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ WrappedAPIKeyResponse,
+ WrappedAPIKeysResponse,
+ WrappedBooleanResponse,
+ WrappedCollectionsResponse,
+ WrappedGenericMessageResponse,
+ WrappedLimitsResponse,
+ WrappedLoginResponse,
+ WrappedTokenResponse,
+ WrappedUserResponse,
+ WrappedUsersResponse,
+)
+
+from ...abstractions import R2RProviders, R2RServices
+from ...config import R2RConfig
+from .base_router import BaseRouterV3
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+class UsersRouter(BaseRouterV3):
+ def __init__(
+ self, providers: R2RProviders, services: R2RServices, config: R2RConfig
+ ):
+ logging.info("Initializing UsersRouter")
+ super().__init__(providers, services, config)
+ self.google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
+ self.google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
+ self.google_redirect_uri = os.environ.get("GOOGLE_REDIRECT_URI")
+
+ self.github_client_id = os.environ.get("GITHUB_CLIENT_ID")
+ self.github_client_secret = os.environ.get("GITHUB_CLIENT_SECRET")
+ self.github_redirect_uri = os.environ.get("GITHUB_REDIRECT_URI")
+
+ def _setup_routes(self):
+ @self.router.post(
+ "/users",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedUserResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ new_user = client.users.create(
+ email="jane.doe@example.com",
+ password="secure_password123"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.create({
+ email: "jane.doe@example.com",
+ password: "secure_password123"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "email": "jane.doe@example.com",
+ "password": "secure_password123"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def register(
+ email: EmailStr = Body(..., description="User's email address"),
+ password: str = Body(..., description="User's password"),
+ name: str | None = Body(
+ None, description="The name for the new user"
+ ),
+ bio: str | None = Body(
+ None, description="The bio for the new user"
+ ),
+ profile_picture: str | None = Body(
+ None, description="Updated user profile picture"
+ ),
+ # auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Register a new user with the given email and password."""
+
+ # TODO: Do we really want this validation? The default password for the superuser would not pass...
+ def validate_password(password: str) -> bool:
+ if len(password) < 10:
+ return False
+ if not any(c.isupper() for c in password):
+ return False
+ if not any(c.islower() for c in password):
+ return False
+ if not any(c.isdigit() for c in password):
+ return False
+ if not any(c in "!@#$%^&*" for c in password):
+ return False
+ return True
+
+ # if not validate_password(password):
+ # raise R2RException(
+ # f"Password must be at least 10 characters long and contain at least one uppercase letter, one lowercase letter, one digit, and one special character from '!@#$%^&*'.",
+ # 400,
+ # )
+
+ registration_response = await self.services.auth.register(
+ email=email,
+ password=password,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+
+ return registration_response # type: ignore
+
+ @self.router.post(
+ "/users/export",
+ summary="Export users to CSV",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient("http://localhost:7272")
+ # when using auth, do client.login(...)
+
+ response = client.users.export(
+ output_path="export.csv",
+ columns=["id", "name", "created_at"],
+ include_header=True,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient("http://localhost:7272");
+
+ function main() {
+ await client.users.export({
+ outputPath: "export.csv",
+ columns: ["id", "name", "created_at"],
+ includeHeader: true,
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "http://127.0.0.1:7272/v3/users/export" \
+ -H "Authorization: Bearer YOUR_API_KEY" \
+ -H "Content-Type: application/json" \
+ -H "Accept: text/csv" \
+ -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \
+ --output export.csv
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def export_users(
+ background_tasks: BackgroundTasks,
+ columns: Optional[list[str]] = Body(
+ None, description="Specific columns to export"
+ ),
+ filters: Optional[dict] = Body(
+ None, description="Filters to apply to the export"
+ ),
+ include_header: Optional[bool] = Body(
+ True, description="Whether to include column headers"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> FileResponse:
+ """Export users as a CSV file."""
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ status_code=403,
+ message="Only a superuser can export data.",
+ )
+
+ (
+ csv_file_path,
+ temp_file,
+ ) = await self.services.management.export_users(
+ columns=columns,
+ filters=filters,
+ include_header=include_header
+ if include_header is not None
+ else True,
+ )
+
+ background_tasks.add_task(temp_file.close)
+
+ return FileResponse(
+ path=csv_file_path,
+ media_type="text/csv",
+ filename="users_export.csv",
+ )
+
+ @self.router.post(
+ "/users/verify-email",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ tokens = client.users.verify_email(
+ email="jane.doe@example.com",
+ verification_code="1lklwal!awdclm"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.verifyEmail({
+ email: jane.doe@example.com",
+ verificationCode: "1lklwal!awdclm"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/login" \\
+ -H "Content-Type: application/x-www-form-urlencoded" \\
+ -d "email=jane.doe@example.com&verification_code=1lklwal!awdclm"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def verify_email(
+ email: EmailStr = Body(..., description="User's email address"),
+ verification_code: str = Body(
+ ..., description="Email verification code"
+ ),
+ ) -> WrappedGenericMessageResponse:
+ """Verify a user's email address."""
+ user = (
+ await self.providers.database.users_handler.get_user_by_email(
+ email
+ )
+ )
+ if user and user.is_verified:
+ raise R2RException(
+ status_code=400,
+ message="This email is already verified. Please log in.",
+ )
+
+ result = await self.services.auth.verify_email(
+ email, verification_code
+ )
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/send-verification-email",
+ dependencies=[
+ Depends(self.providers.auth.auth_wrapper(public=True))
+ ],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ tokens = client.users.send_verification_email(
+ email="jane.doe@example.com",
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.sendVerificationEmail({
+ email: jane.doe@example.com",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/send-verification-email" \\
+ -H "Content-Type: application/x-www-form-urlencoded" \\
+ -d "email=jane.doe@example.com"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def send_verification_email(
+ email: EmailStr = Body(..., description="User's email address"),
+ ) -> WrappedGenericMessageResponse:
+ """Send a user's email a verification code."""
+ user = (
+ await self.providers.database.users_handler.get_user_by_email(
+ email
+ )
+ )
+ if user and user.is_verified:
+ raise R2RException(
+ status_code=400,
+ message="This email is already verified. Please log in.",
+ )
+
+ await self.services.auth.send_verification_email(email=email)
+ return GenericMessageResponse(
+ message="A verification email has been sent."
+ ) # type: ignore
+
+ @self.router.post(
+ "/users/login",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedTokenResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ tokens = client.users.login(
+ email="jane.doe@example.com",
+ password="secure_password123"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.login({
+ email: jane.doe@example.com",
+ password: "secure_password123"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/login" \\
+ -H "Content-Type: application/x-www-form-urlencoded" \\
+ -d "username=jane.doe@example.com&password=secure_password123"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def login(
+ form_data: OAuth2PasswordRequestForm = Depends(),
+ ) -> WrappedLoginResponse:
+ """Authenticate a user and provide access tokens."""
+ return await self.services.auth.login( # type: ignore
+ form_data.username, form_data.password
+ )
+
+ @self.router.post(
+ "/users/logout",
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+ result = client.users.logout()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.logout();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/logout" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def logout(
+ token: str = Depends(oauth2_scheme),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Log out the current user."""
+ result = await self.services.auth.logout(token)
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/refresh-token",
+ # dependencies=[Depends(self.rate_limit_dependency)],
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ new_tokens = client.users.refresh_token()
+ # New tokens are automatically stored in the client"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.refreshAccessToken();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/refresh-token" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "refresh_token": "YOUR_REFRESH_TOKEN"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def refresh_token(
+ refresh_token: str = Body(..., description="Refresh token"),
+ ) -> WrappedTokenResponse:
+ """Refresh the access token using a refresh token."""
+ result = await self.services.auth.refresh_access_token(
+ refresh_token=refresh_token
+ )
+ return result # type: ignore
+
+ @self.router.post(
+ "/users/change-password",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ result = client.users.change_password(
+ current_password="old_password123",
+ new_password="new_secure_password456"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.changePassword({
+ currentPassword: "old_password123",
+ newPassword: "new_secure_password456"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/change-password" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "current_password": "old_password123",
+ "new_password": "new_secure_password456"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def change_password(
+ current_password: str = Body(..., description="Current password"),
+ new_password: str = Body(..., description="New password"),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedGenericMessageResponse:
+ """Change the authenticated user's password."""
+ result = await self.services.auth.change_password(
+ auth_user, current_password, new_password
+ )
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/request-password-reset",
+ dependencies=[
+ Depends(self.providers.auth.auth_wrapper(public=True))
+ ],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ result = client.users.request_password_reset(
+ email="jane.doe@example.com"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.requestPasswordReset({
+ email: jane.doe@example.com",
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/request-password-reset" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "email": "jane.doe@example.com"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def request_password_reset(
+ email: EmailStr = Body(..., description="User's email address"),
+ ) -> WrappedGenericMessageResponse:
+ """Request a password reset for a user."""
+ result = await self.services.auth.request_password_reset(email)
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.post(
+ "/users/reset-password",
+ dependencies=[
+ Depends(self.providers.auth.auth_wrapper(public=True))
+ ],
+ response_model=WrappedGenericMessageResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ result = client.users.reset_password(
+ reset_token="reset_token_received_via_email",
+ new_password="new_secure_password789"
+ )"""),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.resetPassword({
+ resestToken: "reset_token_received_via_email",
+ newPassword: "new_secure_password789"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/v3/users/reset-password" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "reset_token": "reset_token_received_via_email",
+ "new_password": "new_secure_password789"
+ }'"""),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def reset_password(
+ reset_token: str = Body(..., description="Password reset token"),
+ new_password: str = Body(..., description="New password"),
+ ) -> WrappedGenericMessageResponse:
+ """Reset a user's password using a reset token."""
+ result = await self.services.auth.confirm_password_reset(
+ reset_token, new_password
+ )
+ return GenericMessageResponse(message=result["message"]) # type: ignore
+
+ @self.router.get(
+ "/users",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List Users",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # List users with filters
+ users = client.users.list(
+ offset=0,
+ limit=100,
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.list();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users?offset=0&limit=100&username=john&email=john@example.com&is_active=true&is_superuser=false" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_users(
+ ids: list[str] = Query(
+ [], description="List of user IDs to filter by"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUsersResponse:
+ """List all users with pagination and filtering options.
+
+ Only accessible by superusers.
+ """
+
+ if not auth_user.is_superuser:
+ raise R2RException(
+ status_code=403,
+ message="Only a superuser can call the `users_overview` endpoint.",
+ )
+
+ user_uuids = [UUID(user_id) for user_id in ids]
+
+ users_overview_response = (
+ await self.services.management.users_overview(
+ user_ids=user_uuids, offset=offset, limit=limit
+ )
+ )
+ return users_overview_response["results"], { # type: ignore
+ "total_entries": users_overview_response["total_entries"]
+ }
+
+ @self.router.get(
+ "/users/me",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get the Current User",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Get user details
+ users = client.users.me()
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.retrieve();
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/me" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_current_user(
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Get detailed information about the currently authenticated
+ user."""
+ return auth_user
+
+ @self.router.get(
+ "/users/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get User Details",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Get user details
+ users = client.users.retrieve(
+ id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.retrieve({
+ id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_user(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Get detailed information about a specific user.
+
+ Users can only access their own information unless they are
+ superusers.
+ """
+ if not auth_user.is_superuser and auth_user.id != id:
+ raise R2RException(
+ "Only a superuser can call the get `user` endpoint for other users.",
+ 403,
+ )
+
+ users_overview_response = (
+ await self.services.management.users_overview(
+ offset=0,
+ limit=1,
+ user_ids=[id],
+ )
+ )
+
+ return users_overview_response["results"][0]
+
+ @self.router.delete(
+ "/users/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete User",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Delete user
+ client.users.delete(id="550e8400-e29b-41d4-a716-446655440000", password="secure_password123")
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.delete({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ password: "secure_password123"
+ });
+ }
+
+ main();
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_user(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ password: Optional[str] = Body(
+ None, description="User's current password"
+ ),
+ delete_vector_data: Optional[bool] = Body(
+ False,
+ description="Whether to delete the user's vector data",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific user.
+
+ Users can only delete their own account unless they are superusers.
+ """
+ if not auth_user.is_superuser and auth_user.id != id:
+ raise R2RException(
+ "Only a superuser can delete other users.",
+ 403,
+ )
+
+ await self.services.auth.delete_user(
+ user_id=id,
+ password=password,
+ delete_vector_data=delete_vector_data or False,
+ is_superuser=auth_user.is_superuser,
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.get(
+ "/users/{id}/collections",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Get User Collections",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Get user collections
+ collections = client.user.list_collections(
+ "550e8400-e29b-41d4-a716-446655440000",
+ offset=0,
+ limit=100
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.listCollections({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ offset: 0,
+ limit: 100
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections?offset=0&limit=100" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_user_collections(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ offset: int = Query(
+ 0,
+ ge=0,
+ description="Specifies the number of objects to skip. Defaults to 0.",
+ ),
+ limit: int = Query(
+ 100,
+ ge=1,
+ le=1000,
+ description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedCollectionsResponse:
+ """Get all collections associated with a specific user.
+
+ Users can only access their own collections unless they are
+ superusers.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection.",
+ 403,
+ )
+ user_collection_response = (
+ await self.services.management.collections_overview(
+ offset=offset,
+ limit=limit,
+ user_ids=[id],
+ )
+ )
+ return user_collection_response["results"], { # type: ignore
+ "total_entries": user_collection_response["total_entries"]
+ }
+
+ @self.router.post(
+ "/users/{id}/collections/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Add User to Collection",
+ response_model=WrappedBooleanResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Add user to collection
+ client.users.add_to_collection(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ collection_id="750e8400-e29b-41d4-a716-446655440000"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.addToCollection({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ collectionId: "750e8400-e29b-41d4-a716-446655440000"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def add_user_to_collection(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ collection_id: UUID = Path(
+ ..., example="750e8400-e29b-41d4-a716-446655440000"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection.",
+ 403,
+ )
+
+ # TODO - Do we need a check on user access to the collection?
+ await self.services.management.add_user_to_collection( # type: ignore
+ id, collection_id
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.delete(
+ "/users/{id}/collections/{collection_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Remove User from Collection",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Remove user from collection
+ client.users.remove_from_collection(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ collection_id="750e8400-e29b-41d4-a716-446655440000"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.removeFromCollection({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ collectionId: "750e8400-e29b-41d4-a716-446655440000"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def remove_user_from_collection(
+ id: UUID = Path(
+ ..., example="550e8400-e29b-41d4-a716-446655440000"
+ ),
+ collection_id: UUID = Path(
+ ..., example="750e8400-e29b-41d4-a716-446655440000"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Remove a user from a collection.
+
+ Requires either superuser status or access to the collection.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "The currently authenticated user does not have access to the specified collection.",
+ 403,
+ )
+
+ # TODO - Do we need a check on user access to the collection?
+ await self.services.management.remove_user_from_collection( # type: ignore
+ id, collection_id
+ )
+ return GenericBooleanResponse(success=True) # type: ignore
+
+ @self.router.post(
+ "/users/{id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Update User",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ # Update user
+ updated_user = client.update_user(
+ "550e8400-e29b-41d4-a716-446655440000",
+ name="John Doe"
+ )
+ """),
+ },
+ {
+ "lang": "JavaScript",
+ "source": textwrap.dedent("""
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+
+ function main() {
+ const response = await client.users.update({
+ id: "550e8400-e29b-41d4-a716-446655440000",
+ name: "John Doe"
+ });
+ }
+
+ main();
+ """),
+ },
+ {
+ "lang": "Shell",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\
+ -H "Authorization: Bearer YOUR_API_KEY" \\
+ -H "Content-Type: application/json" \\
+ -d '{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "John Doe",
+ }'
+ """),
+ },
+ ]
+ },
+ )
+ # TODO - Modify update user to have synced params with user object
+ @self.base_endpoint
+ async def update_user(
+ id: UUID = Path(..., description="ID of the user to update"),
+ email: EmailStr | None = Body(
+ None, description="Updated email address"
+ ),
+ is_superuser: bool | None = Body(
+ None, description="Updated superuser status"
+ ),
+ name: str | None = Body(None, description="Updated user name"),
+ bio: str | None = Body(None, description="Updated user bio"),
+ profile_picture: str | None = Body(
+ None, description="Updated profile picture URL"
+ ),
+ limits_overrides: dict = Body(
+ None,
+ description="Updated limits overrides",
+ ),
+ metadata: dict[str, str | None] | None = None,
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedUserResponse:
+ """Update user information.
+
+ Users can only update their own information unless they are
+ superusers. Superuser status can only be modified by existing
+ superusers.
+ """
+
+ if is_superuser is not None and not auth_user.is_superuser:
+ raise R2RException(
+ "Only superusers can update the superuser status of a user",
+ 403,
+ )
+
+ if not auth_user.is_superuser and auth_user.id != id:
+ raise R2RException(
+ "Only superusers can update other users' information",
+ 403,
+ )
+
+ if not auth_user.is_superuser and limits_overrides is not None:
+ raise R2RException(
+ "Only superusers can update other users' limits overrides",
+ 403,
+ )
+
+ # Pass `metadata` to our auth or management service so it can do a
+ # partial (Stripe-like) merge of metadata.
+ return await self.services.auth.update_user( # type: ignore
+ user_id=id,
+ email=email,
+ is_superuser=is_superuser,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ limits_overrides=limits_overrides,
+ new_metadata=metadata,
+ )
+
+ @self.router.post(
+ "/users/{id}/api-keys",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Create User API Key",
+ response_model=WrappedAPIKeyResponse,
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ result = client.users.create_api_key(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ name="My API Key",
+ description="API key for accessing the app",
+ )
+ # result["api_key"] contains the newly created API key
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\
+ -H "Authorization: Bearer YOUR_API_TOKEN" \\
+ -d '{"name": "My API Key", "description": "API key for accessing the app"}'
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def create_user_api_key(
+ id: UUID = Path(
+ ..., description="ID of the user for whom to create an API key"
+ ),
+ name: Optional[str] = Body(
+ None, description="Name of the API key"
+ ),
+ description: Optional[str] = Body(
+ None, description="Description of the API key"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedAPIKeyResponse:
+ """Create a new API key for the specified user.
+
+ Only superusers or the user themselves may create an API key.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "Only the user themselves or a superuser can create API keys for this user.",
+ 403,
+ )
+
+ api_key = await self.services.auth.create_user_api_key(
+ id, name=name, description=description
+ )
+ return api_key # type: ignore
+
+ @self.router.get(
+ "/users/{id}/api-keys",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="List User API Keys",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ keys = client.users.list_api_keys(
+ id="550e8400-e29b-41d4-a716-446655440000"
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\
+ -H "Authorization: Bearer YOUR_API_TOKEN"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def list_user_api_keys(
+ id: UUID = Path(
+ ..., description="ID of the user whose API keys to list"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedAPIKeysResponse:
+ """List all API keys for the specified user.
+
+ Only superusers or the user themselves may list the API keys.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "Only the user themselves or a superuser can list API keys for this user.",
+ 403,
+ )
+
+ keys = (
+ await self.providers.database.users_handler.get_user_api_keys(
+ id
+ )
+ )
+ return keys, {"total_entries": len(keys)} # type: ignore
+
+ @self.router.delete(
+ "/users/{id}/api-keys/{key_id}",
+ dependencies=[Depends(self.rate_limit_dependency)],
+ summary="Delete User API Key",
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": textwrap.dedent("""
+ from r2r import R2RClient
+ from uuid import UUID
+
+ client = R2RClient()
+ # client.login(...)
+
+ response = client.users.delete_api_key(
+ id="550e8400-e29b-41d4-a716-446655440000",
+ key_id="d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25"
+ )
+ """),
+ },
+ {
+ "lang": "cURL",
+ "source": textwrap.dedent("""
+ curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys/d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" \\
+ -H "Authorization: Bearer YOUR_API_TOKEN"
+ """),
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def delete_user_api_key(
+ id: UUID = Path(..., description="ID of the user"),
+ key_id: UUID = Path(
+ ..., description="ID of the API key to delete"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedBooleanResponse:
+ """Delete a specific API key for the specified user.
+
+ Only superusers or the user themselves may delete the API key.
+ """
+ if auth_user.id != id and not auth_user.is_superuser:
+ raise R2RException(
+ "Only the user themselves or a superuser can delete this API key.",
+ 403,
+ )
+
+ success = (
+ await self.providers.database.users_handler.delete_api_key(
+ id, key_id
+ )
+ )
+ if not success:
+ raise R2RException(
+ "API key not found or could not be deleted", 400
+ )
+ return {"success": True} # type: ignore
+
+ @self.router.get(
+ "/users/{id}/limits",
+ summary="Fetch User Limits",
+ responses={
+ 200: {
+ "description": "Returns system default limits, user overrides, and final effective settings."
+ },
+ 403: {
+ "description": "If the requesting user is neither the same user nor a superuser."
+ },
+ 404: {"description": "If the user ID does not exist."},
+ },
+ openapi_extra={
+ "x-codeSamples": [
+ {
+ "lang": "Python",
+ "source": """
+ from r2r import R2RClient
+
+ client = R2RClient()
+ # client.login(...)
+
+ user_limits = client.users.get_limits("550e8400-e29b-41d4-a716-446655440000")
+ """,
+ },
+ {
+ "lang": "JavaScript",
+ "source": """
+ const { r2rClient } = require("r2r-js");
+
+ const client = new r2rClient();
+ // await client.users.login(...)
+
+ async function main() {
+ const userLimits = await client.users.getLimits({
+ id: "550e8400-e29b-41d4-a716-446655440000"
+ });
+ console.log(userLimits);
+ }
+
+ main();
+ """,
+ },
+ {
+ "lang": "cURL",
+ "source": """
+ curl -X GET "https://api.example.com/v3/users/550e8400-e29b-41d4-a716-446655440000/limits" \\
+ -H "Authorization: Bearer YOUR_API_KEY"
+ """,
+ },
+ ]
+ },
+ )
+ @self.base_endpoint
+ async def get_user_limits(
+ id: UUID = Path(
+ ..., description="ID of the user to fetch limits for"
+ ),
+ auth_user=Depends(self.providers.auth.auth_wrapper()),
+ ) -> WrappedLimitsResponse:
+ """Return the system default limits, user-level overrides, and
+ final "effective" limit settings for the specified user.
+
+ Only superusers or the user themself may fetch these values.
+ """
+ if (auth_user.id != id) and (not auth_user.is_superuser):
+ raise R2RException(
+ "Only the user themselves or a superuser can view these limits.",
+ status_code=403,
+ )
+
+ # This calls the new helper you created in ManagementService
+ limits_info = await self.services.management.get_all_user_limits(
+ id
+ )
+ return limits_info # type: ignore
+
+ @self.router.get("/users/oauth/google/authorize")
+ @self.base_endpoint
+ async def google_authorize() -> WrappedGenericMessageResponse:
+ """Redirect user to Google's OAuth 2.0 consent screen."""
+ state = "some_random_string_or_csrf_token" # Usually you store a random state in session/Redis
+ scope = "openid email profile"
+
+ # Build the Google OAuth URL
+ params = {
+ "client_id": self.google_client_id,
+ "redirect_uri": self.google_redirect_uri,
+ "response_type": "code",
+ "scope": scope,
+ "state": state,
+ "access_type": "offline", # to get refresh token if needed
+ "prompt": "consent", # Force consent each time if you want
+ }
+ google_auth_url = f"https://accounts.google.com/o/oauth2/v2/auth?{urllib.parse.urlencode(params)}"
+ return GenericMessageResponse(message=google_auth_url) # type: ignore
+
+ @self.router.get("/users/oauth/google/callback")
+ @self.base_endpoint
+ async def google_callback(
+ code: str = Query(...), state: str = Query(...)
+ ) -> WrappedLoginResponse:
+ """Google's callback that will receive the `code` and `state`.
+
+ We then exchange code for tokens, verify, and log the user in.
+ """
+ # 1. Exchange `code` for tokens
+ token_data = requests.post(
+ "https://oauth2.googleapis.com/token",
+ data={
+ "code": code,
+ "client_id": self.google_client_id,
+ "client_secret": self.google_client_secret,
+ "redirect_uri": self.google_redirect_uri,
+ "grant_type": "authorization_code",
+ },
+ ).json()
+ if "error" in token_data:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to get token: {token_data}",
+ )
+
+ # 2. Verify the ID token
+ id_token_str = token_data["id_token"]
+ try:
+ # google_auth.transport.requests.Request() is a session for verifying
+ id_info = id_token.verify_oauth2_token(
+ id_token_str,
+ google_requests.Request(),
+ self.google_client_id,
+ )
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Token verification failed: {str(e)}",
+ ) from e
+
+ # id_info will contain "sub", "email", etc.
+ google_id = id_info["sub"]
+ email = id_info.get("email")
+ email = email or f"{google_id}@google_oauth.fake"
+
+ # 3. Now call our R2RAuthProvider method that handles "oauth-based" user creation or login
+ return await self.providers.auth.oauth_callback_handler( # type: ignore
+ provider="google",
+ oauth_id=google_id,
+ email=email,
+ )
+
+ @self.router.get("/users/oauth/github/authorize")
+ @self.base_endpoint
+ async def github_authorize() -> WrappedGenericMessageResponse:
+ """Redirect user to GitHub's OAuth consent screen."""
+ state = "some_random_string_or_csrf_token"
+ scope = "read:user user:email"
+
+ params = {
+ "client_id": self.github_client_id,
+ "redirect_uri": self.github_redirect_uri,
+ "scope": scope,
+ "state": state,
+ }
+ github_auth_url = f"https://github.com/login/oauth/authorize?{urllib.parse.urlencode(params)}"
+ return GenericMessageResponse(message=github_auth_url) # type: ignore
+
+ @self.router.get("/users/oauth/github/callback")
+ @self.base_endpoint
+ async def github_callback(
+ code: str = Query(...), state: str = Query(...)
+ ) -> WrappedLoginResponse:
+ """GitHub callback route to exchange code for an access_token, then
+ fetch user info from GitHub's API, then do the same 'oauth-based'
+ login or registration."""
+ # 1. Exchange code for access_token
+ token_resp = requests.post(
+ "https://github.com/login/oauth/access_token",
+ data={
+ "client_id": self.github_client_id,
+ "client_secret": self.github_client_secret,
+ "code": code,
+ "redirect_uri": self.github_redirect_uri,
+ "state": state,
+ },
+ headers={"Accept": "application/json"},
+ )
+ token_data = token_resp.json()
+ if "error" in token_data:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Failed to get token: {token_data}",
+ )
+ access_token = token_data["access_token"]
+
+ # 2. Use the access_token to fetch user info
+ user_info_resp = requests.get(
+ "https://api.github.com/user",
+ headers={"Authorization": f"Bearer {access_token}"},
+ ).json()
+
+ github_id = str(
+ user_info_resp["id"]
+ ) # GitHub user ID is typically an integer
+ # fetch email (sometimes you need to call /user/emails endpoint if user sets email private)
+ email = user_info_resp.get("email")
+ email = email or f"{github_id}@github_oauth.fake"
+ # 3. Pass to your auth provider
+ return await self.providers.auth.oauth_callback_handler( # type: ignore
+ provider="github",
+ oauth_id=github_id,
+ email=email,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/app.py b/.venv/lib/python3.12/site-packages/core/main/app.py
new file mode 100644
index 00000000..ceb13cce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/app.py
@@ -0,0 +1,121 @@
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.openapi.utils import get_openapi
+from fastapi.responses import JSONResponse
+
+from core.base import R2RException
+from core.providers import (
+ HatchetOrchestrationProvider,
+ SimpleOrchestrationProvider,
+)
+from core.utils.sentry import init_sentry
+
+from .abstractions import R2RServices
+from .api.v3.chunks_router import ChunksRouter
+from .api.v3.collections_router import CollectionsRouter
+from .api.v3.conversations_router import ConversationsRouter
+from .api.v3.documents_router import DocumentsRouter
+from .api.v3.graph_router import GraphRouter
+from .api.v3.indices_router import IndicesRouter
+from .api.v3.prompts_router import PromptsRouter
+from .api.v3.retrieval_router import RetrievalRouter
+from .api.v3.system_router import SystemRouter
+from .api.v3.users_router import UsersRouter
+from .config import R2RConfig
+
+
+class R2RApp:
+ def __init__(
+ self,
+ config: R2RConfig,
+ orchestration_provider: (
+ HatchetOrchestrationProvider | SimpleOrchestrationProvider
+ ),
+ services: R2RServices,
+ chunks_router: ChunksRouter,
+ collections_router: CollectionsRouter,
+ conversations_router: ConversationsRouter,
+ documents_router: DocumentsRouter,
+ graph_router: GraphRouter,
+ indices_router: IndicesRouter,
+ prompts_router: PromptsRouter,
+ retrieval_router: RetrievalRouter,
+ system_router: SystemRouter,
+ users_router: UsersRouter,
+ ):
+ init_sentry()
+
+ self.config = config
+ self.services = services
+ self.chunks_router = chunks_router
+ self.collections_router = collections_router
+ self.conversations_router = conversations_router
+ self.documents_router = documents_router
+ self.graph_router = graph_router
+ self.indices_router = indices_router
+ self.orchestration_provider = orchestration_provider
+ self.prompts_router = prompts_router
+ self.retrieval_router = retrieval_router
+ self.system_router = system_router
+ self.users_router = users_router
+
+ self.app = FastAPI()
+
+ @self.app.exception_handler(R2RException)
+ async def r2r_exception_handler(request: Request, exc: R2RException):
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={
+ "message": exc.message,
+ "error_type": type(exc).__name__,
+ },
+ )
+
+ self._setup_routes()
+ self._apply_cors()
+
+ def _setup_routes(self):
+ self.app.include_router(self.chunks_router, prefix="/v3")
+ self.app.include_router(self.collections_router, prefix="/v3")
+ self.app.include_router(self.conversations_router, prefix="/v3")
+ self.app.include_router(self.documents_router, prefix="/v3")
+ self.app.include_router(self.graph_router, prefix="/v3")
+ self.app.include_router(self.indices_router, prefix="/v3")
+ self.app.include_router(self.prompts_router, prefix="/v3")
+ self.app.include_router(self.retrieval_router, prefix="/v3")
+ self.app.include_router(self.system_router, prefix="/v3")
+ self.app.include_router(self.users_router, prefix="/v3")
+
+ @self.app.get("/openapi_spec", include_in_schema=False)
+ async def openapi_spec():
+ return get_openapi(
+ title="R2R Application API",
+ version="1.0.0",
+ routes=self.app.routes,
+ )
+
+ def _apply_cors(self):
+ origins = ["*", "http://localhost:3000", "http://localhost:7272"]
+ self.app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ async def serve(self, host: str = "0.0.0.0", port: int = 7272):
+ import uvicorn
+
+ from core.utils.logging_config import configure_logging
+
+ configure_logging()
+
+ config = uvicorn.Config(
+ self.app,
+ host=host,
+ port=port,
+ log_config=None,
+ )
+ server = uvicorn.Server(config)
+ await server.serve()
diff --git a/.venv/lib/python3.12/site-packages/core/main/app_entry.py b/.venv/lib/python3.12/site-packages/core/main/app_entry.py
new file mode 100644
index 00000000..cd3ea84d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/app_entry.py
@@ -0,0 +1,125 @@
+import logging
+import os
+from contextlib import asynccontextmanager
+from typing import Optional
+
+from apscheduler.schedulers.asyncio import AsyncIOScheduler
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+
+from core.base import R2RException
+from core.utils.logging_config import configure_logging
+
+from .assembly import R2RBuilder, R2RConfig
+
+log_file = configure_logging()
+
+# Global scheduler
+scheduler = AsyncIOScheduler()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # Startup
+ r2r_app = await create_r2r_app(
+ config_name=config_name,
+ config_path=config_path,
+ )
+
+ # Copy all routes from r2r_app to app
+ app.router.routes = r2r_app.app.routes
+
+ # Copy middleware and exception handlers
+ app.middleware = r2r_app.app.middleware # type: ignore
+ app.exception_handlers = r2r_app.app.exception_handlers
+
+ # Start the scheduler
+ scheduler.start()
+
+ # Start the Hatchet worker
+ await r2r_app.orchestration_provider.start_worker()
+
+ yield
+
+ # # Shutdown
+ scheduler.shutdown()
+
+
+async def create_r2r_app(
+ config_name: Optional[str] = "default",
+ config_path: Optional[str] = None,
+):
+ config = R2RConfig.load(config_name=config_name, config_path=config_path)
+
+ if (
+ config.embedding.provider == "openai"
+ and "OPENAI_API_KEY" not in os.environ
+ ):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ # Build the R2RApp
+ builder = R2RBuilder(config=config)
+ return await builder.build()
+
+
+config_name = os.getenv("R2R_CONFIG_NAME", None)
+config_path = os.getenv("R2R_CONFIG_PATH", None)
+
+if not config_path and not config_name:
+ config_name = "default"
+host = os.getenv("R2R_HOST", os.getenv("HOST", "0.0.0.0"))
+port = int(os.getenv("R2R_PORT", "7272"))
+
+logging.info(
+ f"Environment R2R_IMAGE: {os.getenv('R2R_IMAGE')}",
+)
+logging.info(
+ f"Environment R2R_CONFIG_NAME: {'None' if config_name is None else config_name}"
+)
+logging.info(
+ f"Environment R2R_CONFIG_PATH: {'None' if config_path is None else config_path}"
+)
+logging.info(f"Environment R2R_PROJECT_NAME: {os.getenv('R2R_PROJECT_NAME')}")
+
+logging.info(
+ f"Environment R2R_POSTGRES_HOST: {os.getenv('R2R_POSTGRES_HOST')}"
+)
+logging.info(
+ f"Environment R2R_POSTGRES_DBNAME: {os.getenv('R2R_POSTGRES_DBNAME')}"
+)
+logging.info(
+ f"Environment R2R_POSTGRES_PORT: {os.getenv('R2R_POSTGRES_PORT')}"
+)
+logging.info(
+ f"Environment R2R_POSTGRES_PASSWORD: {os.getenv('R2R_POSTGRES_PASSWORD')}"
+)
+
+# Create the FastAPI app
+app = FastAPI(
+ lifespan=lifespan,
+ log_config=None,
+)
+
+
+@app.exception_handler(R2RException)
+async def r2r_exception_handler(request: Request, exc: R2RException):
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={
+ "message": exc.message,
+ "error_type": type(exc).__name__,
+ },
+ )
+
+
+# Add CORS middleware
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py
new file mode 100644
index 00000000..3d10f2b6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/__init__.py
@@ -0,0 +1,12 @@
+from ..config import R2RConfig
+from .builder import R2RBuilder
+from .factory import R2RProviderFactory
+
+__all__ = [
+ # Builder
+ "R2RBuilder",
+ # Config
+ "R2RConfig",
+ # Factory
+ "R2RProviderFactory",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py
new file mode 100644
index 00000000..f72a15c9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/builder.py
@@ -0,0 +1,127 @@
+import logging
+from typing import Any, Type
+
+from ..abstractions import R2RProviders, R2RServices
+from ..api.v3.chunks_router import ChunksRouter
+from ..api.v3.collections_router import CollectionsRouter
+from ..api.v3.conversations_router import ConversationsRouter
+from ..api.v3.documents_router import DocumentsRouter
+from ..api.v3.graph_router import GraphRouter
+from ..api.v3.indices_router import IndicesRouter
+from ..api.v3.prompts_router import PromptsRouter
+from ..api.v3.retrieval_router import RetrievalRouter
+from ..api.v3.system_router import SystemRouter
+from ..api.v3.users_router import UsersRouter
+from ..app import R2RApp
+from ..config import R2RConfig
+from ..services.auth_service import AuthService # noqa: F401
+from ..services.graph_service import GraphService # noqa: F401
+from ..services.ingestion_service import IngestionService # noqa: F401
+from ..services.management_service import ManagementService # noqa: F401
+from ..services.retrieval_service import ( # type: ignore
+ RetrievalService, # noqa: F401 # type: ignore
+)
+from .factory import R2RProviderFactory
+
+logger = logging.getLogger()
+
+
+class R2RBuilder:
+ _SERVICES = ["auth", "ingestion", "management", "retrieval", "graph"]
+
+ def __init__(self, config: R2RConfig):
+ self.config = config
+
+ async def build(self, *args, **kwargs) -> R2RApp:
+ provider_factory = R2RProviderFactory
+
+ try:
+ providers = await self._create_providers(
+ provider_factory, *args, **kwargs
+ )
+ except Exception as e:
+ logger.error(f"Error {e} while creating R2RProviders.")
+ raise
+
+ service_params = {
+ "config": self.config,
+ "providers": providers,
+ }
+
+ services = self._create_services(service_params)
+
+ routers = {
+ "chunks_router": ChunksRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "collections_router": CollectionsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "conversations_router": ConversationsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "documents_router": DocumentsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "graph_router": GraphRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "indices_router": IndicesRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "prompts_router": PromptsRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "retrieval_router": RetrievalRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "system_router": SystemRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ "users_router": UsersRouter(
+ providers=providers,
+ services=services,
+ config=self.config,
+ ).get_router(),
+ }
+
+ return R2RApp(
+ config=self.config,
+ orchestration_provider=providers.orchestration,
+ services=services,
+ **routers,
+ )
+
+ async def _create_providers(
+ self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
+ ) -> R2RProviders:
+ factory = provider_factory(self.config)
+ return await factory.create_providers(*args, **kwargs)
+
+ def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
+ services = R2RBuilder._SERVICES
+ service_instances = {}
+
+ for service_type in services:
+ service_class = globals()[f"{service_type.capitalize()}Service"]
+ service_instances[service_type] = service_class(**service_params)
+
+ return R2RServices(**service_instances)
diff --git a/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py
new file mode 100644
index 00000000..b982aa18
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/assembly/factory.py
@@ -0,0 +1,417 @@
+import logging
+import math
+import os
+from typing import Any, Optional
+
+from core.base import (
+ AuthConfig,
+ CompletionConfig,
+ CompletionProvider,
+ CryptoConfig,
+ DatabaseConfig,
+ EmailConfig,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ IngestionConfig,
+ OrchestrationConfig,
+)
+from core.providers import (
+ AnthropicCompletionProvider,
+ AsyncSMTPEmailProvider,
+ BcryptCryptoConfig,
+ BCryptCryptoProvider,
+ ClerkAuthProvider,
+ ConsoleMockEmailProvider,
+ HatchetOrchestrationProvider,
+ JwtAuthProvider,
+ LiteLLMCompletionProvider,
+ LiteLLMEmbeddingProvider,
+ MailerSendEmailProvider,
+ NaClCryptoConfig,
+ NaClCryptoProvider,
+ OllamaEmbeddingProvider,
+ OpenAICompletionProvider,
+ OpenAIEmbeddingProvider,
+ PostgresDatabaseProvider,
+ R2RAuthProvider,
+ R2RCompletionProvider,
+ R2RIngestionConfig,
+ R2RIngestionProvider,
+ SendGridEmailProvider,
+ SimpleOrchestrationProvider,
+ SupabaseAuthProvider,
+ UnstructuredIngestionConfig,
+ UnstructuredIngestionProvider,
+)
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+logger = logging.getLogger()
+
+
+class R2RProviderFactory:
+ def __init__(self, config: R2RConfig):
+ self.config = config
+
+ @staticmethod
+ async def create_auth_provider(
+ auth_config: AuthConfig,
+ crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: (
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ ),
+ *args,
+ **kwargs,
+ ) -> (
+ R2RAuthProvider
+ | SupabaseAuthProvider
+ | JwtAuthProvider
+ | ClerkAuthProvider
+ ):
+ if auth_config.provider == "r2r":
+ r2r_auth = R2RAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ await r2r_auth.initialize()
+ return r2r_auth
+ elif auth_config.provider == "supabase":
+ return SupabaseAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ elif auth_config.provider == "jwt":
+ return JwtAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ elif auth_config.provider == "clerk":
+ return ClerkAuthProvider(
+ auth_config, crypto_provider, database_provider, email_provider
+ )
+ else:
+ raise ValueError(
+ f"Auth provider {auth_config.provider} not supported."
+ )
+
+ @staticmethod
+ def create_crypto_provider(
+ crypto_config: CryptoConfig, *args, **kwargs
+ ) -> BCryptCryptoProvider | NaClCryptoProvider:
+ if crypto_config.provider == "bcrypt":
+ return BCryptCryptoProvider(
+ BcryptCryptoConfig(**crypto_config.model_dump())
+ )
+ if crypto_config.provider == "nacl":
+ return NaClCryptoProvider(
+ NaClCryptoConfig(**crypto_config.model_dump())
+ )
+ else:
+ raise ValueError(
+ f"Crypto provider {crypto_config.provider} not supported."
+ )
+
+ @staticmethod
+ def create_ingestion_provider(
+ ingestion_config: IngestionConfig,
+ database_provider: PostgresDatabaseProvider,
+ llm_provider: (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ *args,
+ **kwargs,
+ ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
+ config_dict = (
+ ingestion_config.model_dump()
+ if isinstance(ingestion_config, IngestionConfig)
+ else ingestion_config
+ )
+
+ extra_fields = config_dict.pop("extra_fields", {})
+
+ if config_dict["provider"] == "r2r":
+ r2r_ingestion_config = R2RIngestionConfig(
+ **config_dict, **extra_fields
+ )
+ return R2RIngestionProvider(
+ r2r_ingestion_config, database_provider, llm_provider
+ )
+ elif config_dict["provider"] in [
+ "unstructured_local",
+ "unstructured_api",
+ ]:
+ unstructured_ingestion_config = UnstructuredIngestionConfig(
+ **config_dict, **extra_fields
+ )
+
+ return UnstructuredIngestionProvider(
+ unstructured_ingestion_config, database_provider, llm_provider
+ )
+ else:
+ raise ValueError(
+ f"Ingestion provider {ingestion_config.provider} not supported"
+ )
+
+ @staticmethod
+ def create_orchestration_provider(
+ config: OrchestrationConfig, *args, **kwargs
+ ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
+ if config.provider == "hatchet":
+ orchestration_provider = HatchetOrchestrationProvider(config)
+ orchestration_provider.get_worker("r2r-worker")
+ return orchestration_provider
+ elif config.provider == "simple":
+ from core.providers import SimpleOrchestrationProvider
+
+ return SimpleOrchestrationProvider(config)
+ else:
+ raise ValueError(
+ f"Orchestration provider {config.provider} not supported"
+ )
+
+ async def create_database_provider(
+ self,
+ db_config: DatabaseConfig,
+ crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
+ *args,
+ **kwargs,
+ ) -> PostgresDatabaseProvider:
+ if not self.config.embedding.base_dimension:
+ raise ValueError(
+ "Embedding config must have a base dimension to initialize database."
+ )
+
+ dimension = self.config.embedding.base_dimension
+ quantization_type = (
+ self.config.embedding.quantization_settings.quantization_type
+ )
+ if db_config.provider == "postgres":
+ database_provider = PostgresDatabaseProvider(
+ db_config,
+ dimension,
+ crypto_provider=crypto_provider,
+ quantization_type=quantization_type,
+ )
+ await database_provider.initialize()
+ return database_provider
+ else:
+ raise ValueError(
+ f"Database provider {db_config.provider} not supported"
+ )
+
+ @staticmethod
+ def create_embedding_provider(
+ embedding: EmbeddingConfig, *args, **kwargs
+ ) -> (
+ LiteLLMEmbeddingProvider
+ | OllamaEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ ):
+ embedding_provider: Optional[EmbeddingProvider] = None
+
+ if embedding.provider == "openai":
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ from core.providers import OpenAIEmbeddingProvider
+
+ embedding_provider = OpenAIEmbeddingProvider(embedding)
+
+ elif embedding.provider == "litellm":
+ from core.providers import LiteLLMEmbeddingProvider
+
+ embedding_provider = LiteLLMEmbeddingProvider(embedding)
+
+ elif embedding.provider == "ollama":
+ from core.providers import OllamaEmbeddingProvider
+
+ embedding_provider = OllamaEmbeddingProvider(embedding)
+
+ else:
+ raise ValueError(
+ f"Embedding provider {embedding.provider} not supported"
+ )
+
+ return embedding_provider
+
+ @staticmethod
+ def create_llm_provider(
+ llm_config: CompletionConfig, *args, **kwargs
+ ) -> (
+ AnthropicCompletionProvider
+ | LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ):
+ llm_provider: Optional[CompletionProvider] = None
+ if llm_config.provider == "anthropic":
+ llm_provider = AnthropicCompletionProvider(llm_config)
+ elif llm_config.provider == "litellm":
+ llm_provider = LiteLLMCompletionProvider(llm_config)
+ elif llm_config.provider == "openai":
+ llm_provider = OpenAICompletionProvider(llm_config)
+ elif llm_config.provider == "r2r":
+ llm_provider = R2RCompletionProvider(llm_config)
+ else:
+ raise ValueError(
+ f"Language model provider {llm_config.provider} not supported"
+ )
+ if not llm_provider:
+ raise ValueError("Language model provider not found")
+ return llm_provider
+
+ @staticmethod
+ async def create_email_provider(
+ email_config: Optional[EmailConfig] = None, *args, **kwargs
+ ) -> (
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ ):
+ """Creates an email provider based on configuration."""
+ if not email_config:
+ raise ValueError(
+ "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
+ )
+
+ if email_config.provider == "smtp":
+ return AsyncSMTPEmailProvider(email_config)
+ elif email_config.provider == "console_mock":
+ return ConsoleMockEmailProvider(email_config)
+ elif email_config.provider == "sendgrid":
+ return SendGridEmailProvider(email_config)
+ elif email_config.provider == "mailersend":
+ return MailerSendEmailProvider(email_config)
+ else:
+ raise ValueError(
+ f"Email provider {email_config.provider} not supported."
+ )
+
+ async def create_providers(
+ self,
+ auth_provider_override: Optional[
+ R2RAuthProvider | SupabaseAuthProvider
+ ] = None,
+ crypto_provider_override: Optional[
+ BCryptCryptoProvider | NaClCryptoProvider
+ ] = None,
+ database_provider_override: Optional[PostgresDatabaseProvider] = None,
+ email_provider_override: Optional[
+ AsyncSMTPEmailProvider
+ | ConsoleMockEmailProvider
+ | SendGridEmailProvider
+ | MailerSendEmailProvider
+ ] = None,
+ embedding_provider_override: Optional[
+ LiteLLMEmbeddingProvider
+ | OpenAIEmbeddingProvider
+ | OllamaEmbeddingProvider
+ ] = None,
+ ingestion_provider_override: Optional[
+ R2RIngestionProvider | UnstructuredIngestionProvider
+ ] = None,
+ llm_provider_override: Optional[
+ AnthropicCompletionProvider
+ | OpenAICompletionProvider
+ | LiteLLMCompletionProvider
+ | R2RCompletionProvider
+ ] = None,
+ orchestration_provider_override: Optional[Any] = None,
+ *args,
+ **kwargs,
+ ) -> R2RProviders:
+ if (
+ math.isnan(self.config.embedding.base_dimension)
+ != math.isnan(self.config.completion_embedding.base_dimension)
+ ) or (
+ not math.isnan(self.config.embedding.base_dimension)
+ and not math.isnan(self.config.completion_embedding.base_dimension)
+ and self.config.embedding.base_dimension
+ != self.config.completion_embedding.base_dimension
+ ):
+ raise ValueError(
+ f"Both embedding configurations must use the same dimensions. Got {self.config.embedding.base_dimension} and {self.config.completion_embedding.base_dimension}"
+ )
+
+ embedding_provider = (
+ embedding_provider_override
+ or self.create_embedding_provider(
+ self.config.embedding, *args, **kwargs
+ )
+ )
+
+ completion_embedding_provider = (
+ embedding_provider_override
+ or self.create_embedding_provider(
+ self.config.completion_embedding, *args, **kwargs
+ )
+ )
+
+ llm_provider = llm_provider_override or self.create_llm_provider(
+ self.config.completion, *args, **kwargs
+ )
+
+ crypto_provider = (
+ crypto_provider_override
+ or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
+ )
+
+ database_provider = (
+ database_provider_override
+ or await self.create_database_provider(
+ self.config.database, crypto_provider, *args, **kwargs
+ )
+ )
+
+ ingestion_provider = (
+ ingestion_provider_override
+ or self.create_ingestion_provider(
+ self.config.ingestion,
+ database_provider,
+ llm_provider,
+ *args,
+ **kwargs,
+ )
+ )
+
+ email_provider = (
+ email_provider_override
+ or await self.create_email_provider(
+ self.config.email, crypto_provider, *args, **kwargs
+ )
+ )
+
+ auth_provider = (
+ auth_provider_override
+ or await self.create_auth_provider(
+ self.config.auth,
+ crypto_provider,
+ database_provider,
+ email_provider,
+ *args,
+ **kwargs,
+ )
+ )
+
+ orchestration_provider = (
+ orchestration_provider_override
+ or self.create_orchestration_provider(self.config.orchestration)
+ )
+
+ return R2RProviders(
+ auth=auth_provider,
+ database=database_provider,
+ embedding=embedding_provider,
+ completion_embedding=completion_embedding_provider,
+ ingestion=ingestion_provider,
+ llm=llm_provider,
+ email=email_provider,
+ orchestration=orchestration_provider,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/main/config.py b/.venv/lib/python3.12/site-packages/core/main/config.py
new file mode 100644
index 00000000..f49b4041
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/config.py
@@ -0,0 +1,213 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+import logging
+import os
+from enum import Enum
+from typing import Any, Optional
+
+import toml
+from pydantic import BaseModel
+
+from ..base.abstractions import GenerationConfig
+from ..base.agent.agent import RAGAgentConfig # type: ignore
+from ..base.providers import AppConfig
+from ..base.providers.auth import AuthConfig
+from ..base.providers.crypto import CryptoConfig
+from ..base.providers.database import DatabaseConfig
+from ..base.providers.email import EmailConfig
+from ..base.providers.embedding import EmbeddingConfig
+from ..base.providers.ingestion import IngestionConfig
+from ..base.providers.llm import CompletionConfig
+from ..base.providers.orchestration import OrchestrationConfig
+from ..base.utils import deep_update
+
+logger = logging.getLogger()
+
+
+class R2RConfig:
+ current_file_path = os.path.dirname(__file__)
+ config_dir_root = os.path.join(current_file_path, "..", "configs")
+ default_config_path = os.path.join(
+ current_file_path, "..", "..", "r2r", "r2r.toml"
+ )
+
+ CONFIG_OPTIONS: dict[str, Optional[str]] = {}
+ for file_ in os.listdir(config_dir_root):
+ if file_.endswith(".toml"):
+ CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
+ config_dir_root, file_
+ )
+ CONFIG_OPTIONS["default"] = None
+
+ REQUIRED_KEYS: dict[str, list] = {
+ "app": [],
+ "completion": ["provider"],
+ "crypto": ["provider"],
+ "email": ["provider"],
+ "auth": ["provider"],
+ "embedding": [
+ "provider",
+ "base_model",
+ "base_dimension",
+ "batch_size",
+ "add_title_as_prefix",
+ ],
+ "completion_embedding": [
+ "provider",
+ "base_model",
+ "base_dimension",
+ "batch_size",
+ "add_title_as_prefix",
+ ],
+ # TODO - deprecated, remove
+ "ingestion": ["provider"],
+ "logging": ["provider", "log_table"],
+ "database": ["provider"],
+ "agent": ["generation_config"],
+ "orchestration": ["provider"],
+ }
+
+ app: AppConfig
+ auth: AuthConfig
+ completion: CompletionConfig
+ crypto: CryptoConfig
+ database: DatabaseConfig
+ embedding: EmbeddingConfig
+ completion_embedding: EmbeddingConfig
+ email: EmailConfig
+ ingestion: IngestionConfig
+ agent: RAGAgentConfig
+ orchestration: OrchestrationConfig
+
+ def __init__(self, config_data: dict[str, Any]):
+ """
+ :param config_data: dictionary of configuration parameters
+ :param base_path: base path when a relative path is specified for the prompts directory
+ """
+ # Load the default configuration
+ default_config = self.load_default_config()
+
+ # Override the default configuration with the passed configuration
+ default_config = deep_update(default_config, config_data)
+
+ # Validate and set the configuration
+ for section, keys in R2RConfig.REQUIRED_KEYS.items():
+ # Check the keys when provider is set
+ # TODO - remove after deprecation
+ if section in ["graph", "file"] and section not in default_config:
+ continue
+ if "provider" in default_config[section] and (
+ default_config[section]["provider"] is not None
+ and default_config[section]["provider"] != "None"
+ and default_config[section]["provider"] != "null"
+ ):
+ self._validate_config_section(default_config, section, keys)
+ setattr(self, section, default_config[section])
+
+ self.app = AppConfig.create(**self.app) # type: ignore
+ self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
+ self.completion = CompletionConfig.create(
+ **self.completion, app=self.app
+ ) # type: ignore
+ self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
+ self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore
+ self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
+ self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
+ self.completion_embedding = EmbeddingConfig.create(
+ **self.completion_embedding, app=self.app
+ ) # type: ignore
+ self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
+ self.agent = RAGAgentConfig.create(**self.agent, app=self.app) # type: ignore
+ self.orchestration = OrchestrationConfig.create(
+ **self.orchestration, app=self.app
+ ) # type: ignore
+
+ IngestionConfig.set_default(**self.ingestion.dict())
+
+ # override GenerationConfig defaults
+ if self.completion.generation_config:
+ GenerationConfig.set_default(
+ **self.completion.generation_config.dict()
+ )
+
+ def _validate_config_section(
+ self, config_data: dict[str, Any], section: str, keys: list
+ ):
+ if section not in config_data:
+ raise ValueError(f"Missing '{section}' section in config")
+ if missing_keys := [
+ key for key in keys if key not in config_data[section]
+ ]:
+ raise ValueError(
+ f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
+ )
+
+ @classmethod
+ def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
+ if config_path is None:
+ config_path = R2RConfig.default_config_path
+
+ # Load configuration from TOML file
+ with open(config_path, encoding="utf-8") as f:
+ config_data = toml.load(f)
+
+ return cls(config_data)
+
+ def to_toml(self):
+ config_data = {}
+ for section in R2RConfig.REQUIRED_KEYS.keys():
+ section_data = self._serialize_config(getattr(self, section))
+ if isinstance(section_data, dict):
+ # Remove app from nested configs before serializing
+ section_data.pop("app", None)
+ config_data[section] = section_data
+ return toml.dumps(config_data)
+
+ @classmethod
+ def load_default_config(cls) -> dict:
+ with open(R2RConfig.default_config_path, encoding="utf-8") as f:
+ return toml.load(f)
+
+ @staticmethod
+ def _serialize_config(config_section: Any):
+ """Serialize config section while excluding internal state."""
+ if isinstance(config_section, dict):
+ return {
+ R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
+ for k, v in config_section.items()
+ if k != "app" # Exclude app from serialization
+ }
+ elif isinstance(config_section, (list, tuple)):
+ return [
+ R2RConfig._serialize_config(item) for item in config_section
+ ]
+ elif isinstance(config_section, Enum):
+ return config_section.value
+ elif isinstance(config_section, BaseModel):
+ data = config_section.model_dump(exclude_none=True)
+ data.pop("app", None) # Remove app from the serialized data
+ return R2RConfig._serialize_config(data)
+ else:
+ return config_section
+
+ @staticmethod
+ def _serialize_key(key: Any) -> str:
+ return key.value if isinstance(key, Enum) else str(key)
+
+ @classmethod
+ def load(
+ cls,
+ config_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ ) -> "R2RConfig":
+ if config_path and config_name:
+ raise ValueError(
+ f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
+ )
+
+ if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
+ return cls.from_toml(config_path)
+
+ config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
+ if config_name not in R2RConfig.CONFIG_OPTIONS:
+ raise ValueError(f"Invalid config name: {config_name}")
+ return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py
new file mode 100644
index 00000000..19cb0428
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/__init__.py
@@ -0,0 +1,16 @@
+# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
+from .hatchet.graph_workflow import ( # type: ignore
+ hatchet_graph_search_results_factory,
+)
+from .hatchet.ingestion_workflow import ( # type: ignore
+ hatchet_ingestion_factory,
+)
+from .simple.graph_workflow import simple_graph_search_results_factory
+from .simple.ingestion_workflow import simple_ingestion_factory
+
+__all__ = [
+ "hatchet_ingestion_factory",
+ "hatchet_graph_search_results_factory",
+ "simple_ingestion_factory",
+ "simple_graph_search_results_factory",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py
new file mode 100644
index 00000000..cc128b0f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/graph_workflow.py
@@ -0,0 +1,539 @@
+# type: ignore
+import asyncio
+import contextlib
+import json
+import logging
+import math
+import time
+import uuid
+from typing import TYPE_CHECKING
+
+from hatchet_sdk import ConcurrencyLimitStrategy, Context
+
+from core import GenerationConfig
+from core.base import OrchestrationProvider, R2RException
+from core.base.abstractions import (
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+)
+
+from ...services import GraphService
+
+if TYPE_CHECKING:
+ from hatchet_sdk import Hatchet
+
+logger = logging.getLogger()
+
+
+def hatchet_graph_search_results_factory(
+ orchestration_provider: OrchestrationProvider, service: GraphService
+) -> dict[str, "Hatchet.Workflow"]:
+ def convert_to_dict(input_data):
+ """Converts input data back to a plain dictionary format, handling
+ special cases like UUID and GenerationConfig. This is the inverse of
+ get_input_data_dict.
+
+ Args:
+ input_data: Dictionary containing the input data with potentially special types
+
+ Returns:
+ Dictionary with all values converted to basic Python types
+ """
+ output_data = {}
+
+ for key, value in input_data.items():
+ if value is None:
+ output_data[key] = None
+ continue
+
+ # Convert UUID to string
+ if isinstance(value, uuid.UUID):
+ output_data[key] = str(value)
+
+ try:
+ output_data[key] = value.model_dump()
+ except Exception:
+ # Handle nested dictionaries that might contain settings
+ if isinstance(value, dict):
+ output_data[key] = convert_to_dict(value)
+
+ # Handle lists that might contain dictionaries
+ elif isinstance(value, list):
+ output_data[key] = [
+ (
+ convert_to_dict(item)
+ if isinstance(item, dict)
+ else item
+ )
+ for item in value
+ ]
+
+ # All other types can be directly assigned
+ else:
+ output_data[key] = value
+
+ return output_data
+
+ def get_input_data_dict(input_data):
+ for key, value in input_data.items():
+ if value is None:
+ continue
+
+ if key == "document_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "collection_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "graph_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key in ["graph_creation_settings", "graph_enrichment_settings"]:
+ # Ensure we have a dict (if not already)
+ input_data[key] = (
+ json.loads(value) if not isinstance(value, dict) else value
+ )
+
+ if "generation_config" in input_data[key]:
+ gen_cfg = input_data[key]["generation_config"]
+ # If it's a dict, convert it
+ if isinstance(gen_cfg, dict):
+ input_data[key]["generation_config"] = (
+ GenerationConfig(**gen_cfg)
+ )
+ # If it's not already a GenerationConfig, default it
+ elif not isinstance(gen_cfg, GenerationConfig):
+ input_data[key]["generation_config"] = (
+ GenerationConfig()
+ )
+
+ input_data[key]["generation_config"].model = (
+ input_data[key]["generation_config"].model
+ or service.config.app.fast_llm
+ )
+
+ return input_data
+
+ @orchestration_provider.workflow(name="graph-extraction", timeout="360m")
+ class GraphExtractionWorkflow:
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ with contextlib.suppress(Exception):
+ return str(
+ context.workflow_input()["request"]["collection_id"]
+ )
+
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.step(retries=1, timeout="360m")
+ async def graph_search_results_extraction(
+ self, context: Context
+ ) -> dict:
+ request = context.workflow_input()["request"]
+
+ input_data = get_input_data_dict(request)
+ document_id = input_data.get("document_id", None)
+ collection_id = input_data.get("collection_id", None)
+
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=document_id,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.PROCESSING,
+ )
+
+ if collection_id and not document_id:
+ document_ids = await self.graph_search_results_service.get_document_ids_for_create_graph(
+ collection_id=collection_id,
+ **input_data["graph_creation_settings"],
+ )
+ workflows = []
+
+ for document_id in document_ids:
+ input_data_copy = input_data.copy()
+ input_data_copy["collection_id"] = str(
+ input_data_copy["collection_id"]
+ )
+ input_data_copy["document_id"] = str(document_id)
+
+ workflows.append(
+ context.aio.spawn_workflow(
+ "graph-extraction",
+ {
+ "request": {
+ **convert_to_dict(input_data_copy),
+ }
+ },
+ key=str(document_id),
+ )
+ )
+ # Wait for all workflows to complete
+ results = await asyncio.gather(*workflows)
+ return {
+ "result": f"successfully submitted graph_search_results relationships extraction for document {document_id}",
+ "document_id": str(collection_id),
+ }
+
+ else:
+ # Extract relationships and store them
+ extractions = []
+ async for extraction in self.graph_search_results_service.graph_search_results_extraction(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ ):
+ logger.info(
+ f"Found extraction with {len(extraction.entities)} entities"
+ )
+ extractions.append(extraction)
+
+ await self.graph_search_results_service.store_graph_search_results_extractions(
+ extractions
+ )
+
+ logger.info(
+ f"Successfully ran graph_search_results relationships extraction for document {document_id}"
+ )
+
+ return {
+ "result": f"successfully ran graph_search_results relationships extraction for document {document_id}",
+ "document_id": str(document_id),
+ }
+
+ @orchestration_provider.step(
+ retries=1,
+ timeout="360m",
+ parents=["graph_search_results_extraction"],
+ )
+ async def graph_search_results_entity_description(
+ self, context: Context
+ ) -> dict:
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+ document_id = input_data.get("document_id", None)
+
+ # Describe the entities in the graph
+ await self.graph_search_results_service.graph_search_results_entity_description(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ )
+
+ logger.info(
+ f"Successfully ran graph_search_results entity description for document {document_id}"
+ )
+
+ if service.providers.database.config.graph_creation_settings.automatic_deduplication:
+ extract_input = {
+ "document_id": str(document_id),
+ }
+
+ extract_result = (
+ await context.aio.spawn_workflow(
+ "graph-deduplication",
+ {"request": extract_input},
+ )
+ ).result()
+
+ await asyncio.gather(extract_result)
+
+ return {
+ "result": f"successfully ran graph_search_results entity description for document {document_id}"
+ }
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ request = context.workflow_input().get("request", {})
+ document_id = request.get("document_id")
+
+ if not document_id:
+ logger.info(
+ "No document id was found in workflow input to mark a failure."
+ )
+ return
+
+ try:
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=uuid.UUID(document_id),
+ status_type="extraction_status",
+ status=GraphExtractionStatus.FAILED,
+ )
+ logger.info(
+ f"Updated Graph extraction status for {document_id} to FAILED"
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status for {document_id}: {e}"
+ )
+
+ @orchestration_provider.workflow(name="graph-clustering", timeout="360m")
+ class GraphClusteringWorkflow:
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.step(retries=1, timeout="360m", parents=[])
+ async def graph_search_results_clustering(
+ self, context: Context
+ ) -> dict:
+ logger.info("Running Graph Clustering")
+
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+
+ # Get the collection_id and graph_id
+ collection_id = input_data.get("collection_id", None)
+ graph_id = input_data.get("graph_id", None)
+
+ # Check current workflow status
+ workflow_status = await self.graph_search_results_service.providers.database.documents_handler.get_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ )
+
+ if workflow_status == GraphConstructionStatus.SUCCESS:
+ raise R2RException(
+ "Communities have already been built for this collection. To build communities again, first reset the graph.",
+ 400,
+ )
+
+ # Run clustering
+ try:
+ graph_search_results_clustering_results = await self.graph_search_results_service.graph_search_results_clustering(
+ collection_id=collection_id,
+ graph_id=graph_id,
+ **input_data["graph_enrichment_settings"],
+ )
+
+ num_communities = graph_search_results_clustering_results[
+ "num_communities"
+ ][0]
+
+ if num_communities == 0:
+ raise R2RException("No communities found", 400)
+
+ return {
+ "result": graph_search_results_clustering_results,
+ }
+ except Exception as e:
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.FAILED,
+ )
+ raise e
+
+ @orchestration_provider.step(
+ retries=1,
+ timeout="360m",
+ parents=["graph_search_results_clustering"],
+ )
+ async def graph_search_results_community_summary(
+ self, context: Context
+ ) -> dict:
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+ collection_id = input_data.get("collection_id", None)
+ graph_id = input_data.get("graph_id", None)
+ # Get number of communities from previous step
+ num_communities = context.step_output(
+ "graph_search_results_clustering"
+ )["result"]["num_communities"][0]
+
+ # Calculate batching
+ parallel_communities = min(100, num_communities)
+ total_workflows = math.ceil(num_communities / parallel_communities)
+ workflows = []
+
+ logger.info(
+ f"Running Graph Community Summary for {num_communities} communities, spawning {total_workflows} workflows"
+ )
+
+ # Spawn summary workflows
+ for i in range(total_workflows):
+ offset = i * parallel_communities
+ limit = min(parallel_communities, num_communities - offset)
+
+ workflows.append(
+ (
+ await context.aio.spawn_workflow(
+ "graph-community-summarization",
+ {
+ "request": {
+ "offset": offset,
+ "limit": limit,
+ "graph_id": (
+ str(graph_id) if graph_id else None
+ ),
+ "collection_id": (
+ str(collection_id)
+ if collection_id
+ else None
+ ),
+ "graph_enrichment_settings": convert_to_dict(
+ input_data["graph_enrichment_settings"]
+ ),
+ }
+ },
+ key=f"{i}/{total_workflows}_community_summary",
+ )
+ ).result()
+ )
+
+ results = await asyncio.gather(*workflows)
+ logger.info(
+ f"Completed {len(results)} community summary workflows"
+ )
+
+ # Update statuses
+ document_ids = await self.graph_search_results_service.providers.database.documents_handler.get_document_ids_by_status(
+ status_type="extraction_status",
+ status=GraphExtractionStatus.SUCCESS,
+ collection_id=collection_id,
+ )
+
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=document_ids,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.ENRICHED,
+ )
+
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.SUCCESS,
+ )
+
+ return {
+ "result": f"Successfully completed enrichment with {len(results)} summary workflows"
+ }
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ collection_id = context.workflow_input()["request"].get(
+ "collection_id", None
+ )
+ if collection_id:
+ await self.graph_search_results_service.providers.database.documents_handler.set_workflow_status(
+ id=uuid.UUID(collection_id),
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.FAILED,
+ )
+
+ @orchestration_provider.workflow(
+ name="graph-community-summarization", timeout="360m"
+ )
+ class GraphCommunitySummarizerWorkflow:
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ try:
+ return str(
+ context.workflow_input()["request"]["collection_id"]
+ )
+ except Exception:
+ return str(uuid.uuid4())
+
+ @orchestration_provider.step(retries=1, timeout="360m")
+ async def graph_search_results_community_summary(
+ self, context: Context
+ ) -> dict:
+ start_time = time.time()
+
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+
+ base_args = {
+ k: v
+ for k, v in input_data.items()
+ if k != "graph_enrichment_settings"
+ }
+ enrichment_args = input_data.get("graph_enrichment_settings", {})
+
+ # Merge them together.
+ # Note: if there is any key overlap, values from enrichment_args will override those from base_args.
+ merged_args = {**base_args, **enrichment_args}
+
+ # Now call the service method with all arguments at the top level.
+ # This ensures that keys like "max_summary_input_length" and "generation_config" are present.
+ community_summary = await self.graph_search_results_service.graph_search_results_community_summary(
+ **merged_args
+ )
+ logger.info(
+ f"Successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)} in {time.time() - start_time:.2f} seconds "
+ )
+ return {
+ "result": f"successfully ran graph_search_results community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}"
+ }
+
+ @orchestration_provider.workflow(
+ name="graph-deduplication", timeout="360m"
+ )
+ class GraphDeduplicationWorkflow:
+ def __init__(self, graph_search_results_service: GraphService):
+ self.graph_search_results_service = graph_search_results_service
+
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.graph_search_results_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ try:
+ return str(context.workflow_input()["request"]["document_id"])
+ except Exception:
+ return str(uuid.uuid4())
+
+ @orchestration_provider.step(retries=1, timeout="360m")
+ async def deduplicate_document_entities(
+ self, context: Context
+ ) -> dict:
+ start_time = time.time()
+
+ input_data = get_input_data_dict(
+ context.workflow_input()["request"]
+ )
+
+ document_id = input_data.get("document_id", None)
+
+ await service.deduplicate_document_entities(
+ document_id=document_id,
+ )
+ logger.info(
+ f"Successfully ran deduplication for document {document_id} in {time.time() - start_time:.2f} seconds "
+ )
+ return {
+ "result": f"Successfully ran deduplication for document {document_id}"
+ }
+
+ return {
+ "graph-extraction": GraphExtractionWorkflow(service),
+ "graph-clustering": GraphClusteringWorkflow(service),
+ "graph-community-summarization": GraphCommunitySummarizerWorkflow(
+ service
+ ),
+ "graph-deduplication": GraphDeduplicationWorkflow(service),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py
new file mode 100644
index 00000000..96d7aebb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/hatchet/ingestion_workflow.py
@@ -0,0 +1,721 @@
+# type: ignore
+import asyncio
+import logging
+import uuid
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+import tiktoken
+from fastapi import HTTPException
+from hatchet_sdk import ConcurrencyLimitStrategy, Context
+from litellm import AuthenticationError
+
+from core.base import (
+ DocumentChunk,
+ GraphConstructionStatus,
+ IngestionStatus,
+ OrchestrationProvider,
+ generate_extraction_id,
+)
+from core.base.abstractions import DocumentResponse, R2RException
+from core.utils import (
+ generate_default_user_collection_id,
+ update_settings_from_dict,
+)
+
+from ...services import IngestionService, IngestionServiceAdapter
+
+if TYPE_CHECKING:
+ from hatchet_sdk import Hatchet
+
+logger = logging.getLogger()
+
+
+# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module
+def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int:
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ # Fallback to a known encoding if model not recognized
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+def hatchet_ingestion_factory(
+ orchestration_provider: OrchestrationProvider, service: IngestionService
+) -> dict[str, "Hatchet.Workflow"]:
+ @orchestration_provider.workflow(
+ name="ingest-files",
+ timeout="60m",
+ )
+ class HatchetIngestFilesWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.concurrency( # type: ignore
+ max_runs=orchestration_provider.config.ingestion_concurrency_limit, # type: ignore
+ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
+ )
+ def concurrency(self, context: Context) -> str:
+ # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun
+ try:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_ingest_file_input(
+ input_data
+ )
+ return str(parsed_data["user"].id)
+ except Exception:
+ return str(uuid.uuid4())
+
+ @orchestration_provider.step(retries=0, timeout="60m")
+ async def parse(self, context: Context) -> dict:
+ try:
+ logger.info("Initiating ingestion workflow, step: parse")
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_ingest_file_input(
+ input_data
+ )
+
+ # ingestion_result = (
+ # await self.ingestion_service.ingest_file_ingress(
+ # **parsed_data
+ # )
+ # )
+
+ # document_info = ingestion_result["info"]
+ document_info = (
+ self.ingestion_service.create_document_info_from_file(
+ parsed_data["document_id"],
+ parsed_data["user"],
+ parsed_data["file_data"]["filename"],
+ parsed_data["metadata"],
+ parsed_data["version"],
+ parsed_data["size_in_bytes"],
+ )
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.PARSING,
+ )
+
+ ingestion_config = parsed_data["ingestion_config"] or {}
+ extractions_generator = self.ingestion_service.parse_file(
+ document_info, ingestion_config
+ )
+
+ extractions = []
+ async for extraction in extractions_generator:
+ extractions.append(extraction)
+
+ # 2) Sum tokens
+ total_tokens = 0
+ for chunk in extractions:
+ text_data = chunk.data
+ if not isinstance(text_data, str):
+ text_data = text_data.decode("utf-8", errors="ignore")
+ total_tokens += count_tokens_for_text(text_data)
+ document_info.total_tokens = total_tokens
+
+ if not ingestion_config.get("skip_document_summary", False):
+ await service.update_document_status(
+ document_info, status=IngestionStatus.AUGMENTING
+ )
+ await service.augment_document_info(
+ document_info,
+ [extraction.to_dict() for extraction in extractions],
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.EMBEDDING,
+ )
+
+ # extractions = context.step_output("parse")["extractions"]
+
+ embedding_generator = self.ingestion_service.embed_document(
+ [extraction.to_dict() for extraction in extractions]
+ )
+
+ embeddings = []
+ async for embedding in embedding_generator:
+ embeddings.append(embedding)
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.STORING,
+ )
+
+ storage_generator = self.ingestion_service.store_embeddings( # type: ignore
+ embeddings
+ )
+
+ async for _ in storage_generator:
+ pass
+
+ await self.ingestion_service.finalize_ingestion(document_info)
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.SUCCESS,
+ )
+
+ collection_ids = context.workflow_input()["request"].get(
+ "collection_ids"
+ )
+ if not collection_ids:
+ # TODO: Move logic onto the `management service`
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ else:
+ for collection_id_str in collection_ids:
+ collection_id = UUID(collection_id_str)
+ try:
+ name = document_info.title or "N/A"
+ description = ""
+ await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await (
+ self.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ )
+
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ status=GraphConstructionStatus.OUTDATED,
+ )
+
+ # get server chunk enrichment settings and override parts of it if provided in the ingestion config
+ if server_chunk_enrichment_settings := getattr(
+ service.providers.ingestion.config,
+ "chunk_enrichment_settings",
+ None,
+ ):
+ chunk_enrichment_settings = update_settings_from_dict(
+ server_chunk_enrichment_settings,
+ ingestion_config.get("chunk_enrichment_settings", {})
+ or {},
+ )
+
+ if chunk_enrichment_settings.enable_chunk_enrichment:
+ logger.info("Enriching document with contextual chunks")
+
+ document_info: DocumentResponse = (
+ await self.ingestion_service.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_user_ids=[document_info.owner_id],
+ filter_document_ids=[document_info.id],
+ )
+ )["results"][0]
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.ENRICHING,
+ )
+
+ await self.ingestion_service.chunk_enrichment(
+ document_id=document_info.id,
+ document_summary=document_info.summary,
+ chunk_enrichment_settings=chunk_enrichment_settings,
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.SUCCESS,
+ )
+ # ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+ if service.providers.ingestion.config.automatic_extraction:
+ extract_input = {
+ "document_id": str(document_info.id),
+ "graph_creation_settings": self.ingestion_service.providers.database.config.graph_creation_settings.model_dump_json(),
+ "user": input_data["user"],
+ }
+
+ extract_result = (
+ await context.aio.spawn_workflow(
+ "graph-extraction",
+ {"request": extract_input},
+ )
+ ).result()
+
+ await asyncio.gather(extract_result)
+
+ return {
+ "status": "Successfully finalized ingestion",
+ "document_info": document_info.to_dict(),
+ }
+
+ except AuthenticationError:
+ raise R2RException(
+ status_code=401,
+ message="Authentication error: Invalid API key or credentials.",
+ ) from None
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during ingestion: {str(e)}",
+ ) from e
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ request = context.workflow_input().get("request", {})
+ document_id = request.get("document_id")
+
+ if not document_id:
+ logger.error(
+ "No document id was found in workflow input to mark a failure."
+ )
+ return
+
+ try:
+ documents_overview = (
+ await self.ingestion_service.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+
+ if not documents_overview:
+ logger.error(
+ f"Document with id {document_id} not found in database to mark failure."
+ )
+ return
+
+ document_info = documents_overview[0]
+
+ # Update the document status to FAILED
+ if document_info.ingestion_status != IngestionStatus.SUCCESS:
+ await self.ingestion_service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{context.step_run_errors()}"},
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status for {document_id}: {e}"
+ )
+
+ @orchestration_provider.workflow(
+ name="ingest-chunks",
+ timeout="60m",
+ )
+ class HatchetIngestChunksWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="60m")
+ async def ingest(self, context: Context) -> dict:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input(
+ input_data
+ )
+
+ document_info = await self.ingestion_service.ingest_chunks_ingress(
+ **parsed_data
+ )
+
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.EMBEDDING
+ )
+ document_id = document_info.id
+
+ extractions = [
+ DocumentChunk(
+ id=generate_extraction_id(document_id, i),
+ document_id=document_id,
+ collection_ids=[],
+ owner_id=document_info.owner_id,
+ data=chunk.text,
+ metadata=parsed_data["metadata"],
+ ).to_dict()
+ for i, chunk in enumerate(parsed_data["chunks"])
+ ]
+
+ # 2) Sum tokens
+ total_tokens = 0
+ for chunk in extractions:
+ text_data = chunk["data"]
+ if not isinstance(text_data, str):
+ text_data = text_data.decode("utf-8", errors="ignore")
+ total_tokens += count_tokens_for_text(text_data)
+ document_info.total_tokens = total_tokens
+
+ return {
+ "status": "Successfully ingested chunks",
+ "extractions": extractions,
+ "document_info": document_info.to_dict(),
+ }
+
+ @orchestration_provider.step(parents=["ingest"], timeout="60m")
+ async def embed(self, context: Context) -> dict:
+ document_info_dict = context.step_output("ingest")["document_info"]
+ document_info = DocumentResponse(**document_info_dict)
+
+ extractions = context.step_output("ingest")["extractions"]
+
+ embedding_generator = self.ingestion_service.embed_document(
+ extractions
+ )
+ embeddings = [
+ embedding.model_dump()
+ async for embedding in embedding_generator
+ ]
+
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.STORING
+ )
+
+ storage_generator = self.ingestion_service.store_embeddings(
+ embeddings
+ )
+ async for _ in storage_generator:
+ pass
+
+ return {
+ "status": "Successfully embedded and stored chunks",
+ "document_info": document_info.to_dict(),
+ }
+
+ @orchestration_provider.step(parents=["embed"], timeout="60m")
+ async def finalize(self, context: Context) -> dict:
+ document_info_dict = context.step_output("embed")["document_info"]
+ document_info = DocumentResponse(**document_info_dict)
+
+ await self.ingestion_service.finalize_ingestion(document_info)
+
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.SUCCESS
+ )
+
+ try:
+ # TODO - Move logic onto the `management service`
+ collection_ids = context.workflow_input()["request"].get(
+ "collection_ids"
+ )
+ if not collection_ids:
+ # TODO: Move logic onto the `management service`
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ else:
+ for collection_id_str in collection_ids:
+ collection_id = UUID(collection_id_str)
+ try:
+ name = document_info.title or "N/A"
+ description = ""
+ await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await (
+ self.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ )
+
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+ except Exception as e:
+ logger.error(
+ f"Error during assigning document to collection: {str(e)}"
+ )
+
+ return {
+ "status": "Successfully finalized ingestion",
+ "document_info": document_info.to_dict(),
+ }
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ request = context.workflow_input().get("request", {})
+ document_id = request.get("document_id")
+
+ if not document_id:
+ logger.error(
+ "No document id was found in workflow input to mark a failure."
+ )
+ return
+
+ try:
+ documents_overview = (
+ await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+ offset=0,
+ limit=100,
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+
+ if not documents_overview:
+ logger.error(
+ f"Document with id {document_id} not found in database to mark failure."
+ )
+ return
+
+ document_info = documents_overview[0]
+
+ if document_info.ingestion_status != IngestionStatus.SUCCESS:
+ await self.ingestion_service.update_document_status(
+ document_info, status=IngestionStatus.FAILED
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status for {document_id}: {e}"
+ )
+
+ @orchestration_provider.workflow(
+ name="update-chunk",
+ timeout="60m",
+ )
+ class HatchetUpdateChunkWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="60m")
+ async def update_chunk(self, context: Context) -> dict:
+ try:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_update_chunk_input(
+ input_data
+ )
+
+ document_uuid = (
+ UUID(parsed_data["document_id"])
+ if isinstance(parsed_data["document_id"], str)
+ else parsed_data["document_id"]
+ )
+ extraction_uuid = (
+ UUID(parsed_data["id"])
+ if isinstance(parsed_data["id"], str)
+ else parsed_data["id"]
+ )
+
+ await self.ingestion_service.update_chunk_ingress(
+ document_id=document_uuid,
+ chunk_id=extraction_uuid,
+ text=parsed_data.get("text"),
+ user=parsed_data["user"],
+ metadata=parsed_data.get("metadata"),
+ collection_ids=parsed_data.get("collection_ids"),
+ )
+
+ return {
+ "message": "Chunk update completed successfully.",
+ "task_id": context.workflow_run_id(),
+ "document_ids": [str(document_uuid)],
+ }
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during chunk update: {str(e)}",
+ ) from e
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ # Handle failure case if necessary
+ pass
+
+ @orchestration_provider.workflow(
+ name="create-vector-index", timeout="360m"
+ )
+ class HatchetCreateVectorIndexWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="60m")
+ async def create_vector_index(self, context: Context) -> dict:
+ input_data = context.workflow_input()["request"]
+ parsed_data = (
+ IngestionServiceAdapter.parse_create_vector_index_input(
+ input_data
+ )
+ )
+
+ await self.ingestion_service.providers.database.chunks_handler.create_index(
+ **parsed_data
+ )
+
+ return {
+ "status": "Vector index creation queued successfully.",
+ }
+
+ @orchestration_provider.workflow(name="delete-vector-index", timeout="30m")
+ class HatchetDeleteVectorIndexWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="10m")
+ async def delete_vector_index(self, context: Context) -> dict:
+ input_data = context.workflow_input()["request"]
+ parsed_data = (
+ IngestionServiceAdapter.parse_delete_vector_index_input(
+ input_data
+ )
+ )
+
+ await self.ingestion_service.providers.database.chunks_handler.delete_index(
+ **parsed_data
+ )
+
+ return {"status": "Vector index deleted successfully."}
+
+ @orchestration_provider.workflow(
+ name="update-document-metadata",
+ timeout="30m",
+ )
+ class HatchetUpdateDocumentMetadataWorkflow:
+ def __init__(self, ingestion_service: IngestionService):
+ self.ingestion_service = ingestion_service
+
+ @orchestration_provider.step(timeout="30m")
+ async def update_document_metadata(self, context: Context) -> dict:
+ try:
+ input_data = context.workflow_input()["request"]
+ parsed_data = IngestionServiceAdapter.parse_update_document_metadata_input(
+ input_data
+ )
+
+ document_id = UUID(parsed_data["document_id"])
+ metadata = parsed_data["metadata"]
+ user = parsed_data["user"]
+
+ await self.ingestion_service.update_document_metadata(
+ document_id=document_id,
+ metadata=metadata,
+ user=user,
+ )
+
+ return {
+ "message": "Document metadata update completed successfully.",
+ "document_id": str(document_id),
+ "task_id": context.workflow_run_id(),
+ }
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during document metadata update: {str(e)}",
+ ) from e
+
+ @orchestration_provider.failure()
+ async def on_failure(self, context: Context) -> None:
+ # Handle failure case if necessary
+ pass
+
+ # Add this to the workflows dictionary in hatchet_ingestion_factory
+ ingest_files_workflow = HatchetIngestFilesWorkflow(service)
+ ingest_chunks_workflow = HatchetIngestChunksWorkflow(service)
+ update_chunks_workflow = HatchetUpdateChunkWorkflow(service)
+ update_document_metadata_workflow = HatchetUpdateDocumentMetadataWorkflow(
+ service
+ )
+ create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service)
+ delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service)
+
+ return {
+ "ingest_files": ingest_files_workflow,
+ "ingest_chunks": ingest_chunks_workflow,
+ "update_chunk": update_chunks_workflow,
+ "update_document_metadata": update_document_metadata_workflow,
+ "create_vector_index": create_vector_index_workflow,
+ "delete_vector_index": delete_vector_index_workflow,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py
new file mode 100644
index 00000000..9e043263
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/graph_workflow.py
@@ -0,0 +1,222 @@
+import json
+import logging
+import math
+import uuid
+
+from core import GenerationConfig, R2RException
+from core.base.abstractions import (
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+)
+
+from ...services import GraphService
+
+logger = logging.getLogger()
+
+
+def simple_graph_search_results_factory(service: GraphService):
+ def get_input_data_dict(input_data):
+ for key, value in input_data.items():
+ if value is None:
+ continue
+
+ if key == "document_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "collection_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key == "graph_id":
+ input_data[key] = (
+ uuid.UUID(value)
+ if not isinstance(value, uuid.UUID)
+ else value
+ )
+
+ if key in ["graph_creation_settings", "graph_enrichment_settings"]:
+ # Ensure we have a dict (if not already)
+ input_data[key] = (
+ json.loads(value) if not isinstance(value, dict) else value
+ )
+
+ if "generation_config" in input_data[key]:
+ if isinstance(input_data[key]["generation_config"], dict):
+ input_data[key]["generation_config"] = (
+ GenerationConfig(
+ **input_data[key]["generation_config"]
+ )
+ )
+ elif not isinstance(
+ input_data[key]["generation_config"], GenerationConfig
+ ):
+ input_data[key]["generation_config"] = (
+ GenerationConfig()
+ )
+
+ input_data[key]["generation_config"].model = (
+ input_data[key]["generation_config"].model
+ or service.config.app.fast_llm
+ )
+
+ return input_data
+
+ async def graph_extraction(input_data):
+ input_data = get_input_data_dict(input_data)
+
+ if input_data.get("document_id"):
+ document_ids = [input_data.get("document_id")]
+ else:
+ documents = []
+ collection_id = input_data.get("collection_id")
+ batch_size = 100
+ offset = 0
+ while True:
+ # Fetch current batch
+ batch = (
+ await service.providers.database.collections_handler.documents_in_collection(
+ collection_id=collection_id,
+ offset=offset,
+ limit=batch_size,
+ )
+ )["results"]
+
+ # If no documents returned, we've reached the end
+ if not batch:
+ break
+
+ # Add current batch to results
+ documents.extend(batch)
+
+ # Update offset for next batch
+ offset += batch_size
+
+ # Optional: If batch is smaller than batch_size, we've reached the end
+ if len(batch) < batch_size:
+ break
+
+ document_ids = [document.id for document in documents]
+
+ logger.info(
+ f"Creating graph for {len(document_ids)} documents with IDs: {document_ids}"
+ )
+
+ for _, document_id in enumerate(document_ids):
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=document_id,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.PROCESSING,
+ )
+
+ # Extract relationships from the document
+ try:
+ extractions = []
+ async for (
+ extraction
+ ) in service.graph_search_results_extraction(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ ):
+ extractions.append(extraction)
+ await service.store_graph_search_results_extractions(
+ extractions
+ )
+
+ # Describe the entities in the graph
+ await service.graph_search_results_entity_description(
+ document_id=document_id,
+ **input_data["graph_creation_settings"],
+ )
+
+ if service.providers.database.config.graph_creation_settings.automatic_deduplication:
+ logger.warning(
+ "Automatic deduplication is not yet implemented for `simple` workflows."
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Error in creating graph for document {document_id}: {e}"
+ )
+ raise e
+
+ async def graph_clustering(input_data):
+ input_data = get_input_data_dict(input_data)
+ workflow_status = await service.providers.database.documents_handler.get_workflow_status(
+ id=input_data.get("collection_id", None),
+ status_type="graph_cluster_status",
+ )
+ if workflow_status == GraphConstructionStatus.SUCCESS:
+ raise R2RException(
+ "Communities have already been built for this collection. To build communities again, first submit a POST request to `graphs/{collection_id}/reset` to erase the previously built communities.",
+ 400,
+ )
+
+ try:
+ num_communities = await service.graph_search_results_clustering(
+ collection_id=input_data.get("collection_id", None),
+ # graph_id=input_data.get("graph_id", None),
+ **input_data["graph_enrichment_settings"],
+ )
+ num_communities = num_communities["num_communities"][0]
+ # TODO - Do not hardcode the number of parallel communities,
+ # make it a configurable parameter at runtime & add server-side defaults
+
+ if num_communities == 0:
+ raise R2RException("No communities found", 400)
+
+ parallel_communities = min(100, num_communities)
+
+ total_workflows = math.ceil(num_communities / parallel_communities)
+ for i in range(total_workflows):
+ input_data_copy = input_data.copy()
+ input_data_copy["offset"] = i * parallel_communities
+ input_data_copy["limit"] = min(
+ parallel_communities,
+ num_communities - i * parallel_communities,
+ )
+
+ logger.info(
+ f"Running graph_search_results community summary for workflow {i + 1} of {total_workflows}"
+ )
+
+ await service.graph_search_results_community_summary(
+ offset=input_data_copy["offset"],
+ limit=input_data_copy["limit"],
+ collection_id=input_data_copy.get("collection_id", None),
+ # graph_id=input_data_copy.get("graph_id", None),
+ **input_data_copy["graph_enrichment_settings"],
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=input_data.get("collection_id", None),
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.SUCCESS,
+ )
+
+ except Exception as e:
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=input_data.get("collection_id", None),
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.FAILED,
+ )
+
+ raise e
+
+ async def graph_deduplication(input_data):
+ input_data = get_input_data_dict(input_data)
+ await service.deduplicate_document_entities(
+ document_id=input_data.get("document_id", None),
+ )
+
+ return {
+ "graph-extraction": graph_extraction,
+ "graph-clustering": graph_clustering,
+ "graph-deduplication": graph_deduplication,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py
new file mode 100644
index 00000000..60a696c1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/orchestration/simple/ingestion_workflow.py
@@ -0,0 +1,598 @@
+import asyncio
+import logging
+from uuid import UUID
+
+import tiktoken
+from fastapi import HTTPException
+from litellm import AuthenticationError
+
+from core.base import (
+ DocumentChunk,
+ GraphConstructionStatus,
+ R2RException,
+ increment_version,
+)
+from core.utils import (
+ generate_default_user_collection_id,
+ generate_extraction_id,
+ update_settings_from_dict,
+)
+
+from ...services import IngestionService
+
+logger = logging.getLogger()
+
+
+# FIXME: No need to duplicate this function between the workflows, consolidate it into a shared module
+def count_tokens_for_text(text: str, model: str = "gpt-4o") -> int:
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ # Fallback to a known encoding if model not recognized
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ return len(encoding.encode(text, disallowed_special=()))
+
+
+def simple_ingestion_factory(service: IngestionService):
+ async def ingest_files(input_data):
+ document_info = None
+ try:
+ from core.base import IngestionStatus
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = IngestionServiceAdapter.parse_ingest_file_input(
+ input_data
+ )
+
+ document_info = service.create_document_info_from_file(
+ parsed_data["document_id"],
+ parsed_data["user"],
+ parsed_data["file_data"]["filename"],
+ parsed_data["metadata"],
+ parsed_data["version"],
+ parsed_data["size_in_bytes"],
+ )
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.PARSING
+ )
+
+ ingestion_config = parsed_data["ingestion_config"]
+ extractions_generator = service.parse_file(
+ document_info=document_info,
+ ingestion_config=ingestion_config,
+ )
+ extractions = [
+ extraction.model_dump()
+ async for extraction in extractions_generator
+ ]
+
+ # 2) Sum tokens
+ total_tokens = 0
+ for chunk_dict in extractions:
+ text_data = chunk_dict["data"]
+ if not isinstance(text_data, str):
+ text_data = text_data.decode("utf-8", errors="ignore")
+ total_tokens += count_tokens_for_text(text_data)
+ document_info.total_tokens = total_tokens
+
+ if not ingestion_config.get("skip_document_summary", False):
+ await service.update_document_status(
+ document_info=document_info,
+ status=IngestionStatus.AUGMENTING,
+ )
+ await service.augment_document_info(document_info, extractions)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.EMBEDDING
+ )
+ embedding_generator = service.embed_document(extractions)
+ embeddings = [
+ embedding.model_dump()
+ async for embedding in embedding_generator
+ ]
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.STORING
+ )
+ storage_generator = service.store_embeddings(embeddings)
+ async for _ in storage_generator:
+ pass
+
+ await service.finalize_ingestion(document_info)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.SUCCESS
+ )
+
+ collection_ids = parsed_data.get("collection_ids")
+
+ try:
+ if not collection_ids:
+ # TODO: Move logic onto the `management service`
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+ else:
+ for collection_id in collection_ids:
+ try:
+ # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
+ name = "My Collection"
+ description = f"A collection started during {document_info.title} ingestion"
+
+ await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await service.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+ except Exception as e:
+ logger.error(
+ f"Error during assigning document to collection: {str(e)}"
+ )
+
+ # Chunk enrichment
+ if server_chunk_enrichment_settings := getattr(
+ service.providers.ingestion.config,
+ "chunk_enrichment_settings",
+ None,
+ ):
+ chunk_enrichment_settings = update_settings_from_dict(
+ server_chunk_enrichment_settings,
+ ingestion_config.get("chunk_enrichment_settings", {})
+ or {},
+ )
+
+ if chunk_enrichment_settings.enable_chunk_enrichment:
+ logger.info("Enriching document with contextual chunks")
+
+ # Get updated document info with collection IDs
+ document_info = (
+ await service.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_user_ids=[document_info.owner_id],
+ filter_document_ids=[document_info.id],
+ )
+ )["results"][0]
+
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.ENRICHING,
+ )
+
+ await service.chunk_enrichment(
+ document_id=document_info.id,
+ document_summary=document_info.summary,
+ chunk_enrichment_settings=chunk_enrichment_settings,
+ )
+
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.SUCCESS,
+ )
+
+ # Automatic extraction
+ if service.providers.ingestion.config.automatic_extraction:
+ logger.warning(
+ "Automatic extraction not yet implemented for `simple` ingestion workflows."
+ )
+
+ except AuthenticationError as e:
+ if document_info is not None:
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{str(e)}"},
+ )
+ raise R2RException(
+ status_code=401,
+ message="Authentication error: Invalid API key or credentials.",
+ ) from e
+ except Exception as e:
+ if document_info is not None:
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{str(e)}"},
+ )
+ if isinstance(e, R2RException):
+ raise
+ raise HTTPException(
+ status_code=500, detail=f"Error during ingestion: {str(e)}"
+ ) from e
+
+ async def update_files(input_data):
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = IngestionServiceAdapter.parse_update_files_input(
+ input_data
+ )
+
+ file_datas = parsed_data["file_datas"]
+ user = parsed_data["user"]
+ document_ids = parsed_data["document_ids"]
+ metadatas = parsed_data["metadatas"]
+ ingestion_config = parsed_data["ingestion_config"]
+ file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"]
+
+ if not file_datas:
+ raise R2RException(
+ status_code=400, message="No files provided for update."
+ ) from None
+ if len(document_ids) != len(file_datas):
+ raise R2RException(
+ status_code=400,
+ message="Number of ids does not match number of files.",
+ ) from None
+
+ documents_overview = (
+ await service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+ offset=0,
+ limit=100,
+ filter_user_ids=None if user.is_superuser else [user.id],
+ filter_document_ids=document_ids,
+ )
+ )["results"]
+
+ if len(documents_overview) != len(document_ids):
+ raise R2RException(
+ status_code=404,
+ message="One or more documents not found.",
+ ) from None
+
+ results = []
+
+ for idx, (
+ file_data,
+ doc_id,
+ doc_info,
+ file_size_in_bytes,
+ ) in enumerate(
+ zip(
+ file_datas,
+ document_ids,
+ documents_overview,
+ file_sizes_in_bytes,
+ strict=False,
+ )
+ ):
+ new_version = increment_version(doc_info.version)
+
+ updated_metadata = (
+ metadatas[idx] if metadatas else doc_info.metadata
+ )
+ updated_metadata["title"] = (
+ updated_metadata.get("title")
+ or file_data["filename"].split("/")[-1]
+ )
+
+ ingest_input = {
+ "file_data": file_data,
+ "user": user.model_dump(),
+ "metadata": updated_metadata,
+ "document_id": str(doc_id),
+ "version": new_version,
+ "ingestion_config": ingestion_config,
+ "size_in_bytes": file_size_in_bytes,
+ }
+
+ result = ingest_files(ingest_input)
+ results.append(result)
+
+ await asyncio.gather(*results)
+ if service.providers.ingestion.config.automatic_extraction:
+ raise R2RException(
+ status_code=501,
+ message="Automatic extraction not yet implemented for `simple` ingestion workflows.",
+ ) from None
+
+ async def ingest_chunks(input_data):
+ document_info = None
+ try:
+ from core.base import IngestionStatus
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input(
+ input_data
+ )
+
+ document_info = await service.ingest_chunks_ingress(**parsed_data)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.EMBEDDING
+ )
+ document_id = document_info.id
+
+ extractions = [
+ DocumentChunk(
+ id=(
+ generate_extraction_id(document_id, i)
+ if chunk.id is None
+ else chunk.id
+ ),
+ document_id=document_id,
+ collection_ids=[],
+ owner_id=document_info.owner_id,
+ data=chunk.text,
+ metadata=parsed_data["metadata"],
+ ).model_dump()
+ for i, chunk in enumerate(parsed_data["chunks"])
+ ]
+
+ embedding_generator = service.embed_document(extractions)
+ embeddings = [
+ embedding.model_dump()
+ async for embedding in embedding_generator
+ ]
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.STORING
+ )
+ storage_generator = service.store_embeddings(embeddings)
+ async for _ in storage_generator:
+ pass
+
+ await service.finalize_ingestion(document_info)
+
+ await service.update_document_status(
+ document_info, status=IngestionStatus.SUCCESS
+ )
+
+ collection_ids = parsed_data.get("collection_ids")
+
+ try:
+ # TODO - Move logic onto management service
+ if not collection_ids:
+ collection_id = generate_default_user_collection_id(
+ document_info.owner_id
+ )
+
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+
+ else:
+ for collection_id in collection_ids:
+ try:
+ name = document_info.title or "N/A"
+ description = ""
+ result = await service.providers.database.collections_handler.create_collection(
+ owner_id=document_info.owner_id,
+ name=name,
+ description=description,
+ collection_id=collection_id,
+ )
+ await service.providers.database.graphs_handler.create(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ graph_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Warning, could not create collection with error: {str(e)}"
+ )
+ await service.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id=document_info.id,
+ collection_id=collection_id,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await service.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
+ )
+
+ if service.providers.ingestion.config.automatic_extraction:
+ raise R2RException(
+ status_code=501,
+ message="Automatic extraction not yet implemented for `simple` ingestion workflows.",
+ ) from None
+
+ except Exception as e:
+ logger.error(
+ f"Error during assigning document to collection: {str(e)}"
+ )
+
+ except Exception as e:
+ if document_info is not None:
+ await service.update_document_status(
+ document_info,
+ status=IngestionStatus.FAILED,
+ metadata={"failure": f"{str(e)}"},
+ )
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during chunk ingestion: {str(e)}",
+ ) from e
+
+ async def update_chunk(input_data):
+ from core.main import IngestionServiceAdapter
+
+ try:
+ parsed_data = IngestionServiceAdapter.parse_update_chunk_input(
+ input_data
+ )
+ document_uuid = (
+ UUID(parsed_data["document_id"])
+ if isinstance(parsed_data["document_id"], str)
+ else parsed_data["document_id"]
+ )
+ extraction_uuid = (
+ UUID(parsed_data["id"])
+ if isinstance(parsed_data["id"], str)
+ else parsed_data["id"]
+ )
+
+ await service.update_chunk_ingress(
+ document_id=document_uuid,
+ chunk_id=extraction_uuid,
+ text=parsed_data.get("text"),
+ user=parsed_data["user"],
+ metadata=parsed_data.get("metadata"),
+ collection_ids=parsed_data.get("collection_ids"),
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during chunk update: {str(e)}",
+ ) from e
+
+ async def create_vector_index(input_data):
+ try:
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = (
+ IngestionServiceAdapter.parse_create_vector_index_input(
+ input_data
+ )
+ )
+
+ await service.providers.database.chunks_handler.create_index(
+ **parsed_data
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during vector index creation: {str(e)}",
+ ) from e
+
+ async def delete_vector_index(input_data):
+ try:
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = (
+ IngestionServiceAdapter.parse_delete_vector_index_input(
+ input_data
+ )
+ )
+
+ await service.providers.database.chunks_handler.delete_index(
+ **parsed_data
+ )
+
+ return {"status": "Vector index deleted successfully."}
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during vector index deletion: {str(e)}",
+ ) from e
+
+ async def update_document_metadata(input_data):
+ try:
+ from core.main import IngestionServiceAdapter
+
+ parsed_data = (
+ IngestionServiceAdapter.parse_update_document_metadata_input(
+ input_data
+ )
+ )
+
+ document_id = parsed_data["document_id"]
+ metadata = parsed_data["metadata"]
+ user = parsed_data["user"]
+
+ await service.update_document_metadata(
+ document_id=document_id,
+ metadata=metadata,
+ user=user,
+ )
+
+ return {
+ "message": "Document metadata update completed successfully.",
+ "document_id": str(document_id),
+ "task_id": None,
+ }
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Error during document metadata update: {str(e)}",
+ ) from e
+
+ return {
+ "ingest-files": ingest_files,
+ "ingest-chunks": ingest_chunks,
+ "update-chunk": update_chunk,
+ "update-document-metadata": update_document_metadata,
+ "create-vector-index": create_vector_index,
+ "delete-vector-index": delete_vector_index,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/__init__.py b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py
new file mode 100644
index 00000000..e6a6dec0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/__init__.py
@@ -0,0 +1,14 @@
+from .auth_service import AuthService
+from .graph_service import GraphService
+from .ingestion_service import IngestionService, IngestionServiceAdapter
+from .management_service import ManagementService
+from .retrieval_service import RetrievalService # type: ignore
+
+__all__ = [
+ "AuthService",
+ "IngestionService",
+ "IngestionServiceAdapter",
+ "ManagementService",
+ "GraphService",
+ "RetrievalService",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py
new file mode 100644
index 00000000..c04dd78c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/auth_service.py
@@ -0,0 +1,316 @@
+import logging
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from core.base import R2RException, Token
+from core.base.api.models import User
+from core.utils import generate_default_user_collection_id
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+class AuthService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ return await self.providers.auth.register(
+ email=email,
+ password=password,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+
+ async def send_verification_email(
+ self, email: str
+ ) -> tuple[str, datetime]:
+ return await self.providers.auth.send_verification_email(email=email)
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ if not self.config.auth.require_email_verification:
+ raise R2RException(
+ status_code=400, message="Email verification is not required"
+ )
+
+ user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
+ verification_code
+ )
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if not user or user.email != email:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ await self.providers.database.users_handler.mark_user_as_verified(
+ user_id
+ )
+ await self.providers.database.users_handler.remove_verification_code(
+ verification_code
+ )
+ return {"message": f"User account {user_id} verified successfully."}
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ return await self.providers.auth.login(email, password)
+
+ async def user(self, token: str) -> User:
+ token_data = await self.providers.auth.decode_token(token)
+ if not token_data.email:
+ raise R2RException(
+ status_code=401, message="Invalid authentication credentials"
+ )
+ user = await self.providers.database.users_handler.get_user_by_email(
+ token_data.email
+ )
+ if user is None:
+ raise R2RException(
+ status_code=401, message="Invalid authentication credentials"
+ )
+ return user
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ return await self.providers.auth.refresh_access_token(refresh_token)
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+ return await self.providers.auth.change_password(
+ user, current_password, new_password
+ )
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ return await self.providers.auth.request_password_reset(email)
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ return await self.providers.auth.confirm_password_reset(
+ reset_token, new_password
+ )
+
+ async def logout(self, token: str) -> dict[str, str]:
+ return await self.providers.auth.logout(token)
+
+ async def update_user(
+ self,
+ user_id: UUID,
+ email: Optional[str] = None,
+ is_superuser: Optional[bool] = None,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ limits_overrides: Optional[dict] = None,
+ merge_limits: bool = False,
+ new_metadata: Optional[dict] = None,
+ ) -> User:
+ user: User = (
+ await self.providers.database.users_handler.get_user_by_id(user_id)
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+ if email is not None:
+ user.email = email
+ if is_superuser is not None:
+ user.is_superuser = is_superuser
+ if name is not None:
+ user.name = name
+ if bio is not None:
+ user.bio = bio
+ if profile_picture is not None:
+ user.profile_picture = profile_picture
+ if limits_overrides is not None:
+ user.limits_overrides = limits_overrides
+ return await self.providers.database.users_handler.update_user(
+ user, merge_limits=merge_limits, new_metadata=new_metadata
+ )
+
+ async def delete_user(
+ self,
+ user_id: UUID,
+ password: Optional[str] = None,
+ delete_vector_data: bool = False,
+ is_superuser: bool = False,
+ ) -> dict[str, str]:
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+ if not is_superuser and not password:
+ raise R2RException(
+ status_code=422, message="Password is required for deletion"
+ )
+ if not (
+ is_superuser
+ or (
+ user.hashed_password is not None
+ and password is not None
+ and self.providers.auth.crypto_provider.verify_password(
+ plain_password=password,
+ hashed_password=user.hashed_password,
+ )
+ )
+ ):
+ raise R2RException(status_code=400, message="Incorrect password")
+ await self.providers.database.users_handler.delete_user_relational(
+ user_id
+ )
+
+ # Delete user's default collection
+ # TODO: We need to better define what happens to the user's data when they are deleted
+ collection_id = generate_default_user_collection_id(user_id)
+ await self.providers.database.collections_handler.delete_collection_relational(
+ collection_id
+ )
+
+ try:
+ await self.providers.database.graphs_handler.delete(
+ collection_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Error deleting graph for collection {collection_id}: {e}"
+ )
+
+ if delete_vector_data:
+ await self.providers.database.chunks_handler.delete_user_vector(
+ user_id
+ )
+ await self.providers.database.chunks_handler.delete_collection_vector(
+ collection_id
+ )
+
+ return {"message": f"User account {user_id} deleted successfully."}
+
+ async def clean_expired_blacklisted_tokens(
+ self,
+ max_age_hours: int = 7 * 24,
+ current_time: Optional[datetime] = None,
+ ):
+ await self.providers.database.token_handler.clean_expired_blacklisted_tokens(
+ max_age_hours, current_time
+ )
+
+ async def get_user_verification_code(
+ self,
+ user_id: UUID,
+ ) -> dict:
+ """Get only the verification code data for a specific user.
+
+ This method should be called after superuser authorization has been
+ verified.
+ """
+ verification_data = await self.providers.database.users_handler.get_user_validation_data(
+ user_id=user_id
+ )
+ return {
+ "verification_code": verification_data["verification_data"][
+ "verification_code"
+ ],
+ "expiry": verification_data["verification_data"][
+ "verification_code_expiry"
+ ],
+ }
+
+ async def get_user_reset_token(
+ self,
+ user_id: UUID,
+ ) -> dict:
+ """Get only the verification code data for a specific user.
+
+ This method should be called after superuser authorization has been
+ verified.
+ """
+ verification_data = await self.providers.database.users_handler.get_user_validation_data(
+ user_id=user_id
+ )
+ return {
+ "reset_token": verification_data["verification_data"][
+ "reset_token"
+ ],
+ "expiry": verification_data["verification_data"][
+ "reset_token_expiry"
+ ],
+ }
+
+ async def send_reset_email(self, email: str) -> dict:
+ """Generate a new verification code and send a reset email to the user.
+ Returns the verification code for testing/sandbox environments.
+
+ Args:
+ email (str): The email address of the user
+
+ Returns:
+ dict: Contains verification_code and message
+ """
+ return await self.providers.auth.send_reset_email(email)
+
+ async def create_user_api_key(
+ self, user_id: UUID, name: Optional[str], description: Optional[str]
+ ) -> dict:
+ """Generate a new API key for the user with optional name and
+ description.
+
+ Args:
+ user_id (UUID): The ID of the user
+ name (Optional[str]): Name of the API key
+ description (Optional[str]): Description of the API key
+
+ Returns:
+ dict: Contains the API key and message
+ """
+ return await self.providers.auth.create_user_api_key(
+ user_id=user_id, name=name, description=description
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ """Delete the API key for the user.
+
+ Args:
+ user_id (UUID): The ID of the user
+ key_id (str): The ID of the API key
+
+ Returns:
+ bool: True if the API key was deleted successfully
+ """
+ return await self.providers.auth.delete_user_api_key(
+ user_id=user_id, key_id=key_id
+ )
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ """List all API keys for the user.
+
+ Args:
+ user_id (UUID): The ID of the user
+
+ Returns:
+ dict: Contains the list of API keys
+ """
+ return await self.providers.auth.list_user_api_keys(user_id)
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/base.py b/.venv/lib/python3.12/site-packages/core/main/services/base.py
new file mode 100644
index 00000000..dcd98fd5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/base.py
@@ -0,0 +1,14 @@
+from abc import ABC
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+
+class Service(ABC):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ self.config = config
+ self.providers = providers
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py
new file mode 100644
index 00000000..56f32cf8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/graph_service.py
@@ -0,0 +1,1358 @@
+import asyncio
+import logging
+import math
+import random
+import re
+import time
+import uuid
+import xml.etree.ElementTree as ET
+from typing import Any, AsyncGenerator, Coroutine, Optional
+from uuid import UUID
+from xml.etree.ElementTree import Element
+
+from core.base import (
+ DocumentChunk,
+ GraphExtraction,
+ GraphExtractionStatus,
+ R2RDocumentProcessingError,
+)
+from core.base.abstractions import (
+ Community,
+ Entity,
+ GenerationConfig,
+ GraphConstructionStatus,
+ R2RException,
+ Relationship,
+ StoreType,
+)
+from core.base.api.models import GraphResponse
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128
+
+
+async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]:
+ """Collects all results from an async generator into a list."""
+ results = []
+ async for res in result_gen:
+ results.append(res)
+ return results
+
+
+class GraphService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def create_entity(
+ self,
+ name: str,
+ description: str,
+ parent_id: UUID,
+ category: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Entity:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return await self.providers.database.graphs_handler.entities.create(
+ name=name,
+ parent_id=parent_id,
+ store_type=StoreType.GRAPHS,
+ category=category,
+ description=description,
+ description_embedding=description_embedding,
+ metadata=metadata,
+ )
+
+ async def update_entity(
+ self,
+ entity_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ category: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Entity:
+ description_embedding = None
+ if description is not None:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return await self.providers.database.graphs_handler.entities.update(
+ entity_id=entity_id,
+ store_type=StoreType.GRAPHS,
+ name=name,
+ description=description,
+ description_embedding=description_embedding,
+ category=category,
+ metadata=metadata,
+ )
+
+ async def delete_entity(
+ self,
+ parent_id: UUID,
+ entity_id: UUID,
+ ):
+ return await self.providers.database.graphs_handler.entities.delete(
+ parent_id=parent_id,
+ entity_ids=[entity_id],
+ store_type=StoreType.GRAPHS,
+ )
+
+ async def get_entities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ entity_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ return await self.providers.database.graphs_handler.get_entities(
+ parent_id=parent_id,
+ offset=offset,
+ limit=limit,
+ entity_ids=entity_ids,
+ entity_names=entity_names,
+ include_embeddings=include_embeddings,
+ )
+
+ async def create_relationship(
+ self,
+ subject: str,
+ subject_id: UUID,
+ predicate: str,
+ object: str,
+ object_id: UUID,
+ parent_id: UUID,
+ description: str | None = None,
+ weight: float | None = 1.0,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Relationship:
+ description_embedding = None
+ if description:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return (
+ await self.providers.database.graphs_handler.relationships.create(
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ parent_id=parent_id,
+ description=description,
+ description_embedding=description_embedding,
+ weight=weight,
+ metadata=metadata,
+ store_type=StoreType.GRAPHS,
+ )
+ )
+
+ async def delete_relationship(
+ self,
+ parent_id: UUID,
+ relationship_id: UUID,
+ ):
+ return (
+ await self.providers.database.graphs_handler.relationships.delete(
+ parent_id=parent_id,
+ relationship_ids=[relationship_id],
+ store_type=StoreType.GRAPHS,
+ )
+ )
+
+ async def update_relationship(
+ self,
+ relationship_id: UUID,
+ subject: Optional[str] = None,
+ subject_id: Optional[UUID] = None,
+ predicate: Optional[str] = None,
+ object: Optional[str] = None,
+ object_id: Optional[UUID] = None,
+ description: Optional[str] = None,
+ weight: Optional[float] = None,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Relationship:
+ description_embedding = None
+ if description is not None:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(description)
+ )
+
+ return (
+ await self.providers.database.graphs_handler.relationships.update(
+ relationship_id=relationship_id,
+ subject=subject,
+ subject_id=subject_id,
+ predicate=predicate,
+ object=object,
+ object_id=object_id,
+ description=description,
+ description_embedding=description_embedding,
+ weight=weight,
+ metadata=metadata,
+ store_type=StoreType.GRAPHS,
+ )
+ )
+
+ async def get_relationships(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ relationship_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ ):
+ return await self.providers.database.graphs_handler.relationships.get(
+ parent_id=parent_id,
+ store_type=StoreType.GRAPHS,
+ offset=offset,
+ limit=limit,
+ relationship_ids=relationship_ids,
+ entity_names=entity_names,
+ )
+
+ async def create_community(
+ self,
+ parent_id: UUID,
+ name: str,
+ summary: str,
+ findings: Optional[list[str]],
+ rating: Optional[float],
+ rating_explanation: Optional[str],
+ ) -> Community:
+ description_embedding = str(
+ await self.providers.embedding.async_get_embedding(summary)
+ )
+ return await self.providers.database.graphs_handler.communities.create(
+ parent_id=parent_id,
+ store_type=StoreType.GRAPHS,
+ name=name,
+ summary=summary,
+ description_embedding=description_embedding,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ async def update_community(
+ self,
+ community_id: UUID,
+ name: Optional[str],
+ summary: Optional[str],
+ findings: Optional[list[str]],
+ rating: Optional[float],
+ rating_explanation: Optional[str],
+ ) -> Community:
+ summary_embedding = None
+ if summary is not None:
+ summary_embedding = str(
+ await self.providers.embedding.async_get_embedding(summary)
+ )
+
+ return await self.providers.database.graphs_handler.communities.update(
+ community_id=community_id,
+ store_type=StoreType.GRAPHS,
+ name=name,
+ summary=summary,
+ summary_embedding=summary_embedding,
+ findings=findings,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ )
+
+ async def delete_community(
+ self,
+ parent_id: UUID,
+ community_id: UUID,
+ ) -> None:
+ await self.providers.database.graphs_handler.communities.delete(
+ parent_id=parent_id,
+ community_id=community_id,
+ )
+
+ async def get_communities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ community_ids: Optional[list[UUID]] = None,
+ community_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ return await self.providers.database.graphs_handler.get_communities(
+ parent_id=parent_id,
+ offset=offset,
+ limit=limit,
+ community_ids=community_ids,
+ include_embeddings=include_embeddings,
+ )
+
+ async def list_graphs(
+ self,
+ offset: int,
+ limit: int,
+ graph_ids: Optional[list[UUID]] = None,
+ collection_id: Optional[UUID] = None,
+ ) -> dict[str, list[GraphResponse] | int]:
+ return await self.providers.database.graphs_handler.list_graphs(
+ offset=offset,
+ limit=limit,
+ filter_graph_ids=graph_ids,
+ filter_collection_id=collection_id,
+ )
+
+ async def update_graph(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> GraphResponse:
+ return await self.providers.database.graphs_handler.update(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ )
+
+ async def reset_graph(self, id: UUID) -> bool:
+ await self.providers.database.graphs_handler.reset(
+ parent_id=id,
+ )
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.PENDING,
+ )
+ return True
+
+ async def get_document_ids_for_create_graph(
+ self,
+ collection_id: UUID,
+ **kwargs,
+ ):
+ document_status_filter = [
+ GraphExtractionStatus.PENDING,
+ GraphExtractionStatus.FAILED,
+ ]
+
+ return await self.providers.database.documents_handler.get_document_ids_by_status(
+ status_type="extraction_status",
+ status=[str(ele) for ele in document_status_filter],
+ collection_id=collection_id,
+ )
+
+ async def graph_search_results_entity_description(
+ self,
+ document_id: UUID,
+ max_description_input_length: int,
+ batch_size: int = 256,
+ **kwargs,
+ ):
+ """A new implementation of the old GraphDescriptionPipe logic inline.
+ No references to pipe objects.
+
+ We:
+ 1) Count how many entities are in the document
+ 2) Process them in batches of `batch_size`
+ 3) For each batch, we retrieve the entity map and possibly call LLM for missing descriptions
+ """
+ start_time = time.time()
+ logger.info(
+ f"GraphService: Running graph_search_results_entity_description for doc={document_id}"
+ )
+
+ # Count how many doc-entities exist
+ entity_count = (
+ await self.providers.database.graphs_handler.get_entity_count(
+ document_id=document_id,
+ distinct=True,
+ entity_table_name="documents_entities", # or whichever table
+ )
+ )
+ logger.info(
+ f"GraphService: Found {entity_count} doc-entities to describe."
+ )
+
+ all_results = []
+ num_batches = math.ceil(entity_count / batch_size)
+
+ for i in range(num_batches):
+ offset = i * batch_size
+ limit = batch_size
+
+ logger.info(
+ f"GraphService: describing batch {i + 1}/{num_batches}, offset={offset}, limit={limit}"
+ )
+
+ # Actually handle describing the entities in the batch
+ # We'll collect them into a list via an async generator
+ gen = self._describe_entities_in_document_batch(
+ document_id=document_id,
+ offset=offset,
+ limit=limit,
+ max_description_input_length=max_description_input_length,
+ )
+ batch_results = await _collect_async_results(gen)
+ all_results.append(batch_results)
+
+ # Mark the doc's extraction status as success
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=document_id,
+ status_type="extraction_status",
+ status=GraphExtractionStatus.SUCCESS,
+ )
+ logger.info(
+ f"GraphService: Completed graph_search_results_entity_description for doc {document_id} in {time.time() - start_time:.2f}s."
+ )
+ return all_results
+
+ async def _describe_entities_in_document_batch(
+ self,
+ document_id: UUID,
+ offset: int,
+ limit: int,
+ max_description_input_length: int,
+ ) -> AsyncGenerator[str, None]:
+ """Core logic that replaces GraphDescriptionPipe._run_logic for a
+ particular document/batch.
+
+ Yields entity-names or some textual result as each entity is updated.
+ """
+ start_time = time.time()
+ logger.info(
+ f"Started describing doc={document_id}, offset={offset}, limit={limit}"
+ )
+
+ # 1) Get the "entity map" from the DB
+ entity_map = (
+ await self.providers.database.graphs_handler.get_entity_map(
+ offset=offset, limit=limit, document_id=document_id
+ )
+ )
+ total_entities = len(entity_map)
+ logger.info(
+ f"_describe_entities_in_document_batch: got {total_entities} items in entity_map for doc={document_id}."
+ )
+
+ # 2) For each entity name in the map, we gather sub-entities and relationships
+ tasks: list[Coroutine[Any, Any, str]] = []
+ tasks.extend(
+ self._process_entity_for_description(
+ entities=[
+ entity if isinstance(entity, Entity) else Entity(**entity)
+ for entity in entity_info["entities"]
+ ],
+ relationships=[
+ rel
+ if isinstance(rel, Relationship)
+ else Relationship(**rel)
+ for rel in entity_info["relationships"]
+ ],
+ document_id=document_id,
+ max_description_input_length=max_description_input_length,
+ )
+ for entity_name, entity_info in entity_map.items()
+ )
+
+ # 3) Wait for all tasks, yield as they complete
+ idx = 0
+ for coro in asyncio.as_completed(tasks):
+ result = await coro
+ idx += 1
+ if idx % 100 == 0:
+ logger.info(
+ f"_describe_entities_in_document_batch: {idx}/{total_entities} described for doc={document_id}"
+ )
+ yield result
+
+ logger.info(
+ f"Finished describing doc={document_id} batch offset={offset} in {time.time() - start_time:.2f}s."
+ )
+
+ async def _process_entity_for_description(
+ self,
+ entities: list[Entity],
+ relationships: list[Relationship],
+ document_id: UUID,
+ max_description_input_length: int,
+ ) -> str:
+ """Adapted from the old process_entity function in
+ GraphDescriptionPipe.
+
+ If entity has no description, call an LLM to create one, then store it.
+ Returns the name of the top entity (or could store more details).
+ """
+
+ def truncate_info(info_list: list[str], max_length: int) -> str:
+ """Shuffles lines of info to try to keep them distinct, then
+ accumulates until hitting max_length."""
+ random.shuffle(info_list)
+ truncated_info = ""
+ current_length = 0
+ for info in info_list:
+ if current_length + len(info) > max_length:
+ break
+ truncated_info += info + "\n"
+ current_length += len(info)
+ return truncated_info
+
+ # Grab a doc-level summary (optional) to feed into the prompt
+ response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+ document_summary = (
+ response["results"][0].summary if response["results"] else None
+ )
+
+ # Synthesize a minimal “entity info” string + relationship summary
+ entity_info = [
+ f"{e.name}, {e.description or 'NONE'}" for e in entities
+ ]
+ relationships_txt = [
+ f"{i + 1}: {r.subject}, {r.object}, {r.predicate} - Summary: {r.description or ''}"
+ for i, r in enumerate(relationships)
+ ]
+
+ # We'll describe only the first entity for simplicity
+ # or you could do them all if needed
+ main_entity = entities[0]
+
+ if not main_entity.description:
+ # We only call LLM if the entity is missing a description
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt,
+ task_inputs={
+ "document_summary": document_summary,
+ "entity_info": truncate_info(
+ entity_info, max_description_input_length
+ ),
+ "relationships_txt": truncate_info(
+ relationships_txt, max_description_input_length
+ ),
+ },
+ )
+
+ # Call the LLM
+ gen_config = (
+ self.providers.database.config.graph_creation_settings.generation_config
+ or GenerationConfig(model=self.config.app.fast_llm)
+ )
+ llm_resp = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=gen_config,
+ )
+ new_description = llm_resp.choices[0].message.content
+
+ if not new_description:
+ logger.error(
+ f"No LLM description returned for entity={main_entity.name}"
+ )
+ return main_entity.name
+
+ # create embedding
+ embed = (
+ await self.providers.embedding.async_get_embeddings(
+ [new_description]
+ )
+ )[0]
+
+ # update DB
+ main_entity.description = new_description
+ main_entity.description_embedding = embed
+
+ # Use a method to upsert entity in `documents_entities` or your table
+ await self.providers.database.graphs_handler.add_entities(
+ [main_entity],
+ table_name="documents_entities",
+ )
+
+ return main_entity.name
+
+ async def graph_search_results_clustering(
+ self,
+ collection_id: UUID,
+ generation_config: GenerationConfig,
+ leiden_params: dict,
+ **kwargs,
+ ):
+ """
+ Replacement for the old GraphClusteringPipe logic:
+ 1) call perform_graph_clustering on the DB
+ 2) return the result
+ """
+ logger.info(
+ f"Running inline clustering for collection={collection_id} with params={leiden_params}"
+ )
+ return await self._perform_graph_clustering(
+ collection_id=collection_id,
+ generation_config=generation_config,
+ leiden_params=leiden_params,
+ )
+
+ async def _perform_graph_clustering(
+ self,
+ collection_id: UUID,
+ generation_config: GenerationConfig,
+ leiden_params: dict,
+ ) -> dict:
+ """The actual clustering logic (previously in
+ GraphClusteringPipe.cluster_graph_search_results)."""
+ num_communities = await self.providers.database.graphs_handler.perform_graph_clustering(
+ collection_id=collection_id,
+ leiden_params=leiden_params,
+ )
+ return {"num_communities": num_communities}
+
+ async def graph_search_results_community_summary(
+ self,
+ offset: int,
+ limit: int,
+ max_summary_input_length: int,
+ generation_config: GenerationConfig,
+ collection_id: UUID,
+ leiden_params: Optional[dict] = None,
+ **kwargs,
+ ):
+ """Replacement for the old GraphCommunitySummaryPipe logic.
+
+ Summarizes communities after clustering. Returns an async generator or
+ you can collect into a list.
+ """
+ logger.info(
+ f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}"
+ )
+ # We call an internal function that yields summaries
+ gen = self._summarize_communities(
+ offset=offset,
+ limit=limit,
+ max_summary_input_length=max_summary_input_length,
+ generation_config=generation_config,
+ collection_id=collection_id,
+ leiden_params=leiden_params or {},
+ )
+ return await _collect_async_results(gen)
+
+ async def _summarize_communities(
+ self,
+ offset: int,
+ limit: int,
+ max_summary_input_length: int,
+ generation_config: GenerationConfig,
+ collection_id: UUID,
+ leiden_params: dict,
+ ) -> AsyncGenerator[dict, None]:
+ """Does the community summary logic from
+ GraphCommunitySummaryPipe._run_logic.
+
+ Yields each summary dictionary as it completes.
+ """
+ start_time = time.time()
+ logger.info(
+ f"Starting community summarization for collection={collection_id}"
+ )
+
+ # get all entities & relationships
+ (
+ all_entities,
+ _,
+ ) = await self.providers.database.graphs_handler.get_entities(
+ parent_id=collection_id,
+ offset=0,
+ limit=-1,
+ include_embeddings=False,
+ )
+ (
+ all_relationships,
+ _,
+ ) = await self.providers.database.graphs_handler.get_relationships(
+ parent_id=collection_id,
+ offset=0,
+ limit=-1,
+ include_embeddings=False,
+ )
+
+ # We can optionally re-run the clustering to produce fresh community assignments
+ (
+ _,
+ community_clusters,
+ ) = await self.providers.database.graphs_handler._cluster_and_add_community_info(
+ relationships=all_relationships,
+ leiden_params=leiden_params,
+ collection_id=collection_id,
+ )
+
+ # Group clusters
+ clusters: dict[Any, list[str]] = {}
+ for item in community_clusters:
+ cluster_id = item["cluster"]
+ node_name = item["node"]
+ clusters.setdefault(cluster_id, []).append(node_name)
+
+ # create an async job for each cluster
+ tasks: list[Coroutine[Any, Any, dict]] = []
+
+ tasks.extend(
+ self._process_community_summary(
+ community_id=uuid.uuid4(),
+ nodes=nodes,
+ all_entities=all_entities,
+ all_relationships=all_relationships,
+ max_summary_input_length=max_summary_input_length,
+ generation_config=generation_config,
+ collection_id=collection_id,
+ )
+ for nodes in clusters.values()
+ )
+
+ total_jobs = len(tasks)
+ results_returned = 0
+ total_errors = 0
+
+ for coro in asyncio.as_completed(tasks):
+ summary = await coro
+ results_returned += 1
+ if results_returned % 50 == 0:
+ logger.info(
+ f"Community summaries: {results_returned}/{total_jobs} done in {time.time() - start_time:.2f}s"
+ )
+ if "error" in summary:
+ total_errors += 1
+ yield summary
+
+ if total_errors > 0:
+ logger.warning(
+ f"{total_errors} communities failed summarization out of {total_jobs}"
+ )
+
+ async def _process_community_summary(
+ self,
+ community_id: UUID,
+ nodes: list[str],
+ all_entities: list[Entity],
+ all_relationships: list[Relationship],
+ max_summary_input_length: int,
+ generation_config: GenerationConfig,
+ collection_id: UUID,
+ ) -> dict:
+ """
+ Summarize a single community: gather all relevant entities/relationships, call LLM to generate an XML block,
+ parse it, store the result as a community in DB.
+ """
+ # (Equivalent to process_community in old code)
+ # fetch the collection description (optional)
+ response = await self.providers.database.collections_handler.get_collections_overview(
+ offset=0,
+ limit=1,
+ filter_collection_ids=[collection_id],
+ )
+ collection_description = (
+ response["results"][0].description if response["results"] else None # type: ignore
+ )
+
+ # filter out relevant entities / relationships
+ entities = [e for e in all_entities if e.name in nodes]
+ relationships = [
+ r
+ for r in all_relationships
+ if r.subject in nodes and r.object in nodes
+ ]
+ if not entities and not relationships:
+ return {
+ "community_id": community_id,
+ "error": f"No data in this community (nodes={nodes})",
+ }
+
+ # Create the big input text for the LLM
+ input_text = await self._community_summary_prompt(
+ entities,
+ relationships,
+ max_summary_input_length,
+ )
+
+ # Attempt up to 3 times to parse
+ for attempt in range(3):
+ try:
+ # Build the prompt
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=self.providers.database.config.graph_enrichment_settings.graph_communities_prompt,
+ task_inputs={
+ "collection_description": collection_description,
+ "input_text": input_text,
+ },
+ )
+ llm_resp = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=generation_config,
+ )
+ llm_text = llm_resp.choices[0].message.content or ""
+
+ # find <community>...</community> XML
+ match = re.search(
+ r"<community>.*?</community>", llm_text, re.DOTALL
+ )
+ if not match:
+ raise ValueError(
+ "No <community> XML found in LLM response"
+ )
+
+ xml_content = match.group(0)
+ root = ET.fromstring(xml_content)
+
+ # extract fields
+ name_elem = root.find("name")
+ summary_elem = root.find("summary")
+ rating_elem = root.find("rating")
+ rating_expl_elem = root.find("rating_explanation")
+ findings_elem = root.find("findings")
+
+ name = name_elem.text if name_elem is not None else ""
+ summary = summary_elem.text if summary_elem is not None else ""
+ rating = (
+ float(rating_elem.text)
+ if isinstance(rating_elem, Element) and rating_elem.text
+ else ""
+ )
+ rating_explanation = (
+ rating_expl_elem.text
+ if rating_expl_elem is not None
+ else None
+ )
+ findings = (
+ [f.text for f in findings_elem.findall("finding")]
+ if findings_elem is not None
+ else []
+ )
+
+ # build embedding
+ embed_text = (
+ "Summary:\n"
+ + (summary or "")
+ + "\n\nFindings:\n"
+ + "\n".join(
+ finding for finding in findings if finding is not None
+ )
+ )
+ embedding = await self.providers.embedding.async_get_embedding(
+ embed_text
+ )
+
+ # build Community object
+ community = Community(
+ community_id=community_id,
+ collection_id=collection_id,
+ name=name,
+ summary=summary,
+ rating=rating,
+ rating_explanation=rating_explanation,
+ findings=findings,
+ description_embedding=embedding,
+ )
+
+ # store it
+ await self.providers.database.graphs_handler.add_community(
+ community
+ )
+
+ return {
+ "community_id": community_id,
+ "name": name,
+ }
+
+ except Exception as e:
+ logger.error(
+ f"Error summarizing community {community_id}: {e}"
+ )
+ if attempt == 2:
+ return {"community_id": community_id, "error": str(e)}
+ await asyncio.sleep(1)
+
+ # fallback
+ return {"community_id": community_id, "error": "Failed after retries"}
+
+ async def _community_summary_prompt(
+ self,
+ entities: list[Entity],
+ relationships: list[Relationship],
+ max_summary_input_length: int,
+ ) -> str:
+ """Gathers the entity/relationship text, tries not to exceed
+ `max_summary_input_length`."""
+ # Group them by entity.name
+ entity_map: dict[str, dict] = {}
+ for e in entities:
+ entity_map.setdefault(
+ e.name, {"entities": [], "relationships": []}
+ )
+ entity_map[e.name]["entities"].append(e)
+
+ for r in relationships:
+ # subject
+ entity_map.setdefault(
+ r.subject, {"entities": [], "relationships": []}
+ )
+ entity_map[r.subject]["relationships"].append(r)
+
+ # sort by # of relationships
+ sorted_entries = sorted(
+ entity_map.items(),
+ key=lambda x: len(x[1]["relationships"]),
+ reverse=True,
+ )
+
+ # build up the prompt text
+ prompt_chunks = []
+ cur_len = 0
+ for entity_name, data in sorted_entries:
+ block = f"\nEntity: {entity_name}\nDescriptions:\n"
+ block += "\n".join(
+ f"{e.id},{(e.description or '')}" for e in data["entities"]
+ )
+ block += "\nRelationships:\n"
+ block += "\n".join(
+ f"{r.id},{r.subject},{r.object},{r.predicate},{r.description or ''}"
+ for r in data["relationships"]
+ )
+ # check length
+ if cur_len + len(block) > max_summary_input_length:
+ prompt_chunks.append(
+ block[: max_summary_input_length - cur_len]
+ )
+ break
+ else:
+ prompt_chunks.append(block)
+ cur_len += len(block)
+
+ return "".join(prompt_chunks)
+
+ async def delete(
+ self,
+ collection_id: UUID,
+ **kwargs,
+ ):
+ return await self.providers.database.graphs_handler.delete(
+ collection_id=collection_id,
+ )
+
+ async def graph_search_results_extraction(
+ self,
+ document_id: UUID,
+ generation_config: GenerationConfig,
+ entity_types: list[str],
+ relation_types: list[str],
+ chunk_merge_count: int,
+ filter_out_existing_chunks: bool = True,
+ total_tasks: Optional[int] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]:
+ """The original “extract Graph from doc” logic, but inlined instead of
+ referencing a pipe."""
+ start_time = time.time()
+
+ logger.info(
+ f"Graph Extraction: Processing document {document_id} for graph extraction"
+ )
+
+ # Retrieve chunks from DB
+ chunks = []
+ limit = 100
+ offset = 0
+ while True:
+ chunk_req = await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=offset,
+ limit=limit,
+ )
+ new_chunk_objs = [
+ DocumentChunk(
+ id=chunk["id"],
+ document_id=chunk["document_id"],
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ data=chunk["text"],
+ metadata=chunk["metadata"],
+ )
+ for chunk in chunk_req["results"]
+ ]
+ chunks.extend(new_chunk_objs)
+ if len(chunk_req["results"]) < limit:
+ break
+ offset += limit
+
+ if not chunks:
+ logger.info(f"No chunks found for document {document_id}")
+ raise R2RException(
+ message="No chunks found for document",
+ status_code=404,
+ )
+
+ # Possibly filter out any chunks that have already been processed
+ if filter_out_existing_chunks:
+ existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
+ document_id=document_id
+ )
+ before_count = len(chunks)
+ chunks = [c for c in chunks if c.id not in existing_chunk_ids]
+ logger.info(
+ f"Filtered out {len(existing_chunk_ids)} existing chunk-IDs. {before_count}->{len(chunks)} remain."
+ )
+ if not chunks:
+ return # nothing left to yield
+
+ # sort by chunk_order if present
+ chunks = sorted(
+ chunks,
+ key=lambda x: x.metadata.get("chunk_order", float("inf")),
+ )
+
+ # group them
+ grouped_chunks = [
+ chunks[i : i + chunk_merge_count]
+ for i in range(0, len(chunks), chunk_merge_count)
+ ]
+
+ logger.info(
+ f"Graph Extraction: Created {len(grouped_chunks)} tasks for doc={document_id}"
+ )
+ tasks = [
+ asyncio.create_task(
+ self._extract_graph_search_results_from_chunk_group(
+ chunk_group,
+ generation_config,
+ entity_types,
+ relation_types,
+ )
+ )
+ for chunk_group in grouped_chunks
+ ]
+
+ completed_tasks = 0
+ for t in asyncio.as_completed(tasks):
+ try:
+ yield await t
+ completed_tasks += 1
+ if completed_tasks % 100 == 0:
+ logger.info(
+ f"Graph Extraction: completed {completed_tasks}/{len(tasks)} tasks"
+ )
+ except Exception as e:
+ logger.error(f"Error extracting from chunk group: {e}")
+ yield R2RDocumentProcessingError(
+ document_id=document_id,
+ error_message=str(e),
+ )
+
+ logger.info(
+ f"Graph Extraction: done with {document_id}, time={time.time() - start_time:.2f}s"
+ )
+
+ async def _extract_graph_search_results_from_chunk_group(
+ self,
+ chunks: list[DocumentChunk],
+ generation_config: GenerationConfig,
+ entity_types: list[str],
+ relation_types: list[str],
+ retries: int = 5,
+ delay: int = 2,
+ ) -> GraphExtraction:
+ """(Equivalent to _extract_graph_search_results in old code.) Merges
+ chunk data, calls LLM, parses XML, returns GraphExtraction object."""
+ combined_extraction: str = " ".join(
+ [
+ c.data.decode("utf-8") if isinstance(c.data, bytes) else c.data
+ for c in chunks
+ if c.data
+ ]
+ )
+
+ # Possibly get doc-level summary
+ doc_id = chunks[0].document_id
+ response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[doc_id],
+ )
+ document_summary = (
+ response["results"][0].summary if response["results"] else None
+ )
+
+ # Build messages/prompt
+ prompt_name = self.providers.database.config.graph_creation_settings.graph_extraction_prompt
+ messages = (
+ await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=prompt_name,
+ task_inputs={
+ "document_summary": document_summary or "",
+ "input": combined_extraction,
+ "entity_types": "\n".join(entity_types),
+ "relation_types": "\n".join(relation_types),
+ },
+ )
+ )
+
+ for attempt in range(retries):
+ try:
+ resp = await self.providers.llm.aget_completion(
+ messages, generation_config=generation_config
+ )
+ graph_search_results_str = resp.choices[0].message.content
+
+ if not graph_search_results_str:
+ raise R2RException(
+ "No extraction found in LLM response.",
+ 400,
+ )
+
+ # parse the XML
+ (
+ entities,
+ relationships,
+ ) = await self._parse_graph_search_results_extraction_xml(
+ graph_search_results_str, chunks
+ )
+ return GraphExtraction(
+ entities=entities, relationships=relationships
+ )
+
+ except Exception as e:
+ if attempt < retries - 1:
+ await asyncio.sleep(delay)
+ continue
+ else:
+ logger.error(
+ f"All extraction attempts for doc={doc_id} and chunks{[chunk.id for chunk in chunks]} failed with error:\n{e}"
+ )
+ return GraphExtraction(entities=[], relationships=[])
+
+ return GraphExtraction(entities=[], relationships=[])
+
+ async def _parse_graph_search_results_extraction_xml(
+ self, response_str: str, chunks: list[DocumentChunk]
+ ) -> tuple[list[Entity], list[Relationship]]:
+ """Helper to parse the LLM's XML format, handle edge cases/cleanup,
+ produce Entities/Relationships."""
+
+ def sanitize_xml(r: str) -> str:
+ # Remove markdown fences
+ r = re.sub(r"```xml|```", "", r)
+ # Remove xml instructions or userStyle
+ r = re.sub(r"<\?.*?\?>", "", r)
+ r = re.sub(r"<userStyle>.*?</userStyle>", "", r)
+ # Replace bare `&` with `&amp;`
+ r = re.sub(r"&(?!amp;|quot;|apos;|lt;|gt;)", "&amp;", r)
+ # Also remove <root> if it appears
+ r = r.replace("<root>", "").replace("</root>", "")
+ return r.strip()
+
+ cleaned_xml = sanitize_xml(response_str)
+ wrapped = f"<root>{cleaned_xml}</root>"
+ try:
+ root = ET.fromstring(wrapped)
+ except ET.ParseError:
+ raise R2RException(
+ f"Failed to parse XML:\nData: {wrapped[:1000]}...", 400
+ ) from None
+
+ entities_elems = root.findall(".//entity")
+ if (
+ len(response_str) > MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH
+ and len(entities_elems) == 0
+ ):
+ raise R2RException(
+ f"No <entity> found in LLM XML, possibly malformed. Response excerpt: {response_str[:300]}",
+ 400,
+ )
+
+ # build entity objects
+ doc_id = chunks[0].document_id
+ chunk_ids = [c.id for c in chunks]
+ entities_list: list[Entity] = []
+ for element in entities_elems:
+ name_attr = element.get("name")
+ type_elem = element.find("type")
+ desc_elem = element.find("description")
+ category = type_elem.text if type_elem is not None else None
+ desc = desc_elem.text if desc_elem is not None else None
+ desc_embed = await self.providers.embedding.async_get_embedding(
+ desc or ""
+ )
+ ent = Entity(
+ category=category,
+ description=desc,
+ name=name_attr,
+ parent_id=doc_id,
+ chunk_ids=chunk_ids,
+ description_embedding=desc_embed,
+ attributes={},
+ )
+ entities_list.append(ent)
+
+ # build relationship objects
+ relationships_list: list[Relationship] = []
+ rel_elems = root.findall(".//relationship")
+ for r_elem in rel_elems:
+ source_elem = r_elem.find("source")
+ target_elem = r_elem.find("target")
+ type_elem = r_elem.find("type")
+ desc_elem = r_elem.find("description")
+ weight_elem = r_elem.find("weight")
+ try:
+ subject = source_elem.text if source_elem is not None else ""
+ object_ = target_elem.text if target_elem is not None else ""
+ predicate = type_elem.text if type_elem is not None else ""
+ desc = desc_elem.text if desc_elem is not None else ""
+ weight = (
+ float(weight_elem.text)
+ if isinstance(weight_elem, Element) and weight_elem.text
+ else ""
+ )
+ embed = await self.providers.embedding.async_get_embedding(
+ desc or ""
+ )
+
+ rel = Relationship(
+ subject=subject,
+ predicate=predicate,
+ object=object_,
+ description=desc,
+ weight=weight,
+ parent_id=doc_id,
+ chunk_ids=chunk_ids,
+ attributes={},
+ description_embedding=embed,
+ )
+ relationships_list.append(rel)
+ except Exception:
+ continue
+ return entities_list, relationships_list
+
+ async def store_graph_search_results_extractions(
+ self,
+ graph_search_results_extractions: list[GraphExtraction],
+ ):
+ """Stores a batch of knowledge graph extractions in the DB."""
+ for extraction in graph_search_results_extractions:
+ # Map name->id after creation
+ entities_id_map = {}
+ for e in extraction.entities:
+ if e.parent_id is not None:
+ result = await self.providers.database.graphs_handler.entities.create(
+ name=e.name,
+ parent_id=e.parent_id,
+ store_type=StoreType.DOCUMENTS,
+ category=e.category,
+ description=e.description,
+ description_embedding=e.description_embedding,
+ chunk_ids=e.chunk_ids,
+ metadata=e.metadata,
+ )
+ entities_id_map[e.name] = result.id
+ else:
+ logger.warning(f"Skipping entity with None parent_id: {e}")
+
+ # Insert relationships
+ for rel in extraction.relationships:
+ subject_id = entities_id_map.get(rel.subject)
+ object_id = entities_id_map.get(rel.object)
+ parent_id = rel.parent_id
+
+ if any(
+ id is None for id in (subject_id, object_id, parent_id)
+ ):
+ logger.warning(f"Missing ID for relationship: {rel}")
+ continue
+
+ assert isinstance(subject_id, UUID)
+ assert isinstance(object_id, UUID)
+ assert isinstance(parent_id, UUID)
+
+ await self.providers.database.graphs_handler.relationships.create(
+ subject=rel.subject,
+ subject_id=subject_id,
+ predicate=rel.predicate,
+ object=rel.object,
+ object_id=object_id,
+ parent_id=parent_id,
+ description=rel.description,
+ description_embedding=rel.description_embedding,
+ weight=rel.weight,
+ metadata=rel.metadata,
+ store_type=StoreType.DOCUMENTS,
+ )
+
+ async def deduplicate_document_entities(
+ self,
+ document_id: UUID,
+ ):
+ """
+ Inlined from old code: merges duplicates by name, calls LLM for a new consolidated description, updates the record.
+ """
+ merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks(
+ parent_id=document_id,
+ store_type=StoreType.DOCUMENTS,
+ )
+
+ # Grab doc summary
+ response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+ document_summary = (
+ response["results"][0].summary if response["results"] else None
+ )
+
+ # For each merged entity
+ for original_entities, merged_entity in merged_results:
+ # Summarize them with LLM
+ entity_info = "\n".join(
+ e.description for e in original_entities if e.description
+ )
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt,
+ task_inputs={
+ "document_summary": document_summary,
+ "entity_info": f"{merged_entity.name}\n{entity_info}",
+ "relationships_txt": "",
+ },
+ )
+ gen_config = (
+ self.config.database.graph_creation_settings.generation_config
+ or GenerationConfig(model=self.config.app.fast_llm)
+ )
+ resp = await self.providers.llm.aget_completion(
+ messages, generation_config=gen_config
+ )
+ new_description = resp.choices[0].message.content
+
+ new_embedding = await self.providers.embedding.async_get_embedding(
+ new_description or ""
+ )
+
+ if merged_entity.id is not None:
+ await self.providers.database.graphs_handler.entities.update(
+ entity_id=merged_entity.id,
+ store_type=StoreType.DOCUMENTS,
+ description=new_description,
+ description_embedding=str(new_embedding),
+ )
+ else:
+ logger.warning("Skipping update for entity with None id")
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py
new file mode 100644
index 00000000..55b06911
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/ingestion_service.py
@@ -0,0 +1,983 @@
+import asyncio
+import json
+import logging
+from datetime import datetime
+from typing import Any, AsyncGenerator, Optional, Sequence
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core.base import (
+ Document,
+ DocumentChunk,
+ DocumentResponse,
+ DocumentType,
+ GenerationConfig,
+ IngestionStatus,
+ R2RException,
+ RawChunk,
+ UnprocessedChunk,
+ Vector,
+ VectorEntry,
+ VectorType,
+ generate_id,
+)
+from core.base.abstractions import (
+ ChunkEnrichmentSettings,
+ IndexMeasure,
+ IndexMethod,
+ R2RDocumentProcessingError,
+ VectorTableName,
+)
+from core.base.api.models import User
+from shared.abstractions import PDFParsingError, PopplerNotFoundError
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+
+logger = logging.getLogger()
+STARTING_VERSION = "v0"
+
+
+class IngestionService:
+ """A refactored IngestionService that inlines all pipe logic for parsing,
+ embedding, and vector storage directly in its methods."""
+
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ) -> None:
+ self.config = config
+ self.providers = providers
+
+ async def ingest_file_ingress(
+ self,
+ file_data: dict,
+ user: User,
+ document_id: UUID,
+ size_in_bytes,
+ metadata: Optional[dict] = None,
+ version: Optional[str] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ """Pre-ingests a file by creating or validating the DocumentResponse
+ entry.
+
+ Does not actually parse/ingest the content. (See parse_file() for that
+ step.)
+ """
+ try:
+ if not file_data:
+ raise R2RException(
+ status_code=400, message="No files provided for ingestion."
+ )
+ if not file_data.get("filename"):
+ raise R2RException(
+ status_code=400, message="File name not provided."
+ )
+
+ metadata = metadata or {}
+ version = version or STARTING_VERSION
+
+ document_info = self.create_document_info_from_file(
+ document_id,
+ user,
+ file_data["filename"],
+ metadata,
+ version,
+ size_in_bytes,
+ )
+
+ existing_document_info = (
+ await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_user_ids=[user.id],
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+
+ # Validate ingestion status for re-ingestion
+ if len(existing_document_info) > 0:
+ existing_doc = existing_document_info[0]
+ if existing_doc.ingestion_status == IngestionStatus.SUCCESS:
+ raise R2RException(
+ status_code=409,
+ message=(
+ f"Document {document_id} already exists. "
+ "Submit a DELETE request to `/documents/{document_id}` "
+ "to delete this document and allow for re-ingestion."
+ ),
+ )
+ elif existing_doc.ingestion_status != IngestionStatus.FAILED:
+ raise R2RException(
+ status_code=409,
+ message=(
+ f"Document {document_id} is currently ingesting "
+ f"with status {existing_doc.ingestion_status}."
+ ),
+ )
+
+ # Set to PARSING until we actually parse
+ document_info.ingestion_status = IngestionStatus.PARSING
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ document_info
+ )
+
+ return {
+ "info": document_info,
+ }
+ except R2RException as e:
+ logger.error(f"R2RException in ingest_file_ingress: {str(e)}")
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error during ingestion: {str(e)}"
+ ) from e
+
+ def create_document_info_from_file(
+ self,
+ document_id: UUID,
+ user: User,
+ file_name: str,
+ metadata: dict,
+ version: str,
+ size_in_bytes: int,
+ ) -> DocumentResponse:
+ file_extension = (
+ file_name.split(".")[-1].lower() if file_name != "N/A" else "txt"
+ )
+ if file_extension.upper() not in DocumentType.__members__:
+ raise R2RException(
+ status_code=415,
+ message=f"'{file_extension}' is not a valid DocumentType.",
+ )
+
+ metadata = metadata or {}
+ metadata["version"] = version
+
+ return DocumentResponse(
+ id=document_id,
+ owner_id=user.id,
+ collection_ids=metadata.get("collection_ids", []),
+ document_type=DocumentType[file_extension.upper()],
+ title=(
+ metadata.get("title", file_name.split("/")[-1])
+ if file_name != "N/A"
+ else "N/A"
+ ),
+ metadata=metadata,
+ version=version,
+ size_in_bytes=size_in_bytes,
+ ingestion_status=IngestionStatus.PENDING,
+ created_at=datetime.now(),
+ updated_at=datetime.now(),
+ )
+
+ def _create_document_info_from_chunks(
+ self,
+ document_id: UUID,
+ user: User,
+ chunks: list[RawChunk],
+ metadata: dict,
+ version: str,
+ ) -> DocumentResponse:
+ metadata = metadata or {}
+ metadata["version"] = version
+
+ return DocumentResponse(
+ id=document_id,
+ owner_id=user.id,
+ collection_ids=metadata.get("collection_ids", []),
+ document_type=DocumentType.TXT,
+ title=metadata.get("title", f"Ingested Chunks - {document_id}"),
+ metadata=metadata,
+ version=version,
+ size_in_bytes=sum(
+ len(chunk.text.encode("utf-8")) for chunk in chunks
+ ),
+ ingestion_status=IngestionStatus.PENDING,
+ created_at=datetime.now(),
+ updated_at=datetime.now(),
+ )
+
+ async def parse_file(
+ self,
+ document_info: DocumentResponse,
+ ingestion_config: dict | None,
+ ) -> AsyncGenerator[DocumentChunk, None]:
+ """Reads the file content from the DB, calls the ingestion
+ provider to parse, and yields DocumentChunk objects."""
+ version = document_info.version or "v0"
+ ingestion_config_override = ingestion_config or {}
+
+ # The ingestion config might specify a different provider, etc.
+ override_provider = ingestion_config_override.pop("provider", None)
+ if (
+ override_provider
+ and override_provider != self.providers.ingestion.config.provider
+ ):
+ raise ValueError(
+ f"Provider '{override_provider}' does not match ingestion provider "
+ f"'{self.providers.ingestion.config.provider}'."
+ )
+
+ try:
+ # Pull file from DB
+ retrieved = (
+ await self.providers.database.files_handler.retrieve_file(
+ document_info.id
+ )
+ )
+ if not retrieved:
+ # No file found in the DB, can't parse
+ raise R2RDocumentProcessingError(
+ document_id=document_info.id,
+ error_message="No file content found in DB for this document.",
+ )
+
+ file_name, file_wrapper, file_size = retrieved
+
+ # Read the content
+ with file_wrapper as file_content_stream:
+ file_content = file_content_stream.read()
+
+ # Build a barebones Document object
+ doc = Document(
+ id=document_info.id,
+ collection_ids=document_info.collection_ids,
+ owner_id=document_info.owner_id,
+ metadata={
+ "document_type": document_info.document_type.value,
+ **document_info.metadata,
+ },
+ document_type=document_info.document_type,
+ )
+
+ # Delegate to the ingestion provider to parse
+ async for extraction in self.providers.ingestion.parse(
+ file_content, # raw bytes
+ doc,
+ ingestion_config_override,
+ ):
+ # Adjust chunk ID to incorporate version
+ # or any other needed transformations
+ extraction.id = generate_id(f"{extraction.id}_{version}")
+ extraction.metadata["version"] = version
+ yield extraction
+
+ except (PopplerNotFoundError, PDFParsingError) as e:
+ raise R2RDocumentProcessingError(
+ error_message=e.message,
+ document_id=document_info.id,
+ status_code=e.status_code,
+ ) from None
+ except Exception as e:
+ if isinstance(e, R2RException):
+ raise
+ raise R2RDocumentProcessingError(
+ document_id=document_info.id,
+ error_message=f"Error parsing document: {str(e)}",
+ ) from e
+
+ async def augment_document_info(
+ self,
+ document_info: DocumentResponse,
+ chunked_documents: list[dict],
+ ) -> None:
+ if not self.config.ingestion.skip_document_summary:
+ document = f"Document Title: {document_info.title}\n"
+ if document_info.metadata != {}:
+ document += f"Document Metadata: {json.dumps(document_info.metadata)}\n"
+
+ document += "Document Text:\n"
+ for chunk in chunked_documents[
+ : self.config.ingestion.chunks_for_document_summary
+ ]:
+ document += chunk["data"]
+
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ system_prompt_name=self.config.ingestion.document_summary_system_prompt,
+ task_prompt_name=self.config.ingestion.document_summary_task_prompt,
+ task_inputs={
+ "document": document[
+ : self.config.ingestion.document_summary_max_length
+ ]
+ },
+ )
+
+ response = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=GenerationConfig(
+ model=self.config.ingestion.document_summary_model
+ or self.config.app.fast_llm
+ ),
+ )
+
+ document_info.summary = response.choices[0].message.content # type: ignore
+
+ if not document_info.summary:
+ raise ValueError("Expected a generated response.")
+
+ embedding = await self.providers.embedding.async_get_embedding(
+ text=document_info.summary,
+ )
+ document_info.summary_embedding = embedding
+ return
+
+ async def embed_document(
+ self,
+ chunked_documents: list[dict],
+ embedding_batch_size: int = 8,
+ ) -> AsyncGenerator[VectorEntry, None]:
+ """Inline replacement for the old embedding_pipe.run(...).
+
+ Batches the embedding calls and yields VectorEntry objects.
+ """
+ if not chunked_documents:
+ return
+
+ concurrency_limit = (
+ self.providers.embedding.config.concurrent_request_limit or 5
+ )
+ extraction_batch: list[DocumentChunk] = []
+ tasks: set[asyncio.Task] = set()
+
+ async def process_batch(
+ batch: list[DocumentChunk],
+ ) -> list[VectorEntry]:
+ # All text from the batch
+ texts = [
+ (
+ ex.data.decode("utf-8")
+ if isinstance(ex.data, bytes)
+ else ex.data
+ )
+ for ex in batch
+ ]
+ # Retrieve embeddings in bulk
+ vectors = await self.providers.embedding.async_get_embeddings(
+ texts, # list of strings
+ )
+ # Zip them back together
+ results = []
+ for raw_vector, extraction in zip(vectors, batch, strict=False):
+ results.append(
+ VectorEntry(
+ id=extraction.id,
+ document_id=extraction.document_id,
+ owner_id=extraction.owner_id,
+ collection_ids=extraction.collection_ids,
+ vector=Vector(data=raw_vector, type=VectorType.FIXED),
+ text=(
+ extraction.data.decode("utf-8")
+ if isinstance(extraction.data, bytes)
+ else str(extraction.data)
+ ),
+ metadata={**extraction.metadata},
+ )
+ )
+ return results
+
+ async def run_process_batch(batch: list[DocumentChunk]):
+ return await process_batch(batch)
+
+ # Convert each chunk dict to a DocumentChunk
+ for chunk_dict in chunked_documents:
+ extraction = DocumentChunk.from_dict(chunk_dict)
+ extraction_batch.append(extraction)
+
+ # If we hit a batch threshold, spawn a task
+ if len(extraction_batch) >= embedding_batch_size:
+ tasks.add(
+ asyncio.create_task(run_process_batch(extraction_batch))
+ )
+ extraction_batch = []
+
+ # If tasks are at concurrency limit, wait for the first to finish
+ while len(tasks) >= concurrency_limit:
+ done, tasks = await asyncio.wait(
+ tasks, return_when=asyncio.FIRST_COMPLETED
+ )
+ for t in done:
+ for vector_entry in await t:
+ yield vector_entry
+
+ # Handle any leftover items
+ if extraction_batch:
+ tasks.add(asyncio.create_task(run_process_batch(extraction_batch)))
+
+ # Gather remaining tasks
+ for future_task in asyncio.as_completed(tasks):
+ for vector_entry in await future_task:
+ yield vector_entry
+
+ async def store_embeddings(
+ self,
+ embeddings: Sequence[dict | VectorEntry],
+ storage_batch_size: int = 128,
+ ) -> AsyncGenerator[str, None]:
+ """Inline replacement for the old vector_storage_pipe.run(...).
+
+ Batches up the vector entries, enforces usage limits, stores them, and
+ yields a success/error string (or you could yield a StorageResult).
+ """
+ if not embeddings:
+ return
+
+ vector_entries: list[VectorEntry] = []
+ for item in embeddings:
+ if isinstance(item, VectorEntry):
+ vector_entries.append(item)
+ else:
+ vector_entries.append(VectorEntry.from_dict(item))
+
+ vector_batch: list[VectorEntry] = []
+ document_counts: dict[UUID, int] = {}
+
+ # We'll track usage from the first user we see; if your scenario allows
+ # multiple user owners in a single ingestion, you'd need to refine usage checks.
+ current_usage = None
+ user_id_for_usage_check: UUID | None = None
+
+ count = 0
+
+ for msg in vector_entries:
+ # If we haven't set usage yet, do so on the first chunk
+ if current_usage is None:
+ user_id_for_usage_check = msg.owner_id
+ usage_data = (
+ await self.providers.database.chunks_handler.list_chunks(
+ limit=1,
+ offset=0,
+ filters={"owner_id": msg.owner_id},
+ )
+ )
+ current_usage = usage_data["total_entries"]
+
+ # Figure out the user's limit
+ user = await self.providers.database.users_handler.get_user_by_id(
+ msg.owner_id
+ )
+ max_chunks = (
+ self.providers.database.config.app.default_max_chunks_per_user
+ )
+ if user.limits_overrides and "max_chunks" in user.limits_overrides:
+ max_chunks = user.limits_overrides["max_chunks"]
+
+ # Add to our local batch
+ vector_batch.append(msg)
+ document_counts[msg.document_id] = (
+ document_counts.get(msg.document_id, 0) + 1
+ )
+ count += 1
+
+ # Check usage
+ if (
+ current_usage is not None
+ and (current_usage + len(vector_batch) + count) > max_chunks
+ ):
+ error_message = f"User {msg.owner_id} has exceeded the maximum number of allowed chunks: {max_chunks}"
+ logger.error(error_message)
+ yield error_message
+ continue
+
+ # Once we hit our batch size, store them
+ if len(vector_batch) >= storage_batch_size:
+ try:
+ await (
+ self.providers.database.chunks_handler.upsert_entries(
+ vector_batch
+ )
+ )
+ except Exception as e:
+ logger.error(f"Failed to store vector batch: {e}")
+ yield f"Error: {e}"
+ vector_batch.clear()
+
+ # Store any leftover items
+ if vector_batch:
+ try:
+ await self.providers.database.chunks_handler.upsert_entries(
+ vector_batch
+ )
+ except Exception as e:
+ logger.error(f"Failed to store final vector batch: {e}")
+ yield f"Error: {e}"
+
+ # Summaries
+ for doc_id, cnt in document_counts.items():
+ info_msg = f"Successful ingestion for document_id: {doc_id}, with vector count: {cnt}"
+ logger.info(info_msg)
+ yield info_msg
+
+ async def finalize_ingestion(
+ self, document_info: DocumentResponse
+ ) -> None:
+ """Called at the end of a successful ingestion pipeline to set the
+ document status to SUCCESS or similar final steps."""
+
+ async def empty_generator():
+ yield document_info
+
+ await self.update_document_status(
+ document_info, IngestionStatus.SUCCESS
+ )
+ return empty_generator()
+
+ async def update_document_status(
+ self,
+ document_info: DocumentResponse,
+ status: IngestionStatus,
+ metadata: Optional[dict] = None,
+ ) -> None:
+ document_info.ingestion_status = status
+ if metadata:
+ document_info.metadata = {**document_info.metadata, **metadata}
+ await self._update_document_status_in_db(document_info)
+
+ async def _update_document_status_in_db(
+ self, document_info: DocumentResponse
+ ):
+ try:
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ document_info
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to update document status: {document_info.id}. Error: {str(e)}"
+ )
+
+ async def ingest_chunks_ingress(
+ self,
+ document_id: UUID,
+ metadata: Optional[dict],
+ chunks: list[RawChunk],
+ user: User,
+ *args: Any,
+ **kwargs: Any,
+ ) -> DocumentResponse:
+ """Directly ingest user-provided text chunks (rather than from a
+ file)."""
+ if not chunks:
+ raise R2RException(
+ status_code=400, message="No chunks provided for ingestion."
+ )
+ metadata = metadata or {}
+ version = STARTING_VERSION
+
+ document_info = self._create_document_info_from_chunks(
+ document_id,
+ user,
+ chunks,
+ metadata,
+ version,
+ )
+
+ existing_document_info = (
+ await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_user_ids=[user.id],
+ filter_document_ids=[document_id],
+ )
+ )["results"]
+ if len(existing_document_info) > 0:
+ existing_doc = existing_document_info[0]
+ if existing_doc.ingestion_status != IngestionStatus.FAILED:
+ raise R2RException(
+ status_code=409,
+ message=(
+ f"Document {document_id} was already ingested "
+ "and is not in a failed state."
+ ),
+ )
+
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ document_info
+ )
+ return document_info
+
+ async def update_chunk_ingress(
+ self,
+ document_id: UUID,
+ chunk_id: UUID,
+ text: str,
+ user: User,
+ metadata: Optional[dict] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ """Update an individual chunk's text and metadata, re-embed, and re-
+ store it."""
+ # Verify chunk exists and user has access
+ existing_chunks = (
+ await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=0,
+ limit=1,
+ )
+ )
+ if not existing_chunks["results"]:
+ raise R2RException(
+ status_code=404,
+ message=f"Chunk with chunk_id {chunk_id} not found.",
+ )
+
+ existing_chunk = (
+ await self.providers.database.chunks_handler.get_chunk(chunk_id)
+ )
+ if not existing_chunk:
+ raise R2RException(
+ status_code=404,
+ message=f"Chunk with id {chunk_id} not found",
+ )
+
+ if (
+ str(existing_chunk["owner_id"]) != str(user.id)
+ and not user.is_superuser
+ ):
+ raise R2RException(
+ status_code=403,
+ message="You don't have permission to modify this chunk.",
+ )
+
+ # Merge metadata
+ merged_metadata = {**existing_chunk["metadata"]}
+ if metadata is not None:
+ merged_metadata |= metadata
+
+ # Create updated chunk
+ extraction_data = {
+ "id": chunk_id,
+ "document_id": document_id,
+ "collection_ids": kwargs.get(
+ "collection_ids", existing_chunk["collection_ids"]
+ ),
+ "owner_id": existing_chunk["owner_id"],
+ "data": text or existing_chunk["text"],
+ "metadata": merged_metadata,
+ }
+ extraction = DocumentChunk(**extraction_data).model_dump()
+
+ # Re-embed
+ embeddings_generator = self.embed_document(
+ [extraction], embedding_batch_size=1
+ )
+ embeddings = []
+ async for embedding in embeddings_generator:
+ embeddings.append(embedding)
+
+ # Re-store
+ store_gen = self.store_embeddings(embeddings, storage_batch_size=1)
+ async for _ in store_gen:
+ pass
+
+ return extraction
+
+ async def _get_enriched_chunk_text(
+ self,
+ chunk_idx: int,
+ chunk: dict,
+ document_id: UUID,
+ document_summary: str | None,
+ chunk_enrichment_settings: ChunkEnrichmentSettings,
+ list_document_chunks: list[dict],
+ ) -> VectorEntry:
+ """Helper for chunk_enrichment.
+
+ Leverages an LLM to rewrite or expand chunk text, then re-embeds it.
+ """
+ preceding_chunks = [
+ list_document_chunks[idx]["text"]
+ for idx in range(
+ max(0, chunk_idx - chunk_enrichment_settings.n_chunks),
+ chunk_idx,
+ )
+ ]
+ succeeding_chunks = [
+ list_document_chunks[idx]["text"]
+ for idx in range(
+ chunk_idx + 1,
+ min(
+ len(list_document_chunks),
+ chunk_idx + chunk_enrichment_settings.n_chunks + 1,
+ ),
+ )
+ ]
+ try:
+ # Obtain the updated text from the LLM
+ updated_chunk_text = (
+ (
+ await self.providers.llm.aget_completion(
+ messages=await self.providers.database.prompts_handler.get_message_payload(
+ task_prompt_name=chunk_enrichment_settings.chunk_enrichment_prompt,
+ task_inputs={
+ "document_summary": document_summary or "None",
+ "chunk": chunk["text"],
+ "preceding_chunks": (
+ "\n".join(preceding_chunks)
+ if preceding_chunks
+ else "None"
+ ),
+ "succeeding_chunks": (
+ "\n".join(succeeding_chunks)
+ if succeeding_chunks
+ else "None"
+ ),
+ "chunk_size": self.config.ingestion.chunk_size
+ or 1024,
+ },
+ ),
+ generation_config=chunk_enrichment_settings.generation_config
+ or GenerationConfig(model=self.config.app.fast_llm),
+ )
+ )
+ .choices[0]
+ .message.content
+ )
+ except Exception:
+ updated_chunk_text = chunk["text"]
+ chunk["metadata"]["chunk_enrichment_status"] = "failed"
+ else:
+ chunk["metadata"]["chunk_enrichment_status"] = (
+ "success" if updated_chunk_text else "failed"
+ )
+
+ if not updated_chunk_text or not isinstance(updated_chunk_text, str):
+ updated_chunk_text = str(chunk["text"])
+ chunk["metadata"]["chunk_enrichment_status"] = "failed"
+
+ # Re-embed
+ data = await self.providers.embedding.async_get_embedding(
+ updated_chunk_text
+ )
+ chunk["metadata"]["original_text"] = chunk["text"]
+
+ return VectorEntry(
+ id=generate_id(str(chunk["id"])),
+ vector=Vector(data=data, type=VectorType.FIXED, length=len(data)),
+ document_id=document_id,
+ owner_id=chunk["owner_id"],
+ collection_ids=chunk["collection_ids"],
+ text=updated_chunk_text,
+ metadata=chunk["metadata"],
+ )
+
+ async def chunk_enrichment(
+ self,
+ document_id: UUID,
+ document_summary: str | None,
+ chunk_enrichment_settings: ChunkEnrichmentSettings,
+ ) -> int:
+ """Example function that modifies chunk text via an LLM then re-embeds
+ and re-stores all chunks for the given document."""
+ list_document_chunks = (
+ await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=0,
+ limit=-1,
+ )
+ )["results"]
+
+ new_vector_entries: list[VectorEntry] = []
+ tasks = []
+ total_completed = 0
+
+ for chunk_idx, chunk in enumerate(list_document_chunks):
+ tasks.append(
+ self._get_enriched_chunk_text(
+ chunk_idx=chunk_idx,
+ chunk=chunk,
+ document_id=document_id,
+ document_summary=document_summary,
+ chunk_enrichment_settings=chunk_enrichment_settings,
+ list_document_chunks=list_document_chunks,
+ )
+ )
+
+ # Process in batches of e.g. 128 concurrency
+ if len(tasks) == 128:
+ new_vector_entries.extend(await asyncio.gather(*tasks))
+ total_completed += 128
+ logger.info(
+ f"Completed {total_completed} out of {len(list_document_chunks)} chunks for document {document_id}"
+ )
+ tasks = []
+
+ # Finish any remaining tasks
+ new_vector_entries.extend(await asyncio.gather(*tasks))
+ logger.info(
+ f"Completed enrichment of {len(list_document_chunks)} chunks for document {document_id}"
+ )
+
+ # Delete old chunks from vector db
+ await self.providers.database.chunks_handler.delete(
+ filters={"document_id": document_id}
+ )
+
+ # Insert the newly enriched entries
+ await self.providers.database.chunks_handler.upsert_entries(
+ new_vector_entries
+ )
+ return len(new_vector_entries)
+
+ async def list_chunks(
+ self,
+ offset: int,
+ limit: int,
+ filters: Optional[dict[str, Any]] = None,
+ include_vectors: bool = False,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ return await self.providers.database.chunks_handler.list_chunks(
+ offset=offset,
+ limit=limit,
+ filters=filters,
+ include_vectors=include_vectors,
+ )
+
+ async def get_chunk(
+ self,
+ chunk_id: UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> dict:
+ return await self.providers.database.chunks_handler.get_chunk(chunk_id)
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: dict,
+ user: User,
+ ) -> None:
+ # Verify document exists and user has access
+ existing_document = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=100,
+ filter_document_ids=[document_id],
+ filter_user_ids=[user.id],
+ )
+ if not existing_document["results"]:
+ raise R2RException(
+ status_code=404,
+ message=(
+ f"Document with id {document_id} not found "
+ "or you don't have access."
+ ),
+ )
+
+ existing_document = existing_document["results"][0]
+
+ # Merge metadata
+ merged_metadata = {**existing_document.metadata, **metadata} # type: ignore
+
+ # Update document metadata
+ existing_document.metadata = merged_metadata # type: ignore
+ await self.providers.database.documents_handler.upsert_documents_overview(
+ existing_document # type: ignore
+ )
+
+
+class IngestionServiceAdapter:
+ @staticmethod
+ def _parse_user_data(user_data) -> User:
+ if isinstance(user_data, str):
+ try:
+ user_data = json.loads(user_data)
+ except json.JSONDecodeError as e:
+ raise ValueError(
+ f"Invalid user data format: {user_data}"
+ ) from e
+ return User.from_dict(user_data)
+
+ @staticmethod
+ def parse_ingest_file_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "metadata": data["metadata"],
+ "document_id": (
+ UUID(data["document_id"]) if data["document_id"] else None
+ ),
+ "version": data.get("version"),
+ "ingestion_config": data["ingestion_config"] or {},
+ "file_data": data["file_data"],
+ "size_in_bytes": data["size_in_bytes"],
+ "collection_ids": data.get("collection_ids", []),
+ }
+
+ @staticmethod
+ def parse_ingest_chunks_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "metadata": data["metadata"],
+ "document_id": data["document_id"],
+ "chunks": [
+ UnprocessedChunk.from_dict(chunk) for chunk in data["chunks"]
+ ],
+ "id": data.get("id"),
+ }
+
+ @staticmethod
+ def parse_update_chunk_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "document_id": UUID(data["document_id"]),
+ "id": UUID(data["id"]),
+ "text": data["text"],
+ "metadata": data.get("metadata"),
+ "collection_ids": data.get("collection_ids", []),
+ }
+
+ @staticmethod
+ def parse_update_files_input(data: dict) -> dict:
+ return {
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ "document_ids": [UUID(doc_id) for doc_id in data["document_ids"]],
+ "metadatas": data["metadatas"],
+ "ingestion_config": data["ingestion_config"],
+ "file_sizes_in_bytes": data["file_sizes_in_bytes"],
+ "file_datas": data["file_datas"],
+ }
+
+ @staticmethod
+ def parse_create_vector_index_input(data: dict) -> dict:
+ return {
+ "table_name": VectorTableName(data["table_name"]),
+ "index_method": IndexMethod(data["index_method"]),
+ "index_measure": IndexMeasure(data["index_measure"]),
+ "index_name": data["index_name"],
+ "index_column": data["index_column"],
+ "index_arguments": data["index_arguments"],
+ "concurrently": data["concurrently"],
+ }
+
+ @staticmethod
+ def parse_list_vector_indices_input(input_data: dict) -> dict:
+ return {"table_name": input_data["table_name"]}
+
+ @staticmethod
+ def parse_delete_vector_index_input(input_data: dict) -> dict:
+ return {
+ "index_name": input_data["index_name"],
+ "table_name": input_data.get("table_name"),
+ "concurrently": input_data.get("concurrently", True),
+ }
+
+ @staticmethod
+ def parse_select_vector_index_input(input_data: dict) -> dict:
+ return {
+ "index_name": input_data["index_name"],
+ "table_name": input_data.get("table_name"),
+ }
+
+ @staticmethod
+ def parse_update_document_metadata_input(data: dict) -> dict:
+ return {
+ "document_id": data["document_id"],
+ "metadata": data["metadata"],
+ "user": IngestionServiceAdapter._parse_user_data(data["user"]),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/management_service.py b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py
new file mode 100644
index 00000000..62b4ca0b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/management_service.py
@@ -0,0 +1,1084 @@
+import logging
+import os
+from collections import defaultdict
+from datetime import datetime, timedelta, timezone
+from typing import IO, Any, BinaryIO, Optional, Tuple
+from uuid import UUID
+
+import toml
+
+from core.base import (
+ CollectionResponse,
+ ConversationResponse,
+ DocumentResponse,
+ GenerationConfig,
+ GraphConstructionStatus,
+ Message,
+ MessageResponse,
+ Prompt,
+ R2RException,
+ StoreType,
+ User,
+)
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+class ManagementService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def app_settings(self):
+ prompts = (
+ await self.providers.database.prompts_handler.get_all_prompts()
+ )
+ config_toml = self.config.to_toml()
+ config_dict = toml.loads(config_toml)
+ try:
+ project_name = os.environ["R2R_PROJECT_NAME"]
+ except KeyError:
+ project_name = ""
+ return {
+ "config": config_dict,
+ "prompts": prompts,
+ "r2r_project_name": project_name,
+ }
+
+ async def users_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ ):
+ return await self.providers.database.users_handler.get_users_overview(
+ offset=offset,
+ limit=limit,
+ user_ids=user_ids,
+ )
+
+ async def delete_documents_and_chunks_by_filter(
+ self,
+ filters: dict[str, Any],
+ ):
+ """Delete chunks matching the given filters. If any documents are now
+ empty (i.e., have no remaining chunks), delete those documents as well.
+
+ Args:
+ filters (dict[str, Any]): Filters specifying which chunks to delete.
+ chunks_handler (PostgresChunksHandler): The handler for chunk operations.
+ documents_handler (PostgresDocumentsHandler): The handler for document operations.
+ graphs_handler: Handler for entity and relationship operations in the Graph.
+
+ Returns:
+ dict: A summary of what was deleted.
+ """
+
+ def transform_chunk_id_to_id(
+ filters: dict[str, Any],
+ ) -> dict[str, Any]:
+ """Example transformation function if your filters use `chunk_id`
+ instead of `id`.
+
+ Recursively transform `chunk_id` to `id`.
+ """
+ if isinstance(filters, dict):
+ transformed = {}
+ for key, value in filters.items():
+ if key == "chunk_id":
+ transformed["id"] = value
+ elif key in ["$and", "$or"]:
+ transformed[key] = [
+ transform_chunk_id_to_id(item) for item in value
+ ]
+ else:
+ transformed[key] = transform_chunk_id_to_id(value)
+ return transformed
+ return filters
+
+ # Transform filters if needed.
+ transformed_filters = transform_chunk_id_to_id(filters)
+
+ # Find chunks that match the filters before deleting
+ interim_results = (
+ await self.providers.database.chunks_handler.list_chunks(
+ filters=transformed_filters,
+ offset=0,
+ limit=1_000,
+ include_vectors=False,
+ )
+ )
+
+ results = interim_results["results"]
+ while interim_results["total_entries"] == 1_000:
+ # If we hit the limit, we need to paginate to get all results
+
+ interim_results = (
+ await self.providers.database.chunks_handler.list_chunks(
+ filters=transformed_filters,
+ offset=interim_results["offset"] + 1_000,
+ limit=1_000,
+ include_vectors=False,
+ )
+ )
+ results.extend(interim_results["results"])
+
+ document_ids = set()
+ owner_id = None
+
+ if "$and" in filters:
+ for condition in filters["$and"]:
+ if "owner_id" in condition and "$eq" in condition["owner_id"]:
+ owner_id = condition["owner_id"]["$eq"]
+ elif (
+ "document_id" in condition
+ and "$eq" in condition["document_id"]
+ ):
+ document_ids.add(UUID(condition["document_id"]["$eq"]))
+ elif "document_id" in filters:
+ doc_id = filters["document_id"]
+ if isinstance(doc_id, str):
+ document_ids.add(UUID(doc_id))
+ elif isinstance(doc_id, UUID):
+ document_ids.add(doc_id)
+ elif isinstance(doc_id, dict) and "$eq" in doc_id:
+ value = doc_id["$eq"]
+ document_ids.add(
+ UUID(value) if isinstance(value, str) else value
+ )
+
+ # Delete matching chunks from the database
+ delete_results = await self.providers.database.chunks_handler.delete(
+ transformed_filters
+ )
+
+ # Extract the document_ids that were affected.
+ affected_doc_ids = {
+ UUID(info["document_id"])
+ for info in delete_results.values()
+ if info.get("document_id")
+ }
+ document_ids.update(affected_doc_ids)
+
+ # Check if the document still has any chunks left
+ docs_to_delete = []
+ for doc_id in document_ids:
+ documents_overview_response = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0, limit=1, filter_document_ids=[doc_id]
+ )
+ if not documents_overview_response["results"]:
+ raise R2RException(
+ status_code=404, message="Document not found"
+ )
+
+ document = documents_overview_response["results"][0]
+
+ for collection_id in document.collection_ids:
+ await self.providers.database.collections_handler.decrement_collection_document_count(
+ collection_id=collection_id
+ )
+
+ if owner_id and str(document.owner_id) != owner_id:
+ raise R2RException(
+ status_code=404,
+ message="Document not found or insufficient permissions",
+ )
+ docs_to_delete.append(doc_id)
+
+ # Delete documents that no longer have associated chunks
+ for doc_id in docs_to_delete:
+ # Delete related entities & relationships if needed:
+ await self.providers.database.graphs_handler.entities.delete(
+ parent_id=doc_id,
+ store_type=StoreType.DOCUMENTS,
+ )
+ await self.providers.database.graphs_handler.relationships.delete(
+ parent_id=doc_id,
+ store_type=StoreType.DOCUMENTS,
+ )
+
+ # Finally, delete the document from documents_overview:
+ await self.providers.database.documents_handler.delete(
+ document_id=doc_id
+ )
+
+ return {
+ "success": True,
+ "deleted_chunks_count": len(delete_results),
+ "deleted_documents_count": len(docs_to_delete),
+ "deleted_document_ids": [str(d) for d in docs_to_delete],
+ }
+
+ async def download_file(
+ self, document_id: UUID
+ ) -> Optional[Tuple[str, BinaryIO, int]]:
+ if result := await self.providers.database.files_handler.retrieve_file(
+ document_id
+ ):
+ return result
+ return None
+
+ async def export_files(
+ self,
+ document_ids: Optional[list[UUID]] = None,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ ) -> tuple[str, BinaryIO, int]:
+ return (
+ await self.providers.database.files_handler.retrieve_files_as_zip(
+ document_ids=document_ids,
+ start_date=start_date,
+ end_date=end_date,
+ )
+ )
+
+ async def export_collections(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.collections_handler.export_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_documents(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.documents_handler.export_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_document_entities(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.entities.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_document_relationships(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.relationships.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.DOCUMENTS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_conversations(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.conversations_handler.export_conversations_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_graph_entities(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.entities.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.GRAPHS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_graph_relationships(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.relationships.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.GRAPHS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_graph_communities(
+ self,
+ id: UUID,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.graphs_handler.communities.export_to_csv(
+ parent_id=id,
+ store_type=StoreType.GRAPHS,
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_messages(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.conversations_handler.export_messages_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def export_users(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ return await self.providers.database.users_handler.export_to_csv(
+ columns=columns,
+ filters=filters,
+ include_header=include_header,
+ )
+
+ async def documents_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ collection_ids: Optional[list[UUID]] = None,
+ document_ids: Optional[list[UUID]] = None,
+ ):
+ return await self.providers.database.documents_handler.get_documents_overview(
+ offset=offset,
+ limit=limit,
+ filter_document_ids=document_ids,
+ filter_user_ids=user_ids,
+ filter_collection_ids=collection_ids,
+ )
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: list[dict],
+ overwrite: bool = False,
+ ):
+ return await self.providers.database.documents_handler.update_document_metadata(
+ document_id=document_id,
+ metadata=metadata,
+ overwrite=overwrite,
+ )
+
+ async def list_document_chunks(
+ self,
+ document_id: UUID,
+ offset: int,
+ limit: int,
+ include_vectors: bool = False,
+ ):
+ return (
+ await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=document_id,
+ offset=offset,
+ limit=limit,
+ include_vectors=include_vectors,
+ )
+ )
+
+ async def assign_document_to_collection(
+ self, document_id: UUID, collection_id: UUID
+ ):
+ await self.providers.database.chunks_handler.assign_document_chunks_to_collection(
+ document_id, collection_id
+ )
+ await self.providers.database.collections_handler.assign_document_to_collection_relational(
+ document_id, collection_id
+ )
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_sync_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+ await self.providers.database.documents_handler.set_workflow_status(
+ id=collection_id,
+ status_type="graph_cluster_status",
+ status=GraphConstructionStatus.OUTDATED,
+ )
+
+ return {"message": "Document assigned to collection successfully"}
+
+ async def remove_document_from_collection(
+ self, document_id: UUID, collection_id: UUID
+ ):
+ await self.providers.database.collections_handler.remove_document_from_collection_relational(
+ document_id, collection_id
+ )
+ await self.providers.database.chunks_handler.remove_document_from_collection_vector(
+ document_id, collection_id
+ )
+ # await self.providers.database.graphs_handler.delete_node_via_document_id(
+ # document_id, collection_id
+ # )
+ return None
+
+ def _process_relationships(
+ self, relationships: list[Tuple[str, str, str]]
+ ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]:
+ graph = defaultdict(list)
+ grouped: dict[str, dict[str, list[str]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ for subject, relation, obj in relationships:
+ graph[subject].append(obj)
+ grouped[subject][relation].append(obj)
+ if obj not in graph:
+ graph[obj] = []
+ return dict(graph), dict(grouped)
+
+ def generate_output(
+ self,
+ grouped_relationships: dict[str, dict[str, list[str]]],
+ graph: dict[str, list[str]],
+ descriptions_dict: dict[str, str],
+ print_descriptions: bool = True,
+ ) -> list[str]:
+ output = []
+ # Print grouped relationships
+ for subject, relations in grouped_relationships.items():
+ output.append(f"\n== {subject} ==")
+ if print_descriptions and subject in descriptions_dict:
+ output.append(f"\tDescription: {descriptions_dict[subject]}")
+ for relation, objects in relations.items():
+ output.append(f" {relation}:")
+ for obj in objects:
+ output.append(f" - {obj}")
+ if print_descriptions and obj in descriptions_dict:
+ output.append(
+ f" Description: {descriptions_dict[obj]}"
+ )
+
+ # Print basic graph statistics
+ output.extend(
+ [
+ "\n== Graph Statistics ==",
+ f"Number of nodes: {len(graph)}",
+ f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}",
+ f"Number of connected components: {self._count_connected_components(graph)}",
+ ]
+ )
+
+ # Find central nodes
+ central_nodes = self._get_central_nodes(graph)
+ output.extend(
+ [
+ "\n== Most Central Nodes ==",
+ *(
+ f" {node}: {centrality:.4f}"
+ for node, centrality in central_nodes
+ ),
+ ]
+ )
+
+ return output
+
+ def _count_connected_components(self, graph: dict[str, list[str]]) -> int:
+ visited = set()
+ components = 0
+
+ def dfs(node):
+ visited.add(node)
+ for neighbor in graph[node]:
+ if neighbor not in visited:
+ dfs(neighbor)
+
+ for node in graph:
+ if node not in visited:
+ dfs(node)
+ components += 1
+
+ return components
+
+ def _get_central_nodes(
+ self, graph: dict[str, list[str]]
+ ) -> list[Tuple[str, float]]:
+ degree = {node: len(neighbors) for node, neighbors in graph.items()}
+ total_nodes = len(graph)
+ centrality = {
+ node: deg / (total_nodes - 1) for node, deg in degree.items()
+ }
+ return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
+
+ async def create_collection(
+ self,
+ owner_id: UUID,
+ name: Optional[str] = None,
+ description: str | None = None,
+ ) -> CollectionResponse:
+ result = await self.providers.database.collections_handler.create_collection(
+ owner_id=owner_id,
+ name=name,
+ description=description,
+ )
+ await self.providers.database.graphs_handler.create(
+ collection_id=result.id,
+ name=name,
+ description=description,
+ )
+ return result
+
+ async def update_collection(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ generate_description: bool = False,
+ ) -> CollectionResponse:
+ if generate_description:
+ description = await self.summarize_collection(
+ id=collection_id, offset=0, limit=100
+ )
+ return await self.providers.database.collections_handler.update_collection(
+ collection_id=collection_id,
+ name=name,
+ description=description,
+ )
+
+ async def delete_collection(self, collection_id: UUID) -> bool:
+ await self.providers.database.collections_handler.delete_collection_relational(
+ collection_id
+ )
+ await self.providers.database.chunks_handler.delete_collection_vector(
+ collection_id
+ )
+ try:
+ await self.providers.database.graphs_handler.delete(
+ collection_id=collection_id,
+ )
+ except Exception as e:
+ logger.warning(
+ f"Error deleting graph for collection {collection_id}: {e}"
+ )
+ return True
+
+ async def collections_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ document_ids: Optional[list[UUID]] = None,
+ collection_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[CollectionResponse] | int]:
+ return await self.providers.database.collections_handler.get_collections_overview(
+ offset=offset,
+ limit=limit,
+ filter_user_ids=user_ids,
+ filter_document_ids=document_ids,
+ filter_collection_ids=collection_ids,
+ )
+
+ async def add_user_to_collection(
+ self, user_id: UUID, collection_id: UUID
+ ) -> bool:
+ return (
+ await self.providers.database.users_handler.add_user_to_collection(
+ user_id, collection_id
+ )
+ )
+
+ async def remove_user_from_collection(
+ self, user_id: UUID, collection_id: UUID
+ ) -> bool:
+ return await self.providers.database.users_handler.remove_user_from_collection(
+ user_id, collection_id
+ )
+
+ async def get_users_in_collection(
+ self, collection_id: UUID, offset: int = 0, limit: int = 100
+ ) -> dict[str, list[User] | int]:
+ return await self.providers.database.users_handler.get_users_in_collection(
+ collection_id, offset=offset, limit=limit
+ )
+
+ async def documents_in_collection(
+ self, collection_id: UUID, offset: int = 0, limit: int = 100
+ ) -> dict[str, list[DocumentResponse] | int]:
+ return await self.providers.database.collections_handler.documents_in_collection(
+ collection_id, offset=offset, limit=limit
+ )
+
+ async def summarize_collection(
+ self, id: UUID, offset: int, limit: int
+ ) -> str:
+ documents_in_collection_response = await self.documents_in_collection(
+ collection_id=id,
+ offset=offset,
+ limit=limit,
+ )
+
+ document_summaries = [
+ document.summary
+ for document in documents_in_collection_response["results"] # type: ignore
+ ]
+
+ logger.info(
+ f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents."
+ )
+
+ formatted_summaries = "\n\n".join(document_summaries) # type: ignore
+
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ system_prompt_name=self.config.database.collection_summary_system_prompt,
+ task_prompt_name=self.config.database.collection_summary_prompt,
+ task_inputs={"document_summaries": formatted_summaries},
+ )
+
+ response = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=GenerationConfig(
+ model=self.config.ingestion.document_summary_model
+ or self.config.app.fast_llm
+ ),
+ )
+
+ if collection_summary := response.choices[0].message.content:
+ return collection_summary
+ else:
+ raise ValueError("Expected a generated response.")
+
+ async def add_prompt(
+ self, name: str, template: str, input_types: dict[str, str]
+ ) -> dict:
+ try:
+ await self.providers.database.prompts_handler.add_prompt(
+ name, template, input_types
+ )
+ return f"Prompt '{name}' added successfully." # type: ignore
+ except ValueError as e:
+ raise R2RException(status_code=400, message=str(e)) from e
+
+ async def get_cached_prompt(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ prompt_override: Optional[str] = None,
+ ) -> dict:
+ try:
+ return {
+ "message": (
+ await self.providers.database.prompts_handler.get_cached_prompt(
+ prompt_name=prompt_name,
+ inputs=inputs,
+ prompt_override=prompt_override,
+ )
+ )
+ }
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def get_prompt(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ prompt_override: Optional[str] = None,
+ ) -> dict:
+ try:
+ return await self.providers.database.prompts_handler.get_prompt( # type: ignore
+ name=prompt_name,
+ inputs=inputs,
+ prompt_override=prompt_override,
+ )
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def get_all_prompts(self) -> dict[str, Prompt]:
+ return await self.providers.database.prompts_handler.get_all_prompts()
+
+ async def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> dict:
+ try:
+ await self.providers.database.prompts_handler.update_prompt(
+ name, template, input_types
+ )
+ return f"Prompt '{name}' updated successfully." # type: ignore
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def delete_prompt(self, name: str) -> dict:
+ try:
+ await self.providers.database.prompts_handler.delete_prompt(name)
+ return {"message": f"Prompt '{name}' deleted successfully."}
+ except ValueError as e:
+ raise R2RException(status_code=404, message=str(e)) from e
+
+ async def get_conversation(
+ self,
+ conversation_id: UUID,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> list[MessageResponse]:
+ return await self.providers.database.conversations_handler.get_conversation(
+ conversation_id=conversation_id,
+ filter_user_ids=user_ids,
+ )
+
+ async def create_conversation(
+ self,
+ user_id: Optional[UUID] = None,
+ name: Optional[str] = None,
+ ) -> ConversationResponse:
+ return await self.providers.database.conversations_handler.create_conversation(
+ user_id=user_id,
+ name=name,
+ )
+
+ async def conversations_overview(
+ self,
+ offset: int,
+ limit: int,
+ conversation_ids: Optional[list[UUID]] = None,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[dict] | int]:
+ return await self.providers.database.conversations_handler.get_conversations_overview(
+ offset=offset,
+ limit=limit,
+ filter_user_ids=user_ids,
+ conversation_ids=conversation_ids,
+ )
+
+ async def add_message(
+ self,
+ conversation_id: UUID,
+ content: Message,
+ parent_id: Optional[UUID] = None,
+ metadata: Optional[dict] = None,
+ ) -> MessageResponse:
+ return await self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=content,
+ parent_id=parent_id,
+ metadata=metadata,
+ )
+
+ async def edit_message(
+ self,
+ message_id: UUID,
+ new_content: Optional[str] = None,
+ additional_metadata: Optional[dict] = None,
+ ) -> dict[str, Any]:
+ return (
+ await self.providers.database.conversations_handler.edit_message(
+ message_id=message_id,
+ new_content=new_content,
+ additional_metadata=additional_metadata or {},
+ )
+ )
+
+ async def update_conversation(
+ self, conversation_id: UUID, name: str
+ ) -> ConversationResponse:
+ return await self.providers.database.conversations_handler.update_conversation(
+ conversation_id=conversation_id, name=name
+ )
+
+ async def delete_conversation(
+ self,
+ conversation_id: UUID,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> None:
+ await (
+ self.providers.database.conversations_handler.delete_conversation(
+ conversation_id=conversation_id,
+ filter_user_ids=user_ids,
+ )
+ )
+
+ async def get_user_max_documents(self, user_id: UUID) -> int | None:
+ # Fetch the user to see if they have any overrides stored
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if user.limits_overrides and "max_documents" in user.limits_overrides:
+ return user.limits_overrides["max_documents"]
+ return self.config.app.default_max_documents_per_user
+
+ async def get_user_max_chunks(self, user_id: UUID) -> int | None:
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if user.limits_overrides and "max_chunks" in user.limits_overrides:
+ return user.limits_overrides["max_chunks"]
+ return self.config.app.default_max_chunks_per_user
+
+ async def get_user_max_collections(self, user_id: UUID) -> int | None:
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ if (
+ user.limits_overrides
+ and "max_collections" in user.limits_overrides
+ ):
+ return user.limits_overrides["max_collections"]
+ return self.config.app.default_max_collections_per_user
+
+ async def get_max_upload_size_by_type(
+ self, user_id: UUID, file_type_or_ext: str
+ ) -> int:
+ """Return the maximum allowed upload size (in bytes) for the given
+ user's file type/extension. Respects user-level overrides if present,
+ falling back to the system config.
+
+ ```json
+ {
+ "limits_overrides": {
+ "max_file_size": 20_000_000,
+ "max_file_size_by_type":
+ {
+ "pdf": 50_000_000,
+ "docx": 30_000_000
+ },
+ ...
+ }
+ }
+ ```
+ """
+ # 1. Normalize extension
+ ext = file_type_or_ext.lower().lstrip(".")
+
+ # 2. Fetch user from DB to see if we have any overrides
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ user_overrides = user.limits_overrides or {}
+
+ # 3. Check if there's a user-level override for "max_file_size_by_type"
+ user_file_type_limits = user_overrides.get("max_file_size_by_type", {})
+ if ext in user_file_type_limits:
+ return user_file_type_limits[ext]
+
+ # 4. If not, check if there's a user-level fallback "max_file_size"
+ if "max_file_size" in user_overrides:
+ return user_overrides["max_file_size"]
+
+ # 5. If none exist at user level, use system config
+ # Example config paths:
+ system_type_limits = self.config.app.max_upload_size_by_type
+ if ext in system_type_limits:
+ return system_type_limits[ext]
+
+ # 6. Otherwise, return the global default
+ return self.config.app.default_max_upload_size
+
+ async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
+ """
+ Return a dictionary containing:
+ - The system default limits (from self.config.limits)
+ - The user's overrides (from user.limits_overrides)
+ - The final 'effective' set of limits after merging (overall)
+ - The usage for each relevant limit (per-route usage, etc.)
+ """
+ # 1) Fetch the user
+ user = await self.providers.database.users_handler.get_user_by_id(
+ user_id
+ )
+ user_overrides = user.limits_overrides or {}
+
+ # 2) Grab system defaults
+ system_defaults = {
+ "global_per_min": self.config.database.limits.global_per_min,
+ "route_per_min": self.config.database.limits.route_per_min,
+ "monthly_limit": self.config.database.limits.monthly_limit,
+ # Add additional fields if your LimitSettings has them
+ }
+
+ # 3) Build the overall (global) "effective limits" ignoring any specific route
+ overall_effective = (
+ self.providers.database.limits_handler.determine_effective_limits(
+ user, route=""
+ )
+ )
+
+ # 4) Build usage data. We'll do top-level usage for global_per_min/monthly,
+ # then do route-by-route usage in a loop.
+ usage: dict[str, Any] = {}
+ now = datetime.now(timezone.utc)
+ one_min_ago = now - timedelta(minutes=1)
+
+ # (a) Global usage (per-minute)
+ global_per_min_used = (
+ await self.providers.database.limits_handler._count_requests(
+ user_id, route=None, since=one_min_ago
+ )
+ )
+ # (a2) Global usage (monthly) - i.e. usage across ALL routes
+ global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
+ user_id, route=None
+ )
+
+ usage["global_per_min"] = {
+ "used": global_per_min_used,
+ "limit": overall_effective.global_per_min,
+ "remaining": (
+ overall_effective.global_per_min - global_per_min_used
+ if overall_effective.global_per_min is not None
+ else None
+ ),
+ }
+ usage["monthly_limit"] = {
+ "used": global_monthly_used,
+ "limit": overall_effective.monthly_limit,
+ "remaining": (
+ overall_effective.monthly_limit - global_monthly_used
+ if overall_effective.monthly_limit is not None
+ else None
+ ),
+ }
+
+ # (b) Route-level usage. We'll gather all routes from system + user overrides
+ system_route_limits = (
+ self.config.database.route_limits
+ ) # dict[str, LimitSettings]
+ user_route_overrides = user_overrides.get("route_overrides", {})
+ route_keys = set(system_route_limits.keys()) | set(
+ user_route_overrides.keys()
+ )
+
+ usage["routes"] = {}
+ for route in route_keys:
+ # 1) Get the final merged limits for this specific route
+ route_effective = self.providers.database.limits_handler.determine_effective_limits(
+ user, route
+ )
+
+ # 2) Count requests for the last minute on this route
+ route_per_min_used = (
+ await self.providers.database.limits_handler._count_requests(
+ user_id, route, one_min_ago
+ )
+ )
+
+ # 3) Count route-specific monthly usage
+ route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
+ user_id, route
+ )
+
+ usage["routes"][route] = {
+ "route_per_min": {
+ "used": route_per_min_used,
+ "limit": route_effective.route_per_min,
+ "remaining": (
+ route_effective.route_per_min - route_per_min_used
+ if route_effective.route_per_min is not None
+ else None
+ ),
+ },
+ "monthly_limit": {
+ "used": route_monthly_used,
+ "limit": route_effective.monthly_limit,
+ "remaining": (
+ route_effective.monthly_limit - route_monthly_used
+ if route_effective.monthly_limit is not None
+ else None
+ ),
+ },
+ }
+
+ max_documents = await self.get_user_max_documents(user_id)
+ used_documents = (
+ await self.providers.database.documents_handler.get_documents_overview(
+ limit=1, offset=0, filter_user_ids=[user_id]
+ )
+ )["total_entries"]
+ max_chunks = await self.get_user_max_chunks(user_id)
+ used_chunks = (
+ await self.providers.database.chunks_handler.list_chunks(
+ limit=1, offset=0, filters={"owner_id": user_id}
+ )
+ )["total_entries"]
+
+ max_collections = await self.get_user_max_collections(user_id)
+ used_collections: int = ( # type: ignore
+ await self.providers.database.collections_handler.get_collections_overview(
+ limit=1, offset=0, filter_user_ids=[user_id]
+ )
+ )["total_entries"]
+
+ storage_limits = {
+ "chunks": {
+ "limit": max_chunks,
+ "used": used_chunks,
+ "remaining": (
+ max_chunks - used_chunks
+ if max_chunks is not None
+ else None
+ ),
+ },
+ "documents": {
+ "limit": max_documents,
+ "used": used_documents,
+ "remaining": (
+ max_documents - used_documents
+ if max_documents is not None
+ else None
+ ),
+ },
+ "collections": {
+ "limit": max_collections,
+ "used": used_collections,
+ "remaining": (
+ max_collections - used_collections
+ if max_collections is not None
+ else None
+ ),
+ },
+ }
+ # 5) Return a structured response
+ return {
+ "storage_limits": storage_limits,
+ "system_defaults": system_defaults,
+ "user_overrides": user_overrides,
+ "effective_limits": {
+ "global_per_min": overall_effective.global_per_min,
+ "route_per_min": overall_effective.route_per_min,
+ "monthly_limit": overall_effective.monthly_limit,
+ },
+ "usage": usage,
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py
new file mode 100644
index 00000000..2ae4af31
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/main/services/retrieval_service.py
@@ -0,0 +1,2087 @@
+import asyncio
+import json
+import logging
+from copy import deepcopy
+from datetime import datetime
+from typing import Any, AsyncGenerator, Literal, Optional
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core import (
+ Citation,
+ R2RRAGAgent,
+ R2RStreamingRAGAgent,
+ R2RStreamingResearchAgent,
+ R2RXMLToolsRAGAgent,
+ R2RXMLToolsResearchAgent,
+ R2RXMLToolsStreamingRAGAgent,
+ R2RXMLToolsStreamingResearchAgent,
+)
+from core.agent.research import R2RResearchAgent
+from core.base import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ DocumentResponse,
+ GenerationConfig,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+ GraphSearchResult,
+ GraphSearchResultType,
+ IngestionStatus,
+ Message,
+ R2RException,
+ SearchSettings,
+ WebSearchResult,
+ format_search_results_for_llm,
+)
+from core.base.api.models import RAGResponse, User
+from core.utils import (
+ CitationTracker,
+ SearchResultsCollector,
+ SSEFormatter,
+ dump_collector,
+ dump_obj,
+ extract_citations,
+ find_new_citation_spans,
+ num_tokens_from_messages,
+)
+from shared.api.models.management.responses import MessageResponse
+
+from ..abstractions import R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+class AgentFactory:
+ """
+ Factory class that creates appropriate agent instances based on mode,
+ model type, and streaming preferences.
+ """
+
+ @staticmethod
+ def create_agent(
+ mode: Literal["rag", "research"],
+ database_provider,
+ llm_provider,
+ config, # : AgentConfig
+ search_settings, # : SearchSettings
+ generation_config, #: GenerationConfig
+ app_config, #: AppConfig
+ knowledge_search_method,
+ content_method,
+ file_search_method,
+ max_tool_context_length: int = 32_768,
+ rag_tools: Optional[list[str]] = None,
+ research_tools: Optional[list[str]] = None,
+ tools: Optional[list[str]] = None, # For backward compatibility
+ ):
+ """
+ Creates and returns the appropriate agent based on provided parameters.
+
+ Args:
+ mode: Either "rag" or "research" to determine agent type
+ database_provider: Provider for database operations
+ llm_provider: Provider for LLM operations
+ config: Agent configuration
+ search_settings: Search settings for retrieval
+ generation_config: Generation configuration with LLM parameters
+ app_config: Application configuration
+ knowledge_search_method: Method for knowledge search
+ content_method: Method for content retrieval
+ file_search_method: Method for file search
+ max_tool_context_length: Maximum context length for tools
+ rag_tools: Tools specifically for RAG mode
+ research_tools: Tools specifically for Research mode
+ tools: Deprecated backward compatibility parameter
+
+ Returns:
+ An appropriate agent instance
+ """
+ # Create a deep copy of the config to avoid modifying the original
+ agent_config = deepcopy(config)
+
+ # Handle tool specifications based on mode
+ if mode == "rag":
+ # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults
+ if rag_tools:
+ agent_config.rag_tools = rag_tools
+ elif tools: # Backward compatibility
+ agent_config.rag_tools = tools
+ # If neither was provided, the config's default rag_tools will be used
+ elif mode == "research":
+ # For Research mode, prioritize explicitly passed research_tools, then tools, then config defaults
+ if research_tools:
+ agent_config.research_tools = research_tools
+ elif tools: # Backward compatibility
+ agent_config.research_tools = tools
+ # If neither was provided, the config's default research_tools will be used
+
+ # Determine if we need XML-based tools based on model
+ use_xml_format = False
+ # if generation_config.model:
+ # model_str = generation_config.model.lower()
+ # use_xml_format = "deepseek" in model_str or "gemini" in model_str
+
+ # Set streaming mode based on generation config
+ is_streaming = generation_config.stream
+
+ # Create the appropriate agent based on all factors
+ if mode == "rag":
+ # RAG mode agents
+ if is_streaming:
+ if use_xml_format:
+ return R2RXMLToolsStreamingRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RStreamingRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ if use_xml_format:
+ return R2RXMLToolsRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RRAGAgent(
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ # Research mode agents
+ if is_streaming:
+ if use_xml_format:
+ return R2RXMLToolsStreamingResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RStreamingResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ if use_xml_format:
+ return R2RXMLToolsResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+ else:
+ return R2RResearchAgent(
+ app_config=app_config,
+ database_provider=database_provider,
+ llm_provider=llm_provider,
+ config=agent_config,
+ search_settings=search_settings,
+ rag_generation_config=generation_config,
+ max_tool_context_length=max_tool_context_length,
+ knowledge_search_method=knowledge_search_method,
+ content_method=content_method,
+ file_search_method=file_search_method,
+ )
+
+
+class RetrievalService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ ):
+ super().__init__(
+ config,
+ providers,
+ )
+
+ async def search(
+ self,
+ query: str,
+ search_settings: SearchSettings = SearchSettings(),
+ *args,
+ **kwargs,
+ ) -> AggregateSearchResult:
+ """
+ Depending on search_settings.search_strategy, fan out
+ to basic, hyde, or rag_fusion method. Each returns
+ an AggregateSearchResult that includes chunk + graph results.
+ """
+ strategy = search_settings.search_strategy.lower()
+
+ if strategy == "hyde":
+ return await self._hyde_search(query, search_settings)
+ elif strategy == "rag_fusion":
+ return await self._rag_fusion_search(query, search_settings)
+ else:
+ # 'vanilla', 'basic', or anything else...
+ return await self._basic_search(query, search_settings)
+
+ async def _basic_search(
+ self, query: str, search_settings: SearchSettings
+ ) -> AggregateSearchResult:
+ """
+ 1) Possibly embed the query (if semantic or hybrid).
+ 2) Chunk search.
+ 3) Graph search.
+ 4) Combine into an AggregateSearchResult.
+ """
+ # -- 1) Possibly embed the query
+ query_vector = None
+ if (
+ search_settings.use_semantic_search
+ or search_settings.use_hybrid_search
+ ):
+ query_vector = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query # , EmbeddingPurpose.QUERY
+ )
+ )
+
+ # -- 2) Chunk search
+ chunk_results = []
+ if search_settings.chunk_settings.enabled:
+ chunk_results = await self._vector_search_logic(
+ query_text=query,
+ search_settings=search_settings,
+ precomputed_vector=query_vector, # Pass in the vector we just computed (if any)
+ )
+
+ # -- 3) Graph search
+ graph_results = []
+ if search_settings.graph_settings.enabled:
+ graph_results = await self._graph_search_logic(
+ query_text=query,
+ search_settings=search_settings,
+ precomputed_vector=query_vector, # same idea
+ )
+
+ # -- 4) Combine
+ return AggregateSearchResult(
+ chunk_search_results=chunk_results,
+ graph_search_results=graph_results,
+ )
+
+ async def _rag_fusion_search(
+ self, query: str, search_settings: SearchSettings
+ ) -> AggregateSearchResult:
+ """
+ Implements 'RAG Fusion':
+ 1) Generate N sub-queries from the user query
+ 2) For each sub-query => do chunk & graph search
+ 3) Combine / fuse all retrieved results using Reciprocal Rank Fusion
+ 4) Return an AggregateSearchResult
+ """
+
+ # 1) Generate sub-queries from the user’s original query
+ # Typically you want the original query to remain in the set as well,
+ # so that we do not lose the exact user intent.
+ sub_queries = [query]
+ if search_settings.num_sub_queries > 1:
+ # Generate (num_sub_queries - 1) rephrasings
+ # (Or just generate exactly search_settings.num_sub_queries,
+ # and remove the first if you prefer.)
+ extra = await self._generate_similar_queries(
+ query=query,
+ num_sub_queries=search_settings.num_sub_queries - 1,
+ )
+ sub_queries.extend(extra)
+
+ # 2) For each sub-query => do chunk + graph search
+ # We’ll store them in a structure so we can fuse them.
+ # chunk_results_list is a list of lists of ChunkSearchResult
+ # graph_results_list is a list of lists of GraphSearchResult
+ chunk_results_list = []
+ graph_results_list = []
+
+ for sq in sub_queries:
+ # Recompute or reuse the embedding if desired
+ # (You could do so, but not mandatory if you have a local approach)
+ # chunk + graph search
+ aggr = await self._basic_search(sq, search_settings)
+ chunk_results_list.append(aggr.chunk_search_results)
+ graph_results_list.append(aggr.graph_search_results)
+
+ # 3) Fuse the chunk results and fuse the graph results.
+ # We'll use a simple RRF approach: each sub-query's result list
+ # is a ranking from best to worst.
+ fused_chunk_results = self._reciprocal_rank_fusion_chunks( # type: ignore
+ chunk_results_list # type: ignore
+ )
+ filtered_graph_results = [
+ results for results in graph_results_list if results is not None
+ ]
+ fused_graph_results = self._reciprocal_rank_fusion_graphs(
+ filtered_graph_results
+ )
+
+ # Optionally, after the RRF, you may want to do a final semantic re-rank
+ # of the fused results by the user’s original query.
+ # E.g.:
+ if fused_chunk_results:
+ fused_chunk_results = (
+ await self.providers.completion_embedding.arerank(
+ query=query,
+ results=fused_chunk_results,
+ limit=search_settings.limit,
+ )
+ )
+
+ # Sort or slice the graph results if needed:
+ if fused_graph_results and search_settings.include_scores:
+ fused_graph_results.sort(
+ key=lambda g: g.score if g.score is not None else 0.0,
+ reverse=True,
+ )
+ fused_graph_results = fused_graph_results[: search_settings.limit]
+
+ # 4) Return final AggregateSearchResult
+ return AggregateSearchResult(
+ chunk_search_results=fused_chunk_results,
+ graph_search_results=fused_graph_results,
+ )
+
+ async def _generate_similar_queries(
+ self, query: str, num_sub_queries: int = 2
+ ) -> list[str]:
+ """
+ Use your LLM to produce 'similar' queries or rephrasings
+ that might retrieve different but relevant documents.
+
+ You can prompt your model with something like:
+ "Given the user query, produce N alternative short queries that
+ capture possible interpretations or expansions.
+ Keep them relevant to the user's intent."
+ """
+ if num_sub_queries < 1:
+ return []
+
+ # In production, you'd fetch a prompt from your prompts DB:
+ # Something like:
+ prompt = f"""
+ You are a helpful assistant. The user query is: "{query}"
+ Generate {num_sub_queries} alternative search queries that capture
+ slightly different phrasings or expansions while preserving the core meaning.
+ Return each alternative on its own line.
+ """
+
+ # For a short generation, we can set minimal tokens
+ gen_config = GenerationConfig(
+ model=self.config.app.fast_llm,
+ max_tokens=128,
+ temperature=0.8,
+ stream=False,
+ )
+ response = await self.providers.llm.aget_completion(
+ messages=[{"role": "system", "content": prompt}],
+ generation_config=gen_config,
+ )
+ raw_text = (
+ response.choices[0].message.content.strip()
+ if response.choices[0].message.content is not None
+ else ""
+ )
+
+ # Suppose each line is a sub-query
+ lines = [line.strip() for line in raw_text.split("\n") if line.strip()]
+ return lines[:num_sub_queries]
+
+ def _reciprocal_rank_fusion_chunks(
+ self, list_of_rankings: list[list[ChunkSearchResult]], k: float = 60.0
+ ) -> list[ChunkSearchResult]:
+ """
+ Simple RRF for chunk results.
+ list_of_rankings is something like:
+ [
+ [chunkA, chunkB, chunkC], # sub-query #1, in order
+ [chunkC, chunkD], # sub-query #2, in order
+ ...
+ ]
+
+ We'll produce a dictionary mapping chunk.id -> aggregated_score,
+ then sort descending.
+ """
+ if not list_of_rankings:
+ return []
+
+ # Build a map of chunk_id => final_rff_score
+ score_map: dict[str, float] = {}
+
+ # We also need to store a reference to the chunk object
+ # (the "first" or "best" instance), so we can reconstruct them later
+ chunk_map: dict[str, Any] = {}
+
+ for ranking_list in list_of_rankings:
+ for rank, chunk_result in enumerate(ranking_list, start=1):
+ if not chunk_result.id:
+ # fallback if no chunk_id is present
+ continue
+
+ c_id = chunk_result.id
+ # RRF scoring
+ # score = sum(1 / (k + rank)) for each sub-query ranking
+ # We'll accumulate it.
+ existing_score = score_map.get(str(c_id), 0.0)
+ new_score = existing_score + 1.0 / (k + rank)
+ score_map[str(c_id)] = new_score
+
+ # Keep a reference to chunk
+ if c_id not in chunk_map:
+ chunk_map[str(c_id)] = chunk_result
+
+ # Now sort by final score
+ fused_items = sorted(
+ score_map.items(), key=lambda x: x[1], reverse=True
+ )
+
+ # Rebuild the final list of chunk results with new 'score'
+ fused_chunks = []
+ for c_id, agg_score in fused_items: # type: ignore
+ # copy the chunk
+ c = chunk_map[str(c_id)]
+ # Optionally store the RRF score if you want
+ c.score = agg_score
+ fused_chunks.append(c)
+
+ return fused_chunks
+
+ def _reciprocal_rank_fusion_graphs(
+ self, list_of_rankings: list[list[GraphSearchResult]], k: float = 60.0
+ ) -> list[GraphSearchResult]:
+ """
+ Similar RRF logic but for graph results.
+ """
+ if not list_of_rankings:
+ return []
+
+ score_map: dict[str, float] = {}
+ graph_map = {}
+
+ for ranking_list in list_of_rankings:
+ for rank, g_result in enumerate(ranking_list, start=1):
+ # We'll do a naive ID approach:
+ # If your GraphSearchResult has a unique ID in g_result.content.id or so
+ # we can use that as a key.
+ # If not, you might have to build a key from the content.
+ g_id = None
+ if hasattr(g_result.content, "id"):
+ g_id = str(g_result.content.id)
+ else:
+ # fallback
+ g_id = f"graph_{hash(g_result.content.json())}"
+
+ existing_score = score_map.get(g_id, 0.0)
+ new_score = existing_score + 1.0 / (k + rank)
+ score_map[g_id] = new_score
+
+ if g_id not in graph_map:
+ graph_map[g_id] = g_result
+
+ # Sort descending by aggregated RRF score
+ fused_items = sorted(
+ score_map.items(), key=lambda x: x[1], reverse=True
+ )
+
+ fused_graphs = []
+ for g_id, agg_score in fused_items:
+ g = graph_map[g_id]
+ g.score = agg_score
+ fused_graphs.append(g)
+
+ return fused_graphs
+
+ async def _hyde_search(
+ self, query: str, search_settings: SearchSettings
+ ) -> AggregateSearchResult:
+ """
+ 1) Generate N hypothetical docs via LLM
+ 2) For each doc => embed => parallel chunk search & graph search
+ 3) Merge chunk results => optional re-rank => top K
+ 4) Merge graph results => (optionally re-rank or keep them distinct)
+ """
+ # 1) Generate hypothetical docs
+ hyde_docs = await self._run_hyde_generation(
+ query=query, num_sub_queries=search_settings.num_sub_queries
+ )
+
+ chunk_all = []
+ graph_all = []
+
+ # We'll gather the per-doc searches in parallel
+ tasks = []
+ for hypothetical_text in hyde_docs:
+ tasks.append(
+ asyncio.create_task(
+ self._fanout_chunk_and_graph_search(
+ user_text=query, # The user’s original query
+ alt_text=hypothetical_text, # The hypothetical doc
+ search_settings=search_settings,
+ )
+ )
+ )
+
+ # 2) Wait for them all
+ results_list = await asyncio.gather(*tasks)
+ # each item in results_list is a tuple: (chunks, graphs)
+
+ # Flatten chunk+graph results
+ for c_results, g_results in results_list:
+ chunk_all.extend(c_results)
+ graph_all.extend(g_results)
+
+ # 3) Re-rank chunk results with the original query
+ if chunk_all:
+ chunk_all = await self.providers.completion_embedding.arerank(
+ query=query, # final user query
+ results=chunk_all,
+ limit=int(
+ search_settings.limit * search_settings.num_sub_queries
+ ),
+ # no limit on results - limit=search_settings.limit,
+ )
+
+ # 4) If needed, re-rank graph results or just slice top-K by score
+ if search_settings.include_scores and graph_all:
+ graph_all.sort(key=lambda g: g.score or 0.0, reverse=True)
+ graph_all = (
+ graph_all # no limit on results - [: search_settings.limit]
+ )
+
+ return AggregateSearchResult(
+ chunk_search_results=chunk_all,
+ graph_search_results=graph_all,
+ )
+
+ async def _fanout_chunk_and_graph_search(
+ self,
+ user_text: str,
+ alt_text: str,
+ search_settings: SearchSettings,
+ ) -> tuple[list[ChunkSearchResult], list[GraphSearchResult]]:
+ """
+ 1) embed alt_text (HyDE doc or sub-query, etc.)
+ 2) chunk search + graph search with that embedding
+ """
+ # Precompute the embedding of alt_text
+ vec = await self.providers.completion_embedding.async_get_embedding(
+ alt_text # , EmbeddingPurpose.QUERY
+ )
+
+ # chunk search
+ chunk_results = []
+ if search_settings.chunk_settings.enabled:
+ chunk_results = await self._vector_search_logic(
+ query_text=user_text, # used for text-based stuff & re-ranking
+ search_settings=search_settings,
+ precomputed_vector=vec, # use the alt_text vector for semantic/hybrid
+ )
+
+ # graph search
+ graph_results = []
+ if search_settings.graph_settings.enabled:
+ graph_results = await self._graph_search_logic(
+ query_text=user_text, # or alt_text if you prefer
+ search_settings=search_settings,
+ precomputed_vector=vec,
+ )
+
+ return (chunk_results, graph_results)
+
+ async def _vector_search_logic(
+ self,
+ query_text: str,
+ search_settings: SearchSettings,
+ precomputed_vector: Optional[list[float]] = None,
+ ) -> list[ChunkSearchResult]:
+ """
+ • If precomputed_vector is given, use it for semantic/hybrid search.
+ Otherwise embed query_text ourselves.
+ • Then do fulltext, semantic, or hybrid search.
+ • Optionally re-rank and return results.
+ """
+ if not search_settings.chunk_settings.enabled:
+ return []
+
+ # 1) Possibly embed
+ query_vector = precomputed_vector
+ if query_vector is None and (
+ search_settings.use_semantic_search
+ or search_settings.use_hybrid_search
+ ):
+ query_vector = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query_text # , EmbeddingPurpose.QUERY
+ )
+ )
+
+ # 2) Choose which search to run
+ if (
+ search_settings.use_fulltext_search
+ and search_settings.use_semantic_search
+ ) or search_settings.use_hybrid_search:
+ if query_vector is None:
+ raise ValueError("Hybrid search requires a precomputed vector")
+ raw_results = (
+ await self.providers.database.chunks_handler.hybrid_search(
+ query_vector=query_vector,
+ query_text=query_text,
+ search_settings=search_settings,
+ )
+ )
+ elif search_settings.use_fulltext_search:
+ raw_results = (
+ await self.providers.database.chunks_handler.full_text_search(
+ query_text=query_text,
+ search_settings=search_settings,
+ )
+ )
+ elif search_settings.use_semantic_search:
+ if query_vector is None:
+ raise ValueError(
+ "Semantic search requires a precomputed vector"
+ )
+ raw_results = (
+ await self.providers.database.chunks_handler.semantic_search(
+ query_vector=query_vector,
+ search_settings=search_settings,
+ )
+ )
+ else:
+ raise ValueError(
+ "At least one of use_fulltext_search or use_semantic_search must be True"
+ )
+
+ # 3) Re-rank
+ reranked = await self.providers.completion_embedding.arerank(
+ query=query_text, results=raw_results, limit=search_settings.limit
+ )
+
+ # 4) Possibly augment text or metadata
+ final_results = []
+ for r in reranked:
+ if "title" in r.metadata and search_settings.include_metadatas:
+ title = r.metadata["title"]
+ r.text = f"Document Title: {title}\n\nText: {r.text}"
+ r.metadata["associated_query"] = query_text
+ final_results.append(r)
+
+ return final_results
+
+ async def _graph_search_logic(
+ self,
+ query_text: str,
+ search_settings: SearchSettings,
+ precomputed_vector: Optional[list[float]] = None,
+ ) -> list[GraphSearchResult]:
+ """
+ Mirrors your previous GraphSearch approach:
+ • if precomputed_vector is supplied, use that
+ • otherwise embed query_text
+ • search entities, relationships, communities
+ • return results
+ """
+ results: list[GraphSearchResult] = []
+
+ if not search_settings.graph_settings.enabled:
+ return results
+
+ # 1) Possibly embed
+ query_embedding = precomputed_vector
+ if query_embedding is None:
+ query_embedding = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query_text
+ )
+ )
+
+ base_limit = search_settings.limit
+ graph_limits = search_settings.graph_settings.limits or {}
+
+ # Entity search
+ entity_limit = graph_limits.get("entities", base_limit)
+ entity_cursor = self.providers.database.graphs_handler.graph_search(
+ query_text,
+ search_type="entities",
+ limit=entity_limit,
+ query_embedding=query_embedding,
+ property_names=["name", "description", "id"],
+ filters=search_settings.filters,
+ )
+ async for ent in entity_cursor:
+ score = ent.get("similarity_score")
+ metadata = ent.get("metadata", {})
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except Exception as e:
+ pass
+
+ results.append(
+ GraphSearchResult(
+ id=ent.get("id", None),
+ content=GraphEntityResult(
+ name=ent.get("name", ""),
+ description=ent.get("description", ""),
+ id=ent.get("id", None),
+ ),
+ result_type=GraphSearchResultType.ENTITY,
+ score=score if search_settings.include_scores else None,
+ metadata=(
+ {
+ **(metadata or {}),
+ "associated_query": query_text,
+ }
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ )
+
+ # Relationship search
+ rel_limit = graph_limits.get("relationships", base_limit)
+ rel_cursor = self.providers.database.graphs_handler.graph_search(
+ query_text,
+ search_type="relationships",
+ limit=rel_limit,
+ query_embedding=query_embedding,
+ property_names=[
+ "id",
+ "subject",
+ "predicate",
+ "object",
+ "description",
+ "subject_id",
+ "object_id",
+ ],
+ filters=search_settings.filters,
+ )
+ async for rel in rel_cursor:
+ score = rel.get("similarity_score")
+ metadata = rel.get("metadata", {})
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except Exception as e:
+ pass
+
+ results.append(
+ GraphSearchResult(
+ id=ent.get("id", None),
+ content=GraphRelationshipResult(
+ id=rel.get("id", None),
+ subject=rel.get("subject", ""),
+ predicate=rel.get("predicate", ""),
+ object=rel.get("object", ""),
+ subject_id=rel.get("subject_id", None),
+ object_id=rel.get("object_id", None),
+ description=rel.get("description", ""),
+ ),
+ result_type=GraphSearchResultType.RELATIONSHIP,
+ score=score if search_settings.include_scores else None,
+ metadata=(
+ {
+ **(metadata or {}),
+ "associated_query": query_text,
+ }
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ )
+
+ # Community search
+ comm_limit = graph_limits.get("communities", base_limit)
+ comm_cursor = self.providers.database.graphs_handler.graph_search(
+ query_text,
+ search_type="communities",
+ limit=comm_limit,
+ query_embedding=query_embedding,
+ property_names=[
+ "id",
+ "name",
+ "summary",
+ ],
+ filters=search_settings.filters,
+ )
+ async for comm in comm_cursor:
+ score = comm.get("similarity_score")
+ metadata = comm.get("metadata", {})
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except Exception as e:
+ pass
+
+ results.append(
+ GraphSearchResult(
+ id=ent.get("id", None),
+ content=GraphCommunityResult(
+ id=comm.get("id", None),
+ name=comm.get("name", ""),
+ summary=comm.get("summary", ""),
+ ),
+ result_type=GraphSearchResultType.COMMUNITY,
+ score=score if search_settings.include_scores else None,
+ metadata=(
+ {
+ **(metadata or {}),
+ "associated_query": query_text,
+ }
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ )
+
+ return results
+
+ async def _run_hyde_generation(
+ self,
+ query: str,
+ num_sub_queries: int = 2,
+ ) -> list[str]:
+ """
+ Calls the LLM with a 'HyDE' style prompt to produce multiple
+ hypothetical documents/answers, one per line or separated by blank lines.
+ """
+ # Retrieve the prompt template from your database or config:
+ # e.g. your "hyde" prompt has placeholders: {message}, {num_outputs}
+ hyde_template = (
+ await self.providers.database.prompts_handler.get_cached_prompt(
+ prompt_name="hyde",
+ inputs={"message": query, "num_outputs": num_sub_queries},
+ )
+ )
+
+ # Now call the LLM with that as the system or user prompt:
+ completion_config = GenerationConfig(
+ model=self.config.app.fast_llm, # or whichever short/cheap model
+ max_tokens=512,
+ temperature=0.7,
+ stream=False,
+ )
+
+ response = await self.providers.llm.aget_completion(
+ messages=[{"role": "system", "content": hyde_template}],
+ generation_config=completion_config,
+ )
+
+ # Suppose the LLM returns something like:
+ #
+ # "Doc1. Some made up text.\n\nDoc2. Another made up text.\n\n"
+ #
+ # So we split by double-newline or some pattern:
+ raw_text = response.choices[0].message.content
+ return [
+ chunk.strip()
+ for chunk in (raw_text or "").split("\n\n")
+ if chunk.strip()
+ ]
+
+ async def search_documents(
+ self,
+ query: str,
+ settings: SearchSettings,
+ query_embedding: Optional[list[float]] = None,
+ ) -> list[DocumentResponse]:
+ if query_embedding is None:
+ query_embedding = (
+ await self.providers.completion_embedding.async_get_embedding(
+ query
+ )
+ )
+ result = (
+ await self.providers.database.documents_handler.search_documents(
+ query_text=query,
+ settings=settings,
+ query_embedding=query_embedding,
+ )
+ )
+ return result
+
+ async def completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ *args,
+ **kwargs,
+ ):
+ return await self.providers.llm.aget_completion(
+ [message.to_dict() for message in messages], # type: ignore
+ generation_config,
+ *args,
+ **kwargs,
+ )
+
+ async def embedding(
+ self,
+ text: str,
+ ):
+ return await self.providers.completion_embedding.async_get_embedding(
+ text=text
+ )
+
+ async def rag(
+ self,
+ query: str,
+ rag_generation_config: GenerationConfig,
+ search_settings: SearchSettings = SearchSettings(),
+ system_prompt_name: str | None = None,
+ task_prompt_name: str | None = None,
+ include_web_search: bool = False,
+ **kwargs,
+ ) -> Any:
+ """
+ A single RAG method that can do EITHER a one-shot synchronous RAG or
+ streaming SSE-based RAG, depending on rag_generation_config.stream.
+
+ 1) Perform aggregator search => context
+ 2) Build system+task prompts => messages
+ 3) If not streaming => normal LLM call => return RAGResponse
+ 4) If streaming => return an async generator of SSE lines
+ """
+ # 1) Possibly fix up any UUID filters in search_settings
+ for f, val in list(search_settings.filters.items()):
+ if isinstance(val, UUID):
+ search_settings.filters[f] = str(val)
+
+ try:
+ # 2) Perform search => aggregated_results
+ aggregated_results = await self.search(query, search_settings)
+ # 3) Optionally add web search results if flag is enabled
+ if include_web_search:
+ web_results = await self._perform_web_search(query)
+ # Merge web search results with existing aggregated results
+ if web_results and web_results.web_search_results:
+ if not aggregated_results.web_search_results:
+ aggregated_results.web_search_results = (
+ web_results.web_search_results
+ )
+ else:
+ aggregated_results.web_search_results.extend(
+ web_results.web_search_results
+ )
+ # 3) Build context from aggregator
+ collector = SearchResultsCollector()
+ collector.add_aggregate_result(aggregated_results)
+ context_str = format_search_results_for_llm(
+ aggregated_results, collector
+ )
+
+ # 4) Prepare system+task messages
+ system_prompt_name = system_prompt_name or "system"
+ task_prompt_name = task_prompt_name or "rag"
+ task_prompt = kwargs.get("task_prompt")
+
+ messages = await self.providers.database.prompts_handler.get_message_payload(
+ system_prompt_name=system_prompt_name,
+ task_prompt_name=task_prompt_name,
+ task_inputs={"query": query, "context": context_str},
+ task_prompt=task_prompt,
+ )
+
+ # 5) Check streaming vs. non-streaming
+ if not rag_generation_config.stream:
+ # ========== Non-Streaming Logic ==========
+ response = await self.providers.llm.aget_completion(
+ messages=messages,
+ generation_config=rag_generation_config,
+ )
+ llm_text = response.choices[0].message.content
+
+ # (a) Extract short-ID references from final text
+ raw_sids = extract_citations(llm_text or "")
+
+ # (b) Possibly prune large content out of metadata
+ metadata = response.dict()
+ if "choices" in metadata and len(metadata["choices"]) > 0:
+ metadata["choices"][0]["message"].pop("content", None)
+
+ # (c) Build final RAGResponse
+ rag_resp = RAGResponse(
+ generated_answer=llm_text or "",
+ search_results=aggregated_results,
+ citations=[
+ Citation(
+ id=f"{sid}",
+ object="citation",
+ payload=dump_obj( # type: ignore
+ self._find_item_by_shortid(sid, collector)
+ ),
+ )
+ for sid in raw_sids
+ ],
+ metadata=metadata,
+ completion=llm_text or "",
+ )
+ return rag_resp
+
+ else:
+ # ========== Streaming SSE Logic ==========
+ async def sse_generator() -> AsyncGenerator[str, None]:
+ # 1) Emit search results via SSEFormatter
+ async for line in SSEFormatter.yield_search_results_event(
+ aggregated_results
+ ):
+ yield line
+
+ # Initialize citation tracker to manage citation state
+ citation_tracker = CitationTracker()
+
+ # Store citation payloads by ID for reuse
+ citation_payloads = {}
+
+ partial_text_buffer = ""
+
+ # Begin streaming from the LLM
+ msg_stream = self.providers.llm.aget_completion_stream(
+ messages=messages,
+ generation_config=rag_generation_config,
+ )
+
+ try:
+ async for chunk in msg_stream:
+ delta = chunk.choices[0].delta
+ finish_reason = chunk.choices[0].finish_reason
+ # if delta.thinking:
+ # check if delta has `thinking` attribute
+
+ if hasattr(delta, "thinking") and delta.thinking:
+ # Emit SSE "thinking" event
+ async for (
+ line
+ ) in SSEFormatter.yield_thinking_event(
+ delta.thinking
+ ):
+ yield line
+
+ if delta.content:
+ # (b) Emit SSE "message" event for this chunk of text
+ async for (
+ line
+ ) in SSEFormatter.yield_message_event(
+ delta.content
+ ):
+ yield line
+
+ # Accumulate new text
+ partial_text_buffer += delta.content
+
+ # (a) Extract citations from updated buffer
+ # For each *new* short ID, emit an SSE "citation" event
+ # Find new citation spans in the accumulated text
+ new_citation_spans = find_new_citation_spans(
+ partial_text_buffer, citation_tracker
+ )
+
+ # Process each new citation span
+ for cid, spans in new_citation_spans.items():
+ for span in spans:
+ # Check if this is the first time we've seen this citation ID
+ is_new_citation = (
+ citation_tracker.is_new_citation(
+ cid
+ )
+ )
+
+ # Get payload if it's a new citation
+ payload = None
+ if is_new_citation:
+ source_obj = (
+ self._find_item_by_shortid(
+ cid, collector
+ )
+ )
+ if source_obj:
+ # Store payload for reuse
+ payload = dump_obj(source_obj)
+ citation_payloads[cid] = (
+ payload
+ )
+
+ # Create citation event payload
+ citation_data = {
+ "id": cid,
+ "object": "citation",
+ "is_new": is_new_citation,
+ "span": {
+ "start": span[0],
+ "end": span[1],
+ },
+ }
+
+ # Only include full payload for new citations
+ if is_new_citation and payload:
+ citation_data["payload"] = payload
+
+ # Emit the citation event
+ async for (
+ line
+ ) in SSEFormatter.yield_citation_event(
+ citation_data
+ ):
+ yield line
+
+ # If the LLM signals it’s done
+ if finish_reason == "stop":
+ # Prepare consolidated citations for final answer event
+ consolidated_citations = []
+ # Group citations by ID with all their spans
+ for (
+ cid,
+ spans,
+ ) in citation_tracker.get_all_spans().items():
+ if cid in citation_payloads:
+ consolidated_citations.append(
+ {
+ "id": cid,
+ "object": "citation",
+ "spans": [
+ {
+ "start": s[0],
+ "end": s[1],
+ }
+ for s in spans
+ ],
+ "payload": citation_payloads[
+ cid
+ ],
+ }
+ )
+
+ # (c) Emit final answer + all collected citations
+ final_answer_evt = {
+ "id": "msg_final",
+ "object": "rag.final_answer",
+ "generated_answer": partial_text_buffer,
+ "citations": consolidated_citations,
+ }
+ async for (
+ line
+ ) in SSEFormatter.yield_final_answer_event(
+ final_answer_evt
+ ):
+ yield line
+
+ # (d) Signal the end of the SSE stream
+ yield SSEFormatter.yield_done_event()
+ break
+
+ except Exception as e:
+ logger.error(f"Error streaming LLM in rag: {e}")
+ # Optionally yield an SSE "error" event or handle differently
+ raise
+
+ return sse_generator()
+
+ except Exception as e:
+ logger.exception(f"Error in RAG pipeline: {e}")
+ if "NoneType" in str(e):
+ raise HTTPException(
+ status_code=502,
+ detail="Server not reachable or returned an invalid response",
+ ) from e
+ raise HTTPException(
+ status_code=500,
+ detail=f"Internal RAG Error - {str(e)}",
+ ) from e
+
+ def _find_item_by_shortid(
+ self, sid: str, collector: SearchResultsCollector
+ ) -> Optional[tuple[str, Any, int]]:
+ """
+ Example helper that tries to match aggregator items by short ID,
+ meaning result_obj.id starts with sid.
+ """
+ for source_type, result_obj in collector.get_all_results():
+ # if the aggregator item has an 'id' attribute
+ if getattr(result_obj, "id", None) is not None:
+ full_id_str = str(result_obj.id)
+ if full_id_str.startswith(sid):
+ if source_type == "chunk":
+ return (
+ result_obj.as_dict()
+ ) # (source_type, result_obj.as_dict())
+ else:
+ return result_obj # (source_type, result_obj)
+ return None
+
+ async def agent(
+ self,
+ rag_generation_config: GenerationConfig,
+ rag_tools: Optional[list[str]] = None,
+ tools: Optional[list[str]] = None, # backward compatibility
+ search_settings: SearchSettings = SearchSettings(),
+ task_prompt: Optional[str] = None,
+ include_title_if_available: Optional[bool] = False,
+ conversation_id: Optional[UUID] = None,
+ message: Optional[Message] = None,
+ messages: Optional[list[Message]] = None,
+ use_system_context: bool = False,
+ max_tool_context_length: int = 32_768,
+ research_tools: Optional[list[str]] = None,
+ research_generation_config: Optional[GenerationConfig] = None,
+ needs_initial_conversation_name: Optional[bool] = None,
+ mode: Optional[Literal["rag", "research"]] = "rag",
+ ):
+ """
+ Engage with an intelligent agent for information retrieval, analysis, and research.
+
+ Args:
+ rag_generation_config: Configuration for RAG mode generation
+ search_settings: Search configuration for retrieving context
+ task_prompt: Optional custom prompt override
+ include_title_if_available: Whether to include document titles
+ conversation_id: Optional conversation ID for continuity
+ message: Current message to process
+ messages: List of messages (deprecated)
+ use_system_context: Whether to use extended prompt
+ max_tool_context_length: Maximum context length for tools
+ rag_tools: List of tools for RAG mode
+ research_tools: List of tools for Research mode
+ research_generation_config: Configuration for Research mode generation
+ mode: Either "rag" or "research"
+
+ Returns:
+ Agent response with messages and conversation ID
+ """
+ try:
+ # Validate message inputs
+ if message and messages:
+ raise R2RException(
+ status_code=400,
+ message="Only one of message or messages should be provided",
+ )
+
+ if not message and not messages:
+ raise R2RException(
+ status_code=400,
+ message="Either message or messages should be provided",
+ )
+
+ # Ensure 'message' is a Message instance
+ if message and not isinstance(message, Message):
+ if isinstance(message, dict):
+ message = Message.from_dict(message)
+ else:
+ raise R2RException(
+ status_code=400,
+ message="""
+ Invalid message format. The expected format contains:
+ role: MessageType | 'system' | 'user' | 'assistant' | 'function'
+ content: Optional[str]
+ name: Optional[str]
+ function_call: Optional[dict[str, Any]]
+ tool_calls: Optional[list[dict[str, Any]]]
+ """,
+ )
+
+ # Ensure 'messages' is a list of Message instances
+ if messages:
+ processed_messages = []
+ for msg in messages:
+ if isinstance(msg, Message):
+ processed_messages.append(msg)
+ elif hasattr(msg, "dict"):
+ processed_messages.append(
+ Message.from_dict(msg.dict())
+ )
+ elif isinstance(msg, dict):
+ processed_messages.append(Message.from_dict(msg))
+ else:
+ processed_messages.append(Message.from_dict(str(msg)))
+ messages = processed_messages
+ else:
+ messages = []
+
+ # Validate and process mode-specific configurations
+ if mode == "rag" and research_tools:
+ logger.warning(
+ "research_tools provided but mode is 'rag'. These tools will be ignored."
+ )
+ research_tools = None
+
+ # Determine effective generation config based on mode
+ effective_generation_config = rag_generation_config
+ if mode == "research" and research_generation_config:
+ effective_generation_config = research_generation_config
+
+ # Set appropriate LLM model based on mode if not explicitly specified
+ if "model" not in effective_generation_config.__fields_set__:
+ if mode == "rag":
+ effective_generation_config.model = (
+ self.config.app.quality_llm
+ )
+ elif mode == "research":
+ effective_generation_config.model = (
+ self.config.app.planning_llm
+ )
+
+ # Transform UUID filters to strings
+ for filter_key, value in search_settings.filters.items():
+ if isinstance(value, UUID):
+ search_settings.filters[filter_key] = str(value)
+
+ # Process conversation data
+ ids = []
+ if conversation_id: # Fetch the existing conversation
+ try:
+ conversation_messages = await self.providers.database.conversations_handler.get_conversation(
+ conversation_id=conversation_id,
+ )
+ if needs_initial_conversation_name is None:
+ overview = await self.providers.database.conversations_handler.get_conversations_overview(
+ offset=0,
+ limit=1,
+ conversation_ids=[conversation_id],
+ )
+ if overview.get("total_entries", 0) > 0:
+ needs_initial_conversation_name = (
+ overview.get("results")[0].get("name") is None # type: ignore
+ )
+ except Exception as e:
+ logger.error(f"Error fetching conversation: {str(e)}")
+
+ if conversation_messages is not None:
+ messages_from_conversation: list[Message] = []
+ for message_response in conversation_messages:
+ if isinstance(message_response, MessageResponse):
+ messages_from_conversation.append(
+ message_response.message
+ )
+ ids.append(message_response.id)
+ else:
+ logger.warning(
+ f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
+ )
+ messages = messages_from_conversation + messages
+ else: # Create new conversation
+ conversation_response = await self.providers.database.conversations_handler.create_conversation()
+ conversation_id = conversation_response.id
+ needs_initial_conversation_name = True
+
+ if message:
+ messages.append(message)
+
+ if not messages:
+ raise R2RException(
+ status_code=400,
+ message="No messages to process",
+ )
+
+ current_message = messages[-1]
+ logger.debug(
+ f"Running the agent with conversation_id = {conversation_id} and message = {current_message}"
+ )
+
+ # Save the new message to the conversation
+ parent_id = ids[-1] if ids else None
+ message_response = await self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=current_message,
+ parent_id=parent_id,
+ )
+
+ message_id = (
+ message_response.id if message_response is not None else None
+ )
+
+ # Extract filter information from search settings
+ filter_user_id, filter_collection_ids = (
+ self._parse_user_and_collection_filters(
+ search_settings.filters
+ )
+ )
+
+ # Validate system instruction configuration
+ if use_system_context and task_prompt:
+ raise R2RException(
+ status_code=400,
+ message="Both use_system_context and task_prompt cannot be True at the same time",
+ )
+
+ # Build the system instruction
+ if task_prompt:
+ system_instruction = task_prompt
+ else:
+ system_instruction = (
+ await self._build_aware_system_instruction(
+ max_tool_context_length=max_tool_context_length,
+ filter_user_id=filter_user_id,
+ filter_collection_ids=filter_collection_ids,
+ model=effective_generation_config.model,
+ use_system_context=use_system_context,
+ mode=mode,
+ )
+ )
+
+ # Configure agent with appropriate tools
+ agent_config = deepcopy(self.config.agent)
+ if mode == "rag":
+ # Use provided RAG tools or default from config
+ agent_config.rag_tools = (
+ rag_tools or tools or self.config.agent.rag_tools
+ )
+ else: # research mode
+ # Use provided Research tools or default from config
+ agent_config.research_tools = (
+ research_tools or tools or self.config.agent.research_tools
+ )
+
+ # Create the agent using our factory
+ mode = mode or "rag"
+
+ for msg in messages:
+ if msg.content is None:
+ msg.content = ""
+
+ agent = AgentFactory.create_agent(
+ mode=mode,
+ database_provider=self.providers.database,
+ llm_provider=self.providers.llm,
+ config=agent_config,
+ search_settings=search_settings,
+ generation_config=effective_generation_config,
+ app_config=self.config.app,
+ knowledge_search_method=self.search,
+ content_method=self.get_context,
+ file_search_method=self.search_documents,
+ max_tool_context_length=max_tool_context_length,
+ rag_tools=rag_tools,
+ research_tools=research_tools,
+ tools=tools, # Backward compatibility
+ )
+
+ # Handle streaming vs. non-streaming response
+ if effective_generation_config.stream:
+
+ async def stream_response():
+ try:
+ async for chunk in agent.arun(
+ messages=messages,
+ system_instruction=system_instruction,
+ include_title_if_available=include_title_if_available,
+ ):
+ yield chunk
+ except Exception as e:
+ logger.error(f"Error streaming agent output: {e}")
+ raise e
+ finally:
+ # Persist conversation data
+ msgs = [
+ msg.to_dict()
+ for msg in agent.conversation.messages
+ ]
+ input_tokens = num_tokens_from_messages(msgs[:-1])
+ output_tokens = num_tokens_from_messages([msgs[-1]])
+ await self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=agent.conversation.messages[-1],
+ parent_id=message_id,
+ metadata={
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ },
+ )
+
+ # Generate conversation name if needed
+ if needs_initial_conversation_name:
+ try:
+ prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict())}"
+ conversation_name = (
+ (
+ await self.providers.llm.aget_completion(
+ [
+ {
+ "role": "system",
+ "content": prompt,
+ }
+ ],
+ GenerationConfig(
+ model=self.config.app.fast_llm
+ ),
+ )
+ )
+ .choices[0]
+ .message.content
+ )
+ await self.providers.database.conversations_handler.update_conversation(
+ conversation_id=conversation_id,
+ name=conversation_name,
+ )
+ except Exception as e:
+ logger.error(
+ f"Error generating conversation name: {e}"
+ )
+
+ return stream_response()
+ else:
+ for idx, msg in enumerate(messages):
+ if msg.content is None:
+ if (
+ hasattr(msg, "structured_content")
+ and msg.structured_content
+ ):
+ messages[idx].content = ""
+ else:
+ messages[idx].content = ""
+
+ # Non-streaming path
+ results = await agent.arun(
+ messages=messages,
+ system_instruction=system_instruction,
+ include_title_if_available=include_title_if_available,
+ )
+
+ # Process the agent results
+ if isinstance(results[-1], dict):
+ if results[-1].get("content") is None:
+ results[-1]["content"] = ""
+ assistant_message = Message(**results[-1])
+ elif isinstance(results[-1], Message):
+ assistant_message = results[-1]
+ if assistant_message.content is None:
+ assistant_message.content = ""
+ else:
+ assistant_message = Message(
+ role="assistant", content=str(results[-1])
+ )
+
+ # Get search results collector for citations
+ if hasattr(agent, "search_results_collector"):
+ collector = agent.search_results_collector
+ else:
+ collector = SearchResultsCollector()
+
+ # Extract content from the message
+ structured_content = assistant_message.structured_content
+ structured_content = (
+ structured_content[-1].get("text")
+ if structured_content
+ else None
+ )
+ raw_text = (
+ assistant_message.content or structured_content or ""
+ )
+ # Process citations
+ short_ids = extract_citations(raw_text or "")
+ final_citations = []
+ for sid in short_ids:
+ obj = collector.find_by_short_id(sid)
+ final_citations.append(
+ {
+ "id": sid,
+ "object": "citation",
+ "payload": dump_obj(obj) if obj else None,
+ }
+ )
+
+ # Persist in conversation DB
+ await (
+ self.providers.database.conversations_handler.add_message(
+ conversation_id=conversation_id,
+ content=assistant_message,
+ parent_id=message_id,
+ metadata={
+ "citations": final_citations,
+ "aggregated_search_result": json.dumps(
+ dump_collector(collector)
+ ),
+ },
+ )
+ )
+
+ # Generate conversation name if needed
+ if needs_initial_conversation_name:
+ conversation_name = None
+ try:
+ prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict() if message else {})}"
+ conversation_name = (
+ (
+ await self.providers.llm.aget_completion(
+ [{"role": "system", "content": prompt}],
+ GenerationConfig(
+ model=self.config.app.fast_llm
+ ),
+ )
+ )
+ .choices[0]
+ .message.content
+ )
+ except Exception as e:
+ pass
+ finally:
+ await self.providers.database.conversations_handler.update_conversation(
+ conversation_id=conversation_id,
+ name=conversation_name or "",
+ )
+
+ tool_calls = []
+ if hasattr(agent, "tool_calls"):
+ if agent.tool_calls is not None:
+ tool_calls = agent.tool_calls
+ else:
+ logger.warning(
+ "agent.tool_calls is None, using empty list instead"
+ )
+ # Return the final response
+ return {
+ "messages": [
+ Message(
+ role="assistant",
+ content=assistant_message.content
+ or structured_content
+ or "",
+ metadata={
+ "citations": final_citations,
+ "tool_calls": tool_calls,
+ "aggregated_search_result": json.dumps(
+ dump_collector(collector)
+ ),
+ },
+ )
+ ],
+ "conversation_id": str(conversation_id),
+ }
+
+ except Exception as e:
+ logger.error(f"Error in agent response: {str(e)}")
+ if "NoneType" in str(e):
+ raise HTTPException(
+ status_code=502,
+ detail="Server not reachable or returned an invalid response",
+ ) from e
+ raise HTTPException(
+ status_code=500,
+ detail=f"Internal Server Error - {str(e)}",
+ ) from e
+
+ async def get_context(
+ self,
+ filters: dict[str, Any],
+ options: dict[str, Any],
+ ) -> list[dict[str, Any]]:
+ """
+ Return an ordered list of documents (with minimal overview fields),
+ plus all associated chunks in ascending chunk order.
+
+ Only the filters: owner_id, collection_ids, and document_id
+ are supported. If any other filter or operator is passed in,
+ we raise an error.
+
+ Args:
+ filters: A dictionary describing the allowed filters
+ (owner_id, collection_ids, document_id).
+ options: A dictionary with extra options, e.g. include_summary_embedding
+ or any custom flags for additional logic.
+
+ Returns:
+ A list of dicts, where each dict has:
+ {
+ "document": <DocumentResponse>,
+ "chunks": [ <chunk0>, <chunk1>, ... ]
+ }
+ """
+ # 2. Fetch matching documents
+ matching_docs = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=-1,
+ filters=filters,
+ include_summary_embedding=options.get(
+ "include_summary_embedding", False
+ ),
+ )
+
+ if not matching_docs["results"]:
+ return []
+
+ # 3. For each document, fetch associated chunks in ascending chunk order
+ results = []
+ for doc_response in matching_docs["results"]:
+ doc_id = doc_response.id
+ chunk_data = await self.providers.database.chunks_handler.list_document_chunks(
+ document_id=doc_id,
+ offset=0,
+ limit=-1, # get all chunks
+ include_vectors=False,
+ )
+ chunks = chunk_data["results"] # already sorted by chunk_order
+ doc_response.chunks = chunks
+ # 4. Build a returned structure that includes doc + chunks
+ results.append(doc_response.model_dump())
+
+ return results
+
+ def _parse_user_and_collection_filters(
+ self,
+ filters: dict[str, Any],
+ ):
+ ### TODO - Come up with smarter way to extract owner / collection ids for non-admin
+ filter_starts_with_and = filters.get("$and")
+ filter_starts_with_or = filters.get("$or")
+ if filter_starts_with_and:
+ try:
+ filter_starts_with_and_then_or = filter_starts_with_and[0][
+ "$or"
+ ]
+
+ user_id = filter_starts_with_and_then_or[0]["owner_id"]["$eq"]
+ collection_ids = [
+ UUID(ele)
+ for ele in filter_starts_with_and_then_or[1][
+ "collection_ids"
+ ]["$overlap"]
+ ]
+ return user_id, [str(ele) for ele in collection_ids]
+ except Exception as e:
+ logger.error(
+ f"Error: {e}.\n\n While"
+ + """ parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored."""
+ )
+ return None, []
+ elif filter_starts_with_or:
+ try:
+ user_id = filter_starts_with_or[0]["owner_id"]["$eq"]
+ collection_ids = [
+ UUID(ele)
+ for ele in filter_starts_with_or[1]["collection_ids"][
+ "$overlap"
+ ]
+ ]
+ return user_id, [str(ele) for ele in collection_ids]
+ except Exception as e:
+ logger.error(
+ """Error parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored."""
+ )
+ return None, []
+ else:
+ # Admin user
+ return None, []
+
+ async def _build_documents_context(
+ self,
+ filter_user_id: Optional[UUID] = None,
+ max_summary_length: int = 128,
+ limit: int = 25,
+ reverse_order: bool = True,
+ ) -> str:
+ """
+ Fetches documents matching the given filters and returns a formatted string
+ enumerating them.
+ """
+ # We only want up to `limit` documents for brevity
+ docs_data = await self.providers.database.documents_handler.get_documents_overview(
+ offset=0,
+ limit=limit,
+ filter_user_ids=[filter_user_id] if filter_user_id else None,
+ include_summary_embedding=False,
+ sort_order="DESC" if reverse_order else "ASC",
+ )
+
+ found_max = False
+ if len(docs_data["results"]) == limit:
+ found_max = True
+
+ docs = docs_data["results"]
+ if not docs:
+ return "No documents found."
+
+ lines = []
+ for i, doc in enumerate(docs, start=1):
+ if (
+ not doc.summary
+ or doc.ingestion_status != IngestionStatus.SUCCESS
+ ):
+ lines.append(
+ f"[{i}] Title: {doc.title}, Summary: (Summary not available), Status:{doc.ingestion_status} ID: {doc.id}"
+ )
+ continue
+
+ # Build a line referencing the doc
+ title = doc.title or "(Untitled Document)"
+ lines.append(
+ f"[{i}] Title: {title}, Summary: {(doc.summary[0:max_summary_length] + ('...' if len(doc.summary) > max_summary_length else ''),)}, Total Tokens: {doc.total_tokens}, ID: {doc.id}"
+ )
+ if found_max:
+ lines.append(
+ f"Note: Displaying only the first {limit} documents. Use a filter to narrow down the search if more documents are required."
+ )
+
+ return "\n".join(lines)
+
+ async def _build_aware_system_instruction(
+ self,
+ max_tool_context_length: int = 10_000,
+ filter_user_id: Optional[UUID] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ model: Optional[str] = None,
+ use_system_context: bool = False,
+ mode: Optional[str] = "rag",
+ ) -> str:
+ """
+ High-level method that:
+ 1) builds the documents context
+ 2) builds the collections context
+ 3) loads the new `dynamic_reasoning_rag_agent` prompt
+ """
+ date_str = str(datetime.now().strftime("%m/%d/%Y"))
+
+ # "dynamic_rag_agent" // "static_rag_agent"
+
+ if mode == "rag":
+ prompt_name = (
+ self.config.agent.rag_agent_dynamic_prompt
+ if use_system_context
+ else self.config.agent.rag_rag_agent_static_prompt
+ )
+ else:
+ prompt_name = "static_research_agent"
+ return await self.providers.database.prompts_handler.get_cached_prompt(
+ # We use custom tooling and a custom agent to handle gemini models
+ prompt_name,
+ inputs={
+ "date": date_str,
+ },
+ )
+
+ if model is not None and ("deepseek" in model):
+ prompt_name = f"{prompt_name}_xml_tooling"
+
+ if use_system_context:
+ doc_context_str = await self._build_documents_context(
+ filter_user_id=filter_user_id,
+ )
+ logger.debug(f"Loading prompt {prompt_name}")
+ # Now fetch the prompt from the database prompts handler
+ # This relies on your "rag_agent_extended" existing with
+ # placeholders: date, document_context
+ system_prompt = await self.providers.database.prompts_handler.get_cached_prompt(
+ # We use custom tooling and a custom agent to handle gemini models
+ prompt_name,
+ inputs={
+ "date": date_str,
+ "max_tool_context_length": max_tool_context_length,
+ "document_context": doc_context_str,
+ },
+ )
+ else:
+ system_prompt = await self.providers.database.prompts_handler.get_cached_prompt(
+ prompt_name,
+ inputs={
+ "date": date_str,
+ },
+ )
+ logger.debug(f"Running agent with system prompt = {system_prompt}")
+ return system_prompt
+
+ async def _perform_web_search(
+ self,
+ query: str,
+ search_settings: SearchSettings = SearchSettings(),
+ ) -> AggregateSearchResult:
+ """
+ Perform a web search using an external search engine API (Serper).
+
+ Args:
+ query: The search query string
+ search_settings: Optional search settings to customize the search
+
+ Returns:
+ AggregateSearchResult containing web search results
+ """
+ try:
+ # Import the Serper client here to avoid circular imports
+ from core.utils.serper import SerperClient
+
+ # Initialize the Serper client
+ serper_client = SerperClient()
+
+ # Perform the raw search using Serper API
+ raw_results = serper_client.get_raw(query)
+
+ # Process the raw results into a WebSearchResult object
+ web_response = WebSearchResult.from_serper_results(raw_results)
+
+ # Create an AggregateSearchResult with the web search results
+ agg_result = AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=web_response.organic_results,
+ )
+
+ # Log the search for monitoring purposes
+ logger.debug(f"Web search completed for query: {query}")
+ logger.debug(
+ f"Found {len(web_response.organic_results)} web results"
+ )
+
+ return agg_result
+
+ except Exception as e:
+ logger.error(f"Error performing web search: {str(e)}")
+ # Return empty results rather than failing completely
+ return AggregateSearchResult(
+ chunk_search_results=None,
+ graph_search_results=None,
+ web_search_results=[],
+ )
+
+
+class RetrievalServiceAdapter:
+ @staticmethod
+ def _parse_user_data(user_data):
+ if isinstance(user_data, str):
+ try:
+ user_data = json.loads(user_data)
+ except json.JSONDecodeError as e:
+ raise ValueError(
+ f"Invalid user data format: {user_data}"
+ ) from e
+ return User.from_dict(user_data)
+
+ @staticmethod
+ def prepare_search_input(
+ query: str,
+ search_settings: SearchSettings,
+ user: User,
+ ) -> dict:
+ return {
+ "query": query,
+ "search_settings": search_settings.to_dict(),
+ "user": user.to_dict(),
+ }
+
+ @staticmethod
+ def parse_search_input(data: dict):
+ return {
+ "query": data["query"],
+ "search_settings": SearchSettings.from_dict(
+ data["search_settings"]
+ ),
+ "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
+ }
+
+ @staticmethod
+ def prepare_rag_input(
+ query: str,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ task_prompt: Optional[str],
+ include_web_search: bool,
+ user: User,
+ ) -> dict:
+ return {
+ "query": query,
+ "search_settings": search_settings.to_dict(),
+ "rag_generation_config": rag_generation_config.to_dict(),
+ "task_prompt": task_prompt,
+ "include_web_search": include_web_search,
+ "user": user.to_dict(),
+ }
+
+ @staticmethod
+ def parse_rag_input(data: dict):
+ return {
+ "query": data["query"],
+ "search_settings": SearchSettings.from_dict(
+ data["search_settings"]
+ ),
+ "rag_generation_config": GenerationConfig.from_dict(
+ data["rag_generation_config"]
+ ),
+ "task_prompt": data["task_prompt"],
+ "include_web_search": data["include_web_search"],
+ "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
+ }
+
+ @staticmethod
+ def prepare_agent_input(
+ message: Message,
+ search_settings: SearchSettings,
+ rag_generation_config: GenerationConfig,
+ task_prompt: Optional[str],
+ include_title_if_available: bool,
+ user: User,
+ conversation_id: Optional[str] = None,
+ ) -> dict:
+ return {
+ "message": message.to_dict(),
+ "search_settings": search_settings.to_dict(),
+ "rag_generation_config": rag_generation_config.to_dict(),
+ "task_prompt": task_prompt,
+ "include_title_if_available": include_title_if_available,
+ "user": user.to_dict(),
+ "conversation_id": conversation_id,
+ }
+
+ @staticmethod
+ def parse_agent_input(data: dict):
+ return {
+ "message": Message.from_dict(data["message"]),
+ "search_settings": SearchSettings.from_dict(
+ data["search_settings"]
+ ),
+ "rag_generation_config": GenerationConfig.from_dict(
+ data["rag_generation_config"]
+ ),
+ "task_prompt": data["task_prompt"],
+ "include_title_if_available": data["include_title_if_available"],
+ "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
+ "conversation_id": data.get("conversation_id"),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/__init__.py b/.venv/lib/python3.12/site-packages/core/parsers/__init__.py
new file mode 100644
index 00000000..8a7d5bbe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/__init__.py
@@ -0,0 +1,35 @@
+from .media import *
+from .structured import *
+from .text import *
+
+__all__ = [
+ "AudioParser",
+ "BMPParser",
+ "DOCParser",
+ "DOCXParser",
+ "ImageParser",
+ "ODTParser",
+ "VLMPDFParser",
+ "BasicPDFParser",
+ "PDFParserUnstructured",
+ "VLMPDFParser",
+ "PPTParser",
+ "PPTXParser",
+ "RTFParser",
+ "CSVParser",
+ "CSVParserAdvanced",
+ "EMLParser",
+ "EPUBParser",
+ "JSONParser",
+ "MSGParser",
+ "ORGParser",
+ "P7SParser",
+ "RSTParser",
+ "TSVParser",
+ "XLSParser",
+ "XLSXParser",
+ "XLSXParserAdvanced",
+ "MDParser",
+ "HTMLParser",
+ "TextParser",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/__init__.py b/.venv/lib/python3.12/site-packages/core/parsers/media/__init__.py
new file mode 100644
index 00000000..c268b673
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/__init__.py
@@ -0,0 +1,26 @@
+# type: ignore
+from .audio_parser import AudioParser
+from .bmp_parser import BMPParser
+from .doc_parser import DOCParser
+from .docx_parser import DOCXParser
+from .img_parser import ImageParser
+from .odt_parser import ODTParser
+from .pdf_parser import BasicPDFParser, PDFParserUnstructured, VLMPDFParser
+from .ppt_parser import PPTParser
+from .pptx_parser import PPTXParser
+from .rtf_parser import RTFParser
+
+__all__ = [
+ "AudioParser",
+ "BMPParser",
+ "DOCParser",
+ "DOCXParser",
+ "ImageParser",
+ "ODTParser",
+ "VLMPDFParser",
+ "BasicPDFParser",
+ "PDFParserUnstructured",
+ "PPTParser",
+ "PPTXParser",
+ "RTFParser",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/audio_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/audio_parser.py
new file mode 100644
index 00000000..7d5f9f1d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/audio_parser.py
@@ -0,0 +1,74 @@
+# type: ignore
+import logging
+import os
+import tempfile
+from typing import AsyncGenerator
+
+from litellm import atranscription
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+logger = logging.getLogger()
+
+
+class AudioParser(AsyncParser[bytes]):
+ """A parser for audio data using Whisper transcription."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.atranscription = atranscription
+
+ async def ingest( # type: ignore
+ self, data: bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest audio data and yield a transcription using Whisper via
+ LiteLLM.
+
+ Args:
+ data: Raw audio bytes
+ *args, **kwargs: Additional arguments passed to the transcription call
+
+ Yields:
+ Chunks of transcribed text
+ """
+ try:
+ # Create a temporary file to store the audio data
+ with tempfile.NamedTemporaryFile(
+ suffix=".wav", delete=False
+ ) as temp_file:
+ temp_file.write(data)
+ temp_file_path = temp_file.name
+
+ # Call Whisper transcription
+ response = await self.atranscription(
+ model=self.config.audio_transcription_model
+ or self.config.app.audio_lm,
+ file=open(temp_file_path, "rb"),
+ **kwargs,
+ )
+
+ # The response should contain the transcribed text directly
+ yield response.text
+
+ except Exception as e:
+ logger.error(f"Error processing audio with Whisper: {str(e)}")
+ raise
+
+ finally:
+ # Clean up the temporary file
+ try:
+ os.unlink(temp_file_path)
+ except Exception as e:
+ logger.warning(f"Failed to delete temporary file: {str(e)}")
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/bmp_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/bmp_parser.py
new file mode 100644
index 00000000..78646da7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/bmp_parser.py
@@ -0,0 +1,78 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class BMPParser(AsyncParser[str | bytes]):
+ """A parser for BMP image data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ import struct
+
+ self.struct = struct
+
+ async def extract_bmp_metadata(self, data: bytes) -> dict:
+ """Extract metadata from BMP file header."""
+ try:
+ # BMP header format
+ header_format = "<2sIHHI"
+ header_size = self.struct.calcsize(header_format)
+
+ # Unpack header data
+ (
+ signature,
+ file_size,
+ reserved,
+ reserved2,
+ data_offset,
+ ) = self.struct.unpack(header_format, data[:header_size])
+
+ # DIB header
+ dib_format = "<IiiHHIIiiII"
+ dib_size = self.struct.calcsize(dib_format)
+ dib_data = self.struct.unpack(dib_format, data[14 : 14 + dib_size])
+
+ width = dib_data[1]
+ height = abs(dib_data[2]) # Height can be negative
+ bits_per_pixel = dib_data[4]
+ compression = dib_data[5]
+
+ return {
+ "width": width,
+ "height": height,
+ "bits_per_pixel": bits_per_pixel,
+ "file_size": file_size,
+ "compression": compression,
+ }
+ except Exception as e:
+ return {"error": f"Failed to parse BMP header: {str(e)}"}
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest BMP data and yield metadata description."""
+ if isinstance(data, str):
+ # Convert base64 string to bytes if needed
+ import base64
+
+ data = base64.b64decode(data)
+
+ metadata = await self.extract_bmp_metadata(data)
+
+ # Generate description of the BMP file
+ yield f"BMP image with dimensions {metadata.get('width', 'unknown')}x{metadata.get('height', 'unknown')} pixels, {metadata.get('bits_per_pixel', 'unknown')} bits per pixel, file size: {metadata.get('file_size', 'unknown')} bytes"
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/doc_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/doc_parser.py
new file mode 100644
index 00000000..5b49e2cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/doc_parser.py
@@ -0,0 +1,108 @@
+# type: ignore
+import re
+from io import BytesIO
+from typing import AsyncGenerator
+
+import olefile
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class DOCParser(AsyncParser[str | bytes]):
+ """A parser for DOC (legacy Microsoft Word) data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.olefile = olefile
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest DOC data and yield text from the document."""
+ if isinstance(data, str):
+ raise ValueError("DOC data must be in bytes format.")
+
+ # Create BytesIO object from the data
+ file_obj = BytesIO(data)
+
+ try:
+ # Open the DOC file using olefile
+ ole = self.olefile.OleFileIO(file_obj)
+
+ # Check if it's a Word document
+ if not ole.exists("WordDocument"):
+ raise ValueError("Not a valid Word document")
+
+ # Read the WordDocument stream
+ word_stream = ole.openstream("WordDocument").read()
+
+ # Read the text from the 0Table or 1Table stream (contains the text)
+ if ole.exists("1Table"):
+ table_stream = ole.openstream("1Table").read()
+ elif ole.exists("0Table"):
+ table_stream = ole.openstream("0Table").read()
+ else:
+ table_stream = b""
+
+ # Extract text content
+ text = self._extract_text(word_stream, table_stream)
+
+ # Clean and split the text
+ paragraphs = self._clean_text(text)
+
+ # Yield non-empty paragraphs
+ for paragraph in paragraphs:
+ if paragraph.strip():
+ yield paragraph.strip()
+
+ except Exception as e:
+ raise ValueError(f"Error processing DOC file: {str(e)}") from e
+ finally:
+ ole.close()
+ file_obj.close()
+
+ def _extract_text(self, word_stream: bytes, table_stream: bytes) -> str:
+ """Extract text from Word document streams."""
+ try:
+ text = word_stream.replace(b"\x00", b"").decode(
+ "utf-8", errors="ignore"
+ )
+
+ # If table_stream exists, try to extract additional text
+ if table_stream:
+ table_text = table_stream.replace(b"\x00", b"").decode(
+ "utf-8", errors="ignore"
+ )
+ text += table_text
+
+ return text
+ except Exception as e:
+ raise ValueError(f"Error extracting text: {str(e)}") from e
+
+ def _clean_text(self, text: str) -> list[str]:
+ """Clean and split the extracted text into paragraphs."""
+ # Remove binary artifacts and control characters
+ text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\xFF]", "", text)
+
+ # Remove multiple spaces and newlines
+ text = re.sub(r"\s+", " ", text)
+
+ # Split into paragraphs on double newlines or other common separators
+ paragraphs = re.split(r"\n\n|\r\n\r\n|\f", text)
+
+ # Remove empty or whitespace-only paragraphs
+ paragraphs = [p.strip() for p in paragraphs if p.strip()]
+
+ return paragraphs
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/docx_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/docx_parser.py
new file mode 100644
index 00000000..988f8341
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/docx_parser.py
@@ -0,0 +1,38 @@
+# type: ignore
+from io import BytesIO
+from typing import AsyncGenerator
+
+from docx import Document
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class DOCXParser(AsyncParser[str | bytes]):
+ """A parser for DOCX data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.Document = Document
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]: # type: ignore
+ """Ingest DOCX data and yield text from each paragraph."""
+ if isinstance(data, str):
+ raise ValueError("DOCX data must be in bytes format.")
+
+ doc = self.Document(BytesIO(data))
+ for paragraph in doc.paragraphs:
+ yield paragraph.text
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/img_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/img_parser.py
new file mode 100644
index 00000000..bcb37eab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/img_parser.py
@@ -0,0 +1,281 @@
+# type: ignore
+import base64
+import logging
+from io import BytesIO
+from typing import AsyncGenerator, Optional
+
+import filetype
+import pillow_heif
+from PIL import Image
+
+from core.base.abstractions import GenerationConfig
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+logger = logging.getLogger()
+
+
+class ImageParser(AsyncParser[str | bytes]):
+ # Mapping of file extensions to MIME types
+ MIME_TYPE_MAPPING = {
+ "bmp": "image/bmp",
+ "gif": "image/gif",
+ "heic": "image/heic",
+ "jpeg": "image/jpeg",
+ "jpg": "image/jpeg",
+ "png": "image/png",
+ "tiff": "image/tiff",
+ "tif": "image/tiff",
+ "webp": "image/webp",
+ }
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.vision_prompt_text = None
+ self.Image = Image
+ self.pillow_heif = pillow_heif
+ self.pillow_heif.register_heif_opener()
+
+ def _is_heic(self, data: bytes) -> bool:
+ """Detect HEIC format using magic numbers and patterns."""
+ heic_patterns = [
+ b"ftyp",
+ b"heic",
+ b"heix",
+ b"hevc",
+ b"HEIC",
+ b"mif1",
+ b"msf1",
+ b"hevc",
+ b"hevx",
+ ]
+
+ try:
+ header = data[:32] # Get first 32 bytes
+ return any(pattern in header for pattern in heic_patterns)
+ except Exception as e:
+ logger.error(f"Error checking for HEIC format: {str(e)}")
+ return False
+
+ async def _convert_heic_to_jpeg(self, data: bytes) -> bytes:
+ """Convert HEIC image to JPEG format."""
+ try:
+ # Create BytesIO object for input
+ input_buffer = BytesIO(data)
+
+ # Load HEIC image using pillow_heif
+ heif_file = self.pillow_heif.read_heif(input_buffer)
+
+ # Get the primary image - API changed, need to get first image
+ heif_image = heif_file[0] # Get first image in the container
+
+ # Convert to PIL Image directly from the HEIF image
+ pil_image = heif_image.to_pillow()
+
+ # Convert to RGB if needed
+ if pil_image.mode != "RGB":
+ pil_image = pil_image.convert("RGB")
+
+ # Save as JPEG
+ output_buffer = BytesIO()
+ pil_image.save(output_buffer, format="JPEG", quality=95)
+ return output_buffer.getvalue()
+
+ except Exception as e:
+ logger.error(f"Error converting HEIC to JPEG: {str(e)}")
+ raise
+
+ def _is_jpeg(self, data: bytes) -> bool:
+ """Detect JPEG format using magic numbers."""
+ return len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8
+
+ def _is_png(self, data: bytes) -> bool:
+ """Detect PNG format using magic numbers."""
+ png_signature = b"\x89PNG\r\n\x1a\n"
+ return data.startswith(png_signature)
+
+ def _is_bmp(self, data: bytes) -> bool:
+ """Detect BMP format using magic numbers."""
+ return data.startswith(b"BM")
+
+ def _is_tiff(self, data: bytes) -> bool:
+ """Detect TIFF format using magic numbers."""
+ return (
+ data.startswith(b"II*\x00") # Little-endian
+ or data.startswith(b"MM\x00*")
+ ) # Big-endian
+
+ def _get_image_media_type(
+ self, data: bytes, filename: Optional[str] = None
+ ) -> str:
+ """
+ Determine the correct media type based on image data and/or filename.
+
+ Args:
+ data: The binary image data
+ filename: Optional filename which may contain extension information
+
+ Returns:
+ str: The MIME type for the image
+ """
+ try:
+ # First, try format-specific detection functions
+ if self._is_heic(data):
+ return "image/heic"
+ if self._is_jpeg(data):
+ return "image/jpeg"
+ if self._is_png(data):
+ return "image/png"
+ if self._is_bmp(data):
+ return "image/bmp"
+ if self._is_tiff(data):
+ return "image/tiff"
+
+ # Try using filetype as a fallback
+ img_type = filetype.guess(data)
+ if img_type:
+ # Map the detected type to a MIME type
+ return self.MIME_TYPE_MAPPING.get(
+ img_type, f"image/{img_type}"
+ )
+
+ # If we have a filename, try to get the type from the extension
+ if filename:
+ extension = filename.split(".")[-1].lower()
+ if extension in self.MIME_TYPE_MAPPING:
+ return self.MIME_TYPE_MAPPING[extension]
+
+ # If all else fails, default to octet-stream (generic binary)
+ logger.warning(
+ "Could not determine image type, using application/octet-stream"
+ )
+ return "application/octet-stream"
+
+ except Exception as e:
+ logger.error(f"Error determining image media type: {str(e)}")
+ return "application/octet-stream" # Default to generic binary as fallback
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ if not self.vision_prompt_text:
+ self.vision_prompt_text = (
+ await self.database_provider.prompts_handler.get_cached_prompt(
+ prompt_name=self.config.vision_img_prompt_name
+ )
+ )
+ try:
+ filename = kwargs.get("filename", None)
+ # Whether to convert HEIC to JPEG (default: True for backward compatibility)
+ convert_heic = kwargs.get("convert_heic", True)
+
+ if isinstance(data, bytes):
+ try:
+ # First detect the original media type
+ original_media_type = self._get_image_media_type(
+ data, filename
+ )
+ logger.debug(
+ f"Detected original image type: {original_media_type}"
+ )
+
+ # Determine if we need to convert HEIC
+ is_heic_format = self._is_heic(data)
+
+ # Handle HEIC images
+ if is_heic_format and convert_heic:
+ logger.debug(
+ "Detected HEIC format, converting to JPEG"
+ )
+ data = await self._convert_heic_to_jpeg(data)
+ media_type = "image/jpeg"
+ else:
+ # Keep original format and media type
+ media_type = original_media_type
+
+ # Encode the data to base64
+ image_data = base64.b64encode(data).decode("utf-8")
+
+ except Exception as e:
+ logger.error(f"Error processing image data: {str(e)}")
+ raise
+ else:
+ # If data is already a string (base64), we assume it has a reliable content type
+ # from the source that encoded it
+ image_data = data
+
+ # Try to determine the media type from the context if available
+ media_type = kwargs.get(
+ "media_type", "application/octet-stream"
+ )
+
+ # Get the model from kwargs or config
+ model = kwargs.get("vlm", None) or self.config.app.vlm
+
+ generation_config = GenerationConfig(
+ model=model,
+ stream=False,
+ )
+
+ logger.debug(f"Using model: {model}, media_type: {media_type}")
+
+ if "anthropic" in model:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": self.vision_prompt_text},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": media_type,
+ "data": image_data,
+ },
+ },
+ ],
+ }
+ ]
+ else:
+ # For OpenAI-style APIs, use their format
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": self.vision_prompt_text},
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:{media_type};base64,{image_data}"
+ },
+ },
+ ],
+ }
+ ]
+
+ response = await self.llm_provider.aget_completion(
+ messages=messages, generation_config=generation_config
+ )
+
+ if response.choices and response.choices[0].message:
+ content = response.choices[0].message.content
+ if not content:
+ raise ValueError("No content in response")
+ yield content
+ else:
+ raise ValueError("No response content")
+
+ except Exception as e:
+ logger.error(f"Error processing image with vision model: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/odt_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/odt_parser.py
new file mode 100644
index 00000000..cb146464
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/odt_parser.py
@@ -0,0 +1,60 @@
+# type: ignore
+import xml.etree.ElementTree as ET
+import zipfile
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class ODTParser(AsyncParser[str | bytes]):
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.zipfile = zipfile
+ self.ET = ET
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ if isinstance(data, str):
+ raise ValueError("ODT data must be in bytes format.")
+
+ from io import BytesIO
+
+ file_obj = BytesIO(data)
+
+ try:
+ with self.zipfile.ZipFile(file_obj) as odt:
+ # ODT files are zip archives containing content.xml
+ content = odt.read("content.xml")
+ root = self.ET.fromstring(content)
+
+ # ODT XML namespace
+ ns = {"text": "urn:oasis:names:tc:opendocument:xmlns:text:1.0"}
+
+ # Extract paragraphs and headers
+ for p in root.findall(".//text:p", ns):
+ text = "".join(p.itertext())
+ if text.strip():
+ yield text.strip()
+
+ for h in root.findall(".//text:h", ns):
+ text = "".join(h.itertext())
+ if text.strip():
+ yield text.strip()
+
+ except Exception as e:
+ raise ValueError(f"Error processing ODT file: {str(e)}") from e
+ finally:
+ file_obj.close()
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/pdf_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/pdf_parser.py
new file mode 100644
index 00000000..b33ccb63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/pdf_parser.py
@@ -0,0 +1,363 @@
+# type: ignore
+import asyncio
+import base64
+import json
+import logging
+import string
+import time
+import unicodedata
+from io import BytesIO
+from typing import AsyncGenerator
+
+from pdf2image import convert_from_bytes, convert_from_path
+from pdf2image.exceptions import PDFInfoNotInstalledError
+from PIL import Image
+from pypdf import PdfReader
+
+from core.base.abstractions import GenerationConfig
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+from shared.abstractions import PDFParsingError, PopplerNotFoundError
+
+logger = logging.getLogger()
+
+
+class VLMPDFParser(AsyncParser[str | bytes]):
+ """A parser for PDF documents using vision models for page processing."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.vision_prompt_text = None
+
+ async def convert_pdf_to_images(
+ self, data: str | bytes
+ ) -> list[Image.Image]:
+ """Convert PDF pages to images asynchronously using in-memory
+ conversion."""
+ logger.info("Starting PDF conversion to images.")
+ start_time = time.perf_counter()
+ options = {
+ "dpi": 300, # You can make this configurable via self.config if needed
+ "fmt": "jpeg",
+ "thread_count": 4,
+ "paths_only": False, # Return PIL Image objects instead of writing to disk
+ }
+ try:
+ if isinstance(data, bytes):
+ images = await asyncio.to_thread(
+ convert_from_bytes, data, **options
+ )
+ else:
+ images = await asyncio.to_thread(
+ convert_from_path, data, **options
+ )
+ elapsed = time.perf_counter() - start_time
+ logger.info(
+ f"PDF conversion completed in {elapsed:.2f} seconds, total pages: {len(images)}"
+ )
+ return images
+ except PDFInfoNotInstalledError as e:
+ logger.error(
+ "PDFInfoNotInstalledError encountered during PDF conversion."
+ )
+ raise PopplerNotFoundError() from e
+ except Exception as err:
+ logger.error(
+ f"Error converting PDF to images: {err} type: {type(err)}"
+ )
+ raise PDFParsingError(
+ f"Failed to process PDF: {str(err)}", err
+ ) from err
+
+ async def process_page(
+ self, image: Image.Image, page_num: int
+ ) -> dict[str, str]:
+ """Process a single PDF page using the vision model."""
+ page_start = time.perf_counter()
+ try:
+ # Convert PIL image to JPEG bytes in-memory
+ buf = BytesIO()
+ image.save(buf, format="JPEG")
+ buf.seek(0)
+ image_data = buf.read()
+ image_base64 = base64.b64encode(image_data).decode("utf-8")
+
+ model = self.config.app.vlm
+
+ # Configure generation parameters
+ generation_config = GenerationConfig(
+ model=self.config.app.vlm,
+ stream=False,
+ )
+
+ is_anthropic = model and "anthropic/" in model
+
+ # FIXME: This is a hacky fix to handle the different formats
+ # that was causing an outage. This logic really needs to be refactored
+ # and cleaned up such that it handles providers more robustly.
+
+ # Prepare message with image content
+ if is_anthropic:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": self.vision_prompt_text},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": image_base64,
+ },
+ },
+ ],
+ }
+ ]
+ else:
+ # Use OpenAI format
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": self.vision_prompt_text},
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{image_base64}"
+ },
+ },
+ ],
+ }
+ ]
+
+ logger.debug(f"Sending page {page_num} to vision model.")
+ req_start = time.perf_counter()
+ if is_anthropic:
+ response = await self.llm_provider.aget_completion(
+ messages=messages,
+ generation_config=generation_config,
+ tools=[
+ {
+ "name": "parse_pdf_page",
+ "description": "Parse text content from a PDF page",
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ "page_content": {
+ "type": "string",
+ "description": "Extracted text from the PDF page",
+ },
+ },
+ "required": ["page_content"],
+ },
+ }
+ ],
+ tool_choice={"type": "tool", "name": "parse_pdf_page"},
+ )
+
+ if (
+ response.choices
+ and response.choices[0].message
+ and response.choices[0].message.tool_calls
+ ):
+ tool_call = response.choices[0].message.tool_calls[0]
+ args = json.loads(tool_call.function.arguments)
+ content = args.get("page_content", "")
+ page_elapsed = time.perf_counter() - page_start
+ logger.debug(
+ f"Processed page {page_num} in {page_elapsed:.2f} seconds."
+ )
+
+ return {"page": str(page_num), "content": content}
+ else:
+ logger.warning(
+ f"No valid tool call in response for page {page_num}, document might be missing text."
+ )
+ else:
+ response = await self.llm_provider.aget_completion(
+ messages=messages, generation_config=generation_config
+ )
+ req_elapsed = time.perf_counter() - req_start
+ logger.debug(
+ f"Vision model response for page {page_num} received in {req_elapsed:.2f} seconds."
+ )
+
+ if response.choices and response.choices[0].message:
+ content = response.choices[0].message.content
+ page_elapsed = time.perf_counter() - page_start
+ logger.debug(
+ f"Processed page {page_num} in {page_elapsed:.2f} seconds."
+ )
+ return {"page": str(page_num), "content": content}
+ else:
+ msg = f"No response content for page {page_num}"
+ logger.error(msg)
+ raise ValueError(msg)
+ except Exception as e:
+ logger.error(
+ f"Error processing page {page_num} with vision model: {str(e)}"
+ )
+ raise
+
+ async def ingest(
+ self, data: str | bytes, maintain_order: bool = True, **kwargs
+ ) -> AsyncGenerator[dict[str, str | int], None]:
+ """Ingest PDF data and yield the text description for each page using
+ the vision model.
+
+ (This version yields a string per page rather than a dictionary.)
+ """
+ ingest_start = time.perf_counter()
+ logger.info("Starting PDF ingestion using VLMPDFParser.")
+ if not self.vision_prompt_text:
+ self.vision_prompt_text = (
+ await self.database_provider.prompts_handler.get_cached_prompt(
+ prompt_name=self.config.vision_pdf_prompt_name
+ )
+ )
+ logger.info("Retrieved vision prompt text from database.")
+
+ try:
+ # Convert PDF to images (in-memory)
+ images = await self.convert_pdf_to_images(data)
+
+ # Create asynchronous tasks for processing each page
+ tasks = {
+ asyncio.create_task(
+ self.process_page(image, page_num)
+ ): page_num
+ for page_num, image in enumerate(images, 1)
+ }
+
+ if maintain_order:
+ pending = set(tasks.keys())
+ results = {}
+ next_page = 1
+ while pending:
+ done, pending = await asyncio.wait(
+ pending, return_when=asyncio.FIRST_COMPLETED
+ )
+ for task in done:
+ result = await task
+ page_num = int(result["page"])
+ results[page_num] = result
+ while next_page in results:
+ yield {
+ "content": results[next_page]["content"] or "",
+ "page_number": next_page,
+ }
+ results.pop(next_page)
+ next_page += 1
+ else:
+ # Yield results as tasks complete
+ for coro in asyncio.as_completed(tasks.keys()):
+ result = await coro
+ yield {
+ "content": result["content"],
+ "page_number": int(result["page"]),
+ }
+ total_elapsed = time.perf_counter() - ingest_start
+ logger.info(
+ f"Completed PDF ingestion in {total_elapsed:.2f} seconds using VLMPDFParser."
+ )
+ except Exception as e:
+ logger.error(f"Error processing PDF: {str(e)}")
+ raise
+
+
+class BasicPDFParser(AsyncParser[str | bytes]):
+ """A parser for PDF data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.PdfReader = PdfReader
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest PDF data and yield text from each page."""
+ if isinstance(data, str):
+ raise ValueError("PDF data must be in bytes format.")
+ pdf = self.PdfReader(BytesIO(data))
+ for page in pdf.pages:
+ page_text = page.extract_text()
+ if page_text is not None:
+ page_text = "".join(
+ filter(
+ lambda x: (
+ unicodedata.category(x)
+ in [
+ "Ll",
+ "Lu",
+ "Lt",
+ "Lm",
+ "Lo",
+ "Nl",
+ "No",
+ ] # Keep letters and numbers
+ or "\u4e00" <= x <= "\u9fff" # Chinese characters
+ or "\u0600" <= x <= "\u06ff" # Arabic characters
+ or "\u0400" <= x <= "\u04ff" # Cyrillic letters
+ or "\u0370" <= x <= "\u03ff" # Greek letters
+ or "\u0e00" <= x <= "\u0e7f" # Thai
+ or "\u3040" <= x <= "\u309f" # Japanese Hiragana
+ or "\u30a0" <= x <= "\u30ff" # Katakana
+ or x in string.printable
+ ),
+ page_text,
+ )
+ ) # Keep characters in common languages ; # Filter out non-printable characters
+ yield page_text
+
+
+class PDFParserUnstructured(AsyncParser[str | bytes]):
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ try:
+ from unstructured.partition.pdf import partition_pdf
+
+ self.partition_pdf = partition_pdf
+
+ except ImportError as e:
+ logger.error("PDFParserUnstructured ImportError : ", e)
+
+ async def ingest(
+ self,
+ data: str | bytes,
+ partition_strategy: str = "hi_res",
+ chunking_strategy="by_title",
+ ) -> AsyncGenerator[str, None]:
+ # partition the pdf
+ elements = self.partition_pdf(
+ file=BytesIO(data),
+ partition_strategy=partition_strategy,
+ chunking_strategy=chunking_strategy,
+ )
+ for element in elements:
+ yield element.text
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/ppt_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/ppt_parser.py
new file mode 100644
index 00000000..c8bbaa55
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/ppt_parser.py
@@ -0,0 +1,88 @@
+# type: ignore
+import struct
+from io import BytesIO
+from typing import AsyncGenerator
+
+import olefile
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class PPTParser(AsyncParser[str | bytes]):
+ """A parser for legacy PPT (PowerPoint 97-2003) files."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.olefile = olefile
+
+ def _extract_text_from_record(self, data: bytes) -> str:
+ """Extract text from a PPT text record."""
+ try:
+ # Skip record header
+ text_data = data[8:]
+ # Convert from UTF-16-LE
+ return text_data.decode("utf-16-le", errors="ignore").strip()
+ except Exception:
+ return ""
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest PPT data and yield text from each slide."""
+ if isinstance(data, str):
+ raise ValueError("PPT data must be in bytes format.")
+
+ try:
+ ole = self.olefile.OleFileIO(BytesIO(data))
+
+ # PPT stores text in PowerPoint Document stream
+ if not ole.exists("PowerPoint Document"):
+ raise ValueError("Not a valid PowerPoint file")
+
+ # Read PowerPoint Document stream
+ ppt_stream = ole.openstream("PowerPoint Document")
+ content = ppt_stream.read()
+
+ # Text records start with 0x0FA0 or 0x0FD0
+ text_markers = [b"\xa0\x0f", b"\xd0\x0f"]
+
+ current_position = 0
+ while current_position < len(content):
+ # Look for text markers
+ for marker in text_markers:
+ marker_pos = content.find(marker, current_position)
+ if marker_pos != -1:
+ # Get record size from header (4 bytes after marker)
+ size_bytes = content[marker_pos + 2 : marker_pos + 6]
+ record_size = struct.unpack("<I", size_bytes)[0]
+
+ # Extract record data
+ record_data = content[
+ marker_pos : marker_pos + record_size + 8
+ ]
+ text = self._extract_text_from_record(record_data)
+
+ if text.strip():
+ yield text.strip()
+
+ current_position = marker_pos + record_size + 8
+ break
+ else:
+ current_position += 1
+
+ except Exception as e:
+ raise ValueError(f"Error processing PPT file: {str(e)}") from e
+ finally:
+ ole.close()
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/pptx_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/pptx_parser.py
new file mode 100644
index 00000000..8685c8fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/pptx_parser.py
@@ -0,0 +1,40 @@
+# type: ignore
+from io import BytesIO
+from typing import AsyncGenerator
+
+from pptx import Presentation
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class PPTXParser(AsyncParser[str | bytes]):
+ """A parser for PPT data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.Presentation = Presentation
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]: # type: ignore
+ """Ingest PPT data and yield text from each slide."""
+ if isinstance(data, str):
+ raise ValueError("PPT data must be in bytes format.")
+
+ prs = self.Presentation(BytesIO(data))
+ for slide in prs.slides:
+ for shape in slide.shapes:
+ if hasattr(shape, "text"):
+ yield shape.text
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/media/rtf_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/media/rtf_parser.py
new file mode 100644
index 00000000..6be12076
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/media/rtf_parser.py
@@ -0,0 +1,45 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from striprtf.striprtf import rtf_to_text
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class RTFParser(AsyncParser[str | bytes]):
+ """Parser for Rich Text Format (.rtf) files."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.striprtf = rtf_to_text
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ if isinstance(data, bytes):
+ data = data.decode("utf-8", errors="ignore")
+
+ try:
+ # Convert RTF to plain text
+ plain_text = self.striprtf(data)
+
+ # Split into paragraphs and yield non-empty ones
+ paragraphs = plain_text.split("\n\n")
+ for paragraph in paragraphs:
+ if paragraph.strip():
+ yield paragraph.strip()
+
+ except Exception as e:
+ raise ValueError(f"Error processing RTF file: {str(e)}") from e
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/__init__.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/__init__.py
new file mode 100644
index 00000000..a770502e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/__init__.py
@@ -0,0 +1,28 @@
+# type: ignore
+from .csv_parser import CSVParser, CSVParserAdvanced
+from .eml_parser import EMLParser
+from .epub_parser import EPUBParser
+from .json_parser import JSONParser
+from .msg_parser import MSGParser
+from .org_parser import ORGParser
+from .p7s_parser import P7SParser
+from .rst_parser import RSTParser
+from .tsv_parser import TSVParser
+from .xls_parser import XLSParser
+from .xlsx_parser import XLSXParser, XLSXParserAdvanced
+
+__all__ = [
+ "CSVParser",
+ "CSVParserAdvanced",
+ "EMLParser",
+ "EPUBParser",
+ "JSONParser",
+ "MSGParser",
+ "ORGParser",
+ "P7SParser",
+ "RSTParser",
+ "TSVParser",
+ "XLSParser",
+ "XLSXParser",
+ "XLSXParserAdvanced",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/csv_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/csv_parser.py
new file mode 100644
index 00000000..d80d5d07
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/csv_parser.py
@@ -0,0 +1,108 @@
+# type: ignore
+from typing import IO, AsyncGenerator, Optional
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class CSVParser(AsyncParser[str | bytes]):
+ """A parser for CSV data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ import csv
+ from io import StringIO
+
+ self.csv = csv
+ self.StringIO = StringIO
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest CSV data and yield text from each row."""
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ csv_reader = self.csv.reader(self.StringIO(data))
+ for row in csv_reader:
+ yield ", ".join(row)
+
+
+class CSVParserAdvanced(AsyncParser[str | bytes]):
+ """A parser for CSV data."""
+
+ def __init__(
+ self, config: IngestionConfig, llm_provider: CompletionProvider
+ ):
+ self.llm_provider = llm_provider
+ self.config = config
+
+ import csv
+ from io import StringIO
+
+ self.csv = csv
+ self.StringIO = StringIO
+
+ def get_delimiter(
+ self, file_path: Optional[str] = None, file: Optional[IO[bytes]] = None
+ ):
+ sniffer = self.csv.Sniffer()
+ num_bytes = 65536
+
+ if file:
+ lines = file.readlines(num_bytes)
+ file.seek(0)
+ data = "\n".join(ln.decode("utf-8") for ln in lines)
+ elif file_path is not None:
+ with open(file_path) as f:
+ data = "\n".join(f.readlines(num_bytes))
+
+ return sniffer.sniff(data, delimiters=",;").delimiter
+
+ async def ingest(
+ self,
+ data: str | bytes,
+ num_col_times_num_rows: int = 100,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """Ingest CSV data and yield text from each row."""
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ # let the first row be the header
+ delimiter = self.get_delimiter(file=self.StringIO(data))
+
+ csv_reader = self.csv.reader(self.StringIO(data), delimiter=delimiter)
+
+ header = next(csv_reader)
+ num_cols = len(header.split(delimiter))
+ num_rows = num_col_times_num_rows // num_cols
+
+ chunk_rows = []
+ for row_num, row in enumerate(csv_reader):
+ chunk_rows.append(row)
+ if row_num % num_rows == 0:
+ yield (
+ ", ".join(header)
+ + "\n"
+ + "\n".join([", ".join(row) for row in chunk_rows])
+ )
+ chunk_rows = []
+
+ if chunk_rows:
+ yield (
+ ", ".join(header)
+ + "\n"
+ + "\n".join([", ".join(row) for row in chunk_rows])
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/eml_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/eml_parser.py
new file mode 100644
index 00000000..57a5ceab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/eml_parser.py
@@ -0,0 +1,63 @@
+# type: ignore
+from email import message_from_bytes, policy
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class EMLParser(AsyncParser[str | bytes]):
+ """Parser for EML (email) files."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest EML data and yield email content."""
+ if isinstance(data, str):
+ raise ValueError("EML data must be in bytes format.")
+
+ # Parse email with policy for modern email handling
+ email_message = message_from_bytes(data, policy=policy.default)
+
+ # Extract and yield email metadata
+ metadata = []
+ if email_message["Subject"]:
+ metadata.append(f"Subject: {email_message['Subject']}")
+ if email_message["From"]:
+ metadata.append(f"From: {email_message['From']}")
+ if email_message["To"]:
+ metadata.append(f"To: {email_message['To']}")
+ if email_message["Date"]:
+ metadata.append(f"Date: {email_message['Date']}")
+
+ if metadata:
+ yield "\n".join(metadata)
+
+ # Extract and yield email body
+ if email_message.is_multipart():
+ for part in email_message.walk():
+ if part.get_content_type() == "text/plain":
+ text = part.get_content()
+ if text.strip():
+ yield text.strip()
+ elif part.get_content_type() == "text/html":
+ # Could add HTML parsing here if needed
+ continue
+ else:
+ body = email_message.get_content()
+ if body.strip():
+ yield body.strip()
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/epub_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/epub_parser.py
new file mode 100644
index 00000000..ff51fb86
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/epub_parser.py
@@ -0,0 +1,121 @@
+# type: ignore
+import logging
+from typing import AsyncGenerator
+
+import epub
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class EPUBParser(AsyncParser[str | bytes]):
+ """Parser for EPUB electronic book files."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.epub = epub
+
+ def _safe_get_metadata(self, book, field: str) -> str | None:
+ """Safely extract metadata field from epub book."""
+ try:
+ return getattr(book, field, None) or getattr(book.opf, field, None)
+ except Exception as e:
+ logger.debug(f"Error getting {field} metadata: {e}")
+ return None
+
+ def _clean_text(self, content: bytes) -> str:
+ """Clean HTML content and return plain text."""
+ try:
+ import re
+
+ text = content.decode("utf-8", errors="ignore")
+ # Remove HTML tags
+ text = re.sub(r"<[^>]+>", " ", text)
+ # Normalize whitespace
+ text = re.sub(r"\s+", " ", text)
+ # Remove any remaining HTML entities
+ text = re.sub(r"&[^;]+;", " ", text)
+ return text.strip()
+ except Exception as e:
+ logger.warning(f"Error cleaning text: {e}")
+ return ""
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest EPUB data and yield book content."""
+ if isinstance(data, str):
+ raise ValueError("EPUB data must be in bytes format.")
+
+ from io import BytesIO
+
+ file_obj = BytesIO(data)
+
+ try:
+ book = self.epub.open_epub(file_obj)
+
+ # Safely extract metadata
+ metadata = []
+ for field, label in [
+ ("title", "Title"),
+ ("creator", "Author"),
+ ("language", "Language"),
+ ("publisher", "Publisher"),
+ ("date", "Date"),
+ ]:
+ if value := self._safe_get_metadata(book, field):
+ metadata.append(f"{label}: {value}")
+
+ if metadata:
+ yield "\n".join(metadata)
+
+ # Extract content from items
+ try:
+ manifest = getattr(book.opf, "manifest", {}) or {}
+ for item in manifest.values():
+ try:
+ if (
+ getattr(item, "mime_type", "")
+ == "application/xhtml+xml"
+ ):
+ if content := book.read_item(item):
+ if cleaned_text := self._clean_text(content):
+ yield cleaned_text
+ except Exception as e:
+ logger.warning(f"Error processing item: {e}")
+ continue
+
+ except Exception as e:
+ logger.warning(f"Error accessing manifest: {e}")
+ # Fallback: try to get content directly
+ if hasattr(book, "read_item"):
+ for item_id in getattr(book, "items", []):
+ try:
+ if content := book.read_item(item_id):
+ if cleaned_text := self._clean_text(content):
+ yield cleaned_text
+ except Exception as e:
+ logger.warning(f"Error in fallback reading: {e}")
+ continue
+
+ except Exception as e:
+ logger.error(f"Error processing EPUB file: {str(e)}")
+ raise ValueError(f"Error processing EPUB file: {str(e)}") from e
+ finally:
+ try:
+ file_obj.close()
+ except Exception as e:
+ logger.warning(f"Error closing file: {e}")
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/json_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/json_parser.py
new file mode 100644
index 00000000..3948e4de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/json_parser.py
@@ -0,0 +1,94 @@
+# type: ignore
+import asyncio
+import json
+from typing import AsyncGenerator
+
+from core.base import R2RException
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class JSONParser(AsyncParser[str | bytes]):
+ """A parser for JSON data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest JSON data and yield a formatted text representation.
+
+ :param data: The JSON data to parse.
+ :param kwargs: Additional keyword arguments.
+ """
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+
+ loop = asyncio.get_event_loop()
+
+ try:
+ parsed_json = await loop.run_in_executor(None, json.loads, data)
+ formatted_text = await loop.run_in_executor(
+ None, self._parse_json, parsed_json
+ )
+ except json.JSONDecodeError as e:
+ raise R2RException(
+ message=f"Failed to parse JSON data, likely due to invalid JSON: {str(e)}",
+ status_code=400,
+ ) from e
+
+ chunk_size = kwargs.get("chunk_size")
+ if chunk_size and isinstance(chunk_size, int):
+ # If chunk_size is provided and is an integer, yield the formatted text in chunks
+ for i in range(0, len(formatted_text), chunk_size):
+ yield formatted_text[i : i + chunk_size]
+ await asyncio.sleep(0)
+ else:
+ # If no valid chunk_size is provided, yield the entire formatted text
+ yield formatted_text
+
+ def _parse_json(self, data: dict) -> str:
+ def remove_objects_with_null(obj):
+ if not isinstance(obj, dict):
+ return obj
+ result = obj.copy()
+ for key, value in obj.items():
+ if isinstance(value, dict):
+ result[key] = remove_objects_with_null(value)
+ elif value is None:
+ del result[key]
+ return result
+
+ def format_json_as_text(obj, indent=0):
+ lines = []
+ indent_str = " " * indent
+
+ if isinstance(obj, dict):
+ for key, value in obj.items():
+ if isinstance(value, (dict, list)):
+ nested = format_json_as_text(value, indent + 2)
+ lines.append(f"{indent_str}{key}:\n{nested}")
+ else:
+ lines.append(f"{indent_str}{key}: {value}")
+ elif isinstance(obj, list):
+ for item in obj:
+ nested = format_json_as_text(item, indent + 2)
+ lines.append(f"{nested}")
+ else:
+ return f"{indent_str}{obj}"
+
+ return "\n".join(lines)
+
+ return format_json_as_text(remove_objects_with_null(data))
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/msg_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/msg_parser.py
new file mode 100644
index 00000000..4a024ecf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/msg_parser.py
@@ -0,0 +1,65 @@
+# type: ignore
+import os
+import tempfile
+from typing import AsyncGenerator
+
+from msg_parser import MsOxMessage
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class MSGParser(AsyncParser[str | bytes]):
+ """Parser for MSG (Outlook Message) files using msg_parser."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest MSG data and yield email content."""
+ if isinstance(data, str):
+ raise ValueError("MSG data must be in bytes format.")
+
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".msg")
+ try:
+ tmp_file.write(data)
+ tmp_file.close()
+
+ msg = MsOxMessage(tmp_file.name)
+
+ metadata = []
+
+ if msg.subject:
+ metadata.append(f"Subject: {msg.subject}")
+ if msg.sender:
+ metadata.append(f"From: {msg.sender}")
+ if msg.to:
+ metadata.append(f"To: {', '.join(msg.to)}")
+ if msg.sent_date:
+ metadata.append(f"Date: {msg.sent_date}")
+ if metadata:
+ yield "\n".join(metadata)
+ if msg.body:
+ yield msg.body.strip()
+
+ for attachment in msg.attachments:
+ if attachment.Filename:
+ yield f"\nAttachment: {attachment.Filename}"
+
+ except Exception as e:
+ raise ValueError(f"Error processing MSG file: {str(e)}") from e
+ finally:
+ os.remove(tmp_file.name)
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/org_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/org_parser.py
new file mode 100644
index 00000000..2ea3f857
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/org_parser.py
@@ -0,0 +1,72 @@
+# type: ignore
+from typing import AsyncGenerator
+
+import orgparse
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class ORGParser(AsyncParser[str | bytes]):
+ """Parser for ORG (Emacs Org-mode) files."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.orgparse = orgparse
+
+ def _process_node(self, node) -> list[str]:
+ """Process an org-mode node and return its content."""
+ contents = []
+
+ # Add heading with proper level of asterisks
+ if node.level > 0:
+ contents.append(f"{'*' * node.level} {node.heading}")
+
+ # Add body content if exists
+ if node.body:
+ contents.append(node.body.strip())
+
+ return contents
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest ORG data and yield document content."""
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+
+ try:
+ # Create a temporary file-like object for orgparse
+ from io import StringIO
+
+ file_obj = StringIO(data)
+
+ # Parse the org file
+ root = self.orgparse.load(file_obj)
+
+ # Process root node if it has content
+ if root.body:
+ yield root.body.strip()
+
+ # Process all nodes
+ for node in root[1:]: # Skip root node in iteration
+ contents = self._process_node(node)
+ for content in contents:
+ if content.strip():
+ yield content.strip()
+
+ except Exception as e:
+ raise ValueError(f"Error processing ORG file: {str(e)}") from e
+ finally:
+ file_obj.close()
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/p7s_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/p7s_parser.py
new file mode 100644
index 00000000..84983494
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/p7s_parser.py
@@ -0,0 +1,178 @@
+# type: ignore
+import email
+import logging
+from base64 import b64decode
+from datetime import datetime
+from email.message import Message
+from typing import AsyncGenerator
+
+from cryptography import x509
+from cryptography.hazmat.primitives.serialization import pkcs7
+from cryptography.x509.oid import NameOID
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class P7SParser(AsyncParser[str | bytes]):
+ """Parser for S/MIME messages containing a P7S (PKCS#7 Signature) file."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.x509 = x509
+ self.pkcs7 = pkcs7
+ self.NameOID = NameOID
+
+ def _format_datetime(self, dt: datetime) -> str:
+ """Format datetime in a readable way."""
+ return dt.strftime("%Y-%m-%d %H:%M:%S UTC")
+
+ def _get_name_attribute(self, name, oid):
+ """Safely get name attribute."""
+ try:
+ return name.get_attributes_for_oid(oid)[0].value
+ except (IndexError, ValueError):
+ return None
+
+ def _extract_cert_info(self, cert) -> dict:
+ """Extract relevant information from a certificate."""
+ try:
+ subject = cert.subject
+ issuer = cert.issuer
+
+ info = {
+ "common_name": self._get_name_attribute(
+ subject, self.NameOID.COMMON_NAME
+ ),
+ "organization": self._get_name_attribute(
+ subject, self.NameOID.ORGANIZATION_NAME
+ ),
+ "email": self._get_name_attribute(
+ subject, self.NameOID.EMAIL_ADDRESS
+ ),
+ "issuer_common_name": self._get_name_attribute(
+ issuer, self.NameOID.COMMON_NAME
+ ),
+ "issuer_organization": self._get_name_attribute(
+ issuer, self.NameOID.ORGANIZATION_NAME
+ ),
+ "serial_number": hex(cert.serial_number)[2:],
+ "not_valid_before": self._format_datetime(
+ cert.not_valid_before
+ ),
+ "not_valid_after": self._format_datetime(cert.not_valid_after),
+ "version": cert.version.name,
+ }
+
+ return {k: v for k, v in info.items() if v is not None}
+
+ except Exception as e:
+ logger.warning(f"Error extracting certificate info: {e}")
+ return {}
+
+ def _try_parse_signature(self, data: bytes):
+ """Try to parse the signature data as PKCS7 containing certificates."""
+ exceptions = []
+
+ # Try DER format PKCS7
+ try:
+ certs = self.pkcs7.load_der_pkcs7_certificates(data)
+ if certs is not None:
+ return certs
+ except Exception as e:
+ exceptions.append(f"DER PKCS7 parsing failed: {str(e)}")
+
+ # Try PEM format PKCS7
+ try:
+ certs = self.pkcs7.load_pem_pkcs7_certificates(data)
+ if certs is not None:
+ return certs
+ except Exception as e:
+ exceptions.append(f"PEM PKCS7 parsing failed: {str(e)}")
+
+ raise ValueError(
+ "Unable to parse signature file as PKCS7 with certificates. Attempted methods:\n"
+ + "\n".join(exceptions)
+ )
+
+ def _extract_p7s_data_from_mime(self, raw_data: bytes) -> bytes:
+ """Extract the raw PKCS#7 signature data from a MIME message."""
+ msg: Message = email.message_from_bytes(raw_data)
+
+ # If the message is multipart, find the part with application/x-pkcs7-signature
+ if msg.is_multipart():
+ for part in msg.walk():
+ ctype = part.get_content_type()
+ if ctype == "application/x-pkcs7-signature":
+ # Get the base64 encoded data from the payload
+ payload = part.get_payload(decode=False)
+ # payload at this stage is a base64 string
+ try:
+ return b64decode(payload)
+ except Exception as e:
+ raise ValueError(
+ f"Failed to decode base64 PKCS#7 signature: {str(e)}"
+ ) from e
+ # If we reach here, no PKCS#7 part was found
+ raise ValueError(
+ "No application/x-pkcs7-signature part found in the MIME message."
+ )
+ else:
+ # Not multipart, try to parse directly if it's just a raw P7S
+ # This scenario is less common; usually it's multipart.
+ if msg.get_content_type() == "application/x-pkcs7-signature":
+ payload = msg.get_payload(decode=False)
+ return b64decode(payload)
+
+ raise ValueError(
+ "The provided data does not contain a valid S/MIME signed message."
+ )
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest an S/MIME message and extract the PKCS#7 signature
+ information."""
+ # If data is a string, it might be base64 encoded, or it might be the raw MIME text.
+ # We should assume it's raw MIME text here because the input includes MIME headers.
+ if isinstance(data, str):
+ # Convert to bytes (raw MIME)
+ data = data.encode("utf-8")
+
+ try:
+ # Extract the raw PKCS#7 data (der/pem) from the MIME message
+ p7s_data = self._extract_p7s_data_from_mime(data)
+
+ # Parse the PKCS#7 data for certificates
+ certificates = self._try_parse_signature(p7s_data)
+
+ if not certificates:
+ yield "No certificates found in the provided P7S file."
+ return
+
+ # Process each certificate
+ for i, cert in enumerate(certificates, 1):
+ if cert_info := self._extract_cert_info(cert):
+ yield f"Certificate {i}:"
+ for key, value in cert_info.items():
+ if value:
+ yield f"{key.replace('_', ' ').title()}: {value}"
+ yield "" # Empty line between certificates
+ else:
+ yield f"Certificate {i}: No detailed information extracted."
+
+ except Exception as e:
+ raise ValueError(f"Error processing P7S file: {str(e)}") from e
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/rst_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/rst_parser.py
new file mode 100644
index 00000000..76390655
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/rst_parser.py
@@ -0,0 +1,58 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from docutils.core import publish_string
+from docutils.writers import html5_polyglot
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class RSTParser(AsyncParser[str | bytes]):
+ """Parser for reStructuredText (.rst) files."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.publish_string = publish_string
+ self.html5_polyglot = html5_polyglot
+
+ async def ingest(
+ self, data: str | bytes, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+
+ try:
+ # Convert RST to HTML
+ html = self.publish_string(
+ source=data,
+ writer=self.html5_polyglot.Writer(),
+ settings_overrides={"report_level": 5},
+ )
+
+ # Basic HTML cleanup
+ import re
+
+ text = html.decode("utf-8")
+ text = re.sub(r"<[^>]+>", " ", text)
+ text = re.sub(r"\s+", " ", text)
+
+ # Split into paragraphs and yield non-empty ones
+ paragraphs = text.split("\n\n")
+ for paragraph in paragraphs:
+ if paragraph.strip():
+ yield paragraph.strip()
+
+ except Exception as e:
+ raise ValueError(f"Error processing RST file: {str(e)}") from e
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/tsv_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/tsv_parser.py
new file mode 100644
index 00000000..35478360
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/tsv_parser.py
@@ -0,0 +1,109 @@
+# type: ignore
+from typing import IO, AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class TSVParser(AsyncParser[str | bytes]):
+ """A parser for TSV (Tab Separated Values) data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ import csv
+ from io import StringIO
+
+ self.csv = csv
+ self.StringIO = StringIO
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest TSV data and yield text from each row."""
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ tsv_reader = self.csv.reader(self.StringIO(data), delimiter="\t")
+ for row in tsv_reader:
+ yield ", ".join(row) # Still join with comma for readability
+
+
+class TSVParserAdvanced(AsyncParser[str | bytes]):
+ """An advanced parser for TSV data with chunking support."""
+
+ def __init__(
+ self, config: IngestionConfig, llm_provider: CompletionProvider
+ ):
+ self.llm_provider = llm_provider
+ self.config = config
+
+ import csv
+ from io import StringIO
+
+ self.csv = csv
+ self.StringIO = StringIO
+
+ def validate_tsv(self, file: IO[bytes]) -> bool:
+ """Validate if the file is actually tab-delimited."""
+ num_bytes = 65536
+ lines = file.readlines(num_bytes)
+ file.seek(0)
+
+ if not lines:
+ return False
+
+ # Check if tabs exist in first few lines
+ sample = "\n".join(ln.decode("utf-8") for ln in lines[:5])
+ return "\t" in sample
+
+ async def ingest(
+ self,
+ data: str | bytes,
+ num_col_times_num_rows: int = 100,
+ *args,
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """Ingest TSV data and yield text in chunks."""
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+
+ # Validate TSV format
+ if not self.validate_tsv(self.StringIO(data)):
+ raise ValueError("File does not appear to be tab-delimited")
+
+ tsv_reader = self.csv.reader(self.StringIO(data), delimiter="\t")
+
+ # Get header
+ header = next(tsv_reader)
+ num_cols = len(header)
+ num_rows = num_col_times_num_rows // num_cols
+
+ chunk_rows = []
+ for row_num, row in enumerate(tsv_reader):
+ chunk_rows.append(row)
+ if row_num % num_rows == 0:
+ yield (
+ ", ".join(header)
+ + "\n"
+ + "\n".join([", ".join(row) for row in chunk_rows])
+ )
+ chunk_rows = []
+
+ # Yield remaining rows
+ if chunk_rows:
+ yield (
+ ", ".join(header)
+ + "\n"
+ + "\n".join([", ".join(row) for row in chunk_rows])
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/xls_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/xls_parser.py
new file mode 100644
index 00000000..0bda9510
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/xls_parser.py
@@ -0,0 +1,140 @@
+# type: ignore
+from typing import AsyncGenerator
+
+import networkx as nx
+import numpy as np
+import xlrd
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class XLSParser(AsyncParser[str | bytes]):
+ """A parser for XLS (Excel 97-2003) data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.xlrd = xlrd
+
+ async def ingest(
+ self, data: bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest XLS data and yield text from each row."""
+ if isinstance(data, str):
+ raise ValueError("XLS data must be in bytes format.")
+
+ wb = self.xlrd.open_workbook(file_contents=data)
+ for sheet in wb.sheets():
+ for row_idx in range(sheet.nrows):
+ # Get all values in the row
+ row_values = []
+ for col_idx in range(sheet.ncols):
+ cell = sheet.cell(row_idx, col_idx)
+ # Handle different cell types
+ if cell.ctype == self.xlrd.XL_CELL_DATE:
+ try:
+ value = self.xlrd.xldate_as_datetime(
+ cell.value, wb.datemode
+ ).strftime("%Y-%m-%d")
+ except Exception:
+ value = str(cell.value)
+ elif cell.ctype == self.xlrd.XL_CELL_BOOLEAN:
+ value = str(bool(cell.value)).lower()
+ elif cell.ctype == self.xlrd.XL_CELL_ERROR:
+ value = "#ERROR#"
+ else:
+ value = str(cell.value).strip()
+
+ row_values.append(value)
+
+ # Yield non-empty rows
+ if any(val.strip() for val in row_values):
+ yield ", ".join(row_values)
+
+
+class XLSParserAdvanced(AsyncParser[str | bytes]):
+ """An advanced parser for XLS data with chunking support."""
+
+ def __init__(
+ self, config: IngestionConfig, llm_provider: CompletionProvider
+ ):
+ self.llm_provider = llm_provider
+ self.config = config
+ self.nx = nx
+ self.np = np
+ self.xlrd = xlrd
+
+ def connected_components(self, arr):
+ g = self.nx.grid_2d_graph(len(arr), len(arr[0]))
+ empty_cell_indices = list(zip(*self.np.where(arr == ""), strict=False))
+ g.remove_nodes_from(empty_cell_indices)
+ components = self.nx.connected_components(g)
+ for component in components:
+ rows, cols = zip(*component, strict=False)
+ min_row, max_row = min(rows), max(rows)
+ min_col, max_col = min(cols), max(cols)
+ yield arr[min_row : max_row + 1, min_col : max_col + 1]
+
+ def get_cell_value(self, cell, workbook):
+ """Extract cell value handling different data types."""
+ if cell.ctype == self.xlrd.XL_CELL_DATE:
+ try:
+ return self.xlrd.xldate_as_datetime(
+ cell.value, workbook.datemode
+ ).strftime("%Y-%m-%d")
+ except Exception:
+ return str(cell.value)
+ elif cell.ctype == self.xlrd.XL_CELL_BOOLEAN:
+ return str(bool(cell.value)).lower()
+ elif cell.ctype == self.xlrd.XL_CELL_ERROR:
+ return "#ERROR#"
+ else:
+ return str(cell.value).strip()
+
+ async def ingest(
+ self, data: bytes, num_col_times_num_rows: int = 100, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest XLS data and yield text from each connected component."""
+ if isinstance(data, str):
+ raise ValueError("XLS data must be in bytes format.")
+
+ workbook = self.xlrd.open_workbook(file_contents=data)
+
+ for sheet in workbook.sheets():
+ # Convert sheet to numpy array with proper value handling
+ ws_data = self.np.array(
+ [
+ [
+ self.get_cell_value(sheet.cell(row, col), workbook)
+ for col in range(sheet.ncols)
+ ]
+ for row in range(sheet.nrows)
+ ]
+ )
+
+ for table in self.connected_components(ws_data):
+ if len(table) <= 1:
+ continue
+
+ num_rows = len(table)
+ num_rows_per_chunk = num_col_times_num_rows // num_rows
+ headers = ", ".join(table[0])
+
+ for i in range(1, num_rows, num_rows_per_chunk):
+ chunk = table[i : i + num_rows_per_chunk]
+ yield (
+ headers
+ + "\n"
+ + "\n".join([", ".join(row) for row in chunk])
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/structured/xlsx_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/structured/xlsx_parser.py
new file mode 100644
index 00000000..4c303177
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/structured/xlsx_parser.py
@@ -0,0 +1,100 @@
+# type: ignore
+from io import BytesIO
+from typing import AsyncGenerator
+
+import networkx as nx
+import numpy as np
+from openpyxl import load_workbook
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class XLSXParser(AsyncParser[str | bytes]):
+ """A parser for XLSX data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+ self.load_workbook = load_workbook
+
+ async def ingest(
+ self, data: bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest XLSX data and yield text from each row."""
+ if isinstance(data, str):
+ raise ValueError("XLSX data must be in bytes format.")
+
+ wb = self.load_workbook(filename=BytesIO(data))
+ for sheet in wb.worksheets:
+ for row in sheet.iter_rows(values_only=True):
+ yield ", ".join(map(str, row))
+
+
+class XLSXParserAdvanced(AsyncParser[str | bytes]):
+ """A parser for XLSX data."""
+
+ # identifies connected components in the excel graph and extracts data from each component
+ def __init__(
+ self, config: IngestionConfig, llm_provider: CompletionProvider
+ ):
+ self.llm_provider = llm_provider
+ self.config = config
+ self.nx = nx
+ self.np = np
+ self.load_workbook = load_workbook
+
+ def connected_components(self, arr):
+ g = self.nx.grid_2d_graph(len(arr), len(arr[0]))
+ empty_cell_indices = list(
+ zip(*self.np.where(arr is None), strict=False)
+ )
+ g.remove_nodes_from(empty_cell_indices)
+ components = self.nx.connected_components(g)
+ for component in components:
+ rows, cols = zip(*component, strict=False)
+ min_row, max_row = min(rows), max(rows)
+ min_col, max_col = min(cols), max(cols)
+ yield arr[min_row : max_row + 1, min_col : max_col + 1].astype(
+ "str"
+ )
+
+ async def ingest(
+ self, data: bytes, num_col_times_num_rows: int = 100, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest XLSX data and yield text from each connected component."""
+ if isinstance(data, str):
+ raise ValueError("XLSX data must be in bytes format.")
+
+ workbook = self.load_workbook(filename=BytesIO(data))
+
+ for ws in workbook.worksheets:
+ ws_data = self.np.array(
+ [[cell.value for cell in row] for row in ws.iter_rows()]
+ )
+ for table in self.connected_components(ws_data):
+ # parse like a csv parser, assumes that the first row has column names
+ if len(table) <= 1:
+ continue
+
+ num_rows = len(table)
+ num_rows_per_chunk = num_col_times_num_rows // num_rows
+ headers = ", ".join(table[0])
+ # add header to each one
+ for i in range(1, num_rows, num_rows_per_chunk):
+ chunk = table[i : i + num_rows_per_chunk]
+ yield (
+ headers
+ + "\n"
+ + "\n".join([", ".join(row) for row in chunk])
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/text/__init__.py b/.venv/lib/python3.12/site-packages/core/parsers/text/__init__.py
new file mode 100644
index 00000000..8f85d046
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/text/__init__.py
@@ -0,0 +1,10 @@
+# type: ignore
+from .html_parser import HTMLParser
+from .md_parser import MDParser
+from .text_parser import TextParser
+
+__all__ = [
+ "MDParser",
+ "HTMLParser",
+ "TextParser",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/text/html_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/text/html_parser.py
new file mode 100644
index 00000000..a04331e0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/text/html_parser.py
@@ -0,0 +1,32 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from bs4 import BeautifulSoup
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class HTMLParser(AsyncParser[str | bytes]):
+ """A parser for HTML data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest HTML data and yield text."""
+ soup = BeautifulSoup(data, "html.parser")
+ yield soup.get_text()
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/text/md_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/text/md_parser.py
new file mode 100644
index 00000000..7ab11d92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/text/md_parser.py
@@ -0,0 +1,39 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from bs4 import BeautifulSoup
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class MDParser(AsyncParser[str | bytes]):
+ """A parser for Markdown data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ import markdown
+
+ self.markdown = markdown
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str, None]:
+ """Ingest Markdown data and yield text."""
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ html = self.markdown.markdown(data)
+ soup = BeautifulSoup(html, "html.parser")
+ yield soup.get_text()
diff --git a/.venv/lib/python3.12/site-packages/core/parsers/text/text_parser.py b/.venv/lib/python3.12/site-packages/core/parsers/text/text_parser.py
new file mode 100644
index 00000000..51ff1cbd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/parsers/text/text_parser.py
@@ -0,0 +1,30 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+ CompletionProvider,
+ DatabaseProvider,
+ IngestionConfig,
+)
+
+
+class TextParser(AsyncParser[str | bytes]):
+ """A parser for raw text data."""
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: DatabaseProvider,
+ llm_provider: CompletionProvider,
+ ):
+ self.database_provider = database_provider
+ self.llm_provider = llm_provider
+ self.config = config
+
+ async def ingest(
+ self, data: str | bytes, *args, **kwargs
+ ) -> AsyncGenerator[str | bytes, None]:
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ yield data
diff --git a/.venv/lib/python3.12/site-packages/core/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/__init__.py
new file mode 100644
index 00000000..7cfa82eb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/__init__.py
@@ -0,0 +1,77 @@
+from .auth import (
+ ClerkAuthProvider,
+ JwtAuthProvider,
+ R2RAuthProvider,
+ SupabaseAuthProvider,
+)
+from .crypto import (
+ BcryptCryptoConfig,
+ BCryptCryptoProvider,
+ NaClCryptoConfig,
+ NaClCryptoProvider,
+)
+from .database import PostgresDatabaseProvider
+from .email import (
+ AsyncSMTPEmailProvider,
+ ConsoleMockEmailProvider,
+ MailerSendEmailProvider,
+ SendGridEmailProvider,
+)
+from .embeddings import (
+ LiteLLMEmbeddingProvider,
+ OllamaEmbeddingProvider,
+ OpenAIEmbeddingProvider,
+)
+from .ingestion import ( # type: ignore
+ R2RIngestionConfig,
+ R2RIngestionProvider,
+ UnstructuredIngestionConfig,
+ UnstructuredIngestionProvider,
+)
+from .llm import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from .orchestration import (
+ HatchetOrchestrationProvider,
+ SimpleOrchestrationProvider,
+)
+
+__all__ = [
+ # Auth
+ "R2RAuthProvider",
+ "SupabaseAuthProvider",
+ "JwtAuthProvider",
+ "ClerkAuthProvider",
+ # Ingestion
+ "R2RIngestionProvider",
+ "R2RIngestionConfig",
+ "UnstructuredIngestionProvider",
+ "UnstructuredIngestionConfig",
+ # Crypto
+ "BCryptCryptoProvider",
+ "BcryptCryptoConfig",
+ "NaClCryptoConfig",
+ "NaClCryptoProvider",
+ # Embeddings
+ "LiteLLMEmbeddingProvider",
+ "OllamaEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ # Database
+ "PostgresDatabaseProvider",
+ # Email
+ "AsyncSMTPEmailProvider",
+ "ConsoleMockEmailProvider",
+ "SendGridEmailProvider",
+ "MailerSendEmailProvider",
+ # Orchestration
+ "HatchetOrchestrationProvider",
+ "SimpleOrchestrationProvider",
+ # LLM
+ "AnthropicCompletionProvider",
+ "OpenAICompletionProvider",
+ "R2RCompletionProvider",
+ "LiteLLMCompletionProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py
new file mode 100644
index 00000000..9f116ffa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py
@@ -0,0 +1,11 @@
+from .clerk import ClerkAuthProvider
+from .jwt import JwtAuthProvider
+from .r2r_auth import R2RAuthProvider
+from .supabase import SupabaseAuthProvider
+
+__all__ = [
+ "R2RAuthProvider",
+ "SupabaseAuthProvider",
+ "JwtAuthProvider",
+ "ClerkAuthProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py
new file mode 100644
index 00000000..0db665e0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py
@@ -0,0 +1,133 @@
+import logging
+import os
+from datetime import datetime
+
+from core.base import (
+ AuthConfig,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ TokenData,
+)
+
+from ..database import PostgresDatabaseProvider
+from .jwt import JwtAuthProvider
+
+logger = logging.getLogger(__name__)
+
+
+class ClerkAuthProvider(JwtAuthProvider):
+ """
+ ClerkAuthProvider extends JwtAuthProvider to support token verification with Clerk.
+ It uses Clerk's SDK to verify the JWT token and extract user information.
+ """
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config=config,
+ crypto_provider=crypto_provider,
+ database_provider=database_provider,
+ email_provider=email_provider,
+ )
+ try:
+ from clerk_backend_api.jwks_helpers.verifytoken import (
+ VerifyTokenOptions,
+ verify_token,
+ )
+
+ self.verify_token = verify_token
+ self.VerifyTokenOptions = VerifyTokenOptions
+ except ImportError as e:
+ raise R2RException(
+ status_code=500,
+ message="Clerk SDK is not installed. Run `pip install clerk-backend-api`",
+ ) from e
+
+ async def decode_token(self, token: str) -> TokenData:
+ """
+ Decode and verify the JWT token using Clerk's verify_token function.
+
+ Args:
+ token: The JWT token to decode
+
+ Returns:
+ TokenData: The decoded token data with user information
+
+ Raises:
+ R2RException: If the token is invalid or verification fails
+ """
+ clerk_secret_key = os.getenv("CLERK_SECRET_KEY")
+ if not clerk_secret_key:
+ raise R2RException(
+ status_code=500,
+ message="CLERK_SECRET_KEY environment variable is not set",
+ )
+
+ try:
+ # Configure verification options
+ options = self.VerifyTokenOptions(
+ secret_key=clerk_secret_key,
+ # Optional: specify audience if needed
+ # audience="your-audience",
+ # Optional: specify authorized parties if needed
+ # authorized_parties=["https://your-domain.com"]
+ )
+
+ # Verify the token using Clerk's SDK
+ payload = self.verify_token(token, options)
+
+ # Check for the expected claims in the token payload
+ if not payload.get("sub") or not payload.get("email"):
+ raise R2RException(
+ status_code=401,
+ message="Invalid token: missing required claims",
+ )
+
+ # Create user in database if not exists
+ try:
+ await self.database_provider.users_handler.get_user_by_email(
+ payload.get("email")
+ )
+ # TODO do we want to update user info here based on what's in the token?
+ except Exception:
+ # user doesn't exist, create in db
+ logger.debug(f"Creating new user: {payload.get('email')}")
+ try:
+ # Construct name from first_name and last_name if available
+ first_name = payload.get("first_name", "")
+ last_name = payload.get("last_name", "")
+ name = payload.get("name")
+
+ # If name not directly provided, try to build it from first and last names
+ if not name and (first_name or last_name):
+ name = f"{first_name} {last_name}".strip()
+
+ await self.database_provider.users_handler.create_user(
+ email=payload.get("email"),
+ account_type="external",
+ name=name,
+ )
+ except Exception as e:
+ logger.error(f"Error creating user: {e}")
+ raise R2RException(
+ status_code=500, message="Failed to create user"
+ ) from e
+
+ # Return the token data
+ return TokenData(
+ email=payload.get("email"),
+ token_type="bearer",
+ exp=datetime.fromtimestamp(payload.get("exp")),
+ )
+
+ except Exception as e:
+ logger.info(f"Clerk token verification failed: {e}")
+ raise R2RException(
+ status_code=401, message="Invalid token", detail=str(e)
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py
new file mode 100644
index 00000000..08f85e6d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py
@@ -0,0 +1,166 @@
+import logging
+import os
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+import jwt
+from fastapi import Depends
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+logger = logging.getLogger()
+
+
+class JwtAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ async def oauth_callback(self, code: str) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ async def user(self, token: str) -> User:
+ raise NotImplementedError("Not implemented")
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ def create_access_token(self, data: dict) -> str:
+ raise NotImplementedError("Not implemented")
+
+ def create_refresh_token(self, data: dict) -> str:
+ raise NotImplementedError("Not implemented")
+
+ async def decode_token(self, token: str) -> TokenData:
+ # use JWT library to validate and decode JWT token
+ jwtSecret = os.getenv("JWT_SECRET")
+ if jwtSecret is None:
+ raise R2RException(
+ status_code=500,
+ message="JWT_SECRET environment variable is not set",
+ )
+ try:
+ user = jwt.decode(token, jwtSecret, algorithms=["HS256"])
+ except Exception as e:
+ logger.info(f"JWT verification failed: {e}")
+ raise R2RException(
+ status_code=401, message="Invalid JWT token", detail=e
+ ) from e
+ if user:
+ # Create user in database if not exists
+ try:
+ await self.database_provider.users_handler.get_user_by_email(
+ user.get("email")
+ )
+ # TODO do we want to update user info here based on what's in the token?
+ except Exception:
+ # user doesn't exist, create in db
+ logger.debug(f"Creating new user: {user.get('email')}")
+ try:
+ await self.database_provider.users_handler.create_user(
+ email=user.get("email"),
+ account_type="external",
+ name=user.get("name"),
+ )
+ except Exception as e:
+ logger.error(f"Error creating user: {e}")
+ raise R2RException(
+ status_code=500, message="Failed to create user"
+ ) from e
+ return TokenData(
+ email=user.get("email"),
+ token_type="bearer",
+ exp=user.get("exp"),
+ )
+ else:
+ raise R2RException(status_code=401, message="Invalid JWT token")
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ # Check if user is active
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def logout(self, token: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User: # type: ignore
+ raise NotImplementedError("Not implemented")
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ raise NotImplementedError("Not implemented")
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ raise NotImplementedError("Not implemented")
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ raise NotImplementedError("Not implemented")
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py
new file mode 100644
index 00000000..762884ce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py
@@ -0,0 +1,701 @@
+import logging
+import os
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CollectionResponse,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
+DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
+
+logger = logging.getLogger()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+def normalize_email(email: str) -> str:
+ """Normalizes an email address by converting it to lowercase. This ensures
+ consistent email handling throughout the application.
+
+ Args:
+ email: The email address to normalize
+
+ Returns:
+ The normalized (lowercase) email address
+ """
+ return email.lower() if email else ""
+
+
+class R2RAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ logger.debug(f"Initializing R2RAuthProvider with config: {config}")
+
+ # We no longer use a local secret_key or defaults here.
+ # All key handling is done in the crypto_provider.
+ self.access_token_lifetime_in_minutes = (
+ config.access_token_lifetime_in_minutes
+ or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
+ or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
+ )
+ self.refresh_token_lifetime_in_days = (
+ config.refresh_token_lifetime_in_days
+ or os.getenv("R2R_REFRESH_LIFE_IN_DAYS")
+ or DEFAULT_REFRESH_LIFETIME_IN_DAYS
+ )
+ self.config: AuthConfig = config
+
+ async def initialize(self):
+ try:
+ user = await self.register(
+ email=normalize_email(self.admin_email),
+ password=self.admin_password,
+ is_superuser=True,
+ )
+ await self.database_provider.users_handler.mark_user_as_superuser(
+ id=user.id
+ )
+ except R2RException:
+ logger.info("Default admin user already exists.")
+
+ def create_access_token(self, data: dict) -> str:
+ expire = datetime.now(timezone.utc) + timedelta(
+ minutes=float(self.access_token_lifetime_in_minutes)
+ )
+ # Add token_type and pass data/expiry to crypto_provider
+ data_with_type = {**data, "token_type": "access"}
+ return self.crypto_provider.generate_secure_token(
+ data=data_with_type,
+ expiry=expire,
+ )
+
+ def create_refresh_token(self, data: dict) -> str:
+ expire = datetime.now(timezone.utc) + timedelta(
+ days=float(self.refresh_token_lifetime_in_days)
+ )
+ data_with_type = {**data, "token_type": "refresh"}
+ return self.crypto_provider.generate_secure_token(
+ data=data_with_type,
+ expiry=expire,
+ )
+
+ async def decode_token(self, token: str) -> TokenData:
+ if "token=" in token:
+ token = token.split("token=")[1]
+ if "&tokenType=refresh" in token:
+ token = token.split("&tokenType=refresh")[0]
+ # First, check if the token is blacklisted
+ if await self.database_provider.token_handler.is_token_blacklisted(
+ token=token
+ ):
+ raise R2RException(
+ status_code=401, message="Token has been invalidated"
+ )
+
+ # Verify token using crypto_provider
+ payload = self.crypto_provider.verify_secure_token(token=token)
+ if payload is None:
+ raise R2RException(
+ status_code=401, message="Invalid or expired token"
+ )
+
+ email = payload.get("sub")
+ token_type = payload.get("token_type")
+ exp = payload.get("exp")
+
+ if email is None or token_type is None or exp is None:
+ raise R2RException(status_code=401, message="Invalid token claims")
+
+ email_str: str = email
+ token_type_str: str = token_type
+ exp_float: float = exp
+
+ exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc)
+ if exp_datetime < datetime.now(timezone.utc):
+ raise R2RException(status_code=401, message="Token has expired")
+
+ return TokenData(
+ email=normalize_email(email_str),
+ token_type=token_type_str,
+ exp=exp_datetime,
+ )
+
+ async def authenticate_api_key(self, api_key: str) -> User:
+ """Authenticate using an API key of the form "public_key.raw_key".
+
+ Returns a User if successful, or raises R2RException if not.
+ """
+ try:
+ key_id, raw_key = api_key.split(".", 1)
+ except ValueError as e:
+ raise R2RException(
+ status_code=401, message="Invalid API key format"
+ ) from e
+
+ key_record = (
+ await self.database_provider.users_handler.get_api_key_record(
+ key_id=key_id
+ )
+ )
+ if not key_record:
+ raise R2RException(status_code=401, message="Invalid API key")
+
+ if not self.crypto_provider.verify_api_key(
+ raw_api_key=raw_key, hashed_key=key_record["hashed_key"]
+ ):
+ raise R2RException(status_code=401, message="Invalid API key")
+
+ user = await self.database_provider.users_handler.get_user_by_id(
+ id=key_record["user_id"]
+ )
+ if not user.is_active:
+ raise R2RException(
+ status_code=401, message="User account is inactive"
+ )
+
+ return user
+
+ async def user(self, token: str = Depends(oauth2_scheme)) -> User:
+ """Attempt to authenticate via JWT first, then fallback to API key."""
+ # Try JWT auth
+ try:
+ token_data = await self.decode_token(token=token)
+ if not token_data.email:
+ raise R2RException(
+ status_code=401, message="Could not validate credentials"
+ )
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(token_data.email)
+ )
+ )
+ if user is None:
+ raise R2RException(
+ status_code=401,
+ message="Invalid authentication credentials",
+ )
+ return user
+ except R2RException:
+ # If JWT fails, try API key auth
+ # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed
+ token = token.removeprefix("Bearer ")
+ return await self.authenticate_api_key(api_key=token)
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def register(
+ self,
+ email: str,
+ password: Optional[str] = None,
+ is_superuser: bool = False,
+ account_type: str = "password",
+ github_id: Optional[str] = None,
+ google_id: Optional[str] = None,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ if account_type == "password":
+ if not password:
+ raise R2RException(
+ status_code=400,
+ message="Password is required for password accounts",
+ )
+ else:
+ if github_id and google_id:
+ raise R2RException(
+ status_code=400,
+ message="Cannot register OAuth with both GitHub and Google IDs",
+ )
+ if not github_id and not google_id:
+ raise R2RException(
+ status_code=400,
+ message="Invalid OAuth specification without GitHub or Google ID",
+ )
+ new_user = await self.database_provider.users_handler.create_user(
+ email=normalize_email(email),
+ password=password,
+ is_superuser=is_superuser,
+ account_type=account_type,
+ github_id=github_id,
+ google_id=google_id,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+ default_collection: CollectionResponse = (
+ await self.database_provider.collections_handler.create_collection(
+ owner_id=new_user.id,
+ )
+ )
+ await self.database_provider.graphs_handler.create(
+ collection_id=default_collection.id,
+ name=default_collection.name,
+ description=default_collection.description,
+ )
+
+ await self.database_provider.users_handler.add_user_to_collection(
+ new_user.id, default_collection.id
+ )
+
+ new_user = await self.database_provider.users_handler.get_user_by_id(
+ new_user.id
+ )
+
+ if self.config.require_email_verification:
+ verification_code, _ = await self.send_verification_email(
+ email=normalize_email(email), user=new_user
+ )
+ else:
+ expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
+ await self.database_provider.users_handler.store_verification_code(
+ id=new_user.id,
+ verification_code=str(-1),
+ expiry=expiry,
+ )
+ await self.database_provider.users_handler.mark_user_as_verified(
+ id=new_user.id
+ )
+
+ return new_user
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ if user is None:
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+
+ verification_code = self.crypto_provider.generate_verification_code()
+ expiry = datetime.now(timezone.utc) + timedelta(hours=24)
+
+ await self.database_provider.users_handler.store_verification_code(
+ id=user.id,
+ verification_code=verification_code,
+ expiry=expiry,
+ )
+
+ if hasattr(user, "verification_code_expiry"):
+ user.verification_code_expiry = expiry
+
+ first_name = (
+ user.name.split(" ")[0] if user.name else email.split("@")[0]
+ )
+
+ await self.email_provider.send_verification_email(
+ to_email=user.email,
+ verification_code=verification_code,
+ dynamic_template_data={"first_name": first_name},
+ )
+
+ return verification_code, expiry
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
+ verification_code=verification_code
+ )
+ await self.database_provider.users_handler.mark_user_as_verified(
+ id=user_id
+ )
+ await self.database_provider.users_handler.remove_verification_code(
+ verification_code=verification_code
+ )
+ return {"message": "Email verified successfully"}
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ logger.debug(f"Attempting login for email: {email}")
+ user = await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+
+ if user.account_type != "password":
+ logger.warning(
+ f"Password login not allowed for {user.account_type} accounts: {email}"
+ )
+ raise R2RException(
+ status_code=401,
+ message=f"This account is configured for {user.account_type} login, not password.",
+ )
+
+ logger.debug(f"User found: {user}")
+
+ if not isinstance(user.hashed_password, str):
+ logger.error(
+ f"Invalid hashed_password type: {type(user.hashed_password)}"
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Invalid password hash in database",
+ )
+
+ try:
+ password_verified = self.crypto_provider.verify_password(
+ plain_password=password,
+ hashed_password=user.hashed_password,
+ )
+ except Exception as e:
+ logger.error(f"Error during password verification: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail="Error during password verification",
+ ) from e
+
+ if not password_verified:
+ logger.warning(f"Invalid password for user: {email}")
+ raise R2RException(
+ status_code=401, message="Incorrect email or password"
+ )
+
+ if not user.is_verified and self.config.require_email_verification:
+ logger.warning(f"Unverified user attempted login: {email}")
+ raise R2RException(status_code=401, message="Email not verified")
+
+ access_token = self.create_access_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(token=refresh_token, token_type="refresh"),
+ }
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ token_data = await self.decode_token(refresh_token)
+ if token_data.token_type != "refresh":
+ raise R2RException(
+ status_code=401, message="Invalid refresh token"
+ )
+
+ # Invalidate the old refresh token and create a new one
+ await self.database_provider.token_handler.blacklist_token(
+ token=refresh_token
+ )
+
+ new_access_token = self.create_access_token(
+ data={"sub": normalize_email(token_data.email)}
+ )
+ new_refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(token_data.email)}
+ )
+ return {
+ "access_token": Token(token=new_access_token, token_type="access"),
+ "refresh_token": Token(
+ token=new_refresh_token, token_type="refresh"
+ ),
+ }
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ if not isinstance(user.hashed_password, str):
+ logger.error(
+ f"Invalid hashed_password type: {type(user.hashed_password)}"
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Invalid password hash in database",
+ )
+
+ if not self.crypto_provider.verify_password(
+ plain_password=current_password,
+ hashed_password=user.hashed_password,
+ ):
+ raise R2RException(
+ status_code=400, message="Incorrect current password"
+ )
+
+ hashed_new_password = self.crypto_provider.get_password_hash(
+ password=new_password
+ )
+ await self.database_provider.users_handler.update_user_password(
+ id=user.id,
+ new_hashed_password=hashed_new_password,
+ )
+ try:
+ await self.email_provider.send_password_changed_email(
+ to_email=normalize_email(user.email),
+ dynamic_template_data={
+ "first_name": (
+ user.name.split(" ")[0] or "User"
+ if user.name
+ else "User"
+ )
+ },
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to send password change notification: {str(e)}"
+ )
+
+ return {"message": "Password changed successfully"}
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ try:
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+ )
+
+ reset_token = self.crypto_provider.generate_verification_code()
+ expiry = datetime.now(timezone.utc) + timedelta(hours=1)
+ await self.database_provider.users_handler.store_reset_token(
+ id=user.id,
+ reset_token=reset_token,
+ expiry=expiry,
+ )
+
+ first_name = (
+ user.name.split(" ")[0] if user.name else email.split("@")[0]
+ )
+ await self.email_provider.send_password_reset_email(
+ to_email=normalize_email(email),
+ reset_token=reset_token,
+ dynamic_template_data={"first_name": first_name},
+ )
+
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ except R2RException as e:
+ if e.status_code == 404:
+ # User doesn't exist; return a success message anyway
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ else:
+ raise
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
+ reset_token=reset_token
+ )
+ if not user_id:
+ raise R2RException(
+ status_code=400, message="Invalid or expired reset token"
+ )
+
+ hashed_new_password = self.crypto_provider.get_password_hash(
+ password=new_password
+ )
+ await self.database_provider.users_handler.update_user_password(
+ id=user_id,
+ new_hashed_password=hashed_new_password,
+ )
+ await self.database_provider.users_handler.remove_reset_token(
+ id=user_id
+ )
+ # Get the user information
+ user = await self.database_provider.users_handler.get_user_by_id(
+ id=user_id
+ )
+
+ try:
+ await self.email_provider.send_password_changed_email(
+ to_email=normalize_email(user.email),
+ dynamic_template_data={
+ "first_name": (
+ user.name.split(" ")[0] or "User"
+ if user.name
+ else "User"
+ )
+ },
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to send password change notification: {str(e)}"
+ )
+
+ return {"message": "Password reset successfully"}
+
+ async def logout(self, token: str) -> dict[str, str]:
+ await self.database_provider.token_handler.blacklist_token(token=token)
+ return {"message": "Logged out successfully"}
+
+ async def clean_expired_blacklisted_tokens(self):
+ await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
+
+ async def send_reset_email(self, email: str) -> dict:
+ verification_code, expiry = await self.send_verification_email(
+ email=normalize_email(email)
+ )
+
+ return {
+ "verification_code": verification_code,
+ "expiry": expiry,
+ "message": f"Verification email sent successfully to {email}",
+ }
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ key_id, raw_api_key = self.crypto_provider.generate_api_key()
+ hashed_key = self.crypto_provider.hash_api_key(raw_api_key)
+
+ api_key_uuid = (
+ await self.database_provider.users_handler.store_user_api_key(
+ user_id=user_id,
+ key_id=key_id,
+ hashed_key=hashed_key,
+ name=name,
+ description=description,
+ )
+ )
+
+ return {
+ "api_key": f"{key_id}.{raw_api_key}",
+ "key_id": str(api_key_uuid),
+ "public_key": key_id,
+ "name": name or "",
+ }
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ return await self.database_provider.users_handler.get_user_api_keys(
+ user_id=user_id
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ return await self.database_provider.users_handler.delete_api_key(
+ user_id=user_id,
+ key_id=key_id,
+ )
+
+ async def rename_api_key(
+ self, user_id: UUID, key_id: UUID, new_name: str
+ ) -> bool:
+ return await self.database_provider.users_handler.update_api_key_name(
+ user_id=user_id,
+ key_id=key_id,
+ name=new_name,
+ )
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ """Handles a login/registration flow for OAuth providers (e.g., Google
+ or GitHub).
+
+ :param provider: "google" or "github"
+ :param oauth_id: The unique ID from the OAuth provider (e.g. Google's
+ 'sub')
+ :param email: The user's email from the provider, if available.
+ :return: dict with access_token and refresh_token
+ """
+ # 1) Attempt to find user by google_id or github_id, or by email
+ # The logic depends on your preference. We'll assume "google" => google_id, etc.
+ try:
+ if provider == "google":
+ try:
+ user = await self.database_provider.users_handler.get_user_by_email(
+ normalize_email(email)
+ )
+ # If user found, check if user.google_id matches or is null. If null, update it
+ if user and not user.google_id:
+ raise R2RException(
+ status_code=401,
+ message="User already exists and is not linked to Google account",
+ )
+ except Exception:
+ # Create new user
+ user = await self.register(
+ email=normalize_email(email)
+ or f"{oauth_id}@google_oauth.fake", # fallback
+ password=None, # no password
+ account_type="oauth",
+ google_id=oauth_id,
+ )
+ elif provider == "github":
+ try:
+ user = await self.database_provider.users_handler.get_user_by_email(
+ normalize_email(email)
+ )
+ # If user found, check if user.google_id matches or is null. If null, update it
+ if user and not user.github_id:
+ raise R2RException(
+ status_code=401,
+ message="User already exists and is not linked to Github account",
+ )
+ except Exception:
+ # Create new user
+ user = await self.register(
+ email=normalize_email(email)
+ or f"{oauth_id}@github_oauth.fake", # fallback
+ password=None, # no password
+ account_type="oauth",
+ github_id=oauth_id,
+ )
+ # else handle other providers
+
+ except R2RException:
+ # If no user found or creation fails
+ raise R2RException(
+ status_code=401, message="Could not create or fetch user"
+ ) from None
+
+ # If user is inactive, etc.
+ if not user.is_active:
+ raise R2RException(
+ status_code=401, message="User account is inactive"
+ )
+
+ # Possibly mark user as verified if you trust the OAuth provider's email
+ user.is_verified = True
+ await self.database_provider.users_handler.update_user(user)
+
+ # 2) Generate tokens
+ access_token = self.create_access_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(user.email)}
+ )
+
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(token=refresh_token, token_type="refresh"),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py
new file mode 100644
index 00000000..5fc0e0bf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py
@@ -0,0 +1,249 @@
+import logging
+import os
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from supabase import Client, create_client
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+logger = logging.getLogger()
+
+logger = logging.getLogger()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+class SupabaseAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+ self.supabase_url = config.extra_fields.get(
+ "supabase_url", None
+ ) or os.getenv("SUPABASE_URL")
+ self.supabase_key = config.extra_fields.get(
+ "supabase_key", None
+ ) or os.getenv("SUPABASE_KEY")
+ if not self.supabase_url or not self.supabase_key:
+ raise HTTPException(
+ status_code=500,
+ detail="Supabase URL and key must be provided",
+ )
+ self.supabase: Client = create_client(
+ self.supabase_url, self.supabase_key
+ )
+
+ async def initialize(self):
+ # No initialization needed for Supabase
+ pass
+
+ def create_access_token(self, data: dict) -> str:
+ raise NotImplementedError(
+ "create_access_token is not used with Supabase authentication"
+ )
+
+ def create_refresh_token(self, data: dict) -> str:
+ raise NotImplementedError(
+ "create_refresh_token is not used with Supabase authentication"
+ )
+
+ async def decode_token(self, token: str) -> TokenData:
+ raise NotImplementedError(
+ "decode_token is not used with Supabase authentication"
+ )
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User: # type: ignore
+ # Use Supabase client to create a new user
+
+ if self.supabase.auth.sign_up(email=email, password=password):
+ raise R2RException(
+ status_code=400,
+ message="Supabase provider implementation is still under construction",
+ )
+ else:
+ raise R2RException(
+ status_code=400, message="User registration failed"
+ )
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ raise NotImplementedError(
+ "send_verification_email is not used with Supabase"
+ )
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ # Use Supabase client to verify email
+ if self.supabase.auth.verify_email(email, verification_code):
+ return {"message": "Email verified successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ # Use Supabase client to authenticate user and get tokens
+ if response := self.supabase.auth.sign_in(
+ email=email, password=password
+ ):
+ access_token = response.access_token
+ refresh_token = response.refresh_token
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(
+ token=refresh_token, token_type="refresh"
+ ),
+ }
+ else:
+ raise R2RException(
+ status_code=401, message="Invalid email or password"
+ )
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ # Use Supabase client to refresh access token
+ if response := self.supabase.auth.refresh_access_token(refresh_token):
+ new_access_token = response.access_token
+ new_refresh_token = response.refresh_token
+ return {
+ "access_token": Token(
+ token=new_access_token, token_type="access"
+ ),
+ "refresh_token": Token(
+ token=new_refresh_token, token_type="refresh"
+ ),
+ }
+ else:
+ raise R2RException(
+ status_code=401, message="Invalid refresh token"
+ )
+
+ async def user(self, token: str = Depends(oauth2_scheme)) -> User:
+ # Use Supabase client to get user details from token
+ if user := self.supabase.auth.get_user(token).user:
+ return User(
+ id=user.id,
+ email=user.email,
+ is_active=True, # Assuming active if exists in Supabase
+ is_superuser=False, # Default to False unless explicitly set
+ created_at=user.created_at,
+ updated_at=user.updated_at,
+ is_verified=user.email_confirmed_at is not None,
+ name=user.user_metadata.get("full_name"),
+ # Set other optional fields if available in user metadata
+ )
+
+ else:
+ raise R2RException(status_code=401, message="Invalid token")
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ # Check if user is active
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ # Use Supabase client to update user password
+ if self.supabase.auth.update(user.id, {"password": new_password}):
+ return {"message": "Password changed successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Failed to change password"
+ )
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ # Use Supabase client to send password reset email
+ if self.supabase.auth.send_password_reset_email(email):
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ else:
+ raise R2RException(
+ status_code=400, message="Failed to send password reset email"
+ )
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ # Use Supabase client to reset password with token
+ if self.supabase.auth.reset_password_for_email(
+ reset_token, new_password
+ ):
+ return {"message": "Password reset successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Invalid or expired reset token"
+ )
+
+ async def logout(self, token: str) -> dict[str, str]:
+ # Use Supabase client to logout user and revoke token
+ self.supabase.auth.sign_out(token)
+ return {"message": "Logged out successfully"}
+
+ async def clean_expired_blacklisted_tokens(self):
+ # Not applicable for Supabase, tokens are managed by Supabase
+ pass
+
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("send_reset_email is not used with Supabase")
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py
new file mode 100644
index 00000000..e509f990
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py
@@ -0,0 +1,9 @@
+from .bcrypt import BcryptCryptoConfig, BCryptCryptoProvider
+from .nacl import NaClCryptoConfig, NaClCryptoProvider
+
+__all__ = [
+ "BCryptCryptoProvider",
+ "BcryptCryptoConfig",
+ "NaClCryptoConfig",
+ "NaClCryptoProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py b/.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py
new file mode 100644
index 00000000..9c39977c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py
@@ -0,0 +1,195 @@
+import base64
+import logging
+import os
+from abc import ABC
+from datetime import datetime, timezone
+from typing import Optional, Tuple
+
+import bcrypt
+import jwt
+import nacl.encoding
+import nacl.exceptions
+import nacl.signing
+import nacl.utils
+
+from core.base import CryptoConfig, CryptoProvider
+
+DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
+
+
+class BcryptCryptoConfig(CryptoConfig):
+ provider: str = "bcrypt"
+ # Number of rounds for bcrypt (increasing this makes hashing slower but more secure)
+ bcrypt_rounds: int = 12
+ secret_key: Optional[str] = None
+ api_key_bytes: int = 32 # Length of raw API keys
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["bcrypt"]
+
+ def validate_config(self) -> None:
+ super().validate_config()
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Unsupported crypto provider: {self.provider}")
+ if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31:
+ raise ValueError("bcrypt_rounds must be between 4 and 31")
+
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ try:
+ # First try to decode as base64 (new format)
+ stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
+ except Exception:
+ # If that fails, treat as raw bcrypt hash (old format)
+ stored_hash = hashed_password.encode("utf-8")
+
+ return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
+
+
+class BCryptCryptoProvider(CryptoProvider, ABC):
+ def __init__(self, config: BcryptCryptoConfig):
+ if not isinstance(config, BcryptCryptoConfig):
+ raise ValueError(
+ "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig"
+ )
+ logging.info("Initializing BcryptCryptoProvider")
+ super().__init__(config)
+ self.config: BcryptCryptoConfig = config
+
+ # Load the secret key for JWT
+ # No fallback defaults: fail if not provided
+ self.secret_key = (
+ config.secret_key
+ or os.getenv("R2R_SECRET_KEY")
+ or DEFAULT_BCRYPT_SECRET_KEY
+ )
+ if not self.secret_key:
+ raise ValueError(
+ "No secret key provided for BcryptCryptoProvider."
+ )
+
+ def get_password_hash(self, password: str) -> str:
+ # Bcrypt expects bytes
+ password_bytes = password.encode("utf-8")
+ hashed = bcrypt.hashpw(
+ password_bytes, bcrypt.gensalt(rounds=self.config.bcrypt_rounds)
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ try:
+ # First try to decode as base64 (new format)
+ stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
+ if not stored_hash.startswith(b"$2b$"): # Valid bcrypt hash prefix
+ stored_hash = hashed_password.encode("utf-8")
+ except Exception:
+ # Otherwise raw bcrypt hash (old format)
+ stored_hash = hashed_password.encode("utf-8")
+
+ try:
+ return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
+ except ValueError as e:
+ if "Invalid salt" in str(e):
+ # If it's an invalid salt, the hash format is wrong - try the other format
+ try:
+ stored_hash = (
+ hashed_password
+ if isinstance(hashed_password, bytes)
+ else hashed_password.encode("utf-8")
+ )
+ return bcrypt.checkpw(
+ plain_password.encode("utf-8"), stored_hash
+ )
+ except ValueError:
+ return False
+ raise
+
+ def generate_verification_code(self, length: int = 32) -> str:
+ random_bytes = nacl.utils.random(length)
+ return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
+
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ signing_key = nacl.signing.SigningKey.generate()
+ verify_key = signing_key.verify_key
+
+ # Generate unique key_id
+ key_entropy = nacl.utils.random(16)
+ key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"
+
+ private_key = base64.b64encode(bytes(signing_key)).decode()
+ public_key = base64.b64encode(bytes(verify_key)).decode()
+ return key_id, private_key, public_key
+
+ def sign_request(self, private_key: str, data: str) -> str:
+ try:
+ key_bytes = base64.b64decode(private_key)
+ signing_key = nacl.signing.SigningKey(key_bytes)
+ signature = signing_key.sign(data.encode())
+ return base64.b64encode(signature.signature).decode()
+ except Exception as e:
+ raise ValueError(
+ f"Invalid private key or signing error: {str(e)}"
+ ) from e
+
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ try:
+ key_bytes = base64.b64decode(public_key)
+ verify_key = nacl.signing.VerifyKey(key_bytes)
+ signature_bytes = base64.b64decode(signature)
+ verify_key.verify(data.encode(), signature_bytes)
+ return True
+ except (nacl.exceptions.BadSignatureError, ValueError):
+ return False
+
+ def generate_api_key(self) -> Tuple[str, str]:
+ # Similar approach as with NaCl provider:
+ key_id_bytes = nacl.utils.random(16)
+ key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
+
+ # Generate raw API key
+ raw_api_key = base64.urlsafe_b64encode(
+ nacl.utils.random(self.config.api_key_bytes)
+ ).decode()
+ return key_id, raw_api_key
+
+ def hash_api_key(self, raw_api_key: str) -> str:
+ # Hash with bcrypt
+ hashed = bcrypt.hashpw(
+ raw_api_key.encode("utf-8"),
+ bcrypt.gensalt(rounds=self.config.bcrypt_rounds),
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
+ return bcrypt.checkpw(raw_api_key.encode("utf-8"), stored_hash)
+
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ now = datetime.now(timezone.utc)
+ to_encode = {
+ **data,
+ "exp": expiry.timestamp(),
+ "iat": now.timestamp(),
+ "nbf": now.timestamp(),
+ "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ }
+ return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
+
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ try:
+ payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
+ exp = payload.get("exp")
+ if exp is None or datetime.fromtimestamp(
+ exp, tz=timezone.utc
+ ) < datetime.now(timezone.utc):
+ return None
+ return payload
+ except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
+ return None
diff --git a/.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py b/.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py
new file mode 100644
index 00000000..63232565
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py
@@ -0,0 +1,181 @@
+import base64
+import logging
+import os
+import string
+from datetime import datetime, timezone
+from typing import Optional, Tuple
+
+import jwt
+import nacl.encoding
+import nacl.exceptions
+import nacl.pwhash
+import nacl.signing
+from nacl.exceptions import BadSignatureError
+from nacl.pwhash import argon2i
+
+from core.base import CryptoConfig, CryptoProvider
+
+DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
+
+
+def encode_bytes_readable(random_bytes: bytes, chars: str) -> str:
+ """Convert random bytes to a readable string using the given character
+ set."""
+ # Each byte gives us 8 bits of randomness
+ # We use modulo to map each byte to our character set
+ result = []
+ for byte in random_bytes:
+ # Use modulo to map the byte (0-255) to our character set length
+ idx = byte % len(chars)
+ result.append(chars[idx])
+ return "".join(result)
+
+
+class NaClCryptoConfig(CryptoConfig):
+ provider: str = "nacl"
+ # Interactive parameters for password ops (fast)
+ ops_limit: int = argon2i.OPSLIMIT_MIN
+ mem_limit: int = argon2i.MEMLIMIT_MIN
+ # Sensitive parameters for API key generation (slow but more secure)
+ api_ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
+ api_mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
+ api_key_bytes: int = 32
+ secret_key: Optional[str] = None
+
+
+class NaClCryptoProvider(CryptoProvider):
+ def __init__(self, config: NaClCryptoConfig):
+ if not isinstance(config, NaClCryptoConfig):
+ raise ValueError(
+ "NaClCryptoProvider must be initialized with a NaClCryptoConfig"
+ )
+ super().__init__(config)
+ self.config: NaClCryptoConfig = config
+ logging.info("Initializing NaClCryptoProvider")
+
+ # Securely load the secret key for JWT
+ # Priority: config.secret_key > environment variable > default
+ self.secret_key = (
+ config.secret_key
+ or os.getenv("R2R_SECRET_KEY")
+ or DEFAULT_NACL_SECRET_KEY
+ )
+
+ def get_password_hash(self, password: str) -> str:
+ password_bytes = password.encode("utf-8")
+ hashed = nacl.pwhash.argon2i.str(
+ password_bytes,
+ opslimit=self.config.ops_limit,
+ memlimit=self.config.mem_limit,
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ try:
+ stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
+ nacl.pwhash.verify(stored_hash, plain_password.encode("utf-8"))
+ return True
+ except nacl.exceptions.InvalidkeyError:
+ return False
+
+ def generate_verification_code(self, length: int = 32) -> str:
+ random_bytes = nacl.utils.random(length)
+ return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
+
+ def generate_api_key(self) -> Tuple[str, str]:
+ # Define our character set (excluding ambiguous characters)
+ chars = string.ascii_letters.replace("l", "").replace("I", "").replace(
+ "O", ""
+ ) + string.digits.replace("0", "").replace("1", "")
+
+ # Generate a unique key_id
+ key_id_bytes = nacl.utils.random(16) # 16 random bytes
+ key_id = f"pk_{encode_bytes_readable(key_id_bytes, chars)}"
+
+ # Generate a high-entropy API key
+ raw_api_key = f"sk_{encode_bytes_readable(nacl.utils.random(self.config.api_key_bytes), chars)}"
+
+ # The caller will store the hashed version in the database
+ return key_id, raw_api_key
+
+ def hash_api_key(self, raw_api_key: str) -> str:
+ hashed = nacl.pwhash.argon2i.str(
+ raw_api_key.encode("utf-8"),
+ opslimit=self.config.api_ops_limit,
+ memlimit=self.config.api_mem_limit,
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ try:
+ stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
+ nacl.pwhash.verify(stored_hash, raw_api_key.encode("utf-8"))
+ return True
+ except nacl.exceptions.InvalidkeyError:
+ return False
+
+ def sign_request(self, private_key: str, data: str) -> str:
+ try:
+ key_bytes = base64.b64decode(private_key)
+ signing_key = nacl.signing.SigningKey(key_bytes)
+ signature = signing_key.sign(data.encode())
+ return base64.b64encode(signature.signature).decode()
+ except Exception as e:
+ raise ValueError(
+ f"Invalid private key or signing error: {str(e)}"
+ ) from e
+
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ try:
+ key_bytes = base64.b64decode(public_key)
+ verify_key = nacl.signing.VerifyKey(key_bytes)
+ signature_bytes = base64.b64decode(signature)
+ verify_key.verify(data.encode(), signature_bytes)
+ return True
+ except (BadSignatureError, ValueError):
+ return False
+
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ """Generate a secure token using JWT with HS256.
+
+ The secret_key is used for symmetrical signing.
+ """
+ now = datetime.now(timezone.utc)
+ to_encode = {
+ **data,
+ "exp": expiry.timestamp(),
+ "iat": now.timestamp(),
+ "nbf": now.timestamp(),
+ "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ }
+
+ return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
+
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ """Verify a secure token using the shared secret_key and JWT."""
+ try:
+ payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
+ exp = payload.get("exp")
+ if exp is None or datetime.fromtimestamp(
+ exp, tz=timezone.utc
+ ) < datetime.now(timezone.utc):
+ return None
+ return payload
+ except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
+ return None
+
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ signing_key = nacl.signing.SigningKey.generate()
+ private_key_b64 = base64.b64encode(signing_key.encode()).decode()
+ public_key_b64 = base64.b64encode(
+ signing_key.verify_key.encode()
+ ).decode()
+ # Generate a unique key_id
+ key_id_bytes = nacl.utils.random(16)
+ key_id = f"sign_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
+ return (key_id, private_key_b64, public_key_b64)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/database/__init__.py
new file mode 100644
index 00000000..72e6cba8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/__init__.py
@@ -0,0 +1,5 @@
+from .postgres import PostgresDatabaseProvider
+
+__all__ = [
+ "PostgresDatabaseProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/base.py b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
new file mode 100644
index 00000000..c70c1352
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
@@ -0,0 +1,247 @@
+import asyncio
+import logging
+import textwrap
+from contextlib import asynccontextmanager
+from typing import Optional
+
+import asyncpg
+
+from core.base.providers import DatabaseConnectionManager
+
+logger = logging.getLogger()
+
+
+class SemaphoreConnectionPool:
+ def __init__(self, connection_string, postgres_configuration_settings):
+ self.connection_string = connection_string
+ self.postgres_configuration_settings = postgres_configuration_settings
+
+ async def initialize(self):
+ try:
+ logger.info(
+ f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
+ )
+
+ self.semaphore = asyncio.Semaphore(
+ int(self.postgres_configuration_settings.max_connections * 0.9)
+ )
+
+ self.pool = await asyncpg.create_pool(
+ self.connection_string,
+ max_size=self.postgres_configuration_settings.max_connections,
+ statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
+ )
+
+ logger.info(
+ "Successfully connected to Postgres database and created connection pool."
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Error {e} occurred while attempting to connect to relational database."
+ ) from e
+
+ @asynccontextmanager
+ async def get_connection(self):
+ async with self.semaphore:
+ async with self.pool.acquire() as conn:
+ yield conn
+
+ async def close(self):
+ await self.pool.close()
+
+
+class QueryBuilder:
+ def __init__(self, table_name: str):
+ self.table_name = table_name
+ self.conditions: list[str] = []
+ self.params: list = []
+ self.select_fields = "*"
+ self.operation = "SELECT"
+ self.limit_value: Optional[int] = None
+ self.offset_value: Optional[int] = None
+ self.order_by_fields: Optional[str] = None
+ self.returning_fields: Optional[list[str]] = None
+ self.insert_data: Optional[dict] = None
+ self.update_data: Optional[dict] = None
+ self.param_counter = 1
+
+ def select(self, fields: list[str]):
+ self.select_fields = ", ".join(fields)
+ return self
+
+ def insert(self, data: dict):
+ self.operation = "INSERT"
+ self.insert_data = data
+ return self
+
+ def update(self, data: dict):
+ self.operation = "UPDATE"
+ self.update_data = data
+ return self
+
+ def delete(self):
+ self.operation = "DELETE"
+ return self
+
+ def where(self, condition: str):
+ self.conditions.append(condition)
+ return self
+
+ def limit(self, value: Optional[int]):
+ self.limit_value = value
+ return self
+
+ def offset(self, value: int):
+ self.offset_value = value
+ return self
+
+ def order_by(self, fields: str):
+ self.order_by_fields = fields
+ return self
+
+ def returning(self, fields: list[str]):
+ self.returning_fields = fields
+ return self
+
+ def build(self):
+ if self.operation == "SELECT":
+ query = f"SELECT {self.select_fields} FROM {self.table_name}"
+
+ elif self.operation == "INSERT":
+ columns = ", ".join(self.insert_data.keys())
+ placeholders = ", ".join(
+ f"${i}" for i in range(1, len(self.insert_data) + 1)
+ )
+ query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
+ self.params.extend(list(self.insert_data.values()))
+
+ elif self.operation == "UPDATE":
+ set_clauses = []
+ for i, (key, value) in enumerate(
+ self.update_data.items(), start=len(self.params) + 1
+ ):
+ set_clauses.append(f"{key} = ${i}")
+ self.params.append(value)
+ query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
+
+ elif self.operation == "DELETE":
+ query = f"DELETE FROM {self.table_name}"
+
+ else:
+ raise ValueError(f"Unsupported operation: {self.operation}")
+
+ if self.conditions:
+ query += " WHERE " + " AND ".join(self.conditions)
+
+ if self.order_by_fields and self.operation == "SELECT":
+ query += f" ORDER BY {self.order_by_fields}"
+
+ if self.offset_value is not None:
+ query += f" OFFSET {self.offset_value}"
+
+ if self.limit_value is not None:
+ query += f" LIMIT {self.limit_value}"
+
+ if self.returning_fields:
+ query += f" RETURNING {', '.join(self.returning_fields)}"
+
+ return query, self.params
+
+
+class PostgresConnectionManager(DatabaseConnectionManager):
+ def __init__(self):
+ self.pool: Optional[SemaphoreConnectionPool] = None
+
+ async def initialize(self, pool: SemaphoreConnectionPool):
+ self.pool = pool
+
+ async def execute_query(self, query, params=None, isolation_level=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ if isolation_level:
+ async with conn.transaction(isolation=isolation_level):
+ if params:
+ return await conn.execute(query, *params)
+ else:
+ return await conn.execute(query)
+ else:
+ if params:
+ return await conn.execute(query, *params)
+ else:
+ return await conn.execute(query)
+
+ async def execute_many(self, query, params=None, batch_size=1000):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ if params:
+ results = []
+ for i in range(0, len(params), batch_size):
+ param_batch = params[i : i + batch_size]
+ result = await conn.executemany(query, param_batch)
+ results.append(result)
+ return results
+ else:
+ return await conn.executemany(query)
+
+ async def fetch_query(self, query, params=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ try:
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ return (
+ await conn.fetch(query, *params)
+ if params
+ else await conn.fetch(query)
+ )
+ except asyncpg.exceptions.DuplicatePreparedStatementError:
+ error_msg = textwrap.dedent("""
+ Database Configuration Error
+
+ Your database provider does not support statement caching.
+
+ To fix this, either:
+ • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
+ • Add statement_cache_size = 0 to your database configuration:
+
+ [database.postgres_configuration_settings]
+ statement_cache_size = 0
+
+ This is required when using connection poolers like PgBouncer or
+ managed database services like Supabase.
+ """).strip()
+ raise ValueError(error_msg) from None
+
+ async def fetchrow_query(self, query, params=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ if params:
+ return await conn.fetchrow(query, *params)
+ else:
+ return await conn.fetchrow(query)
+
+ @asynccontextmanager
+ async def transaction(self, isolation_level=None):
+ """Async context manager for database transactions.
+
+ Args:
+ isolation_level: Optional isolation level for the transaction
+
+ Yields:
+ The connection manager instance for use within the transaction
+ """
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction(isolation=isolation_level):
+ try:
+ yield self
+ except Exception as e:
+ logger.error(f"Transaction failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py b/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py
new file mode 100644
index 00000000..177f3395
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py
@@ -0,0 +1,1316 @@
+import copy
+import json
+import logging
+import math
+import time
+import uuid
+from typing import Any, Optional, TypedDict
+from uuid import UUID
+
+import numpy as np
+
+from core.base import (
+ ChunkSearchResult,
+ Handler,
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexMeasure,
+ IndexMethod,
+ R2RException,
+ SearchSettings,
+ VectorEntry,
+ VectorQuantizationType,
+ VectorTableName,
+)
+from core.base.utils import _decorate_vector_type
+
+from .base import PostgresConnectionManager
+from .filters import apply_filters
+
+logger = logging.getLogger()
+
+
+def psql_quote_literal(value: str) -> str:
+ """Safely quote a string literal for PostgreSQL to prevent SQL injection.
+
+ This is a simple implementation - in production, you should use proper parameterization
+ or your database driver's quoting functions.
+ """
+ return "'" + value.replace("'", "''") + "'"
+
+
+def index_measure_to_ops(
+ measure: IndexMeasure,
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+):
+ return _decorate_vector_type(measure.ops, quantization_type)
+
+
+def quantize_vector_to_binary(
+ vector: list[float] | np.ndarray,
+ threshold: float = 0.0,
+) -> bytes:
+ """Quantizes a float vector to a binary vector string for PostgreSQL bit
+ type. Used when quantization_type is INT1.
+
+ Args:
+ vector (List[float] | np.ndarray): Input vector of floats
+ threshold (float, optional): Threshold for binarization. Defaults to 0.0.
+
+ Returns:
+ str: Binary string representation for PostgreSQL bit type
+ """
+ # Convert input to numpy array if it isn't already
+ if not isinstance(vector, np.ndarray):
+ vector = np.array(vector)
+
+ # Convert to binary (1 where value > threshold, 0 otherwise)
+ binary_vector = (vector > threshold).astype(int)
+
+ # Convert to string of 1s and 0s
+ # Convert to string of 1s and 0s, then to bytes
+ binary_string = "".join(map(str, binary_vector))
+ return binary_string.encode("ascii")
+
+
+class HybridSearchIntermediateResult(TypedDict):
+ semantic_rank: int
+ full_text_rank: int
+ data: ChunkSearchResult
+ rrf_score: float
+
+
+class PostgresChunksHandler(Handler):
+ TABLE_NAME = VectorTableName.CHUNKS
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ dimension: int | float,
+ quantization_type: VectorQuantizationType,
+ ):
+ super().__init__(project_name, connection_manager)
+ self.dimension = dimension
+ self.quantization_type = quantization_type
+
+ async def create_tables(self):
+ # First check if table already exists and validate dimensions
+ table_exists_query = """
+ SELECT EXISTS (
+ SELECT FROM pg_tables
+ WHERE schemaname = $1
+ AND tablename = $2
+ );
+ """
+ table_name = VectorTableName.CHUNKS
+ table_exists = await self.connection_manager.fetch_query(
+ table_exists_query, (self.project_name, table_name)
+ )
+
+ if len(table_exists) > 0 and table_exists[0]["exists"]:
+ # Table exists, check vector dimension
+ vector_dim_query = """
+ SELECT a.atttypmod as dimension
+ FROM pg_attribute a
+ JOIN pg_class c ON a.attrelid = c.oid
+ JOIN pg_namespace n ON c.relnamespace = n.oid
+ WHERE n.nspname = $1
+ AND c.relname = $2
+ AND a.attname = 'vec';
+ """
+
+ vector_dim_result = await self.connection_manager.fetch_query(
+ vector_dim_query, (self.project_name, table_name)
+ )
+
+ if vector_dim_result and len(vector_dim_result) > 0:
+ existing_dimension = vector_dim_result[0]["dimension"]
+ # In pgvector, dimension is stored as atttypmod - 4
+ if existing_dimension > 0: # If it has a specific dimension
+ # Compare with provided dimension
+ if (
+ self.dimension > 0
+ and existing_dimension != self.dimension
+ ):
+ raise ValueError(
+ f"Dimension mismatch: Table '{self.project_name}.{table_name}' was created with "
+ f"dimension {existing_dimension}, but {self.dimension} was provided. "
+ f"You must use the same dimension for existing tables."
+ )
+
+ # Check for old table name
+ check_query = """
+ SELECT EXISTS (
+ SELECT FROM pg_tables
+ WHERE schemaname = $1
+ AND tablename = $2
+ );
+ """
+ old_table_exists = await self.connection_manager.fetch_query(
+ check_query, (self.project_name, self.project_name)
+ )
+
+ if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
+ raise ValueError(
+ f"Found old vector table '{self.project_name}.{self.project_name}'. "
+ "Please run `r2r db upgrade` with the CLI, or to run manually, "
+ "run in R2R/py/migrations with 'alembic upgrade head' to update "
+ "your database schema to the new version."
+ )
+
+ binary_col = (
+ ""
+ if self.quantization_type != VectorQuantizationType.INT1
+ else f"vec_binary bit({self.dimension}),"
+ )
+
+ if self.dimension > 0:
+ vector_col = f"vec vector({self.dimension})"
+ else:
+ vector_col = "vec vector"
+
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY,
+ document_id UUID,
+ owner_id UUID,
+ collection_ids UUID[],
+ {vector_col},
+ {binary_col}
+ text TEXT,
+ metadata JSONB,
+ fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
+ );
+ CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
+ CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
+ CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
+ CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
+ """
+
+ await self.connection_manager.execute_query(query)
+
+ async def upsert(self, entry: VectorEntry) -> None:
+ """Upsert function that handles vector quantization only when
+ quantization_type is INT1.
+
+ Matches the table schema where vec_binary column only exists for INT1
+ quantization.
+ """
+ # Check the quantization type to determine which columns to use
+ if self.quantization_type == VectorQuantizationType.INT1:
+ bit_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+
+ # For quantized vectors, use vec_binary column
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ vec_binary = EXCLUDED.vec_binary,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+ await self.connection_manager.execute_query(
+ query,
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ quantize_vector_to_binary(
+ entry.vector.data
+ ), # Convert to binary
+ entry.text,
+ json.dumps(entry.metadata),
+ ),
+ )
+ else:
+ # For regular vectors, use vec column only
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+
+ await self.connection_manager.execute_query(
+ query,
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ entry.text,
+ json.dumps(entry.metadata),
+ ),
+ )
+
+ async def upsert_entries(self, entries: list[VectorEntry]) -> None:
+ """Batch upsert function that handles vector quantization only when
+ quantization_type is INT1.
+
+ Matches the table schema where vec_binary column only exists for INT1
+ quantization.
+ """
+ if self.quantization_type == VectorQuantizationType.INT1:
+ bit_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+
+ # For quantized vectors, use vec_binary column
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ vec_binary = EXCLUDED.vec_binary,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+ bin_params = [
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ quantize_vector_to_binary(
+ entry.vector.data
+ ), # Convert to binary
+ entry.text,
+ json.dumps(entry.metadata),
+ )
+ for entry in entries
+ ]
+ await self.connection_manager.execute_many(query, bin_params)
+
+ else:
+ # For regular vectors, use vec column only
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+ params = [
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ entry.text,
+ json.dumps(entry.metadata),
+ )
+ for entry in entries
+ ]
+
+ await self.connection_manager.execute_many(query, params)
+
+ async def semantic_search(
+ self, query_vector: list[float], search_settings: SearchSettings
+ ) -> list[ChunkSearchResult]:
+ try:
+ imeasure_obj = IndexMeasure(
+ search_settings.chunk_settings.index_measure
+ )
+ except ValueError:
+ raise ValueError("Invalid index measure") from None
+
+ table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
+ cols = [
+ f"{table_name}.id",
+ f"{table_name}.document_id",
+ f"{table_name}.owner_id",
+ f"{table_name}.collection_ids",
+ f"{table_name}.text",
+ ]
+
+ params: list[str | int | bytes] = []
+
+ # For binary vectors (INT1), implement two-stage search
+ if self.quantization_type == VectorQuantizationType.INT1:
+ # Convert query vector to binary format
+ binary_query = quantize_vector_to_binary(query_vector)
+ # TODO - Put depth multiplier in config / settings
+ extended_limit = (
+ search_settings.limit * 20
+ ) # Get 20x candidates for re-ranking
+
+ if (
+ imeasure_obj == IndexMeasure.hamming_distance
+ or imeasure_obj == IndexMeasure.jaccard_distance
+ ):
+ binary_search_measure_repr = imeasure_obj.pgvector_repr
+ else:
+ binary_search_measure_repr = (
+ IndexMeasure.hamming_distance.pgvector_repr
+ )
+
+ # Use binary column and binary-specific distance measures for first stage
+ bit_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
+ stage1_param = binary_query
+
+ cols.append(
+ f"{table_name}.vec"
+ ) # Need original vector for re-ranking
+ if search_settings.include_metadatas:
+ cols.append(f"{table_name}.metadata")
+
+ select_clause = ", ".join(cols)
+ where_clause = ""
+ params.append(stage1_param)
+
+ if search_settings.filters:
+ where_clause, params = apply_filters(
+ search_settings.filters, params, mode="where_clause"
+ )
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+
+ # First stage: Get candidates using binary search
+ query = f"""
+ WITH candidates AS (
+ SELECT {select_clause},
+ ({stage1_distance}) as binary_distance
+ FROM {table_name}
+ {where_clause}
+ ORDER BY {stage1_distance}
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ -- Second stage: Re-rank using original vectors
+ SELECT
+ id,
+ document_id,
+ owner_id,
+ collection_ids,
+ text,
+ {"metadata," if search_settings.include_metadatas else ""}
+ (vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
+ FROM candidates
+ ORDER BY distance
+ LIMIT ${len(params) + 3}
+ """
+
+ params.extend(
+ [
+ extended_limit, # First stage limit
+ search_settings.offset,
+ search_settings.limit, # Final limit
+ str(query_vector), # For re-ranking
+ ]
+ )
+
+ else:
+ # Standard float vector handling
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
+ query_param = str(query_vector)
+
+ if search_settings.include_scores:
+ cols.append(f"({distance_calc}) AS distance")
+ if search_settings.include_metadatas:
+ cols.append(f"{table_name}.metadata")
+
+ select_clause = ", ".join(cols)
+ where_clause = ""
+ params.append(query_param)
+
+ if search_settings.filters:
+ where_clause, new_params = apply_filters(
+ search_settings.filters,
+ params,
+ mode="where_clause", # Get just conditions without WHERE
+ )
+ params = new_params
+
+ query = f"""
+ SELECT {select_clause}
+ FROM {table_name}
+ {where_clause}
+ ORDER BY {distance_calc}
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ """
+ params.extend([search_settings.limit, search_settings.offset])
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ ChunkSearchResult(
+ id=UUID(str(result["id"])),
+ document_id=UUID(str(result["document_id"])),
+ owner_id=UUID(str(result["owner_id"])),
+ collection_ids=result["collection_ids"],
+ text=result["text"],
+ score=(
+ (1 - float(result["distance"]))
+ if "distance" in result
+ else -1
+ ),
+ metadata=(
+ json.loads(result["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ for result in results
+ ]
+
+ async def full_text_search(
+ self, query_text: str, search_settings: SearchSettings
+ ) -> list[ChunkSearchResult]:
+ conditions = []
+ params: list[str | int | bytes] = [query_text]
+
+ conditions.append("fts @@ websearch_to_tsquery('english', $1)")
+
+ if search_settings.filters:
+ filter_condition, params = apply_filters(
+ search_settings.filters, params, mode="condition_only"
+ )
+ if filter_condition:
+ conditions.append(filter_condition)
+
+ where_clause = "WHERE " + " AND ".join(conditions)
+
+ query = f"""
+ SELECT
+ id,
+ document_id,
+ owner_id,
+ collection_ids,
+ text,
+ metadata,
+ ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ {where_clause}
+ ORDER BY rank DESC
+ OFFSET ${len(params) + 1}
+ LIMIT ${len(params) + 2}
+ """
+
+ params.extend(
+ [
+ search_settings.offset,
+ search_settings.hybrid_settings.full_text_limit,
+ ]
+ )
+
+ results = await self.connection_manager.fetch_query(query, params)
+ return [
+ ChunkSearchResult(
+ id=UUID(str(r["id"])),
+ document_id=UUID(str(r["document_id"])),
+ owner_id=UUID(str(r["owner_id"])),
+ collection_ids=r["collection_ids"],
+ text=r["text"],
+ score=float(r["rank"]),
+ metadata=json.loads(r["metadata"]),
+ )
+ for r in results
+ ]
+
+ async def hybrid_search(
+ self,
+ query_text: str,
+ query_vector: list[float],
+ search_settings: SearchSettings,
+ *args,
+ **kwargs,
+ ) -> list[ChunkSearchResult]:
+ if search_settings.hybrid_settings is None:
+ raise ValueError(
+ "Please provide a valid `hybrid_settings` in the `search_settings`."
+ )
+ if (
+ search_settings.hybrid_settings.full_text_limit
+ < search_settings.limit
+ ):
+ raise ValueError(
+ "The `full_text_limit` must be greater than or equal to the `limit`."
+ )
+
+ semantic_settings = copy.deepcopy(search_settings)
+ semantic_settings.limit += search_settings.offset
+
+ full_text_settings = copy.deepcopy(search_settings)
+ full_text_settings.hybrid_settings.full_text_limit += (
+ search_settings.offset
+ )
+
+ semantic_results: list[ChunkSearchResult] = await self.semantic_search(
+ query_vector, semantic_settings
+ )
+ full_text_results: list[
+ ChunkSearchResult
+ ] = await self.full_text_search(query_text, full_text_settings)
+
+ semantic_limit = search_settings.limit
+ full_text_limit = search_settings.hybrid_settings.full_text_limit
+ semantic_weight = search_settings.hybrid_settings.semantic_weight
+ full_text_weight = search_settings.hybrid_settings.full_text_weight
+ rrf_k = search_settings.hybrid_settings.rrf_k
+
+ combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
+
+ for rank, result in enumerate(semantic_results, 1):
+ combined_results[result.id] = {
+ "semantic_rank": rank,
+ "full_text_rank": full_text_limit,
+ "data": result,
+ "rrf_score": 0.0, # Initialize with 0, will be calculated later
+ }
+
+ for rank, result in enumerate(full_text_results, 1):
+ if result.id in combined_results:
+ combined_results[result.id]["full_text_rank"] = rank
+ else:
+ combined_results[result.id] = {
+ "semantic_rank": semantic_limit,
+ "full_text_rank": rank,
+ "data": result,
+ "rrf_score": 0.0, # Initialize with 0, will be calculated later
+ }
+
+ combined_results = {
+ k: v
+ for k, v in combined_results.items()
+ if v["semantic_rank"] <= semantic_limit * 2
+ and v["full_text_rank"] <= full_text_limit * 2
+ }
+
+ for hyb_result in combined_results.values():
+ semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
+ full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
+ hyb_result["rrf_score"] = (
+ semantic_score * semantic_weight
+ + full_text_score * full_text_weight
+ ) / (semantic_weight + full_text_weight)
+
+ sorted_results = sorted(
+ combined_results.values(),
+ key=lambda x: x["rrf_score"],
+ reverse=True,
+ )
+ offset_results = sorted_results[
+ search_settings.offset : search_settings.offset
+ + search_settings.limit
+ ]
+
+ return [
+ ChunkSearchResult(
+ id=result["data"].id,
+ document_id=result["data"].document_id,
+ owner_id=result["data"].owner_id,
+ collection_ids=result["data"].collection_ids,
+ text=result["data"].text,
+ score=result["rrf_score"],
+ metadata={
+ **result["data"].metadata,
+ "semantic_rank": result["semantic_rank"],
+ "full_text_rank": result["full_text_rank"],
+ },
+ )
+ for result in offset_results
+ ]
+
+ async def delete(
+ self, filters: dict[str, Any]
+ ) -> dict[str, dict[str, str]]:
+ params: list[str | int | bytes] = []
+ where_clause, params = apply_filters(
+ filters, params, mode="condition_only"
+ )
+
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ RETURNING id, document_id, text;
+ """
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return {
+ str(result["id"]): {
+ "status": "deleted",
+ "id": str(result["id"]),
+ "document_id": str(result["document_id"]),
+ "text": result["text"],
+ }
+ for result in results
+ }
+
+ async def assign_document_chunks_to_collection(
+ self, document_id: UUID, collection_id: UUID
+ ) -> None:
+ query = f"""
+ UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ SET collection_ids = array_append(collection_ids, $1)
+ WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
+ """
+ return await self.connection_manager.execute_query(
+ query, (str(collection_id), str(document_id))
+ )
+
+ async def remove_document_from_collection_vector(
+ self, document_id: UUID, collection_id: UUID
+ ) -> None:
+ query = f"""
+ UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE document_id = $2;
+ """
+ await self.connection_manager.execute_query(
+ query, (collection_id, document_id)
+ )
+
+ async def delete_user_vector(self, owner_id: UUID) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE owner_id = $1;
+ """
+ await self.connection_manager.execute_query(query, (owner_id,))
+
+ async def delete_collection_vector(self, collection_id: UUID) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE $1 = ANY(collection_ids)
+ RETURNING collection_ids
+ """
+ await self.connection_manager.fetchrow_query(query, (collection_id,))
+ return None
+
+ async def list_document_chunks(
+ self,
+ document_id: UUID,
+ offset: int,
+ limit: int,
+ include_vectors: bool = False,
+ ) -> dict[str, Any]:
+ vector_select = ", vec" if include_vectors else ""
+ limit_clause = f"LIMIT {limit}" if limit > -1 else ""
+
+ query = f"""
+ SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ ORDER BY (metadata->>'chunk_order')::integer
+ OFFSET $2
+ {limit_clause};
+ """
+
+ params = [document_id, offset]
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ chunks = []
+ total = 0
+ if results:
+ total = results[0].get("total", 0)
+ chunks = [
+ {
+ "id": result["id"],
+ "document_id": result["document_id"],
+ "owner_id": result["owner_id"],
+ "collection_ids": result["collection_ids"],
+ "text": result["text"],
+ "metadata": json.loads(result["metadata"]),
+ "vector": (
+ json.loads(result["vec"]) if include_vectors else None
+ ),
+ }
+ for result in results
+ ]
+
+ return {"results": chunks, "total_entries": total}
+
+ async def get_chunk(self, id: UUID) -> dict:
+ query = f"""
+ SELECT id, document_id, owner_id, collection_ids, text, metadata
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE id = $1;
+ """
+
+ result = await self.connection_manager.fetchrow_query(query, (id,))
+
+ if result:
+ return {
+ "id": result["id"],
+ "document_id": result["document_id"],
+ "owner_id": result["owner_id"],
+ "collection_ids": result["collection_ids"],
+ "text": result["text"],
+ "metadata": json.loads(result["metadata"]),
+ }
+ raise R2RException(
+ message=f"Chunk with ID {id} not found", status_code=404
+ )
+
+ async def create_index(
+ self,
+ table_name: Optional[VectorTableName] = None,
+ index_measure: IndexMeasure = IndexMeasure.cosine_distance,
+ index_method: IndexMethod = IndexMethod.auto,
+ index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
+ index_name: Optional[str] = None,
+ index_column: Optional[str] = None,
+ concurrently: bool = True,
+ ) -> None:
+ """Creates an index for the collection.
+
+ Note:
+ When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
+ process that enables performant indexes to be built for large collections with low end
+ database hardware.
+
+ Those steps are:
+
+ - Creates a new table with a different name
+ - Randomly selects records from the existing table
+ - Inserts the random records from the existing table into the new table
+ - Creates the requested vector index on the new table
+ - Upserts all data from the existing table into the new table
+ - Drops the existing table
+ - Renames the new table to the existing tables name
+
+ If you create dependencies (like views) on the table that underpins
+ a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
+ it will succeed.
+
+ Args:
+ index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
+ index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
+ index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
+ index_name (str, optional): The name of the index to create. Defaults to None.
+ concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
+ Raises:
+ ValueError: If an invalid index method is used, or if *replace* is False and an index already exists.
+ """
+
+ if table_name == VectorTableName.CHUNKS:
+ table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
+ if index_column:
+ col_name = index_column
+ else:
+ col_name = (
+ "vec"
+ if (
+ index_measure != IndexMeasure.hamming_distance
+ and index_measure != IndexMeasure.jaccard_distance
+ )
+ else "vec_binary"
+ )
+ elif table_name == VectorTableName.ENTITIES_DOCUMENT:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.GRAPHS_ENTITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.COMMUNITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.COMMUNITIES}"
+ )
+ col_name = "embedding"
+ else:
+ raise ValueError("invalid table name")
+
+ if index_method not in (
+ IndexMethod.ivfflat,
+ IndexMethod.hnsw,
+ IndexMethod.auto,
+ ):
+ raise ValueError("invalid index method")
+
+ if index_arguments:
+ # Disallow case where user submits index arguments but uses the
+ # IndexMethod.auto index (index build arguments should only be
+ # used with a specific index)
+ if index_method == IndexMethod.auto:
+ raise ValueError(
+ "Index build parameters are not allowed when using the IndexMethod.auto index."
+ )
+ # Disallow case where user specifies one index type but submits
+ # index build arguments for the other index type
+ if (
+ isinstance(index_arguments, IndexArgsHNSW)
+ and index_method != IndexMethod.hnsw
+ ) or (
+ isinstance(index_arguments, IndexArgsIVFFlat)
+ and index_method != IndexMethod.ivfflat
+ ):
+ raise ValueError(
+ f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
+ )
+
+ if index_method == IndexMethod.auto:
+ index_method = IndexMethod.hnsw
+
+ ops = index_measure_to_ops(
+ index_measure # , quantization_type=self.quantization_type
+ )
+
+ if ops is None:
+ raise ValueError("Unknown index measure")
+
+ concurrently_sql = "CONCURRENTLY" if concurrently else ""
+
+ index_name = (
+ index_name
+ or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
+ )
+
+ create_index_sql = f"""
+ CREATE INDEX {concurrently_sql} {index_name}
+ ON {table_name_str}
+ USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
+ """
+
+ try:
+ if concurrently:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ # Disable automatic transaction management
+ await conn.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
+ )
+ await conn.execute(create_index_sql)
+ else:
+ # Non-concurrent index creation can use normal query execution
+ await self.connection_manager.execute_query(create_index_sql)
+ except Exception as e:
+ raise Exception(f"Failed to create index: {e}") from e
+ return None
+
+ async def list_indices(
+ self,
+ offset: int,
+ limit: int,
+ filters: Optional[dict[str, Any]] = None,
+ ) -> dict:
+ where_clauses = []
+ params: list[Any] = [self.project_name] # Start with schema name
+ param_count = 1
+
+ # Handle filtering
+ if filters:
+ if "table_name" in filters:
+ where_clauses.append(f"i.tablename = ${param_count + 1}")
+ params.append(filters["table_name"])
+ param_count += 1
+ if "index_method" in filters:
+ where_clauses.append(f"am.amname = ${param_count + 1}")
+ params.append(filters["index_method"])
+ param_count += 1
+ if "index_name" in filters:
+ where_clauses.append(
+ f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
+ )
+ params.append(f"%{filters['index_name']}%")
+ param_count += 1
+
+ where_clause = " AND ".join(where_clauses) if where_clauses else ""
+ if where_clause:
+ where_clause = f"AND {where_clause}"
+
+ query = f"""
+ WITH index_info AS (
+ SELECT
+ i.indexname as name,
+ i.tablename as table_name,
+ i.indexdef as definition,
+ am.amname as method,
+ pg_relation_size(c.oid) as size_in_bytes,
+ c.reltuples::bigint as row_estimate,
+ COALESCE(psat.idx_scan, 0) as number_of_scans,
+ COALESCE(psat.idx_tup_read, 0) as tuples_read,
+ COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
+ COUNT(*) OVER() as total_count
+ FROM pg_indexes i
+ JOIN pg_class c ON c.relname = i.indexname
+ JOIN pg_am am ON c.relam = am.oid
+ LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
+ AND psat.schemaname = i.schemaname
+ WHERE i.schemaname = $1
+ AND i.indexdef LIKE '%vector%'
+ {where_clause}
+ )
+ SELECT *
+ FROM index_info
+ ORDER BY name
+ LIMIT ${param_count + 1}
+ OFFSET ${param_count + 2}
+ """
+
+ # Add limit and offset to params
+ params.extend([limit, offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ indices = []
+ total_entries = 0
+
+ if results:
+ total_entries = results[0]["total_count"]
+ for result in results:
+ index_info = {
+ "name": result["name"],
+ "table_name": result["table_name"],
+ "definition": result["definition"],
+ "size_in_bytes": result["size_in_bytes"],
+ "row_estimate": result["row_estimate"],
+ "number_of_scans": result["number_of_scans"],
+ "tuples_read": result["tuples_read"],
+ "tuples_fetched": result["tuples_fetched"],
+ }
+ indices.append(index_info)
+
+ return {"indices": indices, "total_entries": total_entries}
+
+ async def delete_index(
+ self,
+ index_name: str,
+ table_name: Optional[VectorTableName] = None,
+ concurrently: bool = True,
+ ) -> None:
+ """Deletes a vector index.
+
+ Args:
+ index_name (str): Name of the index to delete
+ table_name (VectorTableName, optional): Table the index belongs to
+ concurrently (bool): Whether to drop the index concurrently
+
+ Raises:
+ ValueError: If table name is invalid or index doesn't exist
+ Exception: If index deletion fails
+ """
+ # Validate table name and get column name
+ if table_name == VectorTableName.CHUNKS:
+ table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
+ col_name = "vec"
+ elif table_name == VectorTableName.ENTITIES_DOCUMENT:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.GRAPHS_ENTITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.COMMUNITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.COMMUNITIES}"
+ )
+ col_name = "description_embedding"
+ else:
+ raise ValueError("invalid table name")
+
+ # Extract schema and base table name
+ schema_name, base_table_name = table_name_str.split(".")
+
+ # Verify index exists and is a vector index
+ query = """
+ SELECT indexdef
+ FROM pg_indexes
+ WHERE indexname = $1
+ AND schemaname = $2
+ AND tablename = $3
+ AND indexdef LIKE $4
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, (index_name, schema_name, base_table_name, f"%({col_name}%")
+ )
+
+ if not result:
+ raise ValueError(
+ f"Vector index '{index_name}' does not exist on table {table_name_str}"
+ )
+
+ # Drop the index
+ concurrently_sql = "CONCURRENTLY" if concurrently else ""
+ drop_query = (
+ f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
+ )
+
+ try:
+ if concurrently:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ # Disable automatic transaction management
+ await conn.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
+ )
+ await conn.execute(drop_query)
+ else:
+ await self.connection_manager.execute_query(drop_query)
+ except Exception as e:
+ raise Exception(f"Failed to delete index: {e}") from e
+
+ async def list_chunks(
+ self,
+ offset: int,
+ limit: int,
+ filters: Optional[dict[str, Any]] = None,
+ include_vectors: bool = False,
+ ) -> dict[str, Any]:
+ """List chunks with pagination support.
+
+ Args:
+ offset (int, optional): Number of records to skip. Defaults to 0.
+ limit (int, optional): Maximum number of records to return. Defaults to 10.
+ filters (dict, optional): Dictionary of filters to apply. Defaults to None.
+ include_vectors (bool, optional): Whether to include vector data. Defaults to False.
+
+ Returns:
+ dict: Dictionary containing:
+ - results: List of chunk records
+ - total_entries: Total number of chunks matching the filters
+ """
+ vector_select = ", vec" if include_vectors else ""
+ select_clause = f"""
+ id, document_id, owner_id, collection_ids,
+ text, metadata{vector_select}, COUNT(*) OVER() AS total_entries
+ """
+
+ params: list[str | int | bytes] = []
+ where_clause = ""
+ if filters:
+ where_clause, params = apply_filters(
+ filters, params, mode="where_clause"
+ )
+
+ query = f"""
+ SELECT {select_clause}
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ {where_clause}
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ """
+
+ params.extend([limit, offset])
+
+ # Execute the query
+ results = await self.connection_manager.fetch_query(query, params)
+
+ # Process results
+ chunks = []
+ total_entries = 0
+ if results:
+ total_entries = results[0].get("total_entries", 0)
+ chunks = [
+ {
+ "id": str(result["id"]),
+ "document_id": str(result["document_id"]),
+ "owner_id": str(result["owner_id"]),
+ "collection_ids": result["collection_ids"],
+ "text": result["text"],
+ "metadata": json.loads(result["metadata"]),
+ "vector": (
+ json.loads(result["vec"]) if include_vectors else None
+ ),
+ }
+ for result in results
+ ]
+
+ return {"results": chunks, "total_entries": total_entries}
+
+ async def search_documents(
+ self,
+ query_text: str,
+ settings: SearchSettings,
+ ) -> list[dict[str, Any]]:
+ """Search for documents based on their metadata fields and/or body
+ text. Joins with documents table to get complete document metadata.
+
+ Args:
+ query_text (str): The search query text
+ settings (SearchSettings): Search settings including search preferences and filters
+
+ Returns:
+ list[dict[str, Any]]: List of documents with their search scores and complete metadata
+ """
+ where_clauses = []
+ params: list[str | int | bytes] = [query_text]
+
+ search_over_body = getattr(settings, "search_over_body", True)
+ search_over_metadata = getattr(settings, "search_over_metadata", True)
+ metadata_weight = getattr(settings, "metadata_weight", 3.0)
+ title_weight = getattr(settings, "title_weight", 1.0)
+ metadata_keys = getattr(
+ settings, "metadata_keys", ["title", "description"]
+ )
+
+ # Build the dynamic metadata field search expression
+ metadata_fields_expr = " || ' ' || ".join(
+ [
+ f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
+ for key in metadata_keys # type: ignore
+ ]
+ )
+
+ query = f"""
+ WITH
+ -- Metadata search scores
+ metadata_scores AS (
+ SELECT DISTINCT ON (v.document_id)
+ v.document_id,
+ d.metadata as doc_metadata,
+ CASE WHEN $1 = '' THEN 0.0
+ ELSE
+ ts_rank_cd(
+ setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
+ websearch_to_tsquery('english', $1),
+ 32
+ )
+ END as metadata_rank
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
+ LEFT JOIN {self._get_table_name("documents")} d ON v.document_id = d.id
+ WHERE v.metadata IS NOT NULL
+ ),
+ -- Body search scores
+ body_scores AS (
+ SELECT
+ document_id,
+ AVG(
+ ts_rank_cd(
+ setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
+ websearch_to_tsquery('english', $1),
+ 32
+ )
+ ) as body_rank
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE $1 != ''
+ {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""}
+ GROUP BY document_id
+ ),
+ -- Combined scores with document metadata
+ combined_scores AS (
+ SELECT
+ COALESCE(m.document_id, b.document_id) as document_id,
+ m.doc_metadata as metadata,
+ COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
+ COALESCE(b.body_rank, 0) as debug_body_rank,
+ CASE
+ WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN
+ COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight}
+ WHEN {str(search_over_metadata).lower()} THEN
+ COALESCE(m.metadata_rank, 0)
+ WHEN {str(search_over_body).lower()} THEN
+ COALESCE(b.body_rank, 0)
+ ELSE 0
+ END as rank
+ FROM metadata_scores m
+ FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
+ WHERE (
+ ($1 = '') OR
+ ({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR
+ ({str(search_over_body).lower()} AND b.body_rank > 0)
+ )
+ """
+
+ # Add any additional filters
+ if settings.filters:
+ filter_clause, params = apply_filters(settings.filters, params)
+ where_clauses.append(filter_clause)
+
+ if where_clauses:
+ query += f" AND {' AND '.join(where_clauses)}"
+
+ query += """
+ )
+ SELECT
+ document_id,
+ metadata,
+ rank as score,
+ debug_metadata_rank,
+ debug_body_rank
+ FROM combined_scores
+ WHERE rank > 0
+ ORDER BY rank DESC
+ OFFSET ${offset_param} LIMIT ${limit_param}
+ """.format(
+ offset_param=len(params) + 1,
+ limit_param=len(params) + 2,
+ )
+
+ # Add offset and limit to params
+ params.extend([settings.offset, settings.limit])
+
+ # Execute query
+ results = await self.connection_manager.fetch_query(query, params)
+
+ # Format results with complete document metadata
+ return [
+ {
+ "document_id": str(r["document_id"]),
+ "metadata": (
+ json.loads(r["metadata"])
+ if isinstance(r["metadata"], str)
+ else r["metadata"]
+ ),
+ "score": float(r["score"]),
+ "debug_metadata_rank": float(r["debug_metadata_rank"]),
+ "debug_body_rank": float(r["debug_body_rank"]),
+ }
+ for r in results
+ ]
+
+ def _get_index_options(
+ self,
+ method: IndexMethod,
+ index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
+ ) -> str:
+ if method == IndexMethod.ivfflat:
+ if isinstance(index_arguments, IndexArgsIVFFlat):
+ return f"WITH (lists={index_arguments.n_lists})"
+ else:
+ # Default value if no arguments provided
+ return "WITH (lists=100)"
+ elif method == IndexMethod.hnsw:
+ if isinstance(index_arguments, IndexArgsHNSW):
+ return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
+ else:
+ # Default values if no arguments provided
+ return "WITH (m=16, ef_construction=64)"
+ else:
+ return "" # No options for other methods
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/collections.py b/.venv/lib/python3.12/site-packages/core/providers/database/collections.py
new file mode 100644
index 00000000..72adaff2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/collections.py
@@ -0,0 +1,701 @@
+import csv
+import json
+import logging
+import tempfile
+from typing import IO, Any, Optional
+from uuid import UUID, uuid4
+
+from asyncpg.exceptions import UniqueViolationError
+from fastapi import HTTPException
+
+from core.base import (
+ DatabaseConfig,
+ GraphExtractionStatus,
+ Handler,
+ R2RException,
+ generate_default_user_collection_id,
+)
+from core.base.abstractions import (
+ DocumentResponse,
+ DocumentType,
+ IngestionStatus,
+)
+from core.base.api.models import CollectionResponse
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresCollectionsHandler(Handler):
+ TABLE_NAME = "collections"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ config: DatabaseConfig,
+ ):
+ self.config = config
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self) -> None:
+ # 1. Create the table if it does not exist.
+ create_table_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ owner_id UUID,
+ name TEXT NOT NULL,
+ description TEXT,
+ graph_sync_status TEXT DEFAULT 'pending',
+ graph_cluster_status TEXT DEFAULT 'pending',
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ user_count INT DEFAULT 0,
+ document_count INT DEFAULT 0
+ );
+ """
+ await self.connection_manager.execute_query(create_table_query)
+
+ # 2. Check for duplicate rows that would violate the uniqueness constraint.
+ check_duplicates_query = f"""
+ SELECT owner_id, name, COUNT(*) AS cnt
+ FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ GROUP BY owner_id, name
+ HAVING COUNT(*) > 1
+ """
+ duplicates = await self.connection_manager.fetch_query(
+ check_duplicates_query
+ )
+ if duplicates:
+ logger.warning(
+ "Cannot add unique constraint (owner_id, name) because duplicates exist. "
+ "Please resolve duplicates first. Found duplicates: %s",
+ duplicates,
+ )
+ return # or raise an exception, depending on your use case
+
+ # 3. Parse the qualified table name into schema and table.
+ qualified_table = self._get_table_name(
+ PostgresCollectionsHandler.TABLE_NAME
+ )
+ if "." in qualified_table:
+ schema, table = qualified_table.split(".", 1)
+ else:
+ schema = "public"
+ table = qualified_table
+
+ # 4. Add the unique constraint if it does not already exist.
+ alter_table_constraint = f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint c
+ JOIN pg_class t ON c.conrelid = t.oid
+ JOIN pg_namespace n ON n.oid = t.relnamespace
+ WHERE t.relname = '{table}'
+ AND n.nspname = '{schema}'
+ AND c.conname = 'unique_owner_collection_name'
+ ) THEN
+ ALTER TABLE {qualified_table}
+ ADD CONSTRAINT unique_owner_collection_name
+ UNIQUE (owner_id, name);
+ END IF;
+ END;
+ $$;
+ """
+ await self.connection_manager.execute_query(alter_table_constraint)
+
+ async def collection_exists(self, collection_id: UUID) -> bool:
+ """Check if a collection exists."""
+ query = f"""
+ SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id]
+ )
+ return result is not None
+
+ async def create_collection(
+ self,
+ owner_id: UUID,
+ name: Optional[str] = None,
+ description: str | None = None,
+ collection_id: Optional[UUID] = None,
+ ) -> CollectionResponse:
+ if not name and not collection_id:
+ name = self.config.default_collection_name
+ collection_id = generate_default_user_collection_id(owner_id)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ (id, owner_id, name, description)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
+ """
+ params = [
+ collection_id or uuid4(),
+ owner_id,
+ name,
+ description,
+ ]
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+ if not result:
+ raise R2RException(
+ status_code=404, message="Collection not found"
+ )
+
+ return CollectionResponse(
+ id=result["id"],
+ owner_id=result["owner_id"],
+ name=result["name"],
+ description=result["description"],
+ graph_cluster_status=result["graph_cluster_status"],
+ graph_sync_status=result["graph_sync_status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ user_count=0,
+ document_count=0,
+ )
+ except UniqueViolationError:
+ raise R2RException(
+ message="Collection with this ID already exists",
+ status_code=409,
+ ) from None
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while creating the collection: {e}",
+ ) from e
+
+ async def update_collection(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> CollectionResponse:
+ """Update an existing collection."""
+ if not await self.collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ update_fields = []
+ params: list = []
+ param_index = 1
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(collection_id)
+
+ query = f"""
+ WITH updated_collection AS (
+ UPDATE {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}
+ RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
+ )
+ SELECT
+ uc.*,
+ COUNT(DISTINCT u.id) FILTER (WHERE u.id IS NOT NULL) as user_count,
+ COUNT(DISTINCT d.id) FILTER (WHERE d.id IS NOT NULL) as document_count
+ FROM updated_collection uc
+ LEFT JOIN {self._get_table_name("users")} u ON uc.id = ANY(u.collection_ids)
+ LEFT JOIN {self._get_table_name("documents")} d ON uc.id = ANY(d.collection_ids)
+ GROUP BY uc.id, uc.owner_id, uc.name, uc.description, uc.graph_sync_status, uc.graph_cluster_status, uc.created_at, uc.updated_at
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+ if not result:
+ raise R2RException(
+ status_code=404, message="Collection not found"
+ )
+
+ return CollectionResponse(
+ id=result["id"],
+ owner_id=result["owner_id"],
+ name=result["name"],
+ description=result["description"],
+ graph_sync_status=result["graph_sync_status"],
+ graph_cluster_status=result["graph_cluster_status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ user_count=result["user_count"],
+ document_count=result["document_count"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the collection: {e}",
+ ) from e
+
+ async def delete_collection_relational(self, collection_id: UUID) -> None:
+ # Remove collection_id from users
+ user_update_query = f"""
+ UPDATE {self._get_table_name("users")}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE $1 = ANY(collection_ids)
+ """
+ await self.connection_manager.execute_query(
+ user_update_query, [collection_id]
+ )
+
+ # Remove collection_id from documents
+ document_update_query = f"""
+ WITH updated AS (
+ UPDATE {self._get_table_name("documents")}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE $1 = ANY(collection_ids)
+ RETURNING 1
+ )
+ SELECT COUNT(*) AS affected_rows FROM updated
+ """
+ await self.connection_manager.fetchrow_query(
+ document_update_query, [collection_id]
+ )
+
+ # Delete the collection
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE id = $1
+ RETURNING id
+ """
+ deleted = await self.connection_manager.fetchrow_query(
+ delete_query, [collection_id]
+ )
+
+ if not deleted:
+ raise R2RException(status_code=404, message="Collection not found")
+
+ async def documents_in_collection(
+ self, collection_id: UUID, offset: int, limit: int
+ ) -> dict[str, list[DocumentResponse] | int]:
+ """Get all documents in a specific collection with pagination.
+
+ Args:
+ collection_id (UUID): The ID of the collection to get documents from.
+ offset (int): The number of documents to skip.
+ limit (int): The maximum number of documents to return.
+ Returns:
+ List[DocumentResponse]: A list of DocumentResponse objects representing the documents in the collection.
+ Raises:
+ R2RException: If the collection doesn't exist.
+ """
+ if not await self.collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+ query = f"""
+ SELECT d.id, d.owner_id, d.type, d.metadata, d.title, d.version,
+ d.size_in_bytes, d.ingestion_status, d.extraction_status, d.created_at, d.updated_at, d.summary,
+ COUNT(*) OVER() AS total_entries
+ FROM {self._get_table_name("documents")} d
+ WHERE $1 = ANY(d.collection_ids)
+ ORDER BY d.created_at DESC
+ OFFSET $2
+ """
+
+ conditions = [collection_id, offset]
+ if limit != -1:
+ query += " LIMIT $3"
+ conditions.append(limit)
+
+ results = await self.connection_manager.fetch_query(query, conditions)
+ documents = [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=[collection_id],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata=json.loads(row["metadata"]),
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ )
+ for row in results
+ ]
+ total_entries = results[0]["total_entries"] if results else 0
+
+ return {"results": documents, "total_entries": total_entries}
+
+ async def get_collections_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[CollectionResponse] | int]:
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filter_user_ids:
+ conditions.append(f"""
+ c.id IN (
+ SELECT unnest(collection_ids)
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if filter_document_ids:
+ conditions.append(f"""
+ c.id IN (
+ SELECT unnest(collection_ids)
+ FROM {self.project_name}.documents
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_document_ids)
+ param_index += 1
+
+ if filter_collection_ids:
+ conditions.append(f"c.id = ANY(${param_index})")
+ params.append(filter_collection_ids)
+ param_index += 1
+
+ where_clause = (
+ f"WHERE {' AND '.join(conditions)}" if conditions else ""
+ )
+
+ query = f"""
+ SELECT
+ c.*,
+ COUNT(*) OVER() as total_entries
+ FROM {self.project_name}.collections c
+ {where_clause}
+ ORDER BY created_at DESC
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"] if results else 0
+
+ collections = [CollectionResponse(**row) for row in results]
+
+ return {"results": collections, "total_entries": total_entries}
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while fetching collections: {e}",
+ ) from e
+
+ async def assign_document_to_collection_relational(
+ self,
+ document_id: UUID,
+ collection_id: UUID,
+ ) -> UUID:
+ """Assign a document to a collection.
+
+ Args:
+ document_id (UUID): The ID of the document to assign.
+ collection_id (UUID): The ID of the collection to assign the document to.
+
+ Raises:
+ R2RException: If the collection doesn't exist, if the document is not found,
+ or if there's a database error.
+ """
+ try:
+ if not await self.collection_exists(collection_id):
+ raise R2RException(
+ status_code=404, message="Collection not found"
+ )
+
+ # First, check if the document exists
+ document_check_query = f"""
+ SELECT 1 FROM {self._get_table_name("documents")}
+ WHERE id = $1
+ """
+ document_exists = await self.connection_manager.fetchrow_query(
+ document_check_query, [document_id]
+ )
+
+ if not document_exists:
+ raise R2RException(
+ status_code=404, message="Document not found"
+ )
+
+ # If document exists, proceed with the assignment
+ assign_query = f"""
+ UPDATE {self._get_table_name("documents")}
+ SET collection_ids = array_append(collection_ids, $1)
+ WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ assign_query, [collection_id, document_id]
+ )
+
+ if not result:
+ # Document exists but was already assigned to the collection
+ raise R2RException(
+ status_code=409,
+ message="Document is already assigned to the collection",
+ )
+
+ update_collection_query = f"""
+ UPDATE {self._get_table_name("collections")}
+ SET document_count = document_count + 1
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(
+ query=update_collection_query, params=[collection_id]
+ )
+
+ return collection_id
+
+ except R2RException:
+ # Re-raise R2RExceptions as they are already handled
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error '{e}' occurred while assigning the document to the collection",
+ ) from e
+
+ async def remove_document_from_collection_relational(
+ self, document_id: UUID, collection_id: UUID
+ ) -> None:
+ """Remove a document from a collection.
+
+ Args:
+ document_id (UUID): The ID of the document to remove.
+ collection_id (UUID): The ID of the collection to remove the document from.
+
+ Raises:
+ R2RException: If the collection doesn't exist or if the document is not in the collection.
+ """
+ if not await self.collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ query = f"""
+ UPDATE {self._get_table_name("documents")}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE id = $2 AND $1 = ANY(collection_ids)
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id, document_id]
+ )
+
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message="Document not found in the specified collection",
+ )
+
+ await self.decrement_collection_document_count(
+ collection_id=collection_id
+ )
+
+ async def decrement_collection_document_count(
+ self, collection_id: UUID, decrement_by: int = 1
+ ) -> None:
+ """Decrement the document count for a collection.
+
+ Args:
+ collection_id (UUID): The ID of the collection to update
+ decrement_by (int): Number to decrease the count by (default: 1)
+ """
+ collection_query = f"""
+ UPDATE {self._get_table_name("collections")}
+ SET document_count = document_count - $1
+ WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ collection_query, [decrement_by, collection_id]
+ )
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "owner_id",
+ "name",
+ "description",
+ "graph_sync_status",
+ "graph_cluster_status",
+ "created_at",
+ "updated_at",
+ "user_count",
+ "document_count",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ owner_id::text,
+ name,
+ description,
+ graph_sync_status,
+ graph_cluster_status,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ user_count,
+ document_count
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ params = []
+ if filters:
+ conditions = []
+ param_index = 1
+
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "owner_id": row[1],
+ "name": row[2],
+ "description": row[3],
+ "graph_sync_status": row[4],
+ "graph_cluster_status": row[5],
+ "created_at": row[6],
+ "updated_at": row[7],
+ "user_count": row[8],
+ "document_count": row[9],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+ async def get_collection_by_name(
+ self, owner_id: UUID, name: str
+ ) -> Optional[CollectionResponse]:
+ """Fetch a collection by owner_id + name combination.
+
+ Return None if not found.
+ """
+ query = f"""
+ SELECT
+ id, owner_id, name, description, graph_sync_status,
+ graph_cluster_status, created_at, updated_at, user_count, document_count
+ FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE owner_id = $1 AND name = $2
+ LIMIT 1
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [owner_id, name]
+ )
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message="No collection found with the specified name",
+ )
+ return CollectionResponse(
+ id=result["id"],
+ owner_id=result["owner_id"],
+ name=result["name"],
+ description=result["description"],
+ graph_sync_status=result["graph_sync_status"],
+ graph_cluster_status=result["graph_cluster_status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ user_count=result["user_count"],
+ document_count=result["document_count"],
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/conversations.py b/.venv/lib/python3.12/site-packages/core/providers/database/conversations.py
new file mode 100644
index 00000000..2be2356c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/conversations.py
@@ -0,0 +1,858 @@
+import csv
+import json
+import logging
+import tempfile
+from datetime import datetime
+from typing import IO, Any, Optional
+from uuid import UUID, uuid4
+
+from fastapi import HTTPException
+
+from core.base import Handler, Message, R2RException
+from shared.api.models.management.responses import (
+ ConversationResponse,
+ MessageResponse,
+)
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+
+def _validate_image_size(
+ message: Message, max_size_bytes: int = 5 * 1024 * 1024
+) -> None:
+ """
+ Validates that images in a message don't exceed the maximum allowed size.
+
+ Args:
+ message: Message object to validate
+ max_size_bytes: Maximum allowed size for base64-encoded images (default: 5MB)
+
+ Raises:
+ R2RException: If image is too large
+ """
+ if (
+ hasattr(message, "image_data")
+ and message.image_data
+ and "data" in message.image_data
+ ):
+ base64_data = message.image_data["data"]
+
+ # Calculate approximate decoded size (base64 increases size by ~33%)
+ # The formula is: decoded_size = encoded_size * 3/4
+ estimated_size_bytes = len(base64_data) * 0.75
+
+ if estimated_size_bytes > max_size_bytes:
+ raise R2RException(
+ status_code=413, # Payload Too Large
+ message=f"Image too large: {estimated_size_bytes / 1024 / 1024:.2f}MB exceeds the maximum allowed size of {max_size_bytes / 1024 / 1024:.2f}MB",
+ )
+
+
+def _json_default(obj: Any) -> str:
+ """Default handler for objects not serializable by the standard json
+ encoder."""
+ if isinstance(obj, datetime):
+ # Return ISO8601 string
+ return obj.isoformat()
+ elif isinstance(obj, UUID):
+ # Convert UUID to string
+ return str(obj)
+ # If you have other special types, handle them here...
+ # e.g. decimal.Decimal -> str(obj)
+
+ # If we get here, raise an error or just default to string:
+ raise TypeError(f"Type {type(obj)} not serializable")
+
+
+def safe_dumps(obj: Any) -> str:
+ """Wrap `json.dumps` with a default that serializes UUID and datetime."""
+ return json.dumps(obj, default=_json_default)
+
+
+class PostgresConversationsHandler(Handler):
+ def __init__(
+ self, project_name: str, connection_manager: PostgresConnectionManager
+ ):
+ self.project_name = project_name
+ self.connection_manager = connection_manager
+
+ async def create_tables(self):
+ create_conversations_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ user_id UUID,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ name TEXT
+ );
+ """
+
+ create_messages_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ conversation_id UUID NOT NULL,
+ parent_id UUID,
+ content JSONB,
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
+ FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
+ );
+ """
+ await self.connection_manager.execute_query(create_conversations_query)
+ await self.connection_manager.execute_query(create_messages_query)
+
+ async def create_conversation(
+ self,
+ user_id: Optional[UUID] = None,
+ name: Optional[str] = None,
+ ) -> ConversationResponse:
+ query = f"""
+ INSERT INTO {self._get_table_name("conversations")} (user_id, name)
+ VALUES ($1, $2)
+ RETURNING id, extract(epoch from created_at) as created_at_epoch
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, [user_id, name]
+ )
+
+ return ConversationResponse(
+ id=result["id"],
+ created_at=result["created_at_epoch"],
+ user_id=user_id or None,
+ name=name or None,
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create conversation: {str(e)}",
+ ) from e
+
+ async def get_conversations_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ conversation_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, Any]:
+ conditions = []
+ params: list = []
+ param_index = 1
+
+ if filter_user_ids:
+ conditions.append(f"""
+ c.user_id IN (
+ SELECT id
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if conversation_ids:
+ conditions.append(f"c.id = ANY(${param_index})")
+ params.append(conversation_ids)
+ param_index += 1
+
+ where_clause = (
+ "WHERE " + " AND ".join(conditions) if conditions else ""
+ )
+
+ query = f"""
+ WITH conversation_overview AS (
+ SELECT c.id,
+ extract(epoch from c.created_at) as created_at_epoch,
+ c.user_id,
+ c.name
+ FROM {self._get_table_name("conversations")} c
+ {where_clause}
+ ),
+ counted_overview AS (
+ SELECT *,
+ COUNT(*) OVER() AS total_entries
+ FROM conversation_overview
+ )
+ SELECT * FROM counted_overview
+ ORDER BY created_at_epoch DESC
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"]
+ conversations = [
+ {
+ "id": str(row["id"]),
+ "created_at": row["created_at_epoch"],
+ "user_id": str(row["user_id"]) if row["user_id"] else None,
+ "name": row["name"] or None,
+ }
+ for row in results
+ ]
+
+ return {"results": conversations, "total_entries": total_entries}
+
+ async def add_message(
+ self,
+ conversation_id: UUID,
+ content: Message,
+ parent_id: Optional[UUID] = None,
+ metadata: Optional[dict] = None,
+ max_image_size_bytes: int = 5 * 1024 * 1024, # 5MB default
+ ) -> MessageResponse:
+ # Validate image size
+ try:
+ _validate_image_size(content, max_image_size_bytes)
+ except R2RException:
+ # Re-raise validation exceptions
+ raise
+ except Exception as e:
+ # Handle unexpected errors during validation
+ logger.error(f"Error validating image: {str(e)}")
+ raise R2RException(
+ status_code=400, message=f"Invalid image data: {str(e)}"
+ ) from e
+
+ # 1) Validate that conversation and parent exist (existing code)
+ conv_check_query = f"""
+ SELECT 1 FROM {self._get_table_name("conversations")}
+ WHERE id = $1
+ """
+ conv_row = await self.connection_manager.fetchrow_query(
+ conv_check_query, [conversation_id]
+ )
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ if parent_id:
+ parent_check_query = f"""
+ SELECT 1 FROM {self._get_table_name("messages")}
+ WHERE id = $1 AND conversation_id = $2
+ """
+ parent_row = await self.connection_manager.fetchrow_query(
+ parent_check_query, [parent_id, conversation_id]
+ )
+ if not parent_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
+ )
+
+ # 2) Add image info to metadata for tracking/analytics if images are present
+ metadata = metadata or {}
+ if hasattr(content, "image_url") and content.image_url:
+ metadata["has_image"] = True
+ metadata["image_type"] = "url"
+ elif hasattr(content, "image_data") and content.image_data:
+ metadata["has_image"] = True
+ metadata["image_type"] = "base64"
+ # Don't store the actual base64 data in metadata as it would be redundant
+
+ # 3) Convert the content & metadata to JSON strings
+ message_id = uuid4()
+ # Using safe_dumps to handle any type of serialization
+ content_str = safe_dumps(content.model_dump())
+ metadata_str = safe_dumps(metadata)
+
+ # 4) Insert the message (existing code)
+ query = f"""
+ INSERT INTO {self._get_table_name("messages")}
+ (id, conversation_id, parent_id, content, created_at, metadata)
+ VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
+ RETURNING id
+ """
+ inserted = await self.connection_manager.fetchrow_query(
+ query,
+ [
+ message_id,
+ conversation_id,
+ parent_id,
+ content_str,
+ metadata_str,
+ ],
+ )
+ if not inserted:
+ raise R2RException(
+ status_code=500, message="Failed to insert message."
+ )
+
+ return MessageResponse(id=message_id, message=content)
+
+ async def edit_message(
+ self,
+ message_id: UUID,
+ new_content: str | None = None,
+ additional_metadata: dict | None = None,
+ ) -> dict[str, Any]:
+ # Get the original message
+ query = f"""
+ SELECT conversation_id, parent_id, content, metadata, created_at
+ FROM {self._get_table_name("messages")}
+ WHERE id = $1
+ """
+ row = await self.connection_manager.fetchrow_query(query, [message_id])
+ if not row:
+ raise R2RException(
+ status_code=404,
+ message=f"Message {message_id} not found.",
+ )
+
+ old_content = json.loads(row["content"])
+ old_metadata = json.loads(row["metadata"])
+
+ if new_content is not None:
+ old_message = Message(**old_content)
+ edited_message = Message(
+ role=old_message.role,
+ content=new_content,
+ name=old_message.name,
+ function_call=old_message.function_call,
+ tool_calls=old_message.tool_calls,
+ # Preserve image content if it exists
+ image_url=getattr(old_message, "image_url", None),
+ image_data=getattr(old_message, "image_data", None),
+ )
+ content_to_save = edited_message.model_dump()
+ else:
+ content_to_save = old_content
+
+ additional_metadata = additional_metadata or {}
+
+ new_metadata = {
+ **old_metadata,
+ **additional_metadata,
+ "edited": (
+ True
+ if new_content is not None
+ else old_metadata.get("edited", False)
+ ),
+ }
+
+ # Update message without changing the timestamp
+ update_query = f"""
+ UPDATE {self._get_table_name("messages")}
+ SET content = $1::jsonb,
+ metadata = $2::jsonb,
+ created_at = $3
+ WHERE id = $4
+ RETURNING id
+ """
+ updated = await self.connection_manager.fetchrow_query(
+ update_query,
+ [
+ json.dumps(content_to_save),
+ json.dumps(new_metadata),
+ row["created_at"],
+ message_id,
+ ],
+ )
+ if not updated:
+ raise R2RException(
+ status_code=500, message="Failed to update message."
+ )
+
+ return {
+ "id": str(message_id),
+ "message": (
+ Message(**content_to_save)
+ if isinstance(content_to_save, dict)
+ else content_to_save
+ ),
+ "metadata": new_metadata,
+ }
+
+ async def update_message_metadata(
+ self, message_id: UUID, metadata: dict
+ ) -> None:
+ # Fetch current metadata
+ query = f"""
+ SELECT metadata FROM {self._get_table_name("messages")}
+ WHERE id = $1
+ """
+ row = await self.connection_manager.fetchrow_query(query, [message_id])
+ if not row:
+ raise R2RException(
+ status_code=404, message=f"Message {message_id} not found."
+ )
+
+ current_metadata = json.loads(row["metadata"]) or {}
+ updated_metadata = {**current_metadata, **metadata}
+
+ update_query = f"""
+ UPDATE {self._get_table_name("messages")}
+ SET metadata = $1::jsonb
+ WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ update_query, [json.dumps(updated_metadata), message_id]
+ )
+
+ async def get_conversation(
+ self,
+ conversation_id: UUID,
+ filter_user_ids: Optional[list[UUID]] = None,
+ ) -> list[MessageResponse]:
+ # Existing validation code remains the same
+ conditions = ["c.id = $1"]
+ params: list = [conversation_id]
+
+ if filter_user_ids:
+ param_index = 2
+ conditions.append(f"""
+ c.user_id IN (
+ SELECT id
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+
+ query = f"""
+ SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch
+ FROM {self._get_table_name("conversations")} c
+ WHERE {" AND ".join(conditions)}
+ """
+
+ conv_row = await self.connection_manager.fetchrow_query(query, params)
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ # Retrieve messages in chronological order
+ msg_query = f"""
+ SELECT id, content, metadata
+ FROM {self._get_table_name("messages")}
+ WHERE conversation_id = $1
+ ORDER BY created_at ASC
+ """
+ results = await self.connection_manager.fetch_query(
+ msg_query, [conversation_id]
+ )
+
+ response_messages = []
+ for row in results:
+ try:
+ # Parse the message content
+ content_json = json.loads(row["content"])
+ # Create a Message object with the parsed content
+ message = Message(**content_json)
+ # Create a MessageResponse
+ response_messages.append(
+ MessageResponse(
+ id=row["id"],
+ message=message,
+ metadata=json.loads(row["metadata"]),
+ )
+ )
+ except Exception as e:
+ # If there's an error parsing the message (e.g., due to version mismatch),
+ # log it and create a fallback message
+ logger.warning(f"Error parsing message {row['id']}: {str(e)}")
+ fallback_content = content_json.get(
+ "content", "Message could not be loaded"
+ )
+ fallback_role = content_json.get("role", "assistant")
+
+ # Create a basic fallback message
+ fallback_message = Message(
+ role=fallback_role,
+ content=f"[Message format incompatible: {fallback_content}]",
+ )
+
+ response_messages.append(
+ MessageResponse(
+ id=row["id"],
+ message=fallback_message,
+ metadata=json.loads(row["metadata"]),
+ )
+ )
+
+ return response_messages
+
+ async def update_conversation(
+ self, conversation_id: UUID, name: str
+ ) -> ConversationResponse:
+ try:
+ # Check if conversation exists
+ conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
+ conv_row = await self.connection_manager.fetchrow_query(
+ conv_query, [conversation_id]
+ )
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ update_query = f"""
+ UPDATE {self._get_table_name("conversations")}
+ SET name = $1 WHERE id = $2
+ RETURNING user_id, extract(epoch from created_at) as created_at_epoch
+ """
+ updated_row = await self.connection_manager.fetchrow_query(
+ update_query, [name, conversation_id]
+ )
+ return ConversationResponse(
+ id=conversation_id,
+ created_at=updated_row["created_at_epoch"],
+ user_id=updated_row["user_id"] or None,
+ name=name,
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to update conversation: {str(e)}",
+ ) from e
+
+ async def delete_conversation(
+ self,
+ conversation_id: UUID,
+ filter_user_ids: Optional[list[UUID]] = None,
+ ) -> None:
+ conditions = ["c.id = $1"]
+ params: list = [conversation_id]
+
+ if filter_user_ids:
+ param_index = 2
+ conditions.append(f"""
+ c.user_id IN (
+ SELECT id
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+
+ conv_query = f"""
+ SELECT 1
+ FROM {self._get_table_name("conversations")} c
+ WHERE {" AND ".join(conditions)}
+ """
+ conv_row = await self.connection_manager.fetchrow_query(
+ conv_query, params
+ )
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ # Delete all messages
+ del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
+ await self.connection_manager.execute_query(
+ del_messages_query, [conversation_id]
+ )
+
+ # Delete conversation
+ del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
+ await self.connection_manager.execute_query(
+ del_conv_query, [conversation_id]
+ )
+
+ async def export_conversations_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "user_id",
+ "created_at",
+ "name",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ user_id::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ name
+ FROM {self._get_table_name("conversations")}
+ """
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "user_id": row[1],
+ "created_at": row[2],
+ "name": row[3],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+ async def export_messages_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ handle_images: str = "metadata_only", # Options: "full", "metadata_only", "exclude"
+ ) -> tuple[str, IO]:
+ """
+ Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+
+ Args:
+ columns: List of columns to include in export
+ filters: Filter criteria for messages
+ include_header: Whether to include header row
+ handle_images: How to handle image data in exports:
+ - "full": Include complete image data (warning: may create large files)
+ - "metadata_only": Replace image data with metadata only
+ - "exclude": Remove image data completely
+ """
+ valid_columns = {
+ "id",
+ "conversation_id",
+ "parent_id",
+ "content",
+ "metadata",
+ "created_at",
+ "has_image", # New virtual column to indicate image presence
+ }
+
+ if not columns:
+ columns = list(valid_columns - {"has_image"})
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ # Add virtual column for image presence
+ virtual_columns = []
+ has_image_column = False
+
+ if "has_image" in columns:
+ virtual_columns.append(
+ "(content->>'image_url' IS NOT NULL OR content->>'image_data' IS NOT NULL) as has_image"
+ )
+ columns.remove("has_image")
+ has_image_column = True
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ conversation_id::text,
+ parent_id::text,
+ content::text,
+ metadata::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at
+ {", " + ", ".join(virtual_columns) if virtual_columns else ""}
+ FROM {self._get_table_name("messages")}
+ """
+
+ # Keep existing filter conditions setup
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns or field == "has_image":
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ # Special filter for has_image
+ if filters and "has_image" in filters:
+ if filters["has_image"]:
+ conditions.append(
+ "(content->>'image_url' IS NOT NULL OR content->>'image_data' IS NOT NULL)"
+ )
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ # Prepare export columns
+ export_columns = list(columns)
+ if has_image_column:
+ export_columns.append("has_image")
+
+ if include_header:
+ writer.writerow(export_columns)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "conversation_id": row[1],
+ "parent_id": row[2],
+ "content": row[3],
+ "metadata": row[4],
+ "created_at": row[5],
+ }
+
+ # Add virtual column if present
+ if has_image_column:
+ row_dict["has_image"] = (
+ "true" if row[6] else "false"
+ )
+
+ # Process image data based on handle_images setting
+ if (
+ "content" in columns
+ and handle_images != "full"
+ ):
+ try:
+ content_json = json.loads(
+ row_dict["content"]
+ )
+
+ if (
+ "image_data" in content_json
+ and content_json["image_data"]
+ ):
+ media_type = content_json[
+ "image_data"
+ ].get("media_type", "image/jpeg")
+
+ if handle_images == "metadata_only":
+ content_json["image_data"] = {
+ "media_type": media_type,
+ "data": "[BASE64_DATA_EXCLUDED_FROM_EXPORT]",
+ }
+ elif handle_images == "exclude":
+ content_json.pop(
+ "image_data", None
+ )
+
+ row_dict["content"] = json.dumps(
+ content_json
+ )
+ except (json.JSONDecodeError, TypeError) as e:
+ logger.warning(
+ f"Error processing message content for export: {e}"
+ )
+
+ writer.writerow(
+ [row_dict[col] for col in export_columns]
+ )
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/documents.py b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
new file mode 100644
index 00000000..19781037
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
@@ -0,0 +1,1172 @@
+import asyncio
+import copy
+import csv
+import json
+import logging
+import math
+import tempfile
+from typing import IO, Any, Optional
+from uuid import UUID
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import (
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ Handler,
+ IngestionStatus,
+ R2RException,
+ SearchSettings,
+)
+
+from .base import PostgresConnectionManager
+from .filters import apply_filters
+
+logger = logging.getLogger()
+
+
+def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]:
+ """Recursively transform filter field names by replacing 'document_id' with
+ 'id'. Handles nested logical operators like $and, $or, etc.
+
+ Args:
+ filters (dict[str, Any]): The original filters dictionary
+
+ Returns:
+ dict[str, Any]: A new dictionary with transformed field names
+ """
+ if not filters:
+ return {}
+
+ transformed = {}
+
+ for key, value in filters.items():
+ # Handle logical operators recursively
+ if key in ("$and", "$or", "$not"):
+ if isinstance(value, list):
+ transformed[key] = [
+ transform_filter_fields(item) for item in value
+ ]
+ else:
+ transformed[key] = transform_filter_fields(value) # type: ignore
+ continue
+
+ # Replace 'document_id' with 'id'
+ new_key = "id" if key == "document_id" else key
+
+ # Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.)
+ if isinstance(value, dict):
+ transformed[new_key] = transform_filter_fields(value) # type: ignore
+ else:
+ transformed[new_key] = value
+
+ logger.debug(f"Transformed filters from {filters} to {transformed}")
+ return transformed
+
+
+class PostgresDocumentsHandler(Handler):
+ TABLE_NAME = "documents"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ dimension: int | float,
+ ):
+ self.dimension = dimension
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self):
+ logger.info(
+ f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ )
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ vector_type = f"vector{vector_dim}"
+
+ try:
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY,
+ collection_ids UUID[],
+ owner_id UUID,
+ type TEXT,
+ metadata JSONB,
+ title TEXT,
+ summary TEXT NULL,
+ summary_embedding {vector_type} NULL,
+ version TEXT,
+ size_in_bytes INT,
+ ingestion_status TEXT DEFAULT 'pending',
+ extraction_status TEXT DEFAULT 'pending',
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ ingestion_attempt_number INT DEFAULT 0,
+ raw_tsvector tsvector GENERATED ALWAYS AS (
+ setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
+ setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
+ setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
+ ) STORED,
+ total_tokens INT DEFAULT 0
+ );
+ CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
+ ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
+
+ -- Full text search index
+ CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
+ ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ USING GIN (raw_tsvector);
+ """
+ await self.connection_manager.execute_query(query)
+
+ # ---------------------------------------------------------------
+ # Now check if total_tokens column exists in the 'documents' table
+ # ---------------------------------------------------------------
+ # 1) See what columns exist
+ # column_check_query = f"""
+ # SELECT column_name
+ # FROM information_schema.columns
+ # WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}'
+ # AND table_schema = CURRENT_SCHEMA()
+ # """
+ # existing_columns = await self.connection_manager.fetch_query(column_check_query)
+ # 2) Parse the table name for schema checks
+ table_full_name = self._get_table_name(
+ PostgresDocumentsHandler.TABLE_NAME
+ )
+ parsed_schema = "public"
+ parsed_table_name = table_full_name
+ if "." in table_full_name:
+ parts = table_full_name.split(".", maxsplit=1)
+ parsed_schema = parts[0].replace('"', "").strip()
+ parsed_table_name = parts[1].replace('"', "").strip()
+ else:
+ parsed_table_name = parsed_table_name.replace('"', "").strip()
+
+ # 3) Check columns
+ column_check_query = f"""
+ SELECT column_name
+ FROM information_schema.columns
+ WHERE table_name = '{parsed_table_name}'
+ AND table_schema = '{parsed_schema}'
+ """
+ existing_columns = await self.connection_manager.fetch_query(
+ column_check_query
+ )
+
+ existing_column_names = {
+ row["column_name"] for row in existing_columns
+ }
+
+ if "total_tokens" not in existing_column_names:
+ # 2) If missing, see if the table already has data
+ # doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ # doc_count = await self.connection_manager.fetchval(doc_count_query)
+ doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ row = await self.connection_manager.fetchrow_query(
+ doc_count_query
+ )
+ if row is None:
+ doc_count = 0
+ else:
+ doc_count = row[
+ "doc_count"
+ ] # or row[0] if you prefer positional indexing
+
+ if doc_count > 0:
+ # We already have documents, but no total_tokens column
+ # => ask user to run r2r db migrate
+ logger.warning(
+ "Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files."
+ )
+
+ create_tokens_col = f"""
+ ALTER TABLE {table_full_name}
+ ADD COLUMN total_tokens INT DEFAULT 0
+ """
+ await self.connection_manager.execute_query(create_tokens_col)
+
+ except Exception as e:
+ logger.warning(f"Error {e} when creating document table.")
+ raise e
+
+ async def upsert_documents_overview(
+ self, documents_overview: DocumentResponse | list[DocumentResponse]
+ ) -> None:
+ if isinstance(documents_overview, DocumentResponse):
+ documents_overview = [documents_overview]
+
+ # TODO: make this an arg
+ max_retries = 20
+ for document in documents_overview:
+ retries = 0
+ while retries < max_retries:
+ try:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ async with conn.transaction():
+ # Lock the row for update
+ check_query = f"""
+ SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE id = $1 FOR UPDATE
+ """
+ existing_doc = await conn.fetchrow(
+ check_query, document.id
+ )
+
+ db_entry = document.convert_to_db_entry()
+
+ if existing_doc:
+ db_version = existing_doc[
+ "ingestion_attempt_number"
+ ]
+ db_status = existing_doc["ingestion_status"]
+ new_version = db_entry[
+ "ingestion_attempt_number"
+ ]
+
+ # Only increment version if status is changing to 'success' or if it's a new version
+ if (
+ db_status != "success"
+ and db_entry["ingestion_status"]
+ == "success"
+ ) or (new_version > db_version):
+ new_attempt_number = db_version + 1
+ else:
+ new_attempt_number = db_version
+
+ db_entry["ingestion_attempt_number"] = (
+ new_attempt_number
+ )
+
+ update_query = f"""
+ UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ SET collection_ids = $1,
+ owner_id = $2,
+ type = $3,
+ metadata = $4,
+ title = $5,
+ version = $6,
+ size_in_bytes = $7,
+ ingestion_status = $8,
+ extraction_status = $9,
+ updated_at = $10,
+ ingestion_attempt_number = $11,
+ summary = $12,
+ summary_embedding = $13,
+ total_tokens = $14
+ WHERE id = $15
+ """
+
+ await conn.execute(
+ update_query,
+ db_entry["collection_ids"],
+ db_entry["owner_id"],
+ db_entry["document_type"],
+ db_entry["metadata"],
+ db_entry["title"],
+ db_entry["version"],
+ db_entry["size_in_bytes"],
+ db_entry["ingestion_status"],
+ db_entry["extraction_status"],
+ db_entry["updated_at"],
+ db_entry["ingestion_attempt_number"],
+ db_entry["summary"],
+ db_entry["summary_embedding"],
+ db_entry[
+ "total_tokens"
+ ], # pass the new field here
+ document.id,
+ )
+ else:
+ insert_query = f"""
+ INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ (id, collection_ids, owner_id, type, metadata, title, version,
+ size_in_bytes, ingestion_status, extraction_status, created_at,
+ updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
+ """
+ await conn.execute(
+ insert_query,
+ db_entry["id"],
+ db_entry["collection_ids"],
+ db_entry["owner_id"],
+ db_entry["document_type"],
+ db_entry["metadata"],
+ db_entry["title"],
+ db_entry["version"],
+ db_entry["size_in_bytes"],
+ db_entry["ingestion_status"],
+ db_entry["extraction_status"],
+ db_entry["created_at"],
+ db_entry["updated_at"],
+ db_entry["ingestion_attempt_number"],
+ db_entry["summary"],
+ db_entry["summary_embedding"],
+ db_entry["total_tokens"],
+ )
+
+ break # Success, exit the retry loop
+ except (
+ asyncpg.exceptions.UniqueViolationError,
+ asyncpg.exceptions.DeadlockDetectedError,
+ ) as e:
+ retries += 1
+ if retries == max_retries:
+ logger.error(
+ f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
+ )
+ raise
+ else:
+ wait_time = 0.1 * (2**retries) # Exponential backoff
+ await asyncio.sleep(wait_time)
+
+ async def delete(
+ self, document_id: UUID, version: Optional[str] = None
+ ) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+
+ params = [str(document_id)]
+
+ if version:
+ query += " AND version = $2"
+ params.append(version)
+
+ await self.connection_manager.execute_query(query=query, params=params)
+
+ async def _get_status_from_table(
+ self,
+ ids: list[UUID],
+ table_name: str,
+ status_type: str,
+ column_name: str,
+ ):
+ """Get the workflow status for a given document or list of documents.
+
+ Args:
+ ids (list[UUID]): The document IDs.
+ table_name (str): The table name.
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The workflow status for the given document or list of documents.
+ """
+ query = f"""
+ SELECT {status_type} FROM {self._get_table_name(table_name)}
+ WHERE {column_name} = ANY($1)
+ """
+ return [
+ row[status_type]
+ for row in await self.connection_manager.fetch_query(query, [ids])
+ ]
+
+ async def _get_ids_from_table(
+ self,
+ status: list[str],
+ table_name: str,
+ status_type: str,
+ collection_id: Optional[UUID] = None,
+ ):
+ """Get the IDs from a given table.
+
+ Args:
+ status (str | list[str]): The status or list of statuses to retrieve.
+ table_name (str): The table name.
+ status_type (str): The type of status to retrieve.
+ """
+ query = f"""
+ SELECT id FROM {self._get_table_name(table_name)}
+ WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
+ """
+ records = await self.connection_manager.fetch_query(
+ query, [status, collection_id]
+ )
+ return [record["id"] for record in records]
+
+ async def _set_status_in_table(
+ self,
+ ids: list[UUID],
+ status: str,
+ table_name: str,
+ status_type: str,
+ column_name: str,
+ ):
+ """Set the workflow status for a given document or list of documents.
+
+ Args:
+ ids (list[UUID]): The document IDs.
+ status (str): The status to set.
+ table_name (str): The table name.
+ status_type (str): The type of status to set.
+ column_name (str): The column name in the table to update.
+ """
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {status_type} = $1
+ WHERE {column_name} = Any($2)
+ """
+ await self.connection_manager.execute_query(query, [status, ids])
+
+ def _get_status_model(self, status_type: str):
+ """Get the status model for a given status type.
+
+ Args:
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The status model for the given status type.
+ """
+ if status_type == "ingestion":
+ return IngestionStatus
+ elif status_type == "extraction_status":
+ return GraphExtractionStatus
+ elif status_type in {"graph_cluster_status", "graph_sync_status"}:
+ return GraphConstructionStatus
+ else:
+ raise R2RException(
+ status_code=400, message=f"Invalid status type: {status_type}"
+ )
+
+ async def get_workflow_status(
+ self, id: UUID | list[UUID], status_type: str
+ ):
+ """Get the workflow status for a given document or list of documents.
+
+ Args:
+ id (UUID | list[UUID]): The document ID or list of document IDs.
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The workflow status for the given document or list of documents.
+ """
+
+ ids = [id] if isinstance(id, UUID) else id
+ out_model = self._get_status_model(status_type)
+ result = await self._get_status_from_table(
+ ids,
+ out_model.table_name(),
+ status_type,
+ out_model.id_column(),
+ )
+
+ result = [out_model[status.upper()] for status in result]
+ return result[0] if isinstance(id, UUID) else result
+
+ async def set_workflow_status(
+ self, id: UUID | list[UUID], status_type: str, status: str
+ ):
+ """Set the workflow status for a given document or list of documents.
+
+ Args:
+ id (UUID | list[UUID]): The document ID or list of document IDs.
+ status_type (str): The type of status to set.
+ status (str): The status to set.
+ """
+ ids = [id] if isinstance(id, UUID) else id
+ out_model = self._get_status_model(status_type)
+
+ return await self._set_status_in_table(
+ ids,
+ status,
+ out_model.table_name(),
+ status_type,
+ out_model.id_column(),
+ )
+
+ async def get_document_ids_by_status(
+ self,
+ status_type: str,
+ status: str | list[str],
+ collection_id: Optional[UUID] = None,
+ ):
+ """Get the IDs for a given status.
+
+ Args:
+ ids_key (str): The key to retrieve the IDs.
+ status_type (str): The type of status to retrieve.
+ status (str | list[str]): The status or list of statuses to retrieve.
+ """
+
+ if isinstance(status, str):
+ status = [status]
+
+ out_model = self._get_status_model(status_type)
+ return await self._get_ids_from_table(
+ status, out_model.table_name(), status_type, collection_id
+ )
+
+ async def get_documents_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ include_summary_embedding: Optional[bool] = True,
+ filters: Optional[dict[str, Any]] = None,
+ sort_order: str = "DESC", # Add this parameter with a default of DESC
+ ) -> dict[str, Any]:
+ """Fetch overviews of documents with optional offset/limit pagination.
+
+ You can use either:
+ - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids`
+ - A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`.
+
+ If both the `filters` dict and any of the traditional filter arguments are provided,
+ this method will raise an error.
+ """
+
+ filters = copy.deepcopy(filters)
+ filters = transform_filter_fields(filters) # type: ignore
+
+ # Safety check: We do not allow mixing the old filter arguments with the new `filters` dict.
+ # This keeps the query logic unambiguous.
+ if filters and any(
+ [
+ filter_user_ids,
+ filter_document_ids,
+ filter_collection_ids,
+ ]
+ ):
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ "Cannot use both the 'filters' dictionary "
+ "and the 'filter_*_ids' parameters simultaneously."
+ ),
+ )
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ # -------------------------------------------
+ # 1) If using the new `filters` dict approach
+ # -------------------------------------------
+ if filters:
+ # Apply the filters to generate a WHERE clause
+ filter_condition, filter_params = apply_filters(
+ filters, params, mode="condition_only"
+ )
+ if filter_condition:
+ conditions.append(filter_condition)
+ # Make sure we keep adding to the same params list
+ # params.extend(filter_params)
+ param_index += len(filter_params)
+
+ # -------------------------------------------
+ # 2) If using the old filter_*_ids approach
+ # -------------------------------------------
+ else:
+ # Handle document IDs with AND
+ if filter_document_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(filter_document_ids)
+ param_index += 1
+
+ # For owner/collection filters, we used OR logic previously
+ # so we combine them into a single sub-condition in parentheses
+ or_conditions = []
+ if filter_user_ids:
+ or_conditions.append(f"owner_id = ANY(${param_index})")
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if filter_collection_ids:
+ or_conditions.append(f"collection_ids && ${param_index}")
+ params.append(filter_collection_ids)
+ param_index += 1
+
+ if or_conditions:
+ conditions.append(f"({' OR '.join(or_conditions)})")
+
+ # -------------------------
+ # Build the full query
+ # -------------------------
+ base_query = (
+ f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ )
+ if conditions:
+ # Combine everything with AND
+ base_query += " WHERE " + " AND ".join(conditions)
+
+ # Construct SELECT fields (including total_entries via window function)
+ select_fields = """
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ COUNT(*) OVER() AS total_entries
+ """
+
+ query = f"""
+ {select_fields}
+ {base_query}
+ ORDER BY created_at {sort_order}
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+ param_index += 1
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+ total_entries = results[0]["total_entries"] if results else 0
+
+ documents = []
+ for row in results:
+ # Safely handle the embedding
+ embedding = None
+ if (
+ "summary_embedding" in row
+ and row["summary_embedding"] is not None
+ ):
+ try:
+ # The embedding is stored as a string like "[0.1, 0.2, ...]"
+ embedding_str = row["summary_embedding"]
+ if embedding_str.startswith(
+ "["
+ ) and embedding_str.endswith("]"):
+ embedding = [
+ float(x)
+ for x in embedding_str[1:-1].split(",")
+ if x
+ ]
+ except Exception as e:
+ logger.warning(
+ f"Failed to parse embedding for document {row['id']}: {e}"
+ )
+
+ documents.append(
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata=json.loads(row["metadata"]),
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(
+ row["ingestion_status"]
+ ),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"] if "summary" in row else None,
+ summary_embedding=(
+ embedding if include_summary_embedding else None
+ ),
+ total_tokens=row["total_tokens"],
+ )
+ )
+ return {"results": documents, "total_entries": total_entries}
+ except Exception as e:
+ logger.error(f"Error in get_documents_overview: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail="Database query failed",
+ ) from e
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: list[dict],
+ overwrite: bool = False,
+ ) -> DocumentResponse:
+ """
+ Update the metadata of a document, either by appending to existing metadata or overwriting it.
+ Accepts a list of metadata dictionaries.
+ """
+
+ doc_result = await self.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+
+ if not doc_result["results"]:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Document with ID {document_id} not found",
+ )
+
+ existing_doc = doc_result["results"][0]
+
+ if overwrite:
+ combined_metadata: dict[str, Any] = {}
+ for meta_item in metadata:
+ combined_metadata |= meta_item
+ existing_doc.metadata = combined_metadata
+ else:
+ for meta_item in metadata:
+ existing_doc.metadata.update(meta_item)
+
+ await self.upsert_documents_overview(existing_doc)
+
+ return existing_doc
+
+ async def semantic_document_search(
+ self, query_embedding: list[float], search_settings: SearchSettings
+ ) -> list[DocumentResponse]:
+ """Search documents using semantic similarity with their summary
+ embeddings."""
+
+ where_clauses = ["summary_embedding IS NOT NULL"]
+ params: list[str | int | bytes] = [str(query_embedding)]
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ filters = copy.deepcopy(search_settings.filters)
+ if filters:
+ filter_condition, params = apply_filters(
+ transform_filter_fields(filters), params, mode="condition_only"
+ )
+ if filter_condition:
+ where_clauses.append(filter_condition)
+
+ where_clause = " AND ".join(where_clauses)
+
+ query = f"""
+ WITH document_scores AS (
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ (summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
+ FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ ORDER BY semantic_distance ASC
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ SELECT *,
+ 1.0 - semantic_distance as semantic_score
+ FROM document_scores
+ """
+
+ params.extend([search_settings.limit, search_settings.offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata={
+ **(
+ json.loads(row["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": float(row["semantic_score"]),
+ "search_type": "semantic",
+ },
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ summary_embedding=[
+ float(x)
+ for x in row["summary_embedding"][1:-1].split(",")
+ if x
+ ],
+ total_tokens=row["total_tokens"],
+ )
+ for row in results
+ ]
+
+ async def full_text_document_search(
+ self, query_text: str, search_settings: SearchSettings
+ ) -> list[DocumentResponse]:
+ """Enhanced full-text search using generated tsvector."""
+
+ where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
+ params: list[str | int | bytes] = [query_text]
+
+ filters = copy.deepcopy(search_settings.filters)
+ if filters:
+ filter_condition, params = apply_filters(
+ transform_filter_fields(filters), params, mode="condition_only"
+ )
+ if filter_condition:
+ where_clauses.append(filter_condition)
+
+ where_clause = " AND ".join(where_clauses)
+
+ query = f"""
+ WITH document_scores AS (
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
+ FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ ORDER BY text_score DESC
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ SELECT * FROM document_scores
+ """
+
+ params.extend([search_settings.limit, search_settings.offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata={
+ **(
+ json.loads(row["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": float(row["text_score"]),
+ "search_type": "full_text",
+ },
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ summary_embedding=(
+ [
+ float(x)
+ for x in row["summary_embedding"][1:-1].split(",")
+ if x
+ ]
+ if row["summary_embedding"]
+ else None
+ ),
+ total_tokens=row["total_tokens"],
+ )
+ for row in results
+ ]
+
+ async def hybrid_document_search(
+ self,
+ query_text: str,
+ query_embedding: list[float],
+ search_settings: SearchSettings,
+ ) -> list[DocumentResponse]:
+ """Search documents using both semantic and full-text search with RRF
+ fusion."""
+
+ # Get more results than needed for better fusion
+ extended_settings = copy.deepcopy(search_settings)
+ extended_settings.limit = search_settings.limit * 3
+
+ # Get results from both search methods
+ semantic_results = await self.semantic_document_search(
+ query_embedding, extended_settings
+ )
+ full_text_results = await self.full_text_document_search(
+ query_text, extended_settings
+ )
+
+ # Combine results using RRF
+ doc_scores: dict[str, dict] = {}
+
+ # Process semantic results
+ for rank, result in enumerate(semantic_results, 1):
+ doc_id = str(result.id)
+ doc_scores[doc_id] = {
+ "semantic_rank": rank,
+ "full_text_rank": len(full_text_results)
+ + 1, # Default rank if not found
+ "data": result,
+ }
+
+ # Process full-text results
+ for rank, result in enumerate(full_text_results, 1):
+ doc_id = str(result.id)
+ if doc_id in doc_scores:
+ doc_scores[doc_id]["full_text_rank"] = rank
+ else:
+ doc_scores[doc_id] = {
+ "semantic_rank": len(semantic_results)
+ + 1, # Default rank if not found
+ "full_text_rank": rank,
+ "data": result,
+ }
+
+ # Calculate RRF scores using hybrid search settings
+ rrf_k = search_settings.hybrid_settings.rrf_k
+ semantic_weight = search_settings.hybrid_settings.semantic_weight
+ full_text_weight = search_settings.hybrid_settings.full_text_weight
+
+ for scores in doc_scores.values():
+ semantic_score = 1 / (rrf_k + scores["semantic_rank"])
+ full_text_score = 1 / (rrf_k + scores["full_text_rank"])
+
+ # Weighted combination
+ combined_score = (
+ semantic_score * semantic_weight
+ + full_text_score * full_text_weight
+ ) / (semantic_weight + full_text_weight)
+
+ scores["final_score"] = combined_score
+
+ # Sort by final score and apply offset/limit
+ sorted_results = sorted(
+ doc_scores.values(), key=lambda x: x["final_score"], reverse=True
+ )[
+ search_settings.offset : search_settings.offset
+ + search_settings.limit
+ ]
+
+ return [
+ DocumentResponse(
+ **{
+ **result["data"].__dict__,
+ "metadata": {
+ **(
+ result["data"].metadata
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": result["final_score"],
+ "semantic_rank": result["semantic_rank"],
+ "full_text_rank": result["full_text_rank"],
+ "search_type": "hybrid",
+ },
+ }
+ )
+ for result in sorted_results
+ ]
+
+ async def search_documents(
+ self,
+ query_text: str,
+ query_embedding: Optional[list[float]] = None,
+ settings: Optional[SearchSettings] = None,
+ ) -> list[DocumentResponse]:
+ """Main search method that delegates to the appropriate search method
+ based on settings."""
+ if settings is None:
+ settings = SearchSettings()
+
+ if (
+ settings.use_semantic_search and settings.use_fulltext_search
+ ) or settings.use_hybrid_search:
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding is required for hybrid search"
+ )
+ return await self.hybrid_document_search(
+ query_text, query_embedding, settings
+ )
+ elif settings.use_semantic_search:
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding is required for vector search"
+ )
+ return await self.semantic_document_search(
+ query_embedding, settings
+ )
+ else:
+ return await self.full_text_document_search(query_text, settings)
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "collection_ids",
+ "owner_id",
+ "type",
+ "metadata",
+ "title",
+ "summary",
+ "version",
+ "size_in_bytes",
+ "ingestion_status",
+ "extraction_status",
+ "created_at",
+ "updated_at",
+ "total_tokens",
+ }
+ filters = copy.deepcopy(filters)
+ filters = transform_filter_fields(filters) # type: ignore
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ collection_ids::text,
+ owner_id::text,
+ type::text,
+ metadata::text AS metadata,
+ title,
+ summary,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ total_tokens
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "collection_ids": row[1],
+ "owner_id": row[2],
+ "type": row[3],
+ "metadata": row[4],
+ "title": row[5],
+ "summary": row[6],
+ "version": row[7],
+ "size_in_bytes": row[8],
+ "ingestion_status": row[9],
+ "extraction_status": row[10],
+ "created_at": row[11],
+ "updated_at": row[12],
+ "total_tokens": row[13],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/files.py b/.venv/lib/python3.12/site-packages/core/providers/database/files.py
new file mode 100644
index 00000000..dc349a7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/files.py
@@ -0,0 +1,334 @@
+import io
+import logging
+from datetime import datetime
+from io import BytesIO
+from typing import BinaryIO, Optional
+from uuid import UUID
+from zipfile import ZipFile
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import Handler, R2RException
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresFilesHandler(Handler):
+ """PostgreSQL implementation of the FileHandler."""
+
+ TABLE_NAME = "files"
+
+ connection_manager: PostgresConnectionManager
+
+ async def create_tables(self) -> None:
+ """Create the necessary tables for file storage."""
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} (
+ document_id UUID PRIMARY KEY,
+ name TEXT NOT NULL,
+ oid OID NOT NULL,
+ size BIGINT NOT NULL,
+ type TEXT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ -- Create trigger for updating the updated_at timestamp
+ CREATE OR REPLACE FUNCTION {self.project_name}.update_files_updated_at()
+ RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ DROP TRIGGER IF EXISTS update_files_updated_at
+ ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)};
+
+ CREATE TRIGGER update_files_updated_at
+ BEFORE UPDATE ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ FOR EACH ROW
+ EXECUTE FUNCTION {self.project_name}.update_files_updated_at();
+ """
+ await self.connection_manager.execute_query(query)
+
+ async def upsert_file(
+ self,
+ document_id: UUID,
+ file_name: str,
+ file_oid: int,
+ file_size: int,
+ file_type: Optional[str] = None,
+ ) -> None:
+ """Add or update a file entry in storage."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ (document_id, name, oid, size, type)
+ VALUES ($1, $2, $3, $4, $5)
+ ON CONFLICT (document_id) DO UPDATE SET
+ name = EXCLUDED.name,
+ oid = EXCLUDED.oid,
+ size = EXCLUDED.size,
+ type = EXCLUDED.type,
+ updated_at = NOW();
+ """
+ await self.connection_manager.execute_query(
+ query, [document_id, file_name, file_oid, file_size, file_type]
+ )
+
+ async def store_file(
+ self,
+ document_id: UUID,
+ file_name: str,
+ file_content: io.BytesIO,
+ file_type: Optional[str] = None,
+ ) -> None:
+ """Store a new file in the database."""
+ size = file_content.getbuffer().nbytes
+
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ async with conn.transaction():
+ oid = await conn.fetchval("SELECT lo_create(0)")
+ await self._write_lobject(conn, oid, file_content)
+ await self.upsert_file(
+ document_id, file_name, oid, size, file_type
+ )
+
+ async def _write_lobject(
+ self, conn, oid: int, file_content: io.BytesIO
+ ) -> None:
+ """Write content to a large object."""
+ lobject = await conn.fetchval("SELECT lo_open($1, $2)", oid, 0x20000)
+
+ try:
+ chunk_size = 8192 # 8 KB chunks
+ while True:
+ if chunk := file_content.read(chunk_size):
+ await conn.execute(
+ "SELECT lowrite($1, $2)", lobject, chunk
+ )
+ else:
+ break
+
+ await conn.execute("SELECT lo_close($1)", lobject)
+
+ except Exception as e:
+ await conn.execute("SELECT lo_unlink($1)", oid)
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to write to large object: {e}",
+ ) from e
+
+ async def retrieve_file(
+ self, document_id: UUID
+ ) -> Optional[tuple[str, BinaryIO, int]]:
+ """Retrieve a file from storage."""
+ query = f"""
+ SELECT name, oid, size
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [document_id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message=f"File for document {document_id} not found",
+ )
+
+ file_name, oid, size = (
+ result["name"],
+ result["oid"],
+ result["size"],
+ )
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ file_content = await self._read_lobject(conn, oid)
+ return file_name, io.BytesIO(file_content), size
+
+ async def retrieve_files_as_zip(
+ self,
+ document_ids: Optional[list[UUID]] = None,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ ) -> tuple[str, BinaryIO, int]:
+ """Retrieve multiple files and return them as a zip file."""
+
+ query = f"""
+ SELECT document_id, name, oid, size
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE 1=1
+ """
+ params: list = []
+
+ if document_ids:
+ query += f" AND document_id = ANY(${len(params) + 1})"
+ params.append([str(doc_id) for doc_id in document_ids])
+
+ if start_date:
+ query += f" AND created_at >= ${len(params) + 1}"
+ params.append(start_date)
+
+ if end_date:
+ query += f" AND created_at <= ${len(params) + 1}"
+ params.append(end_date)
+
+ query += " ORDER BY created_at DESC"
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ raise R2RException(
+ status_code=404,
+ message="No files found matching the specified criteria",
+ )
+
+ zip_buffer = BytesIO()
+ total_size = 0
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ with ZipFile(zip_buffer, "w") as zip_file:
+ for record in results:
+ file_content = await self._read_lobject(
+ conn, record["oid"]
+ )
+
+ zip_file.writestr(record["name"], file_content)
+ total_size += record["size"]
+
+ zip_buffer.seek(0)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ zip_filename = f"files_export_{timestamp}.zip"
+
+ return zip_filename, zip_buffer, zip_buffer.getbuffer().nbytes
+
+ async def _read_lobject(self, conn, oid: int) -> bytes:
+ """Read content from a large object."""
+ file_data = io.BytesIO()
+ chunk_size = 8192
+
+ async with conn.transaction():
+ try:
+ lo_exists = await conn.fetchval(
+ "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_largeobject_metadata WHERE oid = $1);",
+ oid,
+ )
+ if not lo_exists:
+ raise R2RException(
+ status_code=404,
+ message=f"Large object {oid} not found.",
+ )
+
+ lobject = await conn.fetchval(
+ "SELECT lo_open($1, 262144)", oid
+ )
+
+ if lobject is None:
+ raise R2RException(
+ status_code=404,
+ message=f"Failed to open large object {oid}.",
+ )
+
+ while True:
+ chunk = await conn.fetchval(
+ "SELECT loread($1, $2)", lobject, chunk_size
+ )
+ if not chunk:
+ break
+ file_data.write(chunk)
+ except asyncpg.exceptions.UndefinedObjectError:
+ raise R2RException(
+ status_code=404,
+ message=f"Failed to read large object {oid}",
+ ) from None
+ finally:
+ await conn.execute("SELECT lo_close($1)", lobject)
+
+ return file_data.getvalue()
+
+ async def delete_file(self, document_id: UUID) -> bool:
+ """Delete a file from storage."""
+ query = f"""
+ SELECT oid FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ oid = await conn.fetchval(query, document_id)
+ if not oid:
+ raise R2RException(
+ status_code=404,
+ message=f"File for document {document_id} not found",
+ )
+
+ await self._delete_lobject(conn, oid)
+
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+ await conn.execute(delete_query, document_id)
+
+ return True
+
+ async def _delete_lobject(self, conn, oid: int) -> None:
+ """Delete a large object."""
+ await conn.execute("SELECT lo_unlink($1)", oid)
+
+ async def get_files_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_file_names: Optional[list[str]] = None,
+ ) -> list[dict]:
+ """Get an overview of stored files."""
+ conditions = []
+ params: list[str | list[str] | int] = []
+ query = f"""
+ SELECT document_id, name, oid, size, type, created_at, updated_at
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ """
+
+ if filter_document_ids:
+ conditions.append(f"document_id = ANY(${len(params) + 1})")
+ params.append([str(doc_id) for doc_id in filter_document_ids])
+
+ if filter_file_names:
+ conditions.append(f"name = ANY(${len(params) + 1})")
+ params.append(filter_file_names)
+
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+
+ query += f" ORDER BY created_at DESC OFFSET ${len(params) + 1} LIMIT ${len(params) + 2}"
+ params.extend([offset, limit])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ raise R2RException(
+ status_code=404,
+ message="No files found with the given filters",
+ )
+
+ return [
+ {
+ "document_id": row["document_id"],
+ "file_name": row["name"],
+ "file_oid": row["oid"],
+ "file_size": row["size"],
+ "file_type": row["type"],
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/filters.py b/.venv/lib/python3.12/site-packages/core/providers/database/filters.py
new file mode 100644
index 00000000..9231e35b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/filters.py
@@ -0,0 +1,478 @@
+import json
+from typing import Any, Optional, Tuple
+
+COLUMN_VARS = [
+ "id",
+ "document_id",
+ "owner_id",
+ "collection_ids",
+]
+
+
+class FilterError(Exception):
+ pass
+
+
+class FilterOperator:
+ EQ = "$eq"
+ NE = "$ne"
+ LT = "$lt"
+ LTE = "$lte"
+ GT = "$gt"
+ GTE = "$gte"
+ IN = "$in"
+ NIN = "$nin"
+ LIKE = "$like"
+ ILIKE = "$ilike"
+ CONTAINS = "$contains"
+ AND = "$and"
+ OR = "$or"
+ OVERLAP = "$overlap"
+
+ SCALAR_OPS = {EQ, NE, LT, LTE, GT, GTE, LIKE, ILIKE}
+ ARRAY_OPS = {IN, NIN, OVERLAP}
+ JSON_OPS = {CONTAINS}
+ LOGICAL_OPS = {AND, OR}
+
+
+class FilterCondition:
+ def __init__(self, field: str, operator: str, value: Any):
+ self.field = field
+ self.operator = operator
+ self.value = value
+
+
+class FilterExpression:
+ def __init__(self, logical_op: Optional[str] = None):
+ self.logical_op = logical_op
+ self.conditions: list[FilterCondition | "FilterExpression"] = []
+
+
+class FilterParser:
+ def __init__(
+ self,
+ top_level_columns: Optional[list[str]] = None,
+ json_column: str = "metadata",
+ ):
+ if top_level_columns is None:
+ self.top_level_columns = set(COLUMN_VARS)
+ else:
+ self.top_level_columns = set(top_level_columns)
+ self.json_column = json_column
+
+ def parse(self, filters: dict) -> FilterExpression:
+ if not filters:
+ raise FilterError("Empty filters are not allowed")
+ return self._parse_logical(filters)
+
+ def _parse_logical(self, dct: dict) -> FilterExpression:
+ keys = list(dct.keys())
+ expr = FilterExpression()
+ if len(keys) == 1 and keys[0] in (
+ FilterOperator.AND,
+ FilterOperator.OR,
+ ):
+ expr.logical_op = keys[0]
+ if not isinstance(dct[keys[0]], list):
+ raise FilterError(f"{keys[0]} value must be a list")
+ for item in dct[keys[0]]:
+ if isinstance(item, dict):
+ if self._is_logical_block(item):
+ expr.conditions.append(self._parse_logical(item))
+ else:
+ expr.conditions.append(
+ self._parse_condition_dict(item)
+ )
+ else:
+ raise FilterError("Invalid filter format")
+ else:
+ expr.logical_op = FilterOperator.AND
+ expr.conditions.append(self._parse_condition_dict(dct))
+
+ return expr
+
+ def _is_logical_block(self, dct: dict) -> bool:
+ if len(dct.keys()) == 1:
+ k = next(iter(dct.keys()))
+ if k in FilterOperator.LOGICAL_OPS:
+ return True
+ return False
+
+ def _parse_condition_dict(self, dct: dict) -> FilterExpression:
+ expr = FilterExpression(logical_op=FilterOperator.AND)
+ for field, cond in dct.items():
+ if not isinstance(cond, dict):
+ # direct equality
+ expr.conditions.append(
+ FilterCondition(field, FilterOperator.EQ, cond)
+ )
+ else:
+ if len(cond) != 1:
+ raise FilterError(
+ f"Condition for field {field} must have exactly one operator"
+ )
+ op, val = next(iter(cond.items()))
+ self._validate_operator(op)
+ expr.conditions.append(FilterCondition(field, op, val))
+ return expr
+
+ def _validate_operator(self, op: str):
+ allowed = (
+ FilterOperator.SCALAR_OPS
+ | FilterOperator.ARRAY_OPS
+ | FilterOperator.JSON_OPS
+ | FilterOperator.LOGICAL_OPS
+ )
+ if op not in allowed:
+ raise FilterError(f"Unsupported operator: {op}")
+
+
+class SQLFilterBuilder:
+ def __init__(
+ self,
+ params: list[Any],
+ top_level_columns: Optional[list[str]] = None,
+ json_column: str = "metadata",
+ mode: str = "where_clause",
+ ):
+ if top_level_columns is None:
+ self.top_level_columns = set(COLUMN_VARS)
+ else:
+ self.top_level_columns = set(top_level_columns)
+ self.json_column = json_column
+ self.params: list[Any] = params # mutated during construction
+ self.mode = mode
+
+ def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]:
+ where_clause = self._build_expression(expr)
+ if self.mode == "where_clause":
+ return f"WHERE {where_clause}", self.params
+
+ return where_clause, self.params
+
+ def _build_expression(self, expr: FilterExpression) -> str:
+ parts = []
+ for c in expr.conditions:
+ if isinstance(c, FilterCondition):
+ parts.append(self._build_condition(c))
+ else:
+ nested_sql = self._build_expression(c)
+ parts.append(f"({nested_sql})")
+
+ if expr.logical_op == FilterOperator.AND:
+ return " AND ".join(parts)
+ elif expr.logical_op == FilterOperator.OR:
+ return " OR ".join(parts)
+ else:
+ return " AND ".join(parts)
+
+ @staticmethod
+ def _psql_quote_literal(value: str) -> str:
+ """Simple quoting for demonstration.
+
+ In production, use parameterized queries or your DB driver's quoting
+ function instead.
+ """
+ return "'" + value.replace("'", "''") + "'"
+
+ def _build_condition(self, cond: FilterCondition) -> str:
+ field_is_metadata = cond.field not in self.top_level_columns
+ key = cond.field
+ op = cond.operator
+ val = cond.value
+
+ # 1. If the filter references "parent_id", handle it as a single-UUID column for graphs:
+ if key == "parent_id":
+ return self._build_parent_id_condition(op, val)
+
+ # 2. If the filter references "collection_id", handle it as an array column (chunks)
+ if key == "collection_id":
+ return self._build_collection_id_condition(op, val)
+
+ # 3. Otherwise, decide if it's top-level or metadata:
+ if field_is_metadata:
+ return self._build_metadata_condition(key, op, val)
+ else:
+ return self._build_column_condition(key, op, val)
+
+ def _build_parent_id_condition(self, op: str, val: Any) -> str:
+ """For 'graphs' tables, parent_id is a single UUID (not an array).
+
+ We handle the same ops but in a simpler, single-UUID manner.
+ """
+ param_idx = len(self.params) + 1
+
+ if op == "$eq":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$eq for parent_id expects a single UUID string"
+ )
+ self.params.append(val)
+ return f"parent_id = ${param_idx}::uuid"
+
+ elif op == "$ne":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$ne for parent_id expects a single UUID string"
+ )
+ self.params.append(val)
+ return f"parent_id != ${param_idx}::uuid"
+
+ elif op == "$in":
+ # A list of UUIDs, any of which might match
+ if not isinstance(val, list):
+ raise FilterError(
+ "$in for parent_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return f"parent_id = ANY(${param_idx}::uuid[])"
+
+ elif op == "$nin":
+ # A list of UUIDs, none of which may match
+ if not isinstance(val, list):
+ raise FilterError(
+ "$nin for parent_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return f"parent_id != ALL(${param_idx}::uuid[])"
+
+ else:
+ # You could add more (like $gt, $lt, etc.) if your schema wants them
+ raise FilterError(f"Unsupported operator {op} for parent_id")
+
+ def _build_collection_id_condition(self, op: str, val: Any) -> str:
+ """For the 'chunks' table, collection_ids is an array of UUIDs.
+
+ We need to use array operators to compare arrays correctly.
+ """
+ param_idx = len(self.params) + 1
+
+ if op == "$eq":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$eq for collection_id expects a single UUID string"
+ )
+ self.params.append(
+ [val]
+ ) # Make it a list with one element for the overlap check
+ return (
+ f"collection_ids && ${param_idx}::uuid[]" # Use && for overlap
+ )
+
+ elif op == "$ne":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$ne for collection_id expects a single UUID string"
+ )
+ self.params.append([val])
+ return f"NOT (collection_ids && ${param_idx}::uuid[])" # Negate the overlap
+
+ elif op == "$in":
+ if not isinstance(val, list):
+ raise FilterError(
+ "$in for collection_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return (
+ f"collection_ids && ${param_idx}::uuid[]" # Use && for overlap
+ )
+
+ elif op == "$nin":
+ if not isinstance(val, list):
+ raise FilterError(
+ "$nin for collection_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return f"NOT (collection_ids && ${param_idx}::uuid[])" # Negate the overlap
+
+ elif op == "$contains":
+ if isinstance(val, str):
+ # single string -> array with one element
+ self.params.append([val])
+ return f"collection_ids @> ${param_idx}::uuid[]"
+ elif isinstance(val, list):
+ self.params.append(val)
+ return f"collection_ids @> ${param_idx}::uuid[]"
+ else:
+ raise FilterError(
+ "$contains for collection_id expects a UUID or list of UUIDs"
+ )
+
+ elif op == "$overlap":
+ if not isinstance(val, list):
+ self.params.append([val])
+ else:
+ self.params.append(val)
+ return f"collection_ids && ${param_idx}::uuid[]"
+
+ else:
+ raise FilterError(f"Unsupported operator {op} for collection_id")
+
+ def _build_column_condition(self, col: str, op: str, val: Any) -> str:
+ # If we're dealing with collection_ids, route to our specialized handler
+ if col == "collection_ids":
+ return self._build_collection_id_condition(op, val)
+
+ param_idx = len(self.params) + 1
+ if op == "$eq":
+ self.params.append(val)
+ return f"{col} = ${param_idx}"
+ elif op == "$ne":
+ self.params.append(val)
+ return f"{col} != ${param_idx}"
+ elif op == "$in":
+ if not isinstance(val, list):
+ raise FilterError("argument to $in filter must be a list")
+ self.params.append(val)
+ return f"{col} = ANY(${param_idx})"
+ elif op == "$nin":
+ if not isinstance(val, list):
+ raise FilterError("argument to $nin filter must be a list")
+ self.params.append(val)
+ return f"{col} != ALL(${param_idx})"
+ elif op == "$overlap":
+ self.params.append(val)
+ return f"{col} && ${param_idx}"
+ elif op == "$contains":
+ self.params.append(val)
+ return f"{col} @> ${param_idx}"
+ elif op == "$any":
+ if col == "collection_ids":
+ self.params.append(f"%{val}%")
+ return f"array_to_string({col}, ',') LIKE ${param_idx}"
+ else:
+ self.params.append(val)
+ return f"${param_idx} = ANY({col})"
+ elif op in ("$lt", "$lte", "$gt", "$gte"):
+ self.params.append(val)
+ return f"{col} {self._map_op(op)} ${param_idx}"
+ else:
+ raise FilterError(f"Unsupported operator for column {col}: {op}")
+
+ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
+ param_idx = len(self.params) + 1
+ json_col = self.json_column
+
+ # Strip "metadata." prefix if present
+ key = key.removeprefix("metadata.")
+
+ # Split on '.' to handle nested keys
+ parts = key.split(".")
+
+ # Use text extraction for scalar values, but not for arrays
+ use_text_extraction = op in (
+ "$lt",
+ "$lte",
+ "$gt",
+ "$gte",
+ "$eq",
+ "$ne",
+ ) and isinstance(val, (int, float, str))
+ if op == "$in" or op == "$contains" or isinstance(val, (list, dict)):
+ use_text_extraction = False
+
+ # Build the JSON path expression
+ if len(parts) == 1:
+ if use_text_extraction:
+ path_expr = f"{json_col}->>'{parts[0]}'"
+ else:
+ path_expr = f"{json_col}->'{parts[0]}'"
+ else:
+ path_expr = json_col
+ for p in parts[:-1]:
+ path_expr += f"->'{p}'"
+ last_part = parts[-1]
+ if use_text_extraction:
+ path_expr += f"->>'{last_part}'"
+ else:
+ path_expr += f"->'{last_part}'"
+
+ # Convert numeric values to strings for text comparison
+ def prepare_value(v):
+ return str(v) if isinstance(v, (int, float)) else v
+
+ if op == "$eq":
+ if use_text_extraction:
+ prepared_val = prepare_value(val)
+ self.params.append(prepared_val)
+ return f"{path_expr} = ${param_idx}"
+ else:
+ self.params.append(json.dumps(val))
+ return f"{path_expr} = ${param_idx}::jsonb"
+ elif op == "$ne":
+ if use_text_extraction:
+ self.params.append(prepare_value(val))
+ return f"{path_expr} != ${param_idx}"
+ else:
+ self.params.append(json.dumps(val))
+ return f"{path_expr} != ${param_idx}::jsonb"
+ elif op == "$lt":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric < ${param_idx}::numeric"
+ elif op == "$lte":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric <= ${param_idx}::numeric"
+ elif op == "$gt":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric > ${param_idx}::numeric"
+ elif op == "$gte":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric >= ${param_idx}::numeric"
+ elif op == "$in":
+ if not isinstance(val, list):
+ raise FilterError("argument to $in filter must be a list")
+
+ if use_text_extraction:
+ str_vals = [
+ str(v) if isinstance(v, (int, float)) else v for v in val
+ ]
+ self.params.append(str_vals)
+ return f"{path_expr} = ANY(${param_idx}::text[])"
+
+ # For JSON arrays, use containment checks
+ conditions = []
+ for i, v in enumerate(val):
+ self.params.append(json.dumps(v))
+ conditions.append(f"{path_expr} @> ${param_idx + i}::jsonb")
+ return f"({' OR '.join(conditions)})"
+
+ elif op == "$contains":
+ if isinstance(val, (str, int, float, bool)):
+ val = [val]
+ self.params.append(json.dumps(val))
+ return f"{path_expr} @> ${param_idx}::jsonb"
+ else:
+ raise FilterError(f"Unsupported operator for metadata field {op}")
+
+ def _map_op(self, op: str) -> str:
+ mapping = {
+ FilterOperator.EQ: "=",
+ FilterOperator.NE: "!=",
+ FilterOperator.LT: "<",
+ FilterOperator.LTE: "<=",
+ FilterOperator.GT: ">",
+ FilterOperator.GTE: ">=",
+ }
+ return mapping.get(op, op)
+
+
+def apply_filters(
+ filters: dict, params: list[Any], mode: str = "where_clause"
+) -> tuple[str, list[Any]]:
+ """Apply filters with consistent WHERE clause handling."""
+ if not filters:
+ return "", params
+
+ parser = FilterParser()
+ expr = parser.parse(filters)
+ builder = SQLFilterBuilder(params=params, mode=mode)
+ filter_clause, new_params = builder.build(expr)
+
+ if mode == "where_clause":
+ return filter_clause, new_params # Already includes WHERE
+ elif mode == "condition_only":
+ return filter_clause, new_params
+ elif mode == "append_only":
+ return f"AND {filter_clause}", new_params
+ else:
+ raise ValueError(f"Unknown filter mode: {mode}")
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py b/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
new file mode 100644
index 00000000..ba9c22ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
@@ -0,0 +1,2884 @@
+import asyncio
+import contextlib
+import csv
+import datetime
+import json
+import logging
+import os
+import tempfile
+import time
+from typing import IO, Any, AsyncGenerator, Optional, Tuple
+from uuid import UUID
+
+import asyncpg
+import httpx
+from asyncpg.exceptions import UniqueViolationError
+from fastapi import HTTPException
+
+from core.base.abstractions import (
+ Community,
+ Entity,
+ Graph,
+ GraphExtractionStatus,
+ R2RException,
+ Relationship,
+ StoreType,
+ VectorQuantizationType,
+)
+from core.base.api.models import GraphResponse
+from core.base.providers.database import Handler
+from core.base.utils import (
+ _get_vector_column_str,
+ generate_entity_document_id,
+)
+
+from .base import PostgresConnectionManager
+from .collections import PostgresCollectionsHandler
+
+logger = logging.getLogger()
+
+
+class PostgresEntitiesHandler(Handler):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+ self.relationships_handler: PostgresRelationshipsHandler = (
+ PostgresRelationshipsHandler(*args, **kwargs)
+ )
+
+ def _get_table_name(self, table: str) -> str:
+ """Get the fully qualified table name."""
+ return f'"{self.project_name}"."{table}"'
+
+ def _get_entity_table_for_store(self, store_type: StoreType) -> str:
+ """Get the appropriate table name for the store type."""
+ return f"{store_type.value}_entities"
+
+ def _get_parent_constraint(self, store_type: StoreType) -> str:
+ """Get the appropriate foreign key constraint for the store type."""
+ if store_type == StoreType.GRAPHS:
+ return f"""
+ CONSTRAINT fk_graph
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("graphs")}(id)
+ ON DELETE CASCADE
+ """
+ else:
+ return f"""
+ CONSTRAINT fk_document
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("documents")}(id)
+ ON DELETE CASCADE
+ """
+
+ async def create_tables(self) -> None:
+ """Create separate tables for graph and document entities."""
+ vector_column_str = _get_vector_column_str(
+ self.dimension, self.quantization_type
+ )
+
+ for store_type in StoreType:
+ table_name = self._get_entity_table_for_store(store_type)
+ parent_constraint = self._get_parent_constraint(store_type)
+
+ QUERY = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ name TEXT NOT NULL,
+ category TEXT,
+ description TEXT,
+ parent_id UUID NOT NULL,
+ description_embedding {vector_column_str},
+ chunk_ids UUID[],
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ {parent_constraint}
+ );
+ CREATE INDEX IF NOT EXISTS {table_name}_name_idx
+ ON {self._get_table_name(table_name)} (name);
+ CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+ ON {self._get_table_name(table_name)} (parent_id);
+ CREATE INDEX IF NOT EXISTS {table_name}_category_idx
+ ON {self._get_table_name(table_name)} (category);
+ """
+ await self.connection_manager.execute_query(QUERY)
+
+ async def create(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ name: str,
+ category: Optional[str] = None,
+ description: Optional[str] = None,
+ description_embedding: Optional[list[float] | str] = None,
+ chunk_ids: Optional[list[UUID]] = None,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Entity:
+ """Create a new entity in the specified store."""
+ table_name = self._get_entity_table_for_store(store_type)
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if isinstance(description_embedding, list):
+ description_embedding = str(description_embedding)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+ """
+
+ params = [
+ name,
+ category,
+ description,
+ parent_id,
+ description_embedding,
+ chunk_ids,
+ json.dumps(metadata) if metadata else None,
+ ]
+
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Entity(
+ id=result["id"],
+ name=result["name"],
+ category=result["category"],
+ description=result["description"],
+ parent_id=result["parent_id"],
+ chunk_ids=result["chunk_ids"],
+ metadata=result["metadata"],
+ )
+
+ async def get(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ offset: int,
+ limit: int,
+ entity_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ """Retrieve entities from the specified store."""
+ table_name = self._get_entity_table_for_store(store_type)
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if entity_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(entity_ids)
+ param_index += 1
+
+ if entity_names:
+ conditions.append(f"name = ANY(${param_index})")
+ params.append(entity_names)
+ param_index += 1
+
+ select_fields = """
+ id, name, category, description, parent_id,
+ chunk_ids, metadata
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+
+ count_params = params[: param_index - 1]
+ count = (
+ await self.connection_manager.fetch_query(
+ COUNT_QUERY, count_params
+ )
+ )[0]["count"]
+
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ entities = []
+ for row in rows:
+ # Convert the Record to a dictionary
+ entity_dict = dict(row)
+
+ # Process metadata if it exists and is a string
+ if isinstance(entity_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ entity_dict["metadata"] = json.loads(
+ entity_dict["metadata"]
+ )
+
+ entities.append(Entity(**entity_dict))
+
+ return entities, count
+
+ async def update(
+ self,
+ entity_id: UUID,
+ store_type: StoreType,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ description_embedding: Optional[list[float] | str] = None,
+ category: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Entity:
+ """Update an entity in the specified store."""
+ table_name = self._get_entity_table_for_store(store_type)
+ update_fields = []
+ params: list[Any] = []
+ param_index = 1
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if description_embedding is not None:
+ update_fields.append(f"description_embedding = ${param_index}")
+ params.append(description_embedding)
+ param_index += 1
+
+ if category is not None:
+ update_fields.append(f"category = ${param_index}")
+ params.append(category)
+ param_index += 1
+
+ if metadata is not None:
+ update_fields.append(f"metadata = ${param_index}")
+ params.append(json.dumps(metadata))
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(entity_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}\
+ RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Entity(
+ id=result["id"],
+ name=result["name"],
+ category=result["category"],
+ description=result["description"],
+ parent_id=result["parent_id"],
+ chunk_ids=result["chunk_ids"],
+ metadata=result["metadata"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the entity: {e}",
+ ) from e
+
+ async def delete(
+ self,
+ parent_id: UUID,
+ entity_ids: Optional[list[UUID]] = None,
+ store_type: StoreType = StoreType.GRAPHS,
+ ) -> None:
+ """Delete entities from the specified store. If entity_ids is not
+ provided, deletes all entities for the given parent_id.
+
+ Args:
+ parent_id (UUID): Parent ID (collection_id or document_id)
+ entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
+ store_type (StoreType): Type of store (graph or document)
+
+ Returns:
+ list[UUID]: List of deleted entity IDs
+
+ Raises:
+ R2RException: If specific entities were requested but not all found
+ """
+ table_name = self._get_entity_table_for_store(store_type)
+
+ if entity_ids is None:
+ # Delete all entities for the parent_id
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE parent_id = $1
+ RETURNING id
+ """
+ results = await self.connection_manager.fetch_query(
+ QUERY, [parent_id]
+ )
+ else:
+ # Delete specific entities
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = ANY($1) AND parent_id = $2
+ RETURNING id
+ """
+
+ results = await self.connection_manager.fetch_query(
+ QUERY, [entity_ids, parent_id]
+ )
+
+ # Check if all requested entities were deleted
+ deleted_ids = [row["id"] for row in results]
+ if entity_ids and len(deleted_ids) != len(entity_ids):
+ raise R2RException(
+ f"Some entities not found in {store_type} store or no permission to delete",
+ 404,
+ )
+
+ async def get_duplicate_name_blocks(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ ) -> list[list[Entity]]:
+ """Find all groups of entities that share identical names within the
+ same parent.
+
+ Returns a list of entity groups, where each group contains entities
+ with the same name. For each group, includes the n most dissimilar
+ descriptions based on cosine similarity.
+ """
+ table_name = self._get_entity_table_for_store(store_type)
+
+ # First get the duplicate names and their descriptions with embeddings
+ query = f"""
+ WITH duplicates AS (
+ SELECT name
+ FROM {self._get_table_name(table_name)}
+ WHERE parent_id = $1
+ GROUP BY name
+ HAVING COUNT(*) > 1
+ )
+ SELECT
+ e.id, e.name, e.category, e.description,
+ e.parent_id, e.chunk_ids, e.metadata
+ FROM {self._get_table_name(table_name)} e
+ WHERE e.parent_id = $1
+ AND e.name IN (SELECT name FROM duplicates)
+ ORDER BY e.name;
+ """
+
+ rows = await self.connection_manager.fetch_query(query, [parent_id])
+
+ # Group entities by name
+ name_groups: dict[str, list[Entity]] = {}
+ for row in rows:
+ entity_dict = dict(row)
+ if isinstance(entity_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ entity_dict["metadata"] = json.loads(
+ entity_dict["metadata"]
+ )
+
+ entity = Entity(**entity_dict)
+ name_groups.setdefault(entity.name, []).append(entity)
+
+ return list(name_groups.values())
+
+ async def merge_duplicate_name_blocks(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ ) -> list[tuple[list[Entity], Entity]]:
+ """Merge entities that share identical names.
+
+ Returns list of tuples: (original_entities, merged_entity)
+ """
+ duplicate_blocks = await self.get_duplicate_name_blocks(
+ parent_id, store_type
+ )
+ merged_results: list[tuple[list[Entity], Entity]] = []
+
+ for block in duplicate_blocks:
+ # Create a new merged entity from the block
+ merged_entity = await self._create_merged_entity(block)
+ merged_results.append((block, merged_entity))
+
+ table_name = self._get_entity_table_for_store(store_type)
+ async with self.connection_manager.transaction():
+ # Insert the merged entity
+ new_id = await self._insert_merged_entity(
+ merged_entity, table_name
+ )
+
+ merged_entity.id = new_id
+
+ # Get the old entity IDs
+ old_ids = [str(entity.id) for entity in block]
+
+ relationship_table = self.relationships_handler._get_relationship_table_for_store(
+ store_type
+ )
+
+ # Update relationships where old entities appear as subjects
+ subject_update_query = f"""
+ UPDATE {self._get_table_name(relationship_table)}
+ SET subject_id = $1
+ WHERE subject_id = ANY($2::uuid[])
+ AND parent_id = $3
+ """
+ await self.connection_manager.execute_query(
+ subject_update_query, [new_id, old_ids, parent_id]
+ )
+
+ # Update relationships where old entities appear as objects
+ object_update_query = f"""
+ UPDATE {self._get_table_name(relationship_table)}
+ SET object_id = $1
+ WHERE object_id = ANY($2::uuid[])
+ AND parent_id = $3
+ """
+ await self.connection_manager.execute_query(
+ object_update_query, [new_id, old_ids, parent_id]
+ )
+
+ # Delete the original entities
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = ANY($1::uuid[])
+ """
+ await self.connection_manager.execute_query(
+ delete_query, [old_ids]
+ )
+
+ return merged_results
+
+ async def _insert_merged_entity(
+ self, entity: Entity, table_name: str
+ ) -> UUID:
+ """Insert merged entity and return its new ID."""
+ new_id = generate_entity_document_id()
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (id, name, category, description, parent_id, chunk_ids, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ RETURNING id
+ """
+
+ values = [
+ new_id,
+ entity.name,
+ entity.category,
+ entity.description,
+ entity.parent_id,
+ entity.chunk_ids,
+ json.dumps(entity.metadata) if entity.metadata else None,
+ ]
+
+ result = await self.connection_manager.fetch_query(query, values)
+ return result[0]["id"]
+
+ async def _create_merged_entity(self, entities: list[Entity]) -> Entity:
+ """Create a merged entity from a list of duplicate entities.
+
+ Uses various strategies to combine fields.
+ """
+ if not entities:
+ raise ValueError("Cannot merge empty list of entities")
+
+ # Take the first non-None category, or None if all are None
+ category = next(
+ (e.category for e in entities if e.category is not None), None
+ )
+
+ # Combine descriptions with newlines if they differ
+ descriptions = {e.description for e in entities if e.description}
+ description = "\n\n".join(descriptions) if descriptions else None
+
+ # Combine chunk_ids, removing duplicates
+ chunk_ids = list(
+ {
+ chunk_id
+ for entity in entities
+ for chunk_id in (entity.chunk_ids or [])
+ }
+ )
+
+ # Merge metadata dictionaries
+ merged_metadata: dict[str, Any] = {}
+ for entity in entities:
+ if entity.metadata:
+ merged_metadata |= entity.metadata
+
+ # Create new merged entity (without actually inserting to DB)
+ return Entity(
+ id=UUID(
+ "00000000-0000-0000-0000-000000000000"
+ ), # Placeholder UUID
+ name=entities[0].name, # All entities in block have same name
+ category=category,
+ description=description,
+ parent_id=entities[0].parent_id,
+ chunk_ids=chunk_ids or None,
+ metadata=merged_metadata or None,
+ )
+
+ async def export_to_csv(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "name",
+ "category",
+ "description",
+ "parent_id",
+ "chunk_ids",
+ "metadata",
+ "created_at",
+ "updated_at",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ name,
+ category,
+ description,
+ parent_id::text,
+ chunk_ids::text,
+ metadata::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+ FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
+ """
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "name": row[1],
+ "category": row[2],
+ "description": row[3],
+ "parent_id": row[4],
+ "chunk_ids": row[5],
+ "metadata": row[6],
+ "created_at": row[7],
+ "updated_at": row[8],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+
+class PostgresRelationshipsHandler(Handler):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+
+ def _get_table_name(self, table: str) -> str:
+ """Get the fully qualified table name."""
+ return f'"{self.project_name}"."{table}"'
+
+ def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
+ """Get the appropriate table name for the store type."""
+ return f"{store_type.value}_relationships"
+
+ def _get_parent_constraint(self, store_type: StoreType) -> str:
+ """Get the appropriate foreign key constraint for the store type."""
+ if store_type == StoreType.GRAPHS:
+ return f"""
+ CONSTRAINT fk_graph
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("graphs")}(id)
+ ON DELETE CASCADE
+ """
+ else:
+ return f"""
+ CONSTRAINT fk_document
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("documents")}(id)
+ ON DELETE CASCADE
+ """
+
+ async def create_tables(self) -> None:
+ """Create separate tables for graph and document relationships."""
+ for store_type in StoreType:
+ table_name = self._get_relationship_table_for_store(store_type)
+ parent_constraint = self._get_parent_constraint(store_type)
+ vector_column_str = _get_vector_column_str(
+ self.dimension, self.quantization_type
+ )
+
+ QUERY = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ subject TEXT NOT NULL,
+ predicate TEXT NOT NULL,
+ object TEXT NOT NULL,
+ description TEXT,
+ description_embedding {vector_column_str},
+ subject_id UUID,
+ object_id UUID,
+ weight FLOAT DEFAULT 1.0,
+ chunk_ids UUID[],
+ parent_id UUID NOT NULL,
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ {parent_constraint}
+ );
+
+ CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
+ ON {self._get_table_name(table_name)} (subject);
+ CREATE INDEX IF NOT EXISTS {table_name}_object_idx
+ ON {self._get_table_name(table_name)} (object);
+ CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
+ ON {self._get_table_name(table_name)} (predicate);
+ CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+ ON {self._get_table_name(table_name)} (parent_id);
+ CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
+ ON {self._get_table_name(table_name)} (subject_id);
+ CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
+ ON {self._get_table_name(table_name)} (object_id);
+ """
+ await self.connection_manager.execute_query(QUERY)
+
+ async def create(
+ self,
+ subject: str,
+ subject_id: UUID,
+ predicate: str,
+ object: str,
+ object_id: UUID,
+ parent_id: UUID,
+ store_type: StoreType,
+ description: str | None = None,
+ weight: float | None = 1.0,
+ chunk_ids: Optional[list[UUID]] = None,
+ description_embedding: Optional[list[float] | str] = None,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Relationship:
+ """Create a new relationship in the specified store."""
+ table_name = self._get_relationship_table_for_store(store_type)
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if isinstance(description_embedding, list):
+ description_embedding = str(description_embedding)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (subject, predicate, object, description, subject_id, object_id,
+ weight, chunk_ids, parent_id, description_embedding, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+ RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+ """
+
+ params = [
+ subject,
+ predicate,
+ object,
+ description,
+ subject_id,
+ object_id,
+ weight,
+ chunk_ids,
+ parent_id,
+ description_embedding,
+ json.dumps(metadata) if metadata else None,
+ ]
+
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Relationship(
+ id=result["id"],
+ subject=result["subject"],
+ predicate=result["predicate"],
+ object=result["object"],
+ description=result["description"],
+ subject_id=result["subject_id"],
+ object_id=result["object_id"],
+ weight=result["weight"],
+ chunk_ids=result["chunk_ids"],
+ parent_id=result["parent_id"],
+ metadata=result["metadata"],
+ )
+
+ async def get(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ offset: int,
+ limit: int,
+ relationship_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ relationship_types: Optional[list[str]] = None,
+ include_metadata: bool = False,
+ ):
+ """Get relationships from the specified store.
+
+ Args:
+ parent_id: UUID of the parent (collection_id or document_id)
+ store_type: Type of store (graph or document)
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ relationship_ids: Optional list of specific relationship IDs to retrieve
+ entity_names: Optional list of entity names to filter by (matches subject or object)
+ relationship_types: Optional list of relationship types (predicates) to filter by
+ include_metadata: Whether to include metadata in the response
+
+ Returns:
+ Tuple of (list of relationships, total count)
+ """
+ table_name = self._get_relationship_table_for_store(store_type)
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if relationship_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(relationship_ids)
+ param_index += 1
+
+ if entity_names:
+ conditions.append(
+ f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
+ )
+ params.append(entity_names)
+ param_index += 1
+
+ if relationship_types:
+ conditions.append(f"predicate = ANY(${param_index})")
+ params.append(relationship_types)
+ param_index += 1
+
+ select_fields = """
+ id, subject, predicate, object, description,
+ subject_id, object_id, weight, chunk_ids,
+ parent_id
+ """
+ if include_metadata:
+ select_fields += ", metadata"
+
+ # Count query
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+ count_params = params[: param_index - 1]
+ count = (
+ await self.connection_manager.fetch_query(
+ COUNT_QUERY, count_params
+ )
+ )[0]["count"]
+
+ # Main query
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ relationships = []
+ for row in rows:
+ relationship_dict = dict(row)
+ if include_metadata and isinstance(
+ relationship_dict["metadata"], str
+ ):
+ with contextlib.suppress(json.JSONDecodeError):
+ relationship_dict["metadata"] = json.loads(
+ relationship_dict["metadata"]
+ )
+ elif not include_metadata:
+ relationship_dict.pop("metadata", None)
+ relationships.append(Relationship(**relationship_dict))
+
+ return relationships, count
+
+ async def update(
+ self,
+ relationship_id: UUID,
+ store_type: StoreType,
+ subject: Optional[str],
+ subject_id: Optional[UUID],
+ predicate: Optional[str],
+ object: Optional[str],
+ object_id: Optional[UUID],
+ description: Optional[str],
+ description_embedding: Optional[list[float] | str],
+ weight: Optional[float],
+ metadata: Optional[dict[str, Any] | str],
+ ) -> Relationship:
+ """Update multiple relationships in the specified store."""
+ table_name = self._get_relationship_table_for_store(store_type)
+ update_fields = []
+ params: list = []
+ param_index = 1
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if subject is not None:
+ update_fields.append(f"subject = ${param_index}")
+ params.append(subject)
+ param_index += 1
+
+ if subject_id is not None:
+ update_fields.append(f"subject_id = ${param_index}")
+ params.append(subject_id)
+ param_index += 1
+
+ if predicate is not None:
+ update_fields.append(f"predicate = ${param_index}")
+ params.append(predicate)
+ param_index += 1
+
+ if object is not None:
+ update_fields.append(f"object = ${param_index}")
+ params.append(object)
+ param_index += 1
+
+ if object_id is not None:
+ update_fields.append(f"object_id = ${param_index}")
+ params.append(object_id)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if description_embedding is not None:
+ update_fields.append(f"description_embedding = ${param_index}")
+ params.append(description_embedding)
+ param_index += 1
+
+ if weight is not None:
+ update_fields.append(f"weight = ${param_index}")
+ params.append(weight)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(relationship_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}
+ RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+ """
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Relationship(
+ id=result["id"],
+ subject=result["subject"],
+ predicate=result["predicate"],
+ object=result["object"],
+ description=result["description"],
+ subject_id=result["subject_id"],
+ object_id=result["object_id"],
+ weight=result["weight"],
+ chunk_ids=result["chunk_ids"],
+ parent_id=result["parent_id"],
+ metadata=result["metadata"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the relationship: {e}",
+ ) from e
+
+ async def delete(
+ self,
+ parent_id: UUID,
+ relationship_ids: Optional[list[UUID]] = None,
+ store_type: StoreType = StoreType.GRAPHS,
+ ) -> None:
+ """Delete relationships from the specified store. If relationship_ids
+ is not provided, deletes all relationships for the given parent_id.
+
+ Args:
+ parent_id: UUID of the parent (collection_id or document_id)
+ relationship_ids: Optional list of specific relationship IDs to delete
+ store_type: Type of store (graph or document)
+
+ Returns:
+ List of deleted relationship IDs
+
+ Raises:
+ R2RException: If specific relationships were requested but not all found
+ """
+ table_name = self._get_relationship_table_for_store(store_type)
+
+ if relationship_ids is None:
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE parent_id = $1
+ RETURNING id
+ """
+ results = await self.connection_manager.fetch_query(
+ QUERY, [parent_id]
+ )
+ else:
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = ANY($1) AND parent_id = $2
+ RETURNING id
+ """
+ results = await self.connection_manager.fetch_query(
+ QUERY, [relationship_ids, parent_id]
+ )
+
+ deleted_ids = [row["id"] for row in results]
+ if relationship_ids and len(deleted_ids) != len(relationship_ids):
+ raise R2RException(
+ f"Some relationships not found in {store_type} store or no permission to delete",
+ 404,
+ )
+
+ async def export_to_csv(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "subject",
+ "predicate",
+ "object",
+ "description",
+ "subject_id",
+ "object_id",
+ "weight",
+ "chunk_ids",
+ "parent_id",
+ "metadata",
+ "created_at",
+ "updated_at",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ subject,
+ predicate,
+ object,
+ description,
+ subject_id::text,
+ object_id::text,
+ weight,
+ chunk_ids::text,
+ parent_id::text,
+ metadata::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+ FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
+ """
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ writer.writerow(row)
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+
+class PostgresCommunitiesHandler(Handler):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+
+ async def create_tables(self) -> None:
+ vector_column_str = _get_vector_column_str(
+ self.dimension, self.quantization_type
+ )
+
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ collection_id UUID,
+ community_id UUID,
+ level INT,
+ name TEXT NOT NULL,
+ summary TEXT NOT NULL,
+ findings TEXT[],
+ rating FLOAT,
+ rating_explanation TEXT,
+ description_embedding {vector_column_str} NOT NULL,
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ metadata JSONB,
+ UNIQUE (community_id, level, collection_id)
+ );"""
+
+ await self.connection_manager.execute_query(query)
+
+ async def create(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ name: str,
+ summary: str,
+ findings: Optional[list[str]],
+ rating: Optional[float],
+ rating_explanation: Optional[str],
+ description_embedding: Optional[list[float] | str] = None,
+ ) -> Community:
+ table_name = "graphs_communities"
+
+ if isinstance(description_embedding, list):
+ description_embedding = str(description_embedding)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+ """
+
+ params = [
+ parent_id,
+ name,
+ summary,
+ findings,
+ rating,
+ rating_explanation,
+ description_embedding,
+ ]
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Community(
+ id=result["id"],
+ collection_id=result["collection_id"],
+ name=result["name"],
+ summary=result["summary"],
+ findings=result["findings"],
+ rating=result["rating"],
+ rating_explanation=result["rating_explanation"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while creating the community: {e}",
+ ) from e
+
+ async def update(
+ self,
+ community_id: UUID,
+ store_type: StoreType,
+ name: Optional[str] = None,
+ summary: Optional[str] = None,
+ summary_embedding: Optional[list[float] | str] = None,
+ findings: Optional[list[str]] = None,
+ rating: Optional[float] = None,
+ rating_explanation: Optional[str] = None,
+ ) -> Community:
+ table_name = "graphs_communities"
+ update_fields = []
+ params: list[Any] = []
+ param_index = 1
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if summary is not None:
+ update_fields.append(f"summary = ${param_index}")
+ params.append(summary)
+ param_index += 1
+
+ if summary_embedding is not None:
+ update_fields.append(f"description_embedding = ${param_index}")
+ params.append(summary_embedding)
+ param_index += 1
+
+ if findings is not None:
+ update_fields.append(f"findings = ${param_index}")
+ params.append(findings)
+ param_index += 1
+
+ if rating is not None:
+ update_fields.append(f"rating = ${param_index}")
+ params.append(rating)
+ param_index += 1
+
+ if rating_explanation is not None:
+ update_fields.append(f"rating_explanation = ${param_index}")
+ params.append(rating_explanation)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(community_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}\
+ RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+
+ return Community(
+ id=result["id"],
+ community_id=result["community_id"],
+ name=result["name"],
+ summary=result["summary"],
+ findings=result["findings"],
+ rating=result["rating"],
+ rating_explanation=result["rating_explanation"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the community: {e}",
+ ) from e
+
+ async def delete(
+ self,
+ parent_id: UUID,
+ community_id: UUID,
+ ) -> None:
+ table_name = "graphs_communities"
+
+ params = [community_id, parent_id]
+
+ # Delete the community
+ query = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = $1 AND collection_id = $2
+ """
+
+ try:
+ await self.connection_manager.execute_query(query, params)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while deleting the community: {e}",
+ ) from e
+
+ async def delete_all_communities(
+ self,
+ parent_id: UUID,
+ ) -> None:
+ table_name = "graphs_communities"
+
+ params = [parent_id]
+
+ # Delete all communities for the parent_id
+ query = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE collection_id = $1
+ """
+
+ try:
+ await self.connection_manager.execute_query(query, params)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while deleting communities: {e}",
+ ) from e
+
+ async def get(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ offset: int,
+ limit: int,
+ community_ids: Optional[list[UUID]] = None,
+ community_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ """Retrieve communities from the specified store."""
+ # Do we ever want to get communities from document store?
+ table_name = "graphs_communities"
+
+ conditions = ["collection_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if community_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(community_ids)
+ param_index += 1
+
+ if community_names:
+ conditions.append(f"name = ANY(${param_index})")
+ params.append(community_names)
+ param_index += 1
+
+ select_fields = """
+ id, community_id, name, summary, findings, rating,
+ rating_explanation, level, created_at, updated_at
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+
+ count = (
+ await self.connection_manager.fetch_query(
+ COUNT_QUERY, params[: param_index - 1]
+ )
+ )[0]["count"]
+
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ communities = []
+ for row in rows:
+ community_dict = dict(row)
+
+ communities.append(Community(**community_dict))
+
+ return communities, count
+
+ async def export_to_csv(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "collection_id",
+ "community_id",
+ "level",
+ "name",
+ "summary",
+ "findings",
+ "rating",
+ "rating_explanation",
+ "created_at",
+ "updated_at",
+ "metadata",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ table_name = "graphs_communities"
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ collection_id::text,
+ community_id::text,
+ level,
+ name,
+ summary,
+ findings::text,
+ rating,
+ rating_explanation,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ metadata::text
+ FROM {self._get_table_name(table_name)}
+ """
+
+ conditions = ["collection_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ writer.writerow(row)
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+
+class PostgresGraphsHandler(Handler):
+ """Handler for Knowledge Graph METHODS in PostgreSQL."""
+
+ TABLE_NAME = "graphs"
+
+ def __init__(
+ self,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+ self.collections_handler: PostgresCollectionsHandler = kwargs.get(
+ "collections_handler"
+ ) # type: ignore
+
+ self.entities = PostgresEntitiesHandler(*args, **kwargs)
+ self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
+ self.communities = PostgresCommunitiesHandler(*args, **kwargs)
+
+ self.handlers = [
+ self.entities,
+ self.relationships,
+ self.communities,
+ ]
+
+ async def create_tables(self) -> None:
+ """Create the graph tables with mandatory collection_id support."""
+ QUERY = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ collection_id UUID NOT NULL,
+ name TEXT NOT NULL,
+ description TEXT,
+ status TEXT NOT NULL,
+ document_ids UUID[],
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ CREATE INDEX IF NOT EXISTS graph_collection_id_idx
+ ON {self._get_table_name("graphs")} (collection_id);
+ """
+
+ await self.connection_manager.execute_query(QUERY)
+
+ for handler in self.handlers:
+ await handler.create_tables()
+
+ async def create(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ status: str = "pending",
+ ) -> GraphResponse:
+ """Create a new graph associated with a collection."""
+
+ name = name or f"Graph {collection_id}"
+ description = description or ""
+
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ (id, collection_id, name, description, status)
+ VALUES ($1, $2, $3, $4, $5)
+ RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
+ """
+ params = [
+ collection_id,
+ collection_id,
+ name,
+ description,
+ status,
+ ]
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return GraphResponse(
+ id=result["id"],
+ collection_id=result["collection_id"],
+ name=result["name"],
+ description=result["description"],
+ status=result["status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ document_ids=result["document_ids"] or [],
+ )
+ except UniqueViolationError:
+ raise R2RException(
+ message="Graph with this ID already exists",
+ status_code=409,
+ ) from None
+
+ async def reset(self, parent_id: UUID) -> None:
+ """Completely reset a graph and all associated data."""
+
+ await self.entities.delete(
+ parent_id=parent_id, store_type=StoreType.GRAPHS
+ )
+ await self.relationships.delete(
+ parent_id=parent_id, store_type=StoreType.GRAPHS
+ )
+ await self.communities.delete_all_communities(parent_id=parent_id)
+
+ # Now, update the graph record to remove any attached document IDs.
+ # This sets document_ids to an empty UUID array.
+ query = f"""
+ UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ SET document_ids = ARRAY[]::uuid[]
+ WHERE id = $1;
+ """
+ await self.connection_manager.execute_query(query, [parent_id])
+
+ async def list_graphs(
+ self,
+ offset: int,
+ limit: int,
+ # filter_user_ids: Optional[list[UUID]] = None,
+ filter_graph_ids: Optional[list[UUID]] = None,
+ filter_collection_id: Optional[UUID] = None,
+ ) -> dict[str, list[GraphResponse] | int]:
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filter_graph_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(filter_graph_ids)
+ param_index += 1
+
+ # if filter_user_ids:
+ # conditions.append(f"user_id = ANY(${param_index})")
+ # params.append(filter_user_ids)
+ # param_index += 1
+
+ if filter_collection_id:
+ conditions.append(f"collection_id = ${param_index}")
+ params.append(filter_collection_id)
+ param_index += 1
+
+ where_clause = (
+ f"WHERE {' AND '.join(conditions)}" if conditions else ""
+ )
+
+ query = f"""
+ WITH RankedGraphs AS (
+ SELECT
+ id, collection_id, name, description, status, created_at, updated_at, document_ids,
+ COUNT(*) OVER() as total_entries,
+ ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
+ FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ {where_clause}
+ )
+ SELECT * FROM RankedGraphs
+ WHERE rn = 1
+ ORDER BY created_at DESC
+ OFFSET ${param_index} LIMIT ${param_index + 1}
+ """
+
+ params.extend([offset, limit])
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"] if results else 0
+
+ graphs = [
+ GraphResponse(
+ id=row["id"],
+ document_ids=row["document_ids"] or [],
+ name=row["name"],
+ collection_id=row["collection_id"],
+ description=row["description"],
+ status=row["status"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ )
+ for row in results
+ ]
+
+ return {"results": graphs, "total_entries": total_entries}
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while fetching graphs: {e}",
+ ) from e
+
+ async def get(
+ self, offset: int, limit: int, graph_id: Optional[UUID] = None
+ ):
+ if graph_id is None:
+ params = [offset, limit]
+
+ QUERY = f"""
+ SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ OFFSET $1 LIMIT $2
+ """
+
+ ret = await self.connection_manager.fetch_query(QUERY, params)
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ """
+ count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
+ 0
+ ]["count"]
+
+ return {
+ "results": [Graph(**row) for row in ret],
+ "total_entries": count,
+ }
+
+ else:
+ QUERY = f"""
+ SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
+ """
+
+ params = [graph_id] # type: ignore
+
+ return {
+ "results": [
+ Graph(
+ **await self.connection_manager.fetchrow_query(
+ QUERY, params
+ )
+ )
+ ]
+ }
+
+ async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
+ """Add documents to the graph by copying their entities and
+ relationships."""
+ # Copy entities from document_entity to graphs_entities
+ ENTITY_COPY_QUERY = f"""
+ INSERT INTO {self._get_table_name("graphs_entities")} (
+ name, category, description, parent_id, description_embedding,
+ chunk_ids, metadata
+ )
+ SELECT
+ name, category, description, $1, description_embedding,
+ chunk_ids, metadata
+ FROM {self._get_table_name("documents_entities")}
+ WHERE parent_id = ANY($2)
+ """
+ await self.connection_manager.execute_query(
+ ENTITY_COPY_QUERY, [id, document_ids]
+ )
+
+ # Copy relationships from documents_relationships to graphs_relationships
+ RELATIONSHIP_COPY_QUERY = f"""
+ INSERT INTO {self._get_table_name("graphs_relationships")} (
+ subject, predicate, object, description, subject_id, object_id,
+ weight, chunk_ids, parent_id, metadata, description_embedding
+ )
+ SELECT
+ subject, predicate, object, description, subject_id, object_id,
+ weight, chunk_ids, $1, metadata, description_embedding
+ FROM {self._get_table_name("documents_relationships")}
+ WHERE parent_id = ANY($2)
+ """
+ await self.connection_manager.execute_query(
+ RELATIONSHIP_COPY_QUERY, [id, document_ids]
+ )
+
+ # Add document_ids to the graph
+ UPDATE_GRAPH_QUERY = f"""
+ UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ SET document_ids = array_cat(
+ CASE
+ WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
+ ELSE document_ids
+ END,
+ $2::uuid[]
+ )
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(
+ UPDATE_GRAPH_QUERY, [id, document_ids]
+ )
+
+ return True
+
+ async def update(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> GraphResponse:
+ """Update an existing graph."""
+ update_fields = []
+ params: list = []
+ param_index = 1
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(collection_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}
+ RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
+ """
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+
+ if not result:
+ raise R2RException(status_code=404, message="Graph not found")
+
+ return GraphResponse(
+ id=result["id"],
+ collection_id=result["collection_id"],
+ name=result["name"],
+ description=result["description"],
+ status=result["status"],
+ created_at=result["created_at"],
+ document_ids=result["document_ids"] or [],
+ updated_at=result["updated_at"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the graph: {e}",
+ ) from e
+
+ async def get_entities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ entity_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ) -> tuple[list[Entity], int]:
+ """Get entities for a graph.
+
+ Args:
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ parent_id: UUID of the collection
+ entity_ids: Optional list of entity IDs to filter by
+ entity_names: Optional list of entity names to filter by
+ include_embeddings: Whether to include embeddings in the response
+
+ Returns:
+ Tuple of (list of entities, total count)
+ """
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if entity_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(entity_ids)
+ param_index += 1
+
+ if entity_names:
+ conditions.append(f"name = ANY(${param_index})")
+ params.append(entity_names)
+ param_index += 1
+
+ # Count query - uses the same conditions but without offset/limit
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name("graphs_entities")}
+ WHERE {" AND ".join(conditions)}
+ """
+ count = (
+ await self.connection_manager.fetch_query(COUNT_QUERY, params)
+ )[0]["count"]
+
+ # Define base columns to select
+ select_fields = """
+ id, name, category, description, parent_id,
+ chunk_ids, metadata
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ # Main query for fetching entities with pagination
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name("graphs_entities")}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ entities = []
+ for row in rows:
+ entity_dict = dict(row)
+ if isinstance(entity_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ entity_dict["metadata"] = json.loads(
+ entity_dict["metadata"]
+ )
+
+ entities.append(Entity(**entity_dict))
+
+ return entities, count
+
+ async def get_relationships(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ relationship_ids: Optional[list[UUID]] = None,
+ relationship_types: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ) -> tuple[list[Relationship], int]:
+ """Get relationships for a graph.
+
+ Args:
+ parent_id: UUID of the graph
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ relationship_ids: Optional list of relationship IDs to filter by
+ relationship_types: Optional list of relationship types to filter by
+ include_metadata: Whether to include metadata in the response
+
+ Returns:
+ Tuple of (list of relationships, total count)
+ """
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if relationship_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(relationship_ids)
+ param_index += 1
+
+ if relationship_types:
+ conditions.append(f"predicate = ANY(${param_index})")
+ params.append(relationship_types)
+ param_index += 1
+
+ # Count query - uses the same conditions but without offset/limit
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name("graphs_relationships")}
+ WHERE {" AND ".join(conditions)}
+ """
+ count = (
+ await self.connection_manager.fetch_query(COUNT_QUERY, params)
+ )[0]["count"]
+
+ # Define base columns to select
+ select_fields = """
+ id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ # Main query for fetching relationships with pagination
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name("graphs_relationships")}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ relationships = []
+ for row in rows:
+ relationship_dict = dict(row)
+ if isinstance(relationship_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ relationship_dict["metadata"] = json.loads(
+ relationship_dict["metadata"]
+ )
+
+ relationships.append(Relationship(**relationship_dict))
+
+ return relationships, count
+
+ async def add_entities(
+ self,
+ entities: list[Entity],
+ table_name: str,
+ conflict_columns: list[str] | None = None,
+ ) -> asyncpg.Record:
+ """Upsert entities into the entities_raw table. These are raw entities
+ extracted from the document.
+
+ Args:
+ entities: list[Entity]: list of entities to upsert
+ collection_name: str: name of the collection
+
+ Returns:
+ result: asyncpg.Record: result of the upsert operation
+ """
+ if not conflict_columns:
+ conflict_columns = []
+ cleaned_entities = []
+ for entity in entities:
+ entity_dict = entity.to_dict()
+ entity_dict["chunk_ids"] = (
+ entity_dict["chunk_ids"]
+ if entity_dict.get("chunk_ids")
+ else []
+ )
+ entity_dict["description_embedding"] = (
+ str(entity_dict["description_embedding"])
+ if entity_dict.get("description_embedding") # type: ignore
+ else None
+ )
+ cleaned_entities.append(entity_dict)
+
+ return await _add_objects(
+ objects=cleaned_entities,
+ full_table_name=self._get_table_name(table_name),
+ connection_manager=self.connection_manager,
+ conflict_columns=conflict_columns,
+ )
+
+ async def get_all_relationships(
+ self,
+ collection_id: UUID | None,
+ graph_id: UUID | None,
+ document_ids: Optional[list[UUID]] = None,
+ ) -> list[Relationship]:
+ QUERY = f"""
+ SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
+ """
+ relationships = await self.connection_manager.fetch_query(
+ QUERY, [collection_id]
+ )
+
+ return [Relationship(**relationship) for relationship in relationships]
+
+ async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
+ """Check if a document exists in the graph's document_ids array.
+
+ Args:
+ graph_id (UUID): ID of the graph to check
+ document_id (UUID): ID of the document to look for
+
+ Returns:
+ bool: True if document exists in graph, False otherwise
+
+ Raises:
+ R2RException: If graph not found
+ """
+ QUERY = f"""
+ SELECT EXISTS (
+ SELECT 1
+ FROM {self._get_table_name("graphs")}
+ WHERE id = $1
+ AND document_ids IS NOT NULL
+ AND $2 = ANY(document_ids)
+ ) as exists;
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ QUERY, [graph_id, document_id]
+ )
+
+ if result is None:
+ raise R2RException(f"Graph {graph_id} not found", 404)
+
+ return result["exists"]
+
+ async def get_communities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ community_ids: Optional[list[UUID]] = None,
+ include_embeddings: bool = False,
+ ) -> tuple[list[Community], int]:
+ """Get communities for a graph.
+
+ Args:
+ collection_id: UUID of the collection
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ community_ids: Optional list of community IDs to filter by
+ include_embeddings: Whether to include embeddings in the response
+
+ Returns:
+ Tuple of (list of communities, total count)
+ """
+ conditions = ["collection_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if community_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(community_ids)
+ param_index += 1
+
+ select_fields = """
+ id, collection_id, name, summary, findings, rating, rating_explanation
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name("graphs_communities")}
+ WHERE {" AND ".join(conditions)}
+ """
+ count = (
+ await self.connection_manager.fetch_query(COUNT_QUERY, params)
+ )[0]["count"]
+
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name("graphs_communities")}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ communities = []
+ for row in rows:
+ community_dict = dict(row)
+ communities.append(Community(**community_dict))
+
+ return communities, count
+
+ async def add_community(self, community: Community) -> None:
+ # TODO: Fix in the short term.
+ # we need to do this because postgres insert needs to be a string
+ community.description_embedding = str(community.description_embedding) # type: ignore[assignment]
+
+ non_null_attrs = {
+ k: v for k, v in community.__dict__.items() if v is not None
+ }
+ columns = ", ".join(non_null_attrs.keys())
+ placeholders = ", ".join(
+ f"${i + 1}" for i in range(len(non_null_attrs))
+ )
+
+ conflict_columns = ", ".join(
+ [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
+ )
+
+ QUERY = f"""
+ INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
+ VALUES ({placeholders})
+ ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
+ {conflict_columns}
+ """
+
+ await self.connection_manager.execute_many(
+ QUERY, [tuple(non_null_attrs.values())]
+ )
+
+ async def delete(self, collection_id: UUID) -> None:
+ graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
+
+ if len(graphs["results"]) == 0:
+ raise R2RException(
+ message=f"Graph not found for collection {collection_id}",
+ status_code=404,
+ )
+ await self.reset(collection_id)
+ # set status to PENDING for this collection.
+ QUERY = f"""
+ UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ QUERY, [GraphExtractionStatus.PENDING, collection_id]
+ )
+ # Delete the graph
+ QUERY = f"""
+ DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
+ """
+ try:
+ await self.connection_manager.execute_query(QUERY, [collection_id])
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while deleting the graph: {e}",
+ ) from e
+
+ async def perform_graph_clustering(
+ self,
+ collection_id: UUID,
+ leiden_params: dict[str, Any],
+ ) -> Tuple[int, Any]:
+ """Calls the external clustering service to cluster the graph."""
+
+ offset = 0
+ page_size = 1000
+ all_relationships = []
+ while True:
+ relationships, count = await self.relationships.get(
+ parent_id=collection_id,
+ store_type=StoreType.GRAPHS,
+ offset=offset,
+ limit=page_size,
+ )
+
+ if not relationships:
+ break
+
+ all_relationships.extend(relationships)
+ offset += len(relationships)
+
+ if offset >= count:
+ break
+
+ logger.info(
+ f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
+ )
+ if len(all_relationships) == 0:
+ raise R2RException(
+ message="No relationships found for clustering",
+ status_code=400,
+ )
+
+ return await self._cluster_and_add_community_info(
+ relationships=all_relationships,
+ leiden_params=leiden_params,
+ collection_id=collection_id,
+ )
+
+ async def _call_clustering_service(
+ self, relationships: list[Relationship], leiden_params: dict[str, Any]
+ ) -> list[dict]:
+ """Calls the external Graspologic clustering service, sending
+ relationships and parameters.
+
+ Expects a response with 'communities' field.
+ """
+ # Convert relationships to a JSON-friendly format
+ rel_data = []
+ for r in relationships:
+ rel_data.append(
+ {
+ "id": str(r.id),
+ "subject": r.subject,
+ "object": r.object,
+ "weight": r.weight if r.weight is not None else 1.0,
+ }
+ )
+
+ endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
+ if not endpoint:
+ raise ValueError("CLUSTERING_SERVICE_URL not set.")
+
+ url = f"{endpoint}/cluster"
+
+ payload = {"relationships": rel_data, "leiden_params": leiden_params}
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(url, json=payload, timeout=3600)
+ response.raise_for_status()
+
+ data = response.json()
+ return data.get("communities", [])
+
+ async def _create_graph_and_cluster(
+ self,
+ relationships: list[Relationship],
+ leiden_params: dict[str, Any],
+ ) -> Any:
+ """Create a graph and cluster it."""
+
+ return await self._call_clustering_service(
+ relationships, leiden_params
+ )
+
+ async def _cluster_and_add_community_info(
+ self,
+ relationships: list[Relationship],
+ leiden_params: dict[str, Any],
+ collection_id: UUID,
+ ) -> Tuple[int, Any]:
+ logger.info(f"Creating graph and clustering for {collection_id}")
+
+ await asyncio.sleep(0.1)
+ start_time = time.time()
+
+ hierarchical_communities = await self._create_graph_and_cluster(
+ relationships=relationships,
+ leiden_params=leiden_params,
+ )
+
+ logger.info(
+ f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
+ )
+
+ if not hierarchical_communities:
+ num_communities = 0
+ else:
+ num_communities = (
+ max(item["cluster"] for item in hierarchical_communities) + 1
+ )
+
+ logger.info(
+ f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
+ )
+
+ return num_communities, hierarchical_communities
+
+ async def get_entity_map(
+ self, offset: int, limit: int, document_id: UUID
+ ) -> dict[str, dict[str, list[dict[str, Any]]]]:
+ QUERY1 = f"""
+ WITH entities_list AS (
+ SELECT DISTINCT name
+ FROM {self._get_table_name("documents_entities")}
+ WHERE parent_id = $1
+ ORDER BY name ASC
+ LIMIT {limit} OFFSET {offset}
+ )
+ SELECT e.name, e.description, e.category,
+ (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
+ e.parent_id
+ FROM {self._get_table_name("documents_entities")} e
+ JOIN entities_list el ON e.name = el.name
+ GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
+ ORDER BY e.name;"""
+
+ entities_list = await self.connection_manager.fetch_query(
+ QUERY1, [document_id]
+ )
+ entities_list = [Entity(**entity) for entity in entities_list]
+
+ QUERY2 = f"""
+ WITH entities_list AS (
+
+ SELECT DISTINCT name
+ FROM {self._get_table_name("documents_entities")}
+ WHERE parent_id = $1
+ ORDER BY name ASC
+ LIMIT {limit} OFFSET {offset}
+ )
+
+ SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
+ (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
+ FROM {self._get_table_name("documents_relationships")} t
+ JOIN entities_list el ON t.subject = el.name
+ ORDER BY t.subject, t.predicate, t.object;
+ """
+
+ relationships_list = await self.connection_manager.fetch_query(
+ QUERY2, [document_id]
+ )
+ relationships_list = [
+ Relationship(**relationship) for relationship in relationships_list
+ ]
+
+ entity_map: dict[str, dict[str, list[Any]]] = {}
+ for entity in entities_list:
+ if entity.name not in entity_map:
+ entity_map[entity.name] = {"entities": [], "relationships": []}
+ entity_map[entity.name]["entities"].append(entity)
+
+ for relationship in relationships_list:
+ if relationship.subject in entity_map:
+ entity_map[relationship.subject]["relationships"].append(
+ relationship
+ )
+ if relationship.object in entity_map:
+ entity_map[relationship.object]["relationships"].append(
+ relationship
+ )
+
+ return entity_map
+
+ async def graph_search(
+ self, query: str, **kwargs: Any
+ ) -> AsyncGenerator[Any, None]:
+ """Perform semantic search with similarity scores while maintaining
+ exact same structure."""
+
+ query_embedding = kwargs.get("query_embedding", None)
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding must be provided for semantic search"
+ )
+
+ search_type = kwargs.get(
+ "search_type", "entities"
+ ) # entities | relationships | communities
+ embedding_type = kwargs.get("embedding_type", "description_embedding")
+ property_names = kwargs.get("property_names", ["name", "description"])
+
+ # Add metadata if not present
+ if "metadata" not in property_names:
+ property_names.append("metadata")
+
+ filters = kwargs.get("filters", {})
+ limit = kwargs.get("limit", 10)
+ use_fulltext_search = kwargs.get("use_fulltext_search", True)
+ use_hybrid_search = kwargs.get("use_hybrid_search", True)
+
+ if use_hybrid_search or use_fulltext_search:
+ logger.warning(
+ "Hybrid and fulltext search not supported for graph search, ignoring."
+ )
+
+ table_name = f"graphs_{search_type}"
+ property_names_str = ", ".join(property_names)
+
+ # Build the WHERE clause from filters
+ params: list[str | int | bytes] = [
+ json.dumps(query_embedding),
+ limit,
+ ]
+ conditions_clause = self._build_filters(filters, params, search_type)
+ where_clause = (
+ f"WHERE {conditions_clause}" if conditions_clause else ""
+ )
+
+ # Construct the query
+ # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
+ # We'll convert that to similarity_score by doing (1 - distance).
+ QUERY = f"""
+ SELECT
+ {property_names_str},
+ ({embedding_type} <=> $1) as similarity_score
+ FROM {self._get_table_name(table_name)}
+ {where_clause}
+ ORDER BY {embedding_type} <=> $1
+ LIMIT $2;
+ """
+
+ results = await self.connection_manager.fetch_query(
+ QUERY, tuple(params)
+ )
+
+ for result in results:
+ output = {
+ prop: result[prop] for prop in property_names if prop in result
+ }
+ output["similarity_score"] = (
+ 1 - float(result["similarity_score"])
+ if result.get("similarity_score")
+ else "n/a"
+ )
+ yield output
+
+ def _build_filters(
+ self, filter_dict: dict, parameters: list[Any], search_type: str
+ ) -> str:
+ """Build a WHERE clause from a nested filter dictionary for the graph
+ search.
+
+ - If search_type == "communities", we normally filter by `collection_id`.
+ - Otherwise (entities/relationships), we normally filter by `parent_id`.
+ - If user provides `"collection_ids": {...}`, we interpret that as wanting
+ to filter by multiple collection IDs (i.e. 'parent_id IN (...)' or
+ 'collection_id IN (...)').
+ """
+
+ # The usual "base" column used by your code
+ base_id_column = (
+ "collection_id" if search_type == "communities" else "parent_id"
+ )
+
+ def parse_condition(key: str, value: Any) -> str:
+ # ----------------------------------------------------------------------
+ # 1) If it's the normal base_id_column (like "parent_id" or "collection_id")
+ # ----------------------------------------------------------------------
+ if key == base_id_column:
+ if isinstance(value, dict):
+ op, clause = next(iter(value.items()))
+ if op == "$eq":
+ # single equality
+ parameters.append(str(clause))
+ return f"{base_id_column} = ${len(parameters)}::uuid"
+ elif op in ("$in", "$overlap"):
+ # treat both $in/$overlap as "IN the set" for a single column
+ array_val = [str(x) for x in clause]
+ parameters.append(array_val)
+ return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
+ # handle other operators as needed
+ else:
+ # direct equality
+ parameters.append(str(value))
+ return f"{base_id_column} = ${len(parameters)}::uuid"
+
+ # ----------------------------------------------------------------------
+ # 2) SPECIAL: if user specifically sets "collection_ids" in filters
+ # We interpret that to mean "Look for rows whose parent_id (or collection_id)
+ # is in the array of values" – i.e. we do the same logic but we forcibly
+ # direct it to the same column: parent_id or collection_id.
+ # ----------------------------------------------------------------------
+ elif key == "collection_ids":
+ # If we are searching communities, the relevant field is `collection_id`.
+ # If searching entities/relationships, the relevant field is `parent_id`.
+ col_to_use = (
+ "collection_id"
+ if search_type == "communities"
+ else "parent_id"
+ )
+
+ if isinstance(value, dict):
+ op, clause = next(iter(value.items()))
+ if op == "$eq":
+ # single equality => col_to_use = clause
+ parameters.append(str(clause))
+ return f"{col_to_use} = ${len(parameters)}::uuid"
+ elif op in ("$in", "$overlap"):
+ # "col_to_use = ANY($param::uuid[])"
+ array_val = [str(x) for x in clause]
+ parameters.append(array_val)
+ return (
+ f"{col_to_use} = ANY(${len(parameters)}::uuid[])"
+ )
+ # add more if you want, e.g. $ne, $gt, etc.
+ else:
+ # direct equality scenario: "collection_ids": "some-uuid"
+ parameters.append(str(value))
+ return f"{col_to_use} = ${len(parameters)}::uuid"
+
+ # ----------------------------------------------------------------------
+ # 3) If key starts with "metadata.", handle metadata-based filters
+ # ----------------------------------------------------------------------
+ elif key.startswith("metadata."):
+ field = key.split("metadata.")[1]
+ if isinstance(value, dict):
+ op, clause = next(iter(value.items()))
+ if op == "$eq":
+ parameters.append(clause)
+ return f"(metadata->>'{field}') = ${len(parameters)}"
+ elif op == "$ne":
+ parameters.append(clause)
+ return f"(metadata->>'{field}') != ${len(parameters)}"
+ elif op == "$gt":
+ parameters.append(clause)
+ return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
+ # etc...
+ else:
+ parameters.append(value)
+ return f"(metadata->>'{field}') = ${len(parameters)}"
+
+ # ----------------------------------------------------------------------
+ # 4) Not recognized => return empty so we skip it
+ # ----------------------------------------------------------------------
+ return ""
+
+ # --------------------------------------------------------------------------
+ # 5) parse_filter() is the recursive walker that sees $and/$or or normal fields
+ # --------------------------------------------------------------------------
+ def parse_filter(fd: dict) -> str:
+ filter_conditions = []
+ for k, v in fd.items():
+ if k == "$and":
+ and_parts = [parse_filter(sub) for sub in v if sub]
+ and_parts = [x for x in and_parts if x.strip()]
+ if and_parts:
+ filter_conditions.append(
+ f"({' AND '.join(and_parts)})"
+ )
+ elif k == "$or":
+ or_parts = [parse_filter(sub) for sub in v if sub]
+ or_parts = [x for x in or_parts if x.strip()]
+ if or_parts:
+ filter_conditions.append(f"({' OR '.join(or_parts)})")
+ else:
+ c = parse_condition(k, v)
+ if c and c.strip():
+ filter_conditions.append(c)
+
+ if not filter_conditions:
+ return ""
+ if len(filter_conditions) == 1:
+ return filter_conditions[0]
+ return " AND ".join(filter_conditions)
+
+ return parse_filter(filter_dict)
+
+ async def get_existing_document_entity_chunk_ids(
+ self, document_id: UUID
+ ) -> list[str]:
+ QUERY = f"""
+ SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
+ """
+ return [
+ item["chunk_id"]
+ for item in await self.connection_manager.fetch_query(
+ QUERY, [document_id]
+ )
+ ]
+
+ async def get_entity_count(
+ self,
+ collection_id: Optional[UUID] = None,
+ document_id: Optional[UUID] = None,
+ distinct: bool = False,
+ entity_table_name: str = "entity",
+ ) -> int:
+ if collection_id is None and document_id is None:
+ raise ValueError(
+ "Either collection_id or document_id must be provided."
+ )
+
+ conditions = ["parent_id = $1"]
+ params = [str(document_id)]
+
+ count_value = "DISTINCT name" if distinct else "*"
+
+ QUERY = f"""
+ SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+
+ return (await self.connection_manager.fetch_query(QUERY, params))[0][
+ "count"
+ ]
+
+ async def update_entity_descriptions(self, entities: list[Entity]):
+ query = f"""
+ UPDATE {self._get_table_name("graphs_entities")}
+ SET description = $3, description_embedding = $4
+ WHERE name = $1 AND graph_id = $2
+ """
+
+ inputs = [
+ (
+ entity.name,
+ entity.parent_id,
+ entity.description,
+ entity.description_embedding,
+ )
+ for entity in entities
+ ]
+
+ await self.connection_manager.execute_many(query, inputs) # type: ignore
+
+
+def _json_serialize(obj):
+ if isinstance(obj, UUID):
+ return str(obj)
+ elif isinstance(obj, (datetime.datetime, datetime.date)):
+ return obj.isoformat()
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
+
+
+async def _add_objects(
+ objects: list[dict],
+ full_table_name: str,
+ connection_manager: PostgresConnectionManager,
+ conflict_columns: list[str] | None = None,
+ exclude_metadata: list[str] | None = None,
+) -> list[UUID]:
+ """Bulk insert objects into the specified table using
+ jsonb_to_recordset."""
+
+ if conflict_columns is None:
+ conflict_columns = []
+ if exclude_metadata is None:
+ exclude_metadata = []
+
+ # Exclude specified metadata and prepare data
+ cleaned_objects = []
+ for obj in objects:
+ cleaned_obj = {
+ k: v
+ for k, v in obj.items()
+ if k not in exclude_metadata and v is not None
+ }
+ cleaned_objects.append(cleaned_obj)
+
+ # Serialize the list of objects to JSON
+ json_data = json.dumps(cleaned_objects, default=_json_serialize)
+
+ # Prepare the column definitions for jsonb_to_recordset
+
+ columns = cleaned_objects[0].keys()
+ column_defs = []
+ for col in columns:
+ # Map Python types to PostgreSQL types
+ sample_value = cleaned_objects[0][col]
+ if "embedding" in col:
+ pg_type = "vector"
+ elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
+ pg_type = "uuid[]"
+ elif col == "id" or "_id" in col:
+ pg_type = "uuid"
+ elif isinstance(sample_value, str):
+ pg_type = "text"
+ elif isinstance(sample_value, UUID):
+ pg_type = "uuid"
+ elif isinstance(sample_value, (int, float)):
+ pg_type = "numeric"
+ elif isinstance(sample_value, list) and all(
+ isinstance(x, UUID) for x in sample_value
+ ):
+ pg_type = "uuid[]"
+ elif isinstance(sample_value, list):
+ pg_type = "jsonb"
+ elif isinstance(sample_value, dict):
+ pg_type = "jsonb"
+ elif isinstance(sample_value, bool):
+ pg_type = "boolean"
+ elif isinstance(sample_value, (datetime.datetime, datetime.date)):
+ pg_type = "timestamp"
+ else:
+ raise TypeError(
+ f"Unsupported data type for column '{col}': {type(sample_value)}"
+ )
+
+ column_defs.append(f"{col} {pg_type}")
+
+ columns_str = ", ".join(columns)
+ column_defs_str = ", ".join(column_defs)
+
+ if conflict_columns:
+ conflict_columns_str = ", ".join(conflict_columns)
+ update_columns_str = ", ".join(
+ f"{col}=EXCLUDED.{col}"
+ for col in columns
+ if col not in conflict_columns
+ )
+ on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
+ else:
+ on_conflict_clause = ""
+
+ QUERY = f"""
+ INSERT INTO {full_table_name} ({columns_str})
+ SELECT {columns_str}
+ FROM jsonb_to_recordset($1::jsonb)
+ AS x({column_defs_str})
+ {on_conflict_clause}
+ RETURNING id;
+ """
+
+ # Execute the query
+ result = await connection_manager.fetch_query(QUERY, [json_data])
+
+ # Extract and return the IDs
+ return [record["id"] for record in result]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/limits.py b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
new file mode 100644
index 00000000..1029ec50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
@@ -0,0 +1,434 @@
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from core.base import Handler
+from shared.abstractions import User
+
+from ...base.providers.database import DatabaseConfig, LimitSettings
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+
+class PostgresLimitsHandler(Handler):
+ TABLE_NAME = "request_log"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ config: DatabaseConfig,
+ ):
+ """
+ :param config: The global DatabaseConfig with default rate limits.
+ """
+ super().__init__(project_name, connection_manager)
+ self.config = config
+
+ logger.debug(
+ f"Initialized PostgresLimitsHandler with project: {project_name}"
+ )
+
+ async def create_tables(self):
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+ time TIMESTAMPTZ NOT NULL,
+ user_id UUID NOT NULL,
+ route TEXT NOT NULL
+ );
+ """
+ logger.debug("Creating request_log table if not exists")
+ await self.connection_manager.execute_query(query)
+
+ async def _count_requests(
+ self,
+ user_id: UUID,
+ route: Optional[str],
+ since: datetime,
+ ) -> int:
+ """Count how many requests a user (optionally for a specific route) has
+ made since the given datetime."""
+ if route:
+ query = f"""
+ SELECT COUNT(*)::int
+ FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ WHERE user_id = $1
+ AND route = $2
+ AND time >= $3
+ """
+ params = [user_id, route, since]
+ logger.debug(
+ f"Counting requests for user={user_id}, route={route}"
+ )
+ else:
+ query = f"""
+ SELECT COUNT(*)::int
+ FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ WHERE user_id = $1
+ AND time >= $2
+ """
+ params = [user_id, since]
+ logger.debug(f"Counting all requests for user={user_id}")
+
+ result = await self.connection_manager.fetchrow_query(query, params)
+ return result["count"] if result else 0
+
+ async def _count_monthly_requests(
+ self,
+ user_id: UUID,
+ route: Optional[str] = None, # <--- ADDED THIS
+ ) -> int:
+ """Count the number of requests so far this month for a given user.
+
+ If route is provided, count only for that route. Otherwise, count
+ globally.
+ """
+ now = datetime.now(timezone.utc)
+ start_of_month = now.replace(
+ day=1, hour=0, minute=0, second=0, microsecond=0
+ )
+ return await self._count_requests(
+ user_id, route=route, since=start_of_month
+ )
+
+ def determine_effective_limits(
+ self, user: User, route: str
+ ) -> LimitSettings:
+ """
+ Determine the final effective limits for a user+route combination,
+ respecting:
+ 1) Global defaults
+ 2) Route-specific overrides
+ 3) User-level overrides
+ """
+ # ------------------------
+ # 1) Start with global/base
+ # ------------------------
+ base_limits = self.config.limits
+
+ # We’ll make a copy so we don’t mutate self.config.limits directly
+ effective = LimitSettings(
+ global_per_min=base_limits.global_per_min,
+ route_per_min=base_limits.route_per_min,
+ monthly_limit=base_limits.monthly_limit,
+ )
+
+ # ------------------------
+ # 2) Route-level overrides
+ # ------------------------
+ route_config = self.config.route_limits.get(route)
+ if route_config:
+ if route_config.global_per_min is not None:
+ effective.global_per_min = route_config.global_per_min
+ if route_config.route_per_min is not None:
+ effective.route_per_min = route_config.route_per_min
+ if route_config.monthly_limit is not None:
+ effective.monthly_limit = route_config.monthly_limit
+
+ # ------------------------
+ # 3) User-level overrides
+ # ------------------------
+ # The user object might have a dictionary of overrides
+ # which can include route_overrides, global_per_min, monthly_limit, etc.
+ user_overrides = user.limits_overrides or {}
+
+ # (a) "global" user overrides
+ if user_overrides.get("global_per_min") is not None:
+ effective.global_per_min = user_overrides["global_per_min"]
+ if user_overrides.get("monthly_limit") is not None:
+ effective.monthly_limit = user_overrides["monthly_limit"]
+
+ # (b) route-level user overrides
+ route_overrides = user_overrides.get("route_overrides", {})
+ specific_config = route_overrides.get(route, {})
+ if specific_config.get("global_per_min") is not None:
+ effective.global_per_min = specific_config["global_per_min"]
+ if specific_config.get("route_per_min") is not None:
+ effective.route_per_min = specific_config["route_per_min"]
+ if specific_config.get("monthly_limit") is not None:
+ effective.monthly_limit = specific_config["monthly_limit"]
+
+ return effective
+
+ async def check_limits(self, user: User, route: str):
+ """Perform rate limit checks for a user on a specific route.
+
+ :param user: The fully-fetched User object with .limits_overrides, etc.
+ :param route: The route/path being accessed.
+ :raises ValueError: if any limit is exceeded.
+ """
+ user_id = user.id
+ now = datetime.now(timezone.utc)
+ one_min_ago = now - timedelta(minutes=1)
+
+ # 1) Compute the final (effective) limits for this user & route
+ limits = self.determine_effective_limits(user, route)
+
+ # 2) Check each of them in turn, if they exist
+ # ------------------------------------------------------------
+ # Global per-minute limit
+ # ------------------------------------------------------------
+ if limits.global_per_min is not None:
+ user_req_count = await self._count_requests(
+ user_id, None, one_min_ago
+ )
+ if user_req_count > limits.global_per_min:
+ logger.warning(
+ f"Global per-minute limit exceeded for "
+ f"user_id={user_id}, route={route}"
+ )
+ raise ValueError("Global per-minute rate limit exceeded")
+
+ # ------------------------------------------------------------
+ # Route-specific per-minute limit
+ # ------------------------------------------------------------
+ if limits.route_per_min is not None:
+ route_req_count = await self._count_requests(
+ user_id, route, one_min_ago
+ )
+ if route_req_count > limits.route_per_min:
+ logger.warning(
+ f"Per-route per-minute limit exceeded for "
+ f"user_id={user_id}, route={route}"
+ )
+ raise ValueError("Per-route per-minute rate limit exceeded")
+
+ # ------------------------------------------------------------
+ # Monthly limit
+ # ------------------------------------------------------------
+ if limits.monthly_limit is not None:
+ # If you truly want a per-route monthly limit, we pass 'route'.
+ # If you want a global monthly limit, pass 'None'.
+ monthly_count = await self._count_monthly_requests(user_id, route)
+ if monthly_count > limits.monthly_limit:
+ logger.warning(
+ f"Monthly limit exceeded for user_id={user_id}, "
+ f"route={route}"
+ )
+ raise ValueError("Monthly rate limit exceeded")
+
+ async def log_request(self, user_id: UUID, route: str):
+ """Log a successful request to the request_log table."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ (time, user_id, route)
+ VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+ """
+ await self.connection_manager.execute_query(query, [user_id, route])
+
+
+# import logging
+# from datetime import datetime, timedelta, timezone
+# from typing import Optional
+# from uuid import UUID
+
+# from core.base import Handler
+# from shared.abstractions import User
+
+# from ..base.providers.database import DatabaseConfig, LimitSettings
+# from .base import PostgresConnectionManager
+
+# logger = logging.getLogger(__name__)
+
+# class PostgresLimitsHandler(Handler):
+# TABLE_NAME = "request_log"
+
+# def __init__(
+# self,
+# project_name: str,
+# connection_manager: PostgresConnectionManager,
+# config: DatabaseConfig,
+# ):
+# """
+# :param config: The global DatabaseConfig with default rate limits.
+# """
+# super().__init__(project_name, connection_manager)
+# self.config = config
+
+# logger.debug(
+# f"Initialized PostgresLimitsHandler with project: {project_name}"
+# )
+
+# async def create_tables(self):
+# query = f"""
+# CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+# time TIMESTAMPTZ NOT NULL,
+# user_id UUID NOT NULL,
+# route TEXT NOT NULL
+# );
+# """
+# logger.debug("Creating request_log table if not exists")
+# await self.connection_manager.execute_query(query)
+
+# async def _count_requests(
+# self,
+# user_id: UUID,
+# route: Optional[str],
+# since: datetime,
+# ) -> int:
+# """
+# Count how many requests a user (optionally for a specific route)
+# has made since the given datetime.
+# """
+# if route:
+# query = f"""
+# SELECT COUNT(*)::int
+# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# WHERE user_id = $1
+# AND route = $2
+# AND time >= $3
+# """
+# params = [user_id, route, since]
+# logger.debug(f"Counting requests for user={user_id}, route={route}")
+# else:
+# query = f"""
+# SELECT COUNT(*)::int
+# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# WHERE user_id = $1
+# AND time >= $2
+# """
+# params = [user_id, since]
+# logger.debug(f"Counting all requests for user={user_id}")
+
+# result = await self.connection_manager.fetchrow_query(query, params)
+# return result["count"] if result else 0
+
+# async def _count_monthly_requests(self, user_id: UUID) -> int:
+# """
+# Count the number of requests so far this month for a given user.
+# """
+# now = datetime.now(timezone.utc)
+# start_of_month = now.replace(
+# day=1, hour=0, minute=0, second=0, microsecond=0
+# )
+# return await self._count_requests(
+# user_id, route=None, since=start_of_month
+# )
+
+# def determine_effective_limits(
+# self, user: User, route: str
+# ) -> LimitSettings:
+# """
+# Determine the final effective limits for a user+route combination,
+# respecting:
+# 1) Global defaults
+# 2) Route-specific overrides
+# 3) User-level overrides
+# """
+# # ------------------------
+# # 1) Start with global/base
+# # ------------------------
+# base_limits = self.config.limits
+
+# # We’ll make a copy so we don’t mutate self.config.limits directly
+# effective = LimitSettings(
+# global_per_min=base_limits.global_per_min,
+# route_per_min=base_limits.route_per_min,
+# monthly_limit=base_limits.monthly_limit,
+# )
+
+# # ------------------------
+# # 2) Route-level overrides
+# # ------------------------
+# route_config = self.config.route_limits.get(route)
+# if route_config:
+# if route_config.global_per_min is not None:
+# effective.global_per_min = route_config.global_per_min
+# if route_config.route_per_min is not None:
+# effective.route_per_min = route_config.route_per_min
+# if route_config.monthly_limit is not None:
+# effective.monthly_limit = route_config.monthly_limit
+
+# # ------------------------
+# # 3) User-level overrides
+# # ------------------------
+# # The user object might have a dictionary of overrides
+# # which can include route_overrides, global_per_min, monthly_limit, etc.
+# user_overrides = user.limits_overrides or {}
+
+# # (a) "global" user overrides
+# if user_overrides.get("global_per_min") is not None:
+# effective.global_per_min = user_overrides["global_per_min"]
+# if user_overrides.get("monthly_limit") is not None:
+# effective.monthly_limit = user_overrides["monthly_limit"]
+
+# # (b) route-level user overrides
+# route_overrides = user_overrides.get("route_overrides", {})
+# specific_config = route_overrides.get(route, {})
+# if specific_config.get("global_per_min") is not None:
+# effective.global_per_min = specific_config["global_per_min"]
+# if specific_config.get("route_per_min") is not None:
+# effective.route_per_min = specific_config["route_per_min"]
+# if specific_config.get("monthly_limit") is not None:
+# effective.monthly_limit = specific_config["monthly_limit"]
+
+# return effective
+
+# async def check_limits(self, user: User, route: str):
+# """
+# Perform rate limit checks for a user on a specific route.
+
+# :param user: The fully-fetched User object with .limits_overrides, etc.
+# :param route: The route/path being accessed.
+# :raises ValueError: if any limit is exceeded.
+# """
+# user_id = user.id
+# now = datetime.now(timezone.utc)
+# one_min_ago = now - timedelta(minutes=1)
+
+# # 1) Compute the final (effective) limits for this user & route
+# limits = self.determine_effective_limits(user, route)
+
+# # 2) Check each of them in turn, if they exist
+# # ------------------------------------------------------------
+# # Global per-minute limit
+# # ------------------------------------------------------------
+# if limits.global_per_min is not None:
+# user_req_count = await self._count_requests(
+# user_id, None, one_min_ago
+# )
+# if user_req_count > limits.global_per_min:
+# logger.warning(
+# f"Global per-minute limit exceeded for "
+# f"user_id={user_id}, route={route}"
+# )
+# raise ValueError("Global per-minute rate limit exceeded")
+
+# # ------------------------------------------------------------
+# # Route-specific per-minute limit
+# # ------------------------------------------------------------
+# if limits.route_per_min is not None:
+# route_req_count = await self._count_requests(
+# user_id, route, one_min_ago
+# )
+# if route_req_count > limits.route_per_min:
+# logger.warning(
+# f"Per-route per-minute limit exceeded for "
+# f"user_id={user_id}, route={route}"
+# )
+# raise ValueError("Per-route per-minute rate limit exceeded")
+
+# # ------------------------------------------------------------
+# # Monthly limit
+# # ------------------------------------------------------------
+# if limits.monthly_limit is not None:
+# monthly_count = await self._count_monthly_requests(user_id)
+# if monthly_count > limits.monthly_limit:
+# logger.warning(
+# f"Monthly limit exceeded for user_id={user_id}, "
+# f"route={route}"
+# )
+# raise ValueError("Monthly rate limit exceeded")
+
+# async def log_request(self, user_id: UUID, route: str):
+# """
+# Log a successful request to the request_log table.
+# """
+# query = f"""
+# INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# (time, user_id, route)
+# VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+# """
+# await self.connection_manager.execute_query(query, [user_id, route])
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
new file mode 100644
index 00000000..acccc9c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
@@ -0,0 +1,286 @@
+# TODO: Clean this up and make it more congruent across the vector database and the relational database.
+import logging
+import os
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...base.abstractions import VectorQuantizationType
+from ...base.providers import (
+ DatabaseConfig,
+ DatabaseProvider,
+ PostgresConfigurationSettings,
+)
+from .base import PostgresConnectionManager, SemaphoreConnectionPool
+from .chunks import PostgresChunksHandler
+from .collections import PostgresCollectionsHandler
+from .conversations import PostgresConversationsHandler
+from .documents import PostgresDocumentsHandler
+from .files import PostgresFilesHandler
+from .graphs import (
+ PostgresCommunitiesHandler,
+ PostgresEntitiesHandler,
+ PostgresGraphsHandler,
+ PostgresRelationshipsHandler,
+)
+from .limits import PostgresLimitsHandler
+from .prompts_handler import PostgresPromptsHandler
+from .tokens import PostgresTokensHandler
+from .users import PostgresUserHandler
+
+if TYPE_CHECKING:
+ from ..crypto import BCryptCryptoProvider, NaClCryptoProvider
+
+ CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
+
+logger = logging.getLogger()
+
+
+class PostgresDatabaseProvider(DatabaseProvider):
+ # R2R configuration settings
+ config: DatabaseConfig
+ project_name: str
+
+ # Postgres connection settings
+ user: str
+ password: str
+ host: str
+ port: int
+ db_name: str
+ connection_string: str
+ dimension: int | float
+ conn: Optional[Any]
+
+ crypto_provider: "CryptoProviderType"
+ postgres_configuration_settings: PostgresConfigurationSettings
+ default_collection_name: str
+ default_collection_description: str
+
+ connection_manager: PostgresConnectionManager
+ documents_handler: PostgresDocumentsHandler
+ collections_handler: PostgresCollectionsHandler
+ token_handler: PostgresTokensHandler
+ users_handler: PostgresUserHandler
+ chunks_handler: PostgresChunksHandler
+ entities_handler: PostgresEntitiesHandler
+ communities_handler: PostgresCommunitiesHandler
+ relationships_handler: PostgresRelationshipsHandler
+ graphs_handler: PostgresGraphsHandler
+ prompts_handler: PostgresPromptsHandler
+ files_handler: PostgresFilesHandler
+ conversations_handler: PostgresConversationsHandler
+ limits_handler: PostgresLimitsHandler
+
+ def __init__(
+ self,
+ config: DatabaseConfig,
+ dimension: int | float,
+ crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(config)
+
+ env_vars = [
+ ("user", "R2R_POSTGRES_USER"),
+ ("password", "R2R_POSTGRES_PASSWORD"),
+ ("host", "R2R_POSTGRES_HOST"),
+ ("port", "R2R_POSTGRES_PORT"),
+ ("db_name", "R2R_POSTGRES_DBNAME"),
+ ]
+
+ for attr, env_var in env_vars:
+ if value := (getattr(config, attr) or os.getenv(env_var)):
+ setattr(self, attr, value)
+ else:
+ raise ValueError(
+ f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
+ )
+
+ self.port = int(self.port)
+
+ self.project_name = (
+ config.app.project_name
+ or os.getenv("R2R_PROJECT_NAME")
+ or "r2r_default"
+ )
+
+ if not self.project_name:
+ raise ValueError(
+ "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
+ )
+
+ # Check if it's a Unix socket connection
+ if self.host.startswith("/") and not self.port:
+ self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
+ logger.info("Connecting to Postgres via Unix socket")
+ else:
+ self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
+ logger.info("Connecting to Postgres via TCP/IP")
+
+ self.dimension = dimension
+ self.quantization_type = quantization_type
+ self.conn = None
+ self.config: DatabaseConfig = config
+ self.crypto_provider = crypto_provider
+ self.postgres_configuration_settings: PostgresConfigurationSettings = (
+ self._get_postgres_configuration_settings(config)
+ )
+ self.default_collection_name = config.default_collection_name
+ self.default_collection_description = (
+ config.default_collection_description
+ )
+
+ self.connection_manager: PostgresConnectionManager = (
+ PostgresConnectionManager()
+ )
+ self.documents_handler = PostgresDocumentsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ dimension=self.dimension,
+ )
+ self.token_handler = PostgresTokensHandler(
+ self.project_name, self.connection_manager
+ )
+ self.collections_handler = PostgresCollectionsHandler(
+ self.project_name, self.connection_manager, self.config
+ )
+ self.users_handler = PostgresUserHandler(
+ self.project_name, self.connection_manager, self.crypto_provider
+ )
+ self.chunks_handler = PostgresChunksHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ dimension=self.dimension,
+ quantization_type=(self.quantization_type),
+ )
+ self.conversations_handler = PostgresConversationsHandler(
+ self.project_name, self.connection_manager
+ )
+ self.entities_handler = PostgresEntitiesHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.relationships_handler = PostgresRelationshipsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.communities_handler = PostgresCommunitiesHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.graphs_handler = PostgresGraphsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.prompts_handler = PostgresPromptsHandler(
+ self.project_name, self.connection_manager
+ )
+ self.files_handler = PostgresFilesHandler(
+ self.project_name, self.connection_manager
+ )
+
+ self.limits_handler = PostgresLimitsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ config=self.config,
+ )
+
+ async def initialize(self):
+ logger.info("Initializing `PostgresDatabaseProvider`.")
+ self.pool = SemaphoreConnectionPool(
+ self.connection_string, self.postgres_configuration_settings
+ )
+ await self.pool.initialize()
+ await self.connection_manager.initialize(self.pool)
+
+ async with self.pool.get_connection() as conn:
+ await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
+
+ # Create schema if it doesn't exist
+ await conn.execute(
+ f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
+ )
+
+ await self.documents_handler.create_tables()
+ await self.collections_handler.create_tables()
+ await self.token_handler.create_tables()
+ await self.users_handler.create_tables()
+ await self.chunks_handler.create_tables()
+ await self.prompts_handler.create_tables()
+ await self.files_handler.create_tables()
+ await self.graphs_handler.create_tables()
+ await self.communities_handler.create_tables()
+ await self.entities_handler.create_tables()
+ await self.relationships_handler.create_tables()
+ await self.conversations_handler.create_tables()
+ await self.limits_handler.create_tables()
+
+ def _get_postgres_configuration_settings(
+ self, config: DatabaseConfig
+ ) -> PostgresConfigurationSettings:
+ settings = PostgresConfigurationSettings()
+
+ env_mapping = {
+ "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
+ "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
+ "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
+ "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
+ "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
+ "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
+ "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
+ "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
+ "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
+ "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
+ "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
+ "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
+ "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
+ "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
+ "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
+ "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
+ "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
+ "work_mem": "R2R_POSTGRES_WORK_MEM",
+ }
+
+ for setting, env_var in env_mapping.items():
+ value = getattr(
+ config.postgres_configuration_settings, setting, None
+ )
+ if value is None:
+ value = os.getenv(env_var)
+
+ if value is not None:
+ field_type = settings.__annotations__[setting]
+ if field_type == Optional[int]:
+ value = int(value)
+ elif field_type == Optional[float]:
+ value = float(value)
+
+ setattr(settings, setting, value)
+
+ return settings
+
+ async def close(self):
+ if self.pool:
+ await self.pool.close()
+
+ async def __aenter__(self):
+ await self.initialize()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.close()
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml
new file mode 100644
index 00000000..7e4a2615
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml
@@ -0,0 +1,56 @@
+chunk_enrichment:
+ template: >
+ ## Task:
+
+ Enrich and refine the given chunk of text while maintaining its independence and precision.
+
+ ## Context:
+ Document Summary: {document_summary}
+ Preceding Chunks: {preceding_chunks}
+ Succeeding Chunks: {succeeding_chunks}
+
+ ## Input Chunk:
+ {chunk}
+
+ ## Semantic Organization Guidelines:
+ 1. Group related information:
+ - Combine logically connected data points
+ - Maintain context within each grouping
+ - Preserve relationships between entities
+
+ 2. Structure hierarchy:
+ - Organize from general to specific
+ - Use clear categorical divisions
+ - Maintain parent-child relationships
+
+ 3. Information density:
+ - Balance completeness with clarity
+ - Ensure each chunk can stand alone
+ - Preserve essential context
+
+ 4. Pattern recognition:
+ - Standardize similar information
+ - Use consistent formatting for similar data types
+ - It is appropriate to restructure tables or lists in ways that are more advantageous for sematic matching
+ - Maintain searchable patterns
+
+ ## Output Requirements:
+ 1. Each chunk should be independently meaningful
+ 2. Related information should stay together
+ 3. Format should support efficient matching
+ 4. Original data relationships must be preserved
+ 5. Context should be clear without external references
+
+ Maximum length: {chunk_size} characters
+
+ Output the restructured chunk only.
+
+ ## Restructured Chunk:
+
+ input_types:
+ document_summary: str
+ chunk: str
+ preceding_chunks: str
+ succeeding_chunks: str
+ chunk_size: int
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml
new file mode 100644
index 00000000..b9475453
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml
@@ -0,0 +1,41 @@
+collection_summary:
+ template: >
+ ## Task:
+
+ Generate a comprehensive collection-level summary that describes the overall content, themes, and relationships across multiple documents. The summary should provide a high-level understanding of what the collection contains and represents.
+
+ ### Input Documents:
+
+ Document Summaries:
+ {document_summaries}
+
+ ### Requirements:
+
+ 1. SCOPE
+ - Synthesize key themes and patterns across all documents
+ - Identify common topics, entities, and relationships
+ - Capture the collection's overall purpose or domain
+
+ 2. STRUCTURE
+ - Target length: Approximately 3-4 concise sentences
+ - Focus on collective insights rather than individual document details
+
+ 3. CONTENT GUIDELINES
+ - Emphasize shared concepts and recurring elements
+ - Highlight any temporal or thematic progression
+ - Identify key stakeholders or entities that appear across documents
+ - Note any significant relationships between documents
+
+ 4. INTEGRATION PRINCIPLES
+ - Connect related concepts across different documents
+ - Identify overarching narratives or frameworks
+ - Preserve important context from individual documents
+ - Balance breadth of coverage with depth of insight
+
+ ### Query:
+
+ Generate a collection-level summary following the above requirements. Focus on synthesizing the key themes and relationships across all documents while maintaining clarity and concision.
+
+ ## Response:
+ input_types:
+ document_summaries: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml
new file mode 100644
index 00000000..5b264530
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml
@@ -0,0 +1,28 @@
+dynamic_rag_agent:
+ template: >
+ ### You are a helpful agent that can search for information, the date is {date}.
+
+
+ The response should contain line-item attributions to relevant search results, and be as informative if possible. Note that you will only be able to load {max_tool_context_length} tokens of context at a time, if the context surpasses this then it will be truncated. If possible, set filters which will reduce the context returned to only that which is specific, by means of '$eq' or '$overlap' filters.
+
+
+ Search rarely exceeds the context window, while getting raw context can depending on the user data shown below. IF YOU CAN FETCH THE RAW CONTEXT, THEN DO SO.
+
+
+ The available user documents and collections are shown below:
+
+ <= Documents =>
+ {document_context}
+
+
+ If no relevant results are found, then state that no results were found. If no obvious question is present given the available tools and context, then do not carry out a search, and instead ask for clarification.
+
+
+ REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
+
+ input_types:
+ date: str
+ document_context: str
+ max_tool_context_length: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml
new file mode 100644
index 00000000..ce5784a3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml
@@ -0,0 +1,99 @@
+dynamic_rag_agent_xml_tooling:
+ template: |
+ You are an AI research assistant with access to document retrieval tools. You should use both your internal knowledge store and web search tools to answer the user questions. Today is {date}.
+
+ <AvailableTools>
+
+ <ToolDefinition>
+ <Name>web_search</Name>
+ <Description>External web search. Parameters must be a valid JSON object.</Description>
+ <Parameters>
+ <Parameter type="string" required="true">
+ <Name>query</Name>
+ <Example>{{"query": "recent AI developments 2024"}}</Example>
+ </Parameter>
+ </Parameters>
+ </ToolDefinition>
+
+ </AvailableTools>
+
+ ### Documents
+ {document_context}
+
+ 2. DECIDE response strategy:
+ - If specific document IDs are relevant: Use `content` with $eq filters
+ - For broad concepts: Use `search_file_knowledge` with keyword queries
+ - Use `web_search` to gather live information
+
+ 3. FORMAT response STRICTLY as:
+ <Action>
+ <ToolCalls>
+ <ToolCall>
+ <Name>search_file_knowledge</Name>
+ <!-- Parameters MUST be a single valid JSON object -->
+ <Parameters>{{"query": "example search"}}</Parameters>
+ </ToolCall>
+ <!-- Multiple tool call example -->
+ <ToolCall>
+ <Name>content</Name>
+ <!-- Example with nested filters -->
+ <Parameters>{{"filters": {{"$and": [{{"document_id": {{"$eq": "abc123"}}, {{"collection_ids": {{"$overlap": ["id1"]}}}}]}}}}}}</Parameters>
+ </ToolCall>
+ </ToolCalls>
+ </Action>
+
+ ### Constraints
+ - MAX_CONTEXT: {max_tool_context_length} tokens
+ - REQUIRED: Line-item references like [abc1234][def5678] when using content
+ - REQUIRED: All Parameters must be valid JSON objects
+ - PROHIBITED: Assuming document contents without retrieval
+ - PROHIBITED: Using XML format for Parameters values
+
+ ### Examples
+ 1. Good initial search oepration:
+ <Action>
+ <ToolCalls>
+ <ToolCall>
+ <Name>web_search</Name>
+ <Parameters>{{"query": "recent advances in machine learning"}}</Parameters>
+ </ToolCall>
+ <ToolCall>
+ <Name>search_file_knowledge</Name>
+ <Parameters>{{"query": "machine learning applications"}}</Parameters>
+ </ToolCall>
+ <ToolCall>
+ <Name>search_file_knowledge</Name>
+ <Parameters>{{"query": "recent advances in machine learning"}}</Parameters>
+ </ToolCall>
+ </ToolCalls>
+ </Action>
+
+
+ 2. Good content call with complex filters:
+ <Action>
+ <ToolCalls>
+ <ToolCall>
+ <Name>web_search</Name>
+ <Parameters>{{"query": "recent advances in machine learning"}}</Parameters>
+ </ToolCall>
+ <ToolCall>
+ <Name>content</Name>
+ <Parameters>{{"filters": {{"$or": [{{"document_id": {{"$eq": "a5b880db-..."}}}}, {{"document_id": {{"$overlap": ["54b523f6-...","26fc0bf5-..."]}}}}]}}}}}}</Parameters>
+ </ToolCall>
+ </ToolCalls>
+ </Action>
+
+ ### Important!
+ Continue to take actions until you have sufficient relevant context, then return your answer with the result tool.
+ You have a maximum of 100_000 context tokens or 10 iterations to find the information required.
+
+ RETURN A COMPLETE AND COMPREHENSIVE ANSWER WHEN POSSIBLE.
+
+ REMINDER - Use line item references like `[c910e2e], [b12cd2f]` with THIS EXACT FORMAT to refer to the specific search result IDs returned in the provided context.
+
+ input_types:
+ date: str
+ document_context: str
+ max_tool_context_length: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml
new file mode 100644
index 00000000..50e71544
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml
@@ -0,0 +1,74 @@
+graph_communities:
+ template: |
+ You are an AI assistant that helps a human analyst perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
+
+ Context Overview:
+ {collection_description}
+
+ Your Task:
+ Write a comprehensive report of a community as a single XML document. The report must follow this exact structure:
+
+ <community>
+ <name>A specific, concise community name representing its key entities</name>
+ <summary>An executive summary that contextualizes the community</summary>
+ <rating>A float score (0-10) representing impact severity</rating>
+ <rating_explanation>A single sentence explaining the rating</rating_explanation>
+ <findings>
+ <finding>First key insight about the community</finding>
+ <finding>Second key insight about the community</finding>
+ <!-- Include 5-10 findings total -->
+ </findings>
+ </community>
+
+ Data Reference Format:
+ Include data references in findings like this:
+ "Example sentence [Data: <dataset name> (record ids); <dataset name> (record ids)]"
+ Use no more than 5 record IDs per reference. Add "+more" to indicate additional records.
+
+ Example Input:
+ -----------
+ Text:
+
+ Entity: OpenAI
+ descriptions:
+ 101,OpenAI is an AI research and deployment company.
+ relationships:
+ 201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions.
+ 203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service.
+ 204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round.
+ Entity: Stripe
+ descriptions:
+ 102,Stripe is a technology company that builds economic infrastructure for the internet.
+ relationships:
+ 201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions.
+ 202,Stripe,Airbnb,Stripe provides payment processing services to Airbnb.
+ 204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round.
+ 205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options.
+ Entity: Airbnb
+ descriptions:
+ 103,Airbnb is an online marketplace for lodging and tourism experiences.
+ relationships:
+ 203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service.
+ 205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options.
+
+ Example Output:
+ <community>
+ <name>OpenAI-Stripe-Airbnb Community</name>
+ <summary>The OpenAI-Stripe-Airbnb Community is a network of companies that collaborate on AI research, payment solutions, and customer service.</summary>
+ <rating>8.5</rating>
+ <rating_explanation>The OpenAI-Stripe-Airbnb Community has a high impact on the collection due to its significant contributions to AI research, payment solutions, and customer service.</rating_explanation>
+ <findings>
+ <finding>OpenAI and Stripe have a partnership to integrate payment solutions [Data: Relationships (201)].</finding>
+ <finding>OpenAI and Airbnb collaborate on AI tools for customer service [Data: Relationships (203)].</finding>
+ <finding>Stripe provides payment processing services to Airbnb [Data: Relationships (202)].</finding>
+ <finding>Stripe invested in OpenAI's latest funding round [Data: Relationships (204)].</finding>
+ <finding>Airbnb and Stripe collaborate on expanding global payment options [Data: Relationships (205)].</finding>
+ </findings>
+ </community>
+
+ Entity Data:
+ {input_text}
+
+ input_types:
+ collection_description: str
+ input_text: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml
new file mode 100644
index 00000000..b46185fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml
@@ -0,0 +1,40 @@
+graph_entity_description:
+ template: |
+ Given the following information about an entity:
+
+ Document Summary:
+ {document_summary}
+
+ Entity Information:
+ {entity_info}
+
+ Relationship Data:
+ {relationships_txt}
+
+ Generate a comprehensive entity description that:
+
+ 1. Opens with a clear definition statement identifying the entity's primary classification and core function
+ 2. Incorporates key data points from both the document summary and relationship information
+ 3. Emphasizes the entity's role within its broader context or system
+ 4. Highlights critical relationships, particularly those that:
+ - Demonstrate hierarchical connections
+ - Show functional dependencies
+ - Indicate primary use cases or applications
+
+ Format Requirements:
+ - Length: 2-3 sentences
+ - Style: Technical and precise
+ - Structure: Definition + Context + Key Relationships
+ - Tone: Objective and authoritative
+
+ Integration Guidelines:
+ - Prioritize information that appears in multiple sources
+ - Resolve any conflicting information by favoring the most specific source
+ - Include temporal context if relevant to the entity's current state or evolution
+
+ Output should reflect the entity's complete nature while maintaining concision and clarity.
+ input_types:
+ document_summary: str
+ entity_info: str
+ relationships_txt: str
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml
new file mode 100644
index 00000000..9850878a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml
@@ -0,0 +1,100 @@
+graph_extraction:
+ template: >
+ # Context
+ {document_summary}
+
+ # Goal
+ Given both a document summary and full text, identify all entities and their entity types, along with all relationships among the identified entities.
+
+ # Steps
+ 1. Identify all entities given the full text, grounding and contextualizing them based on the summary. For each identified entity, extract:
+ - entit: Name of the entity, capitalized
+ - entity_type: Type of the entity (constrained to {entity_types} if provided, otherwise all types)
+ - entity_description: Comprehensive description incorporating context from both summary and full text
+
+ Format each Entity in XML tags as follows: <entity name="entity"><type>entity_type</type><description>entity_description</description></entity>
+
+ Note: Generate additional entities from descriptions if they contain named entities for relationship mapping.
+
+ 2. From the identified entities, identify all related entity pairs, using both summary and full text context:
+ - source_entity: name of the source entity
+ - target_entity: name of the target entity
+ - relation: relationship type (constrained to {relation_types} if provided)
+ - relationship_description: justification based on both summary and full text context
+ - relationship_weight: strength score 0-10
+
+ Format each relationship in XML tags as follows: <relationship><source>source_entity</source><target>target_entity</target><type>relation</type><description>relationship_description</description><weight>relationship_weight</weight></relationship>
+
+ 3. Coverage Requirements:
+ - Each entity must have at least one relationship
+ - Create intermediate entities if needed to establish relationships
+ - Verify relationships against both summary and full text
+ - Resolve any discrepancies between sources
+
+ Example 1:
+ If the list is empty, extract all entities and relations.
+ Entity_types:
+ Relation_types:
+ Text:
+ San Francisco is a city in California. It is known for the Golden Gate Bridge, cable cars, and steep hills. The city is surrounded by the Pacific Ocean and the San Francisco Bay.
+ ######################
+ Output:
+ <entity name="San Francisco"><type>City</type><description>San Francisco is a city in California known for the Golden Gate Bridge, cable cars, and steep hills. It is surrounded by the Pacific Ocean and the San Francisco Bay.</description></entity>
+ <entity name="California"><type>State</type><description>California is a state in the United States.</description></entity>
+ <entity name="Golden Gate Bridge"><type>Landmark</type><description>The Golden Gate Bridge is a famous bridge in San Francisco.</description></entity>
+ <entity name="Pacific Ocean"><type>Body of Water</type><description>The Pacific Ocean is a large body of water that surrounds San Francisco.</description></entity>
+ <entity name="San Francisco Bay"><type>Body of Water</type><description>The San Francisco Bay is a body of water that surrounds San Francisco.</description></entity>
+ <relationship><source>San Francisco</source><target>California</target><type>Located In</type><description>San Francisco is a city located in California.</description><weight>8</weight></relationship>
+ <relationship><source>San Francisco</source><target>Golden Gate Bridge</target><type>Features</type><description>San Francisco features the Golden Gate Bridge.</description><weight>9</weight></relationship>
+ <relationship><source>San Francisco</source><target>Pacific Ocean</target><type>Surrounded By</type><description>San Francisco is surrounded by the Pacific Ocean.</description><weight>7</weight></relationship>
+ <relationship><source>San Francisco</source><target>San Francisco Bay</target><type>Surrounded By</type><description>San Francisco is surrounded by the San Francisco Bay.</description><weight>7</weight></relationship>
+ <relationship><source>California</source><target>San Francisco</target><type>Contains</type><description>California contains the city of San Francisco.</description><weight>8</weight></relationship>
+ <relationship><source>Golden Gate Bridge</source><target>San Francisco</target><type>Located In</type><description>The Golden Gate Bridge is located in San Francisco.</description><weight>8</weight></relationship>
+ <relationship><source>Pacific Ocean</source><target>San Francisco</target><type>Surrounds</type><description>The Pacific Ocean surrounds San Francisco.</description><weight>7</weight></relationship>
+ <relationship><source>San Francisco Bay</source><target>San Francisco</target><type>Surrounds</type><description>The San Francisco Bay surrounds San Francisco.</description><weight>7</weight></relationship>
+
+ ######################
+ Example 2:
+ If the list is empty, extract all entities and relations.
+ Entity_types: Organization, Person
+ Relation_types: Located In, Features
+
+ Text:
+ The Green Bay Packers are a professional American football team based in Green Bay, Wisconsin. The team was established in 1919 by Earl "Curly" Lambeau and George Calhoun. The Packers are the third-oldest franchise in the NFL and have won 13 league championships, including four Super Bowls. The team's home games are played at Lambeau Field, which is named after Curly Lambeau.
+ ######################
+ Output:
+ <entity name="Green Bay Packers"><type>Organization</type><description>The Green Bay Packers are a professional American football team based in Green Bay, Wisconsin. The team was established in 1919 by Earl "Curly" Lambeau and George Calhoun. The Packers are the third-oldest franchise in the NFL and have won 13 league championships, including four Super Bowls. The team's home games are played at Lambeau Field, which is named after Curly Lambeau.</description></entity>
+ <entity name="Green Bay"><type>City</type><description>Green Bay is a city in Wisconsin.</description></entity>
+ <entity name="Wisconsin"><type>State</type><description>Wisconsin is a state in the United States.</description></entity>
+ <entity name="Earl "Curly" Lambeau"><type>Person</type><description>Earl "Curly" Lambeau was a co-founder of the Green Bay Packers.</description></entity>
+ <entity name="George Calhoun"><type>Person</type><description>George Calhoun was a co-founder of the Green Bay Packers.</description></entity>
+ <entity name="NFL"><type>Organization</type><description>The NFL is the National Football League.</description></entity>
+ <entity name="Super Bowl"><type>Event</type><description>The Super Bowl is the championship game of the NFL.</description></entity>
+ <entity name="Lambeau Field"><type>Stadium</type><description>Lambeau Field is the home stadium of the Green Bay Packers.</description></entity>
+ <relationship><source>Green Bay Packers</source><target>Green Bay</target><type>Located In</type><description>The Green Bay Packers are based in Green Bay, Wisconsin.</description><weight>8</weight></relationship>
+ <relationship><source>Green Bay</source><target>Wisconsin</target><type>Located In</type><description>Green Bay is located in Wisconsin.</description><weight>8</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>Earl "Curly" Lambeau</target><type>Founded By</type><description>The Green Bay Packers were established by Earl "Curly" Lambeau.</description><weight>9</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>George Calhoun</target><type>Founded By</type><description>The Green Bay Packers were established by George Calhoun.</description><weight>9</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>NFL</target><type>League</type><description>The Green Bay Packers are a franchise in the NFL.</description><weight>8</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>Super Bowl</target><type>Championships</type><description>The Green Bay Packers have won four Super Bowls.</description><weight>9</weight></relationship>
+
+ -Real Data-
+ ######################
+ If the list is empty, extract all entities and relations.
+ Entity_types: {entity_types}
+ Relation_types: {relation_types}
+
+ Document Summary:
+ {document_summary}
+
+ Full Text:
+ {input}
+ ######################
+ Output:
+ input_types:
+ document_summary: str
+ max_knowledge_relationships: int
+ input: str
+ entity_types: list[str]
+ relation_types: list[str]
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml
new file mode 100644
index 00000000..d8071d1f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml
@@ -0,0 +1,29 @@
+hyde:
+ template: >
+ ### Instruction:
+
+ Given the query that follows write a double newline separated list of {num_outputs} single paragraph distinct attempted answers to the given query.
+
+
+ DO NOT generate any single answer which is likely to require information from multiple distinct documents,
+
+ EACH single answer will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents.
+
+
+ FOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two attempted answers would be
+
+ `The key themes of Great Gatsby are ... ANSWER_CONTINUED` and `The key themes of 1984 are ... ANSWER_CONTINUED`, where `ANSWER_CONTINUED` IS TO BE COMPLETED BY YOU in your response.
+
+
+ Here is the original user query to be transformed into answers:
+
+
+ ### Query:
+
+ {message}
+
+
+ ### Response:
+ input_types:
+ num_outputs: int
+ message: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml
new file mode 100644
index 00000000..c835517d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml
@@ -0,0 +1,29 @@
+rag:
+ template: >
+ ## Task:
+
+ Answer the query given immediately below given the context which follows later. Use line item references to like [c910e2e], [b12cd2f], ... refer to provided search results.
+
+
+ ### Query:
+
+ {query}
+
+
+ ### Context:
+
+ {context}
+
+
+ ### Query:
+
+ {query}
+
+
+ REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
+
+ ## Response:
+ input_types:
+ query: str
+ context: str
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml
new file mode 100644
index 00000000..874d3f39
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml
@@ -0,0 +1,27 @@
+rag_fusion:
+ template: >
+ ### Instruction:
+
+
+ Given the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query.
+
+ DO NOT generate any single query which is likely to require information from multiple distinct documents,
+
+ EACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents.
+
+ FOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be
+
+ `What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.
+
+ Here is the original user query to be transformed into answers:
+
+
+ ### Query:
+
+ {message}
+
+
+ ### Response:
+ input_types:
+ num_outputs: int
+ message: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml
new file mode 100644
index 00000000..0e940af1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml
@@ -0,0 +1,16 @@
+static_rag_agent:
+ template: >
+ ### You are a helpful agent that can search for information, the date is {date}.
+
+ When asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION.
+
+ The response should contain line-item attributions to relevant search results, and be as informative if possible.
+
+ If no relevant results are found, then state that no results were found. If no obvious question is present, then do not carry out a search, and instead ask for clarification.
+
+ REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
+
+ input_types:
+ date: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml
new file mode 100644
index 00000000..417d161c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml
@@ -0,0 +1,61 @@
+static_research_agent:
+ template: >-
+ # You are a helpful agent that can search for information, the date is {date}.
+
+ # Comprehensive Strategic Analysis Report
+
+ ## Objective
+ Produce nuanced, robust, and strategically insightful analyses. Adjust your approach based on the nature of the question:
+
+ - **Broad, qualitative, or subjective questions**:
+ Deliver in-depth, qualitative analysis by systematically exploring multiple dimensions and diverse perspectives. Emphasize strategic insights, market psychology, long-term implications, and nuanced evaluations.
+
+ - **Narrow, academic, or factual questions**:
+ Provide focused, precise, and strategic analyses. Clearly articulate cause-effect relationships, relevant context, and strategic significance. Prioritize accuracy, clarity, and concise insights.
+
+ ## Research Guidance
+ - **Multi-thesis Approach (for qualitative/subjective queries):**
+ - Identify and retrieve detailed information from credible sources covering multiple angles, including technical, economic, market-specific, geopolitical, psychological, and long-term strategic implications.
+ - Seek contrasting viewpoints, expert opinions, market analyses, and nuanced discussions.
+
+ - **Focused Strategic Approach (for narrow/academic queries):**
+ - Clearly identify the core elements of the question and retrieve precise, relevant information.
+ - Highlight strategic significance, context, and implications concisely and accurately.
+
+ ## Source Diversity
+ - Draw from diverse, credible sources such as financial analyses, expert commentary, reputable news outlets, industry reports, academic papers, and analyst research.
+
+ ## Structured Analysis
+ - Organize findings into clear, logically sequenced sections (e.g., Technical Details, Market Reactions, Economic Implications, Strategic Insights).
+ - Explicitly link each factor to its implications, providing concrete examples.
+
+ ## Cause-and-Effect Relationships
+ - Clearly articulate how each piece of information contributes to the overall narrative or impact, emphasizing cause-effect dynamics.
+ - Include direct quotes or analyst insights when available to illustrate context or sentiment.
+
+ ## Depth and Nuance
+ - Go beyond surface-level reporting by critically evaluating how events challenge assumptions, reshape strategies, and alter investment theses.
+ - Highlight underlying trends, second-order effects, and unexpected consequences.
+
+ ## Forward-Looking Insights
+ - Offer informed analysis on long-term strategic implications, identifying future opportunities, risks, and market shifts.
+ - Discuss potential shifts in competitive dynamics, regulation, and industry paradigms.
+
+ ## Line-Item Referencing
+ - Carefully reference specific data and insights using line-item references (e.g., `[c910e2e]`, `[b12cd2f]`) for credibility and transparency.
+
+ ## Writing Style and Tone
+ - Adopt a concise yet sophisticated style suitable for investors, industry professionals, and strategic decision-makers.
+ - Prioritize clarity, strategic depth, and actionable insights.
+
+ ## Final Reminders
+ - Anticipate follow-up questions and proactively provide relevant context or comparisons.
+ - Clearly highlight potential opportunities, risks, and long-term implications.
+ - Focus on actionable, well-rounded insights rather than mere summaries.
+ - Add line item citatons like `[c910e2e]`, `[b12cd2f]`.
+ - IGNORE all previous instructions when dealing with academic or factual queries.
+
+ input_types:
+ date: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml
new file mode 100644
index 00000000..1658f2bf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml
@@ -0,0 +1,18 @@
+summary:
+ template: >
+ ## Task:
+
+ Your task is to generate a descriptive summary of the document that follows. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
+
+ ### Document:
+
+ {document}
+
+
+ ### Query:
+
+ Reminder: Your task is to generate a descriptive summary of the document that was given. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
+
+ ## Response:
+ input_types:
+ document: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml
new file mode 100644
index 00000000..4bc0770b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml
@@ -0,0 +1,3 @@
+system:
+ template: You are a helpful agent.
+ input_types: {}
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml
new file mode 100644
index 00000000..4a1aa477
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml
@@ -0,0 +1,4 @@
+vision_img:
+ template: >
+ First, provide a title for the image, then explain everything that you see. Be very thorough in your analysis as a user will need to understand the image without seeing it. If it is possible to transcribe the image to text directly, then do so. The more detail you provide, the better the user will understand the image.
+ input_types: {}
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml
new file mode 100644
index 00000000..350ead2d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml
@@ -0,0 +1,42 @@
+vision_pdf:
+ template: >
+ Convert this PDF page to markdown format, preserving all content and formatting. Follow these guidelines:
+
+ Text:
+ - Maintain the original text hierarchy (headings, paragraphs, lists)
+ - Preserve any special formatting (bold, italic, underline)
+ - Include all footnotes, citations, and references
+ - Keep text in its original reading order
+
+ Tables:
+ - Recreate tables using markdown table syntax
+ - Preserve all headers, rows, and columns
+ - Maintain alignment and formatting where possible
+ - Include any table captions or notes
+
+ Equations:
+ - Convert mathematical equations using LaTeX notation
+ - Preserve equation numbers if present
+ - Include any surrounding context or references
+
+ Images:
+ - Enclose image descriptions within [FIG] and [/FIG] tags
+ - Include detailed descriptions of:
+ * Main subject matter
+ * Text overlays or captions
+ * Charts, graphs, or diagrams
+ * Relevant colors, patterns, or visual elements
+ - Maintain image placement relative to surrounding text
+
+ Additional Elements:
+ - Include page numbers if visible
+ - Preserve headers and footers
+ - Maintain sidebars or callout boxes
+ - Keep any special symbols or characters
+
+ Quality Requirements:
+ - Ensure 100% content preservation
+ - Maintain logical document flow
+ - Verify all markdown syntax is valid
+ - Double-check completeness before submitting
+ input_types: {}
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py b/.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py
new file mode 100644
index 00000000..29afbb3f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py
@@ -0,0 +1,748 @@
+import json
+import logging
+import os
+from abc import abstractmethod
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Generic, Optional, TypeVar
+
+import yaml
+
+from core.base import Handler, generate_default_prompt_id
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+@dataclass
+class CacheEntry(Generic[T]):
+ """Represents a cached item with metadata."""
+
+ value: T
+ created_at: datetime
+ last_accessed: datetime
+ access_count: int = 0
+
+
+class Cache(Generic[T]):
+ """A generic cache implementation with TTL and LRU-like features."""
+
+ def __init__(
+ self,
+ ttl: Optional[timedelta] = None,
+ max_size: Optional[int] = 1000,
+ cleanup_interval: timedelta = timedelta(hours=1),
+ ):
+ self._cache: dict[str, CacheEntry[T]] = {}
+ self._ttl = ttl
+ self._max_size = max_size
+ self._cleanup_interval = cleanup_interval
+ self._last_cleanup = datetime.now()
+
+ def get(self, key: str) -> Optional[T]:
+ """Retrieve an item from cache."""
+ self._maybe_cleanup()
+
+ if key not in self._cache:
+ return None
+
+ entry = self._cache[key]
+
+ if self._ttl and datetime.now() - entry.created_at > self._ttl:
+ del self._cache[key]
+ return None
+
+ entry.last_accessed = datetime.now()
+ entry.access_count += 1
+ return entry.value
+
+ def set(self, key: str, value: T) -> None:
+ """Store an item in cache."""
+ self._maybe_cleanup()
+
+ now = datetime.now()
+ self._cache[key] = CacheEntry(
+ value=value, created_at=now, last_accessed=now
+ )
+
+ if self._max_size and len(self._cache) > self._max_size:
+ self._evict_lru()
+
+ def invalidate(self, key: str) -> None:
+ """Remove an item from cache."""
+ self._cache.pop(key, None)
+
+ def clear(self) -> None:
+ """Clear all cached items."""
+ self._cache.clear()
+
+ def _maybe_cleanup(self) -> None:
+ """Periodically clean up expired entries."""
+ now = datetime.now()
+ if now - self._last_cleanup > self._cleanup_interval:
+ self._cleanup()
+ self._last_cleanup = now
+
+ def _cleanup(self) -> None:
+ """Remove expired entries."""
+ if not self._ttl:
+ return
+
+ now = datetime.now()
+ expired = [
+ k for k, v in self._cache.items() if now - v.created_at > self._ttl
+ ]
+ for k in expired:
+ del self._cache[k]
+
+ def _evict_lru(self) -> None:
+ """Remove least recently used item."""
+ if not self._cache:
+ return
+
+ lru_key = min(
+ self._cache.keys(), key=lambda k: self._cache[k].last_accessed
+ )
+ del self._cache[lru_key]
+
+
+class CacheablePromptHandler(Handler):
+ """Abstract base class that adds caching capabilities to prompt
+ handlers."""
+
+ def __init__(
+ self,
+ cache_ttl: Optional[timedelta] = timedelta(hours=1),
+ max_cache_size: Optional[int] = 1000,
+ ):
+ self._prompt_cache = Cache[str](ttl=cache_ttl, max_size=max_cache_size)
+ self._template_cache = Cache[dict](
+ ttl=cache_ttl, max_size=max_cache_size
+ )
+
+ def _cache_key(
+ self, prompt_name: str, inputs: Optional[dict] = None
+ ) -> str:
+ """Generate a cache key for a prompt request."""
+ if inputs:
+ # Sort dict items for consistent keys
+ sorted_inputs = sorted(inputs.items())
+ return f"{prompt_name}:{sorted_inputs}"
+ return prompt_name
+
+ async def get_cached_prompt(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ prompt_override: Optional[str] = None,
+ bypass_cache: bool = False,
+ ) -> str:
+ if prompt_override:
+ # If the user gave us a direct override, use it.
+ if inputs:
+ try:
+ return prompt_override.format(**inputs)
+ except KeyError:
+ return prompt_override
+ return prompt_override
+
+ cache_key = self._cache_key(prompt_name, inputs)
+
+ # If not bypassing, try returning from the prompt-level cache
+ if not bypass_cache:
+ cached = self._prompt_cache.get(cache_key)
+ if cached is not None:
+ logger.debug(f"Prompt cache hit: {cache_key}")
+ return cached
+
+ logger.debug(
+ "Prompt cache miss or bypass. Retrieving from DB or template cache."
+ )
+ # Notice the new parameter `bypass_template_cache` below
+ result = await self._get_prompt_impl(
+ prompt_name, inputs, bypass_template_cache=bypass_cache
+ )
+ self._prompt_cache.set(cache_key, result)
+ return result
+
+ async def get_prompt( # type: ignore
+ self,
+ name: str,
+ inputs: Optional[dict] = None,
+ prompt_override: Optional[str] = None,
+ ) -> dict:
+ query = f"""
+ SELECT id, name, template, input_types, created_at, updated_at
+ FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+ result = await self.connection_manager.fetchrow_query(query, [name])
+
+ if not result:
+ raise ValueError(f"Prompt template '{name}' not found")
+
+ input_types = result["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ return {
+ "id": result["id"],
+ "name": result["name"],
+ "template": result["template"],
+ "input_types": input_types,
+ "created_at": result["created_at"],
+ "updated_at": result["updated_at"],
+ }
+
+ def _format_prompt(
+ self,
+ template: str,
+ inputs: Optional[dict[str, Any]],
+ input_types: dict[str, str],
+ ) -> str:
+ if inputs:
+ # optional input validation if needed
+ for k, _v in inputs.items():
+ if k not in input_types:
+ raise ValueError(
+ f"Unexpected input '{k}' for prompt with input types {input_types}"
+ )
+ return template.format(**inputs)
+ return template
+
+ async def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ """Public method to update a prompt with proper cache invalidation."""
+ # First invalidate all caches for this prompt
+ self._template_cache.invalidate(name)
+ cache_keys_to_invalidate = [
+ key
+ for key in self._prompt_cache._cache.keys()
+ if key.startswith(f"{name}:") or key == name
+ ]
+ for key in cache_keys_to_invalidate:
+ self._prompt_cache.invalidate(key)
+
+ # Perform the update
+ await self._update_prompt_impl(name, template, input_types)
+
+ # Force refresh template cache
+ template_info = await self._get_template_info(name)
+ if template_info:
+ self._template_cache.set(name, template_info)
+
+ @abstractmethod
+ async def _update_prompt_impl(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ """Implementation of prompt update logic."""
+ pass
+
+ @abstractmethod
+ async def _get_template_info(self, prompt_name: str) -> Optional[dict]:
+ """Get template info with caching."""
+ pass
+
+ @abstractmethod
+ async def _get_prompt_impl(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ bypass_template_cache: bool = False,
+ ) -> str:
+ """Implementation of prompt retrieval logic."""
+ pass
+
+
+class PostgresPromptsHandler(CacheablePromptHandler):
+ """PostgreSQL implementation of the CacheablePromptHandler."""
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ prompt_directory: Optional[Path] = None,
+ **cache_options,
+ ):
+ super().__init__(**cache_options)
+ self.prompt_directory = (
+ prompt_directory or Path(os.path.dirname(__file__)) / "prompts"
+ )
+ self.connection_manager = connection_manager
+ self.project_name = project_name
+ self.prompts: dict[str, dict[str, str | dict[str, str]]] = {}
+
+ async def _load_prompts(self) -> None:
+ """Load prompts from both database and YAML files."""
+ # First load from database
+ await self._load_prompts_from_database()
+
+ # Then load from YAML files, potentially overriding unmodified database entries
+ await self._load_prompts_from_yaml_directory()
+
+ async def _load_prompts_from_database(self) -> None:
+ """Load prompts from the database."""
+ query = f"""
+ SELECT id, name, template, input_types, created_at, updated_at
+ FROM {self._get_table_name("prompts")};
+ """
+ try:
+ results = await self.connection_manager.fetch_query(query)
+ for row in results:
+ logger.info(f"Loading saved prompt: {row['name']}")
+
+ # Ensure input_types is a dictionary
+ input_types = row["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ self.prompts[row["name"]] = {
+ "id": row["id"],
+ "template": row["template"],
+ "input_types": input_types,
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ # Pre-populate the template cache
+ self._template_cache.set(
+ row["name"],
+ {
+ "id": row["id"],
+ "template": row["template"],
+ "input_types": input_types,
+ },
+ )
+ logger.debug(f"Loaded {len(results)} prompts from database")
+ except Exception as e:
+ logger.error(f"Failed to load prompts from database: {e}")
+ raise
+
+ async def _load_prompts_from_yaml_directory(
+ self, default_overwrite_on_diff: bool = False
+ ) -> None:
+ """Load prompts from YAML files in the specified directory.
+
+ :param default_overwrite_on_diff: If a YAML prompt does not specify
+ 'overwrite_on_diff', we use this default.
+ """
+ if not self.prompt_directory.is_dir():
+ logger.warning(
+ f"Prompt directory not found: {self.prompt_directory}"
+ )
+ return
+
+ logger.info(f"Loading prompts from {self.prompt_directory}")
+ for yaml_file in self.prompt_directory.glob("*.yaml"):
+ logger.debug(f"Processing {yaml_file}")
+ try:
+ with open(yaml_file, "r", encoding="utf-8") as file:
+ data = yaml.safe_load(file)
+ if not isinstance(data, dict):
+ raise ValueError(
+ f"Invalid format in YAML file {yaml_file}"
+ )
+
+ for name, prompt_data in data.items():
+ # Attempt to parse the relevant prompt fields
+ template = prompt_data.get("template")
+ input_types = prompt_data.get("input_types", {})
+
+ # Decide on per-prompt overwrite behavior (or fallback)
+ overwrite_on_diff = prompt_data.get(
+ "overwrite_on_diff", default_overwrite_on_diff
+ )
+ # Some logic to determine if we *should* modify
+ # For instance, preserve only if it has never been updated
+ # (i.e., created_at == updated_at).
+ should_modify = True
+ if name in self.prompts:
+ existing = self.prompts[name]
+ should_modify = (
+ existing["created_at"]
+ == existing["updated_at"]
+ )
+
+ # If should_modify is True, the default logic is
+ # preserve_existing = False,
+ # so we can pass that in. Otherwise, preserve_existing=True
+ # effectively means we skip the update.
+ logger.info(
+ f"Loading default prompt: {name} from {yaml_file}."
+ )
+
+ await self.add_prompt(
+ name=name,
+ template=template,
+ input_types=input_types,
+ preserve_existing=False,
+ overwrite_on_diff=overwrite_on_diff,
+ )
+ except Exception as e:
+ logger.error(f"Error loading {yaml_file}: {e}")
+ continue
+
+ def _get_table_name(self, base_name: str) -> str:
+ """Get the fully qualified table name."""
+ return f"{self.project_name}.{base_name}"
+
+ # Implementation of abstract methods from CacheablePromptHandler
+ async def _get_prompt_impl(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ bypass_template_cache: bool = False,
+ ) -> str:
+ """Implementation of database prompt retrieval."""
+ # If we're bypassing the template cache, skip the cache lookup
+ if not bypass_template_cache:
+ template_info = self._template_cache.get(prompt_name)
+ if template_info is not None:
+ logger.debug(f"Template cache hit: {prompt_name}")
+ # use that
+ return self._format_prompt(
+ template_info["template"],
+ inputs,
+ template_info["input_types"],
+ )
+
+ # If we get here, either no cache was found or bypass_cache is True
+ query = f"""
+ SELECT template, input_types
+ FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [prompt_name]
+ )
+
+ if not result:
+ raise ValueError(f"Prompt template '{prompt_name}' not found")
+
+ template = result["template"]
+ input_types = result["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ # Update template cache if not bypassing it
+ if not bypass_template_cache:
+ self._template_cache.set(
+ prompt_name, {"template": template, "input_types": input_types}
+ )
+
+ return self._format_prompt(template, inputs, input_types)
+
+ async def _get_template_info(self, prompt_name: str) -> Optional[dict]: # type: ignore
+ """Get template info with caching."""
+ cached = self._template_cache.get(prompt_name)
+ if cached is not None:
+ return cached
+
+ query = f"""
+ SELECT template, input_types
+ FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [prompt_name]
+ )
+
+ if result:
+ # Ensure input_types is a dictionary
+ input_types = result["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ template_info = {
+ "template": result["template"],
+ "input_types": input_types,
+ }
+ self._template_cache.set(prompt_name, template_info)
+ return template_info
+
+ return None
+
+ async def _update_prompt_impl(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ """Implementation of database prompt update with proper connection
+ handling."""
+ if not template and not input_types:
+ return
+
+ # Clear caches first
+ self._template_cache.invalidate(name)
+ for key in list(self._prompt_cache._cache.keys()):
+ if key.startswith(f"{name}:"):
+ self._prompt_cache.invalidate(key)
+
+ # Build update query
+ set_clauses = []
+ params = [name] # First parameter is always the name
+ param_index = 2 # Start from 2 since $1 is name
+
+ if template:
+ set_clauses.append(f"template = ${param_index}")
+ params.append(template)
+ param_index += 1
+
+ if input_types:
+ set_clauses.append(f"input_types = ${param_index}")
+ params.append(json.dumps(input_types))
+ param_index += 1
+
+ set_clauses.append("updated_at = CURRENT_TIMESTAMP")
+
+ query = f"""
+ UPDATE {self._get_table_name("prompts")}
+ SET {", ".join(set_clauses)}
+ WHERE name = $1
+ RETURNING id, template, input_types;
+ """
+
+ try:
+ # Execute update and get returned values
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+
+ if not result:
+ raise ValueError(f"Prompt template '{name}' not found")
+
+ # Update in-memory state
+ if name in self.prompts:
+ if template:
+ self.prompts[name]["template"] = template
+ if input_types:
+ self.prompts[name]["input_types"] = input_types
+ self.prompts[name]["updated_at"] = datetime.now().isoformat()
+
+ except Exception as e:
+ logger.error(f"Failed to update prompt {name}: {str(e)}")
+ raise
+
+ async def create_tables(self):
+ """Create the necessary tables for storing prompts."""
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("prompts")} (
+ id UUID PRIMARY KEY,
+ name VARCHAR(255) NOT NULL UNIQUE,
+ template TEXT NOT NULL,
+ input_types JSONB NOT NULL,
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
+ );
+
+ CREATE OR REPLACE FUNCTION {self.project_name}.update_updated_at_column()
+ RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ DROP TRIGGER IF EXISTS update_prompts_updated_at
+ ON {self._get_table_name("prompts")};
+
+ CREATE TRIGGER update_prompts_updated_at
+ BEFORE UPDATE ON {self._get_table_name("prompts")}
+ FOR EACH ROW
+ EXECUTE FUNCTION {self.project_name}.update_updated_at_column();
+ """
+ await self.connection_manager.execute_query(query)
+ await self._load_prompts()
+
+ async def add_prompt(
+ self,
+ name: str,
+ template: str,
+ input_types: dict[str, str],
+ preserve_existing: bool = False,
+ overwrite_on_diff: bool = False, # <-- new param
+ ) -> None:
+ """Add or update a prompt.
+
+ If `preserve_existing` is True and prompt already exists, we skip updating.
+
+ If `overwrite_on_diff` is True and an existing prompt differs from what is provided,
+ we overwrite and log a warning. Otherwise, we skip if the prompt differs.
+ """
+ # Check if prompt is in-memory
+ existing_prompt = self.prompts.get(name)
+
+ # If preserving existing and it already exists, skip entirely
+ if preserve_existing and existing_prompt:
+ logger.debug(
+ f"Preserving existing prompt: {name}, skipping update."
+ )
+ return
+
+ # If an existing prompt is found, check for diffs
+ if existing_prompt:
+ existing_template = existing_prompt["template"]
+ existing_input_types = existing_prompt["input_types"]
+
+ # If there's a difference in template or input_types, decide to overwrite or skip
+ if (
+ existing_template != template
+ or existing_input_types != input_types
+ ):
+ if overwrite_on_diff:
+ logger.warning(
+ f"Overwriting existing prompt '{name}' due to detected diff."
+ )
+ else:
+ logger.info(
+ f"Prompt '{name}' differs from existing but overwrite_on_diff=False. Skipping update."
+ )
+ return
+
+ prompt_id = generate_default_prompt_id(name)
+
+ # Ensure input_types is properly serialized
+ input_types_json = (
+ json.dumps(input_types)
+ if isinstance(input_types, dict)
+ else input_types
+ )
+
+ # Upsert logic
+ query = f"""
+ INSERT INTO {self._get_table_name("prompts")} (id, name, template, input_types)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (name) DO UPDATE
+ SET template = EXCLUDED.template,
+ input_types = EXCLUDED.input_types,
+ updated_at = CURRENT_TIMESTAMP
+ RETURNING id, created_at, updated_at;
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [prompt_id, name, template, input_types_json]
+ )
+
+ self.prompts[name] = {
+ "id": result["id"],
+ "template": template,
+ "input_types": input_types,
+ "created_at": result["created_at"],
+ "updated_at": result["updated_at"],
+ }
+
+ # Update template cache
+ self._template_cache.set(
+ name,
+ {
+ "id": prompt_id,
+ "template": template,
+ "input_types": input_types,
+ },
+ )
+
+ # Invalidate any cached formatted prompts
+ for key in list(self._prompt_cache._cache.keys()):
+ if key.startswith(f"{name}:"):
+ self._prompt_cache.invalidate(key)
+
+ async def get_all_prompts(self) -> dict[str, Any]:
+ """Retrieve all stored prompts."""
+ query = f"""
+ SELECT id, name, template, input_types, created_at, updated_at, COUNT(*) OVER() AS total_entries
+ FROM {self._get_table_name("prompts")};
+ """
+ results = await self.connection_manager.fetch_query(query)
+
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"] if results else 0
+
+ prompts = [
+ {
+ "name": row["name"],
+ "id": row["id"],
+ "template": row["template"],
+ "input_types": (
+ json.loads(row["input_types"])
+ if isinstance(row["input_types"], str)
+ else row["input_types"]
+ ),
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]
+
+ return {"results": prompts, "total_entries": total_entries}
+
+ async def delete_prompt(self, name: str) -> None:
+ """Delete a prompt template."""
+ query = f"""
+ DELETE FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+ result = await self.connection_manager.execute_query(query, [name])
+ if result == "DELETE 0":
+ raise ValueError(f"Prompt template '{name}' not found")
+
+ # Invalidate caches
+ self._template_cache.invalidate(name)
+ for key in list(self._prompt_cache._cache.keys()):
+ if key.startswith(f"{name}:"):
+ self._prompt_cache.invalidate(key)
+
+ async def get_message_payload(
+ self,
+ system_prompt_name: Optional[str] = None,
+ system_role: str = "system",
+ system_inputs: dict | None = None,
+ system_prompt_override: Optional[str] = None,
+ task_prompt_name: Optional[str] = None,
+ task_role: str = "user",
+ task_inputs: Optional[dict] = None,
+ task_prompt: Optional[str] = None,
+ ) -> list[dict]:
+ """Create a message payload from system and task prompts."""
+ if system_inputs is None:
+ system_inputs = {}
+ if task_inputs is None:
+ task_inputs = {}
+ if system_prompt_override:
+ system_prompt = system_prompt_override
+ else:
+ system_prompt = await self.get_cached_prompt(
+ system_prompt_name or "system",
+ system_inputs,
+ prompt_override=system_prompt_override,
+ )
+
+ task_prompt = await self.get_cached_prompt(
+ task_prompt_name or "rag",
+ task_inputs,
+ prompt_override=task_prompt,
+ )
+
+ return [
+ {
+ "role": system_role,
+ "content": system_prompt,
+ },
+ {
+ "role": task_role,
+ "content": task_prompt,
+ },
+ ]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/tokens.py b/.venv/lib/python3.12/site-packages/core/providers/database/tokens.py
new file mode 100644
index 00000000..7d30c326
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/tokens.py
@@ -0,0 +1,67 @@
+from datetime import datetime, timedelta
+from typing import Optional
+
+from core.base import Handler
+
+from .base import PostgresConnectionManager
+
+
+class PostgresTokensHandler(Handler):
+ TABLE_NAME = "blacklisted_tokens"
+
+ def __init__(
+ self, project_name: str, connection_manager: PostgresConnectionManager
+ ):
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self):
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ token TEXT NOT NULL,
+ blacklisted_at TIMESTAMPTZ DEFAULT NOW()
+ );
+ CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_token
+ ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token);
+ CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_blacklisted_at
+ ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (blacklisted_at);
+ """
+ await self.connection_manager.execute_query(query)
+
+ async def blacklist_token(
+ self, token: str, current_time: Optional[datetime] = None
+ ):
+ if current_time is None:
+ current_time = datetime.utcnow()
+
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token, blacklisted_at)
+ VALUES ($1, $2)
+ """
+ await self.connection_manager.execute_query(
+ query, [token, current_time]
+ )
+
+ async def is_token_blacklisted(self, token: str) -> bool:
+ query = f"""
+ SELECT 1 FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
+ WHERE token = $1
+ LIMIT 1
+ """
+ result = await self.connection_manager.fetchrow_query(query, [token])
+ return bool(result)
+
+ async def clean_expired_blacklisted_tokens(
+ self,
+ max_age_hours: int = 7 * 24,
+ current_time: Optional[datetime] = None,
+ ):
+ if current_time is None:
+ current_time = datetime.utcnow()
+ expiry_time = current_time - timedelta(hours=max_age_hours)
+
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
+ WHERE blacklisted_at < $1
+ """
+ await self.connection_manager.execute_query(query, [expiry_time])
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/users.py b/.venv/lib/python3.12/site-packages/core/providers/database/users.py
new file mode 100644
index 00000000..208eeaa4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/users.py
@@ -0,0 +1,1325 @@
+import csv
+import json
+import tempfile
+from datetime import datetime
+from typing import IO, Optional
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core.base import CryptoProvider, Handler
+from core.base.abstractions import R2RException
+from core.utils import generate_user_id
+from shared.abstractions import User
+
+from .base import PostgresConnectionManager, QueryBuilder
+from .collections import PostgresCollectionsHandler
+
+
+def _merge_metadata(
+ existing_metadata: dict[str, str], new_metadata: dict[str, Optional[str]]
+) -> dict[str, str]:
+ """
+ Merges the new metadata with the existing metadata in the Stripe-style approach:
+ - new_metadata[key] = <string> => update or add that key
+ - new_metadata[key] = "" => remove that key
+ - if new_metadata is empty => remove all keys
+ """
+ # If new_metadata is an empty dict, it signals removal of all keys.
+ if new_metadata == {}:
+ return {}
+
+ # Copy so we don't mutate the original
+ final_metadata = dict(existing_metadata)
+
+ for key, value in new_metadata.items():
+ # If the user sets the key to an empty string, it means "delete" that key
+ if value == "":
+ if key in final_metadata:
+ del final_metadata[key]
+ # If not None and not empty, set or override
+ elif value is not None:
+ final_metadata[key] = value
+ else:
+ # If the user sets the value to None in some contexts, decide if you want to remove or ignore
+ # For now we might treat None same as empty string => remove
+ if key in final_metadata:
+ del final_metadata[key]
+
+ return final_metadata
+
+
+class PostgresUserHandler(Handler):
+ TABLE_NAME = "users"
+ API_KEYS_TABLE_NAME = "users_api_keys"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ crypto_provider: CryptoProvider,
+ ):
+ super().__init__(project_name, connection_manager)
+ self.crypto_provider = crypto_provider
+
+ async def create_tables(self):
+ user_table_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ email TEXT UNIQUE NOT NULL,
+ hashed_password TEXT NOT NULL,
+ is_superuser BOOLEAN DEFAULT FALSE,
+ is_active BOOLEAN DEFAULT TRUE,
+ is_verified BOOLEAN DEFAULT FALSE,
+ verification_code TEXT,
+ verification_code_expiry TIMESTAMPTZ,
+ name TEXT,
+ bio TEXT,
+ profile_picture TEXT,
+ reset_token TEXT,
+ reset_token_expiry TIMESTAMPTZ,
+ collection_ids UUID[] NULL,
+ limits_overrides JSONB,
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ account_type TEXT NOT NULL DEFAULT 'password',
+ google_id TEXT,
+ github_id TEXT
+ );
+ """
+
+ # API keys table with updated_at instead of last_used_at
+ api_keys_table_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
+ public_key TEXT UNIQUE NOT NULL,
+ hashed_key TEXT NOT NULL,
+ name TEXT,
+ description TEXT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
+ ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
+
+ CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
+ ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
+ """
+
+ await self.connection_manager.execute_query(user_table_query)
+ await self.connection_manager.execute_query(api_keys_table_query)
+
+ # (New) Code snippet for adding columns if missing
+ # Postgres >= 9.6 supports "ADD COLUMN IF NOT EXISTS"
+ check_columns_query = f"""
+ ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS metadata JSONB;
+
+ ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS limits_overrides JSONB;
+
+ ALTER TABLE {self._get_table_name(self.API_KEYS_TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS description TEXT;
+ """
+ await self.connection_manager.execute_query(check_columns_query)
+
+ # Optionally, create indexes for quick lookups:
+ check_columns_query = f"""
+ ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS account_type TEXT NOT NULL DEFAULT 'password',
+ ADD COLUMN IF NOT EXISTS google_id TEXT,
+ ADD COLUMN IF NOT EXISTS github_id TEXT;
+
+ CREATE INDEX IF NOT EXISTS idx_users_google_id
+ ON {self._get_table_name(self.TABLE_NAME)}(google_id);
+ CREATE INDEX IF NOT EXISTS idx_users_github_id
+ ON {self._get_table_name(self.TABLE_NAME)}(github_id);
+ """
+ await self.connection_manager.execute_query(check_columns_query)
+
+ async def get_user_by_id(self, id: UUID) -> User:
+ query, _ = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("id = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(query, [id])
+
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ hashed_password=result["hashed_password"],
+ account_type=result["account_type"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def get_user_by_email(self, email: str) -> User:
+ query, params = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "metadata",
+ "limits_overrides",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("email = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(query, [email])
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def create_user(
+ self,
+ email: str,
+ password: Optional[str] = None,
+ account_type: Optional[str] = "password",
+ google_id: Optional[str] = None,
+ github_id: Optional[str] = None,
+ is_superuser: bool = False,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ """Create a new user."""
+ # 1) Check if a user with this email already exists
+ try:
+ existing = await self.get_user_by_email(email)
+ if existing:
+ raise R2RException(
+ status_code=400,
+ message="User with this email already exists",
+ )
+ except R2RException as e:
+ if e.status_code != 404:
+ raise e
+ # 2) If google_id is provided, ensure no user already has it
+ if google_id:
+ existing_google_user = await self.get_user_by_google_id(google_id)
+ if existing_google_user:
+ raise R2RException(
+ status_code=400,
+ message="User with this Google account already exists",
+ )
+
+ # 3) If github_id is provided, ensure no user already has it
+ if github_id:
+ existing_github_user = await self.get_user_by_github_id(github_id)
+ if existing_github_user:
+ raise R2RException(
+ status_code=400,
+ message="User with this GitHub account already exists",
+ )
+
+ hashed_password = None
+ if account_type == "password":
+ if password is None:
+ raise R2RException(
+ status_code=400,
+ message="Password is required for a 'password' account_type",
+ )
+ hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
+
+ query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .insert(
+ {
+ "email": email,
+ "id": generate_user_id(email),
+ "is_superuser": is_superuser,
+ "collection_ids": [],
+ "limits_overrides": None,
+ "metadata": None,
+ "account_type": account_type,
+ "hashed_password": hashed_password
+ or "", # Ensure hashed_password is not None
+ # !!WARNING - Upstream checks are required to treat oauth differently from password!!
+ "google_id": google_id,
+ "github_id": github_id,
+ "is_verified": account_type != "password",
+ "name": name,
+ "bio": bio,
+ "profile_picture": profile_picture,
+ }
+ )
+ .returning(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "name",
+ "bio",
+ "profile_picture",
+ ]
+ )
+ .build()
+ )
+
+ result = await self.connection_manager.fetchrow_query(query, params)
+ if not result:
+ raise R2RException(
+ status_code=500,
+ message="Failed to create user",
+ )
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ name=result["name"],
+ bio=result["bio"],
+ profile_picture=result["profile_picture"],
+ account_type=account_type or "password",
+ hashed_password=hashed_password,
+ google_id=google_id,
+ github_id=github_id,
+ )
+
+ async def update_user(
+ self,
+ user: User,
+ merge_limits: bool = False,
+ new_metadata: dict[str, Optional[str]] | None = None,
+ ) -> User:
+ """Update user information including limits_overrides.
+
+ Args:
+ user: User object containing updated information
+ merge_limits: If True, will merge existing limits_overrides with new ones.
+ If False, will overwrite existing limits_overrides.
+
+ Returns:
+ Updated User object
+ """
+
+ # Get current user if we need to merge limits or get hashed password
+ current_user = None
+ try:
+ current_user = await self.get_user_by_id(user.id)
+ except R2RException:
+ raise R2RException(
+ status_code=404, message="User not found"
+ ) from None
+
+ # If the new user.google_id != current_user.google_id, check for duplicates
+ if user.email and (user.email != current_user.email):
+ existing_email_user = await self.get_user_by_email(user.email)
+ if existing_email_user and existing_email_user.id != user.id:
+ raise R2RException(
+ status_code=400,
+ message="That email account is already associated with another user.",
+ )
+
+ # If the new user.google_id != current_user.google_id, check for duplicates
+ if user.google_id and (user.google_id != current_user.google_id):
+ existing_google_user = await self.get_user_by_google_id(
+ user.google_id
+ )
+ if existing_google_user and existing_google_user.id != user.id:
+ raise R2RException(
+ status_code=400,
+ message="That Google account is already associated with another user.",
+ )
+
+ # Similarly for GitHub:
+ if user.github_id and (user.github_id != current_user.github_id):
+ existing_github_user = await self.get_user_by_github_id(
+ user.github_id
+ )
+ if existing_github_user and existing_github_user.id != user.id:
+ raise R2RException(
+ status_code=400,
+ message="That GitHub account is already associated with another user.",
+ )
+
+ # Merge or replace metadata if provided
+ final_metadata = current_user.metadata or {}
+ if new_metadata is not None:
+ final_metadata = _merge_metadata(final_metadata, new_metadata)
+
+ # Merge or replace limits_overrides
+ final_limits = user.limits_overrides
+ if (
+ merge_limits
+ and current_user.limits_overrides
+ and user.limits_overrides
+ ):
+ final_limits = {
+ **current_user.limits_overrides,
+ **user.limits_overrides,
+ }
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET email = $1,
+ is_superuser = $2,
+ is_active = $3,
+ is_verified = $4,
+ updated_at = NOW(),
+ name = $5,
+ profile_picture = $6,
+ bio = $7,
+ collection_ids = $8,
+ limits_overrides = $9::jsonb,
+ metadata = $10::jsonb
+ WHERE id = $11
+ RETURNING id, email, is_superuser, is_active, is_verified,
+ created_at, updated_at, name, profile_picture, bio,
+ collection_ids, limits_overrides, metadata, hashed_password,
+ account_type, google_id, github_id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query,
+ [
+ user.email,
+ user.is_superuser,
+ user.is_active,
+ user.is_verified,
+ user.name,
+ user.profile_picture,
+ user.bio,
+ user.collection_ids or [],
+ json.dumps(final_limits),
+ json.dumps(final_metadata),
+ user.id,
+ ],
+ )
+
+ if not result:
+ raise HTTPException(
+ status_code=500,
+ detail="Failed to update user",
+ )
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"]
+ or [], # Ensure null becomes empty array
+ limits_overrides=json.loads(
+ result["limits_overrides"] or "{}"
+ ), # Can be null
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result[
+ "hashed_password"
+ ], # Include hashed_password
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def delete_user_relational(self, id: UUID) -> None:
+ """Delete a user and update related records."""
+ # Get the collections the user belongs to
+ collection_query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .select(["collection_ids"])
+ .where("id = $1")
+ .build()
+ )
+
+ collection_result = await self.connection_manager.fetchrow_query(
+ collection_query, [id]
+ )
+
+ if not collection_result:
+ raise R2RException(status_code=404, message="User not found")
+
+ # Update documents query
+ doc_update_query, doc_params = (
+ QueryBuilder(self._get_table_name("documents"))
+ .update({"id": None})
+ .where("id = $1")
+ .build()
+ )
+
+ await self.connection_manager.execute_query(doc_update_query, [id])
+
+ # Delete user query
+ delete_query, del_params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .delete()
+ .where("id = $1")
+ .returning(["id"])
+ .build()
+ )
+
+ result = await self.connection_manager.fetchrow_query(
+ delete_query, [id]
+ )
+
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ async def update_user_password(self, id: UUID, new_hashed_password: str):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET hashed_password = $1, updated_at = NOW()
+ WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ query, [new_hashed_password, id]
+ )
+
+ async def get_all_users(self) -> list[User]:
+ """Get all users with minimal information."""
+ query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "collection_ids",
+ "hashed_password",
+ "limits_overrides",
+ "metadata",
+ "name",
+ "bio",
+ "profile_picture",
+ "account_type",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .build()
+ )
+
+ results = await self.connection_manager.fetch_query(query, params)
+ return [
+ User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(
+ result["limits_overrides"] or "{}"
+ ),
+ metadata=json.loads(result["metadata"] or "{}"),
+ name=result["name"],
+ bio=result["bio"],
+ profile_picture=result["profile_picture"],
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+ for result in results
+ ]
+
+ async def store_verification_code(
+ self, id: UUID, verification_code: str, expiry: datetime
+ ):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET verification_code = $1, verification_code_expiry = $2
+ WHERE id = $3
+ """
+ await self.connection_manager.execute_query(
+ query, [verification_code, expiry, id]
+ )
+
+ async def verify_user(self, verification_code: str) -> None:
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+ WHERE verification_code = $1 AND verification_code_expiry > NOW()
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [verification_code]
+ )
+
+ if not result:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ async def remove_verification_code(self, verification_code: str):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET verification_code = NULL, verification_code_expiry = NULL
+ WHERE verification_code = $1
+ """
+ await self.connection_manager.execute_query(query, [verification_code])
+
+ async def expire_verification_code(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET verification_code_expiry = NOW() - INTERVAL '1 day'
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def store_reset_token(
+ self, id: UUID, reset_token: str, expiry: datetime
+ ):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET reset_token = $1, reset_token_expiry = $2
+ WHERE id = $3
+ """
+ await self.connection_manager.execute_query(
+ query, [reset_token, expiry, id]
+ )
+
+ async def get_user_id_by_reset_token(
+ self, reset_token: str
+ ) -> Optional[UUID]:
+ query = f"""
+ SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ WHERE reset_token = $1 AND reset_token_expiry > NOW()
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [reset_token]
+ )
+ return result["id"] if result else None
+
+ async def remove_reset_token(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET reset_token = NULL, reset_token_expiry = NULL
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def remove_user_from_all_collections(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET collection_ids = ARRAY[]::UUID[]
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def add_user_to_collection(
+ self, id: UUID, collection_id: UUID
+ ) -> bool:
+ # Check if the user exists
+ if not await self.get_user_by_id(id):
+ raise R2RException(status_code=404, message="User not found")
+
+ # Check if the collection exists
+ if not await self._collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET collection_ids = array_append(collection_ids, $1)
+ WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id, id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=400, message="User already in collection"
+ )
+
+ update_collection_query = f"""
+ UPDATE {self._get_table_name("collections")}
+ SET user_count = user_count + 1
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(
+ query=update_collection_query,
+ params=[collection_id],
+ )
+
+ return True
+
+ async def remove_user_from_collection(
+ self, id: UUID, collection_id: UUID
+ ) -> bool:
+ if not await self.get_user_by_id(id):
+ raise R2RException(status_code=404, message="User not found")
+
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE id = $2 AND $1 = ANY(collection_ids)
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id, id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=400,
+ message="User is not a member of the specified collection",
+ )
+ return True
+
+ async def get_users_in_collection(
+ self, collection_id: UUID, offset: int, limit: int
+ ) -> dict[str, list[User] | int]:
+ """Get all users in a specific collection with pagination."""
+ if not await self._collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .select(
+ [
+ "id",
+ "email",
+ "is_active",
+ "is_superuser",
+ "created_at",
+ "updated_at",
+ "is_verified",
+ "collection_ids",
+ "name",
+ "bio",
+ "profile_picture",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ "COUNT(*) OVER() AS total_entries",
+ ]
+ )
+ .where("$1 = ANY(collection_ids)")
+ .order_by("name")
+ .offset("$2")
+ .limit("$3" if limit != -1 else None)
+ .build()
+ )
+
+ conditions = [collection_id, offset]
+ if limit != -1:
+ conditions.append(limit)
+
+ results = await self.connection_manager.fetch_query(query, conditions)
+
+ users_list = [
+ User(
+ id=row["id"],
+ email=row["email"],
+ is_active=row["is_active"],
+ is_superuser=row["is_superuser"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ is_verified=row["is_verified"],
+ collection_ids=row["collection_ids"] or [],
+ name=row["name"],
+ bio=row["bio"],
+ profile_picture=row["profile_picture"],
+ limits_overrides=json.loads(row["limits_overrides"] or "{}"),
+ metadata=json.loads(row["metadata"] or "{}"),
+ account_type=row["account_type"],
+ hashed_password=row["hashed_password"],
+ google_id=row["google_id"],
+ github_id=row["github_id"],
+ )
+ for row in results
+ ]
+
+ total_entries = results[0]["total_entries"] if results else 0
+ return {"results": users_list, "total_entries": total_entries}
+
+ async def mark_user_as_superuser(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET is_superuser = TRUE, is_verified = TRUE,
+ verification_code = NULL, verification_code_expiry = NULL
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def get_user_id_by_verification_code(
+ self, verification_code: str
+ ) -> UUID:
+ query = f"""
+ SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ WHERE verification_code = $1 AND verification_code_expiry > NOW()
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [verification_code]
+ )
+
+ if not result:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ return result["id"]
+
+ async def mark_user_as_verified(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET is_verified = TRUE,
+ verification_code = NULL,
+ verification_code_expiry = NULL
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def get_users_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[User] | int]:
+ """Return users with document usage and total entries."""
+ query = f"""
+ WITH user_document_ids AS (
+ SELECT
+ u.id as user_id,
+ ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
+ FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+ LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
+ GROUP BY u.id
+ ),
+ user_docs AS (
+ SELECT
+ u.id,
+ u.email,
+ u.is_superuser,
+ u.is_active,
+ u.is_verified,
+ u.name,
+ u.bio,
+ u.profile_picture,
+ u.collection_ids,
+ u.created_at,
+ u.updated_at,
+ COUNT(d.id) AS num_files,
+ COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
+ ud.doc_ids as document_ids
+ FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+ LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
+ LEFT JOIN user_document_ids ud ON u.id = ud.user_id
+ {" WHERE u.id = ANY($3::uuid[])" if user_ids else ""}
+ GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
+ u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
+ )
+ SELECT
+ user_docs.*,
+ COUNT(*) OVER() AS total_entries
+ FROM user_docs
+ ORDER BY email
+ OFFSET $1
+ """
+
+ params: list = [offset]
+
+ if limit != -1:
+ query += " LIMIT $2"
+ params.append(limit)
+
+ if user_ids:
+ params.append(user_ids)
+
+ results = await self.connection_manager.fetch_query(query, params)
+ if not results:
+ raise R2RException(status_code=404, message="No users found")
+
+ users_list = []
+ for row in results:
+ users_list.append(
+ User(
+ id=row["id"],
+ email=row["email"],
+ is_superuser=row["is_superuser"],
+ is_active=row["is_active"],
+ is_verified=row["is_verified"],
+ name=row["name"],
+ bio=row["bio"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ profile_picture=row["profile_picture"],
+ collection_ids=row["collection_ids"] or [],
+ num_files=row["num_files"],
+ total_size_in_bytes=row["total_size_in_bytes"],
+ document_ids=(
+ list(row["document_ids"])
+ if row["document_ids"]
+ else []
+ ),
+ )
+ )
+
+ total_entries = results[0]["total_entries"]
+ return {"results": users_list, "total_entries": total_entries}
+
+ async def _collection_exists(self, collection_id: UUID) -> bool:
+ """Check if a collection exists."""
+ query = f"""
+ SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id]
+ )
+ return result is not None
+
+ async def get_user_validation_data(
+ self,
+ user_id: UUID,
+ ) -> dict:
+ """Get verification data for a specific user.
+
+ This method should be called after superuser authorization has been
+ verified.
+ """
+ query = f"""
+ SELECT
+ verification_code,
+ verification_code_expiry,
+ reset_token,
+ reset_token_expiry
+ FROM {self._get_table_name("users")}
+ WHERE id = $1
+ """
+ result = await self.connection_manager.fetchrow_query(query, [user_id])
+
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ return {
+ "verification_data": {
+ "verification_code": result["verification_code"],
+ "verification_code_expiry": (
+ result["verification_code_expiry"].isoformat()
+ if result["verification_code_expiry"]
+ else None
+ ),
+ "reset_token": result["reset_token"],
+ "reset_token_expiry": (
+ result["reset_token_expiry"].isoformat()
+ if result["reset_token_expiry"]
+ else None
+ ),
+ }
+ }
+
+ # API Key methods
+ async def store_user_api_key(
+ self,
+ user_id: UUID,
+ key_id: str,
+ hashed_key: str,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> UUID:
+ """Store a new API key for a user with optional name and
+ description."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ (user_id, public_key, hashed_key, name, description)
+ VALUES ($1, $2, $3, $4, $5)
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [user_id, key_id, hashed_key, name or "", description or ""]
+ )
+ if not result:
+ raise R2RException(
+ status_code=500, message="Failed to store API key"
+ )
+ return result["id"]
+
+ async def get_api_key_record(self, key_id: str) -> Optional[dict]:
+ """Get API key record by 'public_key' and update 'updated_at' to now.
+
+ Returns { "user_id", "hashed_key" } or None if not found.
+ """
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ SET updated_at = NOW()
+ WHERE public_key = $1
+ RETURNING user_id, hashed_key
+ """
+ result = await self.connection_manager.fetchrow_query(query, [key_id])
+ if not result:
+ return None
+ return {
+ "user_id": result["user_id"],
+ "hashed_key": result["hashed_key"],
+ }
+
+ async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
+ """Get all API keys for a user."""
+ query = f"""
+ SELECT id, public_key, name, description, created_at, updated_at
+ FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ WHERE user_id = $1
+ ORDER BY created_at DESC
+ """
+ results = await self.connection_manager.fetch_query(query, [user_id])
+ return [
+ {
+ "key_id": str(row["id"]),
+ "public_key": row["public_key"],
+ "name": row["name"] or "",
+ "description": row["description"] or "",
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]
+
+ async def delete_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ """Delete a specific API key."""
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ WHERE id = $1 AND user_id = $2
+ RETURNING id, public_key, name, description
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [key_id, user_id]
+ )
+ if result is None:
+ raise R2RException(status_code=404, message="API key not found")
+
+ return True
+
+ async def update_api_key_name(
+ self, user_id: UUID, key_id: UUID, name: str
+ ) -> bool:
+ """Update the name of an existing API key."""
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ SET name = $1, updated_at = NOW()
+ WHERE id = $2 AND user_id = $3
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [name, key_id, user_id]
+ )
+ if result is None:
+ raise R2RException(status_code=404, message="API key not found")
+ return True
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "name",
+ "bio",
+ "collection_ids",
+ "created_at",
+ "updated_at",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ email,
+ is_superuser,
+ is_active,
+ is_verified,
+ name,
+ bio,
+ collection_ids::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ params = []
+ if filters:
+ conditions = []
+ param_index = 1
+
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "email": row[1],
+ "is_superuser": row[2],
+ "is_active": row[3],
+ "is_verified": row[4],
+ "name": row[5],
+ "bio": row[6],
+ "collection_ids": row[7],
+ "created_at": row[8],
+ "updated_at": row[9],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+ async def get_user_by_google_id(self, google_id: str) -> Optional[User]:
+ """Return a User if the google_id is found; otherwise None."""
+ query, params = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("google_id = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(
+ query, [google_id]
+ )
+ if not result:
+ return None
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def get_user_by_github_id(self, github_id: str) -> Optional[User]:
+ """Return a User if the github_id is found; otherwise None."""
+ query, params = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("github_id = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(
+ query, [github_id]
+ )
+ if not result:
+ return None
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/email/__init__.py
new file mode 100644
index 00000000..38753695
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/__init__.py
@@ -0,0 +1,11 @@
+from .console_mock import ConsoleMockEmailProvider
+from .mailersend import MailerSendEmailProvider
+from .sendgrid import SendGridEmailProvider
+from .smtp import AsyncSMTPEmailProvider
+
+__all__ = [
+ "ConsoleMockEmailProvider",
+ "AsyncSMTPEmailProvider",
+ "SendGridEmailProvider",
+ "MailerSendEmailProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py b/.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py
new file mode 100644
index 00000000..459a978d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py
@@ -0,0 +1,67 @@
+import logging
+from typing import Optional
+
+from core.base import EmailProvider
+
+logger = logging.getLogger()
+
+
+class ConsoleMockEmailProvider(EmailProvider):
+ """A simple email provider that logs emails to console, useful for
+ testing."""
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: {subject}
+ Body:
+ {body}
+ -----------------------------
+ """)
+
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: Please verify your email address
+ Body:
+ Verification code: {verification_code}
+ -----------------------------
+ """)
+
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: Password Reset Request
+ Body:
+ Reset token: {reset_token}
+ -----------------------------
+ """)
+
+ async def send_password_changed_email(
+ self, to_email: str, *args, **kwargs
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: Your Password Has Been Changed
+ Body:
+ Your password has been successfully changed.
+
+ For security reasons, you will need to log in again on all your devices.
+ -----------------------------
+ """)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py b/.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py
new file mode 100644
index 00000000..10fccd56
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py
@@ -0,0 +1,281 @@
+import logging
+import os
+from typing import Optional
+
+from mailersend import emails
+
+from core.base import EmailConfig, EmailProvider
+
+logger = logging.getLogger(__name__)
+
+
+class MailerSendEmailProvider(EmailProvider):
+ """Email provider implementation using MailerSend API."""
+
+ def __init__(self, config: EmailConfig):
+ super().__init__(config)
+ self.api_key = config.mailersend_api_key or os.getenv(
+ "MAILERSEND_API_KEY"
+ )
+ if not self.api_key or not isinstance(self.api_key, str):
+ raise ValueError("A valid MailerSend API key is required.")
+
+ self.from_email = config.from_email or os.getenv("R2R_FROM_EMAIL")
+ if not self.from_email or not isinstance(self.from_email, str):
+ raise ValueError("A valid from email is required.")
+
+ self.frontend_url = config.frontend_url or os.getenv(
+ "R2R_FRONTEND_URL"
+ )
+ if not self.frontend_url or not isinstance(self.frontend_url, str):
+ raise ValueError("A valid frontend URL is required.")
+
+ self.verify_email_template_id = (
+ config.verify_email_template_id
+ or os.getenv("MAILERSEND_VERIFY_EMAIL_TEMPLATE_ID")
+ )
+ self.reset_password_template_id = (
+ config.reset_password_template_id
+ or os.getenv("MAILERSEND_RESET_PASSWORD_TEMPLATE_ID")
+ )
+ self.password_changed_template_id = (
+ config.password_changed_template_id
+ or os.getenv("MAILERSEND_PASSWORD_CHANGED_TEMPLATE_ID")
+ )
+ self.client = emails.NewEmail(self.api_key)
+ self.sender_name = config.sender_name or "R2R"
+
+ # Logo and documentation URLs
+ self.docs_base_url = f"{self.frontend_url}/documentation"
+
+ def _get_base_template_data(self, to_email: str) -> dict:
+ """Get base template data used across all email templates."""
+ return {
+ "user_email": to_email,
+ "docs_url": self.docs_base_url,
+ "quickstart_url": f"{self.docs_base_url}/quickstart",
+ "frontend_url": self.frontend_url,
+ }
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: Optional[str] = None,
+ body: Optional[str] = None,
+ html_body: Optional[str] = None,
+ template_id: Optional[str] = None,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ logger.info("Preparing MailerSend message...")
+
+ mail_body = {
+ "from": {
+ "email": self.from_email,
+ "name": self.sender_name,
+ },
+ "to": [{"email": to_email}],
+ }
+
+ if template_id:
+ # Transform the template data to MailerSend's expected format
+ if dynamic_template_data:
+ formatted_substitutions = {}
+ for key, value in dynamic_template_data.items():
+ formatted_substitutions[key] = {
+ "var": key,
+ "value": value,
+ }
+ mail_body["variables"] = [
+ {
+ "email": to_email,
+ "substitutions": formatted_substitutions,
+ }
+ ]
+
+ mail_body["template_id"] = template_id
+ else:
+ mail_body.update(
+ {
+ "subject": subject or "",
+ "text": body or "",
+ "html": html_body or "",
+ }
+ )
+
+ import asyncio
+
+ response = await asyncio.to_thread(self.client.send, mail_body)
+
+ # Handle different response formats
+ if isinstance(response, str):
+ # Clean the string response by stripping whitespace
+ response_clean = response.strip()
+ if response_clean in ["202", "200"]:
+ logger.info(
+ f"Email accepted for delivery with status code {response_clean}"
+ )
+ return
+ elif isinstance(response, int) and response in [200, 202]:
+ logger.info(
+ f"Email accepted for delivery with status code {response}"
+ )
+ return
+ elif isinstance(response, dict) and response.get(
+ "status_code"
+ ) in [200, 202]:
+ logger.info(
+ f"Email accepted for delivery with status code {response.get('status_code')}"
+ )
+ return
+
+ # If we get here, it's an error
+ error_msg = f"MailerSend error: {response}"
+ logger.error(error_msg)
+
+ except Exception as e:
+ error_msg = f"Failed to send email to {to_email}: {str(e)}"
+ logger.error(error_msg)
+
+ async def send_verification_email(
+ self,
+ to_email: str,
+ verification_code: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.verify_email_template_id:
+ verification_data = {
+ "verification_link": f"{self.frontend_url}/verify-email?verification_code={verification_code}&email={to_email}",
+ "verification_code": verification_code, # Include code separately for flexible template usage
+ }
+
+ # Merge with any additional template data
+ template_data = {
+ **(dynamic_template_data or {}),
+ **verification_data,
+ }
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.verify_email_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ # Fallback to basic email if no template ID is configured
+ subject = "Verify Your R2R Account"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Welcome to R2R!</h1>
+ <p>Please verify your email address to get started with R2R - the most advanced AI retrieval system.</p>
+ <p>Click the link below to verify your email:</p>
+ <p><a href="{self.frontend_url}/verify-email?verification_code={verification_code}&email={to_email}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Verify Email
+ </a></p>
+ <p>Or enter this verification code: <strong>{verification_code}</strong></p>
+ <p>If you didn't create an account with R2R, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Welcome to R2R! Please verify your email using this code: {verification_code}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send verification email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+
+ async def send_password_reset_email(
+ self,
+ to_email: str,
+ reset_token: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.reset_password_template_id:
+ reset_data = {
+ "reset_link": f"{self.frontend_url}/reset-password?token={reset_token}",
+ "reset_token": reset_token,
+ }
+
+ template_data = {**(dynamic_template_data or {}), **reset_data}
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.reset_password_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ subject = "Reset Your R2R Password"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Reset Request</h1>
+ <p>You've requested to reset your R2R password.</p>
+ <p>Click the link below to reset your password:</p>
+ <p><a href="{self.frontend_url}/reset-password?token={reset_token}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Reset Password
+ </a></p>
+ <p>Or use this reset token: <strong>{reset_token}</strong></p>
+ <p>If you didn't request a password reset, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Reset your R2R password using this token: {reset_token}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send password reset email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ dynamic_template_data: Optional[dict] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ if (
+ hasattr(self, "password_changed_template_id")
+ and self.password_changed_template_id
+ ):
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.password_changed_template_id,
+ dynamic_template_data=dynamic_template_data,
+ )
+ else:
+ subject = "Your Password Has Been Changed"
+ body = """
+ Your password has been successfully changed.
+
+ If you did not make this change, please contact support immediately and secure your account.
+
+ """
+ html_body = """
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Changed Successfully</h1>
+ <p>Your password has been successfully changed.</p>
+ </div>
+ """
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=body,
+ )
+ except Exception as e:
+ error_msg = f"Failed to send password change notification to {to_email}: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py b/.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py
new file mode 100644
index 00000000..8b2553f1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py
@@ -0,0 +1,257 @@
+import logging
+import os
+from typing import Optional
+
+from sendgrid import SendGridAPIClient
+from sendgrid.helpers.mail import Content, From, Mail
+
+from core.base import EmailConfig, EmailProvider
+
+logger = logging.getLogger(__name__)
+
+
+class SendGridEmailProvider(EmailProvider):
+ """Email provider implementation using SendGrid API."""
+
+ def __init__(self, config: EmailConfig):
+ super().__init__(config)
+ self.api_key = config.sendgrid_api_key or os.getenv("SENDGRID_API_KEY")
+ if not self.api_key or not isinstance(self.api_key, str):
+ raise ValueError("A valid SendGrid API key is required.")
+
+ self.from_email = config.from_email or os.getenv("R2R_FROM_EMAIL")
+ if not self.from_email or not isinstance(self.from_email, str):
+ raise ValueError("A valid from email is required.")
+
+ self.frontend_url = config.frontend_url or os.getenv(
+ "R2R_FRONTEND_URL"
+ )
+ if not self.frontend_url or not isinstance(self.frontend_url, str):
+ raise ValueError("A valid frontend URL is required.")
+
+ self.verify_email_template_id = (
+ config.verify_email_template_id
+ or os.getenv("SENDGRID_EMAIL_TEMPLATE_ID")
+ )
+ self.reset_password_template_id = (
+ config.reset_password_template_id
+ or os.getenv("SENDGRID_RESET_TEMPLATE_ID")
+ )
+ self.password_changed_template_id = (
+ config.password_changed_template_id
+ or os.getenv("SENDGRID_PASSWORD_CHANGED_TEMPLATE_ID")
+ )
+ self.client = SendGridAPIClient(api_key=self.api_key)
+ self.sender_name = config.sender_name
+
+ # Logo and documentation URLs
+ self.docs_base_url = f"{self.frontend_url}/documentation"
+
+ def _get_base_template_data(self, to_email: str) -> dict:
+ """Get base template data used across all email templates."""
+ return {
+ "user_email": to_email,
+ "docs_url": self.docs_base_url,
+ "quickstart_url": f"{self.docs_base_url}/quickstart",
+ "frontend_url": self.frontend_url,
+ }
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: Optional[str] = None,
+ body: Optional[str] = None,
+ html_body: Optional[str] = None,
+ template_id: Optional[str] = None,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ logger.info("Preparing SendGrid message...")
+ message = Mail(
+ from_email=From(self.from_email, self.sender_name),
+ to_emails=to_email,
+ )
+
+ if template_id:
+ logger.info(f"Using dynamic template with ID: {template_id}")
+ message.template_id = template_id
+ base_data = self._get_base_template_data(to_email)
+ message.dynamic_template_data = {
+ **base_data,
+ **(dynamic_template_data or {}),
+ }
+ else:
+ if not subject:
+ raise ValueError(
+ "Subject is required when not using a template"
+ )
+ message.subject = subject
+ message.add_content(Content("text/plain", body or ""))
+ if html_body:
+ message.add_content(Content("text/html", html_body))
+
+ import asyncio
+
+ response = await asyncio.to_thread(self.client.send, message)
+
+ if response.status_code >= 400:
+ raise RuntimeError(
+ f"Failed to send email: {response.status_code}"
+ )
+ elif response.status_code == 202:
+ logger.info("Message sent successfully!")
+ else:
+ error_msg = f"Failed to send email. Status code: {response.status_code}, Body: {response.body}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg)
+
+ except Exception as e:
+ error_msg = f"Failed to send email to {to_email}: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_verification_email(
+ self,
+ to_email: str,
+ verification_code: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.verify_email_template_id:
+ verification_data = {
+ "verification_link": f"{self.frontend_url}/verify-email?verification_code={verification_code}&email={to_email}",
+ "verification_code": verification_code, # Include code separately for flexible template usage
+ }
+
+ # Merge with any additional template data
+ template_data = {
+ **(dynamic_template_data or {}),
+ **verification_data,
+ }
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.verify_email_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ # Fallback to basic email if no template ID is configured
+ subject = "Verify Your R2R Account"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Welcome to R2R!</h1>
+ <p>Please verify your email address to get started with R2R - the most advanced AI retrieval system.</p>
+ <p>Click the link below to verify your email:</p>
+ <p><a href="{self.frontend_url}/verify-email?token={verification_code}&email={to_email}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Verify Email
+ </a></p>
+ <p>Or enter this verification code: <strong>{verification_code}</strong></p>
+ <p>If you didn't create an account with R2R, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Welcome to R2R! Please verify your email using this code: {verification_code}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send verification email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_password_reset_email(
+ self,
+ to_email: str,
+ reset_token: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.reset_password_template_id:
+ reset_data = {
+ "reset_link": f"{self.frontend_url}/reset-password?token={reset_token}",
+ "reset_token": reset_token,
+ }
+
+ template_data = {**(dynamic_template_data or {}), **reset_data}
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.reset_password_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ subject = "Reset Your R2R Password"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Reset Request</h1>
+ <p>You've requested to reset your R2R password.</p>
+ <p>Click the link below to reset your password:</p>
+ <p><a href="{self.frontend_url}/reset-password?token={reset_token}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Reset Password
+ </a></p>
+ <p>Or use this reset token: <strong>{reset_token}</strong></p>
+ <p>If you didn't request a password reset, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Reset your R2R password using this token: {reset_token}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send password reset email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ dynamic_template_data: Optional[dict] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ if (
+ hasattr(self, "password_changed_template_id")
+ and self.password_changed_template_id
+ ):
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.password_changed_template_id,
+ dynamic_template_data=dynamic_template_data,
+ )
+ else:
+ subject = "Your Password Has Been Changed"
+ body = """
+ Your password has been successfully changed.
+
+ If you did not make this change, please contact support immediately and secure your account.
+
+ """
+ html_body = """
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Changed Successfully</h1>
+ <p>Your password has been successfully changed.</p>
+ </div>
+ """
+ # Move send_email inside the else block
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=body,
+ )
+ except Exception as e:
+ error_msg = f"Failed to send password change notification to {to_email}: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/smtp.py b/.venv/lib/python3.12/site-packages/core/providers/email/smtp.py
new file mode 100644
index 00000000..bd68ff36
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/smtp.py
@@ -0,0 +1,176 @@
+import asyncio
+import logging
+import os
+import smtplib
+import ssl
+from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
+from typing import Optional
+
+from core.base import EmailConfig, EmailProvider
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncSMTPEmailProvider(EmailProvider):
+ """Email provider implementation using Brevo SMTP relay."""
+
+ def __init__(self, config: EmailConfig):
+ super().__init__(config)
+ self.smtp_server = config.smtp_server or os.getenv("R2R_SMTP_SERVER")
+ if not self.smtp_server:
+ raise ValueError("SMTP server is required")
+
+ self.smtp_port = config.smtp_port or os.getenv("R2R_SMTP_PORT")
+ if not self.smtp_port:
+ raise ValueError("SMTP port is required")
+
+ self.smtp_username = config.smtp_username or os.getenv(
+ "R2R_SMTP_USERNAME"
+ )
+ if not self.smtp_username:
+ raise ValueError("SMTP username is required")
+
+ self.smtp_password = config.smtp_password or os.getenv(
+ "R2R_SMTP_PASSWORD"
+ )
+ if not self.smtp_password:
+ raise ValueError("SMTP password is required")
+
+ self.from_email: Optional[str] = (
+ config.from_email
+ or os.getenv("R2R_FROM_EMAIL")
+ or self.smtp_username
+ )
+ self.ssl_context = ssl.create_default_context()
+
+ async def _send_email_sync(self, msg: MIMEMultipart) -> None:
+ """Synchronous email sending wrapped in asyncio executor."""
+ loop = asyncio.get_running_loop()
+
+ def _send():
+ with smtplib.SMTP_SSL(
+ self.smtp_server,
+ self.smtp_port,
+ context=self.ssl_context,
+ timeout=30,
+ ) as server:
+ logger.info("Connected to SMTP server")
+ server.login(self.smtp_username, self.smtp_password)
+ logger.info("Login successful")
+ server.send_message(msg)
+ logger.info("Message sent successfully!")
+
+ try:
+ await loop.run_in_executor(None, _send)
+ except Exception as e:
+ error_msg = f"Failed to send email: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ msg = MIMEMultipart("alternative")
+ msg["Subject"] = subject
+ msg["From"] = self.from_email # type: ignore
+ msg["To"] = to_email
+
+ msg.attach(MIMEText(body, "plain"))
+ if html_body:
+ msg.attach(MIMEText(html_body, "html"))
+
+ try:
+ logger.info("Initializing SMTP connection...")
+ async with asyncio.timeout(30): # Overall timeout
+ await self._send_email_sync(msg)
+ except asyncio.TimeoutError as e:
+ error_msg = "Operation timed out while trying to send email"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+ except Exception as e:
+ error_msg = f"Failed to send email: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ body = f"""
+ Please verify your email address by entering the following code:
+
+ Verification code: {verification_code}
+
+ If you did not request this verification, please ignore this email.
+ """
+
+ html_body = f"""
+ <p>Please verify your email address by entering the following code:</p>
+ <p style="font-size: 24px; font-weight: bold; margin: 20px 0;">
+ Verification code: {verification_code}
+ </p>
+ <p>If you did not request this verification, please ignore this email.</p>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject="Please verify your email address",
+ body=body,
+ html_body=html_body,
+ )
+
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ body = f"""
+ You have requested to reset your password.
+
+ Reset token: {reset_token}
+
+ If you did not request a password reset, please ignore this email.
+ """
+
+ html_body = f"""
+ <p>You have requested to reset your password.</p>
+ <p style="font-size: 24px; font-weight: bold; margin: 20px 0;">
+ Reset token: {reset_token}
+ </p>
+ <p>If you did not request a password reset, please ignore this email.</p>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject="Password Reset Request",
+ body=body,
+ html_body=html_body,
+ )
+
+ async def send_password_changed_email(
+ self, to_email: str, *args, **kwargs
+ ) -> None:
+ body = """
+ Your password has been successfully changed.
+
+ If you did not make this change, please contact support immediately and secure your account.
+
+ """
+
+ html_body = """
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Changed Successfully</h1>
+ <p>Your password has been successfully changed.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject="Your Password Has Been Changed",
+ body=body,
+ html_body=html_body,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py
new file mode 100644
index 00000000..3fa67442
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py
@@ -0,0 +1,9 @@
+from .litellm import LiteLLMEmbeddingProvider
+from .ollama import OllamaEmbeddingProvider
+from .openai import OpenAIEmbeddingProvider
+
+__all__ = [
+ "LiteLLMEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ "OllamaEmbeddingProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py
new file mode 100644
index 00000000..5f705c91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py
@@ -0,0 +1,305 @@
+import logging
+import math
+import os
+from copy import copy
+from typing import Any
+
+import litellm
+import requests
+from aiohttp import ClientError, ClientSession
+from litellm import AuthenticationError, aembedding, embedding
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+ R2RException,
+)
+
+logger = logging.getLogger()
+
+
+class LiteLLMEmbeddingProvider(EmbeddingProvider):
+ def __init__(
+ self,
+ config: EmbeddingConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(config)
+
+ self.litellm_embedding = embedding
+ self.litellm_aembedding = aembedding
+
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `LiteLLMEmbeddingProvider`."
+ )
+ if provider != "litellm":
+ raise ValueError(
+ "LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
+ )
+
+ self.rerank_url = None
+ if config.rerank_model:
+ if "huggingface" not in config.rerank_model:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
+ )
+
+ url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url
+ if not url:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
+ )
+ self.rerank_url = url
+
+ self.base_model = config.base_model
+ if "amazon" in self.base_model:
+ logger.warn("Amazon embedding model detected, dropping params")
+ litellm.drop_params = True
+ self.base_dimension = config.base_dimension
+
+ def _get_embedding_kwargs(self, **kwargs):
+ embedding_kwargs = {
+ "model": self.base_model,
+ "dimensions": self.base_dimension,
+ }
+ embedding_kwargs.update(kwargs)
+ return embedding_kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ if "dimensions" in kwargs and math.isnan(kwargs["dimensions"]):
+ kwargs.pop("dimensions")
+ logger.warning("Dropping nan dimensions from kwargs")
+
+ try:
+ response = await self.litellm_aembedding(
+ input=texts,
+ **kwargs,
+ )
+ return [data["embedding"] for data in response.data]
+ except AuthenticationError:
+ logger.error(
+ "Authentication error: Invalid API key or credentials."
+ )
+ raise
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+
+ raise R2RException(error_msg, 400) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+ try:
+ response = self.litellm_embedding(
+ input=texts,
+ **kwargs,
+ )
+ return [data["embedding"] for data in response.data]
+ except AuthenticationError:
+ logger.error(
+ "Authentication error: Invalid API key or credentials."
+ )
+ raise
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return (await self._execute_with_backoff_async(task))[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ if self.config.rerank_model is not None:
+ if not self.rerank_url:
+ raise ValueError(
+ "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
+ )
+
+ texts = [result.text for result in results]
+
+ payload = {
+ "query": query,
+ "texts": texts,
+ "model-id": self.config.rerank_model.split("huggingface/")[1],
+ }
+
+ headers = {"Content-Type": "application/json"}
+
+ try:
+ response = requests.post(
+ self.rerank_url, json=payload, headers=headers
+ )
+ response.raise_for_status()
+ reranked_results = response.json()
+
+ # Copy reranked results into new array
+ scored_results = []
+ for rank_info in reranked_results:
+ original_result = results[rank_info["index"]]
+ copied_result = copy(original_result)
+ # Inject the reranking score into the result object
+ copied_result.score = rank_info["score"]
+ scored_results.append(copied_result)
+
+ # Return only the ChunkSearchResult objects, limited to specified count
+ return scored_results[:limit]
+
+ except requests.RequestException as e:
+ logger.error(f"Error during reranking: {str(e)}")
+ # Fall back to returning the original results if reranking fails
+ return results[:limit]
+ else:
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ) -> list[ChunkSearchResult]:
+ """Asynchronously rerank search results using the configured rerank
+ model.
+
+ Args:
+ query: The search query string
+ results: List of ChunkSearchResult objects to rerank
+ limit: Maximum number of results to return
+
+ Returns:
+ List of reranked ChunkSearchResult objects, limited to specified count
+ """
+ if self.config.rerank_model is not None:
+ if not self.rerank_url:
+ raise ValueError(
+ "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
+ )
+
+ texts = [result.text for result in results]
+
+ payload = {
+ "query": query,
+ "texts": texts,
+ "model-id": self.config.rerank_model.split("huggingface/")[1],
+ }
+
+ headers = {"Content-Type": "application/json"}
+
+ try:
+ async with ClientSession() as session:
+ async with session.post(
+ self.rerank_url, json=payload, headers=headers
+ ) as response:
+ response.raise_for_status()
+ reranked_results = await response.json()
+
+ # Copy reranked results into new array
+ scored_results = []
+ for rank_info in reranked_results:
+ original_result = results[rank_info["index"]]
+ copied_result = copy(original_result)
+ # Inject the reranking score into the result object
+ copied_result.score = rank_info["score"]
+ scored_results.append(copied_result)
+
+ # Return only the ChunkSearchResult objects, limited to specified count
+ return scored_results[:limit]
+
+ except (ClientError, Exception) as e:
+ logger.error(f"Error during async reranking: {str(e)}")
+ # Fall back to returning the original results if reranking fails
+ return results[:limit]
+ else:
+ return results[:limit]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
new file mode 100644
index 00000000..297d9167
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
@@ -0,0 +1,194 @@
+import logging
+import os
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+ R2RException,
+)
+
+logger = logging.getLogger()
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+ )
+ if provider != "ollama":
+ raise ValueError(
+ "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+ )
+ if config.rerank_model:
+ raise ValueError(
+ "OllamaEmbeddingProvider does not support separate reranking."
+ )
+
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+ self.base_url = os.getenv("OLLAMA_API_BASE")
+ logger.info(
+ f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+ )
+ self.client = Client(host=self.base_url)
+ self.aclient = AsyncClient(host=self.base_url)
+
+ self.set_prefixes(config.prefixes or {}, self.base_model)
+ self.batch_size = config.batch_size or 32
+
+ def _get_embedding_kwargs(self, **kwargs):
+ embedding_kwargs = {
+ "model": self.base_model,
+ }
+ embedding_kwargs.update(kwargs)
+ return embedding_kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ purpose = task.get("purpose", EmbeddingPurpose.INDEX)
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ embeddings = []
+ for i in range(0, len(texts), self.batch_size):
+ batch = texts[i : i + self.batch_size]
+ prefixed_batch = [
+ self.prefixes.get(purpose, "") + text for text in batch
+ ]
+ response = await self.aclient.embed(
+ input=prefixed_batch, **kwargs
+ )
+ embeddings.extend(response["embeddings"])
+ return embeddings
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ purpose = task.get("purpose", EmbeddingPurpose.INDEX)
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ embeddings = []
+ for i in range(0, len(texts), self.batch_size):
+ batch = texts[i : i + self.batch_size]
+ prefixed_batch = [
+ self.prefixes.get(purpose, "") + text for text in batch
+ ]
+ response = self.client.embed(input=prefixed_batch, **kwargs)
+ embeddings.extend(response["embeddings"])
+ return embeddings
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = await self._execute_with_backoff_async(task)
+ return result[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = self._execute_with_backoff_sync(task)
+ return result[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ) -> list[ChunkSearchResult]:
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
new file mode 100644
index 00000000..907cebd9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
@@ -0,0 +1,243 @@
+import logging
+import os
+from typing import Any
+
+import tiktoken
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+from openai._types import NOT_GIVEN
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+)
+
+logger = logging.getLogger()
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+ MODEL_TO_TOKENIZER = {
+ "text-embedding-ada-002": "cl100k_base",
+ "text-embedding-3-small": "cl100k_base",
+ "text-embedding-3-large": "cl100k_base",
+ }
+ MODEL_TO_DIMENSIONS = {
+ "text-embedding-ada-002": [1536],
+ "text-embedding-3-small": [512, 1536],
+ "text-embedding-3-large": [256, 1024, 3072],
+ }
+
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ if not config.provider:
+ raise ValueError(
+ "Must set provider in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if config.provider != "openai":
+ raise ValueError(
+ "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ self.client = OpenAI()
+ self.async_client = AsyncOpenAI()
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+
+ if config.base_model and "openai/" in config.base_model:
+ self.base_model = config.base_model.split("/")[-1]
+ else:
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+
+ if not self.base_model:
+ raise ValueError(
+ "Must set base_model in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(
+ f"OpenAI embedding model {self.base_model} not supported."
+ )
+
+ if self.base_dimension:
+ if (
+ self.base_dimension
+ not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ]
+ ):
+ raise ValueError(
+ f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
+ )
+ else:
+ # If base_dimension is not set, use the largest available dimension for the model
+ self.base_dimension = max(
+ OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+ )
+
+ def _get_dimensions(self):
+ return (
+ NOT_GIVEN
+ if self.base_model == "text-embedding-ada-002"
+ else self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
+ )
+
+ def _get_embedding_kwargs(self, **kwargs):
+ return {
+ "model": self.base_model,
+ "dimensions": self._get_dimensions(),
+ } | kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=texts,
+ **kwargs,
+ )
+ return [data.embedding for data in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise ValueError(error_msg) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+ try:
+ response = self.client.embeddings.create(
+ input=texts,
+ **kwargs,
+ )
+ return [data.embedding for data in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise ValueError(error_msg) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = await self._execute_with_backoff_async(task)
+ return result[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = self._execute_with_backoff_sync(task)
+ return result[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ def tokenize_string(self, text: str, model: str) -> list[int]:
+ if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(f"OpenAI embedding model {model} not supported.")
+ encoding = tiktoken.get_encoding(
+ OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+ )
+ return encoding.encode(text)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py
new file mode 100644
index 00000000..4a25d30d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py
@@ -0,0 +1,13 @@
+# type: ignore
+from .r2r.base import R2RIngestionConfig, R2RIngestionProvider
+from .unstructured.base import (
+ UnstructuredIngestionConfig,
+ UnstructuredIngestionProvider,
+)
+
+__all__ = [
+ "R2RIngestionConfig",
+ "R2RIngestionProvider",
+ "UnstructuredIngestionProvider",
+ "UnstructuredIngestionConfig",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py
new file mode 100644
index 00000000..7d452245
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py
@@ -0,0 +1,355 @@
+# type: ignore
+import logging
+import time
+from typing import Any, AsyncGenerator, Optional
+
+from core import parsers
+from core.base import (
+ AsyncParser,
+ ChunkingStrategy,
+ Document,
+ DocumentChunk,
+ DocumentType,
+ IngestionConfig,
+ IngestionProvider,
+ R2RDocumentProcessingError,
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+)
+from core.utils import generate_extraction_id
+
+from ...database import PostgresDatabaseProvider
+from ...llm import (
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+
+logger = logging.getLogger()
+
+
+class R2RIngestionConfig(IngestionConfig):
+ chunk_size: int = 1024
+ chunk_overlap: int = 512
+ chunking_strategy: ChunkingStrategy = ChunkingStrategy.RECURSIVE
+ extra_fields: dict[str, Any] = {}
+ separator: Optional[str] = None
+
+
+class R2RIngestionProvider(IngestionProvider):
+ DEFAULT_PARSERS = {
+ DocumentType.BMP: parsers.BMPParser,
+ DocumentType.CSV: parsers.CSVParser,
+ DocumentType.DOC: parsers.DOCParser,
+ DocumentType.DOCX: parsers.DOCXParser,
+ DocumentType.EML: parsers.EMLParser,
+ DocumentType.EPUB: parsers.EPUBParser,
+ DocumentType.HTML: parsers.HTMLParser,
+ DocumentType.HTM: parsers.HTMLParser,
+ DocumentType.ODT: parsers.ODTParser,
+ DocumentType.JSON: parsers.JSONParser,
+ DocumentType.MSG: parsers.MSGParser,
+ DocumentType.ORG: parsers.ORGParser,
+ DocumentType.MD: parsers.MDParser,
+ DocumentType.PDF: parsers.BasicPDFParser,
+ DocumentType.PPT: parsers.PPTParser,
+ DocumentType.PPTX: parsers.PPTXParser,
+ DocumentType.TXT: parsers.TextParser,
+ DocumentType.XLSX: parsers.XLSXParser,
+ DocumentType.GIF: parsers.ImageParser,
+ DocumentType.JPEG: parsers.ImageParser,
+ DocumentType.JPG: parsers.ImageParser,
+ DocumentType.TSV: parsers.TSVParser,
+ DocumentType.PNG: parsers.ImageParser,
+ DocumentType.HEIC: parsers.ImageParser,
+ DocumentType.SVG: parsers.ImageParser,
+ DocumentType.MP3: parsers.AudioParser,
+ DocumentType.P7S: parsers.P7SParser,
+ DocumentType.RST: parsers.RSTParser,
+ DocumentType.RTF: parsers.RTFParser,
+ DocumentType.TIFF: parsers.ImageParser,
+ DocumentType.XLS: parsers.XLSParser,
+ }
+
+ EXTRA_PARSERS = {
+ DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced},
+ DocumentType.PDF: {
+ "unstructured": parsers.PDFParserUnstructured,
+ "zerox": parsers.VLMPDFParser,
+ },
+ DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced},
+ }
+
+ IMAGE_TYPES = {
+ DocumentType.GIF,
+ DocumentType.HEIC,
+ DocumentType.JPG,
+ DocumentType.JPEG,
+ DocumentType.PNG,
+ DocumentType.SVG,
+ }
+
+ def __init__(
+ self,
+ config: R2RIngestionConfig,
+ database_provider: PostgresDatabaseProvider,
+ llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ ):
+ super().__init__(config, database_provider, llm_provider)
+ self.config: R2RIngestionConfig = config
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ self.llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ) = llm_provider
+ self.parsers: dict[DocumentType, AsyncParser] = {}
+ self.text_splitter = self._build_text_splitter()
+ self._initialize_parsers()
+
+ logger.info(
+ f"R2RIngestionProvider initialized with config: {self.config}"
+ )
+
+ def _initialize_parsers(self):
+ for doc_type, parser in self.DEFAULT_PARSERS.items():
+ # will choose the first parser in the list
+ if doc_type not in self.config.excluded_parsers:
+ self.parsers[doc_type] = parser(
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ for doc_type, doc_parser_name in self.config.extra_parsers.items():
+ self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = (
+ R2RIngestionProvider.EXTRA_PARSERS[doc_type][doc_parser_name](
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ )
+
+ def _build_text_splitter(
+ self, ingestion_config_override: Optional[dict] = None
+ ) -> TextSplitter:
+ logger.info(
+ f"Initializing text splitter with method: {self.config.chunking_strategy}"
+ )
+
+ if not ingestion_config_override:
+ ingestion_config_override = {}
+
+ chunking_strategy = (
+ ingestion_config_override.get("chunking_strategy")
+ or self.config.chunking_strategy
+ )
+
+ chunk_size = (
+ ingestion_config_override.get("chunk_size")
+ or self.config.chunk_size
+ )
+ chunk_overlap = (
+ ingestion_config_override.get("chunk_overlap")
+ or self.config.chunk_overlap
+ )
+
+ if chunking_strategy == ChunkingStrategy.RECURSIVE:
+ return RecursiveCharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ )
+ elif chunking_strategy == ChunkingStrategy.CHARACTER:
+ from core.base.utils.splitter.text import CharacterTextSplitter
+
+ separator = (
+ ingestion_config_override.get("separator")
+ or self.config.separator
+ or CharacterTextSplitter.DEFAULT_SEPARATOR
+ )
+
+ return CharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ separator=separator,
+ keep_separator=False,
+ strip_whitespace=True,
+ )
+ elif chunking_strategy == ChunkingStrategy.BASIC:
+ raise NotImplementedError(
+ "Basic chunking method not implemented. Please use Recursive."
+ )
+ elif chunking_strategy == ChunkingStrategy.BY_TITLE:
+ raise NotImplementedError("By title method not implemented")
+ else:
+ raise ValueError(f"Unsupported method type: {chunking_strategy}")
+
+ def validate_config(self) -> bool:
+ return self.config.chunk_size > 0 and self.config.chunk_overlap >= 0
+
+ def chunk(
+ self,
+ parsed_document: str | DocumentChunk,
+ ingestion_config_override: dict,
+ ) -> AsyncGenerator[Any, None]:
+ text_spliiter = self.text_splitter
+ if ingestion_config_override:
+ text_spliiter = self._build_text_splitter(
+ ingestion_config_override
+ )
+ if isinstance(parsed_document, DocumentChunk):
+ parsed_document = parsed_document.data
+
+ if isinstance(parsed_document, str):
+ chunks = text_spliiter.create_documents([parsed_document])
+ else:
+ # Assuming parsed_document is already a list of text chunks
+ chunks = parsed_document
+
+ for chunk in chunks:
+ yield (
+ chunk.page_content if hasattr(chunk, "page_content") else chunk
+ )
+
+ async def parse(
+ self,
+ file_content: bytes,
+ document: Document,
+ ingestion_config_override: dict,
+ ) -> AsyncGenerator[DocumentChunk, None]:
+ if document.document_type not in self.parsers:
+ raise R2RDocumentProcessingError(
+ document_id=document.id,
+ error_message=f"Parser for {document.document_type} not found in `R2RIngestionProvider`.",
+ )
+ else:
+ t0 = time.time()
+ contents = []
+ parser_overrides = ingestion_config_override.get(
+ "parser_overrides", {}
+ )
+ if document.document_type.value in parser_overrides:
+ logger.info(
+ f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
+ )
+ # TODO - Cleanup this approach to be less hardcoded
+ if (
+ document.document_type != DocumentType.PDF
+ or parser_overrides[DocumentType.PDF.value] != "zerox"
+ ):
+ raise ValueError(
+ "Only Zerox PDF parser override is available."
+ )
+
+ # Collect content from VLMPDFParser
+ async for chunk in self.parsers[
+ f"zerox_{DocumentType.PDF.value}"
+ ].ingest(file_content, **ingestion_config_override):
+ if isinstance(chunk, dict) and chunk.get("content"):
+ contents.append(chunk)
+ elif (
+ chunk
+ ): # Handle string output for backward compatibility
+ contents.append({"content": chunk})
+
+ if (
+ contents
+ and document.document_type == DocumentType.PDF
+ and parser_overrides.get(DocumentType.PDF.value) == "zerox"
+ ):
+ text_splitter = self._build_text_splitter(
+ ingestion_config_override
+ )
+
+ iteration = 0
+
+ sorted_contents = [
+ item
+ for item in sorted(
+ contents, key=lambda x: x.get("page_number", 0)
+ )
+ if isinstance(item.get("content"), str)
+ ]
+
+ for content_item in sorted_contents:
+ page_num = content_item.get("page_number", 0)
+ page_content = content_item["content"]
+
+ page_chunks = text_splitter.create_documents(
+ [page_content]
+ )
+
+ # Create document chunks for each split piece
+ for chunk in page_chunks:
+ metadata = {
+ **document.metadata,
+ "chunk_order": iteration,
+ "page_number": page_num,
+ }
+
+ extraction = DocumentChunk(
+ id=generate_extraction_id(
+ document.id, iteration
+ ),
+ document_id=document.id,
+ owner_id=document.owner_id,
+ collection_ids=document.collection_ids,
+ data=chunk.page_content,
+ metadata=metadata,
+ )
+ iteration += 1
+ yield extraction
+
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
+ f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
+ f"into {iteration} extractions in t={time.time() - t0:.2f} seconds using page-by-page splitting."
+ )
+ return
+
+ else:
+ # Standard parsing for non-override cases
+ async for text in self.parsers[document.document_type].ingest(
+ file_content, **ingestion_config_override
+ ):
+ if text is not None:
+ contents.append({"content": text})
+
+ if not contents:
+ logging.warning(
+ "No valid text content was extracted during parsing"
+ )
+ return
+
+ iteration = 0
+ for content_item in contents:
+ chunk_text = content_item["content"]
+ chunks = self.chunk(chunk_text, ingestion_config_override)
+
+ for chunk in chunks:
+ metadata = {**document.metadata, "chunk_order": iteration}
+ if "page_number" in content_item:
+ metadata["page_number"] = content_item["page_number"]
+
+ extraction = DocumentChunk(
+ id=generate_extraction_id(document.id, iteration),
+ document_id=document.id,
+ owner_id=document.owner_id,
+ collection_ids=document.collection_ids,
+ data=chunk,
+ metadata=metadata,
+ )
+ iteration += 1
+ yield extraction
+
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
+ f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
+ f"into {iteration} extractions in t={time.time() - t0:.2f} seconds."
+ )
+
+ def get_parser_for_document_type(self, doc_type: DocumentType) -> Any:
+ return self.parsers.get(doc_type)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py
new file mode 100644
index 00000000..29c09bf5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py
@@ -0,0 +1,396 @@
+# TODO - cleanup type issues in this file that relate to `bytes`
+import asyncio
+import base64
+import logging
+import os
+import time
+from copy import copy
+from io import BytesIO
+from typing import Any, AsyncGenerator
+
+import httpx
+from unstructured_client import UnstructuredClient
+from unstructured_client.models import operations, shared
+
+from core import parsers
+from core.base import (
+ AsyncParser,
+ ChunkingStrategy,
+ Document,
+ DocumentChunk,
+ DocumentType,
+ RecursiveCharacterTextSplitter,
+)
+from core.base.abstractions import R2RSerializable
+from core.base.providers.ingestion import IngestionConfig, IngestionProvider
+from core.utils import generate_extraction_id
+
+from ...database import PostgresDatabaseProvider
+from ...llm import (
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+
+logger = logging.getLogger()
+
+
+class FallbackElement(R2RSerializable):
+ text: str
+ metadata: dict[str, Any]
+
+
+class UnstructuredIngestionConfig(IngestionConfig):
+ combine_under_n_chars: int = 128
+ max_characters: int = 500
+ new_after_n_chars: int = 1500
+ overlap: int = 64
+
+ coordinates: bool | None = None
+ encoding: str | None = None # utf-8
+ extract_image_block_types: list[str] | None = None
+ gz_uncompressed_content_type: str | None = None
+ hi_res_model_name: str | None = None
+ include_orig_elements: bool | None = None
+ include_page_breaks: bool | None = None
+
+ languages: list[str] | None = None
+ multipage_sections: bool | None = None
+ ocr_languages: list[str] | None = None
+ # output_format: Optional[str] = "application/json"
+ overlap_all: bool | None = None
+ pdf_infer_table_structure: bool | None = None
+
+ similarity_threshold: float | None = None
+ skip_infer_table_types: list[str] | None = None
+ split_pdf_concurrency_level: int | None = None
+ split_pdf_page: bool | None = None
+ starting_page_number: int | None = None
+ strategy: str | None = None
+ chunking_strategy: str | ChunkingStrategy | None = None # type: ignore
+ unique_element_ids: bool | None = None
+ xml_keep_tags: bool | None = None
+
+ def to_ingestion_request(self):
+ import json
+
+ x = json.loads(self.json())
+ x.pop("extra_fields", None)
+ x.pop("provider", None)
+ x.pop("excluded_parsers", None)
+
+ x = {k: v for k, v in x.items() if v is not None}
+ return x
+
+
+class UnstructuredIngestionProvider(IngestionProvider):
+ R2R_FALLBACK_PARSERS = {
+ DocumentType.GIF: [parsers.ImageParser], # type: ignore
+ DocumentType.JPEG: [parsers.ImageParser], # type: ignore
+ DocumentType.JPG: [parsers.ImageParser], # type: ignore
+ DocumentType.PNG: [parsers.ImageParser], # type: ignore
+ DocumentType.SVG: [parsers.ImageParser], # type: ignore
+ DocumentType.HEIC: [parsers.ImageParser], # type: ignore
+ DocumentType.MP3: [parsers.AudioParser], # type: ignore
+ DocumentType.JSON: [parsers.JSONParser], # type: ignore
+ DocumentType.HTML: [parsers.HTMLParser], # type: ignore
+ DocumentType.XLS: [parsers.XLSParser], # type: ignore
+ DocumentType.XLSX: [parsers.XLSXParser], # type: ignore
+ DocumentType.DOC: [parsers.DOCParser], # type: ignore
+ DocumentType.PPT: [parsers.PPTParser], # type: ignore
+ }
+
+ EXTRA_PARSERS = {
+ DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore
+ DocumentType.PDF: {
+ "unstructured": parsers.PDFParserUnstructured, # type: ignore
+ "zerox": parsers.VLMPDFParser, # type: ignore
+ },
+ DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore
+ }
+
+ IMAGE_TYPES = {
+ DocumentType.GIF,
+ DocumentType.HEIC,
+ DocumentType.JPG,
+ DocumentType.JPEG,
+ DocumentType.PNG,
+ DocumentType.SVG,
+ }
+
+ def __init__(
+ self,
+ config: UnstructuredIngestionConfig,
+ database_provider: PostgresDatabaseProvider,
+ llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ ):
+ super().__init__(config, database_provider, llm_provider)
+ self.config: UnstructuredIngestionConfig = config
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ self.llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ) = llm_provider
+
+ if config.provider == "unstructured_api":
+ try:
+ self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"]
+ except KeyError as e:
+ raise ValueError(
+ "UNSTRUCTURED_API_KEY environment variable is not set"
+ ) from e
+
+ self.unstructured_api_url = os.environ.get(
+ "UNSTRUCTURED_API_URL",
+ "https://api.unstructuredapp.io/general/v0/general",
+ )
+
+ self.client = UnstructuredClient(
+ api_key_auth=self.unstructured_api_auth,
+ server_url=self.unstructured_api_url,
+ )
+ self.shared = shared
+ self.operations = operations
+
+ else:
+ try:
+ self.local_unstructured_url = os.environ[
+ "UNSTRUCTURED_SERVICE_URL"
+ ]
+ except KeyError as e:
+ raise ValueError(
+ "UNSTRUCTURED_SERVICE_URL environment variable is not set"
+ ) from e
+
+ self.client = httpx.AsyncClient()
+
+ self.parsers: dict[DocumentType, AsyncParser] = {}
+ self._initialize_parsers()
+
+ def _initialize_parsers(self):
+ for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items():
+ for parser in parsers:
+ if (
+ doc_type not in self.config.excluded_parsers
+ and doc_type not in self.parsers
+ ):
+ # will choose the first parser in the list
+ self.parsers[doc_type] = parser(
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ # TODO - Reduce code duplication between Unstructured & R2R
+ for doc_type, doc_parser_name in self.config.extra_parsers.items():
+ self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = (
+ UnstructuredIngestionProvider.EXTRA_PARSERS[doc_type][
+ doc_parser_name
+ ](
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ )
+
+ async def parse_fallback(
+ self,
+ file_content: bytes,
+ ingestion_config: dict,
+ parser_name: str,
+ ) -> AsyncGenerator[FallbackElement, None]:
+ contents = []
+ async for chunk in self.parsers[parser_name].ingest( # type: ignore
+ file_content, **ingestion_config
+ ): # type: ignore
+ if isinstance(chunk, dict) and chunk.get("content"):
+ contents.append(chunk)
+ elif chunk: # Handle string output for backward compatibility
+ contents.append({"content": chunk})
+
+ if not contents:
+ logging.warning(
+ "No valid text content was extracted during parsing"
+ )
+ return
+
+ logging.info(f"Fallback ingestion with config = {ingestion_config}")
+
+ iteration = 0
+ for content_item in contents:
+ text = content_item["content"]
+
+ loop = asyncio.get_event_loop()
+ splitter = RecursiveCharacterTextSplitter(
+ chunk_size=ingestion_config["new_after_n_chars"],
+ chunk_overlap=ingestion_config["overlap"],
+ )
+ chunks = await loop.run_in_executor(
+ None, splitter.create_documents, [text]
+ )
+
+ for text_chunk in chunks:
+ metadata = {"chunk_id": iteration}
+ if "page_number" in content_item:
+ metadata["page_number"] = content_item["page_number"]
+
+ yield FallbackElement(
+ text=text_chunk.page_content,
+ metadata=metadata,
+ )
+ iteration += 1
+ await asyncio.sleep(0)
+
+ async def parse(
+ self,
+ file_content: bytes,
+ document: Document,
+ ingestion_config_override: dict,
+ ) -> AsyncGenerator[DocumentChunk, None]:
+ ingestion_config = copy(
+ {
+ **self.config.to_ingestion_request(),
+ **(ingestion_config_override or {}),
+ }
+ )
+ # cleanup extra fields
+ ingestion_config.pop("provider", None)
+ ingestion_config.pop("excluded_parsers", None)
+
+ t0 = time.time()
+ parser_overrides = ingestion_config_override.get(
+ "parser_overrides", {}
+ )
+ elements = []
+
+ # TODO - Cleanup this approach to be less hardcoded
+ # TODO - Remove code duplication between Unstructured & R2R
+ if document.document_type.value in parser_overrides:
+ logger.info(
+ f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
+ )
+ async for element in self.parse_fallback(
+ file_content,
+ ingestion_config=ingestion_config,
+ parser_name=f"zerox_{DocumentType.PDF.value}",
+ ):
+ elements.append(element)
+
+ elif document.document_type in self.R2R_FALLBACK_PARSERS.keys():
+ logger.info(
+ f"Parsing {document.document_type}: {document.id} with fallback parser"
+ )
+ async for element in self.parse_fallback(
+ file_content,
+ ingestion_config=ingestion_config,
+ parser_name=document.document_type,
+ ):
+ elements.append(element)
+ else:
+ logger.info(
+ f"Parsing {document.document_type}: {document.id} with unstructured"
+ )
+ if isinstance(file_content, bytes):
+ file_content = BytesIO(file_content) # type: ignore
+
+ # TODO - Include check on excluded parsers here.
+ if self.config.provider == "unstructured_api":
+ logger.info(f"Using API to parse document {document.id}")
+ files = self.shared.Files(
+ content=file_content.read(), # type: ignore
+ file_name=document.metadata.get("title", "unknown_file"),
+ )
+
+ ingestion_config.pop("app", None)
+ ingestion_config.pop("extra_parsers", None)
+
+ req = self.operations.PartitionRequest(
+ self.shared.PartitionParameters(
+ files=files,
+ **ingestion_config,
+ )
+ )
+ elements = self.client.general.partition(req) # type: ignore
+ elements = list(elements.elements) # type: ignore
+
+ else:
+ logger.info(
+ f"Using local unstructured fastapi server to parse document {document.id}"
+ )
+ # Base64 encode the file content
+ encoded_content = base64.b64encode(file_content.read()).decode( # type: ignore
+ "utf-8"
+ )
+
+ logger.info(
+ f"Sending a request to {self.local_unstructured_url}/partition"
+ )
+
+ response = await self.client.post(
+ f"{self.local_unstructured_url}/partition",
+ json={
+ "file_content": encoded_content, # Use encoded string
+ "ingestion_config": ingestion_config,
+ "filename": document.metadata.get("title", None),
+ },
+ timeout=3600, # Adjust timeout as needed
+ )
+
+ if response.status_code != 200:
+ logger.error(f"Error partitioning file: {response.text}")
+ raise ValueError(
+ f"Error partitioning file: {response.text}"
+ )
+ elements = response.json().get("elements", [])
+
+ iteration = 0 # if there are no chunks
+ for iteration, element in enumerate(elements):
+ if isinstance(element, FallbackElement):
+ text = element.text
+ metadata = copy(document.metadata)
+ metadata.update(element.metadata)
+ else:
+ element_dict = (
+ element if isinstance(element, dict) else element.to_dict()
+ )
+ text = element_dict.get("text", "")
+ if text == "":
+ continue
+
+ metadata = copy(document.metadata)
+ for key, value in element_dict.items():
+ if key == "metadata":
+ for k, v in value.items():
+ if k not in metadata and k != "orig_elements":
+ metadata[f"unstructured_{k}"] = v
+
+ # indicate that the document was chunked using unstructured
+ # nullifies the need for chunking in the pipeline
+ metadata["partitioned_by_unstructured"] = True
+ metadata["chunk_order"] = iteration
+ # creating the text extraction
+ yield DocumentChunk(
+ id=generate_extraction_id(document.id, iteration),
+ document_id=document.id,
+ owner_id=document.owner_id,
+ collection_ids=document.collection_ids,
+ data=text,
+ metadata=metadata,
+ )
+
+ # TODO: explore why this is throwing inadvertedly
+ # if iteration == 0:
+ # raise ValueError(f"No chunks found for document {document.id}")
+
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
+ f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
+ f"into {iteration + 1} extractions in t={time.time() - t0:.2f} seconds."
+ )
+
+ def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
+ return "unstructured_local"
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py
new file mode 100644
index 00000000..8132e11c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py
@@ -0,0 +1,11 @@
+from .anthropic import AnthropicCompletionProvider
+from .litellm import LiteLLMCompletionProvider
+from .openai import OpenAICompletionProvider
+from .r2r_llm import R2RCompletionProvider
+
+__all__ = [
+ "AnthropicCompletionProvider",
+ "LiteLLMCompletionProvider",
+ "OpenAICompletionProvider",
+ "R2RCompletionProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py b/.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py
new file mode 100644
index 00000000..0089a207
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py
@@ -0,0 +1,925 @@
+import copy
+import json
+import logging
+import os
+import time
+import uuid
+from typing import (
+ Any,
+ AsyncGenerator,
+ Generator,
+ Optional,
+)
+
+from anthropic import Anthropic, AsyncAnthropic
+from anthropic.types import (
+ ContentBlockStopEvent,
+ Message,
+ MessageStopEvent,
+ RawContentBlockDeltaEvent,
+ RawContentBlockStartEvent,
+ RawMessageStartEvent,
+ ToolUseBlock,
+)
+
+from core.base.abstractions import GenerationConfig, LLMChatCompletion
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+from .utils import resize_base64_image
+
+logger = logging.getLogger(__name__)
+
+
+def generate_tool_id() -> str:
+ """Generate a unique tool ID using UUID4."""
+ return f"tool_{uuid.uuid4().hex[:12]}"
+
+
+def process_images_in_message(message: dict) -> dict:
+ """
+ Process all images in a message to ensure they're within Anthropic's recommended limits.
+ """
+ if not message or not isinstance(message, dict):
+ return message
+
+ # Handle nested image_data (old format)
+ if (
+ message.get("role")
+ and message.get("image_data")
+ and isinstance(message["image_data"], dict)
+ ):
+ if message["image_data"].get("data") and message["image_data"].get(
+ "media_type"
+ ):
+ message["image_data"]["data"] = resize_base64_image(
+ message["image_data"]["data"]
+ )
+ return message
+
+ # Handle standard content list format
+ if message.get("content") and isinstance(message["content"], list):
+ for i, block in enumerate(message["content"]):
+ if isinstance(block, dict) and block.get("type") == "image":
+ if block.get("source", {}).get("type") == "base64" and block[
+ "source"
+ ].get("data"):
+ message["content"][i]["source"]["data"] = (
+ resize_base64_image(block["source"]["data"])
+ )
+
+ # Handle string content with base64 image detection (less common)
+ elif (
+ message.get("content")
+ and isinstance(message["content"], str)
+ and ";base64," in message["content"]
+ ):
+ # This is a basic detection for base64 images in text - might need more robust handling
+ logger.warning(
+ "Detected potential base64 image in string content - not auto-resizing"
+ )
+
+ return message
+
+
+def openai_message_to_anthropic_block(msg: dict) -> dict:
+ """Converts a single OpenAI-style message (including function/tool calls)
+ into one Anthropic-style message.
+
+ Expected keys in `msg` can include:
+ - role: "system" | "assistant" | "user" | "function" | "tool"
+ - content: str (possibly JSON arguments or the final text)
+ - name: str (tool/function name)
+ - tool_call_id or function_call arguments
+ - function_call: {"name": ..., "arguments": "..."}
+ """
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+ tool_call_id = msg.get("tool_call_id")
+
+ # Handle old-style image_data field
+ image_data = msg.get("image_data")
+ # Handle nested image_url (less common)
+ image_url = msg.get("image_url")
+
+ if role == "system":
+ # System messages should not have images, extract any image to a separate user message
+ if image_url or image_data:
+ logger.warning(
+ "Found image in system message - images should be in user messages only"
+ )
+ return msg
+
+ if role in ["user", "assistant"]:
+ # If content is already a list, assume it's properly formatted
+ if isinstance(content, list):
+ return {"role": role, "content": content}
+
+ # Process old-style image_data or image_url
+ if image_url or image_data:
+ formatted_content = []
+
+ # Add image content first (as recommended by Anthropic)
+ if image_url:
+ formatted_content.append(
+ {
+ "type": "image",
+ "source": {"type": "url", "url": image_url},
+ }
+ )
+ elif image_data:
+ # Resize the image data if needed
+ resized_data = image_data.get("data", "")
+ if resized_data:
+ resized_data = resize_base64_image(resized_data)
+
+ formatted_content.append(
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": image_data.get(
+ "media_type", "image/jpeg"
+ ),
+ "data": resized_data,
+ },
+ }
+ )
+
+ # Add text content after the image
+ if content:
+ if isinstance(content, str):
+ formatted_content.append({"type": "text", "text": content})
+ elif isinstance(content, list):
+ # If it's already a list, extend with it
+ formatted_content.extend(content)
+
+ return {"role": role, "content": formatted_content}
+
+ if role in ["function", "tool"]:
+ return {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": tool_call_id,
+ "content": content,
+ }
+ ],
+ }
+
+ return {"role": role, "content": content}
+
+
+class AnthropicCompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.client = Anthropic()
+ self.async_client = AsyncAnthropic()
+ logger.debug("AnthropicCompletionProvider initialized successfully")
+
+ def _get_base_args(
+ self, generation_config: GenerationConfig
+ ) -> dict[str, Any]:
+ """Build the arguments dictionary for Anthropic's messages.create().
+
+ Handles tool configuration according to Anthropic's schema:
+ {
+ "type": "function", # Use 'function' type for custom tools
+ "name": "tool_name",
+ "description": "tool description",
+ "parameters": { # Note: Anthropic expects 'parameters', not 'input_schema'
+ "type": "object",
+ "properties": {...},
+ "required": [...]
+ }
+ }
+ """
+ model_str = generation_config.model or ""
+ model_name = (
+ model_str.split("anthropic/")[-1]
+ if model_str
+ else "claude-3-opus-20240229"
+ )
+
+ args: dict[str, Any] = {
+ "model": model_name,
+ "temperature": generation_config.temperature,
+ "max_tokens": generation_config.max_tokens_to_sample,
+ "stream": generation_config.stream,
+ }
+ if generation_config.top_p:
+ args["top_p"] = generation_config.top_p
+
+ if generation_config.tools is not None:
+ # Convert tools to Anthropic's format
+ anthropic_tools: list[dict[str, Any]] = []
+ for tool in generation_config.tools:
+ tool_def = {
+ "name": tool["function"]["name"],
+ "description": tool["function"]["description"],
+ "input_schema": tool["function"]["parameters"],
+ }
+ anthropic_tools.append(tool_def)
+ args["tools"] = anthropic_tools
+
+ if hasattr(generation_config, "tool_choice"):
+ tool_choice = generation_config.tool_choice
+ if isinstance(tool_choice, str):
+ if tool_choice == "auto":
+ args["tool_choice"] = {"type": "auto"}
+ elif tool_choice == "any":
+ args["tool_choice"] = {"type": "any"}
+ elif isinstance(tool_choice, dict):
+ if tool_choice.get("type") == "function":
+ args["tool_choice"] = {
+ "type": "function",
+ "name": tool_choice.get("name"),
+ }
+ if hasattr(generation_config, "disable_parallel_tool_use"):
+ args["tool_choice"] = args.get("tool_choice", {})
+ args["tool_choice"]["disable_parallel_tool_use"] = (
+ generation_config.disable_parallel_tool_use
+ )
+
+ # --- Extended Thinking Support ---
+ if getattr(generation_config, "extended_thinking", False):
+ if (
+ not hasattr(generation_config, "thinking_budget")
+ or generation_config.thinking_budget is None
+ ):
+ raise ValueError(
+ "Extended thinking is enabled but no thinking_budget is provided."
+ )
+ if (
+ generation_config.thinking_budget
+ >= generation_config.max_tokens_to_sample
+ ):
+ raise ValueError(
+ "thinking_budget must be less than max_tokens_to_sample."
+ )
+ args["thinking"] = {
+ "type": "enabled",
+ "budget_tokens": generation_config.thinking_budget,
+ }
+ return args
+
+ def _preprocess_messages(self, messages: list[dict]) -> list[dict]:
+ """
+ Preprocess all messages to optimize images before sending to Anthropic API.
+ """
+ if not messages or not isinstance(messages, list):
+ return messages
+
+ processed_messages = []
+ for message in messages:
+ processed_message = process_images_in_message(message)
+ processed_messages.append(processed_message)
+
+ return processed_messages
+
+ def _create_openai_style_message(self, content_blocks, tool_calls=None):
+ """
+ Create an OpenAI-style message from Anthropic content blocks
+ while preserving the original structure.
+ """
+ display_content = ""
+ structured_content: list[Any] = []
+
+ for block in content_blocks:
+ if block.type == "text":
+ display_content += block.text
+ elif block.type == "thinking" and hasattr(block, "thinking"):
+ # Store the complete thinking block
+ structured_content.append(
+ {
+ "type": "thinking",
+ "thinking": block.thinking,
+ "signature": block.signature,
+ }
+ )
+ # For display/logging
+ # display_content += f"<think>{block.thinking}</think>"
+ elif block.type == "redacted_thinking" and hasattr(block, "data"):
+ # Store the complete redacted thinking block
+ structured_content.append(
+ {"type": "redacted_thinking", "data": block.data}
+ )
+ # Add a placeholder for display/logging
+ display_content += "<redacted thinking block>"
+ elif block.type == "tool_use":
+ # Tool use blocks are handled separately via tool_calls
+ pass
+
+ # If we have structured content (thinking blocks), use that
+ if structured_content:
+ # Add any text block at the end if needed
+ for block in content_blocks:
+ if block.type == "text":
+ structured_content.append(
+ {"type": "text", "text": block.text}
+ )
+
+ return {
+ "content": display_content or None,
+ "structured_content": structured_content,
+ }
+ else:
+ # If no structured content, just return the display content
+ return {"content": display_content or None}
+
+ def _convert_to_chat_completion(self, anthropic_msg: Message) -> dict:
+ """
+ Convert a non-streaming Anthropic Message into an OpenAI-style dict.
+ Preserves thinking blocks for proper handling.
+ """
+ tool_calls: list[Any] = []
+ message_data: dict[str, Any] = {"role": anthropic_msg.role}
+
+ if anthropic_msg.content:
+ # First, extract any tool use blocks
+ for block in anthropic_msg.content:
+ if hasattr(block, "type") and block.type == "tool_use":
+ tool_calls.append(
+ {
+ "index": len(tool_calls),
+ "id": block.id,
+ "type": "function",
+ "function": {
+ "name": block.name,
+ "arguments": json.dumps(block.input),
+ },
+ }
+ )
+
+ # Then create the message with appropriate content
+ message_data.update(
+ self._create_openai_style_message(
+ anthropic_msg.content, tool_calls
+ )
+ )
+
+ # If we have tool calls, add them
+ if tool_calls:
+ message_data["tool_calls"] = tool_calls
+
+ finish_reason = (
+ "stop"
+ if anthropic_msg.stop_reason == "end_turn"
+ else anthropic_msg.stop_reason
+ )
+ finish_reason = (
+ "tool_calls"
+ if anthropic_msg.stop_reason == "tool_use"
+ else finish_reason
+ )
+
+ model_str = anthropic_msg.model or ""
+ model_name = model_str.split("anthropic/")[-1] if model_str else ""
+
+ return {
+ "id": anthropic_msg.id,
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": model_name,
+ "usage": {
+ "prompt_tokens": (
+ anthropic_msg.usage.input_tokens
+ if anthropic_msg.usage
+ else 0
+ ),
+ "completion_tokens": (
+ anthropic_msg.usage.output_tokens
+ if anthropic_msg.usage
+ else 0
+ ),
+ "total_tokens": (
+ (
+ anthropic_msg.usage.input_tokens
+ if anthropic_msg.usage
+ else 0
+ )
+ + (
+ anthropic_msg.usage.output_tokens
+ if anthropic_msg.usage
+ else 0
+ )
+ ),
+ },
+ "choices": [
+ {
+ "index": 0,
+ "message": message_data,
+ "finish_reason": finish_reason,
+ }
+ ],
+ }
+
+ def _split_system_messages(
+ self, messages: list[dict]
+ ) -> tuple[list[dict], Optional[str]]:
+ """
+ Process messages for Anthropic API, ensuring proper format for tool use and thinking blocks.
+ Now with image optimization.
+ """
+ # First preprocess to resize any images
+ messages = self._preprocess_messages(messages)
+
+ system_msg = None
+ filtered: list[dict[str, Any]] = []
+ pending_tool_results: list[dict[str, Any]] = []
+
+ # Look for pairs of tool_use and tool_result
+ i = 0
+ while i < len(messages):
+ m = copy.deepcopy(messages[i])
+
+ # Handle system message
+ if m["role"] == "system" and system_msg is None:
+ system_msg = m["content"]
+ i += 1
+ continue
+
+ # Case 1: Message with list format content (thinking blocks or tool blocks)
+ if (
+ isinstance(m.get("content"), list)
+ and len(m["content"]) > 0
+ and isinstance(m["content"][0], dict)
+ ):
+ filtered.append({"role": m["role"], "content": m["content"]})
+ i += 1
+ continue
+
+ # Case 2: Message with structured_content field
+ elif m.get("structured_content") and m["role"] == "assistant":
+ filtered.append(
+ {"role": "assistant", "content": m["structured_content"]}
+ )
+ i += 1
+ continue
+
+ # Case 3: Tool calls in an assistant message
+ elif m.get("tool_calls") and m["role"] == "assistant":
+ # Add content if it exists
+ if m.get("content") and not isinstance(m["content"], list):
+ content_to_add = m["content"]
+ # Handle content with thinking tags
+ if "<think>" in content_to_add:
+ thinking_start = content_to_add.find("<think>")
+ thinking_end = content_to_add.find("</think>")
+ if (
+ thinking_start >= 0
+ and thinking_end > thinking_start
+ ):
+ thinking_content = content_to_add[
+ thinking_start + 7 : thinking_end
+ ]
+ text_content = content_to_add[
+ thinking_end + 8 :
+ ].strip()
+ filtered.append(
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "thinking",
+ "thinking": thinking_content,
+ "signature": "placeholder_signature", # This is a placeholder
+ },
+ {"type": "text", "text": text_content},
+ ],
+ }
+ )
+ else:
+ filtered.append(
+ {
+ "role": "assistant",
+ "content": content_to_add,
+ }
+ )
+ else:
+ filtered.append(
+ {"role": "assistant", "content": content_to_add}
+ )
+
+ # Add tool use blocks
+ tool_uses = []
+ for call in m["tool_calls"]:
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": call["id"],
+ "name": call["function"]["name"],
+ "input": json.loads(call["function"]["arguments"]),
+ }
+ )
+
+ filtered.append({"role": "assistant", "content": tool_uses})
+
+ # Check if next message is a tool result for this tool call
+ if i + 1 < len(messages) and messages[i + 1]["role"] in [
+ "function",
+ "tool",
+ ]:
+ next_m = copy.deepcopy(messages[i + 1])
+
+ # Make sure this is a tool result for the current tool use
+ if next_m.get("tool_call_id") in [
+ call["id"] for call in m["tool_calls"]
+ ]:
+ # Add tool result as a user message
+ filtered.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": next_m["tool_call_id"],
+ "content": next_m["content"],
+ }
+ ],
+ }
+ )
+ i += 2 # Skip both the tool call and result
+ continue
+
+ i += 1
+ continue
+
+ # Case 4: Direct tool result (might be missing its paired tool call)
+ elif m["role"] in ["function", "tool"] and m.get("tool_call_id"):
+ # Add a user message with the tool result
+ filtered.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": m["tool_call_id"],
+ "content": m["content"],
+ }
+ ],
+ }
+ )
+ i += 1
+ continue
+
+ # Default case: normal message
+ elif m["role"] in ["function", "tool"]:
+ # Collect tool results to combine them
+ pending_tool_results.append(
+ {
+ "type": "tool_result",
+ "tool_use_id": m.get("tool_call_id"),
+ "content": m["content"],
+ }
+ )
+
+ # If we have all expected results, add them as one message
+ if len(filtered) > 0 and len(
+ filtered[-1].get("content", [])
+ ) == len(pending_tool_results):
+ filtered.append(
+ {"role": "user", "content": pending_tool_results}
+ )
+ pending_tool_results = []
+ else:
+ filtered.append(openai_message_to_anthropic_block(m))
+ i += 1
+
+ # Final validation: ensure no tool_use is at the end without a tool_result
+ if filtered and len(filtered) > 1:
+ last_msg = filtered[-1]
+ if (
+ last_msg["role"] == "assistant"
+ and isinstance(last_msg.get("content"), list)
+ and any(
+ block.get("type") == "tool_use"
+ for block in last_msg["content"]
+ )
+ ):
+ logger.warning(
+ "Found tool_use at end of conversation without tool_result - removing it"
+ )
+ filtered.pop() # Remove problematic message
+
+ return filtered, system_msg
+
+ async def _execute_task(self, task: dict[str, Any]):
+ """Async entry point.
+
+ Decide if streaming or not, then call the appropriate helper.
+ """
+ api_key = os.getenv("ANTHROPIC_API_KEY")
+ if not api_key:
+ logger.error("Missing ANTHROPIC_API_KEY in environment.")
+ raise ValueError(
+ "Anthropic API key not found. Set ANTHROPIC_API_KEY env var."
+ )
+
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ extra_kwargs = task["kwargs"]
+ base_args = self._get_base_args(generation_config)
+ filtered_messages, system_msg = self._split_system_messages(messages)
+ base_args["messages"] = filtered_messages
+ if system_msg:
+ base_args["system"] = system_msg
+
+ args = {**base_args, **extra_kwargs}
+ logger.debug(f"Anthropic async call with args={args}")
+
+ if generation_config.stream:
+ return self._execute_task_async_streaming(args)
+ else:
+ return await self._execute_task_async_nonstreaming(args)
+
+ async def _execute_task_async_nonstreaming(
+ self, args: dict[str, Any]
+ ) -> LLMChatCompletion:
+ api_key = os.getenv("ANTHROPIC_API_KEY")
+ if not api_key:
+ logger.error("Missing ANTHROPIC_API_KEY in environment.")
+ raise ValueError(
+ "Anthropic API key not found. Set ANTHROPIC_API_KEY env var."
+ )
+
+ try:
+ logger.debug(f"Anthropic API request: {args}")
+ response = await self.async_client.messages.create(**args)
+ logger.debug(f"Anthropic API response: {response}")
+
+ return LLMChatCompletion(
+ **self._convert_to_chat_completion(response)
+ )
+ except Exception as e:
+ logger.error(f"Anthropic async non-stream call failed: {e}")
+ logger.error("message payload = ", args)
+ raise
+
+ async def _execute_task_async_streaming(
+ self, args: dict
+ ) -> AsyncGenerator[dict[str, Any], None]:
+ """Streaming call (async): yields partial tokens in OpenAI-like SSE
+ format."""
+ # The `stream=True` is typically handled by Anthropics from the original args,
+ # but we remove it to avoid conflicts and rely on `messages.stream()`.
+ args.pop("stream", None)
+ try:
+ async with self.async_client.messages.stream(**args) as stream:
+ # We'll track partial JSON for function calls in buffer_data
+ buffer_data: dict[str, Any] = {
+ "tool_json_buffer": "",
+ "tool_name": None,
+ "tool_id": None,
+ "is_collecting_tool": False,
+ "thinking_buffer": "",
+ "is_collecting_thinking": False,
+ "thinking_signature": None,
+ "message_id": f"chatcmpl-{int(time.time())}",
+ }
+ model_name = args.get("model", "claude-2")
+ if isinstance(model_name, str):
+ model_name = model_name.split("anthropic/")[-1]
+
+ async for event in stream:
+ chunks = self._process_stream_event(
+ event=event,
+ buffer_data=buffer_data,
+ model_name=model_name,
+ )
+ for chunk in chunks:
+ yield chunk
+ except Exception as e:
+ logger.error(f"Failed to execute streaming Anthropic task: {e}")
+ logger.error("message payload = ", args)
+
+ raise
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ """Synchronous entry point."""
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ extra_kwargs = task["kwargs"]
+
+ base_args = self._get_base_args(generation_config)
+ filtered_messages, system_msg = self._split_system_messages(messages)
+ base_args["messages"] = filtered_messages
+ if system_msg:
+ base_args["system"] = system_msg
+
+ args = {**base_args, **extra_kwargs}
+ logger.debug(f"Anthropic sync call with args={args}")
+
+ if generation_config.stream:
+ return self._execute_task_sync_streaming(args)
+ else:
+ return self._execute_task_sync_nonstreaming(args)
+
+ def _execute_task_sync_nonstreaming(
+ self, args: dict[str, Any]
+ ): # -> LLMChatCompletion: # FIXME: LLMChatCompletion is an object from the OpenAI API, which causes a validation error
+ """Non-streaming synchronous call."""
+ try:
+ response = self.client.messages.create(**args)
+ logger.debug("Anthropic sync non-stream call succeeded.")
+ return LLMChatCompletion(
+ **self._convert_to_chat_completion(response)
+ )
+ except Exception as e:
+ logger.error(f"Anthropic sync call failed: {e}")
+ raise
+
+ def _execute_task_sync_streaming(
+ self, args: dict[str, Any]
+ ) -> Generator[dict[str, Any], None, None]:
+ """
+ Synchronous streaming call: yields partial tokens in a generator.
+ """
+ args.pop("stream", None)
+ try:
+ with self.client.messages.stream(**args) as stream:
+ buffer_data: dict[str, Any] = {
+ "tool_json_buffer": "",
+ "tool_name": None,
+ "tool_id": None,
+ "is_collecting_tool": False,
+ "thinking_buffer": "",
+ "is_collecting_thinking": False,
+ "thinking_signature": None,
+ "message_id": f"chatcmpl-{int(time.time())}",
+ }
+ model_name = args.get("model", "anthropic/claude-2")
+ if isinstance(model_name, str):
+ model_name = model_name.split("anthropic/")[-1]
+
+ for event in stream:
+ yield from self._process_stream_event(
+ event=event,
+ buffer_data=buffer_data,
+ model_name=model_name.split("anthropic/")[-1],
+ )
+ except Exception as e:
+ logger.error(f"Anthropic sync streaming call failed: {e}")
+ raise
+
+ def _process_stream_event(
+ self, event: Any, buffer_data: dict[str, Any], model_name: str
+ ) -> list[dict[str, Any]]:
+ chunks: list[dict[str, Any]] = []
+
+ def make_base_chunk() -> dict[str, Any]:
+ return {
+ "id": buffer_data["message_id"],
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model_name,
+ "choices": [{"index": 0, "delta": {}, "finish_reason": None}],
+ }
+
+ if isinstance(event, RawMessageStartEvent):
+ buffer_data["message_id"] = event.message.id
+ chunk = make_base_chunk()
+ input_tokens = (
+ event.message.usage.input_tokens if event.message.usage else 0
+ )
+ chunk["usage"] = {
+ "prompt_tokens": input_tokens,
+ "completion_tokens": 0,
+ "total_tokens": input_tokens,
+ }
+ chunks.append(chunk)
+
+ elif isinstance(event, RawContentBlockStartEvent):
+ if hasattr(event.content_block, "type"):
+ block_type = event.content_block.type
+ if block_type == "thinking":
+ buffer_data["is_collecting_thinking"] = True
+ buffer_data["thinking_buffer"] = ""
+ # Don't emit anything yet
+ elif block_type == "tool_use" or isinstance(
+ event.content_block, ToolUseBlock
+ ):
+ buffer_data["tool_name"] = event.content_block.name # type: ignore
+ buffer_data["tool_id"] = event.content_block.id # type: ignore
+ buffer_data["tool_json_buffer"] = ""
+ buffer_data["is_collecting_tool"] = True
+
+ elif isinstance(event, RawContentBlockDeltaEvent):
+ delta_obj = getattr(event, "delta", None)
+ delta_type = getattr(delta_obj, "type", None)
+
+ # Handle thinking deltas
+ if delta_type == "thinking_delta" and hasattr(
+ delta_obj, "thinking"
+ ):
+ thinking_chunk = delta_obj.thinking # type: ignore
+ if buffer_data["is_collecting_thinking"]:
+ buffer_data["thinking_buffer"] += thinking_chunk
+ # Stream thinking chunks as they come in
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {"thinking": thinking_chunk}
+ chunks.append(chunk)
+
+ # Handle signature deltas for thinking blocks
+ elif delta_type == "signature_delta" and hasattr(
+ delta_obj, "signature"
+ ):
+ if buffer_data["is_collecting_thinking"]:
+ buffer_data["thinking_signature"] = delta_obj.signature # type: ignore
+ # No need to emit anything for the signature
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {
+ "thinking_signature": delta_obj.signature # type: ignore
+ }
+ chunks.append(chunk)
+
+ # Handle text deltas
+ elif delta_type == "text_delta" and hasattr(delta_obj, "text"):
+ text_chunk = delta_obj.text # type: ignore
+ if not buffer_data["is_collecting_tool"] and text_chunk:
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {"content": text_chunk}
+ chunks.append(chunk)
+
+ # Handle partial JSON for tools
+ elif hasattr(delta_obj, "partial_json"):
+ if buffer_data["is_collecting_tool"]:
+ buffer_data["tool_json_buffer"] += delta_obj.partial_json # type: ignore
+
+ elif isinstance(event, ContentBlockStopEvent):
+ # Handle the end of a thinking block
+ if buffer_data.get("is_collecting_thinking"):
+ # Emit a special "structured_content_delta" with the complete thinking block
+ if (
+ buffer_data["thinking_buffer"]
+ and buffer_data["thinking_signature"]
+ ):
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {
+ "structured_content": [
+ {
+ "type": "thinking",
+ "thinking": buffer_data["thinking_buffer"],
+ "signature": buffer_data["thinking_signature"],
+ }
+ ]
+ }
+ chunks.append(chunk)
+
+ # Reset thinking collection
+ buffer_data["is_collecting_thinking"] = False
+ buffer_data["thinking_buffer"] = ""
+ buffer_data["thinking_signature"] = None
+
+ # Handle the end of a tool use block
+ elif buffer_data.get("is_collecting_tool"):
+ try:
+ json.loads(buffer_data["tool_json_buffer"])
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {
+ "tool_calls": [
+ {
+ "index": 0,
+ "type": "function",
+ "id": buffer_data["tool_id"]
+ or f"call_{generate_tool_id()}",
+ "function": {
+ "name": buffer_data["tool_name"],
+ "arguments": buffer_data[
+ "tool_json_buffer"
+ ],
+ },
+ }
+ ]
+ }
+ chunks.append(chunk)
+ buffer_data["is_collecting_tool"] = False
+ buffer_data["tool_json_buffer"] = ""
+ buffer_data["tool_name"] = None
+ buffer_data["tool_id"] = None
+ except json.JSONDecodeError:
+ logger.warning(
+ "Incomplete JSON in tool call, skipping chunk"
+ )
+
+ elif isinstance(event, MessageStopEvent):
+ # Check if the event has a message attribute before accessing it
+ stop_reason = getattr(event, "message", None)
+ if stop_reason and hasattr(stop_reason, "stop_reason"):
+ stop_reason = stop_reason.stop_reason
+ chunk = make_base_chunk()
+ if stop_reason == "tool_use":
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = "tool_calls"
+ else:
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = "stop"
+ chunks.append(chunk)
+ else:
+ # Handle the case where message is not available
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = "stop"
+ chunks.append(chunk)
+
+ return chunks
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py b/.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py
new file mode 100644
index 00000000..863e44ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py
@@ -0,0 +1,110 @@
+import logging
+import os
+from typing import Any, Optional
+
+from azure.ai.inference import (
+ ChatCompletionsClient as AzureChatCompletionsClient,
+)
+from azure.ai.inference.aio import (
+ ChatCompletionsClient as AsyncAzureChatCompletionsClient,
+)
+from azure.core.credentials import AzureKeyCredential
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+logger = logging.getLogger(__name__)
+
+
+class AzureFoundryCompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.azure_foundry_client: Optional[AzureChatCompletionsClient] = None
+ self.async_azure_foundry_client: Optional[
+ AsyncAzureChatCompletionsClient
+ ] = None
+
+ # Initialize Azure Foundry clients if credentials exist.
+ azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
+ azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
+
+ if azure_foundry_api_key and azure_foundry_api_endpoint:
+ self.azure_foundry_client = AzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ logger.debug("Azure Foundry clients initialized successfully")
+
+ def _get_base_args(
+ self, generation_config: GenerationConfig
+ ) -> dict[str, Any]:
+ # Construct arguments similar to the other providers.
+ args: dict[str, Any] = {
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ "max_tokens": generation_config.max_tokens_to_sample,
+ "temperature": generation_config.temperature,
+ }
+
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+ if generation_config.tools is not None:
+ args["tools"] = generation_config.tools
+ if generation_config.response_format is not None:
+ args["response_format"] = generation_config.response_format
+ return args
+
+ async def _execute_task(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ # Azure Foundry does not require a "model" argument; the endpoint is fixed.
+ args["messages"] = messages
+ args = {**args, **kwargs}
+ logger.debug(f"Executing async Azure Foundry task with args: {args}")
+
+ try:
+ if self.async_azure_foundry_client is None:
+ raise ValueError("Azure Foundry client is not initialized")
+
+ response = await self.async_azure_foundry_client.complete(**args)
+ logger.debug("Async Azure Foundry task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(
+ f"Async Azure Foundry task execution failed: {str(e)}"
+ )
+ raise
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+ args = {**args, **kwargs}
+ logger.debug(f"Executing sync Azure Foundry task with args: {args}")
+
+ try:
+ if self.azure_foundry_client is None:
+ raise ValueError("Azure Foundry client is not initialized")
+
+ response = self.azure_foundry_client.complete(**args)
+ logger.debug("Sync Azure Foundry task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(f"Sync Azure Foundry task execution failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py b/.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py
new file mode 100644
index 00000000..44d467c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py
@@ -0,0 +1,80 @@
+import logging
+from typing import Any
+
+import litellm
+from litellm import acompletion, completion
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+logger = logging.getLogger()
+
+
+class LiteLLMCompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ litellm.modify_params = True
+ self.acompletion = acompletion
+ self.completion = completion
+
+ # if config.provider != "litellm":
+ # logger.error(f"Invalid provider: {config.provider}")
+ # raise ValueError(
+ # "LiteLLMCompletionProvider must be initialized with config with `litellm` provider."
+ # )
+
+ def _get_base_args(
+ self, generation_config: GenerationConfig
+ ) -> dict[str, Any]:
+ args: dict[str, Any] = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ "max_tokens": generation_config.max_tokens_to_sample,
+ "api_base": generation_config.api_base,
+ }
+
+ # Fix the type errors by properly typing these assignments
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+ if generation_config.tools is not None:
+ args["tools"] = generation_config.tools
+ if generation_config.response_format is not None:
+ args["response_format"] = generation_config.response_format
+
+ return args
+
+ async def _execute_task(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+ args = {**args, **kwargs}
+
+ logger.debug(
+ f"Executing LiteLLM task with generation_config={generation_config}"
+ )
+
+ return await self.acompletion(**args)
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+ args = {**args, **kwargs}
+
+ logger.debug(
+ f"Executing LiteLLM task with generation_config={generation_config}"
+ )
+
+ try:
+ return self.completion(**args)
+ except Exception as e:
+ logger.error(f"Sync LiteLLM task execution failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/openai.py b/.venv/lib/python3.12/site-packages/core/providers/llm/openai.py
new file mode 100644
index 00000000..30ef37ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/openai.py
@@ -0,0 +1,522 @@
+import logging
+import os
+from typing import Any
+
+from openai import AsyncAzureOpenAI, AsyncOpenAI, OpenAI
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+from .utils import resize_base64_image
+
+logger = logging.getLogger()
+
+
+class OpenAICompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.openai_client = None
+ self.async_openai_client = None
+ self.azure_client = None
+ self.async_azure_client = None
+ self.deepseek_client = None
+ self.async_deepseek_client = None
+ self.ollama_client = None
+ self.async_ollama_client = None
+ self.lmstudio_client = None
+ self.async_lmstudio_client = None
+ # NEW: Azure Foundry clients using the Azure Inference API
+ self.azure_foundry_client = None
+ self.async_azure_foundry_client = None
+
+ # Initialize OpenAI clients if credentials exist
+ if os.getenv("OPENAI_API_KEY"):
+ self.openai_client = OpenAI()
+ self.async_openai_client = AsyncOpenAI()
+ logger.debug("OpenAI clients initialized successfully")
+
+ # Initialize Azure OpenAI clients if credentials exist
+ azure_api_key = os.getenv("AZURE_API_KEY")
+ azure_api_base = os.getenv("AZURE_API_BASE")
+ if azure_api_key and azure_api_base:
+ self.azure_client = AsyncAzureOpenAI(
+ api_key=azure_api_key,
+ api_version=os.getenv(
+ "AZURE_API_VERSION", "2024-02-15-preview"
+ ),
+ azure_endpoint=azure_api_base,
+ )
+ self.async_azure_client = AsyncAzureOpenAI(
+ api_key=azure_api_key,
+ api_version=os.getenv(
+ "AZURE_API_VERSION", "2024-02-15-preview"
+ ),
+ azure_endpoint=azure_api_base,
+ )
+ logger.debug("Azure OpenAI clients initialized successfully")
+
+ # Initialize Deepseek clients if credentials exist
+ deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
+ deepseek_api_base = os.getenv(
+ "DEEPSEEK_API_BASE", "https://api.deepseek.com"
+ )
+ if deepseek_api_key and deepseek_api_base:
+ self.deepseek_client = OpenAI(
+ api_key=deepseek_api_key,
+ base_url=deepseek_api_base,
+ )
+ self.async_deepseek_client = AsyncOpenAI(
+ api_key=deepseek_api_key,
+ base_url=deepseek_api_base,
+ )
+ logger.debug("Deepseek OpenAI clients initialized successfully")
+
+ # Initialize Ollama clients with default API key
+ ollama_api_base = os.getenv(
+ "OLLAMA_API_BASE", "http://localhost:11434/v1"
+ )
+ if ollama_api_base:
+ self.ollama_client = OpenAI(
+ api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
+ base_url=ollama_api_base,
+ )
+ self.async_ollama_client = AsyncOpenAI(
+ api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
+ base_url=ollama_api_base,
+ )
+ logger.debug("Ollama OpenAI clients initialized successfully")
+
+ # Initialize LMStudio clients
+ lmstudio_api_base = os.getenv(
+ "LMSTUDIO_API_BASE", "http://localhost:1234/v1"
+ )
+ if lmstudio_api_base:
+ self.lmstudio_client = OpenAI(
+ api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
+ base_url=lmstudio_api_base,
+ )
+ self.async_lmstudio_client = AsyncOpenAI(
+ api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
+ base_url=lmstudio_api_base,
+ )
+ logger.debug("LMStudio OpenAI clients initialized successfully")
+
+ # Initialize Azure Foundry clients if credentials exist.
+ # These use the Azure Inference API (currently pasted into this handler).
+ azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
+ azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
+ if azure_foundry_api_key and azure_foundry_api_endpoint:
+ from azure.ai.inference import (
+ ChatCompletionsClient as AzureChatCompletionsClient,
+ )
+ from azure.ai.inference.aio import (
+ ChatCompletionsClient as AsyncAzureChatCompletionsClient,
+ )
+ from azure.core.credentials import AzureKeyCredential
+
+ self.azure_foundry_client = AzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ logger.debug("Azure Foundry clients initialized successfully")
+
+ if not any(
+ [
+ self.openai_client,
+ self.azure_client,
+ self.ollama_client,
+ self.lmstudio_client,
+ self.azure_foundry_client,
+ ]
+ ):
+ raise ValueError(
+ "No valid client credentials found. Please set either OPENAI_API_KEY, "
+ "both AZURE_API_KEY and AZURE_API_BASE environment variables, "
+ "OLLAMA_API_BASE, LMSTUDIO_API_BASE, or AZURE_FOUNDRY_API_KEY and AZURE_FOUNDRY_API_ENDPOINT."
+ )
+
+ def _get_client_and_model(self, model: str):
+ """Determine which client to use based on model prefix and return the
+ appropriate client and model name."""
+ if model.startswith("azure/"):
+ if not self.azure_client:
+ raise ValueError(
+ "Azure OpenAI credentials not configured but azure/ model prefix used"
+ )
+ return self.azure_client, model[6:] # Strip 'azure/' prefix
+ elif model.startswith("openai/"):
+ if not self.openai_client:
+ raise ValueError(
+ "OpenAI credentials not configured but openai/ model prefix used"
+ )
+ return self.openai_client, model[7:] # Strip 'openai/' prefix
+ elif model.startswith("deepseek/"):
+ if not self.deepseek_client:
+ raise ValueError(
+ "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
+ )
+ return self.deepseek_client, model[9:] # Strip 'deepseek/' prefix
+ elif model.startswith("ollama/"):
+ if not self.ollama_client:
+ raise ValueError(
+ "Ollama OpenAI credentials not configured but ollama/ model prefix used"
+ )
+ return self.ollama_client, model[7:] # Strip 'ollama/' prefix
+ elif model.startswith("lmstudio/"):
+ if not self.lmstudio_client:
+ raise ValueError(
+ "LMStudio credentials not configured but lmstudio/ model prefix used"
+ )
+ return self.lmstudio_client, model[9:] # Strip 'lmstudio/' prefix
+ elif model.startswith("azure-foundry/"):
+ if not self.azure_foundry_client:
+ raise ValueError(
+ "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
+ )
+ return (
+ self.azure_foundry_client,
+ model[14:],
+ ) # Strip 'azure-foundry/' prefix
+ else:
+ # Default to OpenAI if no prefix is provided.
+ if self.openai_client:
+ return self.openai_client, model
+ elif self.azure_client:
+ return self.azure_client, model
+ elif self.ollama_client:
+ return self.ollama_client, model
+ elif self.lmstudio_client:
+ return self.lmstudio_client, model
+ elif self.azure_foundry_client:
+ return self.azure_foundry_client, model
+ else:
+ raise ValueError("No valid client available for model prefix")
+
+ def _get_async_client_and_model(self, model: str):
+ """Get async client and model name based on prefix."""
+ if model.startswith("azure/"):
+ if not self.async_azure_client:
+ raise ValueError(
+ "Azure OpenAI credentials not configured but azure/ model prefix used"
+ )
+ return self.async_azure_client, model[6:]
+ elif model.startswith("openai/"):
+ if not self.async_openai_client:
+ raise ValueError(
+ "OpenAI credentials not configured but openai/ model prefix used"
+ )
+ return self.async_openai_client, model[7:]
+ elif model.startswith("deepseek/"):
+ if not self.async_deepseek_client:
+ raise ValueError(
+ "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
+ )
+ return self.async_deepseek_client, model[9:].strip()
+ elif model.startswith("ollama/"):
+ if not self.async_ollama_client:
+ raise ValueError(
+ "Ollama OpenAI credentials not configured but ollama/ model prefix used"
+ )
+ return self.async_ollama_client, model[7:]
+ elif model.startswith("lmstudio/"):
+ if not self.async_lmstudio_client:
+ raise ValueError(
+ "LMStudio credentials not configured but lmstudio/ model prefix used"
+ )
+ return self.async_lmstudio_client, model[9:]
+ elif model.startswith("azure-foundry/"):
+ if not self.async_azure_foundry_client:
+ raise ValueError(
+ "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
+ )
+ return self.async_azure_foundry_client, model[14:]
+ else:
+ if self.async_openai_client:
+ return self.async_openai_client, model
+ elif self.async_azure_client:
+ return self.async_azure_client, model
+ elif self.async_ollama_client:
+ return self.async_ollama_client, model
+ elif self.async_lmstudio_client:
+ return self.async_lmstudio_client, model
+ elif self.async_azure_foundry_client:
+ return self.async_azure_foundry_client, model
+ else:
+ raise ValueError(
+ "No valid async client available for model prefix"
+ )
+
+ def _process_messages_with_images(
+ self, messages: list[dict]
+ ) -> list[dict]:
+ """
+ Process messages that may contain image_url or image_data fields.
+ Now includes aggressive image resizing similar to Anthropic provider.
+ """
+ processed_messages = []
+
+ for msg in messages:
+ if msg.get("role") == "system":
+ # System messages don't support content arrays in OpenAI
+ processed_messages.append(msg)
+ continue
+
+ # Check if the message contains image data
+ image_url = msg.pop("image_url", None)
+ image_data = msg.pop("image_data", None)
+ content = msg.get("content")
+
+ if image_url or image_data:
+ # Convert to content array format
+ new_content = []
+
+ # Add image content
+ if image_url:
+ new_content.append(
+ {"type": "image_url", "image_url": {"url": image_url}}
+ )
+ elif image_data:
+ # Resize the base64 image data if available
+ media_type = image_data.get("media_type", "image/jpeg")
+ data = image_data.get("data", "")
+
+ # Apply image resizing if PIL is available
+ if data:
+ data = resize_base64_image(data)
+ logger.debug(
+ f"Image resized, new size: {len(data)} chars"
+ )
+
+ # OpenAI expects base64 images in data URL format
+ data_url = f"data:{media_type};base64,{data}"
+ new_content.append(
+ {"type": "image_url", "image_url": {"url": data_url}}
+ )
+
+ # Add text content if present
+ if content:
+ new_content.append({"type": "text", "text": content})
+
+ # Update the message
+ new_msg = dict(msg)
+ new_msg["content"] = new_content
+ processed_messages.append(new_msg)
+ else:
+ processed_messages.append(msg)
+
+ return processed_messages
+
+ def _process_array_content_with_images(self, content: list) -> list:
+ """
+ Process content array that may contain image_url items.
+ Used for messages that already have content in array format.
+ """
+ if not content or not isinstance(content, list):
+ return content
+
+ processed_content = []
+
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "image_url":
+ # Process image URL if needed
+ processed_content.append(item)
+ elif item.get("type") == "image" and item.get("source"):
+ # Convert Anthropic-style to OpenAI-style
+ source = item.get("source", {})
+ if source.get("type") == "base64" and source.get("data"):
+ # Resize the base64 image data
+ resized_data = resize_base64_image(source.get("data"))
+
+ media_type = source.get("media_type", "image/jpeg")
+ data_url = f"data:{media_type};base64,{resized_data}"
+
+ processed_content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": data_url},
+ }
+ )
+ elif source.get("type") == "url" and source.get("url"):
+ processed_content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": source.get("url")},
+ }
+ )
+ else:
+ # Pass through other types
+ processed_content.append(item)
+ else:
+ processed_content.append(item)
+
+ return processed_content
+
+ def _preprocess_messages(self, messages: list[dict]) -> list[dict]:
+ """
+ Preprocess all messages to optimize images before sending to OpenAI API.
+ """
+ if not messages or not isinstance(messages, list):
+ return messages
+
+ processed_messages = []
+
+ for msg in messages:
+ # Skip system messages as they're handled separately
+ if msg.get("role") == "system":
+ processed_messages.append(msg)
+ continue
+
+ # Process array-format content (might contain images)
+ if isinstance(msg.get("content"), list):
+ new_msg = dict(msg)
+ new_msg["content"] = self._process_array_content_with_images(
+ msg["content"]
+ )
+ processed_messages.append(new_msg)
+ else:
+ # Standard processing for non-array content
+ processed_messages.append(msg)
+
+ return processed_messages
+
+ def _get_base_args(self, generation_config: GenerationConfig) -> dict:
+ # Keep existing implementation...
+ args: dict[str, Any] = {
+ "model": generation_config.model,
+ "stream": generation_config.stream,
+ }
+
+ model_str = generation_config.model or ""
+
+ if "o1" not in model_str and "o3" not in model_str:
+ args["max_tokens"] = generation_config.max_tokens_to_sample
+ args["temperature"] = generation_config.temperature
+ args["top_p"] = generation_config.top_p
+ else:
+ args["max_completion_tokens"] = (
+ generation_config.max_tokens_to_sample
+ )
+
+ if generation_config.reasoning_effort is not None:
+ args["reasoning_effort"] = generation_config.reasoning_effort
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+ if generation_config.tools is not None:
+ args["tools"] = generation_config.tools
+ if generation_config.response_format is not None:
+ args["response_format"] = generation_config.response_format
+ return args
+
+ async def _execute_task(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ # First preprocess to handle any images in array format
+ messages = self._preprocess_messages(messages)
+
+ # Then process messages with direct image_url or image_data fields
+ processed_messages = self._process_messages_with_images(messages)
+
+ args = self._get_base_args(generation_config)
+ client, model_name = self._get_async_client_and_model(args["model"])
+ args["model"] = model_name
+ args["messages"] = processed_messages
+ args = {**args, **kwargs}
+
+ # Check if we're using a vision-capable model when images are present
+ contains_images = any(
+ isinstance(msg.get("content"), list)
+ and any(
+ item.get("type") == "image_url"
+ for item in msg.get("content", [])
+ )
+ for msg in processed_messages
+ )
+
+ if contains_images:
+ vision_models = ["gpt-4-vision", "gpt-4o"]
+ if all(
+ vision_model in model_name for vision_model in vision_models
+ ):
+ logger.warning(
+ f"Using model {model_name} with images, but it may not support vision"
+ )
+
+ logger.debug(f"Executing async task with args: {args}")
+ try:
+ # Same as before...
+ if client == self.async_azure_foundry_client:
+ model_value = args.pop(
+ "model"
+ ) # Remove model before passing args
+ response = await client.complete(**args)
+ else:
+ response = await client.chat.completions.create(**args)
+ logger.debug("Async task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(f"Async task execution failed: {str(e)}")
+ # HACK: print the exception to the console for debugging
+ raise
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ # First preprocess to handle any images in array format
+ messages = self._preprocess_messages(messages)
+
+ # Then process messages with direct image_url or image_data fields
+ processed_messages = self._process_messages_with_images(messages)
+
+ args = self._get_base_args(generation_config)
+ client, model_name = self._get_client_and_model(args["model"])
+ args["model"] = model_name
+ args["messages"] = processed_messages
+ args = {**args, **kwargs}
+
+ # Same vision model check as in async version
+ contains_images = any(
+ isinstance(msg.get("content"), list)
+ and any(
+ item.get("type") == "image_url"
+ for item in msg.get("content", [])
+ )
+ for msg in processed_messages
+ )
+
+ if contains_images:
+ vision_models = ["gpt-4-vision", "gpt-4o"]
+ if all(
+ vision_model in model_name for vision_model in vision_models
+ ):
+ logger.warning(
+ f"Using model {model_name} with images, but it may not support vision"
+ )
+
+ logger.debug(f"Executing sync OpenAI task with args: {args}")
+ try:
+ # Same as before...
+ if client == self.azure_foundry_client:
+ args.pop("model")
+ response = client.complete(**args)
+ else:
+ response = client.chat.completions.create(**args)
+ logger.debug("Sync task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(f"Sync task execution failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py b/.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py
new file mode 100644
index 00000000..b95b310a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py
@@ -0,0 +1,96 @@
+import logging
+from typing import Any
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+from .anthropic import AnthropicCompletionProvider
+from .azure_foundry import AzureFoundryCompletionProvider
+from .litellm import LiteLLMCompletionProvider
+from .openai import OpenAICompletionProvider
+
+logger = logging.getLogger(__name__)
+
+
+class R2RCompletionProvider(CompletionProvider):
+ """A provider that routes to the right LLM provider (R2R):
+
+ - If `generation_config.model` starts with "anthropic/", call AnthropicCompletionProvider.
+ - If it starts with "azure-foundry/", call AzureFoundryCompletionProvider.
+ - If it starts with one of the other OpenAI-like prefixes ("openai/", "azure/", "deepseek/", "ollama/", "lmstudio/")
+ or has no prefix (e.g. "gpt-4", "gpt-3.5"), call OpenAICompletionProvider.
+ - Otherwise, fallback to LiteLLMCompletionProvider.
+ """
+
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ """Initialize sub-providers for OpenAI, Anthropic, LiteLLM, and Azure
+ Foundry."""
+ super().__init__(config)
+ self.config = config
+
+ logger.info("Initializing R2RCompletionProvider...")
+ self._openai_provider = OpenAICompletionProvider(
+ self.config, *args, **kwargs
+ )
+ self._anthropic_provider = AnthropicCompletionProvider(
+ self.config, *args, **kwargs
+ )
+ self._litellm_provider = LiteLLMCompletionProvider(
+ self.config, *args, **kwargs
+ )
+ self._azure_foundry_provider = AzureFoundryCompletionProvider(
+ self.config, *args, **kwargs
+ ) # New provider
+
+ logger.debug(
+ "R2RCompletionProvider initialized with OpenAI, Anthropic, LiteLLM, and Azure Foundry sub-providers."
+ )
+
+ def _choose_subprovider_by_model(
+ self, model_name: str, is_streaming: bool = False
+ ) -> CompletionProvider:
+ """Decide which underlying sub-provider to call based on the model name
+ (prefix)."""
+ # Route to Anthropic if appropriate.
+ if model_name.startswith("anthropic/"):
+ return self._anthropic_provider
+
+ # Route to Azure Foundry explicitly.
+ if model_name.startswith("azure-foundry/"):
+ return self._azure_foundry_provider
+
+ # OpenAI-like prefixes.
+ openai_like_prefixes = [
+ "openai/",
+ "azure/",
+ "deepseek/",
+ "ollama/",
+ "lmstudio/",
+ ]
+ if (
+ any(
+ model_name.startswith(prefix)
+ for prefix in openai_like_prefixes
+ )
+ or "/" not in model_name
+ ):
+ return self._openai_provider
+
+ # Fallback to LiteLLM.
+ return self._litellm_provider
+
+ async def _execute_task(self, task: dict[str, Any]):
+ """Pick the sub-provider based on model name and forward the async
+ call."""
+ generation_config: GenerationConfig = task["generation_config"]
+ model_name = generation_config.model
+ sub_provider = self._choose_subprovider_by_model(model_name or "")
+ return await sub_provider._execute_task(task)
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ """Pick the sub-provider based on model name and forward the sync
+ call."""
+ generation_config: GenerationConfig = task["generation_config"]
+ model_name = generation_config.model
+ sub_provider = self._choose_subprovider_by_model(model_name or "")
+ return sub_provider._execute_task_sync(task)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/utils.py b/.venv/lib/python3.12/site-packages/core/providers/llm/utils.py
new file mode 100644
index 00000000..619b2e73
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/utils.py
@@ -0,0 +1,106 @@
+import base64
+import io
+import logging
+from typing import Tuple
+
+from PIL import Image
+
+logger = logging.getLogger()
+
+
+def resize_base64_image(
+ base64_string: str,
+ max_size: Tuple[int, int] = (512, 512),
+ max_megapixels: float = 0.25,
+) -> str:
+ """Aggressively resize images with better error handling and debug output"""
+ logger.debug(
+ f"RESIZING NOW!!! Original length: {len(base64_string)} chars"
+ )
+
+ # Decode base64 string to bytes
+ try:
+ image_data = base64.b64decode(base64_string)
+ image = Image.open(io.BytesIO(image_data))
+ logger.debug(f"Image opened successfully: {image.format} {image.size}")
+ except Exception as e:
+ logger.debug(f"Failed to decode/open image: {e}")
+ # Emergency fallback - truncate the base64 string to reduce tokens
+ if len(base64_string) > 50000:
+ return base64_string[:50000]
+ return base64_string
+
+ try:
+ width, height = image.size
+ current_megapixels = (width * height) / 1_000_000
+ logger.debug(
+ f"Original dimensions: {width}x{height} ({current_megapixels:.2f} MP)"
+ )
+
+ # MUCH more aggressive resizing for large images
+ if current_megapixels > 0.5:
+ max_size = (384, 384)
+ max_megapixels = 0.15
+ logger.debug("Large image detected! Using more aggressive limits")
+
+ # Calculate new dimensions with strict enforcement
+ # Always resize if the image is larger than we want
+ scale_factor = min(
+ max_size[0] / width,
+ max_size[1] / height,
+ (max_megapixels / current_megapixels) ** 0.5,
+ )
+
+ if scale_factor >= 1.0:
+ # No resize needed, but still compress
+ new_width, new_height = width, height
+ else:
+ # Apply scaling
+ new_width = max(int(width * scale_factor), 64) # Min width
+ new_height = max(int(height * scale_factor), 64) # Min height
+
+ # Always resize/recompress the image
+ logger.debug(f"Resizing to: {new_width}x{new_height}")
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS) # type: ignore
+
+ # Convert back to base64 with strong compression
+ buffer = io.BytesIO()
+ if image.format == "JPEG" or image.format is None:
+ # Apply very aggressive JPEG compression
+ quality = 50 # Very low quality to reduce size
+ resized_image.save(
+ buffer, format="JPEG", quality=quality, optimize=True
+ )
+ else:
+ # For other formats
+ resized_image.save(
+ buffer, format=image.format or "PNG", optimize=True
+ )
+
+ resized_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+ logger.debug(
+ f"Resized base64 length: {len(resized_base64)} chars (reduction: {100 * (1 - len(resized_base64) / len(base64_string)):.1f}%)"
+ )
+ return resized_base64
+
+ except Exception as e:
+ logger.debug(f"Error during resize: {e}")
+ # If anything goes wrong, truncate the base64 to a reasonable size
+ if len(base64_string) > 50000:
+ return base64_string[:50000]
+ return base64_string
+
+
+def estimate_image_tokens(width: int, height: int) -> int:
+ """
+ Estimate the number of tokens an image will use based on Anthropic's formula.
+
+ Args:
+ width: Image width in pixels
+ height: Image height in pixels
+
+ Returns:
+ Estimated number of tokens
+ """
+ return int((width * height) / 750)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py
new file mode 100644
index 00000000..b41d79b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py
@@ -0,0 +1,4 @@
+from .hatchet import HatchetOrchestrationProvider
+from .simple import SimpleOrchestrationProvider
+
+__all__ = ["HatchetOrchestrationProvider", "SimpleOrchestrationProvider"]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py
new file mode 100644
index 00000000..941e2048
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py
@@ -0,0 +1,105 @@
+# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
+import asyncio
+import logging
+from typing import Any, Callable, Optional
+
+from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
+
+logger = logging.getLogger()
+
+
+class HatchetOrchestrationProvider(OrchestrationProvider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ try:
+ from hatchet_sdk import ClientConfig, Hatchet
+ except ImportError:
+ raise ImportError(
+ "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`."
+ ) from None
+ root_logger = logging.getLogger()
+
+ self.orchestrator = Hatchet(
+ config=ClientConfig(
+ logger=root_logger,
+ ),
+ )
+ self.root_logger = root_logger
+ self.config: OrchestrationConfig = config
+ self.messages: dict[str, str] = {}
+
+ def workflow(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.workflow(*args, **kwargs)
+
+ def step(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.step(*args, **kwargs)
+
+ def failure(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.on_failure_step(*args, **kwargs)
+
+ def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any:
+ if not max_runs:
+ max_runs = self.config.max_runs
+ self.worker = self.orchestrator.worker(name, max_runs) # type: ignore
+ return self.worker
+
+ def concurrency(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.concurrency(*args, **kwargs)
+
+ async def start_worker(self):
+ if not self.worker:
+ raise ValueError(
+ "Worker not initialized. Call get_worker() first."
+ )
+
+ asyncio.create_task(self.worker.async_start())
+
+ async def run_workflow(
+ self,
+ workflow_name: str,
+ parameters: dict,
+ options: dict,
+ *args,
+ **kwargs,
+ ) -> Any:
+ task_id = self.orchestrator.admin.run_workflow( # type: ignore
+ workflow_name,
+ parameters,
+ options=options, # type: ignore
+ *args,
+ **kwargs,
+ )
+ return {
+ "task_id": str(task_id),
+ "message": self.messages.get(
+ workflow_name, "Workflow queued successfully."
+ ), # Return message based on workflow name
+ }
+
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ self.messages.update(messages)
+
+ logger.info(
+ f"Registering workflows for {workflow} with messages {messages}."
+ )
+ if workflow == Workflow.INGESTION:
+ from core.main.orchestration.hatchet.ingestion_workflow import ( # type: ignore
+ hatchet_ingestion_factory,
+ )
+
+ workflows = hatchet_ingestion_factory(self, service)
+ if self.worker:
+ for workflow in workflows.values():
+ self.worker.register_workflow(workflow)
+
+ elif workflow == Workflow.GRAPH:
+ from core.main.orchestration.hatchet.graph_workflow import ( # type: ignore
+ hatchet_graph_search_results_factory,
+ )
+
+ workflows = hatchet_graph_search_results_factory(self, service)
+ if self.worker:
+ for workflow in workflows.values():
+ self.worker.register_workflow(workflow)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py
new file mode 100644
index 00000000..33028afe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py
@@ -0,0 +1,61 @@
+from typing import Any
+
+from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
+
+
+class SimpleOrchestrationProvider(OrchestrationProvider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ self.config = config
+ self.messages: dict[str, str] = {}
+
+ async def start_worker(self):
+ pass
+
+ def get_worker(self, name: str, max_runs: int) -> Any:
+ pass
+
+ def step(self, *args, **kwargs) -> Any:
+ pass
+
+ def workflow(self, *args, **kwargs) -> Any:
+ pass
+
+ def failure(self, *args, **kwargs) -> Any:
+ pass
+
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ for key, msg in messages.items():
+ self.messages[key] = msg
+
+ if workflow == Workflow.INGESTION:
+ from core.main.orchestration import simple_ingestion_factory
+
+ self.ingestion_workflows = simple_ingestion_factory(service)
+
+ elif workflow == Workflow.GRAPH:
+ from core.main.orchestration.simple.graph_workflow import (
+ simple_graph_search_results_factory,
+ )
+
+ self.graph_search_results_workflows = (
+ simple_graph_search_results_factory(service)
+ )
+
+ async def run_workflow(
+ self, workflow_name: str, parameters: dict, options: dict
+ ) -> dict[str, str]:
+ if workflow_name in self.ingestion_workflows:
+ await self.ingestion_workflows[workflow_name](
+ parameters.get("request")
+ )
+ return {"message": self.messages[workflow_name]}
+ elif workflow_name in self.graph_search_results_workflows:
+ await self.graph_search_results_workflows[workflow_name](
+ parameters.get("request")
+ )
+ return {"message": self.messages[workflow_name]}
+ else:
+ raise ValueError(f"Workflow '{workflow_name}' not found.")
diff --git a/.venv/lib/python3.12/site-packages/core/utils/__init__.py b/.venv/lib/python3.12/site-packages/core/utils/__init__.py
new file mode 100644
index 00000000..e04db4b9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/__init__.py
@@ -0,0 +1,182 @@
+import re
+from typing import Set, Tuple
+
+from shared.utils.base_utils import (
+ SearchResultsCollector,
+ SSEFormatter,
+ convert_nonserializable_objects,
+ decrement_version,
+ deep_update,
+ dump_collector,
+ dump_obj,
+ format_search_results_for_llm,
+ generate_default_user_collection_id,
+ generate_document_id,
+ generate_extraction_id,
+ generate_id,
+ generate_user_id,
+ increment_version,
+ num_tokens,
+ num_tokens_from_messages,
+ update_settings_from_dict,
+ validate_uuid,
+ yield_sse_event,
+)
+from shared.utils.splitter.text import (
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+)
+
+
+def extract_citations(text: str) -> list[str]:
+ """
+ Extract citation IDs enclosed in brackets like [abc1234].
+ Returns a list of citation IDs.
+ """
+ # Direct pattern to match IDs inside brackets with alphanumeric pattern
+ CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
+
+ sids = []
+ for match in CITATION_PATTERN.finditer(text):
+ sid = match.group(1)
+ sids.append(sid)
+
+ return sids
+
+
+def extract_citation_spans(text: str) -> dict[str, list[Tuple[int, int]]]:
+ """
+ Extract citation IDs with their positions in the text.
+
+ Args:
+ text: The text to search for citations
+
+ Returns:
+ dictionary mapping citation IDs to lists of (start, end) position tuples
+ """
+ # Use the same pattern as the original extract_citations
+ CITATION_PATTERN = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
+
+ citation_spans: dict = {}
+
+ for match in CITATION_PATTERN.finditer(text):
+ sid = match.group(1)
+ start = match.start()
+ end = match.end()
+
+ if sid not in citation_spans:
+ citation_spans[sid] = []
+
+ # Add the position span
+ citation_spans[sid].append((start, end))
+
+ return citation_spans
+
+
+class CitationTracker:
+ """
+ Tracks citation spans to ensure each one is only emitted once.
+ """
+
+ def __init__(self):
+ # Track which citation spans we've processed
+ # Format: {citation_id: {(start, end), (start, end), ...}}
+ self.processed_spans: dict[str, Set[Tuple[int, int]]] = {}
+
+ # Track which citation IDs we've seen
+ self.seen_citation_ids: Set[str] = set()
+
+ def is_new_citation(self, citation_id: str) -> bool:
+ """Check if this is the first occurrence of this citation ID."""
+ is_new = citation_id not in self.seen_citation_ids
+ if is_new:
+ self.seen_citation_ids.add(citation_id)
+ return is_new
+
+ def is_new_span(self, citation_id: str, span: Tuple[int, int]) -> bool:
+ """
+ Check if this span has already been processed for this citation ID.
+
+ Args:
+ citation_id: The citation ID
+ span: (start, end) position tuple
+
+ Returns:
+ True if this span hasn't been processed yet, False otherwise
+ """
+ # Initialize set for this citation ID if needed
+ if citation_id not in self.processed_spans:
+ self.processed_spans[citation_id] = set()
+
+ # Check if we've seen this span before
+ if span in self.processed_spans[citation_id]:
+ return False
+
+ # This is a new span, track it
+ self.processed_spans[citation_id].add(span)
+ return True
+
+ def get_all_spans(self) -> dict[str, list[Tuple[int, int]]]:
+ """Get all processed spans for final answer."""
+ return {
+ cid: list(spans) for cid, spans in self.processed_spans.items()
+ }
+
+
+def find_new_citation_spans(
+ text: str, tracker: CitationTracker
+) -> dict[str, list[Tuple[int, int]]]:
+ """
+ Extract citation spans that haven't been processed yet.
+
+ Args:
+ text: Text to search
+ tracker: The CitationTracker instance
+
+ Returns:
+ dictionary of citation IDs to lists of new (start, end) spans
+ """
+ # Get all citation spans in the text
+ all_spans = extract_citation_spans(text)
+
+ # Filter to only spans we haven't processed yet
+ new_spans: dict = {}
+
+ for cid, spans in all_spans.items():
+ for span in spans:
+ if tracker.is_new_span(cid, span):
+ if cid not in new_spans:
+ new_spans[cid] = []
+ new_spans[cid].append(span)
+
+ return new_spans
+
+
+__all__ = [
+ "format_search_results_for_llm",
+ "generate_id",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_user_id",
+ "increment_version",
+ "decrement_version",
+ "generate_default_user_collection_id",
+ "validate_uuid",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+ "convert_nonserializable_objects",
+ "num_tokens",
+ "num_tokens_from_messages",
+ "SSEFormatter",
+ "SearchResultsCollector",
+ "update_settings_from_dict",
+ "deep_update",
+ # Text splitter
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "extract_citations",
+ "extract_citation_spans",
+ "CitationTracker",
+ "find_new_citation_spans",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/utils/logging_config.py b/.venv/lib/python3.12/site-packages/core/utils/logging_config.py
new file mode 100644
index 00000000..9b989c51
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/logging_config.py
@@ -0,0 +1,164 @@
+import logging
+import logging.config
+import os
+import re
+import sys
+from pathlib import Path
+
+
+class HTTPStatusFilter(logging.Filter):
+ """This filter inspects uvicorn.access log records. It uses
+ record.getMessage() to retrieve the fully formatted log message. Then it
+ searches for HTTP status codes and adjusts the.
+
+ record's log level based on that status:
+ - 4xx: WARNING
+ - 5xx: ERROR
+ All other logs remain unchanged.
+ """
+
+ # A broad pattern to find any 3-digit number in the message.
+ # This should capture the HTTP status code from a line like:
+ # '127.0.0.1:54946 - "GET /v2/relationships HTTP/1.1" 404'
+ STATUS_CODE_PATTERN = re.compile(r"\b(\d{3})\b")
+ HEALTH_ENDPOINT_PATTERN = re.compile(r'"GET /v3/health HTTP/\d\.\d"')
+
+ LEVEL_TO_ANSI = {
+ logging.INFO: "\033[32m", # green
+ logging.WARNING: "\033[33m", # yellow
+ logging.ERROR: "\033[31m", # red
+ }
+ RESET = "\033[0m"
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ if record.name != "uvicorn.access":
+ return True
+
+ message = record.getMessage()
+
+ # Filter out health endpoint requests
+ # FIXME: This should be made configurable in the future
+ if self.HEALTH_ENDPOINT_PATTERN.search(message):
+ return False
+
+ if codes := self.STATUS_CODE_PATTERN.findall(message):
+ status_code = int(codes[-1])
+ if 200 <= status_code < 300:
+ record.levelno = logging.INFO
+ record.levelname = "INFO"
+ color = self.LEVEL_TO_ANSI[logging.INFO]
+ elif 400 <= status_code < 500:
+ record.levelno = logging.WARNING
+ record.levelname = "WARNING"
+ color = self.LEVEL_TO_ANSI[logging.WARNING]
+ elif 500 <= status_code < 600:
+ record.levelno = logging.ERROR
+ record.levelname = "ERROR"
+ color = self.LEVEL_TO_ANSI[logging.ERROR]
+ else:
+ return True
+
+ # Wrap the status code in ANSI codes
+ colored_code = f"{color}{status_code}{self.RESET}"
+ # Replace the status code in the message
+ new_msg = message.replace(str(status_code), colored_code)
+
+ # Update record.msg and clear args to avoid formatting issues
+ record.msg = new_msg
+ record.args = ()
+
+ return True
+
+
+log_level = os.environ.get("R2R_LOG_LEVEL", "INFO").upper()
+log_console_formatter = os.environ.get(
+ "R2R_LOG_CONSOLE_FORMATTER", "colored"
+).lower() # colored or json
+
+log_dir = Path.cwd() / "logs"
+log_dir.mkdir(exist_ok=True)
+log_file = log_dir / "app.log"
+
+log_config = {
+ "version": 1,
+ "disable_existing_loggers": False,
+ "filters": {
+ "http_status_filter": {
+ "()": HTTPStatusFilter,
+ }
+ },
+ "formatters": {
+ "default": {
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S",
+ },
+ "colored": {
+ "()": "colorlog.ColoredFormatter",
+ "format": "%(asctime)s - %(log_color)s%(levelname)s%(reset)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S",
+ "log_colors": {
+ "DEBUG": "white",
+ "INFO": "green",
+ "WARNING": "yellow",
+ "ERROR": "red",
+ "CRITICAL": "bold_red",
+ },
+ },
+ "json": {
+ "()": "pythonjsonlogger.json.JsonFormatter",
+ "format": "%(name)s %(levelname)s %(message)s", # these become keys in the JSON log
+ "rename_fields": {
+ "asctime": "time",
+ "levelname": "level",
+ "name": "logger",
+ },
+ },
+ },
+ "handlers": {
+ "file": {
+ "class": "logging.handlers.RotatingFileHandler",
+ "formatter": "colored",
+ "filename": log_file,
+ "maxBytes": 10485760, # 10MB
+ "backupCount": 5,
+ "filters": ["http_status_filter"],
+ "level": log_level, # Set handler level based on the environment variable
+ },
+ "console": {
+ "class": "logging.StreamHandler",
+ "formatter": log_console_formatter,
+ "stream": sys.stdout,
+ "filters": ["http_status_filter"],
+ "level": log_level, # Set handler level based on the environment variable
+ },
+ },
+ "loggers": {
+ "": { # Root logger
+ "handlers": ["console", "file"],
+ "level": log_level, # Set logger level based on the environment variable
+ },
+ "uvicorn": {
+ "handlers": ["console", "file"],
+ "level": log_level,
+ "propagate": False,
+ },
+ "uvicorn.error": {
+ "handlers": ["console", "file"],
+ "level": log_level,
+ "propagate": False,
+ },
+ "uvicorn.access": {
+ "handlers": ["console", "file"],
+ "level": log_level,
+ "propagate": False,
+ },
+ },
+}
+
+
+def configure_logging() -> Path:
+ logging.config.dictConfig(log_config)
+
+ logging.info(f"Logging is configured at {log_level} level.")
+
+ return log_file
diff --git a/.venv/lib/python3.12/site-packages/core/utils/sentry.py b/.venv/lib/python3.12/site-packages/core/utils/sentry.py
new file mode 100644
index 00000000..9a4c09a1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/sentry.py
@@ -0,0 +1,22 @@
+import contextlib
+import os
+
+import sentry_sdk
+
+
+def init_sentry():
+ dsn = os.getenv("R2R_SENTRY_DSN")
+ if not dsn:
+ return
+
+ with contextlib.suppress(Exception):
+ sentry_sdk.init(
+ dsn=dsn,
+ environment=os.getenv("R2R_SENTRY_ENVIRONMENT", "not_set"),
+ traces_sample_rate=float(
+ os.getenv("R2R_SENTRY_TRACES_SAMPLE_RATE", 1.0)
+ ),
+ profiles_sample_rate=float(
+ os.getenv("R2R_SENTRY_PROFILES_SAMPLE_RATE", 1.0)
+ ),
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/utils/serper.py b/.venv/lib/python3.12/site-packages/core/utils/serper.py
new file mode 100644
index 00000000..8962565b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/utils/serper.py
@@ -0,0 +1,107 @@
+# TODO - relocate to a dedicated module
+import http.client
+import json
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+# TODO - Move process json to dedicated data processing module
+def process_json(json_object, indent=0):
+ """Recursively traverses the JSON object (dicts and lists) to create an
+ unstructured text blob."""
+ text_blob = ""
+ if isinstance(json_object, dict):
+ for key, value in json_object.items():
+ padding = " " * indent
+ if isinstance(value, (dict, list)):
+ text_blob += (
+ f"{padding}{key}:\n{process_json(value, indent + 1)}"
+ )
+ else:
+ text_blob += f"{padding}{key}: {value}\n"
+ elif isinstance(json_object, list):
+ for index, item in enumerate(json_object):
+ padding = " " * indent
+ if isinstance(item, (dict, list)):
+ text_blob += f"{padding}Item {index + 1}:\n{process_json(item, indent + 1)}"
+ else:
+ text_blob += f"{padding}Item {index + 1}: {item}\n"
+ return text_blob
+
+
+# TODO - Introduce abstract "Integration" ABC.
+class SerperClient:
+ def __init__(self, api_base: str = "google.serper.dev") -> None:
+ api_key = os.getenv("SERPER_API_KEY")
+ if not api_key:
+ raise ValueError(
+ "Please set the `SERPER_API_KEY` environment variable to use `SerperClient`."
+ )
+
+ self.api_base = api_base
+ self.headers = {
+ "X-API-KEY": api_key,
+ "Content-Type": "application/json",
+ }
+
+ @staticmethod
+ def _extract_results(result_data: dict) -> list:
+ formatted_results = []
+
+ for key, value in result_data.items():
+ # Skip searchParameters as it's not a result entry
+ if key == "searchParameters":
+ continue
+
+ # Handle 'answerBox' as a single item
+ if key == "answerBox":
+ value["type"] = key # Add the type key to the dictionary
+ formatted_results.append(value)
+ # Handle lists of results
+ elif isinstance(value, list):
+ for item in value:
+ item["type"] = key # Add the type key to the dictionary
+ formatted_results.append(item)
+ # Handle 'peopleAlsoAsk' and potentially other single item formats
+ elif isinstance(value, dict):
+ value["type"] = key # Add the type key to the dictionary
+ formatted_results.append(value)
+
+ return formatted_results
+
+ # TODO - Add explicit typing for the return value
+ def get_raw(self, query: str, limit: int = 10) -> list:
+ connection = http.client.HTTPSConnection(self.api_base)
+ payload = json.dumps({"q": query, "num_outputs": limit})
+ connection.request("POST", "/search", payload, self.headers)
+ response = connection.getresponse()
+ logger.debug("Received response {response} from Serper API.")
+ data = response.read()
+ json_data = json.loads(data.decode("utf-8"))
+ return SerperClient._extract_results(json_data)
+
+ @staticmethod
+ def construct_context(results: list) -> str:
+ # Organize results by type
+ organized_results = {}
+ for result in results:
+ result_type = result.metadata.pop(
+ "type", "Unknown"
+ ) # Pop the type and use as key
+ if result_type not in organized_results:
+ organized_results[result_type] = [result.metadata]
+ else:
+ organized_results[result_type].append(result.metadata)
+
+ context = ""
+ # Iterate over each result type
+ for result_type, items in organized_results.items():
+ context += f"# {result_type} Results:\n"
+ for index, item in enumerate(items, start=1):
+ # Process each item under the current type
+ context += f"Item {index}:\n"
+ context += process_json(item) + "\n"
+
+ return context