diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base')
19 files changed, 2359 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/__init__.py b/.venv/lib/python3.12/site-packages/core/base/__init__.py new file mode 100644 index 00000000..1e872799 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/__init__.py @@ -0,0 +1,130 @@ +from .abstractions import * +from .agent import * +from .api.models import * +from .parsers import * +from .providers import * +from .utils import * + +__all__ = [ + "ThinkingEvent", + "ToolCallEvent", + "ToolResultEvent", + "CitationEvent", + "Citation", + ## ABSTRACTIONS + # Base abstractions + "AsyncSyncMeta", + "syncable", + # Completion abstractions + "MessageType", + # Document abstractions + "Document", + "DocumentChunk", + "DocumentResponse", + "IngestionStatus", + "GraphExtractionStatus", + "GraphConstructionStatus", + "DocumentType", + # Embedding abstractions + "EmbeddingPurpose", + "default_embedding_prefixes", + # Exception abstractions + "R2RDocumentProcessingError", + "R2RException", + # Graph abstractions + "Entity", + "GraphExtraction", + "Relationship", + "Community", + "GraphCreationSettings", + "GraphEnrichmentSettings", + # LLM abstractions + "GenerationConfig", + "LLMChatCompletion", + "LLMChatCompletionChunk", + "RAGCompletion", + # Prompt abstractions + "Prompt", + # Search abstractions + "AggregateSearchResult", + "WebSearchResult", + "GraphSearchResult", + "GraphSearchSettings", + "ChunkSearchSettings", + "ChunkSearchResult", + "WebPageSearchResult", + "SearchSettings", + "select_search_filters", + "SearchMode", + "HybridSearchSettings", + # User abstractions + "Token", + "TokenData", + # Vector abstractions + "Vector", + "VectorEntry", + "VectorType", + "StorageResult", + "IndexConfig", + ## AGENT + # Agent abstractions + "Agent", + "AgentConfig", + "Conversation", + "Message", + "Tool", + "ToolResult", + ## API + # Auth Responses + "TokenResponse", + "User", + ## PARSERS + # Base parser + "AsyncParser", + ## PROVIDERS + # Base provider classes + "AppConfig", + "Provider", + "ProviderConfig", + # Auth provider + "AuthConfig", + "AuthProvider", + # Crypto provider + "CryptoConfig", + "CryptoProvider", + # Email provider + "EmailConfig", + "EmailProvider", + # Database providers + "LimitSettings", + "DatabaseConfig", + "DatabaseProvider", + "Handler", + "PostgresConfigurationSettings", + # Embedding provider + "EmbeddingConfig", + "EmbeddingProvider", + # Ingestion provider + "IngestionMode", + "IngestionConfig", + "IngestionProvider", + "ChunkingStrategy", + # LLM provider + "CompletionConfig", + "CompletionProvider", + ## UTILS + "RecursiveCharacterTextSplitter", + "TextSplitter", + "format_search_results_for_llm", + "validate_uuid", + # ID generation + "generate_id", + "generate_document_id", + "generate_extraction_id", + "generate_default_user_collection_id", + "generate_user_id", + "increment_version", + "yield_sse_event", + "dump_collector", + "dump_obj", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py b/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py new file mode 100644 index 00000000..bb1363fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/abstractions/__init__.py @@ -0,0 +1,154 @@ +from shared.abstractions.base import AsyncSyncMeta, R2RSerializable, syncable +from shared.abstractions.document import ( + ChunkEnrichmentSettings, + Document, + DocumentChunk, + DocumentResponse, + DocumentType, + GraphConstructionStatus, + GraphExtractionStatus, + IngestionStatus, + RawChunk, + UnprocessedChunk, + UpdateChunk, +) +from shared.abstractions.embedding import ( + EmbeddingPurpose, + default_embedding_prefixes, +) +from shared.abstractions.exception import ( + R2RDocumentProcessingError, + R2RException, +) +from shared.abstractions.graph import ( + Community, + Entity, + Graph, + GraphCommunitySettings, + GraphCreationSettings, + GraphEnrichmentSettings, + GraphExtraction, + Relationship, + StoreType, +) +from shared.abstractions.llm import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, + Message, + MessageType, + RAGCompletion, +) +from shared.abstractions.prompt import Prompt +from shared.abstractions.search import ( + AggregateSearchResult, + ChunkSearchResult, + ChunkSearchSettings, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, + GraphSearchResult, + GraphSearchResultType, + GraphSearchSettings, + HybridSearchSettings, + SearchMode, + SearchSettings, + WebPageSearchResult, + WebSearchResult, + select_search_filters, +) +from shared.abstractions.user import Token, TokenData, User +from shared.abstractions.vector import ( + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexConfig, + IndexMeasure, + IndexMethod, + StorageResult, + Vector, + VectorEntry, + VectorQuantizationSettings, + VectorQuantizationType, + VectorTableName, + VectorType, +) + +__all__ = [ + # Base abstractions + "R2RSerializable", + "AsyncSyncMeta", + "syncable", + # Completion abstractions + "MessageType", + # Document abstractions + "Document", + "DocumentChunk", + "DocumentResponse", + "DocumentType", + "IngestionStatus", + "GraphExtractionStatus", + "GraphConstructionStatus", + "RawChunk", + "UnprocessedChunk", + "UpdateChunk", + # Embedding abstractions + "EmbeddingPurpose", + "default_embedding_prefixes", + # Exception abstractions + "R2RDocumentProcessingError", + "R2RException", + # Graph abstractions + "Entity", + "Graph", + "Community", + "StoreType", + "GraphExtraction", + "Relationship", + # Index abstractions + "IndexConfig", + # LLM abstractions + "GenerationConfig", + "LLMChatCompletion", + "LLMChatCompletionChunk", + "Message", + "RAGCompletion", + # Prompt abstractions + "Prompt", + # Search abstractions + "WebSearchResult", + "AggregateSearchResult", + "GraphSearchResult", + "GraphSearchResultType", + "GraphEntityResult", + "GraphRelationshipResult", + "GraphCommunityResult", + "GraphSearchSettings", + "ChunkSearchSettings", + "ChunkSearchResult", + "WebPageSearchResult", + "SearchSettings", + "select_search_filters", + "SearchMode", + "HybridSearchSettings", + # Graph abstractions + "GraphCreationSettings", + "GraphEnrichmentSettings", + "GraphCommunitySettings", + # User abstractions + "Token", + "TokenData", + "User", + # Vector abstractions + "Vector", + "VectorEntry", + "VectorType", + "IndexMeasure", + "IndexMethod", + "VectorTableName", + "IndexArgsHNSW", + "IndexArgsIVFFlat", + "VectorQuantizationSettings", + "VectorQuantizationType", + "StorageResult", + "ChunkEnrichmentSettings", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py new file mode 100644 index 00000000..815b9ae7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/agent/__init__.py @@ -0,0 +1,17 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +from .agent import ( # type: ignore + Agent, + AgentConfig, + Conversation, + Tool, + ToolResult, +) + +__all__ = [ + # Agent abstractions + "Agent", + "AgentConfig", + "Conversation", + "Tool", + "ToolResult", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/agent.py b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py new file mode 100644 index 00000000..6813dd21 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/agent/agent.py @@ -0,0 +1,291 @@ +# type: ignore +import asyncio +import json +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from json import JSONDecodeError +from typing import Any, AsyncGenerator, Optional, Type + +from pydantic import BaseModel + +from core.base.abstractions import ( + GenerationConfig, + LLMChatCompletion, + Message, +) +from core.base.providers import CompletionProvider, DatabaseProvider + +from .base import Tool, ToolResult + +logger = logging.getLogger() + + +class Conversation: + def __init__(self): + self.messages: list[Message] = [] + self._lock = asyncio.Lock() + + async def add_message(self, message): + async with self._lock: + self.messages.append(message) + + async def get_messages(self) -> list[dict[str, Any]]: + async with self._lock: + return [ + {**msg.model_dump(exclude_none=True), "role": str(msg.role)} + for msg in self.messages + ] + + +# TODO - Move agents to provider pattern +class AgentConfig(BaseModel): + rag_rag_agent_static_prompt: str = "static_rag_agent" + rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted" + stream: bool = False + include_tools: bool = True + max_iterations: int = 10 + + @classmethod + def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + return cls(**filtered_kwargs) # type: ignore + + +class Agent(ABC): + def __init__( + self, + llm_provider: CompletionProvider, + database_provider: DatabaseProvider, + config: AgentConfig, + rag_generation_config: GenerationConfig, + ): + self.llm_provider = llm_provider + self.database_provider: DatabaseProvider = database_provider + self.config = config + self.conversation = Conversation() + self._completed = False + self._tools: list[Tool] = [] + self.tool_calls: list[dict] = [] + self.rag_generation_config = rag_generation_config + self._register_tools() + + @abstractmethod + def _register_tools(self): + pass + + async def _setup( + self, system_instruction: Optional[str] = None, *args, **kwargs + ): + await self.conversation.add_message( + Message( + role="system", + content=system_instruction + or ( + await self.database_provider.prompts_handler.get_cached_prompt( + self.config.rag_rag_agent_static_prompt, + inputs={ + "date": str(datetime.now().strftime("%m/%d/%Y")) + }, + ) + + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates." + ), + ) + ) + + @property + def tools(self) -> list[Tool]: + return self._tools + + @tools.setter + def tools(self, tools: list[Tool]): + self._tools = tools + + @abstractmethod + async def arun( + self, + system_instruction: Optional[str] = None, + messages: Optional[list[Message]] = None, + *args, + **kwargs, + ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]: + pass + + @abstractmethod + async def process_llm_response( + self, + response: Any, + *args, + **kwargs, + ) -> None | AsyncGenerator[str, None]: + pass + + async def execute_tool(self, tool_name: str, *args, **kwargs) -> str: + if tool := next((t for t in self.tools if t.name == tool_name), None): + return await tool.results_function(*args, **kwargs) + else: + return f"Error: Tool {tool_name} not found." + + def get_generation_config( + self, last_message: dict, stream: bool = False + ) -> GenerationConfig: + if ( + last_message["role"] in ["tool", "function"] + and last_message["content"] != "" + and "ollama" in self.rag_generation_config.model + or not self.config.include_tools + ): + return GenerationConfig( + **self.rag_generation_config.model_dump( + exclude={"functions", "tools", "stream"} + ), + stream=stream, + ) + + return GenerationConfig( + **self.rag_generation_config.model_dump( + exclude={"functions", "tools", "stream"} + ), + # FIXME: Use tools instead of functions + # TODO - Investigate why `tools` fails with OpenAI+LiteLLM + tools=( + [ + { + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + "type": "function", + "name": tool.name, + } + for tool in self.tools + ] + if self.tools + else None + ), + stream=stream, + ) + + async def handle_function_or_tool_call( + self, + function_name: str, + function_arguments: str, + tool_id: Optional[str] = None, + save_messages: bool = True, + *args, + **kwargs, + ) -> ToolResult: + logger.debug( + f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}" + ) + if tool := next( + (t for t in self.tools if t.name == function_name), None + ): + try: + function_args = json.loads(function_arguments) + + except JSONDecodeError as e: + error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`." + if save_messages: + await self.conversation.add_message( + Message( + role="tool" if tool_id else "function", + content=error_message, + name=function_name, + tool_call_id=tool_id, + ) + ) + + # raise R2RException( + # message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.", + # status_code=400, + # ) + + merged_kwargs = {**kwargs, **function_args} + try: + raw_result = await tool.results_function( + *args, **merged_kwargs + ) + llm_formatted_result = tool.llm_format_function(raw_result) + except Exception as e: + raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}." + logger.error(raw_result) + llm_formatted_result = raw_result + + tool_result = ToolResult( + raw_result=raw_result, + llm_formatted_result=llm_formatted_result, + ) + if tool.stream_function: + tool_result.stream_result = tool.stream_function(raw_result) + + if save_messages: + await self.conversation.add_message( + Message( + role="tool" if tool_id else "function", + content=str(tool_result.llm_formatted_result), + name=function_name, + tool_call_id=tool_id, + ) + ) + # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb] + if self.rag_generation_config.extended_thinking: + await self.conversation.add_message( + Message( + role="user", + content="Continue...", + ) + ) + + self.tool_calls.append( + { + "name": function_name, + "args": function_arguments, + } + ) + return tool_result + + +# TODO - Move agents to provider pattern +class RAGAgentConfig(AgentConfig): + rag_rag_agent_static_prompt: str = "static_rag_agent" + rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted" + stream: bool = False + include_tools: bool = True + max_iterations: int = 10 + # tools: list[str] = [] # HACK - unused variable. + + # Default RAG tools + rag_tools: list[str] = [ + "search_file_descriptions", + "search_file_knowledge", + "get_file_content", + ] + + # Default Research tools + research_tools: list[str] = [ + "rag", + "reasoning", + # DISABLED by default + "critique", + "python_executor", + ] + + @classmethod + def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get( + "tool_names", None + ) + return cls(**filtered_kwargs) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/core/base/agent/base.py b/.venv/lib/python3.12/site-packages/core/base/agent/base.py new file mode 100644 index 00000000..0d8f15ee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/agent/base.py @@ -0,0 +1,22 @@ +from typing import Any, Callable, Optional + +from ..abstractions import R2RSerializable + + +class Tool(R2RSerializable): + name: str + description: str + results_function: Callable + llm_format_function: Callable + stream_function: Optional[Callable] = None + parameters: Optional[dict[str, Any]] = None + + class Config: + populate_by_name = True + arbitrary_types_allowed = True + + +class ToolResult(R2RSerializable): + raw_result: Any + llm_formatted_result: str + stream_result: Optional[str] = None diff --git a/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py b/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py new file mode 100644 index 00000000..dc0b041f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/api/models/__init__.py @@ -0,0 +1,208 @@ +from shared.api.models.auth.responses import ( + TokenResponse, + WrappedTokenResponse, +) +from shared.api.models.base import ( + GenericBooleanResponse, + GenericMessageResponse, + PaginatedR2RResult, + R2RResults, + WrappedBooleanResponse, + WrappedGenericMessageResponse, +) +from shared.api.models.graph.responses import ( # TODO: Need to review anything above this + Community, + Entity, + GraphResponse, + Relationship, + WrappedCommunitiesResponse, + WrappedCommunityResponse, + WrappedEntitiesResponse, + WrappedEntityResponse, + WrappedGraphResponse, + WrappedGraphsResponse, + WrappedRelationshipResponse, + WrappedRelationshipsResponse, +) +from shared.api.models.ingestion.responses import ( + IngestionResponse, + UpdateResponse, + VectorIndexResponse, + VectorIndicesResponse, + WrappedIngestionResponse, + WrappedMetadataUpdateResponse, + WrappedUpdateResponse, + WrappedVectorIndexResponse, + WrappedVectorIndicesResponse, +) +from shared.api.models.management.responses import ( # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed + ChunkResponse, + CollectionResponse, + ConversationResponse, + MessageResponse, + PromptResponse, + ServerStats, + SettingsResponse, + User, + WrappedAPIKeyResponse, + WrappedAPIKeysResponse, + WrappedChunkResponse, + WrappedChunksResponse, + WrappedCollectionResponse, + WrappedCollectionsResponse, + WrappedConversationMessagesResponse, + WrappedConversationResponse, + WrappedConversationsResponse, + WrappedDocumentResponse, + WrappedDocumentsResponse, + WrappedLimitsResponse, + WrappedLoginResponse, + WrappedMessageResponse, + WrappedMessagesResponse, + WrappedPromptResponse, + WrappedPromptsResponse, + WrappedServerStatsResponse, + WrappedSettingsResponse, + WrappedUserResponse, + WrappedUsersResponse, +) +from shared.api.models.retrieval.responses import ( + AgentEvent, + AgentResponse, + Citation, + CitationData, + CitationEvent, + Delta, + DeltaPayload, + FinalAnswerData, + FinalAnswerEvent, + MessageData, + MessageDelta, + MessageEvent, + RAGEvent, + RAGResponse, + SearchResultsData, + SearchResultsEvent, + SSEEventBase, + ThinkingData, + ThinkingEvent, + ToolCallData, + ToolCallEvent, + ToolResultData, + ToolResultEvent, + UnknownEvent, + WrappedAgentResponse, + WrappedCompletionResponse, + WrappedDocumentSearchResponse, + WrappedEmbeddingResponse, + WrappedLLMChatCompletion, + WrappedRAGResponse, + WrappedSearchResponse, + WrappedVectorSearchResponse, +) + +__all__ = [ + # Auth Responses + "TokenResponse", + "WrappedTokenResponse", + "WrappedGenericMessageResponse", + # Ingestion Responses + "IngestionResponse", + "WrappedIngestionResponse", + "WrappedUpdateResponse", + "WrappedMetadataUpdateResponse", + "WrappedVectorIndexResponse", + "WrappedVectorIndicesResponse", + "UpdateResponse", + "VectorIndexResponse", + "VectorIndicesResponse", + # Knowledge Graph Responses + "Entity", + "Relationship", + "Community", + "WrappedEntityResponse", + "WrappedEntitiesResponse", + "WrappedRelationshipResponse", + "WrappedRelationshipsResponse", + "WrappedCommunityResponse", + "WrappedCommunitiesResponse", + # TODO: Need to review anything above this + "GraphResponse", + "WrappedGraphResponse", + "WrappedGraphsResponse", + # Management Responses + "PromptResponse", + "ServerStats", + "SettingsResponse", + "ChunkResponse", + "CollectionResponse", + "WrappedServerStatsResponse", + "WrappedSettingsResponse", + "WrappedDocumentResponse", + "WrappedDocumentsResponse", + "WrappedCollectionResponse", + "WrappedCollectionsResponse", + # Conversation Responses + "ConversationResponse", + "WrappedConversationMessagesResponse", + "WrappedConversationResponse", + "WrappedConversationsResponse", + # Prompt Responses + "WrappedPromptResponse", + "WrappedPromptsResponse", + # Conversation Responses + "MessageResponse", + "WrappedMessageResponse", + "WrappedMessagesResponse", + # Chunk Responses + "WrappedChunkResponse", + "WrappedChunksResponse", + # User Responses + "User", + "WrappedUserResponse", + "WrappedUsersResponse", + "WrappedAPIKeyResponse", + "WrappedLimitsResponse", + "WrappedAPIKeysResponse", + "WrappedLoginResponse", + # Base Responses + "PaginatedR2RResult", + "R2RResults", + "GenericBooleanResponse", + "GenericMessageResponse", + "WrappedBooleanResponse", + "WrappedGenericMessageResponse", + # Retrieval Responses + "SSEEventBase", + "SearchResultsData", + "SearchResultsEvent", + "MessageDelta", + "MessageData", + "MessageEvent", + "DeltaPayload", + "Delta", + "CitationData", + "CitationEvent", + "FinalAnswerData", + "FinalAnswerEvent", + "ToolCallData", + "ToolCallEvent", + "ToolResultData", + "ToolResultEvent", + "ThinkingData", + "ThinkingEvent", + "RAGEvent", + "AgentEvent", + "UnknownEvent", + "RAGResponse", + "Citation", + "AgentResponse", + "WrappedDocumentSearchResponse", + "WrappedSearchResponse", + "WrappedVectorSearchResponse", + "WrappedCompletionResponse", + "WrappedRAGResponse", + "WrappedAgentResponse", + "WrappedLLMChatCompletion", + "WrappedEmbeddingResponse", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py new file mode 100644 index 00000000..d7696202 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/parsers/__init__.py @@ -0,0 +1,5 @@ +from .base_parser import AsyncParser + +__all__ = [ + "AsyncParser", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py b/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py new file mode 100644 index 00000000..fb40d767 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/parsers/base_parser.py @@ -0,0 +1,12 @@ +"""Abstract base class for parsers.""" + +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Generic, TypeVar + +T = TypeVar("T") + + +class AsyncParser(ABC, Generic[T]): + @abstractmethod + async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]: + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py new file mode 100644 index 00000000..b902944d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/__init__.py @@ -0,0 +1,59 @@ +from .auth import AuthConfig, AuthProvider +from .base import AppConfig, Provider, ProviderConfig +from .crypto import CryptoConfig, CryptoProvider +from .database import ( + DatabaseConfig, + DatabaseConnectionManager, + DatabaseProvider, + Handler, + LimitSettings, + PostgresConfigurationSettings, +) +from .email import EmailConfig, EmailProvider +from .embedding import EmbeddingConfig, EmbeddingProvider +from .ingestion import ( + ChunkingStrategy, + IngestionConfig, + IngestionMode, + IngestionProvider, +) +from .llm import CompletionConfig, CompletionProvider +from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow + +__all__ = [ + # Auth provider + "AuthConfig", + "AuthProvider", + # Base provider classes + "AppConfig", + "Provider", + "ProviderConfig", + # Ingestion provider + "IngestionMode", + "IngestionConfig", + "IngestionProvider", + "ChunkingStrategy", + # Crypto provider + "CryptoConfig", + "CryptoProvider", + # Email provider + "EmailConfig", + "EmailProvider", + # Database providers + "DatabaseConnectionManager", + "DatabaseConfig", + "LimitSettings", + "PostgresConfigurationSettings", + "DatabaseProvider", + "Handler", + # Embedding provider + "EmbeddingConfig", + "EmbeddingProvider", + # LLM provider + "CompletionConfig", + "CompletionProvider", + # Orchestration provider + "OrchestrationConfig", + "OrchestrationProvider", + "Workflow", +] diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py new file mode 100644 index 00000000..352c3331 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py @@ -0,0 +1,231 @@ +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from fastapi import Security +from fastapi.security import ( + APIKeyHeader, + HTTPAuthorizationCredentials, + HTTPBearer, +) + +from ..abstractions import R2RException, Token, TokenData +from ..api.models import User +from .base import Provider, ProviderConfig +from .crypto import CryptoProvider +from .email import EmailProvider + +logger = logging.getLogger() + +if TYPE_CHECKING: + from core.providers.database import PostgresDatabaseProvider + +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +class AuthConfig(ProviderConfig): + secret_key: Optional[str] = None + require_authentication: bool = False + require_email_verification: bool = False + default_admin_email: str = "admin@example.com" + default_admin_password: str = "change_me_immediately" + access_token_lifetime_in_minutes: Optional[int] = None + refresh_token_lifetime_in_days: Optional[int] = None + + @property + def supported_providers(self) -> list[str]: + return ["r2r"] + + def validate_config(self) -> None: + pass + + +class AuthProvider(Provider, ABC): + security = HTTPBearer(auto_error=False) + crypto_provider: CryptoProvider + email_provider: EmailProvider + database_provider: "PostgresDatabaseProvider" + + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: "PostgresDatabaseProvider", + email_provider: EmailProvider, + ): + if not isinstance(config, AuthConfig): + raise ValueError( + "AuthProvider must be initialized with an AuthConfig" + ) + self.config = config + self.admin_email = config.default_admin_email + self.admin_password = config.default_admin_password + self.crypto_provider = crypto_provider + self.database_provider = database_provider + self.email_provider = email_provider + super().__init__(config) + self.config: AuthConfig = config + self.database_provider: "PostgresDatabaseProvider" = database_provider + + async def _get_default_admin_user(self) -> User: + return await self.database_provider.users_handler.get_user_by_email( + self.admin_email + ) + + @abstractmethod + def create_access_token(self, data: dict) -> str: + pass + + @abstractmethod + def create_refresh_token(self, data: dict) -> str: + pass + + @abstractmethod + async def decode_token(self, token: str) -> TokenData: + pass + + @abstractmethod + async def user(self, token: str) -> User: + pass + + @abstractmethod + def get_current_active_user(self, current_user: User) -> User: + pass + + @abstractmethod + async def register(self, email: str, password: str) -> User: + pass + + @abstractmethod + async def send_verification_email( + self, email: str, user: Optional[User] = None + ) -> tuple[str, datetime]: + pass + + @abstractmethod + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def login(self, email: str, password: str) -> dict[str, Token]: + pass + + @abstractmethod + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + pass + + def auth_wrapper( + self, + public: bool = False, + ): + async def _auth_wrapper( + auth: Optional[HTTPAuthorizationCredentials] = Security( + self.security + ), + api_key: Optional[str] = Security(api_key_header), + ) -> User: + # If authentication is not required and no credentials are provided, return the default admin user + if ( + ((not self.config.require_authentication) or public) + and auth is None + and api_key is None + ): + return await self._get_default_admin_user() + if not auth and not api_key: + raise R2RException( + message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.", + status_code=401, + ) + if auth and api_key: + raise R2RException( + message="Cannot have both Bearer token and API key", + status_code=400, + ) + # 1. Try JWT if `auth` is present (Bearer token) + if auth is not None: + credentials = auth.credentials + try: + token_data = await self.decode_token(credentials) + user = await self.database_provider.users_handler.get_user_by_email( + token_data.email + ) + if user is not None: + return user + except R2RException: + # JWT decoding failed for logical reasons (invalid token) + pass + except Exception as e: + # JWT decoding failed unexpectedly, log and continue + logger.debug(f"JWT verification failed: {e}") + + # 2. If JWT failed, try API key from Bearer token + # Expected format: key_id.raw_api_key + if "." in credentials: + key_id, raw_api_key = credentials.split(".", 1) + api_key_record = await self.database_provider.users_handler.get_api_key_record( + key_id + ) + if api_key_record is not None: + hashed_key = api_key_record["hashed_key"] + if self.crypto_provider.verify_api_key( + raw_api_key, hashed_key + ): + user = await self.database_provider.users_handler.get_user_by_id( + api_key_record["user_id"] + ) + if user is not None and user.is_active: + return user + + # 3. If no Bearer token worked, try the X-API-Key header + if api_key is not None and "." in api_key: + key_id, raw_api_key = api_key.split(".", 1) + api_key_record = await self.database_provider.users_handler.get_api_key_record( + key_id + ) + if api_key_record is not None: + hashed_key = api_key_record["hashed_key"] + if self.crypto_provider.verify_api_key( + raw_api_key, hashed_key + ): + user = await self.database_provider.users_handler.get_user_by_id( + api_key_record["user_id"] + ) + if user is not None and user.is_active: + return user + + # If we reach here, both JWT and API key auth failed + raise R2RException( + message="Invalid token or API key", + status_code=401, + ) + + return _auth_wrapper + + @abstractmethod + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def request_password_reset(self, email: str) -> dict[str, str]: + pass + + @abstractmethod + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def logout(self, token: str) -> dict[str, str]: + pass + + @abstractmethod + async def send_reset_email(self, email: str) -> dict[str, str]: + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/base.py b/.venv/lib/python3.12/site-packages/core/base/providers/base.py new file mode 100644 index 00000000..3f72a5ea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/base.py @@ -0,0 +1,135 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Type + +from pydantic import BaseModel + + +class InnerConfig(BaseModel, ABC): + """A base provider configuration class.""" + + extra_fields: dict[str, Any] = {} + + class Config: + populate_by_name = True + arbitrary_types_allowed = True + ignore_extra = True + + @classmethod + def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) # type: ignore + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + +class AppConfig(InnerConfig): + project_name: Optional[str] = None + default_max_documents_per_user: Optional[int] = 100 + default_max_chunks_per_user: Optional[int] = 10_000 + default_max_collections_per_user: Optional[int] = 5 + default_max_upload_size: int = 2_000_000 # e.g. ~2 MB + quality_llm: Optional[str] = None + fast_llm: Optional[str] = None + vlm: Optional[str] = None + audio_lm: Optional[str] = None + reasoning_llm: Optional[str] = None + planning_llm: Optional[str] = None + + # File extension to max-size mapping + # These are examples; adjust sizes as needed. + max_upload_size_by_type: dict[str, int] = { + # Common text-based formats + "txt": 2_000_000, + "md": 2_000_000, + "tsv": 2_000_000, + "csv": 5_000_000, + "xml": 2_000_000, + "html": 5_000_000, + # Office docs + "doc": 10_000_000, + "docx": 10_000_000, + "ppt": 20_000_000, + "pptx": 20_000_000, + "xls": 10_000_000, + "xlsx": 10_000_000, + "odt": 5_000_000, + # PDFs can expand quite a bit when converted to text + "pdf": 30_000_000, + # E-mail + "eml": 5_000_000, + "msg": 5_000_000, + "p7s": 5_000_000, + # Images + "bmp": 5_000_000, + "heic": 5_000_000, + "jpeg": 5_000_000, + "jpg": 5_000_000, + "png": 5_000_000, + "tiff": 5_000_000, + # Others + "epub": 10_000_000, + "rtf": 5_000_000, + "rst": 5_000_000, + "org": 5_000_000, + } + + +class ProviderConfig(BaseModel, ABC): + """A base provider configuration class.""" + + app: AppConfig # Add an app_config field + extra_fields: dict[str, Any] = {} + provider: Optional[str] = None + + class Config: + populate_by_name = True + arbitrary_types_allowed = True + ignore_extra = True + + @abstractmethod + def validate_config(self) -> None: + pass + + @classmethod + def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig": + base_args = cls.model_fields.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) # type: ignore + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + @property + @abstractmethod + def supported_providers(self) -> list[str]: + """Define a list of supported providers.""" + pass + + @classmethod + def from_dict( + cls: Type["ProviderConfig"], data: dict[str, Any] + ) -> "ProviderConfig": + """Create a new instance of the config from a dictionary.""" + return cls.create(**data) + + +class Provider(ABC): + """A base provider class to provide a common interface for all + providers.""" + + def __init__(self, config: ProviderConfig, *args, **kwargs): + if config: + config.validate_config() + self.config = config diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py new file mode 100644 index 00000000..bdf794b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/crypto.py @@ -0,0 +1,120 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional, Tuple + +from .base import Provider, ProviderConfig + + +class CryptoConfig(ProviderConfig): + provider: Optional[str] = None + + @property + def supported_providers(self) -> list[str]: + return ["bcrypt", "nacl"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Unsupported crypto provider: {self.provider}") + + +class CryptoProvider(Provider, ABC): + def __init__(self, config: CryptoConfig): + if not isinstance(config, CryptoConfig): + raise ValueError( + "CryptoProvider must be initialized with a CryptoConfig" + ) + super().__init__(config) + + @abstractmethod + def get_password_hash(self, password: str) -> str: + """Hash a plaintext password using a secure password hashing algorithm + (e.g., Argon2i).""" + pass + + @abstractmethod + def verify_password( + self, plain_password: str, hashed_password: str + ) -> bool: + """Verify that a plaintext password matches the given hashed + password.""" + pass + + @abstractmethod + def generate_verification_code(self, length: int = 32) -> str: + """Generate a random code for email verification or reset tokens.""" + pass + + @abstractmethod + def generate_signing_keypair(self) -> Tuple[str, str, str]: + """Generate a new Ed25519 signing keypair for request signing. + + Returns: + A tuple of (key_id, private_key, public_key). + - key_id: A unique identifier for this keypair. + - private_key: Base64 encoded Ed25519 private key. + - public_key: Base64 encoded Ed25519 public key. + """ + pass + + @abstractmethod + def sign_request(self, private_key: str, data: str) -> str: + """Sign request data with an Ed25519 private key, returning the + signature.""" + pass + + @abstractmethod + def verify_request_signature( + self, public_key: str, signature: str, data: str + ) -> bool: + """Verify a request signature using the corresponding Ed25519 public + key.""" + pass + + @abstractmethod + def generate_api_key(self) -> Tuple[str, str]: + """Generate a new API key for a user. + + Returns: + A tuple (key_id, raw_api_key): + - key_id: A unique identifier for the API key. + - raw_api_key: The plaintext API key to provide to the user. + """ + pass + + @abstractmethod + def hash_api_key(self, raw_api_key: str) -> str: + """Hash a raw API key for secure storage in the database. + + Use strong parameters suitable for long-term secrets. + """ + pass + + @abstractmethod + def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool: + """Verify that a provided API key matches the stored hashed version.""" + pass + + @abstractmethod + def generate_secure_token(self, data: dict, expiry: datetime) -> str: + """Generate a secure, signed token (e.g., JWT) embedding claims. + + Args: + data: The claims to include in the token. + expiry: A datetime at which the token expires. + + Returns: + A JWT string signed with a secret key. + """ + pass + + @abstractmethod + def verify_secure_token(self, token: str) -> Optional[dict]: + """Verify a secure token (e.g., JWT). + + Args: + token: The token string to verify. + + Returns: + The token payload if valid, otherwise None. + """ + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py new file mode 100644 index 00000000..845a8109 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py @@ -0,0 +1,197 @@ +"""Base classes for database providers.""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional, Sequence, cast +from uuid import UUID + +from pydantic import BaseModel + +from core.base.abstractions import ( + GraphCreationSettings, + GraphEnrichmentSettings, + GraphSearchSettings, +) + +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class DatabaseConnectionManager(ABC): + @abstractmethod + def execute_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + isolation_level: Optional[str] = None, + ): + pass + + @abstractmethod + async def execute_many(self, query, params=None, batch_size=1000): + pass + + @abstractmethod + def fetch_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + ): + pass + + @abstractmethod + def fetchrow_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + ): + pass + + @abstractmethod + async def initialize(self, pool: Any): + pass + + +class Handler(ABC): + def __init__( + self, + project_name: str, + connection_manager: DatabaseConnectionManager, + ): + self.project_name = project_name + self.connection_manager = connection_manager + + def _get_table_name(self, base_name: str) -> str: + return f"{self.project_name}.{base_name}" + + @abstractmethod + def create_tables(self): + pass + + +class PostgresConfigurationSettings(BaseModel): + """Configuration settings with defaults defined by the PGVector docker + image. + + These settings are helpful in managing the connections to the database. To + tune these settings for a specific deployment, see + https://pgtune.leopard.in.ua/ + """ + + checkpoint_completion_target: Optional[float] = 0.9 + default_statistics_target: Optional[int] = 100 + effective_io_concurrency: Optional[int] = 1 + effective_cache_size: Optional[int] = 524288 + huge_pages: Optional[str] = "try" + maintenance_work_mem: Optional[int] = 65536 + max_connections: Optional[int] = 256 + max_parallel_workers_per_gather: Optional[int] = 2 + max_parallel_workers: Optional[int] = 8 + max_parallel_maintenance_workers: Optional[int] = 2 + max_wal_size: Optional[int] = 1024 + max_worker_processes: Optional[int] = 8 + min_wal_size: Optional[int] = 80 + shared_buffers: Optional[int] = 16384 + statement_cache_size: Optional[int] = 100 + random_page_cost: Optional[float] = 4 + wal_buffers: Optional[int] = 512 + work_mem: Optional[int] = 4096 + + +class LimitSettings(BaseModel): + global_per_min: Optional[int] = None + route_per_min: Optional[int] = None + monthly_limit: Optional[int] = None + + def merge_with_defaults( + self, defaults: "LimitSettings" + ) -> "LimitSettings": + return LimitSettings( + global_per_min=self.global_per_min or defaults.global_per_min, + route_per_min=self.route_per_min or defaults.route_per_min, + monthly_limit=self.monthly_limit or defaults.monthly_limit, + ) + + +class DatabaseConfig(ProviderConfig): + """A base database configuration class.""" + + provider: str = "postgres" + user: Optional[str] = None + password: Optional[str] = None + host: Optional[str] = None + port: Optional[int] = None + db_name: Optional[str] = None + project_name: Optional[str] = None + postgres_configuration_settings: Optional[ + PostgresConfigurationSettings + ] = None + default_collection_name: str = "Default" + default_collection_description: str = "Your default collection." + collection_summary_system_prompt: str = "system" + collection_summary_prompt: str = "collection_summary" + enable_fts: bool = False + + # Graph settings + batch_size: Optional[int] = 1 + graph_search_results_store_path: Optional[str] = None + graph_enrichment_settings: GraphEnrichmentSettings = ( + GraphEnrichmentSettings() + ) + graph_creation_settings: GraphCreationSettings = GraphCreationSettings() + graph_search_settings: GraphSearchSettings = GraphSearchSettings() + + # Rate limits + limits: LimitSettings = LimitSettings( + global_per_min=60, route_per_min=20, monthly_limit=10000 + ) + route_limits: dict[str, LimitSettings] = {} + user_limits: dict[UUID, LimitSettings] = {} + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["postgres"] + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig": + instance = cls.create(**data) + + instance = cast(DatabaseConfig, instance) + + limits_data = data.get("limits", {}) + default_limits = LimitSettings( + global_per_min=limits_data.get("global_per_min", 60), + route_per_min=limits_data.get("route_per_min", 20), + monthly_limit=limits_data.get("monthly_limit", 10000), + ) + + instance.limits = default_limits + + route_limits_data = limits_data.get("routes", {}) + for route_str, route_cfg in route_limits_data.items(): + instance.route_limits[route_str] = LimitSettings(**route_cfg) + + return instance + + +class DatabaseProvider(Provider): + connection_manager: DatabaseConnectionManager + config: DatabaseConfig + project_name: str + + def __init__(self, config: DatabaseConfig): + logger.info(f"Initializing DatabaseProvider with config {config}.") + super().__init__(config) + + @abstractmethod + async def __aenter__(self): + pass + + @abstractmethod + async def __aexit__(self, exc_type, exc, tb): + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/email.py b/.venv/lib/python3.12/site-packages/core/base/providers/email.py new file mode 100644 index 00000000..73f88162 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/email.py @@ -0,0 +1,96 @@ +import logging +import os +from abc import ABC, abstractmethod +from typing import Optional + +from .base import Provider, ProviderConfig + + +class EmailConfig(ProviderConfig): + smtp_server: Optional[str] = None + smtp_port: Optional[int] = None + smtp_username: Optional[str] = None + smtp_password: Optional[str] = None + from_email: Optional[str] = None + use_tls: Optional[bool] = True + sendgrid_api_key: Optional[str] = None + mailersend_api_key: Optional[str] = None + verify_email_template_id: Optional[str] = None + reset_password_template_id: Optional[str] = None + password_changed_template_id: Optional[str] = None + frontend_url: Optional[str] = None + sender_name: Optional[str] = None + + @property + def supported_providers(self) -> list[str]: + return [ + "smtp", + "console", + "sendgrid", + "mailersend", + ] # Could add more providers like AWS SES, SendGrid etc. + + def validate_config(self) -> None: + if ( + self.provider == "sendgrid" + and not self.sendgrid_api_key + and not os.getenv("SENDGRID_API_KEY") + ): + raise ValueError( + "SendGrid API key is required when using SendGrid provider" + ) + + if ( + self.provider == "mailersend" + and not self.mailersend_api_key + and not os.getenv("MAILERSEND_API_KEY") + ): + raise ValueError( + "MailerSend API key is required when using MailerSend provider" + ) + + +logger = logging.getLogger(__name__) + + +class EmailProvider(Provider, ABC): + def __init__(self, config: EmailConfig): + if not isinstance(config, EmailConfig): + raise ValueError( + "EmailProvider must be initialized with an EmailConfig" + ) + super().__init__(config) + self.config: EmailConfig = config + + @abstractmethod + async def send_email( + self, + to_email: str, + subject: str, + body: str, + html_body: Optional[str] = None, + *args, + **kwargs, + ) -> None: + pass + + @abstractmethod + async def send_verification_email( + self, to_email: str, verification_code: str, *args, **kwargs + ) -> None: + pass + + @abstractmethod + async def send_password_reset_email( + self, to_email: str, reset_token: str, *args, **kwargs + ) -> None: + pass + + @abstractmethod + async def send_password_changed_email( + self, + to_email: str, + *args, + **kwargs, + ) -> None: + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py new file mode 100644 index 00000000..d1f9f9d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/embedding.py @@ -0,0 +1,197 @@ +import asyncio +import logging +import random +import time +from abc import abstractmethod +from enum import Enum +from typing import Any, Optional + +from litellm import AuthenticationError + +from core.base.abstractions import VectorQuantizationSettings + +from ..abstractions import ( + ChunkSearchResult, + EmbeddingPurpose, + default_embedding_prefixes, +) +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class EmbeddingConfig(ProviderConfig): + provider: str + base_model: str + base_dimension: int | float + rerank_model: Optional[str] = None + rerank_url: Optional[str] = None + batch_size: int = 1 + prefixes: Optional[dict[str, str]] = None + add_title_as_prefix: bool = True + concurrent_request_limit: int = 256 + max_retries: int = 3 + initial_backoff: float = 1 + max_backoff: float = 64.0 + quantization_settings: VectorQuantizationSettings = ( + VectorQuantizationSettings() + ) + + ## deprecated + rerank_dimension: Optional[int] = None + rerank_transformer_type: Optional[str] = None + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["litellm", "openai", "ollama"] + + +class EmbeddingProvider(Provider): + class Step(Enum): + BASE = 1 + RERANK = 2 + + def __init__(self, config: EmbeddingConfig): + if not isinstance(config, EmbeddingConfig): + raise ValueError( + "EmbeddingProvider must be initialized with a `EmbeddingConfig`." + ) + logger.info(f"Initializing EmbeddingProvider with config {config}.") + + super().__init__(config) + self.config: EmbeddingConfig = config + self.semaphore = asyncio.Semaphore(config.concurrent_request_limit) + self.current_requests = 0 + + async def _execute_with_backoff_async(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + async with self.semaphore: + return await self._execute_task(task) + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + await asyncio.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + def _execute_with_backoff_sync(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + return self._execute_task_sync(task) + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + time.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + @abstractmethod + async def _execute_task(self, task: dict[str, Any]): + pass + + @abstractmethod + def _execute_task_sync(self, task: dict[str, Any]): + pass + + async def async_get_embedding( + self, + text: str, + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ): + task = { + "text": text, + "stage": stage, + "purpose": purpose, + } + return await self._execute_with_backoff_async(task) + + def get_embedding( + self, + text: str, + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ): + task = { + "text": text, + "stage": stage, + "purpose": purpose, + } + return self._execute_with_backoff_sync(task) + + async def async_get_embeddings( + self, + texts: list[str], + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ): + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + } + return await self._execute_with_backoff_async(task) + + def get_embeddings( + self, + texts: list[str], + stage: Step = Step.BASE, + purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, + ) -> list[list[float]]: + task = { + "texts": texts, + "stage": stage, + "purpose": purpose, + } + return self._execute_with_backoff_sync(task) + + @abstractmethod + def rerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: Step = Step.RERANK, + limit: int = 10, + ): + pass + + @abstractmethod + async def arerank( + self, + query: str, + results: list[ChunkSearchResult], + stage: Step = Step.RERANK, + limit: int = 10, + ): + pass + + def set_prefixes(self, config_prefixes: dict[str, str], base_model: str): + self.prefixes = {} + + for t, p in config_prefixes.items(): + purpose = EmbeddingPurpose(t.lower()) + self.prefixes[purpose] = p + + if base_model in default_embedding_prefixes: + for t, p in default_embedding_prefixes[base_model].items(): + if t not in self.prefixes: + self.prefixes[t] = p diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py new file mode 100644 index 00000000..70d0d3a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/ingestion.py @@ -0,0 +1,172 @@ +import logging +from abc import ABC +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +from pydantic import Field + +from core.base.abstractions import ChunkEnrichmentSettings + +from .base import AppConfig, Provider, ProviderConfig +from .llm import CompletionProvider + +logger = logging.getLogger() + +if TYPE_CHECKING: + from core.providers.database import PostgresDatabaseProvider + + +class ChunkingStrategy(str, Enum): + RECURSIVE = "recursive" + CHARACTER = "character" + BASIC = "basic" + BY_TITLE = "by_title" + + +class IngestionMode(str, Enum): + hi_res = "hi-res" + fast = "fast" + custom = "custom" + + +class IngestionConfig(ProviderConfig): + _defaults: ClassVar[dict] = { + "app": AppConfig(), + "provider": "r2r", + "excluded_parsers": ["mp4"], + "chunking_strategy": "recursive", + "chunk_size": 1024, + "chunk_enrichment_settings": ChunkEnrichmentSettings(), + "extra_parsers": {}, + "audio_transcription_model": None, + "vision_img_prompt_name": "vision_img", + "vision_pdf_prompt_name": "vision_pdf", + "skip_document_summary": False, + "document_summary_system_prompt": "system", + "document_summary_task_prompt": "summary", + "document_summary_max_length": 100_000, + "chunks_for_document_summary": 128, + "document_summary_model": None, + "parser_overrides": {}, + "extra_fields": {}, + "automatic_extraction": False, + } + + provider: str = Field( + default_factory=lambda: IngestionConfig._defaults["provider"] + ) + excluded_parsers: list[str] = Field( + default_factory=lambda: IngestionConfig._defaults["excluded_parsers"] + ) + chunking_strategy: str | ChunkingStrategy = Field( + default_factory=lambda: IngestionConfig._defaults["chunking_strategy"] + ) + chunk_size: int = Field( + default_factory=lambda: IngestionConfig._defaults["chunk_size"] + ) + chunk_enrichment_settings: ChunkEnrichmentSettings = Field( + default_factory=lambda: IngestionConfig._defaults[ + "chunk_enrichment_settings" + ] + ) + extra_parsers: dict[str, Any] = Field( + default_factory=lambda: IngestionConfig._defaults["extra_parsers"] + ) + audio_transcription_model: Optional[str] = Field( + default_factory=lambda: IngestionConfig._defaults[ + "audio_transcription_model" + ] + ) + vision_img_prompt_name: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "vision_img_prompt_name" + ] + ) + vision_pdf_prompt_name: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "vision_pdf_prompt_name" + ] + ) + skip_document_summary: bool = Field( + default_factory=lambda: IngestionConfig._defaults[ + "skip_document_summary" + ] + ) + document_summary_system_prompt: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_system_prompt" + ] + ) + document_summary_task_prompt: str = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_task_prompt" + ] + ) + chunks_for_document_summary: int = Field( + default_factory=lambda: IngestionConfig._defaults[ + "chunks_for_document_summary" + ] + ) + document_summary_model: Optional[str] = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_model" + ] + ) + parser_overrides: dict[str, str] = Field( + default_factory=lambda: IngestionConfig._defaults["parser_overrides"] + ) + automatic_extraction: bool = Field( + default_factory=lambda: IngestionConfig._defaults[ + "automatic_extraction" + ] + ) + document_summary_max_length: int = Field( + default_factory=lambda: IngestionConfig._defaults[ + "document_summary_max_length" + ] + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in IngestionConfig" + ) + + @property + def supported_providers(self) -> list[str]: + return ["r2r", "unstructured_local", "unstructured_api"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @classmethod + def get_default(cls, mode: str, app) -> "IngestionConfig": + """Return default ingestion configuration for a given mode.""" + if mode == "hi-res": + return cls(app=app, parser_overrides={"pdf": "zerox"}) + if mode == "fast": + return cls(app=app, skip_document_summary=True) + else: + return cls(app=app) + + +class IngestionProvider(Provider, ABC): + config: IngestionConfig + database_provider: "PostgresDatabaseProvider" + llm_provider: CompletionProvider + + def __init__( + self, + config: IngestionConfig, + database_provider: "PostgresDatabaseProvider", + llm_provider: CompletionProvider, + ): + super().__init__(config) + self.config: IngestionConfig = config + self.llm_provider = llm_provider + self.database_provider: "PostgresDatabaseProvider" = database_provider diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/llm.py b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py new file mode 100644 index 00000000..669dfc4f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/llm.py @@ -0,0 +1,200 @@ +import asyncio +import logging +import random +import time +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncGenerator, Generator, Optional + +from litellm import AuthenticationError + +from core.base.abstractions import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, +) + +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class CompletionConfig(ProviderConfig): + provider: Optional[str] = None + generation_config: Optional[GenerationConfig] = None + concurrent_request_limit: int = 256 + max_retries: int = 3 + initial_backoff: float = 1.0 + max_backoff: float = 64.0 + + def validate_config(self) -> None: + if not self.provider: + raise ValueError("Provider must be set.") + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["anthropic", "litellm", "openai", "r2r"] + + +class CompletionProvider(Provider): + def __init__(self, config: CompletionConfig) -> None: + if not isinstance(config, CompletionConfig): + raise ValueError( + "CompletionProvider must be initialized with a `CompletionConfig`." + ) + logger.info(f"Initializing CompletionProvider with config: {config}") + super().__init__(config) + self.config: CompletionConfig = config + self.semaphore = asyncio.Semaphore(config.concurrent_request_limit) + self.thread_pool = ThreadPoolExecutor( + max_workers=config.concurrent_request_limit + ) + + async def _execute_with_backoff_async(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + async with self.semaphore: + return await self._execute_task(task) + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + await asyncio.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + async def _execute_with_backoff_async_stream( + self, task: dict[str, Any] + ) -> AsyncGenerator[Any, None]: + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + async with self.semaphore: + async for chunk in await self._execute_task(task): + yield chunk + return # Successful completion of the stream + except AuthenticationError: + raise + except Exception as e: + logger.warning( + f"Streaming request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + await asyncio.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + def _execute_with_backoff_sync(self, task: dict[str, Any]): + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + return self._execute_task_sync(task) + except Exception as e: + logger.warning( + f"Request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + time.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + def _execute_with_backoff_sync_stream( + self, task: dict[str, Any] + ) -> Generator[Any, None, None]: + retries = 0 + backoff = self.config.initial_backoff + while retries < self.config.max_retries: + try: + yield from self._execute_task_sync(task) + return # Successful completion of the stream + except Exception as e: + logger.warning( + f"Streaming request failed (attempt {retries + 1}): {str(e)}" + ) + retries += 1 + if retries == self.config.max_retries: + raise + time.sleep(random.uniform(0, backoff)) + backoff = min(backoff * 2, self.config.max_backoff) + + @abstractmethod + async def _execute_task(self, task: dict[str, Any]): + pass + + @abstractmethod + def _execute_task_sync(self, task: dict[str, Any]): + pass + + async def aget_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + task = { + "messages": messages, + "generation_config": generation_config, + "kwargs": kwargs, + } + response = await self._execute_with_backoff_async(task) + return LLMChatCompletion(**response.dict()) + + async def aget_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> AsyncGenerator[LLMChatCompletionChunk, None]: + generation_config.stream = True + task = { + "messages": messages, + "generation_config": generation_config, + "kwargs": kwargs, + } + async for chunk in self._execute_with_backoff_async_stream(task): + if isinstance(chunk, dict): + yield LLMChatCompletionChunk(**chunk) + continue + + chunk.choices[0].finish_reason = ( + chunk.choices[0].finish_reason + if chunk.choices[0].finish_reason != "" + else None + ) # handle error output conventions + chunk.choices[0].finish_reason = ( + chunk.choices[0].finish_reason + if chunk.choices[0].finish_reason != "eos" + else "stop" + ) # hardcode `eos` to `stop` for consistency + try: + yield LLMChatCompletionChunk(**(chunk.dict())) + except Exception as e: + logger.error(f"Error parsing chunk: {e}") + yield LLMChatCompletionChunk(**(chunk.as_dict())) + + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> Generator[LLMChatCompletionChunk, None, None]: + generation_config.stream = True + task = { + "messages": messages, + "generation_config": generation_config, + "kwargs": kwargs, + } + for chunk in self._execute_with_backoff_sync_stream(task): + yield LLMChatCompletionChunk(**chunk.dict()) diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py new file mode 100644 index 00000000..c3105f30 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/orchestration.py @@ -0,0 +1,70 @@ +from abc import abstractmethod +from enum import Enum +from typing import Any + +from .base import Provider, ProviderConfig + + +class Workflow(Enum): + INGESTION = "ingestion" + GRAPH = "graph" + + +class OrchestrationConfig(ProviderConfig): + provider: str + max_runs: int = 2_048 + graph_search_results_creation_concurrency_limit: int = 32 + ingestion_concurrency_limit: int = 16 + graph_search_results_concurrency_limit: int = 8 + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["hatchet", "simple"] + + +class OrchestrationProvider(Provider): + def __init__(self, config: OrchestrationConfig): + super().__init__(config) + self.config = config + self.worker = None + + @abstractmethod + async def start_worker(self): + pass + + @abstractmethod + def get_worker(self, name: str, max_runs: int) -> Any: + pass + + @abstractmethod + def step(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def workflow(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def failure(self, *args, **kwargs) -> Any: + pass + + @abstractmethod + def register_workflows( + self, workflow: Workflow, service: Any, messages: dict + ) -> None: + pass + + @abstractmethod + async def run_workflow( + self, + workflow_name: str, + parameters: dict, + options: dict, + *args, + **kwargs, + ) -> dict[str, str]: + pass diff --git a/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py b/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py new file mode 100644 index 00000000..948a1069 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/utils/__init__.py @@ -0,0 +1,43 @@ +from shared.utils import ( + RecursiveCharacterTextSplitter, + TextSplitter, + _decorate_vector_type, + _get_vector_column_str, + decrement_version, + deep_update, + dump_collector, + dump_obj, + format_search_results_for_llm, + generate_default_prompt_id, + generate_default_user_collection_id, + generate_document_id, + generate_entity_document_id, + generate_extraction_id, + generate_id, + generate_user_id, + increment_version, + validate_uuid, + yield_sse_event, +) + +__all__ = [ + "format_search_results_for_llm", + "generate_id", + "generate_default_user_collection_id", + "increment_version", + "decrement_version", + "generate_document_id", + "generate_extraction_id", + "generate_user_id", + "generate_entity_document_id", + "generate_default_prompt_id", + "RecursiveCharacterTextSplitter", + "TextSplitter", + "validate_uuid", + "deep_update", + "_decorate_vector_type", + "_get_vector_column_str", + "yield_sse_event", + "dump_collector", + "dump_obj", +] |