aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/__init__.py130
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py154
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/agent.py291
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/agent/base.py22
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py208
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py12
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/__init__.py59
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/auth.py231
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/base.py135
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/crypto.py120
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/database.py197
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/email.py96
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/embedding.py197
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py172
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/llm.py200
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py70
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/utils/__init__.py43
19 files changed, 2359 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/__init__.py b/.venv/lib/python3.12/site-packages/core/base/__init__.py
new file mode 100644
index 00000000..1e872799
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/__init__.py
@@ -0,0 +1,130 @@
+from .abstractions import *
+from .agent import *
+from .api.models import *
+from .parsers import *
+from .providers import *
+from .utils import *
+
+__all__ = [
+ "ThinkingEvent",
+ "ToolCallEvent",
+ "ToolResultEvent",
+ "CitationEvent",
+ "Citation",
+ ## ABSTRACTIONS
+ # Base abstractions
+ "AsyncSyncMeta",
+ "syncable",
+ # Completion abstractions
+ "MessageType",
+ # Document abstractions
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "DocumentType",
+ # Embedding abstractions
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ # Exception abstractions
+ "R2RDocumentProcessingError",
+ "R2RException",
+ # Graph abstractions
+ "Entity",
+ "GraphExtraction",
+ "Relationship",
+ "Community",
+ "GraphCreationSettings",
+ "GraphEnrichmentSettings",
+ # LLM abstractions
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "RAGCompletion",
+ # Prompt abstractions
+ "Prompt",
+ # Search abstractions
+ "AggregateSearchResult",
+ "WebSearchResult",
+ "GraphSearchResult",
+ "GraphSearchSettings",
+ "ChunkSearchSettings",
+ "ChunkSearchResult",
+ "WebPageSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "SearchMode",
+ "HybridSearchSettings",
+ # User abstractions
+ "Token",
+ "TokenData",
+ # Vector abstractions
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "StorageResult",
+ "IndexConfig",
+ ## AGENT
+ # Agent abstractions
+ "Agent",
+ "AgentConfig",
+ "Conversation",
+ "Message",
+ "Tool",
+ "ToolResult",
+ ## API
+ # Auth Responses
+ "TokenResponse",
+ "User",
+ ## PARSERS
+ # Base parser
+ "AsyncParser",
+ ## PROVIDERS
+ # Base provider classes
+ "AppConfig",
+ "Provider",
+ "ProviderConfig",
+ # Auth provider
+ "AuthConfig",
+ "AuthProvider",
+ # Crypto provider
+ "CryptoConfig",
+ "CryptoProvider",
+ # Email provider
+ "EmailConfig",
+ "EmailProvider",
+ # Database providers
+ "LimitSettings",
+ "DatabaseConfig",
+ "DatabaseProvider",
+ "Handler",
+ "PostgresConfigurationSettings",
+ # Embedding provider
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ # Ingestion provider
+ "IngestionMode",
+ "IngestionConfig",
+ "IngestionProvider",
+ "ChunkingStrategy",
+ # LLM provider
+ "CompletionConfig",
+ "CompletionProvider",
+ ## UTILS
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "format_search_results_for_llm",
+ "validate_uuid",
+ # ID generation
+ "generate_id",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_default_user_collection_id",
+ "generate_user_id",
+ "increment_version",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py
new file mode 100644
index 00000000..bb1363fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py
@@ -0,0 +1,154 @@
+from shared.abstractions.base import AsyncSyncMeta, R2RSerializable, syncable
+from shared.abstractions.document import (
+ ChunkEnrichmentSettings,
+ Document,
+ DocumentChunk,
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ IngestionStatus,
+ RawChunk,
+ UnprocessedChunk,
+ UpdateChunk,
+)
+from shared.abstractions.embedding import (
+ EmbeddingPurpose,
+ default_embedding_prefixes,
+)
+from shared.abstractions.exception import (
+ R2RDocumentProcessingError,
+ R2RException,
+)
+from shared.abstractions.graph import (
+ Community,
+ Entity,
+ Graph,
+ GraphCommunitySettings,
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphExtraction,
+ Relationship,
+ StoreType,
+)
+from shared.abstractions.llm import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ Message,
+ MessageType,
+ RAGCompletion,
+)
+from shared.abstractions.prompt import Prompt
+from shared.abstractions.search import (
+ AggregateSearchResult,
+ ChunkSearchResult,
+ ChunkSearchSettings,
+ GraphCommunityResult,
+ GraphEntityResult,
+ GraphRelationshipResult,
+ GraphSearchResult,
+ GraphSearchResultType,
+ GraphSearchSettings,
+ HybridSearchSettings,
+ SearchMode,
+ SearchSettings,
+ WebPageSearchResult,
+ WebSearchResult,
+ select_search_filters,
+)
+from shared.abstractions.user import Token, TokenData, User
+from shared.abstractions.vector import (
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexConfig,
+ IndexMeasure,
+ IndexMethod,
+ StorageResult,
+ Vector,
+ VectorEntry,
+ VectorQuantizationSettings,
+ VectorQuantizationType,
+ VectorTableName,
+ VectorType,
+)
+
+__all__ = [
+ # Base abstractions
+ "R2RSerializable",
+ "AsyncSyncMeta",
+ "syncable",
+ # Completion abstractions
+ "MessageType",
+ # Document abstractions
+ "Document",
+ "DocumentChunk",
+ "DocumentResponse",
+ "DocumentType",
+ "IngestionStatus",
+ "GraphExtractionStatus",
+ "GraphConstructionStatus",
+ "RawChunk",
+ "UnprocessedChunk",
+ "UpdateChunk",
+ # Embedding abstractions
+ "EmbeddingPurpose",
+ "default_embedding_prefixes",
+ # Exception abstractions
+ "R2RDocumentProcessingError",
+ "R2RException",
+ # Graph abstractions
+ "Entity",
+ "Graph",
+ "Community",
+ "StoreType",
+ "GraphExtraction",
+ "Relationship",
+ # Index abstractions
+ "IndexConfig",
+ # LLM abstractions
+ "GenerationConfig",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "Message",
+ "RAGCompletion",
+ # Prompt abstractions
+ "Prompt",
+ # Search abstractions
+ "WebSearchResult",
+ "AggregateSearchResult",
+ "GraphSearchResult",
+ "GraphSearchResultType",
+ "GraphEntityResult",
+ "GraphRelationshipResult",
+ "GraphCommunityResult",
+ "GraphSearchSettings",
+ "ChunkSearchSettings",
+ "ChunkSearchResult",
+ "WebPageSearchResult",
+ "SearchSettings",
+ "select_search_filters",
+ "SearchMode",
+ "HybridSearchSettings",
+ # Graph abstractions
+ "GraphCreationSettings",
+ "GraphEnrichmentSettings",
+ "GraphCommunitySettings",
+ # User abstractions
+ "Token",
+ "TokenData",
+ "User",
+ # Vector abstractions
+ "Vector",
+ "VectorEntry",
+ "VectorType",
+ "IndexMeasure",
+ "IndexMethod",
+ "VectorTableName",
+ "IndexArgsHNSW",
+ "IndexArgsIVFFlat",
+ "VectorQuantizationSettings",
+ "VectorQuantizationType",
+ "StorageResult",
+ "ChunkEnrichmentSettings",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py
new file mode 100644
index 00000000..815b9ae7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py
@@ -0,0 +1,17 @@
+# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
+from .agent import ( # type: ignore
+ Agent,
+ AgentConfig,
+ Conversation,
+ Tool,
+ ToolResult,
+)
+
+__all__ = [
+ # Agent abstractions
+ "Agent",
+ "AgentConfig",
+ "Conversation",
+ "Tool",
+ "ToolResult",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/agent.py b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
new file mode 100644
index 00000000..6813dd21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py
@@ -0,0 +1,291 @@
+# type: ignore
+import asyncio
+import json
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from json import JSONDecodeError
+from typing import Any, AsyncGenerator, Optional, Type
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ Message,
+)
+from core.base.providers import CompletionProvider, DatabaseProvider
+
+from .base import Tool, ToolResult
+
+logger = logging.getLogger()
+
+
+class Conversation:
+ def __init__(self):
+ self.messages: list[Message] = []
+ self._lock = asyncio.Lock()
+
+ async def add_message(self, message):
+ async with self._lock:
+ self.messages.append(message)
+
+ async def get_messages(self) -> list[dict[str, Any]]:
+ async with self._lock:
+ return [
+ {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
+ for msg in self.messages
+ ]
+
+
+# TODO - Move agents to provider pattern
+class AgentConfig(BaseModel):
+ rag_rag_agent_static_prompt: str = "static_rag_agent"
+ rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+ stream: bool = False
+ include_tools: bool = True
+ max_iterations: int = 10
+
+ @classmethod
+ def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ return cls(**filtered_kwargs) # type: ignore
+
+
+class Agent(ABC):
+ def __init__(
+ self,
+ llm_provider: CompletionProvider,
+ database_provider: DatabaseProvider,
+ config: AgentConfig,
+ rag_generation_config: GenerationConfig,
+ ):
+ self.llm_provider = llm_provider
+ self.database_provider: DatabaseProvider = database_provider
+ self.config = config
+ self.conversation = Conversation()
+ self._completed = False
+ self._tools: list[Tool] = []
+ self.tool_calls: list[dict] = []
+ self.rag_generation_config = rag_generation_config
+ self._register_tools()
+
+ @abstractmethod
+ def _register_tools(self):
+ pass
+
+ async def _setup(
+ self, system_instruction: Optional[str] = None, *args, **kwargs
+ ):
+ await self.conversation.add_message(
+ Message(
+ role="system",
+ content=system_instruction
+ or (
+ await self.database_provider.prompts_handler.get_cached_prompt(
+ self.config.rag_rag_agent_static_prompt,
+ inputs={
+ "date": str(datetime.now().strftime("%m/%d/%Y"))
+ },
+ )
+ + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates."
+ ),
+ )
+ )
+
+ @property
+ def tools(self) -> list[Tool]:
+ return self._tools
+
+ @tools.setter
+ def tools(self, tools: list[Tool]):
+ self._tools = tools
+
+ @abstractmethod
+ async def arun(
+ self,
+ system_instruction: Optional[str] = None,
+ messages: Optional[list[Message]] = None,
+ *args,
+ **kwargs,
+ ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
+ pass
+
+ @abstractmethod
+ async def process_llm_response(
+ self,
+ response: Any,
+ *args,
+ **kwargs,
+ ) -> None | AsyncGenerator[str, None]:
+ pass
+
+ async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
+ if tool := next((t for t in self.tools if t.name == tool_name), None):
+ return await tool.results_function(*args, **kwargs)
+ else:
+ return f"Error: Tool {tool_name} not found."
+
+ def get_generation_config(
+ self, last_message: dict, stream: bool = False
+ ) -> GenerationConfig:
+ if (
+ last_message["role"] in ["tool", "function"]
+ and last_message["content"] != ""
+ and "ollama" in self.rag_generation_config.model
+ or not self.config.include_tools
+ ):
+ return GenerationConfig(
+ **self.rag_generation_config.model_dump(
+ exclude={"functions", "tools", "stream"}
+ ),
+ stream=stream,
+ )
+
+ return GenerationConfig(
+ **self.rag_generation_config.model_dump(
+ exclude={"functions", "tools", "stream"}
+ ),
+ # FIXME: Use tools instead of functions
+ # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
+ tools=(
+ [
+ {
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.parameters,
+ },
+ "type": "function",
+ "name": tool.name,
+ }
+ for tool in self.tools
+ ]
+ if self.tools
+ else None
+ ),
+ stream=stream,
+ )
+
+ async def handle_function_or_tool_call(
+ self,
+ function_name: str,
+ function_arguments: str,
+ tool_id: Optional[str] = None,
+ save_messages: bool = True,
+ *args,
+ **kwargs,
+ ) -> ToolResult:
+ logger.debug(
+ f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}"
+ )
+ if tool := next(
+ (t for t in self.tools if t.name == function_name), None
+ ):
+ try:
+ function_args = json.loads(function_arguments)
+
+ except JSONDecodeError as e:
+ error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`."
+ if save_messages:
+ await self.conversation.add_message(
+ Message(
+ role="tool" if tool_id else "function",
+ content=error_message,
+ name=function_name,
+ tool_call_id=tool_id,
+ )
+ )
+
+ # raise R2RException(
+ # message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.",
+ # status_code=400,
+ # )
+
+ merged_kwargs = {**kwargs, **function_args}
+ try:
+ raw_result = await tool.results_function(
+ *args, **merged_kwargs
+ )
+ llm_formatted_result = tool.llm_format_function(raw_result)
+ except Exception as e:
+ raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}."
+ logger.error(raw_result)
+ llm_formatted_result = raw_result
+
+ tool_result = ToolResult(
+ raw_result=raw_result,
+ llm_formatted_result=llm_formatted_result,
+ )
+ if tool.stream_function:
+ tool_result.stream_result = tool.stream_function(raw_result)
+
+ if save_messages:
+ await self.conversation.add_message(
+ Message(
+ role="tool" if tool_id else "function",
+ content=str(tool_result.llm_formatted_result),
+ name=function_name,
+ tool_call_id=tool_id,
+ )
+ )
+ # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
+ if self.rag_generation_config.extended_thinking:
+ await self.conversation.add_message(
+ Message(
+ role="user",
+ content="Continue...",
+ )
+ )
+
+ self.tool_calls.append(
+ {
+ "name": function_name,
+ "args": function_arguments,
+ }
+ )
+ return tool_result
+
+
+# TODO - Move agents to provider pattern
+class RAGAgentConfig(AgentConfig):
+ rag_rag_agent_static_prompt: str = "static_rag_agent"
+ rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
+ stream: bool = False
+ include_tools: bool = True
+ max_iterations: int = 10
+ # tools: list[str] = [] # HACK - unused variable.
+
+ # Default RAG tools
+ rag_tools: list[str] = [
+ "search_file_descriptions",
+ "search_file_knowledge",
+ "get_file_content",
+ ]
+
+ # Default Research tools
+ research_tools: list[str] = [
+ "rag",
+ "reasoning",
+ # DISABLED by default
+ "critique",
+ "python_executor",
+ ]
+
+ @classmethod
+ def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get(
+ "tool_names", None
+ )
+ return cls(**filtered_kwargs) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/base.py b/.venv/lib/python3.12/site-packages/core/base/agent/base.py
new file mode 100644
index 00000000..0d8f15ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/agent/base.py
@@ -0,0 +1,22 @@
+from typing import Any, Callable, Optional
+
+from ..abstractions import R2RSerializable
+
+
+class Tool(R2RSerializable):
+ name: str
+ description: str
+ results_function: Callable
+ llm_format_function: Callable
+ stream_function: Optional[Callable] = None
+ parameters: Optional[dict[str, Any]] = None
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+
+
+class ToolResult(R2RSerializable):
+ raw_result: Any
+ llm_formatted_result: str
+ stream_result: Optional[str] = None
diff --git a/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py b/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py
new file mode 100644
index 00000000..dc0b041f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py
@@ -0,0 +1,208 @@
+from shared.api.models.auth.responses import (
+ TokenResponse,
+ WrappedTokenResponse,
+)
+from shared.api.models.base import (
+ GenericBooleanResponse,
+ GenericMessageResponse,
+ PaginatedR2RResult,
+ R2RResults,
+ WrappedBooleanResponse,
+ WrappedGenericMessageResponse,
+)
+from shared.api.models.graph.responses import ( # TODO: Need to review anything above this
+ Community,
+ Entity,
+ GraphResponse,
+ Relationship,
+ WrappedCommunitiesResponse,
+ WrappedCommunityResponse,
+ WrappedEntitiesResponse,
+ WrappedEntityResponse,
+ WrappedGraphResponse,
+ WrappedGraphsResponse,
+ WrappedRelationshipResponse,
+ WrappedRelationshipsResponse,
+)
+from shared.api.models.ingestion.responses import (
+ IngestionResponse,
+ UpdateResponse,
+ VectorIndexResponse,
+ VectorIndicesResponse,
+ WrappedIngestionResponse,
+ WrappedMetadataUpdateResponse,
+ WrappedUpdateResponse,
+ WrappedVectorIndexResponse,
+ WrappedVectorIndicesResponse,
+)
+from shared.api.models.management.responses import ( # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
+ ChunkResponse,
+ CollectionResponse,
+ ConversationResponse,
+ MessageResponse,
+ PromptResponse,
+ ServerStats,
+ SettingsResponse,
+ User,
+ WrappedAPIKeyResponse,
+ WrappedAPIKeysResponse,
+ WrappedChunkResponse,
+ WrappedChunksResponse,
+ WrappedCollectionResponse,
+ WrappedCollectionsResponse,
+ WrappedConversationMessagesResponse,
+ WrappedConversationResponse,
+ WrappedConversationsResponse,
+ WrappedDocumentResponse,
+ WrappedDocumentsResponse,
+ WrappedLimitsResponse,
+ WrappedLoginResponse,
+ WrappedMessageResponse,
+ WrappedMessagesResponse,
+ WrappedPromptResponse,
+ WrappedPromptsResponse,
+ WrappedServerStatsResponse,
+ WrappedSettingsResponse,
+ WrappedUserResponse,
+ WrappedUsersResponse,
+)
+from shared.api.models.retrieval.responses import (
+ AgentEvent,
+ AgentResponse,
+ Citation,
+ CitationData,
+ CitationEvent,
+ Delta,
+ DeltaPayload,
+ FinalAnswerData,
+ FinalAnswerEvent,
+ MessageData,
+ MessageDelta,
+ MessageEvent,
+ RAGEvent,
+ RAGResponse,
+ SearchResultsData,
+ SearchResultsEvent,
+ SSEEventBase,
+ ThinkingData,
+ ThinkingEvent,
+ ToolCallData,
+ ToolCallEvent,
+ ToolResultData,
+ ToolResultEvent,
+ UnknownEvent,
+ WrappedAgentResponse,
+ WrappedCompletionResponse,
+ WrappedDocumentSearchResponse,
+ WrappedEmbeddingResponse,
+ WrappedLLMChatCompletion,
+ WrappedRAGResponse,
+ WrappedSearchResponse,
+ WrappedVectorSearchResponse,
+)
+
+__all__ = [
+ # Auth Responses
+ "TokenResponse",
+ "WrappedTokenResponse",
+ "WrappedGenericMessageResponse",
+ # Ingestion Responses
+ "IngestionResponse",
+ "WrappedIngestionResponse",
+ "WrappedUpdateResponse",
+ "WrappedMetadataUpdateResponse",
+ "WrappedVectorIndexResponse",
+ "WrappedVectorIndicesResponse",
+ "UpdateResponse",
+ "VectorIndexResponse",
+ "VectorIndicesResponse",
+ # Knowledge Graph Responses
+ "Entity",
+ "Relationship",
+ "Community",
+ "WrappedEntityResponse",
+ "WrappedEntitiesResponse",
+ "WrappedRelationshipResponse",
+ "WrappedRelationshipsResponse",
+ "WrappedCommunityResponse",
+ "WrappedCommunitiesResponse",
+ # TODO: Need to review anything above this
+ "GraphResponse",
+ "WrappedGraphResponse",
+ "WrappedGraphsResponse",
+ # Management Responses
+ "PromptResponse",
+ "ServerStats",
+ "SettingsResponse",
+ "ChunkResponse",
+ "CollectionResponse",
+ "WrappedServerStatsResponse",
+ "WrappedSettingsResponse",
+ "WrappedDocumentResponse",
+ "WrappedDocumentsResponse",
+ "WrappedCollectionResponse",
+ "WrappedCollectionsResponse",
+ # Conversation Responses
+ "ConversationResponse",
+ "WrappedConversationMessagesResponse",
+ "WrappedConversationResponse",
+ "WrappedConversationsResponse",
+ # Prompt Responses
+ "WrappedPromptResponse",
+ "WrappedPromptsResponse",
+ # Conversation Responses
+ "MessageResponse",
+ "WrappedMessageResponse",
+ "WrappedMessagesResponse",
+ # Chunk Responses
+ "WrappedChunkResponse",
+ "WrappedChunksResponse",
+ # User Responses
+ "User",
+ "WrappedUserResponse",
+ "WrappedUsersResponse",
+ "WrappedAPIKeyResponse",
+ "WrappedLimitsResponse",
+ "WrappedAPIKeysResponse",
+ "WrappedLoginResponse",
+ # Base Responses
+ "PaginatedR2RResult",
+ "R2RResults",
+ "GenericBooleanResponse",
+ "GenericMessageResponse",
+ "WrappedBooleanResponse",
+ "WrappedGenericMessageResponse",
+ # Retrieval Responses
+ "SSEEventBase",
+ "SearchResultsData",
+ "SearchResultsEvent",
+ "MessageDelta",
+ "MessageData",
+ "MessageEvent",
+ "DeltaPayload",
+ "Delta",
+ "CitationData",
+ "CitationEvent",
+ "FinalAnswerData",
+ "FinalAnswerEvent",
+ "ToolCallData",
+ "ToolCallEvent",
+ "ToolResultData",
+ "ToolResultEvent",
+ "ThinkingData",
+ "ThinkingEvent",
+ "RAGEvent",
+ "AgentEvent",
+ "UnknownEvent",
+ "RAGResponse",
+ "Citation",
+ "AgentResponse",
+ "WrappedDocumentSearchResponse",
+ "WrappedSearchResponse",
+ "WrappedVectorSearchResponse",
+ "WrappedCompletionResponse",
+ "WrappedRAGResponse",
+ "WrappedAgentResponse",
+ "WrappedLLMChatCompletion",
+ "WrappedEmbeddingResponse",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py
new file mode 100644
index 00000000..d7696202
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py
@@ -0,0 +1,5 @@
+from .base_parser import AsyncParser
+
+__all__ = [
+ "AsyncParser",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py b/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py
new file mode 100644
index 00000000..fb40d767
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py
@@ -0,0 +1,12 @@
+"""Abstract base class for parsers."""
+
+from abc import ABC, abstractmethod
+from typing import AsyncGenerator, Generic, TypeVar
+
+T = TypeVar("T")
+
+
+class AsyncParser(ABC, Generic[T]):
+ @abstractmethod
+ async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py
new file mode 100644
index 00000000..b902944d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py
@@ -0,0 +1,59 @@
+from .auth import AuthConfig, AuthProvider
+from .base import AppConfig, Provider, ProviderConfig
+from .crypto import CryptoConfig, CryptoProvider
+from .database import (
+ DatabaseConfig,
+ DatabaseConnectionManager,
+ DatabaseProvider,
+ Handler,
+ LimitSettings,
+ PostgresConfigurationSettings,
+)
+from .email import EmailConfig, EmailProvider
+from .embedding import EmbeddingConfig, EmbeddingProvider
+from .ingestion import (
+ ChunkingStrategy,
+ IngestionConfig,
+ IngestionMode,
+ IngestionProvider,
+)
+from .llm import CompletionConfig, CompletionProvider
+from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
+
+__all__ = [
+ # Auth provider
+ "AuthConfig",
+ "AuthProvider",
+ # Base provider classes
+ "AppConfig",
+ "Provider",
+ "ProviderConfig",
+ # Ingestion provider
+ "IngestionMode",
+ "IngestionConfig",
+ "IngestionProvider",
+ "ChunkingStrategy",
+ # Crypto provider
+ "CryptoConfig",
+ "CryptoProvider",
+ # Email provider
+ "EmailConfig",
+ "EmailProvider",
+ # Database providers
+ "DatabaseConnectionManager",
+ "DatabaseConfig",
+ "LimitSettings",
+ "PostgresConfigurationSettings",
+ "DatabaseProvider",
+ "Handler",
+ # Embedding provider
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ # LLM provider
+ "CompletionConfig",
+ "CompletionProvider",
+ # Orchestration provider
+ "OrchestrationConfig",
+ "OrchestrationProvider",
+ "Workflow",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
new file mode 100644
index 00000000..352c3331
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
@@ -0,0 +1,231 @@
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import Security
+from fastapi.security import (
+ APIKeyHeader,
+ HTTPAuthorizationCredentials,
+ HTTPBearer,
+)
+
+from ..abstractions import R2RException, Token, TokenData
+from ..api.models import User
+from .base import Provider, ProviderConfig
+from .crypto import CryptoProvider
+from .email import EmailProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
+
+
+class AuthConfig(ProviderConfig):
+ secret_key: Optional[str] = None
+ require_authentication: bool = False
+ require_email_verification: bool = False
+ default_admin_email: str = "admin@example.com"
+ default_admin_password: str = "change_me_immediately"
+ access_token_lifetime_in_minutes: Optional[int] = None
+ refresh_token_lifetime_in_days: Optional[int] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r"]
+
+ def validate_config(self) -> None:
+ pass
+
+
+class AuthProvider(Provider, ABC):
+ security = HTTPBearer(auto_error=False)
+ crypto_provider: CryptoProvider
+ email_provider: EmailProvider
+ database_provider: "PostgresDatabaseProvider"
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: "PostgresDatabaseProvider",
+ email_provider: EmailProvider,
+ ):
+ if not isinstance(config, AuthConfig):
+ raise ValueError(
+ "AuthProvider must be initialized with an AuthConfig"
+ )
+ self.config = config
+ self.admin_email = config.default_admin_email
+ self.admin_password = config.default_admin_password
+ self.crypto_provider = crypto_provider
+ self.database_provider = database_provider
+ self.email_provider = email_provider
+ super().__init__(config)
+ self.config: AuthConfig = config
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
+
+ async def _get_default_admin_user(self) -> User:
+ return await self.database_provider.users_handler.get_user_by_email(
+ self.admin_email
+ )
+
+ @abstractmethod
+ def create_access_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ def create_refresh_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ async def decode_token(self, token: str) -> TokenData:
+ pass
+
+ @abstractmethod
+ async def user(self, token: str) -> User:
+ pass
+
+ @abstractmethod
+ def get_current_active_user(self, current_user: User) -> User:
+ pass
+
+ @abstractmethod
+ async def register(self, email: str, password: str) -> User:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ pass
+
+ @abstractmethod
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ pass
+
+ @abstractmethod
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ pass
+
+ def auth_wrapper(
+ self,
+ public: bool = False,
+ ):
+ async def _auth_wrapper(
+ auth: Optional[HTTPAuthorizationCredentials] = Security(
+ self.security
+ ),
+ api_key: Optional[str] = Security(api_key_header),
+ ) -> User:
+ # If authentication is not required and no credentials are provided, return the default admin user
+ if (
+ ((not self.config.require_authentication) or public)
+ and auth is None
+ and api_key is None
+ ):
+ return await self._get_default_admin_user()
+ if not auth and not api_key:
+ raise R2RException(
+ message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
+ status_code=401,
+ )
+ if auth and api_key:
+ raise R2RException(
+ message="Cannot have both Bearer token and API key",
+ status_code=400,
+ )
+ # 1. Try JWT if `auth` is present (Bearer token)
+ if auth is not None:
+ credentials = auth.credentials
+ try:
+ token_data = await self.decode_token(credentials)
+ user = await self.database_provider.users_handler.get_user_by_email(
+ token_data.email
+ )
+ if user is not None:
+ return user
+ except R2RException:
+ # JWT decoding failed for logical reasons (invalid token)
+ pass
+ except Exception as e:
+ # JWT decoding failed unexpectedly, log and continue
+ logger.debug(f"JWT verification failed: {e}")
+
+ # 2. If JWT failed, try API key from Bearer token
+ # Expected format: key_id.raw_api_key
+ if "." in credentials:
+ key_id, raw_api_key = credentials.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # 3. If no Bearer token worked, try the X-API-Key header
+ if api_key is not None and "." in api_key:
+ key_id, raw_api_key = api_key.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # If we reach here, both JWT and API key auth failed
+ raise R2RException(
+ message="Invalid token or API key",
+ status_code=401,
+ )
+
+ return _auth_wrapper
+
+ @abstractmethod
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def logout(self, token: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/base.py b/.venv/lib/python3.12/site-packages/core/base/providers/base.py
new file mode 100644
index 00000000..3f72a5ea
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/base.py
@@ -0,0 +1,135 @@
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Type
+
+from pydantic import BaseModel
+
+
+class InnerConfig(BaseModel, ABC):
+ """A base provider configuration class."""
+
+ extra_fields: dict[str, Any] = {}
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @classmethod
+ def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs) # type: ignore
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+
+class AppConfig(InnerConfig):
+ project_name: Optional[str] = None
+ default_max_documents_per_user: Optional[int] = 100
+ default_max_chunks_per_user: Optional[int] = 10_000
+ default_max_collections_per_user: Optional[int] = 5
+ default_max_upload_size: int = 2_000_000 # e.g. ~2 MB
+ quality_llm: Optional[str] = None
+ fast_llm: Optional[str] = None
+ vlm: Optional[str] = None
+ audio_lm: Optional[str] = None
+ reasoning_llm: Optional[str] = None
+ planning_llm: Optional[str] = None
+
+ # File extension to max-size mapping
+ # These are examples; adjust sizes as needed.
+ max_upload_size_by_type: dict[str, int] = {
+ # Common text-based formats
+ "txt": 2_000_000,
+ "md": 2_000_000,
+ "tsv": 2_000_000,
+ "csv": 5_000_000,
+ "xml": 2_000_000,
+ "html": 5_000_000,
+ # Office docs
+ "doc": 10_000_000,
+ "docx": 10_000_000,
+ "ppt": 20_000_000,
+ "pptx": 20_000_000,
+ "xls": 10_000_000,
+ "xlsx": 10_000_000,
+ "odt": 5_000_000,
+ # PDFs can expand quite a bit when converted to text
+ "pdf": 30_000_000,
+ # E-mail
+ "eml": 5_000_000,
+ "msg": 5_000_000,
+ "p7s": 5_000_000,
+ # Images
+ "bmp": 5_000_000,
+ "heic": 5_000_000,
+ "jpeg": 5_000_000,
+ "jpg": 5_000_000,
+ "png": 5_000_000,
+ "tiff": 5_000_000,
+ # Others
+ "epub": 10_000_000,
+ "rtf": 5_000_000,
+ "rst": 5_000_000,
+ "org": 5_000_000,
+ }
+
+
+class ProviderConfig(BaseModel, ABC):
+ """A base provider configuration class."""
+
+ app: AppConfig # Add an app_config field
+ extra_fields: dict[str, Any] = {}
+ provider: Optional[str] = None
+
+ class Config:
+ populate_by_name = True
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @abstractmethod
+ def validate_config(self) -> None:
+ pass
+
+ @classmethod
+ def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+ base_args = cls.model_fields.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs) # type: ignore
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+ @property
+ @abstractmethod
+ def supported_providers(self) -> list[str]:
+ """Define a list of supported providers."""
+ pass
+
+ @classmethod
+ def from_dict(
+ cls: Type["ProviderConfig"], data: dict[str, Any]
+ ) -> "ProviderConfig":
+ """Create a new instance of the config from a dictionary."""
+ return cls.create(**data)
+
+
+class Provider(ABC):
+ """A base provider class to provide a common interface for all
+ providers."""
+
+ def __init__(self, config: ProviderConfig, *args, **kwargs):
+ if config:
+ config.validate_config()
+ self.config = config
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py
new file mode 100644
index 00000000..bdf794b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py
@@ -0,0 +1,120 @@
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import Optional, Tuple
+
+from .base import Provider, ProviderConfig
+
+
+class CryptoConfig(ProviderConfig):
+ provider: Optional[str] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["bcrypt", "nacl"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Unsupported crypto provider: {self.provider}")
+
+
+class CryptoProvider(Provider, ABC):
+ def __init__(self, config: CryptoConfig):
+ if not isinstance(config, CryptoConfig):
+ raise ValueError(
+ "CryptoProvider must be initialized with a CryptoConfig"
+ )
+ super().__init__(config)
+
+ @abstractmethod
+ def get_password_hash(self, password: str) -> str:
+ """Hash a plaintext password using a secure password hashing algorithm
+ (e.g., Argon2i)."""
+ pass
+
+ @abstractmethod
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ """Verify that a plaintext password matches the given hashed
+ password."""
+ pass
+
+ @abstractmethod
+ def generate_verification_code(self, length: int = 32) -> str:
+ """Generate a random code for email verification or reset tokens."""
+ pass
+
+ @abstractmethod
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ """Generate a new Ed25519 signing keypair for request signing.
+
+ Returns:
+ A tuple of (key_id, private_key, public_key).
+ - key_id: A unique identifier for this keypair.
+ - private_key: Base64 encoded Ed25519 private key.
+ - public_key: Base64 encoded Ed25519 public key.
+ """
+ pass
+
+ @abstractmethod
+ def sign_request(self, private_key: str, data: str) -> str:
+ """Sign request data with an Ed25519 private key, returning the
+ signature."""
+ pass
+
+ @abstractmethod
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ """Verify a request signature using the corresponding Ed25519 public
+ key."""
+ pass
+
+ @abstractmethod
+ def generate_api_key(self) -> Tuple[str, str]:
+ """Generate a new API key for a user.
+
+ Returns:
+ A tuple (key_id, raw_api_key):
+ - key_id: A unique identifier for the API key.
+ - raw_api_key: The plaintext API key to provide to the user.
+ """
+ pass
+
+ @abstractmethod
+ def hash_api_key(self, raw_api_key: str) -> str:
+ """Hash a raw API key for secure storage in the database.
+
+ Use strong parameters suitable for long-term secrets.
+ """
+ pass
+
+ @abstractmethod
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ """Verify that a provided API key matches the stored hashed version."""
+ pass
+
+ @abstractmethod
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ """Generate a secure, signed token (e.g., JWT) embedding claims.
+
+ Args:
+ data: The claims to include in the token.
+ expiry: A datetime at which the token expires.
+
+ Returns:
+ A JWT string signed with a secret key.
+ """
+ pass
+
+ @abstractmethod
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ """Verify a secure token (e.g., JWT).
+
+ Args:
+ token: The token string to verify.
+
+ Returns:
+ The token payload if valid, otherwise None.
+ """
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
new file mode 100644
index 00000000..845a8109
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
@@ -0,0 +1,197 @@
+"""Base classes for database providers."""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, cast
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphSearchSettings,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class DatabaseConnectionManager(ABC):
+ @abstractmethod
+ def execute_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ isolation_level: Optional[str] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def execute_many(self, query, params=None, batch_size=1000):
+ pass
+
+ @abstractmethod
+ def fetch_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ def fetchrow_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def initialize(self, pool: Any):
+ pass
+
+
+class Handler(ABC):
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: DatabaseConnectionManager,
+ ):
+ self.project_name = project_name
+ self.connection_manager = connection_manager
+
+ def _get_table_name(self, base_name: str) -> str:
+ return f"{self.project_name}.{base_name}"
+
+ @abstractmethod
+ def create_tables(self):
+ pass
+
+
+class PostgresConfigurationSettings(BaseModel):
+ """Configuration settings with defaults defined by the PGVector docker
+ image.
+
+ These settings are helpful in managing the connections to the database. To
+ tune these settings for a specific deployment, see
+ https://pgtune.leopard.in.ua/
+ """
+
+ checkpoint_completion_target: Optional[float] = 0.9
+ default_statistics_target: Optional[int] = 100
+ effective_io_concurrency: Optional[int] = 1
+ effective_cache_size: Optional[int] = 524288
+ huge_pages: Optional[str] = "try"
+ maintenance_work_mem: Optional[int] = 65536
+ max_connections: Optional[int] = 256
+ max_parallel_workers_per_gather: Optional[int] = 2
+ max_parallel_workers: Optional[int] = 8
+ max_parallel_maintenance_workers: Optional[int] = 2
+ max_wal_size: Optional[int] = 1024
+ max_worker_processes: Optional[int] = 8
+ min_wal_size: Optional[int] = 80
+ shared_buffers: Optional[int] = 16384
+ statement_cache_size: Optional[int] = 100
+ random_page_cost: Optional[float] = 4
+ wal_buffers: Optional[int] = 512
+ work_mem: Optional[int] = 4096
+
+
+class LimitSettings(BaseModel):
+ global_per_min: Optional[int] = None
+ route_per_min: Optional[int] = None
+ monthly_limit: Optional[int] = None
+
+ def merge_with_defaults(
+ self, defaults: "LimitSettings"
+ ) -> "LimitSettings":
+ return LimitSettings(
+ global_per_min=self.global_per_min or defaults.global_per_min,
+ route_per_min=self.route_per_min or defaults.route_per_min,
+ monthly_limit=self.monthly_limit or defaults.monthly_limit,
+ )
+
+
+class DatabaseConfig(ProviderConfig):
+ """A base database configuration class."""
+
+ provider: str = "postgres"
+ user: Optional[str] = None
+ password: Optional[str] = None
+ host: Optional[str] = None
+ port: Optional[int] = None
+ db_name: Optional[str] = None
+ project_name: Optional[str] = None
+ postgres_configuration_settings: Optional[
+ PostgresConfigurationSettings
+ ] = None
+ default_collection_name: str = "Default"
+ default_collection_description: str = "Your default collection."
+ collection_summary_system_prompt: str = "system"
+ collection_summary_prompt: str = "collection_summary"
+ enable_fts: bool = False
+
+ # Graph settings
+ batch_size: Optional[int] = 1
+ graph_search_results_store_path: Optional[str] = None
+ graph_enrichment_settings: GraphEnrichmentSettings = (
+ GraphEnrichmentSettings()
+ )
+ graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
+ graph_search_settings: GraphSearchSettings = GraphSearchSettings()
+
+ # Rate limits
+ limits: LimitSettings = LimitSettings(
+ global_per_min=60, route_per_min=20, monthly_limit=10000
+ )
+ route_limits: dict[str, LimitSettings] = {}
+ user_limits: dict[UUID, LimitSettings] = {}
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["postgres"]
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
+ instance = cls.create(**data)
+
+ instance = cast(DatabaseConfig, instance)
+
+ limits_data = data.get("limits", {})
+ default_limits = LimitSettings(
+ global_per_min=limits_data.get("global_per_min", 60),
+ route_per_min=limits_data.get("route_per_min", 20),
+ monthly_limit=limits_data.get("monthly_limit", 10000),
+ )
+
+ instance.limits = default_limits
+
+ route_limits_data = limits_data.get("routes", {})
+ for route_str, route_cfg in route_limits_data.items():
+ instance.route_limits[route_str] = LimitSettings(**route_cfg)
+
+ return instance
+
+
+class DatabaseProvider(Provider):
+ connection_manager: DatabaseConnectionManager
+ config: DatabaseConfig
+ project_name: str
+
+ def __init__(self, config: DatabaseConfig):
+ logger.info(f"Initializing DatabaseProvider with config {config}.")
+ super().__init__(config)
+
+ @abstractmethod
+ async def __aenter__(self):
+ pass
+
+ @abstractmethod
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/email.py b/.venv/lib/python3.12/site-packages/core/base/providers/email.py
new file mode 100644
index 00000000..73f88162
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/email.py
@@ -0,0 +1,96 @@
+import logging
+import os
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from .base import Provider, ProviderConfig
+
+
+class EmailConfig(ProviderConfig):
+ smtp_server: Optional[str] = None
+ smtp_port: Optional[int] = None
+ smtp_username: Optional[str] = None
+ smtp_password: Optional[str] = None
+ from_email: Optional[str] = None
+ use_tls: Optional[bool] = True
+ sendgrid_api_key: Optional[str] = None
+ mailersend_api_key: Optional[str] = None
+ verify_email_template_id: Optional[str] = None
+ reset_password_template_id: Optional[str] = None
+ password_changed_template_id: Optional[str] = None
+ frontend_url: Optional[str] = None
+ sender_name: Optional[str] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return [
+ "smtp",
+ "console",
+ "sendgrid",
+ "mailersend",
+ ] # Could add more providers like AWS SES, SendGrid etc.
+
+ def validate_config(self) -> None:
+ if (
+ self.provider == "sendgrid"
+ and not self.sendgrid_api_key
+ and not os.getenv("SENDGRID_API_KEY")
+ ):
+ raise ValueError(
+ "SendGrid API key is required when using SendGrid provider"
+ )
+
+ if (
+ self.provider == "mailersend"
+ and not self.mailersend_api_key
+ and not os.getenv("MAILERSEND_API_KEY")
+ ):
+ raise ValueError(
+ "MailerSend API key is required when using MailerSend provider"
+ )
+
+
+logger = logging.getLogger(__name__)
+
+
+class EmailProvider(Provider, ABC):
+ def __init__(self, config: EmailConfig):
+ if not isinstance(config, EmailConfig):
+ raise ValueError(
+ "EmailProvider must be initialized with an EmailConfig"
+ )
+ super().__init__(config)
+ self.config: EmailConfig = config
+
+ @abstractmethod
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
new file mode 100644
index 00000000..d1f9f9d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py
@@ -0,0 +1,197 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import VectorQuantizationSettings
+
+from ..abstractions import (
+ ChunkSearchResult,
+ EmbeddingPurpose,
+ default_embedding_prefixes,
+)
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class EmbeddingConfig(ProviderConfig):
+ provider: str
+ base_model: str
+ base_dimension: int | float
+ rerank_model: Optional[str] = None
+ rerank_url: Optional[str] = None
+ batch_size: int = 1
+ prefixes: Optional[dict[str, str]] = None
+ add_title_as_prefix: bool = True
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1
+ max_backoff: float = 64.0
+ quantization_settings: VectorQuantizationSettings = (
+ VectorQuantizationSettings()
+ )
+
+ ## deprecated
+ rerank_dimension: Optional[int] = None
+ rerank_transformer_type: Optional[str] = None
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["litellm", "openai", "ollama"]
+
+
+class EmbeddingProvider(Provider):
+ class Step(Enum):
+ BASE = 1
+ RERANK = 2
+
+ def __init__(self, config: EmbeddingConfig):
+ if not isinstance(config, EmbeddingConfig):
+ raise ValueError(
+ "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+ )
+ logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+ super().__init__(config)
+ self.config: EmbeddingConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.current_requests = 0
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "text": text,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ):
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: Step = Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ ) -> list[list[float]]:
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ @abstractmethod
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ @abstractmethod
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: Step = Step.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
+ self.prefixes = {}
+
+ for t, p in config_prefixes.items():
+ purpose = EmbeddingPurpose(t.lower())
+ self.prefixes[purpose] = p
+
+ if base_model in default_embedding_prefixes:
+ for t, p in default_embedding_prefixes[base_model].items():
+ if t not in self.prefixes:
+ self.prefixes[t] = p
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py
new file mode 100644
index 00000000..70d0d3a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py
@@ -0,0 +1,172 @@
+import logging
+from abc import ABC
+from enum import Enum
+from typing import TYPE_CHECKING, Any, ClassVar, Optional
+
+from pydantic import Field
+
+from core.base.abstractions import ChunkEnrichmentSettings
+
+from .base import AppConfig, Provider, ProviderConfig
+from .llm import CompletionProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+
+class ChunkingStrategy(str, Enum):
+ RECURSIVE = "recursive"
+ CHARACTER = "character"
+ BASIC = "basic"
+ BY_TITLE = "by_title"
+
+
+class IngestionMode(str, Enum):
+ hi_res = "hi-res"
+ fast = "fast"
+ custom = "custom"
+
+
+class IngestionConfig(ProviderConfig):
+ _defaults: ClassVar[dict] = {
+ "app": AppConfig(),
+ "provider": "r2r",
+ "excluded_parsers": ["mp4"],
+ "chunking_strategy": "recursive",
+ "chunk_size": 1024,
+ "chunk_enrichment_settings": ChunkEnrichmentSettings(),
+ "extra_parsers": {},
+ "audio_transcription_model": None,
+ "vision_img_prompt_name": "vision_img",
+ "vision_pdf_prompt_name": "vision_pdf",
+ "skip_document_summary": False,
+ "document_summary_system_prompt": "system",
+ "document_summary_task_prompt": "summary",
+ "document_summary_max_length": 100_000,
+ "chunks_for_document_summary": 128,
+ "document_summary_model": None,
+ "parser_overrides": {},
+ "extra_fields": {},
+ "automatic_extraction": False,
+ }
+
+ provider: str = Field(
+ default_factory=lambda: IngestionConfig._defaults["provider"]
+ )
+ excluded_parsers: list[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
+ )
+ chunking_strategy: str | ChunkingStrategy = Field(
+ default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
+ )
+ chunk_size: int = Field(
+ default_factory=lambda: IngestionConfig._defaults["chunk_size"]
+ )
+ chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "chunk_enrichment_settings"
+ ]
+ )
+ extra_parsers: dict[str, Any] = Field(
+ default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
+ )
+ audio_transcription_model: Optional[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "audio_transcription_model"
+ ]
+ )
+ vision_img_prompt_name: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "vision_img_prompt_name"
+ ]
+ )
+ vision_pdf_prompt_name: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "vision_pdf_prompt_name"
+ ]
+ )
+ skip_document_summary: bool = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "skip_document_summary"
+ ]
+ )
+ document_summary_system_prompt: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_system_prompt"
+ ]
+ )
+ document_summary_task_prompt: str = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_task_prompt"
+ ]
+ )
+ chunks_for_document_summary: int = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "chunks_for_document_summary"
+ ]
+ )
+ document_summary_model: Optional[str] = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_model"
+ ]
+ )
+ parser_overrides: dict[str, str] = Field(
+ default_factory=lambda: IngestionConfig._defaults["parser_overrides"]
+ )
+ automatic_extraction: bool = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "automatic_extraction"
+ ]
+ )
+ document_summary_max_length: int = Field(
+ default_factory=lambda: IngestionConfig._defaults[
+ "document_summary_max_length"
+ ]
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in IngestionConfig"
+ )
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r", "unstructured_local", "unstructured_api"]
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @classmethod
+ def get_default(cls, mode: str, app) -> "IngestionConfig":
+ """Return default ingestion configuration for a given mode."""
+ if mode == "hi-res":
+ return cls(app=app, parser_overrides={"pdf": "zerox"})
+ if mode == "fast":
+ return cls(app=app, skip_document_summary=True)
+ else:
+ return cls(app=app)
+
+
+class IngestionProvider(Provider, ABC):
+ config: IngestionConfig
+ database_provider: "PostgresDatabaseProvider"
+ llm_provider: CompletionProvider
+
+ def __init__(
+ self,
+ config: IngestionConfig,
+ database_provider: "PostgresDatabaseProvider",
+ llm_provider: CompletionProvider,
+ ):
+ super().__init__(config)
+ self.config: IngestionConfig = config
+ self.llm_provider = llm_provider
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
new file mode 100644
index 00000000..669dfc4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py
@@ -0,0 +1,200 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class CompletionConfig(ProviderConfig):
+ provider: Optional[str] = None
+ generation_config: Optional[GenerationConfig] = None
+ concurrent_request_limit: int = 256
+ max_retries: int = 3
+ initial_backoff: float = 1.0
+ max_backoff: float = 64.0
+
+ def validate_config(self) -> None:
+ if not self.provider:
+ raise ValueError("Provider must be set.")
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["anthropic", "litellm", "openai", "r2r"]
+
+
+class CompletionProvider(Provider):
+ def __init__(self, config: CompletionConfig) -> None:
+ if not isinstance(config, CompletionConfig):
+ raise ValueError(
+ "CompletionProvider must be initialized with a `CompletionConfig`."
+ )
+ logger.info(f"Initializing CompletionProvider with config: {config}")
+ super().__init__(config)
+ self.config: CompletionConfig = config
+ self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+ self.thread_pool = ThreadPoolExecutor(
+ max_workers=config.concurrent_request_limit
+ )
+
+ async def _execute_with_backoff_async(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ return await self._execute_task(task)
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ async def _execute_with_backoff_async_stream(
+ self, task: dict[str, Any]
+ ) -> AsyncGenerator[Any, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ async with self.semaphore:
+ async for chunk in await self._execute_task(task):
+ yield chunk
+ return # Successful completion of the stream
+ except AuthenticationError:
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ await asyncio.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync(self, task: dict[str, Any]):
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ return self._execute_task_sync(task)
+ except Exception as e:
+ logger.warning(
+ f"Request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ def _execute_with_backoff_sync_stream(
+ self, task: dict[str, Any]
+ ) -> Generator[Any, None, None]:
+ retries = 0
+ backoff = self.config.initial_backoff
+ while retries < self.config.max_retries:
+ try:
+ yield from self._execute_task_sync(task)
+ return # Successful completion of the stream
+ except Exception as e:
+ logger.warning(
+ f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+ )
+ retries += 1
+ if retries == self.config.max_retries:
+ raise
+ time.sleep(random.uniform(0, backoff))
+ backoff = min(backoff * 2, self.config.max_backoff)
+
+ @abstractmethod
+ async def _execute_task(self, task: dict[str, Any]):
+ pass
+
+ @abstractmethod
+ def _execute_task_sync(self, task: dict[str, Any]):
+ pass
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ response = await self._execute_with_backoff_async(task)
+ return LLMChatCompletion(**response.dict())
+
+ async def aget_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ async for chunk in self._execute_with_backoff_async_stream(task):
+ if isinstance(chunk, dict):
+ yield LLMChatCompletionChunk(**chunk)
+ continue
+
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != ""
+ else None
+ ) # handle error output conventions
+ chunk.choices[0].finish_reason = (
+ chunk.choices[0].finish_reason
+ if chunk.choices[0].finish_reason != "eos"
+ else "stop"
+ ) # hardcode `eos` to `stop` for consistency
+ try:
+ yield LLMChatCompletionChunk(**(chunk.dict()))
+ except Exception as e:
+ logger.error(f"Error parsing chunk: {e}")
+ yield LLMChatCompletionChunk(**(chunk.as_dict()))
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ generation_config.stream = True
+ task = {
+ "messages": messages,
+ "generation_config": generation_config,
+ "kwargs": kwargs,
+ }
+ for chunk in self._execute_with_backoff_sync_stream(task):
+ yield LLMChatCompletionChunk(**chunk.dict())
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py
new file mode 100644
index 00000000..c3105f30
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py
@@ -0,0 +1,70 @@
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+from .base import Provider, ProviderConfig
+
+
+class Workflow(Enum):
+ INGESTION = "ingestion"
+ GRAPH = "graph"
+
+
+class OrchestrationConfig(ProviderConfig):
+ provider: str
+ max_runs: int = 2_048
+ graph_search_results_creation_concurrency_limit: int = 32
+ ingestion_concurrency_limit: int = 16
+ graph_search_results_concurrency_limit: int = 8
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["hatchet", "simple"]
+
+
+class OrchestrationProvider(Provider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ self.config = config
+ self.worker = None
+
+ @abstractmethod
+ async def start_worker(self):
+ pass
+
+ @abstractmethod
+ def get_worker(self, name: str, max_runs: int) -> Any:
+ pass
+
+ @abstractmethod
+ def step(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def workflow(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def failure(self, *args, **kwargs) -> Any:
+ pass
+
+ @abstractmethod
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ pass
+
+ @abstractmethod
+ async def run_workflow(
+ self,
+ workflow_name: str,
+ parameters: dict,
+ options: dict,
+ *args,
+ **kwargs,
+ ) -> dict[str, str]:
+ pass
diff --git a/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py b/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py
new file mode 100644
index 00000000..948a1069
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py
@@ -0,0 +1,43 @@
+from shared.utils import (
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+ _decorate_vector_type,
+ _get_vector_column_str,
+ decrement_version,
+ deep_update,
+ dump_collector,
+ dump_obj,
+ format_search_results_for_llm,
+ generate_default_prompt_id,
+ generate_default_user_collection_id,
+ generate_document_id,
+ generate_entity_document_id,
+ generate_extraction_id,
+ generate_id,
+ generate_user_id,
+ increment_version,
+ validate_uuid,
+ yield_sse_event,
+)
+
+__all__ = [
+ "format_search_results_for_llm",
+ "generate_id",
+ "generate_default_user_collection_id",
+ "increment_version",
+ "decrement_version",
+ "generate_document_id",
+ "generate_extraction_id",
+ "generate_user_id",
+ "generate_entity_document_id",
+ "generate_default_prompt_id",
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "validate_uuid",
+ "deep_update",
+ "_decorate_vector_type",
+ "_get_vector_column_str",
+ "yield_sse_event",
+ "dump_collector",
+ "dump_obj",
+]