aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/__init__.py77
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py133
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py166
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py701
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py249
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py195
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py181
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/base.py247
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/chunks.py1316
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/collections.py701
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/conversations.py858
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/documents.py1172
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/files.py334
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/filters.py478
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/graphs.py2884
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/limits.py434
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/postgres.py286
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py0
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml56
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml41
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml28
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml99
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml74
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml40
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml100
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml29
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml29
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml27
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml16
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml61
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml18
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml3
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml4
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml42
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py748
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/tokens.py67
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/users.py1325
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py67
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py281
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py257
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/email/smtp.py176
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py9
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py305
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py194
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py243
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py13
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py355
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py396
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py925
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py110
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py80
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/openai.py522
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py96
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/llm/utils.py106
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py4
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py105
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py61
62 files changed, 17571 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/__init__.py
new file mode 100644
index 00000000..7cfa82eb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/__init__.py
@@ -0,0 +1,77 @@
+from .auth import (
+ ClerkAuthProvider,
+ JwtAuthProvider,
+ R2RAuthProvider,
+ SupabaseAuthProvider,
+)
+from .crypto import (
+ BcryptCryptoConfig,
+ BCryptCryptoProvider,
+ NaClCryptoConfig,
+ NaClCryptoProvider,
+)
+from .database import PostgresDatabaseProvider
+from .email import (
+ AsyncSMTPEmailProvider,
+ ConsoleMockEmailProvider,
+ MailerSendEmailProvider,
+ SendGridEmailProvider,
+)
+from .embeddings import (
+ LiteLLMEmbeddingProvider,
+ OllamaEmbeddingProvider,
+ OpenAIEmbeddingProvider,
+)
+from .ingestion import ( # type: ignore
+ R2RIngestionConfig,
+ R2RIngestionProvider,
+ UnstructuredIngestionConfig,
+ UnstructuredIngestionProvider,
+)
+from .llm import (
+ AnthropicCompletionProvider,
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+from .orchestration import (
+ HatchetOrchestrationProvider,
+ SimpleOrchestrationProvider,
+)
+
+__all__ = [
+ # Auth
+ "R2RAuthProvider",
+ "SupabaseAuthProvider",
+ "JwtAuthProvider",
+ "ClerkAuthProvider",
+ # Ingestion
+ "R2RIngestionProvider",
+ "R2RIngestionConfig",
+ "UnstructuredIngestionProvider",
+ "UnstructuredIngestionConfig",
+ # Crypto
+ "BCryptCryptoProvider",
+ "BcryptCryptoConfig",
+ "NaClCryptoConfig",
+ "NaClCryptoProvider",
+ # Embeddings
+ "LiteLLMEmbeddingProvider",
+ "OllamaEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ # Database
+ "PostgresDatabaseProvider",
+ # Email
+ "AsyncSMTPEmailProvider",
+ "ConsoleMockEmailProvider",
+ "SendGridEmailProvider",
+ "MailerSendEmailProvider",
+ # Orchestration
+ "HatchetOrchestrationProvider",
+ "SimpleOrchestrationProvider",
+ # LLM
+ "AnthropicCompletionProvider",
+ "OpenAICompletionProvider",
+ "R2RCompletionProvider",
+ "LiteLLMCompletionProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py
new file mode 100644
index 00000000..9f116ffa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py
@@ -0,0 +1,11 @@
+from .clerk import ClerkAuthProvider
+from .jwt import JwtAuthProvider
+from .r2r_auth import R2RAuthProvider
+from .supabase import SupabaseAuthProvider
+
+__all__ = [
+ "R2RAuthProvider",
+ "SupabaseAuthProvider",
+ "JwtAuthProvider",
+ "ClerkAuthProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py
new file mode 100644
index 00000000..0db665e0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py
@@ -0,0 +1,133 @@
+import logging
+import os
+from datetime import datetime
+
+from core.base import (
+ AuthConfig,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ TokenData,
+)
+
+from ..database import PostgresDatabaseProvider
+from .jwt import JwtAuthProvider
+
+logger = logging.getLogger(__name__)
+
+
+class ClerkAuthProvider(JwtAuthProvider):
+ """
+ ClerkAuthProvider extends JwtAuthProvider to support token verification with Clerk.
+ It uses Clerk's SDK to verify the JWT token and extract user information.
+ """
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config=config,
+ crypto_provider=crypto_provider,
+ database_provider=database_provider,
+ email_provider=email_provider,
+ )
+ try:
+ from clerk_backend_api.jwks_helpers.verifytoken import (
+ VerifyTokenOptions,
+ verify_token,
+ )
+
+ self.verify_token = verify_token
+ self.VerifyTokenOptions = VerifyTokenOptions
+ except ImportError as e:
+ raise R2RException(
+ status_code=500,
+ message="Clerk SDK is not installed. Run `pip install clerk-backend-api`",
+ ) from e
+
+ async def decode_token(self, token: str) -> TokenData:
+ """
+ Decode and verify the JWT token using Clerk's verify_token function.
+
+ Args:
+ token: The JWT token to decode
+
+ Returns:
+ TokenData: The decoded token data with user information
+
+ Raises:
+ R2RException: If the token is invalid or verification fails
+ """
+ clerk_secret_key = os.getenv("CLERK_SECRET_KEY")
+ if not clerk_secret_key:
+ raise R2RException(
+ status_code=500,
+ message="CLERK_SECRET_KEY environment variable is not set",
+ )
+
+ try:
+ # Configure verification options
+ options = self.VerifyTokenOptions(
+ secret_key=clerk_secret_key,
+ # Optional: specify audience if needed
+ # audience="your-audience",
+ # Optional: specify authorized parties if needed
+ # authorized_parties=["https://your-domain.com"]
+ )
+
+ # Verify the token using Clerk's SDK
+ payload = self.verify_token(token, options)
+
+ # Check for the expected claims in the token payload
+ if not payload.get("sub") or not payload.get("email"):
+ raise R2RException(
+ status_code=401,
+ message="Invalid token: missing required claims",
+ )
+
+ # Create user in database if not exists
+ try:
+ await self.database_provider.users_handler.get_user_by_email(
+ payload.get("email")
+ )
+ # TODO do we want to update user info here based on what's in the token?
+ except Exception:
+ # user doesn't exist, create in db
+ logger.debug(f"Creating new user: {payload.get('email')}")
+ try:
+ # Construct name from first_name and last_name if available
+ first_name = payload.get("first_name", "")
+ last_name = payload.get("last_name", "")
+ name = payload.get("name")
+
+ # If name not directly provided, try to build it from first and last names
+ if not name and (first_name or last_name):
+ name = f"{first_name} {last_name}".strip()
+
+ await self.database_provider.users_handler.create_user(
+ email=payload.get("email"),
+ account_type="external",
+ name=name,
+ )
+ except Exception as e:
+ logger.error(f"Error creating user: {e}")
+ raise R2RException(
+ status_code=500, message="Failed to create user"
+ ) from e
+
+ # Return the token data
+ return TokenData(
+ email=payload.get("email"),
+ token_type="bearer",
+ exp=datetime.fromtimestamp(payload.get("exp")),
+ )
+
+ except Exception as e:
+ logger.info(f"Clerk token verification failed: {e}")
+ raise R2RException(
+ status_code=401, message="Invalid token", detail=str(e)
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py
new file mode 100644
index 00000000..08f85e6d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py
@@ -0,0 +1,166 @@
+import logging
+import os
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+import jwt
+from fastapi import Depends
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+logger = logging.getLogger()
+
+
+class JwtAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ async def oauth_callback(self, code: str) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ async def user(self, token: str) -> User:
+ raise NotImplementedError("Not implemented")
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ def create_access_token(self, data: dict) -> str:
+ raise NotImplementedError("Not implemented")
+
+ def create_refresh_token(self, data: dict) -> str:
+ raise NotImplementedError("Not implemented")
+
+ async def decode_token(self, token: str) -> TokenData:
+ # use JWT library to validate and decode JWT token
+ jwtSecret = os.getenv("JWT_SECRET")
+ if jwtSecret is None:
+ raise R2RException(
+ status_code=500,
+ message="JWT_SECRET environment variable is not set",
+ )
+ try:
+ user = jwt.decode(token, jwtSecret, algorithms=["HS256"])
+ except Exception as e:
+ logger.info(f"JWT verification failed: {e}")
+ raise R2RException(
+ status_code=401, message="Invalid JWT token", detail=e
+ ) from e
+ if user:
+ # Create user in database if not exists
+ try:
+ await self.database_provider.users_handler.get_user_by_email(
+ user.get("email")
+ )
+ # TODO do we want to update user info here based on what's in the token?
+ except Exception:
+ # user doesn't exist, create in db
+ logger.debug(f"Creating new user: {user.get('email')}")
+ try:
+ await self.database_provider.users_handler.create_user(
+ email=user.get("email"),
+ account_type="external",
+ name=user.get("name"),
+ )
+ except Exception as e:
+ logger.error(f"Error creating user: {e}")
+ raise R2RException(
+ status_code=500, message="Failed to create user"
+ ) from e
+ return TokenData(
+ email=user.get("email"),
+ token_type="bearer",
+ exp=user.get("exp"),
+ )
+ else:
+ raise R2RException(status_code=401, message="Invalid JWT token")
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ # Check if user is active
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def logout(self, token: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User: # type: ignore
+ raise NotImplementedError("Not implemented")
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ raise NotImplementedError("Not implemented")
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ raise NotImplementedError("Not implemented")
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ raise NotImplementedError("Not implemented")
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py
new file mode 100644
index 00000000..762884ce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py
@@ -0,0 +1,701 @@
+import logging
+import os
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CollectionResponse,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
+DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
+
+logger = logging.getLogger()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+def normalize_email(email: str) -> str:
+ """Normalizes an email address by converting it to lowercase. This ensures
+ consistent email handling throughout the application.
+
+ Args:
+ email: The email address to normalize
+
+ Returns:
+ The normalized (lowercase) email address
+ """
+ return email.lower() if email else ""
+
+
+class R2RAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ logger.debug(f"Initializing R2RAuthProvider with config: {config}")
+
+ # We no longer use a local secret_key or defaults here.
+ # All key handling is done in the crypto_provider.
+ self.access_token_lifetime_in_minutes = (
+ config.access_token_lifetime_in_minutes
+ or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
+ or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
+ )
+ self.refresh_token_lifetime_in_days = (
+ config.refresh_token_lifetime_in_days
+ or os.getenv("R2R_REFRESH_LIFE_IN_DAYS")
+ or DEFAULT_REFRESH_LIFETIME_IN_DAYS
+ )
+ self.config: AuthConfig = config
+
+ async def initialize(self):
+ try:
+ user = await self.register(
+ email=normalize_email(self.admin_email),
+ password=self.admin_password,
+ is_superuser=True,
+ )
+ await self.database_provider.users_handler.mark_user_as_superuser(
+ id=user.id
+ )
+ except R2RException:
+ logger.info("Default admin user already exists.")
+
+ def create_access_token(self, data: dict) -> str:
+ expire = datetime.now(timezone.utc) + timedelta(
+ minutes=float(self.access_token_lifetime_in_minutes)
+ )
+ # Add token_type and pass data/expiry to crypto_provider
+ data_with_type = {**data, "token_type": "access"}
+ return self.crypto_provider.generate_secure_token(
+ data=data_with_type,
+ expiry=expire,
+ )
+
+ def create_refresh_token(self, data: dict) -> str:
+ expire = datetime.now(timezone.utc) + timedelta(
+ days=float(self.refresh_token_lifetime_in_days)
+ )
+ data_with_type = {**data, "token_type": "refresh"}
+ return self.crypto_provider.generate_secure_token(
+ data=data_with_type,
+ expiry=expire,
+ )
+
+ async def decode_token(self, token: str) -> TokenData:
+ if "token=" in token:
+ token = token.split("token=")[1]
+ if "&tokenType=refresh" in token:
+ token = token.split("&tokenType=refresh")[0]
+ # First, check if the token is blacklisted
+ if await self.database_provider.token_handler.is_token_blacklisted(
+ token=token
+ ):
+ raise R2RException(
+ status_code=401, message="Token has been invalidated"
+ )
+
+ # Verify token using crypto_provider
+ payload = self.crypto_provider.verify_secure_token(token=token)
+ if payload is None:
+ raise R2RException(
+ status_code=401, message="Invalid or expired token"
+ )
+
+ email = payload.get("sub")
+ token_type = payload.get("token_type")
+ exp = payload.get("exp")
+
+ if email is None or token_type is None or exp is None:
+ raise R2RException(status_code=401, message="Invalid token claims")
+
+ email_str: str = email
+ token_type_str: str = token_type
+ exp_float: float = exp
+
+ exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc)
+ if exp_datetime < datetime.now(timezone.utc):
+ raise R2RException(status_code=401, message="Token has expired")
+
+ return TokenData(
+ email=normalize_email(email_str),
+ token_type=token_type_str,
+ exp=exp_datetime,
+ )
+
+ async def authenticate_api_key(self, api_key: str) -> User:
+ """Authenticate using an API key of the form "public_key.raw_key".
+
+ Returns a User if successful, or raises R2RException if not.
+ """
+ try:
+ key_id, raw_key = api_key.split(".", 1)
+ except ValueError as e:
+ raise R2RException(
+ status_code=401, message="Invalid API key format"
+ ) from e
+
+ key_record = (
+ await self.database_provider.users_handler.get_api_key_record(
+ key_id=key_id
+ )
+ )
+ if not key_record:
+ raise R2RException(status_code=401, message="Invalid API key")
+
+ if not self.crypto_provider.verify_api_key(
+ raw_api_key=raw_key, hashed_key=key_record["hashed_key"]
+ ):
+ raise R2RException(status_code=401, message="Invalid API key")
+
+ user = await self.database_provider.users_handler.get_user_by_id(
+ id=key_record["user_id"]
+ )
+ if not user.is_active:
+ raise R2RException(
+ status_code=401, message="User account is inactive"
+ )
+
+ return user
+
+ async def user(self, token: str = Depends(oauth2_scheme)) -> User:
+ """Attempt to authenticate via JWT first, then fallback to API key."""
+ # Try JWT auth
+ try:
+ token_data = await self.decode_token(token=token)
+ if not token_data.email:
+ raise R2RException(
+ status_code=401, message="Could not validate credentials"
+ )
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(token_data.email)
+ )
+ )
+ if user is None:
+ raise R2RException(
+ status_code=401,
+ message="Invalid authentication credentials",
+ )
+ return user
+ except R2RException:
+ # If JWT fails, try API key auth
+ # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed
+ token = token.removeprefix("Bearer ")
+ return await self.authenticate_api_key(api_key=token)
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def register(
+ self,
+ email: str,
+ password: Optional[str] = None,
+ is_superuser: bool = False,
+ account_type: str = "password",
+ github_id: Optional[str] = None,
+ google_id: Optional[str] = None,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ if account_type == "password":
+ if not password:
+ raise R2RException(
+ status_code=400,
+ message="Password is required for password accounts",
+ )
+ else:
+ if github_id and google_id:
+ raise R2RException(
+ status_code=400,
+ message="Cannot register OAuth with both GitHub and Google IDs",
+ )
+ if not github_id and not google_id:
+ raise R2RException(
+ status_code=400,
+ message="Invalid OAuth specification without GitHub or Google ID",
+ )
+ new_user = await self.database_provider.users_handler.create_user(
+ email=normalize_email(email),
+ password=password,
+ is_superuser=is_superuser,
+ account_type=account_type,
+ github_id=github_id,
+ google_id=google_id,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+ default_collection: CollectionResponse = (
+ await self.database_provider.collections_handler.create_collection(
+ owner_id=new_user.id,
+ )
+ )
+ await self.database_provider.graphs_handler.create(
+ collection_id=default_collection.id,
+ name=default_collection.name,
+ description=default_collection.description,
+ )
+
+ await self.database_provider.users_handler.add_user_to_collection(
+ new_user.id, default_collection.id
+ )
+
+ new_user = await self.database_provider.users_handler.get_user_by_id(
+ new_user.id
+ )
+
+ if self.config.require_email_verification:
+ verification_code, _ = await self.send_verification_email(
+ email=normalize_email(email), user=new_user
+ )
+ else:
+ expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
+ await self.database_provider.users_handler.store_verification_code(
+ id=new_user.id,
+ verification_code=str(-1),
+ expiry=expiry,
+ )
+ await self.database_provider.users_handler.mark_user_as_verified(
+ id=new_user.id
+ )
+
+ return new_user
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ if user is None:
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+
+ verification_code = self.crypto_provider.generate_verification_code()
+ expiry = datetime.now(timezone.utc) + timedelta(hours=24)
+
+ await self.database_provider.users_handler.store_verification_code(
+ id=user.id,
+ verification_code=verification_code,
+ expiry=expiry,
+ )
+
+ if hasattr(user, "verification_code_expiry"):
+ user.verification_code_expiry = expiry
+
+ first_name = (
+ user.name.split(" ")[0] if user.name else email.split("@")[0]
+ )
+
+ await self.email_provider.send_verification_email(
+ to_email=user.email,
+ verification_code=verification_code,
+ dynamic_template_data={"first_name": first_name},
+ )
+
+ return verification_code, expiry
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
+ verification_code=verification_code
+ )
+ await self.database_provider.users_handler.mark_user_as_verified(
+ id=user_id
+ )
+ await self.database_provider.users_handler.remove_verification_code(
+ verification_code=verification_code
+ )
+ return {"message": "Email verified successfully"}
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ logger.debug(f"Attempting login for email: {email}")
+ user = await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+
+ if user.account_type != "password":
+ logger.warning(
+ f"Password login not allowed for {user.account_type} accounts: {email}"
+ )
+ raise R2RException(
+ status_code=401,
+ message=f"This account is configured for {user.account_type} login, not password.",
+ )
+
+ logger.debug(f"User found: {user}")
+
+ if not isinstance(user.hashed_password, str):
+ logger.error(
+ f"Invalid hashed_password type: {type(user.hashed_password)}"
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Invalid password hash in database",
+ )
+
+ try:
+ password_verified = self.crypto_provider.verify_password(
+ plain_password=password,
+ hashed_password=user.hashed_password,
+ )
+ except Exception as e:
+ logger.error(f"Error during password verification: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail="Error during password verification",
+ ) from e
+
+ if not password_verified:
+ logger.warning(f"Invalid password for user: {email}")
+ raise R2RException(
+ status_code=401, message="Incorrect email or password"
+ )
+
+ if not user.is_verified and self.config.require_email_verification:
+ logger.warning(f"Unverified user attempted login: {email}")
+ raise R2RException(status_code=401, message="Email not verified")
+
+ access_token = self.create_access_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(token=refresh_token, token_type="refresh"),
+ }
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ token_data = await self.decode_token(refresh_token)
+ if token_data.token_type != "refresh":
+ raise R2RException(
+ status_code=401, message="Invalid refresh token"
+ )
+
+ # Invalidate the old refresh token and create a new one
+ await self.database_provider.token_handler.blacklist_token(
+ token=refresh_token
+ )
+
+ new_access_token = self.create_access_token(
+ data={"sub": normalize_email(token_data.email)}
+ )
+ new_refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(token_data.email)}
+ )
+ return {
+ "access_token": Token(token=new_access_token, token_type="access"),
+ "refresh_token": Token(
+ token=new_refresh_token, token_type="refresh"
+ ),
+ }
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ if not isinstance(user.hashed_password, str):
+ logger.error(
+ f"Invalid hashed_password type: {type(user.hashed_password)}"
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Invalid password hash in database",
+ )
+
+ if not self.crypto_provider.verify_password(
+ plain_password=current_password,
+ hashed_password=user.hashed_password,
+ ):
+ raise R2RException(
+ status_code=400, message="Incorrect current password"
+ )
+
+ hashed_new_password = self.crypto_provider.get_password_hash(
+ password=new_password
+ )
+ await self.database_provider.users_handler.update_user_password(
+ id=user.id,
+ new_hashed_password=hashed_new_password,
+ )
+ try:
+ await self.email_provider.send_password_changed_email(
+ to_email=normalize_email(user.email),
+ dynamic_template_data={
+ "first_name": (
+ user.name.split(" ")[0] or "User"
+ if user.name
+ else "User"
+ )
+ },
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to send password change notification: {str(e)}"
+ )
+
+ return {"message": "Password changed successfully"}
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ try:
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+ )
+
+ reset_token = self.crypto_provider.generate_verification_code()
+ expiry = datetime.now(timezone.utc) + timedelta(hours=1)
+ await self.database_provider.users_handler.store_reset_token(
+ id=user.id,
+ reset_token=reset_token,
+ expiry=expiry,
+ )
+
+ first_name = (
+ user.name.split(" ")[0] if user.name else email.split("@")[0]
+ )
+ await self.email_provider.send_password_reset_email(
+ to_email=normalize_email(email),
+ reset_token=reset_token,
+ dynamic_template_data={"first_name": first_name},
+ )
+
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ except R2RException as e:
+ if e.status_code == 404:
+ # User doesn't exist; return a success message anyway
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ else:
+ raise
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
+ reset_token=reset_token
+ )
+ if not user_id:
+ raise R2RException(
+ status_code=400, message="Invalid or expired reset token"
+ )
+
+ hashed_new_password = self.crypto_provider.get_password_hash(
+ password=new_password
+ )
+ await self.database_provider.users_handler.update_user_password(
+ id=user_id,
+ new_hashed_password=hashed_new_password,
+ )
+ await self.database_provider.users_handler.remove_reset_token(
+ id=user_id
+ )
+ # Get the user information
+ user = await self.database_provider.users_handler.get_user_by_id(
+ id=user_id
+ )
+
+ try:
+ await self.email_provider.send_password_changed_email(
+ to_email=normalize_email(user.email),
+ dynamic_template_data={
+ "first_name": (
+ user.name.split(" ")[0] or "User"
+ if user.name
+ else "User"
+ )
+ },
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to send password change notification: {str(e)}"
+ )
+
+ return {"message": "Password reset successfully"}
+
+ async def logout(self, token: str) -> dict[str, str]:
+ await self.database_provider.token_handler.blacklist_token(token=token)
+ return {"message": "Logged out successfully"}
+
+ async def clean_expired_blacklisted_tokens(self):
+ await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
+
+ async def send_reset_email(self, email: str) -> dict:
+ verification_code, expiry = await self.send_verification_email(
+ email=normalize_email(email)
+ )
+
+ return {
+ "verification_code": verification_code,
+ "expiry": expiry,
+ "message": f"Verification email sent successfully to {email}",
+ }
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ key_id, raw_api_key = self.crypto_provider.generate_api_key()
+ hashed_key = self.crypto_provider.hash_api_key(raw_api_key)
+
+ api_key_uuid = (
+ await self.database_provider.users_handler.store_user_api_key(
+ user_id=user_id,
+ key_id=key_id,
+ hashed_key=hashed_key,
+ name=name,
+ description=description,
+ )
+ )
+
+ return {
+ "api_key": f"{key_id}.{raw_api_key}",
+ "key_id": str(api_key_uuid),
+ "public_key": key_id,
+ "name": name or "",
+ }
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ return await self.database_provider.users_handler.get_user_api_keys(
+ user_id=user_id
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ return await self.database_provider.users_handler.delete_api_key(
+ user_id=user_id,
+ key_id=key_id,
+ )
+
+ async def rename_api_key(
+ self, user_id: UUID, key_id: UUID, new_name: str
+ ) -> bool:
+ return await self.database_provider.users_handler.update_api_key_name(
+ user_id=user_id,
+ key_id=key_id,
+ name=new_name,
+ )
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ """Handles a login/registration flow for OAuth providers (e.g., Google
+ or GitHub).
+
+ :param provider: "google" or "github"
+ :param oauth_id: The unique ID from the OAuth provider (e.g. Google's
+ 'sub')
+ :param email: The user's email from the provider, if available.
+ :return: dict with access_token and refresh_token
+ """
+ # 1) Attempt to find user by google_id or github_id, or by email
+ # The logic depends on your preference. We'll assume "google" => google_id, etc.
+ try:
+ if provider == "google":
+ try:
+ user = await self.database_provider.users_handler.get_user_by_email(
+ normalize_email(email)
+ )
+ # If user found, check if user.google_id matches or is null. If null, update it
+ if user and not user.google_id:
+ raise R2RException(
+ status_code=401,
+ message="User already exists and is not linked to Google account",
+ )
+ except Exception:
+ # Create new user
+ user = await self.register(
+ email=normalize_email(email)
+ or f"{oauth_id}@google_oauth.fake", # fallback
+ password=None, # no password
+ account_type="oauth",
+ google_id=oauth_id,
+ )
+ elif provider == "github":
+ try:
+ user = await self.database_provider.users_handler.get_user_by_email(
+ normalize_email(email)
+ )
+ # If user found, check if user.google_id matches or is null. If null, update it
+ if user and not user.github_id:
+ raise R2RException(
+ status_code=401,
+ message="User already exists and is not linked to Github account",
+ )
+ except Exception:
+ # Create new user
+ user = await self.register(
+ email=normalize_email(email)
+ or f"{oauth_id}@github_oauth.fake", # fallback
+ password=None, # no password
+ account_type="oauth",
+ github_id=oauth_id,
+ )
+ # else handle other providers
+
+ except R2RException:
+ # If no user found or creation fails
+ raise R2RException(
+ status_code=401, message="Could not create or fetch user"
+ ) from None
+
+ # If user is inactive, etc.
+ if not user.is_active:
+ raise R2RException(
+ status_code=401, message="User account is inactive"
+ )
+
+ # Possibly mark user as verified if you trust the OAuth provider's email
+ user.is_verified = True
+ await self.database_provider.users_handler.update_user(user)
+
+ # 2) Generate tokens
+ access_token = self.create_access_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(user.email)}
+ )
+
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(token=refresh_token, token_type="refresh"),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py
new file mode 100644
index 00000000..5fc0e0bf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py
@@ -0,0 +1,249 @@
+import logging
+import os
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from supabase import Client, create_client
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+logger = logging.getLogger()
+
+logger = logging.getLogger()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+class SupabaseAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+ self.supabase_url = config.extra_fields.get(
+ "supabase_url", None
+ ) or os.getenv("SUPABASE_URL")
+ self.supabase_key = config.extra_fields.get(
+ "supabase_key", None
+ ) or os.getenv("SUPABASE_KEY")
+ if not self.supabase_url or not self.supabase_key:
+ raise HTTPException(
+ status_code=500,
+ detail="Supabase URL and key must be provided",
+ )
+ self.supabase: Client = create_client(
+ self.supabase_url, self.supabase_key
+ )
+
+ async def initialize(self):
+ # No initialization needed for Supabase
+ pass
+
+ def create_access_token(self, data: dict) -> str:
+ raise NotImplementedError(
+ "create_access_token is not used with Supabase authentication"
+ )
+
+ def create_refresh_token(self, data: dict) -> str:
+ raise NotImplementedError(
+ "create_refresh_token is not used with Supabase authentication"
+ )
+
+ async def decode_token(self, token: str) -> TokenData:
+ raise NotImplementedError(
+ "decode_token is not used with Supabase authentication"
+ )
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User: # type: ignore
+ # Use Supabase client to create a new user
+
+ if self.supabase.auth.sign_up(email=email, password=password):
+ raise R2RException(
+ status_code=400,
+ message="Supabase provider implementation is still under construction",
+ )
+ else:
+ raise R2RException(
+ status_code=400, message="User registration failed"
+ )
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ raise NotImplementedError(
+ "send_verification_email is not used with Supabase"
+ )
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ # Use Supabase client to verify email
+ if self.supabase.auth.verify_email(email, verification_code):
+ return {"message": "Email verified successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ # Use Supabase client to authenticate user and get tokens
+ if response := self.supabase.auth.sign_in(
+ email=email, password=password
+ ):
+ access_token = response.access_token
+ refresh_token = response.refresh_token
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(
+ token=refresh_token, token_type="refresh"
+ ),
+ }
+ else:
+ raise R2RException(
+ status_code=401, message="Invalid email or password"
+ )
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ # Use Supabase client to refresh access token
+ if response := self.supabase.auth.refresh_access_token(refresh_token):
+ new_access_token = response.access_token
+ new_refresh_token = response.refresh_token
+ return {
+ "access_token": Token(
+ token=new_access_token, token_type="access"
+ ),
+ "refresh_token": Token(
+ token=new_refresh_token, token_type="refresh"
+ ),
+ }
+ else:
+ raise R2RException(
+ status_code=401, message="Invalid refresh token"
+ )
+
+ async def user(self, token: str = Depends(oauth2_scheme)) -> User:
+ # Use Supabase client to get user details from token
+ if user := self.supabase.auth.get_user(token).user:
+ return User(
+ id=user.id,
+ email=user.email,
+ is_active=True, # Assuming active if exists in Supabase
+ is_superuser=False, # Default to False unless explicitly set
+ created_at=user.created_at,
+ updated_at=user.updated_at,
+ is_verified=user.email_confirmed_at is not None,
+ name=user.user_metadata.get("full_name"),
+ # Set other optional fields if available in user metadata
+ )
+
+ else:
+ raise R2RException(status_code=401, message="Invalid token")
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ # Check if user is active
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ # Use Supabase client to update user password
+ if self.supabase.auth.update(user.id, {"password": new_password}):
+ return {"message": "Password changed successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Failed to change password"
+ )
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ # Use Supabase client to send password reset email
+ if self.supabase.auth.send_password_reset_email(email):
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ else:
+ raise R2RException(
+ status_code=400, message="Failed to send password reset email"
+ )
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ # Use Supabase client to reset password with token
+ if self.supabase.auth.reset_password_for_email(
+ reset_token, new_password
+ ):
+ return {"message": "Password reset successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Invalid or expired reset token"
+ )
+
+ async def logout(self, token: str) -> dict[str, str]:
+ # Use Supabase client to logout user and revoke token
+ self.supabase.auth.sign_out(token)
+ return {"message": "Logged out successfully"}
+
+ async def clean_expired_blacklisted_tokens(self):
+ # Not applicable for Supabase, tokens are managed by Supabase
+ pass
+
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("send_reset_email is not used with Supabase")
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py
new file mode 100644
index 00000000..e509f990
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/crypto/__init__.py
@@ -0,0 +1,9 @@
+from .bcrypt import BcryptCryptoConfig, BCryptCryptoProvider
+from .nacl import NaClCryptoConfig, NaClCryptoProvider
+
+__all__ = [
+ "BCryptCryptoProvider",
+ "BcryptCryptoConfig",
+ "NaClCryptoConfig",
+ "NaClCryptoProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py b/.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py
new file mode 100644
index 00000000..9c39977c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/crypto/bcrypt.py
@@ -0,0 +1,195 @@
+import base64
+import logging
+import os
+from abc import ABC
+from datetime import datetime, timezone
+from typing import Optional, Tuple
+
+import bcrypt
+import jwt
+import nacl.encoding
+import nacl.exceptions
+import nacl.signing
+import nacl.utils
+
+from core.base import CryptoConfig, CryptoProvider
+
+DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
+
+
+class BcryptCryptoConfig(CryptoConfig):
+ provider: str = "bcrypt"
+ # Number of rounds for bcrypt (increasing this makes hashing slower but more secure)
+ bcrypt_rounds: int = 12
+ secret_key: Optional[str] = None
+ api_key_bytes: int = 32 # Length of raw API keys
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["bcrypt"]
+
+ def validate_config(self) -> None:
+ super().validate_config()
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Unsupported crypto provider: {self.provider}")
+ if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31:
+ raise ValueError("bcrypt_rounds must be between 4 and 31")
+
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ try:
+ # First try to decode as base64 (new format)
+ stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
+ except Exception:
+ # If that fails, treat as raw bcrypt hash (old format)
+ stored_hash = hashed_password.encode("utf-8")
+
+ return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
+
+
+class BCryptCryptoProvider(CryptoProvider, ABC):
+ def __init__(self, config: BcryptCryptoConfig):
+ if not isinstance(config, BcryptCryptoConfig):
+ raise ValueError(
+ "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig"
+ )
+ logging.info("Initializing BcryptCryptoProvider")
+ super().__init__(config)
+ self.config: BcryptCryptoConfig = config
+
+ # Load the secret key for JWT
+ # No fallback defaults: fail if not provided
+ self.secret_key = (
+ config.secret_key
+ or os.getenv("R2R_SECRET_KEY")
+ or DEFAULT_BCRYPT_SECRET_KEY
+ )
+ if not self.secret_key:
+ raise ValueError(
+ "No secret key provided for BcryptCryptoProvider."
+ )
+
+ def get_password_hash(self, password: str) -> str:
+ # Bcrypt expects bytes
+ password_bytes = password.encode("utf-8")
+ hashed = bcrypt.hashpw(
+ password_bytes, bcrypt.gensalt(rounds=self.config.bcrypt_rounds)
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ try:
+ # First try to decode as base64 (new format)
+ stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
+ if not stored_hash.startswith(b"$2b$"): # Valid bcrypt hash prefix
+ stored_hash = hashed_password.encode("utf-8")
+ except Exception:
+ # Otherwise raw bcrypt hash (old format)
+ stored_hash = hashed_password.encode("utf-8")
+
+ try:
+ return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
+ except ValueError as e:
+ if "Invalid salt" in str(e):
+ # If it's an invalid salt, the hash format is wrong - try the other format
+ try:
+ stored_hash = (
+ hashed_password
+ if isinstance(hashed_password, bytes)
+ else hashed_password.encode("utf-8")
+ )
+ return bcrypt.checkpw(
+ plain_password.encode("utf-8"), stored_hash
+ )
+ except ValueError:
+ return False
+ raise
+
+ def generate_verification_code(self, length: int = 32) -> str:
+ random_bytes = nacl.utils.random(length)
+ return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
+
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ signing_key = nacl.signing.SigningKey.generate()
+ verify_key = signing_key.verify_key
+
+ # Generate unique key_id
+ key_entropy = nacl.utils.random(16)
+ key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"
+
+ private_key = base64.b64encode(bytes(signing_key)).decode()
+ public_key = base64.b64encode(bytes(verify_key)).decode()
+ return key_id, private_key, public_key
+
+ def sign_request(self, private_key: str, data: str) -> str:
+ try:
+ key_bytes = base64.b64decode(private_key)
+ signing_key = nacl.signing.SigningKey(key_bytes)
+ signature = signing_key.sign(data.encode())
+ return base64.b64encode(signature.signature).decode()
+ except Exception as e:
+ raise ValueError(
+ f"Invalid private key or signing error: {str(e)}"
+ ) from e
+
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ try:
+ key_bytes = base64.b64decode(public_key)
+ verify_key = nacl.signing.VerifyKey(key_bytes)
+ signature_bytes = base64.b64decode(signature)
+ verify_key.verify(data.encode(), signature_bytes)
+ return True
+ except (nacl.exceptions.BadSignatureError, ValueError):
+ return False
+
+ def generate_api_key(self) -> Tuple[str, str]:
+ # Similar approach as with NaCl provider:
+ key_id_bytes = nacl.utils.random(16)
+ key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
+
+ # Generate raw API key
+ raw_api_key = base64.urlsafe_b64encode(
+ nacl.utils.random(self.config.api_key_bytes)
+ ).decode()
+ return key_id, raw_api_key
+
+ def hash_api_key(self, raw_api_key: str) -> str:
+ # Hash with bcrypt
+ hashed = bcrypt.hashpw(
+ raw_api_key.encode("utf-8"),
+ bcrypt.gensalt(rounds=self.config.bcrypt_rounds),
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
+ return bcrypt.checkpw(raw_api_key.encode("utf-8"), stored_hash)
+
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ now = datetime.now(timezone.utc)
+ to_encode = {
+ **data,
+ "exp": expiry.timestamp(),
+ "iat": now.timestamp(),
+ "nbf": now.timestamp(),
+ "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ }
+ return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
+
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ try:
+ payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
+ exp = payload.get("exp")
+ if exp is None or datetime.fromtimestamp(
+ exp, tz=timezone.utc
+ ) < datetime.now(timezone.utc):
+ return None
+ return payload
+ except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
+ return None
diff --git a/.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py b/.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py
new file mode 100644
index 00000000..63232565
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/crypto/nacl.py
@@ -0,0 +1,181 @@
+import base64
+import logging
+import os
+import string
+from datetime import datetime, timezone
+from typing import Optional, Tuple
+
+import jwt
+import nacl.encoding
+import nacl.exceptions
+import nacl.pwhash
+import nacl.signing
+from nacl.exceptions import BadSignatureError
+from nacl.pwhash import argon2i
+
+from core.base import CryptoConfig, CryptoProvider
+
+DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
+
+
+def encode_bytes_readable(random_bytes: bytes, chars: str) -> str:
+ """Convert random bytes to a readable string using the given character
+ set."""
+ # Each byte gives us 8 bits of randomness
+ # We use modulo to map each byte to our character set
+ result = []
+ for byte in random_bytes:
+ # Use modulo to map the byte (0-255) to our character set length
+ idx = byte % len(chars)
+ result.append(chars[idx])
+ return "".join(result)
+
+
+class NaClCryptoConfig(CryptoConfig):
+ provider: str = "nacl"
+ # Interactive parameters for password ops (fast)
+ ops_limit: int = argon2i.OPSLIMIT_MIN
+ mem_limit: int = argon2i.MEMLIMIT_MIN
+ # Sensitive parameters for API key generation (slow but more secure)
+ api_ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
+ api_mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
+ api_key_bytes: int = 32
+ secret_key: Optional[str] = None
+
+
+class NaClCryptoProvider(CryptoProvider):
+ def __init__(self, config: NaClCryptoConfig):
+ if not isinstance(config, NaClCryptoConfig):
+ raise ValueError(
+ "NaClCryptoProvider must be initialized with a NaClCryptoConfig"
+ )
+ super().__init__(config)
+ self.config: NaClCryptoConfig = config
+ logging.info("Initializing NaClCryptoProvider")
+
+ # Securely load the secret key for JWT
+ # Priority: config.secret_key > environment variable > default
+ self.secret_key = (
+ config.secret_key
+ or os.getenv("R2R_SECRET_KEY")
+ or DEFAULT_NACL_SECRET_KEY
+ )
+
+ def get_password_hash(self, password: str) -> str:
+ password_bytes = password.encode("utf-8")
+ hashed = nacl.pwhash.argon2i.str(
+ password_bytes,
+ opslimit=self.config.ops_limit,
+ memlimit=self.config.mem_limit,
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_password(
+ self, plain_password: str, hashed_password: str
+ ) -> bool:
+ try:
+ stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
+ nacl.pwhash.verify(stored_hash, plain_password.encode("utf-8"))
+ return True
+ except nacl.exceptions.InvalidkeyError:
+ return False
+
+ def generate_verification_code(self, length: int = 32) -> str:
+ random_bytes = nacl.utils.random(length)
+ return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
+
+ def generate_api_key(self) -> Tuple[str, str]:
+ # Define our character set (excluding ambiguous characters)
+ chars = string.ascii_letters.replace("l", "").replace("I", "").replace(
+ "O", ""
+ ) + string.digits.replace("0", "").replace("1", "")
+
+ # Generate a unique key_id
+ key_id_bytes = nacl.utils.random(16) # 16 random bytes
+ key_id = f"pk_{encode_bytes_readable(key_id_bytes, chars)}"
+
+ # Generate a high-entropy API key
+ raw_api_key = f"sk_{encode_bytes_readable(nacl.utils.random(self.config.api_key_bytes), chars)}"
+
+ # The caller will store the hashed version in the database
+ return key_id, raw_api_key
+
+ def hash_api_key(self, raw_api_key: str) -> str:
+ hashed = nacl.pwhash.argon2i.str(
+ raw_api_key.encode("utf-8"),
+ opslimit=self.config.api_ops_limit,
+ memlimit=self.config.api_mem_limit,
+ )
+ return base64.b64encode(hashed).decode("utf-8")
+
+ def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+ try:
+ stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
+ nacl.pwhash.verify(stored_hash, raw_api_key.encode("utf-8"))
+ return True
+ except nacl.exceptions.InvalidkeyError:
+ return False
+
+ def sign_request(self, private_key: str, data: str) -> str:
+ try:
+ key_bytes = base64.b64decode(private_key)
+ signing_key = nacl.signing.SigningKey(key_bytes)
+ signature = signing_key.sign(data.encode())
+ return base64.b64encode(signature.signature).decode()
+ except Exception as e:
+ raise ValueError(
+ f"Invalid private key or signing error: {str(e)}"
+ ) from e
+
+ def verify_request_signature(
+ self, public_key: str, signature: str, data: str
+ ) -> bool:
+ try:
+ key_bytes = base64.b64decode(public_key)
+ verify_key = nacl.signing.VerifyKey(key_bytes)
+ signature_bytes = base64.b64decode(signature)
+ verify_key.verify(data.encode(), signature_bytes)
+ return True
+ except (BadSignatureError, ValueError):
+ return False
+
+ def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+ """Generate a secure token using JWT with HS256.
+
+ The secret_key is used for symmetrical signing.
+ """
+ now = datetime.now(timezone.utc)
+ to_encode = {
+ **data,
+ "exp": expiry.timestamp(),
+ "iat": now.timestamp(),
+ "nbf": now.timestamp(),
+ "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
+ }
+
+ return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
+
+ def verify_secure_token(self, token: str) -> Optional[dict]:
+ """Verify a secure token using the shared secret_key and JWT."""
+ try:
+ payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
+ exp = payload.get("exp")
+ if exp is None or datetime.fromtimestamp(
+ exp, tz=timezone.utc
+ ) < datetime.now(timezone.utc):
+ return None
+ return payload
+ except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
+ return None
+
+ def generate_signing_keypair(self) -> Tuple[str, str, str]:
+ signing_key = nacl.signing.SigningKey.generate()
+ private_key_b64 = base64.b64encode(signing_key.encode()).decode()
+ public_key_b64 = base64.b64encode(
+ signing_key.verify_key.encode()
+ ).decode()
+ # Generate a unique key_id
+ key_id_bytes = nacl.utils.random(16)
+ key_id = f"sign_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
+ return (key_id, private_key_b64, public_key_b64)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/database/__init__.py
new file mode 100644
index 00000000..72e6cba8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/__init__.py
@@ -0,0 +1,5 @@
+from .postgres import PostgresDatabaseProvider
+
+__all__ = [
+ "PostgresDatabaseProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/base.py b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
new file mode 100644
index 00000000..c70c1352
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
@@ -0,0 +1,247 @@
+import asyncio
+import logging
+import textwrap
+from contextlib import asynccontextmanager
+from typing import Optional
+
+import asyncpg
+
+from core.base.providers import DatabaseConnectionManager
+
+logger = logging.getLogger()
+
+
+class SemaphoreConnectionPool:
+ def __init__(self, connection_string, postgres_configuration_settings):
+ self.connection_string = connection_string
+ self.postgres_configuration_settings = postgres_configuration_settings
+
+ async def initialize(self):
+ try:
+ logger.info(
+ f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
+ )
+
+ self.semaphore = asyncio.Semaphore(
+ int(self.postgres_configuration_settings.max_connections * 0.9)
+ )
+
+ self.pool = await asyncpg.create_pool(
+ self.connection_string,
+ max_size=self.postgres_configuration_settings.max_connections,
+ statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
+ )
+
+ logger.info(
+ "Successfully connected to Postgres database and created connection pool."
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Error {e} occurred while attempting to connect to relational database."
+ ) from e
+
+ @asynccontextmanager
+ async def get_connection(self):
+ async with self.semaphore:
+ async with self.pool.acquire() as conn:
+ yield conn
+
+ async def close(self):
+ await self.pool.close()
+
+
+class QueryBuilder:
+ def __init__(self, table_name: str):
+ self.table_name = table_name
+ self.conditions: list[str] = []
+ self.params: list = []
+ self.select_fields = "*"
+ self.operation = "SELECT"
+ self.limit_value: Optional[int] = None
+ self.offset_value: Optional[int] = None
+ self.order_by_fields: Optional[str] = None
+ self.returning_fields: Optional[list[str]] = None
+ self.insert_data: Optional[dict] = None
+ self.update_data: Optional[dict] = None
+ self.param_counter = 1
+
+ def select(self, fields: list[str]):
+ self.select_fields = ", ".join(fields)
+ return self
+
+ def insert(self, data: dict):
+ self.operation = "INSERT"
+ self.insert_data = data
+ return self
+
+ def update(self, data: dict):
+ self.operation = "UPDATE"
+ self.update_data = data
+ return self
+
+ def delete(self):
+ self.operation = "DELETE"
+ return self
+
+ def where(self, condition: str):
+ self.conditions.append(condition)
+ return self
+
+ def limit(self, value: Optional[int]):
+ self.limit_value = value
+ return self
+
+ def offset(self, value: int):
+ self.offset_value = value
+ return self
+
+ def order_by(self, fields: str):
+ self.order_by_fields = fields
+ return self
+
+ def returning(self, fields: list[str]):
+ self.returning_fields = fields
+ return self
+
+ def build(self):
+ if self.operation == "SELECT":
+ query = f"SELECT {self.select_fields} FROM {self.table_name}"
+
+ elif self.operation == "INSERT":
+ columns = ", ".join(self.insert_data.keys())
+ placeholders = ", ".join(
+ f"${i}" for i in range(1, len(self.insert_data) + 1)
+ )
+ query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
+ self.params.extend(list(self.insert_data.values()))
+
+ elif self.operation == "UPDATE":
+ set_clauses = []
+ for i, (key, value) in enumerate(
+ self.update_data.items(), start=len(self.params) + 1
+ ):
+ set_clauses.append(f"{key} = ${i}")
+ self.params.append(value)
+ query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
+
+ elif self.operation == "DELETE":
+ query = f"DELETE FROM {self.table_name}"
+
+ else:
+ raise ValueError(f"Unsupported operation: {self.operation}")
+
+ if self.conditions:
+ query += " WHERE " + " AND ".join(self.conditions)
+
+ if self.order_by_fields and self.operation == "SELECT":
+ query += f" ORDER BY {self.order_by_fields}"
+
+ if self.offset_value is not None:
+ query += f" OFFSET {self.offset_value}"
+
+ if self.limit_value is not None:
+ query += f" LIMIT {self.limit_value}"
+
+ if self.returning_fields:
+ query += f" RETURNING {', '.join(self.returning_fields)}"
+
+ return query, self.params
+
+
+class PostgresConnectionManager(DatabaseConnectionManager):
+ def __init__(self):
+ self.pool: Optional[SemaphoreConnectionPool] = None
+
+ async def initialize(self, pool: SemaphoreConnectionPool):
+ self.pool = pool
+
+ async def execute_query(self, query, params=None, isolation_level=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ if isolation_level:
+ async with conn.transaction(isolation=isolation_level):
+ if params:
+ return await conn.execute(query, *params)
+ else:
+ return await conn.execute(query)
+ else:
+ if params:
+ return await conn.execute(query, *params)
+ else:
+ return await conn.execute(query)
+
+ async def execute_many(self, query, params=None, batch_size=1000):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ if params:
+ results = []
+ for i in range(0, len(params), batch_size):
+ param_batch = params[i : i + batch_size]
+ result = await conn.executemany(query, param_batch)
+ results.append(result)
+ return results
+ else:
+ return await conn.executemany(query)
+
+ async def fetch_query(self, query, params=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ try:
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ return (
+ await conn.fetch(query, *params)
+ if params
+ else await conn.fetch(query)
+ )
+ except asyncpg.exceptions.DuplicatePreparedStatementError:
+ error_msg = textwrap.dedent("""
+ Database Configuration Error
+
+ Your database provider does not support statement caching.
+
+ To fix this, either:
+ • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
+ • Add statement_cache_size = 0 to your database configuration:
+
+ [database.postgres_configuration_settings]
+ statement_cache_size = 0
+
+ This is required when using connection poolers like PgBouncer or
+ managed database services like Supabase.
+ """).strip()
+ raise ValueError(error_msg) from None
+
+ async def fetchrow_query(self, query, params=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ if params:
+ return await conn.fetchrow(query, *params)
+ else:
+ return await conn.fetchrow(query)
+
+ @asynccontextmanager
+ async def transaction(self, isolation_level=None):
+ """Async context manager for database transactions.
+
+ Args:
+ isolation_level: Optional isolation level for the transaction
+
+ Yields:
+ The connection manager instance for use within the transaction
+ """
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction(isolation=isolation_level):
+ try:
+ yield self
+ except Exception as e:
+ logger.error(f"Transaction failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py b/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py
new file mode 100644
index 00000000..177f3395
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py
@@ -0,0 +1,1316 @@
+import copy
+import json
+import logging
+import math
+import time
+import uuid
+from typing import Any, Optional, TypedDict
+from uuid import UUID
+
+import numpy as np
+
+from core.base import (
+ ChunkSearchResult,
+ Handler,
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexMeasure,
+ IndexMethod,
+ R2RException,
+ SearchSettings,
+ VectorEntry,
+ VectorQuantizationType,
+ VectorTableName,
+)
+from core.base.utils import _decorate_vector_type
+
+from .base import PostgresConnectionManager
+from .filters import apply_filters
+
+logger = logging.getLogger()
+
+
+def psql_quote_literal(value: str) -> str:
+ """Safely quote a string literal for PostgreSQL to prevent SQL injection.
+
+ This is a simple implementation - in production, you should use proper parameterization
+ or your database driver's quoting functions.
+ """
+ return "'" + value.replace("'", "''") + "'"
+
+
+def index_measure_to_ops(
+ measure: IndexMeasure,
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+):
+ return _decorate_vector_type(measure.ops, quantization_type)
+
+
+def quantize_vector_to_binary(
+ vector: list[float] | np.ndarray,
+ threshold: float = 0.0,
+) -> bytes:
+ """Quantizes a float vector to a binary vector string for PostgreSQL bit
+ type. Used when quantization_type is INT1.
+
+ Args:
+ vector (List[float] | np.ndarray): Input vector of floats
+ threshold (float, optional): Threshold for binarization. Defaults to 0.0.
+
+ Returns:
+ str: Binary string representation for PostgreSQL bit type
+ """
+ # Convert input to numpy array if it isn't already
+ if not isinstance(vector, np.ndarray):
+ vector = np.array(vector)
+
+ # Convert to binary (1 where value > threshold, 0 otherwise)
+ binary_vector = (vector > threshold).astype(int)
+
+ # Convert to string of 1s and 0s
+ # Convert to string of 1s and 0s, then to bytes
+ binary_string = "".join(map(str, binary_vector))
+ return binary_string.encode("ascii")
+
+
+class HybridSearchIntermediateResult(TypedDict):
+ semantic_rank: int
+ full_text_rank: int
+ data: ChunkSearchResult
+ rrf_score: float
+
+
+class PostgresChunksHandler(Handler):
+ TABLE_NAME = VectorTableName.CHUNKS
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ dimension: int | float,
+ quantization_type: VectorQuantizationType,
+ ):
+ super().__init__(project_name, connection_manager)
+ self.dimension = dimension
+ self.quantization_type = quantization_type
+
+ async def create_tables(self):
+ # First check if table already exists and validate dimensions
+ table_exists_query = """
+ SELECT EXISTS (
+ SELECT FROM pg_tables
+ WHERE schemaname = $1
+ AND tablename = $2
+ );
+ """
+ table_name = VectorTableName.CHUNKS
+ table_exists = await self.connection_manager.fetch_query(
+ table_exists_query, (self.project_name, table_name)
+ )
+
+ if len(table_exists) > 0 and table_exists[0]["exists"]:
+ # Table exists, check vector dimension
+ vector_dim_query = """
+ SELECT a.atttypmod as dimension
+ FROM pg_attribute a
+ JOIN pg_class c ON a.attrelid = c.oid
+ JOIN pg_namespace n ON c.relnamespace = n.oid
+ WHERE n.nspname = $1
+ AND c.relname = $2
+ AND a.attname = 'vec';
+ """
+
+ vector_dim_result = await self.connection_manager.fetch_query(
+ vector_dim_query, (self.project_name, table_name)
+ )
+
+ if vector_dim_result and len(vector_dim_result) > 0:
+ existing_dimension = vector_dim_result[0]["dimension"]
+ # In pgvector, dimension is stored as atttypmod - 4
+ if existing_dimension > 0: # If it has a specific dimension
+ # Compare with provided dimension
+ if (
+ self.dimension > 0
+ and existing_dimension != self.dimension
+ ):
+ raise ValueError(
+ f"Dimension mismatch: Table '{self.project_name}.{table_name}' was created with "
+ f"dimension {existing_dimension}, but {self.dimension} was provided. "
+ f"You must use the same dimension for existing tables."
+ )
+
+ # Check for old table name
+ check_query = """
+ SELECT EXISTS (
+ SELECT FROM pg_tables
+ WHERE schemaname = $1
+ AND tablename = $2
+ );
+ """
+ old_table_exists = await self.connection_manager.fetch_query(
+ check_query, (self.project_name, self.project_name)
+ )
+
+ if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
+ raise ValueError(
+ f"Found old vector table '{self.project_name}.{self.project_name}'. "
+ "Please run `r2r db upgrade` with the CLI, or to run manually, "
+ "run in R2R/py/migrations with 'alembic upgrade head' to update "
+ "your database schema to the new version."
+ )
+
+ binary_col = (
+ ""
+ if self.quantization_type != VectorQuantizationType.INT1
+ else f"vec_binary bit({self.dimension}),"
+ )
+
+ if self.dimension > 0:
+ vector_col = f"vec vector({self.dimension})"
+ else:
+ vector_col = "vec vector"
+
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY,
+ document_id UUID,
+ owner_id UUID,
+ collection_ids UUID[],
+ {vector_col},
+ {binary_col}
+ text TEXT,
+ metadata JSONB,
+ fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
+ );
+ CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
+ CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
+ CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
+ CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
+ """
+
+ await self.connection_manager.execute_query(query)
+
+ async def upsert(self, entry: VectorEntry) -> None:
+ """Upsert function that handles vector quantization only when
+ quantization_type is INT1.
+
+ Matches the table schema where vec_binary column only exists for INT1
+ quantization.
+ """
+ # Check the quantization type to determine which columns to use
+ if self.quantization_type == VectorQuantizationType.INT1:
+ bit_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+
+ # For quantized vectors, use vec_binary column
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ vec_binary = EXCLUDED.vec_binary,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+ await self.connection_manager.execute_query(
+ query,
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ quantize_vector_to_binary(
+ entry.vector.data
+ ), # Convert to binary
+ entry.text,
+ json.dumps(entry.metadata),
+ ),
+ )
+ else:
+ # For regular vectors, use vec column only
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+
+ await self.connection_manager.execute_query(
+ query,
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ entry.text,
+ json.dumps(entry.metadata),
+ ),
+ )
+
+ async def upsert_entries(self, entries: list[VectorEntry]) -> None:
+ """Batch upsert function that handles vector quantization only when
+ quantization_type is INT1.
+
+ Matches the table schema where vec_binary column only exists for INT1
+ quantization.
+ """
+ if self.quantization_type == VectorQuantizationType.INT1:
+ bit_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+
+ # For quantized vectors, use vec_binary column
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ vec_binary = EXCLUDED.vec_binary,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+ bin_params = [
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ quantize_vector_to_binary(
+ entry.vector.data
+ ), # Convert to binary
+ entry.text,
+ json.dumps(entry.metadata),
+ )
+ for entry in entries
+ ]
+ await self.connection_manager.execute_many(query, bin_params)
+
+ else:
+ # For regular vectors, use vec column only
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ (id, document_id, owner_id, collection_ids, vec, text, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ ON CONFLICT (id) DO UPDATE SET
+ document_id = EXCLUDED.document_id,
+ owner_id = EXCLUDED.owner_id,
+ collection_ids = EXCLUDED.collection_ids,
+ vec = EXCLUDED.vec,
+ text = EXCLUDED.text,
+ metadata = EXCLUDED.metadata;
+ """
+ params = [
+ (
+ entry.id,
+ entry.document_id,
+ entry.owner_id,
+ entry.collection_ids,
+ str(entry.vector.data),
+ entry.text,
+ json.dumps(entry.metadata),
+ )
+ for entry in entries
+ ]
+
+ await self.connection_manager.execute_many(query, params)
+
+ async def semantic_search(
+ self, query_vector: list[float], search_settings: SearchSettings
+ ) -> list[ChunkSearchResult]:
+ try:
+ imeasure_obj = IndexMeasure(
+ search_settings.chunk_settings.index_measure
+ )
+ except ValueError:
+ raise ValueError("Invalid index measure") from None
+
+ table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
+ cols = [
+ f"{table_name}.id",
+ f"{table_name}.document_id",
+ f"{table_name}.owner_id",
+ f"{table_name}.collection_ids",
+ f"{table_name}.text",
+ ]
+
+ params: list[str | int | bytes] = []
+
+ # For binary vectors (INT1), implement two-stage search
+ if self.quantization_type == VectorQuantizationType.INT1:
+ # Convert query vector to binary format
+ binary_query = quantize_vector_to_binary(query_vector)
+ # TODO - Put depth multiplier in config / settings
+ extended_limit = (
+ search_settings.limit * 20
+ ) # Get 20x candidates for re-ranking
+
+ if (
+ imeasure_obj == IndexMeasure.hamming_distance
+ or imeasure_obj == IndexMeasure.jaccard_distance
+ ):
+ binary_search_measure_repr = imeasure_obj.pgvector_repr
+ else:
+ binary_search_measure_repr = (
+ IndexMeasure.hamming_distance.pgvector_repr
+ )
+
+ # Use binary column and binary-specific distance measures for first stage
+ bit_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
+ stage1_param = binary_query
+
+ cols.append(
+ f"{table_name}.vec"
+ ) # Need original vector for re-ranking
+ if search_settings.include_metadatas:
+ cols.append(f"{table_name}.metadata")
+
+ select_clause = ", ".join(cols)
+ where_clause = ""
+ params.append(stage1_param)
+
+ if search_settings.filters:
+ where_clause, params = apply_filters(
+ search_settings.filters, params, mode="where_clause"
+ )
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+
+ # First stage: Get candidates using binary search
+ query = f"""
+ WITH candidates AS (
+ SELECT {select_clause},
+ ({stage1_distance}) as binary_distance
+ FROM {table_name}
+ {where_clause}
+ ORDER BY {stage1_distance}
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ -- Second stage: Re-rank using original vectors
+ SELECT
+ id,
+ document_id,
+ owner_id,
+ collection_ids,
+ text,
+ {"metadata," if search_settings.include_metadatas else ""}
+ (vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
+ FROM candidates
+ ORDER BY distance
+ LIMIT ${len(params) + 3}
+ """
+
+ params.extend(
+ [
+ extended_limit, # First stage limit
+ search_settings.offset,
+ search_settings.limit, # Final limit
+ str(query_vector), # For re-ranking
+ ]
+ )
+
+ else:
+ # Standard float vector handling
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
+ query_param = str(query_vector)
+
+ if search_settings.include_scores:
+ cols.append(f"({distance_calc}) AS distance")
+ if search_settings.include_metadatas:
+ cols.append(f"{table_name}.metadata")
+
+ select_clause = ", ".join(cols)
+ where_clause = ""
+ params.append(query_param)
+
+ if search_settings.filters:
+ where_clause, new_params = apply_filters(
+ search_settings.filters,
+ params,
+ mode="where_clause", # Get just conditions without WHERE
+ )
+ params = new_params
+
+ query = f"""
+ SELECT {select_clause}
+ FROM {table_name}
+ {where_clause}
+ ORDER BY {distance_calc}
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ """
+ params.extend([search_settings.limit, search_settings.offset])
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ ChunkSearchResult(
+ id=UUID(str(result["id"])),
+ document_id=UUID(str(result["document_id"])),
+ owner_id=UUID(str(result["owner_id"])),
+ collection_ids=result["collection_ids"],
+ text=result["text"],
+ score=(
+ (1 - float(result["distance"]))
+ if "distance" in result
+ else -1
+ ),
+ metadata=(
+ json.loads(result["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ )
+ for result in results
+ ]
+
+ async def full_text_search(
+ self, query_text: str, search_settings: SearchSettings
+ ) -> list[ChunkSearchResult]:
+ conditions = []
+ params: list[str | int | bytes] = [query_text]
+
+ conditions.append("fts @@ websearch_to_tsquery('english', $1)")
+
+ if search_settings.filters:
+ filter_condition, params = apply_filters(
+ search_settings.filters, params, mode="condition_only"
+ )
+ if filter_condition:
+ conditions.append(filter_condition)
+
+ where_clause = "WHERE " + " AND ".join(conditions)
+
+ query = f"""
+ SELECT
+ id,
+ document_id,
+ owner_id,
+ collection_ids,
+ text,
+ metadata,
+ ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ {where_clause}
+ ORDER BY rank DESC
+ OFFSET ${len(params) + 1}
+ LIMIT ${len(params) + 2}
+ """
+
+ params.extend(
+ [
+ search_settings.offset,
+ search_settings.hybrid_settings.full_text_limit,
+ ]
+ )
+
+ results = await self.connection_manager.fetch_query(query, params)
+ return [
+ ChunkSearchResult(
+ id=UUID(str(r["id"])),
+ document_id=UUID(str(r["document_id"])),
+ owner_id=UUID(str(r["owner_id"])),
+ collection_ids=r["collection_ids"],
+ text=r["text"],
+ score=float(r["rank"]),
+ metadata=json.loads(r["metadata"]),
+ )
+ for r in results
+ ]
+
+ async def hybrid_search(
+ self,
+ query_text: str,
+ query_vector: list[float],
+ search_settings: SearchSettings,
+ *args,
+ **kwargs,
+ ) -> list[ChunkSearchResult]:
+ if search_settings.hybrid_settings is None:
+ raise ValueError(
+ "Please provide a valid `hybrid_settings` in the `search_settings`."
+ )
+ if (
+ search_settings.hybrid_settings.full_text_limit
+ < search_settings.limit
+ ):
+ raise ValueError(
+ "The `full_text_limit` must be greater than or equal to the `limit`."
+ )
+
+ semantic_settings = copy.deepcopy(search_settings)
+ semantic_settings.limit += search_settings.offset
+
+ full_text_settings = copy.deepcopy(search_settings)
+ full_text_settings.hybrid_settings.full_text_limit += (
+ search_settings.offset
+ )
+
+ semantic_results: list[ChunkSearchResult] = await self.semantic_search(
+ query_vector, semantic_settings
+ )
+ full_text_results: list[
+ ChunkSearchResult
+ ] = await self.full_text_search(query_text, full_text_settings)
+
+ semantic_limit = search_settings.limit
+ full_text_limit = search_settings.hybrid_settings.full_text_limit
+ semantic_weight = search_settings.hybrid_settings.semantic_weight
+ full_text_weight = search_settings.hybrid_settings.full_text_weight
+ rrf_k = search_settings.hybrid_settings.rrf_k
+
+ combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
+
+ for rank, result in enumerate(semantic_results, 1):
+ combined_results[result.id] = {
+ "semantic_rank": rank,
+ "full_text_rank": full_text_limit,
+ "data": result,
+ "rrf_score": 0.0, # Initialize with 0, will be calculated later
+ }
+
+ for rank, result in enumerate(full_text_results, 1):
+ if result.id in combined_results:
+ combined_results[result.id]["full_text_rank"] = rank
+ else:
+ combined_results[result.id] = {
+ "semantic_rank": semantic_limit,
+ "full_text_rank": rank,
+ "data": result,
+ "rrf_score": 0.0, # Initialize with 0, will be calculated later
+ }
+
+ combined_results = {
+ k: v
+ for k, v in combined_results.items()
+ if v["semantic_rank"] <= semantic_limit * 2
+ and v["full_text_rank"] <= full_text_limit * 2
+ }
+
+ for hyb_result in combined_results.values():
+ semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
+ full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
+ hyb_result["rrf_score"] = (
+ semantic_score * semantic_weight
+ + full_text_score * full_text_weight
+ ) / (semantic_weight + full_text_weight)
+
+ sorted_results = sorted(
+ combined_results.values(),
+ key=lambda x: x["rrf_score"],
+ reverse=True,
+ )
+ offset_results = sorted_results[
+ search_settings.offset : search_settings.offset
+ + search_settings.limit
+ ]
+
+ return [
+ ChunkSearchResult(
+ id=result["data"].id,
+ document_id=result["data"].document_id,
+ owner_id=result["data"].owner_id,
+ collection_ids=result["data"].collection_ids,
+ text=result["data"].text,
+ score=result["rrf_score"],
+ metadata={
+ **result["data"].metadata,
+ "semantic_rank": result["semantic_rank"],
+ "full_text_rank": result["full_text_rank"],
+ },
+ )
+ for result in offset_results
+ ]
+
+ async def delete(
+ self, filters: dict[str, Any]
+ ) -> dict[str, dict[str, str]]:
+ params: list[str | int | bytes] = []
+ where_clause, params = apply_filters(
+ filters, params, mode="condition_only"
+ )
+
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ RETURNING id, document_id, text;
+ """
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return {
+ str(result["id"]): {
+ "status": "deleted",
+ "id": str(result["id"]),
+ "document_id": str(result["document_id"]),
+ "text": result["text"],
+ }
+ for result in results
+ }
+
+ async def assign_document_chunks_to_collection(
+ self, document_id: UUID, collection_id: UUID
+ ) -> None:
+ query = f"""
+ UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ SET collection_ids = array_append(collection_ids, $1)
+ WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
+ """
+ return await self.connection_manager.execute_query(
+ query, (str(collection_id), str(document_id))
+ )
+
+ async def remove_document_from_collection_vector(
+ self, document_id: UUID, collection_id: UUID
+ ) -> None:
+ query = f"""
+ UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE document_id = $2;
+ """
+ await self.connection_manager.execute_query(
+ query, (collection_id, document_id)
+ )
+
+ async def delete_user_vector(self, owner_id: UUID) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE owner_id = $1;
+ """
+ await self.connection_manager.execute_query(query, (owner_id,))
+
+ async def delete_collection_vector(self, collection_id: UUID) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE $1 = ANY(collection_ids)
+ RETURNING collection_ids
+ """
+ await self.connection_manager.fetchrow_query(query, (collection_id,))
+ return None
+
+ async def list_document_chunks(
+ self,
+ document_id: UUID,
+ offset: int,
+ limit: int,
+ include_vectors: bool = False,
+ ) -> dict[str, Any]:
+ vector_select = ", vec" if include_vectors else ""
+ limit_clause = f"LIMIT {limit}" if limit > -1 else ""
+
+ query = f"""
+ SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ ORDER BY (metadata->>'chunk_order')::integer
+ OFFSET $2
+ {limit_clause};
+ """
+
+ params = [document_id, offset]
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ chunks = []
+ total = 0
+ if results:
+ total = results[0].get("total", 0)
+ chunks = [
+ {
+ "id": result["id"],
+ "document_id": result["document_id"],
+ "owner_id": result["owner_id"],
+ "collection_ids": result["collection_ids"],
+ "text": result["text"],
+ "metadata": json.loads(result["metadata"]),
+ "vector": (
+ json.loads(result["vec"]) if include_vectors else None
+ ),
+ }
+ for result in results
+ ]
+
+ return {"results": chunks, "total_entries": total}
+
+ async def get_chunk(self, id: UUID) -> dict:
+ query = f"""
+ SELECT id, document_id, owner_id, collection_ids, text, metadata
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE id = $1;
+ """
+
+ result = await self.connection_manager.fetchrow_query(query, (id,))
+
+ if result:
+ return {
+ "id": result["id"],
+ "document_id": result["document_id"],
+ "owner_id": result["owner_id"],
+ "collection_ids": result["collection_ids"],
+ "text": result["text"],
+ "metadata": json.loads(result["metadata"]),
+ }
+ raise R2RException(
+ message=f"Chunk with ID {id} not found", status_code=404
+ )
+
+ async def create_index(
+ self,
+ table_name: Optional[VectorTableName] = None,
+ index_measure: IndexMeasure = IndexMeasure.cosine_distance,
+ index_method: IndexMethod = IndexMethod.auto,
+ index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
+ index_name: Optional[str] = None,
+ index_column: Optional[str] = None,
+ concurrently: bool = True,
+ ) -> None:
+ """Creates an index for the collection.
+
+ Note:
+ When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
+ process that enables performant indexes to be built for large collections with low end
+ database hardware.
+
+ Those steps are:
+
+ - Creates a new table with a different name
+ - Randomly selects records from the existing table
+ - Inserts the random records from the existing table into the new table
+ - Creates the requested vector index on the new table
+ - Upserts all data from the existing table into the new table
+ - Drops the existing table
+ - Renames the new table to the existing tables name
+
+ If you create dependencies (like views) on the table that underpins
+ a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
+ it will succeed.
+
+ Args:
+ index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
+ index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
+ index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
+ index_name (str, optional): The name of the index to create. Defaults to None.
+ concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
+ Raises:
+ ValueError: If an invalid index method is used, or if *replace* is False and an index already exists.
+ """
+
+ if table_name == VectorTableName.CHUNKS:
+ table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
+ if index_column:
+ col_name = index_column
+ else:
+ col_name = (
+ "vec"
+ if (
+ index_measure != IndexMeasure.hamming_distance
+ and index_measure != IndexMeasure.jaccard_distance
+ )
+ else "vec_binary"
+ )
+ elif table_name == VectorTableName.ENTITIES_DOCUMENT:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.GRAPHS_ENTITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.COMMUNITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.COMMUNITIES}"
+ )
+ col_name = "embedding"
+ else:
+ raise ValueError("invalid table name")
+
+ if index_method not in (
+ IndexMethod.ivfflat,
+ IndexMethod.hnsw,
+ IndexMethod.auto,
+ ):
+ raise ValueError("invalid index method")
+
+ if index_arguments:
+ # Disallow case where user submits index arguments but uses the
+ # IndexMethod.auto index (index build arguments should only be
+ # used with a specific index)
+ if index_method == IndexMethod.auto:
+ raise ValueError(
+ "Index build parameters are not allowed when using the IndexMethod.auto index."
+ )
+ # Disallow case where user specifies one index type but submits
+ # index build arguments for the other index type
+ if (
+ isinstance(index_arguments, IndexArgsHNSW)
+ and index_method != IndexMethod.hnsw
+ ) or (
+ isinstance(index_arguments, IndexArgsIVFFlat)
+ and index_method != IndexMethod.ivfflat
+ ):
+ raise ValueError(
+ f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
+ )
+
+ if index_method == IndexMethod.auto:
+ index_method = IndexMethod.hnsw
+
+ ops = index_measure_to_ops(
+ index_measure # , quantization_type=self.quantization_type
+ )
+
+ if ops is None:
+ raise ValueError("Unknown index measure")
+
+ concurrently_sql = "CONCURRENTLY" if concurrently else ""
+
+ index_name = (
+ index_name
+ or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
+ )
+
+ create_index_sql = f"""
+ CREATE INDEX {concurrently_sql} {index_name}
+ ON {table_name_str}
+ USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
+ """
+
+ try:
+ if concurrently:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ # Disable automatic transaction management
+ await conn.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
+ )
+ await conn.execute(create_index_sql)
+ else:
+ # Non-concurrent index creation can use normal query execution
+ await self.connection_manager.execute_query(create_index_sql)
+ except Exception as e:
+ raise Exception(f"Failed to create index: {e}") from e
+ return None
+
+ async def list_indices(
+ self,
+ offset: int,
+ limit: int,
+ filters: Optional[dict[str, Any]] = None,
+ ) -> dict:
+ where_clauses = []
+ params: list[Any] = [self.project_name] # Start with schema name
+ param_count = 1
+
+ # Handle filtering
+ if filters:
+ if "table_name" in filters:
+ where_clauses.append(f"i.tablename = ${param_count + 1}")
+ params.append(filters["table_name"])
+ param_count += 1
+ if "index_method" in filters:
+ where_clauses.append(f"am.amname = ${param_count + 1}")
+ params.append(filters["index_method"])
+ param_count += 1
+ if "index_name" in filters:
+ where_clauses.append(
+ f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
+ )
+ params.append(f"%{filters['index_name']}%")
+ param_count += 1
+
+ where_clause = " AND ".join(where_clauses) if where_clauses else ""
+ if where_clause:
+ where_clause = f"AND {where_clause}"
+
+ query = f"""
+ WITH index_info AS (
+ SELECT
+ i.indexname as name,
+ i.tablename as table_name,
+ i.indexdef as definition,
+ am.amname as method,
+ pg_relation_size(c.oid) as size_in_bytes,
+ c.reltuples::bigint as row_estimate,
+ COALESCE(psat.idx_scan, 0) as number_of_scans,
+ COALESCE(psat.idx_tup_read, 0) as tuples_read,
+ COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
+ COUNT(*) OVER() as total_count
+ FROM pg_indexes i
+ JOIN pg_class c ON c.relname = i.indexname
+ JOIN pg_am am ON c.relam = am.oid
+ LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
+ AND psat.schemaname = i.schemaname
+ WHERE i.schemaname = $1
+ AND i.indexdef LIKE '%vector%'
+ {where_clause}
+ )
+ SELECT *
+ FROM index_info
+ ORDER BY name
+ LIMIT ${param_count + 1}
+ OFFSET ${param_count + 2}
+ """
+
+ # Add limit and offset to params
+ params.extend([limit, offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ indices = []
+ total_entries = 0
+
+ if results:
+ total_entries = results[0]["total_count"]
+ for result in results:
+ index_info = {
+ "name": result["name"],
+ "table_name": result["table_name"],
+ "definition": result["definition"],
+ "size_in_bytes": result["size_in_bytes"],
+ "row_estimate": result["row_estimate"],
+ "number_of_scans": result["number_of_scans"],
+ "tuples_read": result["tuples_read"],
+ "tuples_fetched": result["tuples_fetched"],
+ }
+ indices.append(index_info)
+
+ return {"indices": indices, "total_entries": total_entries}
+
+ async def delete_index(
+ self,
+ index_name: str,
+ table_name: Optional[VectorTableName] = None,
+ concurrently: bool = True,
+ ) -> None:
+ """Deletes a vector index.
+
+ Args:
+ index_name (str): Name of the index to delete
+ table_name (VectorTableName, optional): Table the index belongs to
+ concurrently (bool): Whether to drop the index concurrently
+
+ Raises:
+ ValueError: If table name is invalid or index doesn't exist
+ Exception: If index deletion fails
+ """
+ # Validate table name and get column name
+ if table_name == VectorTableName.CHUNKS:
+ table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
+ col_name = "vec"
+ elif table_name == VectorTableName.ENTITIES_DOCUMENT:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.GRAPHS_ENTITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
+ )
+ col_name = "description_embedding"
+ elif table_name == VectorTableName.COMMUNITIES:
+ table_name_str = (
+ f"{self.project_name}.{VectorTableName.COMMUNITIES}"
+ )
+ col_name = "description_embedding"
+ else:
+ raise ValueError("invalid table name")
+
+ # Extract schema and base table name
+ schema_name, base_table_name = table_name_str.split(".")
+
+ # Verify index exists and is a vector index
+ query = """
+ SELECT indexdef
+ FROM pg_indexes
+ WHERE indexname = $1
+ AND schemaname = $2
+ AND tablename = $3
+ AND indexdef LIKE $4
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, (index_name, schema_name, base_table_name, f"%({col_name}%")
+ )
+
+ if not result:
+ raise ValueError(
+ f"Vector index '{index_name}' does not exist on table {table_name_str}"
+ )
+
+ # Drop the index
+ concurrently_sql = "CONCURRENTLY" if concurrently else ""
+ drop_query = (
+ f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
+ )
+
+ try:
+ if concurrently:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ # Disable automatic transaction management
+ await conn.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
+ )
+ await conn.execute(drop_query)
+ else:
+ await self.connection_manager.execute_query(drop_query)
+ except Exception as e:
+ raise Exception(f"Failed to delete index: {e}") from e
+
+ async def list_chunks(
+ self,
+ offset: int,
+ limit: int,
+ filters: Optional[dict[str, Any]] = None,
+ include_vectors: bool = False,
+ ) -> dict[str, Any]:
+ """List chunks with pagination support.
+
+ Args:
+ offset (int, optional): Number of records to skip. Defaults to 0.
+ limit (int, optional): Maximum number of records to return. Defaults to 10.
+ filters (dict, optional): Dictionary of filters to apply. Defaults to None.
+ include_vectors (bool, optional): Whether to include vector data. Defaults to False.
+
+ Returns:
+ dict: Dictionary containing:
+ - results: List of chunk records
+ - total_entries: Total number of chunks matching the filters
+ """
+ vector_select = ", vec" if include_vectors else ""
+ select_clause = f"""
+ id, document_id, owner_id, collection_ids,
+ text, metadata{vector_select}, COUNT(*) OVER() AS total_entries
+ """
+
+ params: list[str | int | bytes] = []
+ where_clause = ""
+ if filters:
+ where_clause, params = apply_filters(
+ filters, params, mode="where_clause"
+ )
+
+ query = f"""
+ SELECT {select_clause}
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ {where_clause}
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ """
+
+ params.extend([limit, offset])
+
+ # Execute the query
+ results = await self.connection_manager.fetch_query(query, params)
+
+ # Process results
+ chunks = []
+ total_entries = 0
+ if results:
+ total_entries = results[0].get("total_entries", 0)
+ chunks = [
+ {
+ "id": str(result["id"]),
+ "document_id": str(result["document_id"]),
+ "owner_id": str(result["owner_id"]),
+ "collection_ids": result["collection_ids"],
+ "text": result["text"],
+ "metadata": json.loads(result["metadata"]),
+ "vector": (
+ json.loads(result["vec"]) if include_vectors else None
+ ),
+ }
+ for result in results
+ ]
+
+ return {"results": chunks, "total_entries": total_entries}
+
+ async def search_documents(
+ self,
+ query_text: str,
+ settings: SearchSettings,
+ ) -> list[dict[str, Any]]:
+ """Search for documents based on their metadata fields and/or body
+ text. Joins with documents table to get complete document metadata.
+
+ Args:
+ query_text (str): The search query text
+ settings (SearchSettings): Search settings including search preferences and filters
+
+ Returns:
+ list[dict[str, Any]]: List of documents with their search scores and complete metadata
+ """
+ where_clauses = []
+ params: list[str | int | bytes] = [query_text]
+
+ search_over_body = getattr(settings, "search_over_body", True)
+ search_over_metadata = getattr(settings, "search_over_metadata", True)
+ metadata_weight = getattr(settings, "metadata_weight", 3.0)
+ title_weight = getattr(settings, "title_weight", 1.0)
+ metadata_keys = getattr(
+ settings, "metadata_keys", ["title", "description"]
+ )
+
+ # Build the dynamic metadata field search expression
+ metadata_fields_expr = " || ' ' || ".join(
+ [
+ f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
+ for key in metadata_keys # type: ignore
+ ]
+ )
+
+ query = f"""
+ WITH
+ -- Metadata search scores
+ metadata_scores AS (
+ SELECT DISTINCT ON (v.document_id)
+ v.document_id,
+ d.metadata as doc_metadata,
+ CASE WHEN $1 = '' THEN 0.0
+ ELSE
+ ts_rank_cd(
+ setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
+ websearch_to_tsquery('english', $1),
+ 32
+ )
+ END as metadata_rank
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
+ LEFT JOIN {self._get_table_name("documents")} d ON v.document_id = d.id
+ WHERE v.metadata IS NOT NULL
+ ),
+ -- Body search scores
+ body_scores AS (
+ SELECT
+ document_id,
+ AVG(
+ ts_rank_cd(
+ setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
+ websearch_to_tsquery('english', $1),
+ 32
+ )
+ ) as body_rank
+ FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+ WHERE $1 != ''
+ {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""}
+ GROUP BY document_id
+ ),
+ -- Combined scores with document metadata
+ combined_scores AS (
+ SELECT
+ COALESCE(m.document_id, b.document_id) as document_id,
+ m.doc_metadata as metadata,
+ COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
+ COALESCE(b.body_rank, 0) as debug_body_rank,
+ CASE
+ WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN
+ COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight}
+ WHEN {str(search_over_metadata).lower()} THEN
+ COALESCE(m.metadata_rank, 0)
+ WHEN {str(search_over_body).lower()} THEN
+ COALESCE(b.body_rank, 0)
+ ELSE 0
+ END as rank
+ FROM metadata_scores m
+ FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
+ WHERE (
+ ($1 = '') OR
+ ({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR
+ ({str(search_over_body).lower()} AND b.body_rank > 0)
+ )
+ """
+
+ # Add any additional filters
+ if settings.filters:
+ filter_clause, params = apply_filters(settings.filters, params)
+ where_clauses.append(filter_clause)
+
+ if where_clauses:
+ query += f" AND {' AND '.join(where_clauses)}"
+
+ query += """
+ )
+ SELECT
+ document_id,
+ metadata,
+ rank as score,
+ debug_metadata_rank,
+ debug_body_rank
+ FROM combined_scores
+ WHERE rank > 0
+ ORDER BY rank DESC
+ OFFSET ${offset_param} LIMIT ${limit_param}
+ """.format(
+ offset_param=len(params) + 1,
+ limit_param=len(params) + 2,
+ )
+
+ # Add offset and limit to params
+ params.extend([settings.offset, settings.limit])
+
+ # Execute query
+ results = await self.connection_manager.fetch_query(query, params)
+
+ # Format results with complete document metadata
+ return [
+ {
+ "document_id": str(r["document_id"]),
+ "metadata": (
+ json.loads(r["metadata"])
+ if isinstance(r["metadata"], str)
+ else r["metadata"]
+ ),
+ "score": float(r["score"]),
+ "debug_metadata_rank": float(r["debug_metadata_rank"]),
+ "debug_body_rank": float(r["debug_body_rank"]),
+ }
+ for r in results
+ ]
+
+ def _get_index_options(
+ self,
+ method: IndexMethod,
+ index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
+ ) -> str:
+ if method == IndexMethod.ivfflat:
+ if isinstance(index_arguments, IndexArgsIVFFlat):
+ return f"WITH (lists={index_arguments.n_lists})"
+ else:
+ # Default value if no arguments provided
+ return "WITH (lists=100)"
+ elif method == IndexMethod.hnsw:
+ if isinstance(index_arguments, IndexArgsHNSW):
+ return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
+ else:
+ # Default values if no arguments provided
+ return "WITH (m=16, ef_construction=64)"
+ else:
+ return "" # No options for other methods
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/collections.py b/.venv/lib/python3.12/site-packages/core/providers/database/collections.py
new file mode 100644
index 00000000..72adaff2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/collections.py
@@ -0,0 +1,701 @@
+import csv
+import json
+import logging
+import tempfile
+from typing import IO, Any, Optional
+from uuid import UUID, uuid4
+
+from asyncpg.exceptions import UniqueViolationError
+from fastapi import HTTPException
+
+from core.base import (
+ DatabaseConfig,
+ GraphExtractionStatus,
+ Handler,
+ R2RException,
+ generate_default_user_collection_id,
+)
+from core.base.abstractions import (
+ DocumentResponse,
+ DocumentType,
+ IngestionStatus,
+)
+from core.base.api.models import CollectionResponse
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresCollectionsHandler(Handler):
+ TABLE_NAME = "collections"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ config: DatabaseConfig,
+ ):
+ self.config = config
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self) -> None:
+ # 1. Create the table if it does not exist.
+ create_table_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ owner_id UUID,
+ name TEXT NOT NULL,
+ description TEXT,
+ graph_sync_status TEXT DEFAULT 'pending',
+ graph_cluster_status TEXT DEFAULT 'pending',
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ user_count INT DEFAULT 0,
+ document_count INT DEFAULT 0
+ );
+ """
+ await self.connection_manager.execute_query(create_table_query)
+
+ # 2. Check for duplicate rows that would violate the uniqueness constraint.
+ check_duplicates_query = f"""
+ SELECT owner_id, name, COUNT(*) AS cnt
+ FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ GROUP BY owner_id, name
+ HAVING COUNT(*) > 1
+ """
+ duplicates = await self.connection_manager.fetch_query(
+ check_duplicates_query
+ )
+ if duplicates:
+ logger.warning(
+ "Cannot add unique constraint (owner_id, name) because duplicates exist. "
+ "Please resolve duplicates first. Found duplicates: %s",
+ duplicates,
+ )
+ return # or raise an exception, depending on your use case
+
+ # 3. Parse the qualified table name into schema and table.
+ qualified_table = self._get_table_name(
+ PostgresCollectionsHandler.TABLE_NAME
+ )
+ if "." in qualified_table:
+ schema, table = qualified_table.split(".", 1)
+ else:
+ schema = "public"
+ table = qualified_table
+
+ # 4. Add the unique constraint if it does not already exist.
+ alter_table_constraint = f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint c
+ JOIN pg_class t ON c.conrelid = t.oid
+ JOIN pg_namespace n ON n.oid = t.relnamespace
+ WHERE t.relname = '{table}'
+ AND n.nspname = '{schema}'
+ AND c.conname = 'unique_owner_collection_name'
+ ) THEN
+ ALTER TABLE {qualified_table}
+ ADD CONSTRAINT unique_owner_collection_name
+ UNIQUE (owner_id, name);
+ END IF;
+ END;
+ $$;
+ """
+ await self.connection_manager.execute_query(alter_table_constraint)
+
+ async def collection_exists(self, collection_id: UUID) -> bool:
+ """Check if a collection exists."""
+ query = f"""
+ SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id]
+ )
+ return result is not None
+
+ async def create_collection(
+ self,
+ owner_id: UUID,
+ name: Optional[str] = None,
+ description: str | None = None,
+ collection_id: Optional[UUID] = None,
+ ) -> CollectionResponse:
+ if not name and not collection_id:
+ name = self.config.default_collection_name
+ collection_id = generate_default_user_collection_id(owner_id)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ (id, owner_id, name, description)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
+ """
+ params = [
+ collection_id or uuid4(),
+ owner_id,
+ name,
+ description,
+ ]
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+ if not result:
+ raise R2RException(
+ status_code=404, message="Collection not found"
+ )
+
+ return CollectionResponse(
+ id=result["id"],
+ owner_id=result["owner_id"],
+ name=result["name"],
+ description=result["description"],
+ graph_cluster_status=result["graph_cluster_status"],
+ graph_sync_status=result["graph_sync_status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ user_count=0,
+ document_count=0,
+ )
+ except UniqueViolationError:
+ raise R2RException(
+ message="Collection with this ID already exists",
+ status_code=409,
+ ) from None
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while creating the collection: {e}",
+ ) from e
+
+ async def update_collection(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> CollectionResponse:
+ """Update an existing collection."""
+ if not await self.collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ update_fields = []
+ params: list = []
+ param_index = 1
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(collection_id)
+
+ query = f"""
+ WITH updated_collection AS (
+ UPDATE {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}
+ RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
+ )
+ SELECT
+ uc.*,
+ COUNT(DISTINCT u.id) FILTER (WHERE u.id IS NOT NULL) as user_count,
+ COUNT(DISTINCT d.id) FILTER (WHERE d.id IS NOT NULL) as document_count
+ FROM updated_collection uc
+ LEFT JOIN {self._get_table_name("users")} u ON uc.id = ANY(u.collection_ids)
+ LEFT JOIN {self._get_table_name("documents")} d ON uc.id = ANY(d.collection_ids)
+ GROUP BY uc.id, uc.owner_id, uc.name, uc.description, uc.graph_sync_status, uc.graph_cluster_status, uc.created_at, uc.updated_at
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+ if not result:
+ raise R2RException(
+ status_code=404, message="Collection not found"
+ )
+
+ return CollectionResponse(
+ id=result["id"],
+ owner_id=result["owner_id"],
+ name=result["name"],
+ description=result["description"],
+ graph_sync_status=result["graph_sync_status"],
+ graph_cluster_status=result["graph_cluster_status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ user_count=result["user_count"],
+ document_count=result["document_count"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the collection: {e}",
+ ) from e
+
+ async def delete_collection_relational(self, collection_id: UUID) -> None:
+ # Remove collection_id from users
+ user_update_query = f"""
+ UPDATE {self._get_table_name("users")}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE $1 = ANY(collection_ids)
+ """
+ await self.connection_manager.execute_query(
+ user_update_query, [collection_id]
+ )
+
+ # Remove collection_id from documents
+ document_update_query = f"""
+ WITH updated AS (
+ UPDATE {self._get_table_name("documents")}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE $1 = ANY(collection_ids)
+ RETURNING 1
+ )
+ SELECT COUNT(*) AS affected_rows FROM updated
+ """
+ await self.connection_manager.fetchrow_query(
+ document_update_query, [collection_id]
+ )
+
+ # Delete the collection
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE id = $1
+ RETURNING id
+ """
+ deleted = await self.connection_manager.fetchrow_query(
+ delete_query, [collection_id]
+ )
+
+ if not deleted:
+ raise R2RException(status_code=404, message="Collection not found")
+
+ async def documents_in_collection(
+ self, collection_id: UUID, offset: int, limit: int
+ ) -> dict[str, list[DocumentResponse] | int]:
+ """Get all documents in a specific collection with pagination.
+
+ Args:
+ collection_id (UUID): The ID of the collection to get documents from.
+ offset (int): The number of documents to skip.
+ limit (int): The maximum number of documents to return.
+ Returns:
+ List[DocumentResponse]: A list of DocumentResponse objects representing the documents in the collection.
+ Raises:
+ R2RException: If the collection doesn't exist.
+ """
+ if not await self.collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+ query = f"""
+ SELECT d.id, d.owner_id, d.type, d.metadata, d.title, d.version,
+ d.size_in_bytes, d.ingestion_status, d.extraction_status, d.created_at, d.updated_at, d.summary,
+ COUNT(*) OVER() AS total_entries
+ FROM {self._get_table_name("documents")} d
+ WHERE $1 = ANY(d.collection_ids)
+ ORDER BY d.created_at DESC
+ OFFSET $2
+ """
+
+ conditions = [collection_id, offset]
+ if limit != -1:
+ query += " LIMIT $3"
+ conditions.append(limit)
+
+ results = await self.connection_manager.fetch_query(query, conditions)
+ documents = [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=[collection_id],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata=json.loads(row["metadata"]),
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ )
+ for row in results
+ ]
+ total_entries = results[0]["total_entries"] if results else 0
+
+ return {"results": documents, "total_entries": total_entries}
+
+ async def get_collections_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[CollectionResponse] | int]:
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filter_user_ids:
+ conditions.append(f"""
+ c.id IN (
+ SELECT unnest(collection_ids)
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if filter_document_ids:
+ conditions.append(f"""
+ c.id IN (
+ SELECT unnest(collection_ids)
+ FROM {self.project_name}.documents
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_document_ids)
+ param_index += 1
+
+ if filter_collection_ids:
+ conditions.append(f"c.id = ANY(${param_index})")
+ params.append(filter_collection_ids)
+ param_index += 1
+
+ where_clause = (
+ f"WHERE {' AND '.join(conditions)}" if conditions else ""
+ )
+
+ query = f"""
+ SELECT
+ c.*,
+ COUNT(*) OVER() as total_entries
+ FROM {self.project_name}.collections c
+ {where_clause}
+ ORDER BY created_at DESC
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"] if results else 0
+
+ collections = [CollectionResponse(**row) for row in results]
+
+ return {"results": collections, "total_entries": total_entries}
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while fetching collections: {e}",
+ ) from e
+
+ async def assign_document_to_collection_relational(
+ self,
+ document_id: UUID,
+ collection_id: UUID,
+ ) -> UUID:
+ """Assign a document to a collection.
+
+ Args:
+ document_id (UUID): The ID of the document to assign.
+ collection_id (UUID): The ID of the collection to assign the document to.
+
+ Raises:
+ R2RException: If the collection doesn't exist, if the document is not found,
+ or if there's a database error.
+ """
+ try:
+ if not await self.collection_exists(collection_id):
+ raise R2RException(
+ status_code=404, message="Collection not found"
+ )
+
+ # First, check if the document exists
+ document_check_query = f"""
+ SELECT 1 FROM {self._get_table_name("documents")}
+ WHERE id = $1
+ """
+ document_exists = await self.connection_manager.fetchrow_query(
+ document_check_query, [document_id]
+ )
+
+ if not document_exists:
+ raise R2RException(
+ status_code=404, message="Document not found"
+ )
+
+ # If document exists, proceed with the assignment
+ assign_query = f"""
+ UPDATE {self._get_table_name("documents")}
+ SET collection_ids = array_append(collection_ids, $1)
+ WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ assign_query, [collection_id, document_id]
+ )
+
+ if not result:
+ # Document exists but was already assigned to the collection
+ raise R2RException(
+ status_code=409,
+ message="Document is already assigned to the collection",
+ )
+
+ update_collection_query = f"""
+ UPDATE {self._get_table_name("collections")}
+ SET document_count = document_count + 1
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(
+ query=update_collection_query, params=[collection_id]
+ )
+
+ return collection_id
+
+ except R2RException:
+ # Re-raise R2RExceptions as they are already handled
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error '{e}' occurred while assigning the document to the collection",
+ ) from e
+
+ async def remove_document_from_collection_relational(
+ self, document_id: UUID, collection_id: UUID
+ ) -> None:
+ """Remove a document from a collection.
+
+ Args:
+ document_id (UUID): The ID of the document to remove.
+ collection_id (UUID): The ID of the collection to remove the document from.
+
+ Raises:
+ R2RException: If the collection doesn't exist or if the document is not in the collection.
+ """
+ if not await self.collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ query = f"""
+ UPDATE {self._get_table_name("documents")}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE id = $2 AND $1 = ANY(collection_ids)
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id, document_id]
+ )
+
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message="Document not found in the specified collection",
+ )
+
+ await self.decrement_collection_document_count(
+ collection_id=collection_id
+ )
+
+ async def decrement_collection_document_count(
+ self, collection_id: UUID, decrement_by: int = 1
+ ) -> None:
+ """Decrement the document count for a collection.
+
+ Args:
+ collection_id (UUID): The ID of the collection to update
+ decrement_by (int): Number to decrease the count by (default: 1)
+ """
+ collection_query = f"""
+ UPDATE {self._get_table_name("collections")}
+ SET document_count = document_count - $1
+ WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ collection_query, [decrement_by, collection_id]
+ )
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "owner_id",
+ "name",
+ "description",
+ "graph_sync_status",
+ "graph_cluster_status",
+ "created_at",
+ "updated_at",
+ "user_count",
+ "document_count",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ owner_id::text,
+ name,
+ description,
+ graph_sync_status,
+ graph_cluster_status,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ user_count,
+ document_count
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ params = []
+ if filters:
+ conditions = []
+ param_index = 1
+
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "owner_id": row[1],
+ "name": row[2],
+ "description": row[3],
+ "graph_sync_status": row[4],
+ "graph_cluster_status": row[5],
+ "created_at": row[6],
+ "updated_at": row[7],
+ "user_count": row[8],
+ "document_count": row[9],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+ async def get_collection_by_name(
+ self, owner_id: UUID, name: str
+ ) -> Optional[CollectionResponse]:
+ """Fetch a collection by owner_id + name combination.
+
+ Return None if not found.
+ """
+ query = f"""
+ SELECT
+ id, owner_id, name, description, graph_sync_status,
+ graph_cluster_status, created_at, updated_at, user_count, document_count
+ FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE owner_id = $1 AND name = $2
+ LIMIT 1
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [owner_id, name]
+ )
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message="No collection found with the specified name",
+ )
+ return CollectionResponse(
+ id=result["id"],
+ owner_id=result["owner_id"],
+ name=result["name"],
+ description=result["description"],
+ graph_sync_status=result["graph_sync_status"],
+ graph_cluster_status=result["graph_cluster_status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ user_count=result["user_count"],
+ document_count=result["document_count"],
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/conversations.py b/.venv/lib/python3.12/site-packages/core/providers/database/conversations.py
new file mode 100644
index 00000000..2be2356c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/conversations.py
@@ -0,0 +1,858 @@
+import csv
+import json
+import logging
+import tempfile
+from datetime import datetime
+from typing import IO, Any, Optional
+from uuid import UUID, uuid4
+
+from fastapi import HTTPException
+
+from core.base import Handler, Message, R2RException
+from shared.api.models.management.responses import (
+ ConversationResponse,
+ MessageResponse,
+)
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+
+def _validate_image_size(
+ message: Message, max_size_bytes: int = 5 * 1024 * 1024
+) -> None:
+ """
+ Validates that images in a message don't exceed the maximum allowed size.
+
+ Args:
+ message: Message object to validate
+ max_size_bytes: Maximum allowed size for base64-encoded images (default: 5MB)
+
+ Raises:
+ R2RException: If image is too large
+ """
+ if (
+ hasattr(message, "image_data")
+ and message.image_data
+ and "data" in message.image_data
+ ):
+ base64_data = message.image_data["data"]
+
+ # Calculate approximate decoded size (base64 increases size by ~33%)
+ # The formula is: decoded_size = encoded_size * 3/4
+ estimated_size_bytes = len(base64_data) * 0.75
+
+ if estimated_size_bytes > max_size_bytes:
+ raise R2RException(
+ status_code=413, # Payload Too Large
+ message=f"Image too large: {estimated_size_bytes / 1024 / 1024:.2f}MB exceeds the maximum allowed size of {max_size_bytes / 1024 / 1024:.2f}MB",
+ )
+
+
+def _json_default(obj: Any) -> str:
+ """Default handler for objects not serializable by the standard json
+ encoder."""
+ if isinstance(obj, datetime):
+ # Return ISO8601 string
+ return obj.isoformat()
+ elif isinstance(obj, UUID):
+ # Convert UUID to string
+ return str(obj)
+ # If you have other special types, handle them here...
+ # e.g. decimal.Decimal -> str(obj)
+
+ # If we get here, raise an error or just default to string:
+ raise TypeError(f"Type {type(obj)} not serializable")
+
+
+def safe_dumps(obj: Any) -> str:
+ """Wrap `json.dumps` with a default that serializes UUID and datetime."""
+ return json.dumps(obj, default=_json_default)
+
+
+class PostgresConversationsHandler(Handler):
+ def __init__(
+ self, project_name: str, connection_manager: PostgresConnectionManager
+ ):
+ self.project_name = project_name
+ self.connection_manager = connection_manager
+
+ async def create_tables(self):
+ create_conversations_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ user_id UUID,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ name TEXT
+ );
+ """
+
+ create_messages_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ conversation_id UUID NOT NULL,
+ parent_id UUID,
+ content JSONB,
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
+ FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
+ );
+ """
+ await self.connection_manager.execute_query(create_conversations_query)
+ await self.connection_manager.execute_query(create_messages_query)
+
+ async def create_conversation(
+ self,
+ user_id: Optional[UUID] = None,
+ name: Optional[str] = None,
+ ) -> ConversationResponse:
+ query = f"""
+ INSERT INTO {self._get_table_name("conversations")} (user_id, name)
+ VALUES ($1, $2)
+ RETURNING id, extract(epoch from created_at) as created_at_epoch
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, [user_id, name]
+ )
+
+ return ConversationResponse(
+ id=result["id"],
+ created_at=result["created_at_epoch"],
+ user_id=user_id or None,
+ name=name or None,
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create conversation: {str(e)}",
+ ) from e
+
+ async def get_conversations_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ conversation_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, Any]:
+ conditions = []
+ params: list = []
+ param_index = 1
+
+ if filter_user_ids:
+ conditions.append(f"""
+ c.user_id IN (
+ SELECT id
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if conversation_ids:
+ conditions.append(f"c.id = ANY(${param_index})")
+ params.append(conversation_ids)
+ param_index += 1
+
+ where_clause = (
+ "WHERE " + " AND ".join(conditions) if conditions else ""
+ )
+
+ query = f"""
+ WITH conversation_overview AS (
+ SELECT c.id,
+ extract(epoch from c.created_at) as created_at_epoch,
+ c.user_id,
+ c.name
+ FROM {self._get_table_name("conversations")} c
+ {where_clause}
+ ),
+ counted_overview AS (
+ SELECT *,
+ COUNT(*) OVER() AS total_entries
+ FROM conversation_overview
+ )
+ SELECT * FROM counted_overview
+ ORDER BY created_at_epoch DESC
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"]
+ conversations = [
+ {
+ "id": str(row["id"]),
+ "created_at": row["created_at_epoch"],
+ "user_id": str(row["user_id"]) if row["user_id"] else None,
+ "name": row["name"] or None,
+ }
+ for row in results
+ ]
+
+ return {"results": conversations, "total_entries": total_entries}
+
+ async def add_message(
+ self,
+ conversation_id: UUID,
+ content: Message,
+ parent_id: Optional[UUID] = None,
+ metadata: Optional[dict] = None,
+ max_image_size_bytes: int = 5 * 1024 * 1024, # 5MB default
+ ) -> MessageResponse:
+ # Validate image size
+ try:
+ _validate_image_size(content, max_image_size_bytes)
+ except R2RException:
+ # Re-raise validation exceptions
+ raise
+ except Exception as e:
+ # Handle unexpected errors during validation
+ logger.error(f"Error validating image: {str(e)}")
+ raise R2RException(
+ status_code=400, message=f"Invalid image data: {str(e)}"
+ ) from e
+
+ # 1) Validate that conversation and parent exist (existing code)
+ conv_check_query = f"""
+ SELECT 1 FROM {self._get_table_name("conversations")}
+ WHERE id = $1
+ """
+ conv_row = await self.connection_manager.fetchrow_query(
+ conv_check_query, [conversation_id]
+ )
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ if parent_id:
+ parent_check_query = f"""
+ SELECT 1 FROM {self._get_table_name("messages")}
+ WHERE id = $1 AND conversation_id = $2
+ """
+ parent_row = await self.connection_manager.fetchrow_query(
+ parent_check_query, [parent_id, conversation_id]
+ )
+ if not parent_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
+ )
+
+ # 2) Add image info to metadata for tracking/analytics if images are present
+ metadata = metadata or {}
+ if hasattr(content, "image_url") and content.image_url:
+ metadata["has_image"] = True
+ metadata["image_type"] = "url"
+ elif hasattr(content, "image_data") and content.image_data:
+ metadata["has_image"] = True
+ metadata["image_type"] = "base64"
+ # Don't store the actual base64 data in metadata as it would be redundant
+
+ # 3) Convert the content & metadata to JSON strings
+ message_id = uuid4()
+ # Using safe_dumps to handle any type of serialization
+ content_str = safe_dumps(content.model_dump())
+ metadata_str = safe_dumps(metadata)
+
+ # 4) Insert the message (existing code)
+ query = f"""
+ INSERT INTO {self._get_table_name("messages")}
+ (id, conversation_id, parent_id, content, created_at, metadata)
+ VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
+ RETURNING id
+ """
+ inserted = await self.connection_manager.fetchrow_query(
+ query,
+ [
+ message_id,
+ conversation_id,
+ parent_id,
+ content_str,
+ metadata_str,
+ ],
+ )
+ if not inserted:
+ raise R2RException(
+ status_code=500, message="Failed to insert message."
+ )
+
+ return MessageResponse(id=message_id, message=content)
+
+ async def edit_message(
+ self,
+ message_id: UUID,
+ new_content: str | None = None,
+ additional_metadata: dict | None = None,
+ ) -> dict[str, Any]:
+ # Get the original message
+ query = f"""
+ SELECT conversation_id, parent_id, content, metadata, created_at
+ FROM {self._get_table_name("messages")}
+ WHERE id = $1
+ """
+ row = await self.connection_manager.fetchrow_query(query, [message_id])
+ if not row:
+ raise R2RException(
+ status_code=404,
+ message=f"Message {message_id} not found.",
+ )
+
+ old_content = json.loads(row["content"])
+ old_metadata = json.loads(row["metadata"])
+
+ if new_content is not None:
+ old_message = Message(**old_content)
+ edited_message = Message(
+ role=old_message.role,
+ content=new_content,
+ name=old_message.name,
+ function_call=old_message.function_call,
+ tool_calls=old_message.tool_calls,
+ # Preserve image content if it exists
+ image_url=getattr(old_message, "image_url", None),
+ image_data=getattr(old_message, "image_data", None),
+ )
+ content_to_save = edited_message.model_dump()
+ else:
+ content_to_save = old_content
+
+ additional_metadata = additional_metadata or {}
+
+ new_metadata = {
+ **old_metadata,
+ **additional_metadata,
+ "edited": (
+ True
+ if new_content is not None
+ else old_metadata.get("edited", False)
+ ),
+ }
+
+ # Update message without changing the timestamp
+ update_query = f"""
+ UPDATE {self._get_table_name("messages")}
+ SET content = $1::jsonb,
+ metadata = $2::jsonb,
+ created_at = $3
+ WHERE id = $4
+ RETURNING id
+ """
+ updated = await self.connection_manager.fetchrow_query(
+ update_query,
+ [
+ json.dumps(content_to_save),
+ json.dumps(new_metadata),
+ row["created_at"],
+ message_id,
+ ],
+ )
+ if not updated:
+ raise R2RException(
+ status_code=500, message="Failed to update message."
+ )
+
+ return {
+ "id": str(message_id),
+ "message": (
+ Message(**content_to_save)
+ if isinstance(content_to_save, dict)
+ else content_to_save
+ ),
+ "metadata": new_metadata,
+ }
+
+ async def update_message_metadata(
+ self, message_id: UUID, metadata: dict
+ ) -> None:
+ # Fetch current metadata
+ query = f"""
+ SELECT metadata FROM {self._get_table_name("messages")}
+ WHERE id = $1
+ """
+ row = await self.connection_manager.fetchrow_query(query, [message_id])
+ if not row:
+ raise R2RException(
+ status_code=404, message=f"Message {message_id} not found."
+ )
+
+ current_metadata = json.loads(row["metadata"]) or {}
+ updated_metadata = {**current_metadata, **metadata}
+
+ update_query = f"""
+ UPDATE {self._get_table_name("messages")}
+ SET metadata = $1::jsonb
+ WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ update_query, [json.dumps(updated_metadata), message_id]
+ )
+
+ async def get_conversation(
+ self,
+ conversation_id: UUID,
+ filter_user_ids: Optional[list[UUID]] = None,
+ ) -> list[MessageResponse]:
+ # Existing validation code remains the same
+ conditions = ["c.id = $1"]
+ params: list = [conversation_id]
+
+ if filter_user_ids:
+ param_index = 2
+ conditions.append(f"""
+ c.user_id IN (
+ SELECT id
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+
+ query = f"""
+ SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch
+ FROM {self._get_table_name("conversations")} c
+ WHERE {" AND ".join(conditions)}
+ """
+
+ conv_row = await self.connection_manager.fetchrow_query(query, params)
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ # Retrieve messages in chronological order
+ msg_query = f"""
+ SELECT id, content, metadata
+ FROM {self._get_table_name("messages")}
+ WHERE conversation_id = $1
+ ORDER BY created_at ASC
+ """
+ results = await self.connection_manager.fetch_query(
+ msg_query, [conversation_id]
+ )
+
+ response_messages = []
+ for row in results:
+ try:
+ # Parse the message content
+ content_json = json.loads(row["content"])
+ # Create a Message object with the parsed content
+ message = Message(**content_json)
+ # Create a MessageResponse
+ response_messages.append(
+ MessageResponse(
+ id=row["id"],
+ message=message,
+ metadata=json.loads(row["metadata"]),
+ )
+ )
+ except Exception as e:
+ # If there's an error parsing the message (e.g., due to version mismatch),
+ # log it and create a fallback message
+ logger.warning(f"Error parsing message {row['id']}: {str(e)}")
+ fallback_content = content_json.get(
+ "content", "Message could not be loaded"
+ )
+ fallback_role = content_json.get("role", "assistant")
+
+ # Create a basic fallback message
+ fallback_message = Message(
+ role=fallback_role,
+ content=f"[Message format incompatible: {fallback_content}]",
+ )
+
+ response_messages.append(
+ MessageResponse(
+ id=row["id"],
+ message=fallback_message,
+ metadata=json.loads(row["metadata"]),
+ )
+ )
+
+ return response_messages
+
+ async def update_conversation(
+ self, conversation_id: UUID, name: str
+ ) -> ConversationResponse:
+ try:
+ # Check if conversation exists
+ conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
+ conv_row = await self.connection_manager.fetchrow_query(
+ conv_query, [conversation_id]
+ )
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ update_query = f"""
+ UPDATE {self._get_table_name("conversations")}
+ SET name = $1 WHERE id = $2
+ RETURNING user_id, extract(epoch from created_at) as created_at_epoch
+ """
+ updated_row = await self.connection_manager.fetchrow_query(
+ update_query, [name, conversation_id]
+ )
+ return ConversationResponse(
+ id=conversation_id,
+ created_at=updated_row["created_at_epoch"],
+ user_id=updated_row["user_id"] or None,
+ name=name,
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to update conversation: {str(e)}",
+ ) from e
+
+ async def delete_conversation(
+ self,
+ conversation_id: UUID,
+ filter_user_ids: Optional[list[UUID]] = None,
+ ) -> None:
+ conditions = ["c.id = $1"]
+ params: list = [conversation_id]
+
+ if filter_user_ids:
+ param_index = 2
+ conditions.append(f"""
+ c.user_id IN (
+ SELECT id
+ FROM {self.project_name}.users
+ WHERE id = ANY(${param_index})
+ )
+ """)
+ params.append(filter_user_ids)
+
+ conv_query = f"""
+ SELECT 1
+ FROM {self._get_table_name("conversations")} c
+ WHERE {" AND ".join(conditions)}
+ """
+ conv_row = await self.connection_manager.fetchrow_query(
+ conv_query, params
+ )
+ if not conv_row:
+ raise R2RException(
+ status_code=404,
+ message=f"Conversation {conversation_id} not found.",
+ )
+
+ # Delete all messages
+ del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
+ await self.connection_manager.execute_query(
+ del_messages_query, [conversation_id]
+ )
+
+ # Delete conversation
+ del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
+ await self.connection_manager.execute_query(
+ del_conv_query, [conversation_id]
+ )
+
+ async def export_conversations_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "user_id",
+ "created_at",
+ "name",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ user_id::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ name
+ FROM {self._get_table_name("conversations")}
+ """
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "user_id": row[1],
+ "created_at": row[2],
+ "name": row[3],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+ async def export_messages_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ handle_images: str = "metadata_only", # Options: "full", "metadata_only", "exclude"
+ ) -> tuple[str, IO]:
+ """
+ Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+
+ Args:
+ columns: List of columns to include in export
+ filters: Filter criteria for messages
+ include_header: Whether to include header row
+ handle_images: How to handle image data in exports:
+ - "full": Include complete image data (warning: may create large files)
+ - "metadata_only": Replace image data with metadata only
+ - "exclude": Remove image data completely
+ """
+ valid_columns = {
+ "id",
+ "conversation_id",
+ "parent_id",
+ "content",
+ "metadata",
+ "created_at",
+ "has_image", # New virtual column to indicate image presence
+ }
+
+ if not columns:
+ columns = list(valid_columns - {"has_image"})
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ # Add virtual column for image presence
+ virtual_columns = []
+ has_image_column = False
+
+ if "has_image" in columns:
+ virtual_columns.append(
+ "(content->>'image_url' IS NOT NULL OR content->>'image_data' IS NOT NULL) as has_image"
+ )
+ columns.remove("has_image")
+ has_image_column = True
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ conversation_id::text,
+ parent_id::text,
+ content::text,
+ metadata::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at
+ {", " + ", ".join(virtual_columns) if virtual_columns else ""}
+ FROM {self._get_table_name("messages")}
+ """
+
+ # Keep existing filter conditions setup
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns or field == "has_image":
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ # Special filter for has_image
+ if filters and "has_image" in filters:
+ if filters["has_image"]:
+ conditions.append(
+ "(content->>'image_url' IS NOT NULL OR content->>'image_data' IS NOT NULL)"
+ )
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ # Prepare export columns
+ export_columns = list(columns)
+ if has_image_column:
+ export_columns.append("has_image")
+
+ if include_header:
+ writer.writerow(export_columns)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "conversation_id": row[1],
+ "parent_id": row[2],
+ "content": row[3],
+ "metadata": row[4],
+ "created_at": row[5],
+ }
+
+ # Add virtual column if present
+ if has_image_column:
+ row_dict["has_image"] = (
+ "true" if row[6] else "false"
+ )
+
+ # Process image data based on handle_images setting
+ if (
+ "content" in columns
+ and handle_images != "full"
+ ):
+ try:
+ content_json = json.loads(
+ row_dict["content"]
+ )
+
+ if (
+ "image_data" in content_json
+ and content_json["image_data"]
+ ):
+ media_type = content_json[
+ "image_data"
+ ].get("media_type", "image/jpeg")
+
+ if handle_images == "metadata_only":
+ content_json["image_data"] = {
+ "media_type": media_type,
+ "data": "[BASE64_DATA_EXCLUDED_FROM_EXPORT]",
+ }
+ elif handle_images == "exclude":
+ content_json.pop(
+ "image_data", None
+ )
+
+ row_dict["content"] = json.dumps(
+ content_json
+ )
+ except (json.JSONDecodeError, TypeError) as e:
+ logger.warning(
+ f"Error processing message content for export: {e}"
+ )
+
+ writer.writerow(
+ [row_dict[col] for col in export_columns]
+ )
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/documents.py b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
new file mode 100644
index 00000000..19781037
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
@@ -0,0 +1,1172 @@
+import asyncio
+import copy
+import csv
+import json
+import logging
+import math
+import tempfile
+from typing import IO, Any, Optional
+from uuid import UUID
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import (
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ Handler,
+ IngestionStatus,
+ R2RException,
+ SearchSettings,
+)
+
+from .base import PostgresConnectionManager
+from .filters import apply_filters
+
+logger = logging.getLogger()
+
+
+def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]:
+ """Recursively transform filter field names by replacing 'document_id' with
+ 'id'. Handles nested logical operators like $and, $or, etc.
+
+ Args:
+ filters (dict[str, Any]): The original filters dictionary
+
+ Returns:
+ dict[str, Any]: A new dictionary with transformed field names
+ """
+ if not filters:
+ return {}
+
+ transformed = {}
+
+ for key, value in filters.items():
+ # Handle logical operators recursively
+ if key in ("$and", "$or", "$not"):
+ if isinstance(value, list):
+ transformed[key] = [
+ transform_filter_fields(item) for item in value
+ ]
+ else:
+ transformed[key] = transform_filter_fields(value) # type: ignore
+ continue
+
+ # Replace 'document_id' with 'id'
+ new_key = "id" if key == "document_id" else key
+
+ # Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.)
+ if isinstance(value, dict):
+ transformed[new_key] = transform_filter_fields(value) # type: ignore
+ else:
+ transformed[new_key] = value
+
+ logger.debug(f"Transformed filters from {filters} to {transformed}")
+ return transformed
+
+
+class PostgresDocumentsHandler(Handler):
+ TABLE_NAME = "documents"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ dimension: int | float,
+ ):
+ self.dimension = dimension
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self):
+ logger.info(
+ f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ )
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ vector_type = f"vector{vector_dim}"
+
+ try:
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY,
+ collection_ids UUID[],
+ owner_id UUID,
+ type TEXT,
+ metadata JSONB,
+ title TEXT,
+ summary TEXT NULL,
+ summary_embedding {vector_type} NULL,
+ version TEXT,
+ size_in_bytes INT,
+ ingestion_status TEXT DEFAULT 'pending',
+ extraction_status TEXT DEFAULT 'pending',
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ ingestion_attempt_number INT DEFAULT 0,
+ raw_tsvector tsvector GENERATED ALWAYS AS (
+ setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
+ setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
+ setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
+ ) STORED,
+ total_tokens INT DEFAULT 0
+ );
+ CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
+ ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
+
+ -- Full text search index
+ CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
+ ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ USING GIN (raw_tsvector);
+ """
+ await self.connection_manager.execute_query(query)
+
+ # ---------------------------------------------------------------
+ # Now check if total_tokens column exists in the 'documents' table
+ # ---------------------------------------------------------------
+ # 1) See what columns exist
+ # column_check_query = f"""
+ # SELECT column_name
+ # FROM information_schema.columns
+ # WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}'
+ # AND table_schema = CURRENT_SCHEMA()
+ # """
+ # existing_columns = await self.connection_manager.fetch_query(column_check_query)
+ # 2) Parse the table name for schema checks
+ table_full_name = self._get_table_name(
+ PostgresDocumentsHandler.TABLE_NAME
+ )
+ parsed_schema = "public"
+ parsed_table_name = table_full_name
+ if "." in table_full_name:
+ parts = table_full_name.split(".", maxsplit=1)
+ parsed_schema = parts[0].replace('"', "").strip()
+ parsed_table_name = parts[1].replace('"', "").strip()
+ else:
+ parsed_table_name = parsed_table_name.replace('"', "").strip()
+
+ # 3) Check columns
+ column_check_query = f"""
+ SELECT column_name
+ FROM information_schema.columns
+ WHERE table_name = '{parsed_table_name}'
+ AND table_schema = '{parsed_schema}'
+ """
+ existing_columns = await self.connection_manager.fetch_query(
+ column_check_query
+ )
+
+ existing_column_names = {
+ row["column_name"] for row in existing_columns
+ }
+
+ if "total_tokens" not in existing_column_names:
+ # 2) If missing, see if the table already has data
+ # doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ # doc_count = await self.connection_manager.fetchval(doc_count_query)
+ doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ row = await self.connection_manager.fetchrow_query(
+ doc_count_query
+ )
+ if row is None:
+ doc_count = 0
+ else:
+ doc_count = row[
+ "doc_count"
+ ] # or row[0] if you prefer positional indexing
+
+ if doc_count > 0:
+ # We already have documents, but no total_tokens column
+ # => ask user to run r2r db migrate
+ logger.warning(
+ "Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files."
+ )
+
+ create_tokens_col = f"""
+ ALTER TABLE {table_full_name}
+ ADD COLUMN total_tokens INT DEFAULT 0
+ """
+ await self.connection_manager.execute_query(create_tokens_col)
+
+ except Exception as e:
+ logger.warning(f"Error {e} when creating document table.")
+ raise e
+
+ async def upsert_documents_overview(
+ self, documents_overview: DocumentResponse | list[DocumentResponse]
+ ) -> None:
+ if isinstance(documents_overview, DocumentResponse):
+ documents_overview = [documents_overview]
+
+ # TODO: make this an arg
+ max_retries = 20
+ for document in documents_overview:
+ retries = 0
+ while retries < max_retries:
+ try:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ async with conn.transaction():
+ # Lock the row for update
+ check_query = f"""
+ SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE id = $1 FOR UPDATE
+ """
+ existing_doc = await conn.fetchrow(
+ check_query, document.id
+ )
+
+ db_entry = document.convert_to_db_entry()
+
+ if existing_doc:
+ db_version = existing_doc[
+ "ingestion_attempt_number"
+ ]
+ db_status = existing_doc["ingestion_status"]
+ new_version = db_entry[
+ "ingestion_attempt_number"
+ ]
+
+ # Only increment version if status is changing to 'success' or if it's a new version
+ if (
+ db_status != "success"
+ and db_entry["ingestion_status"]
+ == "success"
+ ) or (new_version > db_version):
+ new_attempt_number = db_version + 1
+ else:
+ new_attempt_number = db_version
+
+ db_entry["ingestion_attempt_number"] = (
+ new_attempt_number
+ )
+
+ update_query = f"""
+ UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ SET collection_ids = $1,
+ owner_id = $2,
+ type = $3,
+ metadata = $4,
+ title = $5,
+ version = $6,
+ size_in_bytes = $7,
+ ingestion_status = $8,
+ extraction_status = $9,
+ updated_at = $10,
+ ingestion_attempt_number = $11,
+ summary = $12,
+ summary_embedding = $13,
+ total_tokens = $14
+ WHERE id = $15
+ """
+
+ await conn.execute(
+ update_query,
+ db_entry["collection_ids"],
+ db_entry["owner_id"],
+ db_entry["document_type"],
+ db_entry["metadata"],
+ db_entry["title"],
+ db_entry["version"],
+ db_entry["size_in_bytes"],
+ db_entry["ingestion_status"],
+ db_entry["extraction_status"],
+ db_entry["updated_at"],
+ db_entry["ingestion_attempt_number"],
+ db_entry["summary"],
+ db_entry["summary_embedding"],
+ db_entry[
+ "total_tokens"
+ ], # pass the new field here
+ document.id,
+ )
+ else:
+ insert_query = f"""
+ INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ (id, collection_ids, owner_id, type, metadata, title, version,
+ size_in_bytes, ingestion_status, extraction_status, created_at,
+ updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
+ """
+ await conn.execute(
+ insert_query,
+ db_entry["id"],
+ db_entry["collection_ids"],
+ db_entry["owner_id"],
+ db_entry["document_type"],
+ db_entry["metadata"],
+ db_entry["title"],
+ db_entry["version"],
+ db_entry["size_in_bytes"],
+ db_entry["ingestion_status"],
+ db_entry["extraction_status"],
+ db_entry["created_at"],
+ db_entry["updated_at"],
+ db_entry["ingestion_attempt_number"],
+ db_entry["summary"],
+ db_entry["summary_embedding"],
+ db_entry["total_tokens"],
+ )
+
+ break # Success, exit the retry loop
+ except (
+ asyncpg.exceptions.UniqueViolationError,
+ asyncpg.exceptions.DeadlockDetectedError,
+ ) as e:
+ retries += 1
+ if retries == max_retries:
+ logger.error(
+ f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
+ )
+ raise
+ else:
+ wait_time = 0.1 * (2**retries) # Exponential backoff
+ await asyncio.sleep(wait_time)
+
+ async def delete(
+ self, document_id: UUID, version: Optional[str] = None
+ ) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+
+ params = [str(document_id)]
+
+ if version:
+ query += " AND version = $2"
+ params.append(version)
+
+ await self.connection_manager.execute_query(query=query, params=params)
+
+ async def _get_status_from_table(
+ self,
+ ids: list[UUID],
+ table_name: str,
+ status_type: str,
+ column_name: str,
+ ):
+ """Get the workflow status for a given document or list of documents.
+
+ Args:
+ ids (list[UUID]): The document IDs.
+ table_name (str): The table name.
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The workflow status for the given document or list of documents.
+ """
+ query = f"""
+ SELECT {status_type} FROM {self._get_table_name(table_name)}
+ WHERE {column_name} = ANY($1)
+ """
+ return [
+ row[status_type]
+ for row in await self.connection_manager.fetch_query(query, [ids])
+ ]
+
+ async def _get_ids_from_table(
+ self,
+ status: list[str],
+ table_name: str,
+ status_type: str,
+ collection_id: Optional[UUID] = None,
+ ):
+ """Get the IDs from a given table.
+
+ Args:
+ status (str | list[str]): The status or list of statuses to retrieve.
+ table_name (str): The table name.
+ status_type (str): The type of status to retrieve.
+ """
+ query = f"""
+ SELECT id FROM {self._get_table_name(table_name)}
+ WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
+ """
+ records = await self.connection_manager.fetch_query(
+ query, [status, collection_id]
+ )
+ return [record["id"] for record in records]
+
+ async def _set_status_in_table(
+ self,
+ ids: list[UUID],
+ status: str,
+ table_name: str,
+ status_type: str,
+ column_name: str,
+ ):
+ """Set the workflow status for a given document or list of documents.
+
+ Args:
+ ids (list[UUID]): The document IDs.
+ status (str): The status to set.
+ table_name (str): The table name.
+ status_type (str): The type of status to set.
+ column_name (str): The column name in the table to update.
+ """
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {status_type} = $1
+ WHERE {column_name} = Any($2)
+ """
+ await self.connection_manager.execute_query(query, [status, ids])
+
+ def _get_status_model(self, status_type: str):
+ """Get the status model for a given status type.
+
+ Args:
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The status model for the given status type.
+ """
+ if status_type == "ingestion":
+ return IngestionStatus
+ elif status_type == "extraction_status":
+ return GraphExtractionStatus
+ elif status_type in {"graph_cluster_status", "graph_sync_status"}:
+ return GraphConstructionStatus
+ else:
+ raise R2RException(
+ status_code=400, message=f"Invalid status type: {status_type}"
+ )
+
+ async def get_workflow_status(
+ self, id: UUID | list[UUID], status_type: str
+ ):
+ """Get the workflow status for a given document or list of documents.
+
+ Args:
+ id (UUID | list[UUID]): The document ID or list of document IDs.
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The workflow status for the given document or list of documents.
+ """
+
+ ids = [id] if isinstance(id, UUID) else id
+ out_model = self._get_status_model(status_type)
+ result = await self._get_status_from_table(
+ ids,
+ out_model.table_name(),
+ status_type,
+ out_model.id_column(),
+ )
+
+ result = [out_model[status.upper()] for status in result]
+ return result[0] if isinstance(id, UUID) else result
+
+ async def set_workflow_status(
+ self, id: UUID | list[UUID], status_type: str, status: str
+ ):
+ """Set the workflow status for a given document or list of documents.
+
+ Args:
+ id (UUID | list[UUID]): The document ID or list of document IDs.
+ status_type (str): The type of status to set.
+ status (str): The status to set.
+ """
+ ids = [id] if isinstance(id, UUID) else id
+ out_model = self._get_status_model(status_type)
+
+ return await self._set_status_in_table(
+ ids,
+ status,
+ out_model.table_name(),
+ status_type,
+ out_model.id_column(),
+ )
+
+ async def get_document_ids_by_status(
+ self,
+ status_type: str,
+ status: str | list[str],
+ collection_id: Optional[UUID] = None,
+ ):
+ """Get the IDs for a given status.
+
+ Args:
+ ids_key (str): The key to retrieve the IDs.
+ status_type (str): The type of status to retrieve.
+ status (str | list[str]): The status or list of statuses to retrieve.
+ """
+
+ if isinstance(status, str):
+ status = [status]
+
+ out_model = self._get_status_model(status_type)
+ return await self._get_ids_from_table(
+ status, out_model.table_name(), status_type, collection_id
+ )
+
+ async def get_documents_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ include_summary_embedding: Optional[bool] = True,
+ filters: Optional[dict[str, Any]] = None,
+ sort_order: str = "DESC", # Add this parameter with a default of DESC
+ ) -> dict[str, Any]:
+ """Fetch overviews of documents with optional offset/limit pagination.
+
+ You can use either:
+ - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids`
+ - A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`.
+
+ If both the `filters` dict and any of the traditional filter arguments are provided,
+ this method will raise an error.
+ """
+
+ filters = copy.deepcopy(filters)
+ filters = transform_filter_fields(filters) # type: ignore
+
+ # Safety check: We do not allow mixing the old filter arguments with the new `filters` dict.
+ # This keeps the query logic unambiguous.
+ if filters and any(
+ [
+ filter_user_ids,
+ filter_document_ids,
+ filter_collection_ids,
+ ]
+ ):
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ "Cannot use both the 'filters' dictionary "
+ "and the 'filter_*_ids' parameters simultaneously."
+ ),
+ )
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ # -------------------------------------------
+ # 1) If using the new `filters` dict approach
+ # -------------------------------------------
+ if filters:
+ # Apply the filters to generate a WHERE clause
+ filter_condition, filter_params = apply_filters(
+ filters, params, mode="condition_only"
+ )
+ if filter_condition:
+ conditions.append(filter_condition)
+ # Make sure we keep adding to the same params list
+ # params.extend(filter_params)
+ param_index += len(filter_params)
+
+ # -------------------------------------------
+ # 2) If using the old filter_*_ids approach
+ # -------------------------------------------
+ else:
+ # Handle document IDs with AND
+ if filter_document_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(filter_document_ids)
+ param_index += 1
+
+ # For owner/collection filters, we used OR logic previously
+ # so we combine them into a single sub-condition in parentheses
+ or_conditions = []
+ if filter_user_ids:
+ or_conditions.append(f"owner_id = ANY(${param_index})")
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if filter_collection_ids:
+ or_conditions.append(f"collection_ids && ${param_index}")
+ params.append(filter_collection_ids)
+ param_index += 1
+
+ if or_conditions:
+ conditions.append(f"({' OR '.join(or_conditions)})")
+
+ # -------------------------
+ # Build the full query
+ # -------------------------
+ base_query = (
+ f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ )
+ if conditions:
+ # Combine everything with AND
+ base_query += " WHERE " + " AND ".join(conditions)
+
+ # Construct SELECT fields (including total_entries via window function)
+ select_fields = """
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ COUNT(*) OVER() AS total_entries
+ """
+
+ query = f"""
+ {select_fields}
+ {base_query}
+ ORDER BY created_at {sort_order}
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+ param_index += 1
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+ total_entries = results[0]["total_entries"] if results else 0
+
+ documents = []
+ for row in results:
+ # Safely handle the embedding
+ embedding = None
+ if (
+ "summary_embedding" in row
+ and row["summary_embedding"] is not None
+ ):
+ try:
+ # The embedding is stored as a string like "[0.1, 0.2, ...]"
+ embedding_str = row["summary_embedding"]
+ if embedding_str.startswith(
+ "["
+ ) and embedding_str.endswith("]"):
+ embedding = [
+ float(x)
+ for x in embedding_str[1:-1].split(",")
+ if x
+ ]
+ except Exception as e:
+ logger.warning(
+ f"Failed to parse embedding for document {row['id']}: {e}"
+ )
+
+ documents.append(
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata=json.loads(row["metadata"]),
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(
+ row["ingestion_status"]
+ ),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"] if "summary" in row else None,
+ summary_embedding=(
+ embedding if include_summary_embedding else None
+ ),
+ total_tokens=row["total_tokens"],
+ )
+ )
+ return {"results": documents, "total_entries": total_entries}
+ except Exception as e:
+ logger.error(f"Error in get_documents_overview: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail="Database query failed",
+ ) from e
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: list[dict],
+ overwrite: bool = False,
+ ) -> DocumentResponse:
+ """
+ Update the metadata of a document, either by appending to existing metadata or overwriting it.
+ Accepts a list of metadata dictionaries.
+ """
+
+ doc_result = await self.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+
+ if not doc_result["results"]:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Document with ID {document_id} not found",
+ )
+
+ existing_doc = doc_result["results"][0]
+
+ if overwrite:
+ combined_metadata: dict[str, Any] = {}
+ for meta_item in metadata:
+ combined_metadata |= meta_item
+ existing_doc.metadata = combined_metadata
+ else:
+ for meta_item in metadata:
+ existing_doc.metadata.update(meta_item)
+
+ await self.upsert_documents_overview(existing_doc)
+
+ return existing_doc
+
+ async def semantic_document_search(
+ self, query_embedding: list[float], search_settings: SearchSettings
+ ) -> list[DocumentResponse]:
+ """Search documents using semantic similarity with their summary
+ embeddings."""
+
+ where_clauses = ["summary_embedding IS NOT NULL"]
+ params: list[str | int | bytes] = [str(query_embedding)]
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ filters = copy.deepcopy(search_settings.filters)
+ if filters:
+ filter_condition, params = apply_filters(
+ transform_filter_fields(filters), params, mode="condition_only"
+ )
+ if filter_condition:
+ where_clauses.append(filter_condition)
+
+ where_clause = " AND ".join(where_clauses)
+
+ query = f"""
+ WITH document_scores AS (
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ (summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
+ FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ ORDER BY semantic_distance ASC
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ SELECT *,
+ 1.0 - semantic_distance as semantic_score
+ FROM document_scores
+ """
+
+ params.extend([search_settings.limit, search_settings.offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata={
+ **(
+ json.loads(row["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": float(row["semantic_score"]),
+ "search_type": "semantic",
+ },
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ summary_embedding=[
+ float(x)
+ for x in row["summary_embedding"][1:-1].split(",")
+ if x
+ ],
+ total_tokens=row["total_tokens"],
+ )
+ for row in results
+ ]
+
+ async def full_text_document_search(
+ self, query_text: str, search_settings: SearchSettings
+ ) -> list[DocumentResponse]:
+ """Enhanced full-text search using generated tsvector."""
+
+ where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
+ params: list[str | int | bytes] = [query_text]
+
+ filters = copy.deepcopy(search_settings.filters)
+ if filters:
+ filter_condition, params = apply_filters(
+ transform_filter_fields(filters), params, mode="condition_only"
+ )
+ if filter_condition:
+ where_clauses.append(filter_condition)
+
+ where_clause = " AND ".join(where_clauses)
+
+ query = f"""
+ WITH document_scores AS (
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
+ FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ ORDER BY text_score DESC
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ SELECT * FROM document_scores
+ """
+
+ params.extend([search_settings.limit, search_settings.offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata={
+ **(
+ json.loads(row["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": float(row["text_score"]),
+ "search_type": "full_text",
+ },
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ summary_embedding=(
+ [
+ float(x)
+ for x in row["summary_embedding"][1:-1].split(",")
+ if x
+ ]
+ if row["summary_embedding"]
+ else None
+ ),
+ total_tokens=row["total_tokens"],
+ )
+ for row in results
+ ]
+
+ async def hybrid_document_search(
+ self,
+ query_text: str,
+ query_embedding: list[float],
+ search_settings: SearchSettings,
+ ) -> list[DocumentResponse]:
+ """Search documents using both semantic and full-text search with RRF
+ fusion."""
+
+ # Get more results than needed for better fusion
+ extended_settings = copy.deepcopy(search_settings)
+ extended_settings.limit = search_settings.limit * 3
+
+ # Get results from both search methods
+ semantic_results = await self.semantic_document_search(
+ query_embedding, extended_settings
+ )
+ full_text_results = await self.full_text_document_search(
+ query_text, extended_settings
+ )
+
+ # Combine results using RRF
+ doc_scores: dict[str, dict] = {}
+
+ # Process semantic results
+ for rank, result in enumerate(semantic_results, 1):
+ doc_id = str(result.id)
+ doc_scores[doc_id] = {
+ "semantic_rank": rank,
+ "full_text_rank": len(full_text_results)
+ + 1, # Default rank if not found
+ "data": result,
+ }
+
+ # Process full-text results
+ for rank, result in enumerate(full_text_results, 1):
+ doc_id = str(result.id)
+ if doc_id in doc_scores:
+ doc_scores[doc_id]["full_text_rank"] = rank
+ else:
+ doc_scores[doc_id] = {
+ "semantic_rank": len(semantic_results)
+ + 1, # Default rank if not found
+ "full_text_rank": rank,
+ "data": result,
+ }
+
+ # Calculate RRF scores using hybrid search settings
+ rrf_k = search_settings.hybrid_settings.rrf_k
+ semantic_weight = search_settings.hybrid_settings.semantic_weight
+ full_text_weight = search_settings.hybrid_settings.full_text_weight
+
+ for scores in doc_scores.values():
+ semantic_score = 1 / (rrf_k + scores["semantic_rank"])
+ full_text_score = 1 / (rrf_k + scores["full_text_rank"])
+
+ # Weighted combination
+ combined_score = (
+ semantic_score * semantic_weight
+ + full_text_score * full_text_weight
+ ) / (semantic_weight + full_text_weight)
+
+ scores["final_score"] = combined_score
+
+ # Sort by final score and apply offset/limit
+ sorted_results = sorted(
+ doc_scores.values(), key=lambda x: x["final_score"], reverse=True
+ )[
+ search_settings.offset : search_settings.offset
+ + search_settings.limit
+ ]
+
+ return [
+ DocumentResponse(
+ **{
+ **result["data"].__dict__,
+ "metadata": {
+ **(
+ result["data"].metadata
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": result["final_score"],
+ "semantic_rank": result["semantic_rank"],
+ "full_text_rank": result["full_text_rank"],
+ "search_type": "hybrid",
+ },
+ }
+ )
+ for result in sorted_results
+ ]
+
+ async def search_documents(
+ self,
+ query_text: str,
+ query_embedding: Optional[list[float]] = None,
+ settings: Optional[SearchSettings] = None,
+ ) -> list[DocumentResponse]:
+ """Main search method that delegates to the appropriate search method
+ based on settings."""
+ if settings is None:
+ settings = SearchSettings()
+
+ if (
+ settings.use_semantic_search and settings.use_fulltext_search
+ ) or settings.use_hybrid_search:
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding is required for hybrid search"
+ )
+ return await self.hybrid_document_search(
+ query_text, query_embedding, settings
+ )
+ elif settings.use_semantic_search:
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding is required for vector search"
+ )
+ return await self.semantic_document_search(
+ query_embedding, settings
+ )
+ else:
+ return await self.full_text_document_search(query_text, settings)
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "collection_ids",
+ "owner_id",
+ "type",
+ "metadata",
+ "title",
+ "summary",
+ "version",
+ "size_in_bytes",
+ "ingestion_status",
+ "extraction_status",
+ "created_at",
+ "updated_at",
+ "total_tokens",
+ }
+ filters = copy.deepcopy(filters)
+ filters = transform_filter_fields(filters) # type: ignore
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ collection_ids::text,
+ owner_id::text,
+ type::text,
+ metadata::text AS metadata,
+ title,
+ summary,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ total_tokens
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "collection_ids": row[1],
+ "owner_id": row[2],
+ "type": row[3],
+ "metadata": row[4],
+ "title": row[5],
+ "summary": row[6],
+ "version": row[7],
+ "size_in_bytes": row[8],
+ "ingestion_status": row[9],
+ "extraction_status": row[10],
+ "created_at": row[11],
+ "updated_at": row[12],
+ "total_tokens": row[13],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/files.py b/.venv/lib/python3.12/site-packages/core/providers/database/files.py
new file mode 100644
index 00000000..dc349a7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/files.py
@@ -0,0 +1,334 @@
+import io
+import logging
+from datetime import datetime
+from io import BytesIO
+from typing import BinaryIO, Optional
+from uuid import UUID
+from zipfile import ZipFile
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import Handler, R2RException
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresFilesHandler(Handler):
+ """PostgreSQL implementation of the FileHandler."""
+
+ TABLE_NAME = "files"
+
+ connection_manager: PostgresConnectionManager
+
+ async def create_tables(self) -> None:
+ """Create the necessary tables for file storage."""
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} (
+ document_id UUID PRIMARY KEY,
+ name TEXT NOT NULL,
+ oid OID NOT NULL,
+ size BIGINT NOT NULL,
+ type TEXT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ -- Create trigger for updating the updated_at timestamp
+ CREATE OR REPLACE FUNCTION {self.project_name}.update_files_updated_at()
+ RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ DROP TRIGGER IF EXISTS update_files_updated_at
+ ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)};
+
+ CREATE TRIGGER update_files_updated_at
+ BEFORE UPDATE ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ FOR EACH ROW
+ EXECUTE FUNCTION {self.project_name}.update_files_updated_at();
+ """
+ await self.connection_manager.execute_query(query)
+
+ async def upsert_file(
+ self,
+ document_id: UUID,
+ file_name: str,
+ file_oid: int,
+ file_size: int,
+ file_type: Optional[str] = None,
+ ) -> None:
+ """Add or update a file entry in storage."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ (document_id, name, oid, size, type)
+ VALUES ($1, $2, $3, $4, $5)
+ ON CONFLICT (document_id) DO UPDATE SET
+ name = EXCLUDED.name,
+ oid = EXCLUDED.oid,
+ size = EXCLUDED.size,
+ type = EXCLUDED.type,
+ updated_at = NOW();
+ """
+ await self.connection_manager.execute_query(
+ query, [document_id, file_name, file_oid, file_size, file_type]
+ )
+
+ async def store_file(
+ self,
+ document_id: UUID,
+ file_name: str,
+ file_content: io.BytesIO,
+ file_type: Optional[str] = None,
+ ) -> None:
+ """Store a new file in the database."""
+ size = file_content.getbuffer().nbytes
+
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ async with conn.transaction():
+ oid = await conn.fetchval("SELECT lo_create(0)")
+ await self._write_lobject(conn, oid, file_content)
+ await self.upsert_file(
+ document_id, file_name, oid, size, file_type
+ )
+
+ async def _write_lobject(
+ self, conn, oid: int, file_content: io.BytesIO
+ ) -> None:
+ """Write content to a large object."""
+ lobject = await conn.fetchval("SELECT lo_open($1, $2)", oid, 0x20000)
+
+ try:
+ chunk_size = 8192 # 8 KB chunks
+ while True:
+ if chunk := file_content.read(chunk_size):
+ await conn.execute(
+ "SELECT lowrite($1, $2)", lobject, chunk
+ )
+ else:
+ break
+
+ await conn.execute("SELECT lo_close($1)", lobject)
+
+ except Exception as e:
+ await conn.execute("SELECT lo_unlink($1)", oid)
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to write to large object: {e}",
+ ) from e
+
+ async def retrieve_file(
+ self, document_id: UUID
+ ) -> Optional[tuple[str, BinaryIO, int]]:
+ """Retrieve a file from storage."""
+ query = f"""
+ SELECT name, oid, size
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [document_id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message=f"File for document {document_id} not found",
+ )
+
+ file_name, oid, size = (
+ result["name"],
+ result["oid"],
+ result["size"],
+ )
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ file_content = await self._read_lobject(conn, oid)
+ return file_name, io.BytesIO(file_content), size
+
+ async def retrieve_files_as_zip(
+ self,
+ document_ids: Optional[list[UUID]] = None,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ ) -> tuple[str, BinaryIO, int]:
+ """Retrieve multiple files and return them as a zip file."""
+
+ query = f"""
+ SELECT document_id, name, oid, size
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE 1=1
+ """
+ params: list = []
+
+ if document_ids:
+ query += f" AND document_id = ANY(${len(params) + 1})"
+ params.append([str(doc_id) for doc_id in document_ids])
+
+ if start_date:
+ query += f" AND created_at >= ${len(params) + 1}"
+ params.append(start_date)
+
+ if end_date:
+ query += f" AND created_at <= ${len(params) + 1}"
+ params.append(end_date)
+
+ query += " ORDER BY created_at DESC"
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ raise R2RException(
+ status_code=404,
+ message="No files found matching the specified criteria",
+ )
+
+ zip_buffer = BytesIO()
+ total_size = 0
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ with ZipFile(zip_buffer, "w") as zip_file:
+ for record in results:
+ file_content = await self._read_lobject(
+ conn, record["oid"]
+ )
+
+ zip_file.writestr(record["name"], file_content)
+ total_size += record["size"]
+
+ zip_buffer.seek(0)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ zip_filename = f"files_export_{timestamp}.zip"
+
+ return zip_filename, zip_buffer, zip_buffer.getbuffer().nbytes
+
+ async def _read_lobject(self, conn, oid: int) -> bytes:
+ """Read content from a large object."""
+ file_data = io.BytesIO()
+ chunk_size = 8192
+
+ async with conn.transaction():
+ try:
+ lo_exists = await conn.fetchval(
+ "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_largeobject_metadata WHERE oid = $1);",
+ oid,
+ )
+ if not lo_exists:
+ raise R2RException(
+ status_code=404,
+ message=f"Large object {oid} not found.",
+ )
+
+ lobject = await conn.fetchval(
+ "SELECT lo_open($1, 262144)", oid
+ )
+
+ if lobject is None:
+ raise R2RException(
+ status_code=404,
+ message=f"Failed to open large object {oid}.",
+ )
+
+ while True:
+ chunk = await conn.fetchval(
+ "SELECT loread($1, $2)", lobject, chunk_size
+ )
+ if not chunk:
+ break
+ file_data.write(chunk)
+ except asyncpg.exceptions.UndefinedObjectError:
+ raise R2RException(
+ status_code=404,
+ message=f"Failed to read large object {oid}",
+ ) from None
+ finally:
+ await conn.execute("SELECT lo_close($1)", lobject)
+
+ return file_data.getvalue()
+
+ async def delete_file(self, document_id: UUID) -> bool:
+ """Delete a file from storage."""
+ query = f"""
+ SELECT oid FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ oid = await conn.fetchval(query, document_id)
+ if not oid:
+ raise R2RException(
+ status_code=404,
+ message=f"File for document {document_id} not found",
+ )
+
+ await self._delete_lobject(conn, oid)
+
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+ await conn.execute(delete_query, document_id)
+
+ return True
+
+ async def _delete_lobject(self, conn, oid: int) -> None:
+ """Delete a large object."""
+ await conn.execute("SELECT lo_unlink($1)", oid)
+
+ async def get_files_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_file_names: Optional[list[str]] = None,
+ ) -> list[dict]:
+ """Get an overview of stored files."""
+ conditions = []
+ params: list[str | list[str] | int] = []
+ query = f"""
+ SELECT document_id, name, oid, size, type, created_at, updated_at
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ """
+
+ if filter_document_ids:
+ conditions.append(f"document_id = ANY(${len(params) + 1})")
+ params.append([str(doc_id) for doc_id in filter_document_ids])
+
+ if filter_file_names:
+ conditions.append(f"name = ANY(${len(params) + 1})")
+ params.append(filter_file_names)
+
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+
+ query += f" ORDER BY created_at DESC OFFSET ${len(params) + 1} LIMIT ${len(params) + 2}"
+ params.extend([offset, limit])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ raise R2RException(
+ status_code=404,
+ message="No files found with the given filters",
+ )
+
+ return [
+ {
+ "document_id": row["document_id"],
+ "file_name": row["name"],
+ "file_oid": row["oid"],
+ "file_size": row["size"],
+ "file_type": row["type"],
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/filters.py b/.venv/lib/python3.12/site-packages/core/providers/database/filters.py
new file mode 100644
index 00000000..9231e35b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/filters.py
@@ -0,0 +1,478 @@
+import json
+from typing import Any, Optional, Tuple
+
+COLUMN_VARS = [
+ "id",
+ "document_id",
+ "owner_id",
+ "collection_ids",
+]
+
+
+class FilterError(Exception):
+ pass
+
+
+class FilterOperator:
+ EQ = "$eq"
+ NE = "$ne"
+ LT = "$lt"
+ LTE = "$lte"
+ GT = "$gt"
+ GTE = "$gte"
+ IN = "$in"
+ NIN = "$nin"
+ LIKE = "$like"
+ ILIKE = "$ilike"
+ CONTAINS = "$contains"
+ AND = "$and"
+ OR = "$or"
+ OVERLAP = "$overlap"
+
+ SCALAR_OPS = {EQ, NE, LT, LTE, GT, GTE, LIKE, ILIKE}
+ ARRAY_OPS = {IN, NIN, OVERLAP}
+ JSON_OPS = {CONTAINS}
+ LOGICAL_OPS = {AND, OR}
+
+
+class FilterCondition:
+ def __init__(self, field: str, operator: str, value: Any):
+ self.field = field
+ self.operator = operator
+ self.value = value
+
+
+class FilterExpression:
+ def __init__(self, logical_op: Optional[str] = None):
+ self.logical_op = logical_op
+ self.conditions: list[FilterCondition | "FilterExpression"] = []
+
+
+class FilterParser:
+ def __init__(
+ self,
+ top_level_columns: Optional[list[str]] = None,
+ json_column: str = "metadata",
+ ):
+ if top_level_columns is None:
+ self.top_level_columns = set(COLUMN_VARS)
+ else:
+ self.top_level_columns = set(top_level_columns)
+ self.json_column = json_column
+
+ def parse(self, filters: dict) -> FilterExpression:
+ if not filters:
+ raise FilterError("Empty filters are not allowed")
+ return self._parse_logical(filters)
+
+ def _parse_logical(self, dct: dict) -> FilterExpression:
+ keys = list(dct.keys())
+ expr = FilterExpression()
+ if len(keys) == 1 and keys[0] in (
+ FilterOperator.AND,
+ FilterOperator.OR,
+ ):
+ expr.logical_op = keys[0]
+ if not isinstance(dct[keys[0]], list):
+ raise FilterError(f"{keys[0]} value must be a list")
+ for item in dct[keys[0]]:
+ if isinstance(item, dict):
+ if self._is_logical_block(item):
+ expr.conditions.append(self._parse_logical(item))
+ else:
+ expr.conditions.append(
+ self._parse_condition_dict(item)
+ )
+ else:
+ raise FilterError("Invalid filter format")
+ else:
+ expr.logical_op = FilterOperator.AND
+ expr.conditions.append(self._parse_condition_dict(dct))
+
+ return expr
+
+ def _is_logical_block(self, dct: dict) -> bool:
+ if len(dct.keys()) == 1:
+ k = next(iter(dct.keys()))
+ if k in FilterOperator.LOGICAL_OPS:
+ return True
+ return False
+
+ def _parse_condition_dict(self, dct: dict) -> FilterExpression:
+ expr = FilterExpression(logical_op=FilterOperator.AND)
+ for field, cond in dct.items():
+ if not isinstance(cond, dict):
+ # direct equality
+ expr.conditions.append(
+ FilterCondition(field, FilterOperator.EQ, cond)
+ )
+ else:
+ if len(cond) != 1:
+ raise FilterError(
+ f"Condition for field {field} must have exactly one operator"
+ )
+ op, val = next(iter(cond.items()))
+ self._validate_operator(op)
+ expr.conditions.append(FilterCondition(field, op, val))
+ return expr
+
+ def _validate_operator(self, op: str):
+ allowed = (
+ FilterOperator.SCALAR_OPS
+ | FilterOperator.ARRAY_OPS
+ | FilterOperator.JSON_OPS
+ | FilterOperator.LOGICAL_OPS
+ )
+ if op not in allowed:
+ raise FilterError(f"Unsupported operator: {op}")
+
+
+class SQLFilterBuilder:
+ def __init__(
+ self,
+ params: list[Any],
+ top_level_columns: Optional[list[str]] = None,
+ json_column: str = "metadata",
+ mode: str = "where_clause",
+ ):
+ if top_level_columns is None:
+ self.top_level_columns = set(COLUMN_VARS)
+ else:
+ self.top_level_columns = set(top_level_columns)
+ self.json_column = json_column
+ self.params: list[Any] = params # mutated during construction
+ self.mode = mode
+
+ def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]:
+ where_clause = self._build_expression(expr)
+ if self.mode == "where_clause":
+ return f"WHERE {where_clause}", self.params
+
+ return where_clause, self.params
+
+ def _build_expression(self, expr: FilterExpression) -> str:
+ parts = []
+ for c in expr.conditions:
+ if isinstance(c, FilterCondition):
+ parts.append(self._build_condition(c))
+ else:
+ nested_sql = self._build_expression(c)
+ parts.append(f"({nested_sql})")
+
+ if expr.logical_op == FilterOperator.AND:
+ return " AND ".join(parts)
+ elif expr.logical_op == FilterOperator.OR:
+ return " OR ".join(parts)
+ else:
+ return " AND ".join(parts)
+
+ @staticmethod
+ def _psql_quote_literal(value: str) -> str:
+ """Simple quoting for demonstration.
+
+ In production, use parameterized queries or your DB driver's quoting
+ function instead.
+ """
+ return "'" + value.replace("'", "''") + "'"
+
+ def _build_condition(self, cond: FilterCondition) -> str:
+ field_is_metadata = cond.field not in self.top_level_columns
+ key = cond.field
+ op = cond.operator
+ val = cond.value
+
+ # 1. If the filter references "parent_id", handle it as a single-UUID column for graphs:
+ if key == "parent_id":
+ return self._build_parent_id_condition(op, val)
+
+ # 2. If the filter references "collection_id", handle it as an array column (chunks)
+ if key == "collection_id":
+ return self._build_collection_id_condition(op, val)
+
+ # 3. Otherwise, decide if it's top-level or metadata:
+ if field_is_metadata:
+ return self._build_metadata_condition(key, op, val)
+ else:
+ return self._build_column_condition(key, op, val)
+
+ def _build_parent_id_condition(self, op: str, val: Any) -> str:
+ """For 'graphs' tables, parent_id is a single UUID (not an array).
+
+ We handle the same ops but in a simpler, single-UUID manner.
+ """
+ param_idx = len(self.params) + 1
+
+ if op == "$eq":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$eq for parent_id expects a single UUID string"
+ )
+ self.params.append(val)
+ return f"parent_id = ${param_idx}::uuid"
+
+ elif op == "$ne":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$ne for parent_id expects a single UUID string"
+ )
+ self.params.append(val)
+ return f"parent_id != ${param_idx}::uuid"
+
+ elif op == "$in":
+ # A list of UUIDs, any of which might match
+ if not isinstance(val, list):
+ raise FilterError(
+ "$in for parent_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return f"parent_id = ANY(${param_idx}::uuid[])"
+
+ elif op == "$nin":
+ # A list of UUIDs, none of which may match
+ if not isinstance(val, list):
+ raise FilterError(
+ "$nin for parent_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return f"parent_id != ALL(${param_idx}::uuid[])"
+
+ else:
+ # You could add more (like $gt, $lt, etc.) if your schema wants them
+ raise FilterError(f"Unsupported operator {op} for parent_id")
+
+ def _build_collection_id_condition(self, op: str, val: Any) -> str:
+ """For the 'chunks' table, collection_ids is an array of UUIDs.
+
+ We need to use array operators to compare arrays correctly.
+ """
+ param_idx = len(self.params) + 1
+
+ if op == "$eq":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$eq for collection_id expects a single UUID string"
+ )
+ self.params.append(
+ [val]
+ ) # Make it a list with one element for the overlap check
+ return (
+ f"collection_ids && ${param_idx}::uuid[]" # Use && for overlap
+ )
+
+ elif op == "$ne":
+ if not isinstance(val, str):
+ raise FilterError(
+ "$ne for collection_id expects a single UUID string"
+ )
+ self.params.append([val])
+ return f"NOT (collection_ids && ${param_idx}::uuid[])" # Negate the overlap
+
+ elif op == "$in":
+ if not isinstance(val, list):
+ raise FilterError(
+ "$in for collection_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return (
+ f"collection_ids && ${param_idx}::uuid[]" # Use && for overlap
+ )
+
+ elif op == "$nin":
+ if not isinstance(val, list):
+ raise FilterError(
+ "$nin for collection_id expects a list of UUID strings"
+ )
+ self.params.append(val)
+ return f"NOT (collection_ids && ${param_idx}::uuid[])" # Negate the overlap
+
+ elif op == "$contains":
+ if isinstance(val, str):
+ # single string -> array with one element
+ self.params.append([val])
+ return f"collection_ids @> ${param_idx}::uuid[]"
+ elif isinstance(val, list):
+ self.params.append(val)
+ return f"collection_ids @> ${param_idx}::uuid[]"
+ else:
+ raise FilterError(
+ "$contains for collection_id expects a UUID or list of UUIDs"
+ )
+
+ elif op == "$overlap":
+ if not isinstance(val, list):
+ self.params.append([val])
+ else:
+ self.params.append(val)
+ return f"collection_ids && ${param_idx}::uuid[]"
+
+ else:
+ raise FilterError(f"Unsupported operator {op} for collection_id")
+
+ def _build_column_condition(self, col: str, op: str, val: Any) -> str:
+ # If we're dealing with collection_ids, route to our specialized handler
+ if col == "collection_ids":
+ return self._build_collection_id_condition(op, val)
+
+ param_idx = len(self.params) + 1
+ if op == "$eq":
+ self.params.append(val)
+ return f"{col} = ${param_idx}"
+ elif op == "$ne":
+ self.params.append(val)
+ return f"{col} != ${param_idx}"
+ elif op == "$in":
+ if not isinstance(val, list):
+ raise FilterError("argument to $in filter must be a list")
+ self.params.append(val)
+ return f"{col} = ANY(${param_idx})"
+ elif op == "$nin":
+ if not isinstance(val, list):
+ raise FilterError("argument to $nin filter must be a list")
+ self.params.append(val)
+ return f"{col} != ALL(${param_idx})"
+ elif op == "$overlap":
+ self.params.append(val)
+ return f"{col} && ${param_idx}"
+ elif op == "$contains":
+ self.params.append(val)
+ return f"{col} @> ${param_idx}"
+ elif op == "$any":
+ if col == "collection_ids":
+ self.params.append(f"%{val}%")
+ return f"array_to_string({col}, ',') LIKE ${param_idx}"
+ else:
+ self.params.append(val)
+ return f"${param_idx} = ANY({col})"
+ elif op in ("$lt", "$lte", "$gt", "$gte"):
+ self.params.append(val)
+ return f"{col} {self._map_op(op)} ${param_idx}"
+ else:
+ raise FilterError(f"Unsupported operator for column {col}: {op}")
+
+ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
+ param_idx = len(self.params) + 1
+ json_col = self.json_column
+
+ # Strip "metadata." prefix if present
+ key = key.removeprefix("metadata.")
+
+ # Split on '.' to handle nested keys
+ parts = key.split(".")
+
+ # Use text extraction for scalar values, but not for arrays
+ use_text_extraction = op in (
+ "$lt",
+ "$lte",
+ "$gt",
+ "$gte",
+ "$eq",
+ "$ne",
+ ) and isinstance(val, (int, float, str))
+ if op == "$in" or op == "$contains" or isinstance(val, (list, dict)):
+ use_text_extraction = False
+
+ # Build the JSON path expression
+ if len(parts) == 1:
+ if use_text_extraction:
+ path_expr = f"{json_col}->>'{parts[0]}'"
+ else:
+ path_expr = f"{json_col}->'{parts[0]}'"
+ else:
+ path_expr = json_col
+ for p in parts[:-1]:
+ path_expr += f"->'{p}'"
+ last_part = parts[-1]
+ if use_text_extraction:
+ path_expr += f"->>'{last_part}'"
+ else:
+ path_expr += f"->'{last_part}'"
+
+ # Convert numeric values to strings for text comparison
+ def prepare_value(v):
+ return str(v) if isinstance(v, (int, float)) else v
+
+ if op == "$eq":
+ if use_text_extraction:
+ prepared_val = prepare_value(val)
+ self.params.append(prepared_val)
+ return f"{path_expr} = ${param_idx}"
+ else:
+ self.params.append(json.dumps(val))
+ return f"{path_expr} = ${param_idx}::jsonb"
+ elif op == "$ne":
+ if use_text_extraction:
+ self.params.append(prepare_value(val))
+ return f"{path_expr} != ${param_idx}"
+ else:
+ self.params.append(json.dumps(val))
+ return f"{path_expr} != ${param_idx}::jsonb"
+ elif op == "$lt":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric < ${param_idx}::numeric"
+ elif op == "$lte":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric <= ${param_idx}::numeric"
+ elif op == "$gt":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric > ${param_idx}::numeric"
+ elif op == "$gte":
+ self.params.append(prepare_value(val))
+ return f"({path_expr})::numeric >= ${param_idx}::numeric"
+ elif op == "$in":
+ if not isinstance(val, list):
+ raise FilterError("argument to $in filter must be a list")
+
+ if use_text_extraction:
+ str_vals = [
+ str(v) if isinstance(v, (int, float)) else v for v in val
+ ]
+ self.params.append(str_vals)
+ return f"{path_expr} = ANY(${param_idx}::text[])"
+
+ # For JSON arrays, use containment checks
+ conditions = []
+ for i, v in enumerate(val):
+ self.params.append(json.dumps(v))
+ conditions.append(f"{path_expr} @> ${param_idx + i}::jsonb")
+ return f"({' OR '.join(conditions)})"
+
+ elif op == "$contains":
+ if isinstance(val, (str, int, float, bool)):
+ val = [val]
+ self.params.append(json.dumps(val))
+ return f"{path_expr} @> ${param_idx}::jsonb"
+ else:
+ raise FilterError(f"Unsupported operator for metadata field {op}")
+
+ def _map_op(self, op: str) -> str:
+ mapping = {
+ FilterOperator.EQ: "=",
+ FilterOperator.NE: "!=",
+ FilterOperator.LT: "<",
+ FilterOperator.LTE: "<=",
+ FilterOperator.GT: ">",
+ FilterOperator.GTE: ">=",
+ }
+ return mapping.get(op, op)
+
+
+def apply_filters(
+ filters: dict, params: list[Any], mode: str = "where_clause"
+) -> tuple[str, list[Any]]:
+ """Apply filters with consistent WHERE clause handling."""
+ if not filters:
+ return "", params
+
+ parser = FilterParser()
+ expr = parser.parse(filters)
+ builder = SQLFilterBuilder(params=params, mode=mode)
+ filter_clause, new_params = builder.build(expr)
+
+ if mode == "where_clause":
+ return filter_clause, new_params # Already includes WHERE
+ elif mode == "condition_only":
+ return filter_clause, new_params
+ elif mode == "append_only":
+ return f"AND {filter_clause}", new_params
+ else:
+ raise ValueError(f"Unknown filter mode: {mode}")
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py b/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
new file mode 100644
index 00000000..ba9c22ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
@@ -0,0 +1,2884 @@
+import asyncio
+import contextlib
+import csv
+import datetime
+import json
+import logging
+import os
+import tempfile
+import time
+from typing import IO, Any, AsyncGenerator, Optional, Tuple
+from uuid import UUID
+
+import asyncpg
+import httpx
+from asyncpg.exceptions import UniqueViolationError
+from fastapi import HTTPException
+
+from core.base.abstractions import (
+ Community,
+ Entity,
+ Graph,
+ GraphExtractionStatus,
+ R2RException,
+ Relationship,
+ StoreType,
+ VectorQuantizationType,
+)
+from core.base.api.models import GraphResponse
+from core.base.providers.database import Handler
+from core.base.utils import (
+ _get_vector_column_str,
+ generate_entity_document_id,
+)
+
+from .base import PostgresConnectionManager
+from .collections import PostgresCollectionsHandler
+
+logger = logging.getLogger()
+
+
+class PostgresEntitiesHandler(Handler):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+ self.relationships_handler: PostgresRelationshipsHandler = (
+ PostgresRelationshipsHandler(*args, **kwargs)
+ )
+
+ def _get_table_name(self, table: str) -> str:
+ """Get the fully qualified table name."""
+ return f'"{self.project_name}"."{table}"'
+
+ def _get_entity_table_for_store(self, store_type: StoreType) -> str:
+ """Get the appropriate table name for the store type."""
+ return f"{store_type.value}_entities"
+
+ def _get_parent_constraint(self, store_type: StoreType) -> str:
+ """Get the appropriate foreign key constraint for the store type."""
+ if store_type == StoreType.GRAPHS:
+ return f"""
+ CONSTRAINT fk_graph
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("graphs")}(id)
+ ON DELETE CASCADE
+ """
+ else:
+ return f"""
+ CONSTRAINT fk_document
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("documents")}(id)
+ ON DELETE CASCADE
+ """
+
+ async def create_tables(self) -> None:
+ """Create separate tables for graph and document entities."""
+ vector_column_str = _get_vector_column_str(
+ self.dimension, self.quantization_type
+ )
+
+ for store_type in StoreType:
+ table_name = self._get_entity_table_for_store(store_type)
+ parent_constraint = self._get_parent_constraint(store_type)
+
+ QUERY = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ name TEXT NOT NULL,
+ category TEXT,
+ description TEXT,
+ parent_id UUID NOT NULL,
+ description_embedding {vector_column_str},
+ chunk_ids UUID[],
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ {parent_constraint}
+ );
+ CREATE INDEX IF NOT EXISTS {table_name}_name_idx
+ ON {self._get_table_name(table_name)} (name);
+ CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+ ON {self._get_table_name(table_name)} (parent_id);
+ CREATE INDEX IF NOT EXISTS {table_name}_category_idx
+ ON {self._get_table_name(table_name)} (category);
+ """
+ await self.connection_manager.execute_query(QUERY)
+
+ async def create(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ name: str,
+ category: Optional[str] = None,
+ description: Optional[str] = None,
+ description_embedding: Optional[list[float] | str] = None,
+ chunk_ids: Optional[list[UUID]] = None,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Entity:
+ """Create a new entity in the specified store."""
+ table_name = self._get_entity_table_for_store(store_type)
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if isinstance(description_embedding, list):
+ description_embedding = str(description_embedding)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+ """
+
+ params = [
+ name,
+ category,
+ description,
+ parent_id,
+ description_embedding,
+ chunk_ids,
+ json.dumps(metadata) if metadata else None,
+ ]
+
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Entity(
+ id=result["id"],
+ name=result["name"],
+ category=result["category"],
+ description=result["description"],
+ parent_id=result["parent_id"],
+ chunk_ids=result["chunk_ids"],
+ metadata=result["metadata"],
+ )
+
+ async def get(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ offset: int,
+ limit: int,
+ entity_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ """Retrieve entities from the specified store."""
+ table_name = self._get_entity_table_for_store(store_type)
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if entity_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(entity_ids)
+ param_index += 1
+
+ if entity_names:
+ conditions.append(f"name = ANY(${param_index})")
+ params.append(entity_names)
+ param_index += 1
+
+ select_fields = """
+ id, name, category, description, parent_id,
+ chunk_ids, metadata
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+
+ count_params = params[: param_index - 1]
+ count = (
+ await self.connection_manager.fetch_query(
+ COUNT_QUERY, count_params
+ )
+ )[0]["count"]
+
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ entities = []
+ for row in rows:
+ # Convert the Record to a dictionary
+ entity_dict = dict(row)
+
+ # Process metadata if it exists and is a string
+ if isinstance(entity_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ entity_dict["metadata"] = json.loads(
+ entity_dict["metadata"]
+ )
+
+ entities.append(Entity(**entity_dict))
+
+ return entities, count
+
+ async def update(
+ self,
+ entity_id: UUID,
+ store_type: StoreType,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ description_embedding: Optional[list[float] | str] = None,
+ category: Optional[str] = None,
+ metadata: Optional[dict] = None,
+ ) -> Entity:
+ """Update an entity in the specified store."""
+ table_name = self._get_entity_table_for_store(store_type)
+ update_fields = []
+ params: list[Any] = []
+ param_index = 1
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if description_embedding is not None:
+ update_fields.append(f"description_embedding = ${param_index}")
+ params.append(description_embedding)
+ param_index += 1
+
+ if category is not None:
+ update_fields.append(f"category = ${param_index}")
+ params.append(category)
+ param_index += 1
+
+ if metadata is not None:
+ update_fields.append(f"metadata = ${param_index}")
+ params.append(json.dumps(metadata))
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(entity_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}\
+ RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Entity(
+ id=result["id"],
+ name=result["name"],
+ category=result["category"],
+ description=result["description"],
+ parent_id=result["parent_id"],
+ chunk_ids=result["chunk_ids"],
+ metadata=result["metadata"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the entity: {e}",
+ ) from e
+
+ async def delete(
+ self,
+ parent_id: UUID,
+ entity_ids: Optional[list[UUID]] = None,
+ store_type: StoreType = StoreType.GRAPHS,
+ ) -> None:
+ """Delete entities from the specified store. If entity_ids is not
+ provided, deletes all entities for the given parent_id.
+
+ Args:
+ parent_id (UUID): Parent ID (collection_id or document_id)
+ entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
+ store_type (StoreType): Type of store (graph or document)
+
+ Returns:
+ list[UUID]: List of deleted entity IDs
+
+ Raises:
+ R2RException: If specific entities were requested but not all found
+ """
+ table_name = self._get_entity_table_for_store(store_type)
+
+ if entity_ids is None:
+ # Delete all entities for the parent_id
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE parent_id = $1
+ RETURNING id
+ """
+ results = await self.connection_manager.fetch_query(
+ QUERY, [parent_id]
+ )
+ else:
+ # Delete specific entities
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = ANY($1) AND parent_id = $2
+ RETURNING id
+ """
+
+ results = await self.connection_manager.fetch_query(
+ QUERY, [entity_ids, parent_id]
+ )
+
+ # Check if all requested entities were deleted
+ deleted_ids = [row["id"] for row in results]
+ if entity_ids and len(deleted_ids) != len(entity_ids):
+ raise R2RException(
+ f"Some entities not found in {store_type} store or no permission to delete",
+ 404,
+ )
+
+ async def get_duplicate_name_blocks(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ ) -> list[list[Entity]]:
+ """Find all groups of entities that share identical names within the
+ same parent.
+
+ Returns a list of entity groups, where each group contains entities
+ with the same name. For each group, includes the n most dissimilar
+ descriptions based on cosine similarity.
+ """
+ table_name = self._get_entity_table_for_store(store_type)
+
+ # First get the duplicate names and their descriptions with embeddings
+ query = f"""
+ WITH duplicates AS (
+ SELECT name
+ FROM {self._get_table_name(table_name)}
+ WHERE parent_id = $1
+ GROUP BY name
+ HAVING COUNT(*) > 1
+ )
+ SELECT
+ e.id, e.name, e.category, e.description,
+ e.parent_id, e.chunk_ids, e.metadata
+ FROM {self._get_table_name(table_name)} e
+ WHERE e.parent_id = $1
+ AND e.name IN (SELECT name FROM duplicates)
+ ORDER BY e.name;
+ """
+
+ rows = await self.connection_manager.fetch_query(query, [parent_id])
+
+ # Group entities by name
+ name_groups: dict[str, list[Entity]] = {}
+ for row in rows:
+ entity_dict = dict(row)
+ if isinstance(entity_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ entity_dict["metadata"] = json.loads(
+ entity_dict["metadata"]
+ )
+
+ entity = Entity(**entity_dict)
+ name_groups.setdefault(entity.name, []).append(entity)
+
+ return list(name_groups.values())
+
+ async def merge_duplicate_name_blocks(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ ) -> list[tuple[list[Entity], Entity]]:
+ """Merge entities that share identical names.
+
+ Returns list of tuples: (original_entities, merged_entity)
+ """
+ duplicate_blocks = await self.get_duplicate_name_blocks(
+ parent_id, store_type
+ )
+ merged_results: list[tuple[list[Entity], Entity]] = []
+
+ for block in duplicate_blocks:
+ # Create a new merged entity from the block
+ merged_entity = await self._create_merged_entity(block)
+ merged_results.append((block, merged_entity))
+
+ table_name = self._get_entity_table_for_store(store_type)
+ async with self.connection_manager.transaction():
+ # Insert the merged entity
+ new_id = await self._insert_merged_entity(
+ merged_entity, table_name
+ )
+
+ merged_entity.id = new_id
+
+ # Get the old entity IDs
+ old_ids = [str(entity.id) for entity in block]
+
+ relationship_table = self.relationships_handler._get_relationship_table_for_store(
+ store_type
+ )
+
+ # Update relationships where old entities appear as subjects
+ subject_update_query = f"""
+ UPDATE {self._get_table_name(relationship_table)}
+ SET subject_id = $1
+ WHERE subject_id = ANY($2::uuid[])
+ AND parent_id = $3
+ """
+ await self.connection_manager.execute_query(
+ subject_update_query, [new_id, old_ids, parent_id]
+ )
+
+ # Update relationships where old entities appear as objects
+ object_update_query = f"""
+ UPDATE {self._get_table_name(relationship_table)}
+ SET object_id = $1
+ WHERE object_id = ANY($2::uuid[])
+ AND parent_id = $3
+ """
+ await self.connection_manager.execute_query(
+ object_update_query, [new_id, old_ids, parent_id]
+ )
+
+ # Delete the original entities
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = ANY($1::uuid[])
+ """
+ await self.connection_manager.execute_query(
+ delete_query, [old_ids]
+ )
+
+ return merged_results
+
+ async def _insert_merged_entity(
+ self, entity: Entity, table_name: str
+ ) -> UUID:
+ """Insert merged entity and return its new ID."""
+ new_id = generate_entity_document_id()
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (id, name, category, description, parent_id, chunk_ids, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ RETURNING id
+ """
+
+ values = [
+ new_id,
+ entity.name,
+ entity.category,
+ entity.description,
+ entity.parent_id,
+ entity.chunk_ids,
+ json.dumps(entity.metadata) if entity.metadata else None,
+ ]
+
+ result = await self.connection_manager.fetch_query(query, values)
+ return result[0]["id"]
+
+ async def _create_merged_entity(self, entities: list[Entity]) -> Entity:
+ """Create a merged entity from a list of duplicate entities.
+
+ Uses various strategies to combine fields.
+ """
+ if not entities:
+ raise ValueError("Cannot merge empty list of entities")
+
+ # Take the first non-None category, or None if all are None
+ category = next(
+ (e.category for e in entities if e.category is not None), None
+ )
+
+ # Combine descriptions with newlines if they differ
+ descriptions = {e.description for e in entities if e.description}
+ description = "\n\n".join(descriptions) if descriptions else None
+
+ # Combine chunk_ids, removing duplicates
+ chunk_ids = list(
+ {
+ chunk_id
+ for entity in entities
+ for chunk_id in (entity.chunk_ids or [])
+ }
+ )
+
+ # Merge metadata dictionaries
+ merged_metadata: dict[str, Any] = {}
+ for entity in entities:
+ if entity.metadata:
+ merged_metadata |= entity.metadata
+
+ # Create new merged entity (without actually inserting to DB)
+ return Entity(
+ id=UUID(
+ "00000000-0000-0000-0000-000000000000"
+ ), # Placeholder UUID
+ name=entities[0].name, # All entities in block have same name
+ category=category,
+ description=description,
+ parent_id=entities[0].parent_id,
+ chunk_ids=chunk_ids or None,
+ metadata=merged_metadata or None,
+ )
+
+ async def export_to_csv(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "name",
+ "category",
+ "description",
+ "parent_id",
+ "chunk_ids",
+ "metadata",
+ "created_at",
+ "updated_at",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ name,
+ category,
+ description,
+ parent_id::text,
+ chunk_ids::text,
+ metadata::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+ FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
+ """
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "name": row[1],
+ "category": row[2],
+ "description": row[3],
+ "parent_id": row[4],
+ "chunk_ids": row[5],
+ "metadata": row[6],
+ "created_at": row[7],
+ "updated_at": row[8],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+
+class PostgresRelationshipsHandler(Handler):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+
+ def _get_table_name(self, table: str) -> str:
+ """Get the fully qualified table name."""
+ return f'"{self.project_name}"."{table}"'
+
+ def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
+ """Get the appropriate table name for the store type."""
+ return f"{store_type.value}_relationships"
+
+ def _get_parent_constraint(self, store_type: StoreType) -> str:
+ """Get the appropriate foreign key constraint for the store type."""
+ if store_type == StoreType.GRAPHS:
+ return f"""
+ CONSTRAINT fk_graph
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("graphs")}(id)
+ ON DELETE CASCADE
+ """
+ else:
+ return f"""
+ CONSTRAINT fk_document
+ FOREIGN KEY(parent_id)
+ REFERENCES {self._get_table_name("documents")}(id)
+ ON DELETE CASCADE
+ """
+
+ async def create_tables(self) -> None:
+ """Create separate tables for graph and document relationships."""
+ for store_type in StoreType:
+ table_name = self._get_relationship_table_for_store(store_type)
+ parent_constraint = self._get_parent_constraint(store_type)
+ vector_column_str = _get_vector_column_str(
+ self.dimension, self.quantization_type
+ )
+
+ QUERY = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ subject TEXT NOT NULL,
+ predicate TEXT NOT NULL,
+ object TEXT NOT NULL,
+ description TEXT,
+ description_embedding {vector_column_str},
+ subject_id UUID,
+ object_id UUID,
+ weight FLOAT DEFAULT 1.0,
+ chunk_ids UUID[],
+ parent_id UUID NOT NULL,
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ {parent_constraint}
+ );
+
+ CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
+ ON {self._get_table_name(table_name)} (subject);
+ CREATE INDEX IF NOT EXISTS {table_name}_object_idx
+ ON {self._get_table_name(table_name)} (object);
+ CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
+ ON {self._get_table_name(table_name)} (predicate);
+ CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+ ON {self._get_table_name(table_name)} (parent_id);
+ CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
+ ON {self._get_table_name(table_name)} (subject_id);
+ CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
+ ON {self._get_table_name(table_name)} (object_id);
+ """
+ await self.connection_manager.execute_query(QUERY)
+
+ async def create(
+ self,
+ subject: str,
+ subject_id: UUID,
+ predicate: str,
+ object: str,
+ object_id: UUID,
+ parent_id: UUID,
+ store_type: StoreType,
+ description: str | None = None,
+ weight: float | None = 1.0,
+ chunk_ids: Optional[list[UUID]] = None,
+ description_embedding: Optional[list[float] | str] = None,
+ metadata: Optional[dict[str, Any] | str] = None,
+ ) -> Relationship:
+ """Create a new relationship in the specified store."""
+ table_name = self._get_relationship_table_for_store(store_type)
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if isinstance(description_embedding, list):
+ description_embedding = str(description_embedding)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (subject, predicate, object, description, subject_id, object_id,
+ weight, chunk_ids, parent_id, description_embedding, metadata)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+ RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+ """
+
+ params = [
+ subject,
+ predicate,
+ object,
+ description,
+ subject_id,
+ object_id,
+ weight,
+ chunk_ids,
+ parent_id,
+ description_embedding,
+ json.dumps(metadata) if metadata else None,
+ ]
+
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Relationship(
+ id=result["id"],
+ subject=result["subject"],
+ predicate=result["predicate"],
+ object=result["object"],
+ description=result["description"],
+ subject_id=result["subject_id"],
+ object_id=result["object_id"],
+ weight=result["weight"],
+ chunk_ids=result["chunk_ids"],
+ parent_id=result["parent_id"],
+ metadata=result["metadata"],
+ )
+
+ async def get(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ offset: int,
+ limit: int,
+ relationship_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ relationship_types: Optional[list[str]] = None,
+ include_metadata: bool = False,
+ ):
+ """Get relationships from the specified store.
+
+ Args:
+ parent_id: UUID of the parent (collection_id or document_id)
+ store_type: Type of store (graph or document)
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ relationship_ids: Optional list of specific relationship IDs to retrieve
+ entity_names: Optional list of entity names to filter by (matches subject or object)
+ relationship_types: Optional list of relationship types (predicates) to filter by
+ include_metadata: Whether to include metadata in the response
+
+ Returns:
+ Tuple of (list of relationships, total count)
+ """
+ table_name = self._get_relationship_table_for_store(store_type)
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if relationship_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(relationship_ids)
+ param_index += 1
+
+ if entity_names:
+ conditions.append(
+ f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
+ )
+ params.append(entity_names)
+ param_index += 1
+
+ if relationship_types:
+ conditions.append(f"predicate = ANY(${param_index})")
+ params.append(relationship_types)
+ param_index += 1
+
+ select_fields = """
+ id, subject, predicate, object, description,
+ subject_id, object_id, weight, chunk_ids,
+ parent_id
+ """
+ if include_metadata:
+ select_fields += ", metadata"
+
+ # Count query
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+ count_params = params[: param_index - 1]
+ count = (
+ await self.connection_manager.fetch_query(
+ COUNT_QUERY, count_params
+ )
+ )[0]["count"]
+
+ # Main query
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ relationships = []
+ for row in rows:
+ relationship_dict = dict(row)
+ if include_metadata and isinstance(
+ relationship_dict["metadata"], str
+ ):
+ with contextlib.suppress(json.JSONDecodeError):
+ relationship_dict["metadata"] = json.loads(
+ relationship_dict["metadata"]
+ )
+ elif not include_metadata:
+ relationship_dict.pop("metadata", None)
+ relationships.append(Relationship(**relationship_dict))
+
+ return relationships, count
+
+ async def update(
+ self,
+ relationship_id: UUID,
+ store_type: StoreType,
+ subject: Optional[str],
+ subject_id: Optional[UUID],
+ predicate: Optional[str],
+ object: Optional[str],
+ object_id: Optional[UUID],
+ description: Optional[str],
+ description_embedding: Optional[list[float] | str],
+ weight: Optional[float],
+ metadata: Optional[dict[str, Any] | str],
+ ) -> Relationship:
+ """Update multiple relationships in the specified store."""
+ table_name = self._get_relationship_table_for_store(store_type)
+ update_fields = []
+ params: list = []
+ param_index = 1
+
+ if isinstance(metadata, str):
+ with contextlib.suppress(json.JSONDecodeError):
+ metadata = json.loads(metadata)
+
+ if subject is not None:
+ update_fields.append(f"subject = ${param_index}")
+ params.append(subject)
+ param_index += 1
+
+ if subject_id is not None:
+ update_fields.append(f"subject_id = ${param_index}")
+ params.append(subject_id)
+ param_index += 1
+
+ if predicate is not None:
+ update_fields.append(f"predicate = ${param_index}")
+ params.append(predicate)
+ param_index += 1
+
+ if object is not None:
+ update_fields.append(f"object = ${param_index}")
+ params.append(object)
+ param_index += 1
+
+ if object_id is not None:
+ update_fields.append(f"object_id = ${param_index}")
+ params.append(object_id)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if description_embedding is not None:
+ update_fields.append(f"description_embedding = ${param_index}")
+ params.append(description_embedding)
+ param_index += 1
+
+ if weight is not None:
+ update_fields.append(f"weight = ${param_index}")
+ params.append(weight)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(relationship_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}
+ RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+ """
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Relationship(
+ id=result["id"],
+ subject=result["subject"],
+ predicate=result["predicate"],
+ object=result["object"],
+ description=result["description"],
+ subject_id=result["subject_id"],
+ object_id=result["object_id"],
+ weight=result["weight"],
+ chunk_ids=result["chunk_ids"],
+ parent_id=result["parent_id"],
+ metadata=result["metadata"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the relationship: {e}",
+ ) from e
+
+ async def delete(
+ self,
+ parent_id: UUID,
+ relationship_ids: Optional[list[UUID]] = None,
+ store_type: StoreType = StoreType.GRAPHS,
+ ) -> None:
+ """Delete relationships from the specified store. If relationship_ids
+ is not provided, deletes all relationships for the given parent_id.
+
+ Args:
+ parent_id: UUID of the parent (collection_id or document_id)
+ relationship_ids: Optional list of specific relationship IDs to delete
+ store_type: Type of store (graph or document)
+
+ Returns:
+ List of deleted relationship IDs
+
+ Raises:
+ R2RException: If specific relationships were requested but not all found
+ """
+ table_name = self._get_relationship_table_for_store(store_type)
+
+ if relationship_ids is None:
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE parent_id = $1
+ RETURNING id
+ """
+ results = await self.connection_manager.fetch_query(
+ QUERY, [parent_id]
+ )
+ else:
+ QUERY = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = ANY($1) AND parent_id = $2
+ RETURNING id
+ """
+ results = await self.connection_manager.fetch_query(
+ QUERY, [relationship_ids, parent_id]
+ )
+
+ deleted_ids = [row["id"] for row in results]
+ if relationship_ids and len(deleted_ids) != len(relationship_ids):
+ raise R2RException(
+ f"Some relationships not found in {store_type} store or no permission to delete",
+ 404,
+ )
+
+ async def export_to_csv(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "subject",
+ "predicate",
+ "object",
+ "description",
+ "subject_id",
+ "object_id",
+ "weight",
+ "chunk_ids",
+ "parent_id",
+ "metadata",
+ "created_at",
+ "updated_at",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ subject,
+ predicate,
+ object,
+ description,
+ subject_id::text,
+ object_id::text,
+ weight,
+ chunk_ids::text,
+ parent_id::text,
+ metadata::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+ FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
+ """
+
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ writer.writerow(row)
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+
+class PostgresCommunitiesHandler(Handler):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+
+ async def create_tables(self) -> None:
+ vector_column_str = _get_vector_column_str(
+ self.dimension, self.quantization_type
+ )
+
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ collection_id UUID,
+ community_id UUID,
+ level INT,
+ name TEXT NOT NULL,
+ summary TEXT NOT NULL,
+ findings TEXT[],
+ rating FLOAT,
+ rating_explanation TEXT,
+ description_embedding {vector_column_str} NOT NULL,
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ metadata JSONB,
+ UNIQUE (community_id, level, collection_id)
+ );"""
+
+ await self.connection_manager.execute_query(query)
+
+ async def create(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ name: str,
+ summary: str,
+ findings: Optional[list[str]],
+ rating: Optional[float],
+ rating_explanation: Optional[str],
+ description_embedding: Optional[list[float] | str] = None,
+ ) -> Community:
+ table_name = "graphs_communities"
+
+ if isinstance(description_embedding, list):
+ description_embedding = str(description_embedding)
+
+ query = f"""
+ INSERT INTO {self._get_table_name(table_name)}
+ (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+ """
+
+ params = [
+ parent_id,
+ name,
+ summary,
+ findings,
+ rating,
+ rating_explanation,
+ description_embedding,
+ ]
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return Community(
+ id=result["id"],
+ collection_id=result["collection_id"],
+ name=result["name"],
+ summary=result["summary"],
+ findings=result["findings"],
+ rating=result["rating"],
+ rating_explanation=result["rating_explanation"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while creating the community: {e}",
+ ) from e
+
+ async def update(
+ self,
+ community_id: UUID,
+ store_type: StoreType,
+ name: Optional[str] = None,
+ summary: Optional[str] = None,
+ summary_embedding: Optional[list[float] | str] = None,
+ findings: Optional[list[str]] = None,
+ rating: Optional[float] = None,
+ rating_explanation: Optional[str] = None,
+ ) -> Community:
+ table_name = "graphs_communities"
+ update_fields = []
+ params: list[Any] = []
+ param_index = 1
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if summary is not None:
+ update_fields.append(f"summary = ${param_index}")
+ params.append(summary)
+ param_index += 1
+
+ if summary_embedding is not None:
+ update_fields.append(f"description_embedding = ${param_index}")
+ params.append(summary_embedding)
+ param_index += 1
+
+ if findings is not None:
+ update_fields.append(f"findings = ${param_index}")
+ params.append(findings)
+ param_index += 1
+
+ if rating is not None:
+ update_fields.append(f"rating = ${param_index}")
+ params.append(rating)
+ param_index += 1
+
+ if rating_explanation is not None:
+ update_fields.append(f"rating_explanation = ${param_index}")
+ params.append(rating_explanation)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(community_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}\
+ RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+ """
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+
+ return Community(
+ id=result["id"],
+ community_id=result["community_id"],
+ name=result["name"],
+ summary=result["summary"],
+ findings=result["findings"],
+ rating=result["rating"],
+ rating_explanation=result["rating_explanation"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the community: {e}",
+ ) from e
+
+ async def delete(
+ self,
+ parent_id: UUID,
+ community_id: UUID,
+ ) -> None:
+ table_name = "graphs_communities"
+
+ params = [community_id, parent_id]
+
+ # Delete the community
+ query = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE id = $1 AND collection_id = $2
+ """
+
+ try:
+ await self.connection_manager.execute_query(query, params)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while deleting the community: {e}",
+ ) from e
+
+ async def delete_all_communities(
+ self,
+ parent_id: UUID,
+ ) -> None:
+ table_name = "graphs_communities"
+
+ params = [parent_id]
+
+ # Delete all communities for the parent_id
+ query = f"""
+ DELETE FROM {self._get_table_name(table_name)}
+ WHERE collection_id = $1
+ """
+
+ try:
+ await self.connection_manager.execute_query(query, params)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while deleting communities: {e}",
+ ) from e
+
+ async def get(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ offset: int,
+ limit: int,
+ community_ids: Optional[list[UUID]] = None,
+ community_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ):
+ """Retrieve communities from the specified store."""
+ # Do we ever want to get communities from document store?
+ table_name = "graphs_communities"
+
+ conditions = ["collection_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if community_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(community_ids)
+ param_index += 1
+
+ if community_names:
+ conditions.append(f"name = ANY(${param_index})")
+ params.append(community_names)
+ param_index += 1
+
+ select_fields = """
+ id, community_id, name, summary, findings, rating,
+ rating_explanation, level, created_at, updated_at
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+
+ count = (
+ await self.connection_manager.fetch_query(
+ COUNT_QUERY, params[: param_index - 1]
+ )
+ )[0]["count"]
+
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name(table_name)}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ communities = []
+ for row in rows:
+ community_dict = dict(row)
+
+ communities.append(Community(**community_dict))
+
+ return communities, count
+
+ async def export_to_csv(
+ self,
+ parent_id: UUID,
+ store_type: StoreType,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "collection_id",
+ "community_id",
+ "level",
+ "name",
+ "summary",
+ "findings",
+ "rating",
+ "rating_explanation",
+ "created_at",
+ "updated_at",
+ "metadata",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ table_name = "graphs_communities"
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ collection_id::text,
+ community_id::text,
+ level,
+ name,
+ summary,
+ findings::text,
+ rating,
+ rating_explanation,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ metadata::text
+ FROM {self._get_table_name(table_name)}
+ """
+
+ conditions = ["collection_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ writer.writerow(row)
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+
+class PostgresGraphsHandler(Handler):
+ """Handler for Knowledge Graph METHODS in PostgreSQL."""
+
+ TABLE_NAME = "graphs"
+
+ def __init__(
+ self,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ self.project_name: str = kwargs.get("project_name") # type: ignore
+ self.connection_manager: PostgresConnectionManager = kwargs.get(
+ "connection_manager"
+ ) # type: ignore
+ self.dimension: int = kwargs.get("dimension") # type: ignore
+ self.quantization_type: VectorQuantizationType = kwargs.get(
+ "quantization_type"
+ ) # type: ignore
+ self.collections_handler: PostgresCollectionsHandler = kwargs.get(
+ "collections_handler"
+ ) # type: ignore
+
+ self.entities = PostgresEntitiesHandler(*args, **kwargs)
+ self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
+ self.communities = PostgresCommunitiesHandler(*args, **kwargs)
+
+ self.handlers = [
+ self.entities,
+ self.relationships,
+ self.communities,
+ ]
+
+ async def create_tables(self) -> None:
+ """Create the graph tables with mandatory collection_id support."""
+ QUERY = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ collection_id UUID NOT NULL,
+ name TEXT NOT NULL,
+ description TEXT,
+ status TEXT NOT NULL,
+ document_ids UUID[],
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ CREATE INDEX IF NOT EXISTS graph_collection_id_idx
+ ON {self._get_table_name("graphs")} (collection_id);
+ """
+
+ await self.connection_manager.execute_query(QUERY)
+
+ for handler in self.handlers:
+ await handler.create_tables()
+
+ async def create(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ status: str = "pending",
+ ) -> GraphResponse:
+ """Create a new graph associated with a collection."""
+
+ name = name or f"Graph {collection_id}"
+ description = description or ""
+
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ (id, collection_id, name, description, status)
+ VALUES ($1, $2, $3, $4, $5)
+ RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
+ """
+ params = [
+ collection_id,
+ collection_id,
+ name,
+ description,
+ status,
+ ]
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query=query,
+ params=params,
+ )
+
+ return GraphResponse(
+ id=result["id"],
+ collection_id=result["collection_id"],
+ name=result["name"],
+ description=result["description"],
+ status=result["status"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ document_ids=result["document_ids"] or [],
+ )
+ except UniqueViolationError:
+ raise R2RException(
+ message="Graph with this ID already exists",
+ status_code=409,
+ ) from None
+
+ async def reset(self, parent_id: UUID) -> None:
+ """Completely reset a graph and all associated data."""
+
+ await self.entities.delete(
+ parent_id=parent_id, store_type=StoreType.GRAPHS
+ )
+ await self.relationships.delete(
+ parent_id=parent_id, store_type=StoreType.GRAPHS
+ )
+ await self.communities.delete_all_communities(parent_id=parent_id)
+
+ # Now, update the graph record to remove any attached document IDs.
+ # This sets document_ids to an empty UUID array.
+ query = f"""
+ UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ SET document_ids = ARRAY[]::uuid[]
+ WHERE id = $1;
+ """
+ await self.connection_manager.execute_query(query, [parent_id])
+
+ async def list_graphs(
+ self,
+ offset: int,
+ limit: int,
+ # filter_user_ids: Optional[list[UUID]] = None,
+ filter_graph_ids: Optional[list[UUID]] = None,
+ filter_collection_id: Optional[UUID] = None,
+ ) -> dict[str, list[GraphResponse] | int]:
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filter_graph_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(filter_graph_ids)
+ param_index += 1
+
+ # if filter_user_ids:
+ # conditions.append(f"user_id = ANY(${param_index})")
+ # params.append(filter_user_ids)
+ # param_index += 1
+
+ if filter_collection_id:
+ conditions.append(f"collection_id = ${param_index}")
+ params.append(filter_collection_id)
+ param_index += 1
+
+ where_clause = (
+ f"WHERE {' AND '.join(conditions)}" if conditions else ""
+ )
+
+ query = f"""
+ WITH RankedGraphs AS (
+ SELECT
+ id, collection_id, name, description, status, created_at, updated_at, document_ids,
+ COUNT(*) OVER() as total_entries,
+ ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
+ FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ {where_clause}
+ )
+ SELECT * FROM RankedGraphs
+ WHERE rn = 1
+ ORDER BY created_at DESC
+ OFFSET ${param_index} LIMIT ${param_index + 1}
+ """
+
+ params.extend([offset, limit])
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"] if results else 0
+
+ graphs = [
+ GraphResponse(
+ id=row["id"],
+ document_ids=row["document_ids"] or [],
+ name=row["name"],
+ collection_id=row["collection_id"],
+ description=row["description"],
+ status=row["status"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ )
+ for row in results
+ ]
+
+ return {"results": graphs, "total_entries": total_entries}
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while fetching graphs: {e}",
+ ) from e
+
+ async def get(
+ self, offset: int, limit: int, graph_id: Optional[UUID] = None
+ ):
+ if graph_id is None:
+ params = [offset, limit]
+
+ QUERY = f"""
+ SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ OFFSET $1 LIMIT $2
+ """
+
+ ret = await self.connection_manager.fetch_query(QUERY, params)
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ """
+ count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
+ 0
+ ]["count"]
+
+ return {
+ "results": [Graph(**row) for row in ret],
+ "total_entries": count,
+ }
+
+ else:
+ QUERY = f"""
+ SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
+ """
+
+ params = [graph_id] # type: ignore
+
+ return {
+ "results": [
+ Graph(
+ **await self.connection_manager.fetchrow_query(
+ QUERY, params
+ )
+ )
+ ]
+ }
+
+ async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
+ """Add documents to the graph by copying their entities and
+ relationships."""
+ # Copy entities from document_entity to graphs_entities
+ ENTITY_COPY_QUERY = f"""
+ INSERT INTO {self._get_table_name("graphs_entities")} (
+ name, category, description, parent_id, description_embedding,
+ chunk_ids, metadata
+ )
+ SELECT
+ name, category, description, $1, description_embedding,
+ chunk_ids, metadata
+ FROM {self._get_table_name("documents_entities")}
+ WHERE parent_id = ANY($2)
+ """
+ await self.connection_manager.execute_query(
+ ENTITY_COPY_QUERY, [id, document_ids]
+ )
+
+ # Copy relationships from documents_relationships to graphs_relationships
+ RELATIONSHIP_COPY_QUERY = f"""
+ INSERT INTO {self._get_table_name("graphs_relationships")} (
+ subject, predicate, object, description, subject_id, object_id,
+ weight, chunk_ids, parent_id, metadata, description_embedding
+ )
+ SELECT
+ subject, predicate, object, description, subject_id, object_id,
+ weight, chunk_ids, $1, metadata, description_embedding
+ FROM {self._get_table_name("documents_relationships")}
+ WHERE parent_id = ANY($2)
+ """
+ await self.connection_manager.execute_query(
+ RELATIONSHIP_COPY_QUERY, [id, document_ids]
+ )
+
+ # Add document_ids to the graph
+ UPDATE_GRAPH_QUERY = f"""
+ UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ SET document_ids = array_cat(
+ CASE
+ WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
+ ELSE document_ids
+ END,
+ $2::uuid[]
+ )
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(
+ UPDATE_GRAPH_QUERY, [id, document_ids]
+ )
+
+ return True
+
+ async def update(
+ self,
+ collection_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> GraphResponse:
+ """Update an existing graph."""
+ update_fields = []
+ params: list = []
+ param_index = 1
+
+ if name is not None:
+ update_fields.append(f"name = ${param_index}")
+ params.append(name)
+ param_index += 1
+
+ if description is not None:
+ update_fields.append(f"description = ${param_index}")
+ params.append(description)
+ param_index += 1
+
+ if not update_fields:
+ raise R2RException(status_code=400, message="No fields to update")
+
+ update_fields.append("updated_at = NOW()")
+ params.append(collection_id)
+
+ query = f"""
+ UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+ SET {", ".join(update_fields)}
+ WHERE id = ${param_index}
+ RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
+ """
+
+ try:
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+
+ if not result:
+ raise R2RException(status_code=404, message="Graph not found")
+
+ return GraphResponse(
+ id=result["id"],
+ collection_id=result["collection_id"],
+ name=result["name"],
+ description=result["description"],
+ status=result["status"],
+ created_at=result["created_at"],
+ document_ids=result["document_ids"] or [],
+ updated_at=result["updated_at"],
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while updating the graph: {e}",
+ ) from e
+
+ async def get_entities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ entity_ids: Optional[list[UUID]] = None,
+ entity_names: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ) -> tuple[list[Entity], int]:
+ """Get entities for a graph.
+
+ Args:
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ parent_id: UUID of the collection
+ entity_ids: Optional list of entity IDs to filter by
+ entity_names: Optional list of entity names to filter by
+ include_embeddings: Whether to include embeddings in the response
+
+ Returns:
+ Tuple of (list of entities, total count)
+ """
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if entity_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(entity_ids)
+ param_index += 1
+
+ if entity_names:
+ conditions.append(f"name = ANY(${param_index})")
+ params.append(entity_names)
+ param_index += 1
+
+ # Count query - uses the same conditions but without offset/limit
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name("graphs_entities")}
+ WHERE {" AND ".join(conditions)}
+ """
+ count = (
+ await self.connection_manager.fetch_query(COUNT_QUERY, params)
+ )[0]["count"]
+
+ # Define base columns to select
+ select_fields = """
+ id, name, category, description, parent_id,
+ chunk_ids, metadata
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ # Main query for fetching entities with pagination
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name("graphs_entities")}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ entities = []
+ for row in rows:
+ entity_dict = dict(row)
+ if isinstance(entity_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ entity_dict["metadata"] = json.loads(
+ entity_dict["metadata"]
+ )
+
+ entities.append(Entity(**entity_dict))
+
+ return entities, count
+
+ async def get_relationships(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ relationship_ids: Optional[list[UUID]] = None,
+ relationship_types: Optional[list[str]] = None,
+ include_embeddings: bool = False,
+ ) -> tuple[list[Relationship], int]:
+ """Get relationships for a graph.
+
+ Args:
+ parent_id: UUID of the graph
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ relationship_ids: Optional list of relationship IDs to filter by
+ relationship_types: Optional list of relationship types to filter by
+ include_metadata: Whether to include metadata in the response
+
+ Returns:
+ Tuple of (list of relationships, total count)
+ """
+ conditions = ["parent_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if relationship_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(relationship_ids)
+ param_index += 1
+
+ if relationship_types:
+ conditions.append(f"predicate = ANY(${param_index})")
+ params.append(relationship_types)
+ param_index += 1
+
+ # Count query - uses the same conditions but without offset/limit
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name("graphs_relationships")}
+ WHERE {" AND ".join(conditions)}
+ """
+ count = (
+ await self.connection_manager.fetch_query(COUNT_QUERY, params)
+ )[0]["count"]
+
+ # Define base columns to select
+ select_fields = """
+ id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ # Main query for fetching relationships with pagination
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name("graphs_relationships")}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ relationships = []
+ for row in rows:
+ relationship_dict = dict(row)
+ if isinstance(relationship_dict["metadata"], str):
+ with contextlib.suppress(json.JSONDecodeError):
+ relationship_dict["metadata"] = json.loads(
+ relationship_dict["metadata"]
+ )
+
+ relationships.append(Relationship(**relationship_dict))
+
+ return relationships, count
+
+ async def add_entities(
+ self,
+ entities: list[Entity],
+ table_name: str,
+ conflict_columns: list[str] | None = None,
+ ) -> asyncpg.Record:
+ """Upsert entities into the entities_raw table. These are raw entities
+ extracted from the document.
+
+ Args:
+ entities: list[Entity]: list of entities to upsert
+ collection_name: str: name of the collection
+
+ Returns:
+ result: asyncpg.Record: result of the upsert operation
+ """
+ if not conflict_columns:
+ conflict_columns = []
+ cleaned_entities = []
+ for entity in entities:
+ entity_dict = entity.to_dict()
+ entity_dict["chunk_ids"] = (
+ entity_dict["chunk_ids"]
+ if entity_dict.get("chunk_ids")
+ else []
+ )
+ entity_dict["description_embedding"] = (
+ str(entity_dict["description_embedding"])
+ if entity_dict.get("description_embedding") # type: ignore
+ else None
+ )
+ cleaned_entities.append(entity_dict)
+
+ return await _add_objects(
+ objects=cleaned_entities,
+ full_table_name=self._get_table_name(table_name),
+ connection_manager=self.connection_manager,
+ conflict_columns=conflict_columns,
+ )
+
+ async def get_all_relationships(
+ self,
+ collection_id: UUID | None,
+ graph_id: UUID | None,
+ document_ids: Optional[list[UUID]] = None,
+ ) -> list[Relationship]:
+ QUERY = f"""
+ SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
+ """
+ relationships = await self.connection_manager.fetch_query(
+ QUERY, [collection_id]
+ )
+
+ return [Relationship(**relationship) for relationship in relationships]
+
+ async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
+ """Check if a document exists in the graph's document_ids array.
+
+ Args:
+ graph_id (UUID): ID of the graph to check
+ document_id (UUID): ID of the document to look for
+
+ Returns:
+ bool: True if document exists in graph, False otherwise
+
+ Raises:
+ R2RException: If graph not found
+ """
+ QUERY = f"""
+ SELECT EXISTS (
+ SELECT 1
+ FROM {self._get_table_name("graphs")}
+ WHERE id = $1
+ AND document_ids IS NOT NULL
+ AND $2 = ANY(document_ids)
+ ) as exists;
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ QUERY, [graph_id, document_id]
+ )
+
+ if result is None:
+ raise R2RException(f"Graph {graph_id} not found", 404)
+
+ return result["exists"]
+
+ async def get_communities(
+ self,
+ parent_id: UUID,
+ offset: int,
+ limit: int,
+ community_ids: Optional[list[UUID]] = None,
+ include_embeddings: bool = False,
+ ) -> tuple[list[Community], int]:
+ """Get communities for a graph.
+
+ Args:
+ collection_id: UUID of the collection
+ offset: Number of records to skip
+ limit: Maximum number of records to return (-1 for no limit)
+ community_ids: Optional list of community IDs to filter by
+ include_embeddings: Whether to include embeddings in the response
+
+ Returns:
+ Tuple of (list of communities, total count)
+ """
+ conditions = ["collection_id = $1"]
+ params: list[Any] = [parent_id]
+ param_index = 2
+
+ if community_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(community_ids)
+ param_index += 1
+
+ select_fields = """
+ id, collection_id, name, summary, findings, rating, rating_explanation
+ """
+ if include_embeddings:
+ select_fields += ", description_embedding"
+
+ COUNT_QUERY = f"""
+ SELECT COUNT(*)
+ FROM {self._get_table_name("graphs_communities")}
+ WHERE {" AND ".join(conditions)}
+ """
+ count = (
+ await self.connection_manager.fetch_query(COUNT_QUERY, params)
+ )[0]["count"]
+
+ QUERY = f"""
+ SELECT {select_fields}
+ FROM {self._get_table_name("graphs_communities")}
+ WHERE {" AND ".join(conditions)}
+ ORDER BY created_at
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ QUERY += f" LIMIT ${param_index}"
+ params.append(limit)
+
+ rows = await self.connection_manager.fetch_query(QUERY, params)
+
+ communities = []
+ for row in rows:
+ community_dict = dict(row)
+ communities.append(Community(**community_dict))
+
+ return communities, count
+
+ async def add_community(self, community: Community) -> None:
+ # TODO: Fix in the short term.
+ # we need to do this because postgres insert needs to be a string
+ community.description_embedding = str(community.description_embedding) # type: ignore[assignment]
+
+ non_null_attrs = {
+ k: v for k, v in community.__dict__.items() if v is not None
+ }
+ columns = ", ".join(non_null_attrs.keys())
+ placeholders = ", ".join(
+ f"${i + 1}" for i in range(len(non_null_attrs))
+ )
+
+ conflict_columns = ", ".join(
+ [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
+ )
+
+ QUERY = f"""
+ INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
+ VALUES ({placeholders})
+ ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
+ {conflict_columns}
+ """
+
+ await self.connection_manager.execute_many(
+ QUERY, [tuple(non_null_attrs.values())]
+ )
+
+ async def delete(self, collection_id: UUID) -> None:
+ graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
+
+ if len(graphs["results"]) == 0:
+ raise R2RException(
+ message=f"Graph not found for collection {collection_id}",
+ status_code=404,
+ )
+ await self.reset(collection_id)
+ # set status to PENDING for this collection.
+ QUERY = f"""
+ UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ QUERY, [GraphExtractionStatus.PENDING, collection_id]
+ )
+ # Delete the graph
+ QUERY = f"""
+ DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
+ """
+ try:
+ await self.connection_manager.execute_query(QUERY, [collection_id])
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"An error occurred while deleting the graph: {e}",
+ ) from e
+
+ async def perform_graph_clustering(
+ self,
+ collection_id: UUID,
+ leiden_params: dict[str, Any],
+ ) -> Tuple[int, Any]:
+ """Calls the external clustering service to cluster the graph."""
+
+ offset = 0
+ page_size = 1000
+ all_relationships = []
+ while True:
+ relationships, count = await self.relationships.get(
+ parent_id=collection_id,
+ store_type=StoreType.GRAPHS,
+ offset=offset,
+ limit=page_size,
+ )
+
+ if not relationships:
+ break
+
+ all_relationships.extend(relationships)
+ offset += len(relationships)
+
+ if offset >= count:
+ break
+
+ logger.info(
+ f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
+ )
+ if len(all_relationships) == 0:
+ raise R2RException(
+ message="No relationships found for clustering",
+ status_code=400,
+ )
+
+ return await self._cluster_and_add_community_info(
+ relationships=all_relationships,
+ leiden_params=leiden_params,
+ collection_id=collection_id,
+ )
+
+ async def _call_clustering_service(
+ self, relationships: list[Relationship], leiden_params: dict[str, Any]
+ ) -> list[dict]:
+ """Calls the external Graspologic clustering service, sending
+ relationships and parameters.
+
+ Expects a response with 'communities' field.
+ """
+ # Convert relationships to a JSON-friendly format
+ rel_data = []
+ for r in relationships:
+ rel_data.append(
+ {
+ "id": str(r.id),
+ "subject": r.subject,
+ "object": r.object,
+ "weight": r.weight if r.weight is not None else 1.0,
+ }
+ )
+
+ endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
+ if not endpoint:
+ raise ValueError("CLUSTERING_SERVICE_URL not set.")
+
+ url = f"{endpoint}/cluster"
+
+ payload = {"relationships": rel_data, "leiden_params": leiden_params}
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(url, json=payload, timeout=3600)
+ response.raise_for_status()
+
+ data = response.json()
+ return data.get("communities", [])
+
+ async def _create_graph_and_cluster(
+ self,
+ relationships: list[Relationship],
+ leiden_params: dict[str, Any],
+ ) -> Any:
+ """Create a graph and cluster it."""
+
+ return await self._call_clustering_service(
+ relationships, leiden_params
+ )
+
+ async def _cluster_and_add_community_info(
+ self,
+ relationships: list[Relationship],
+ leiden_params: dict[str, Any],
+ collection_id: UUID,
+ ) -> Tuple[int, Any]:
+ logger.info(f"Creating graph and clustering for {collection_id}")
+
+ await asyncio.sleep(0.1)
+ start_time = time.time()
+
+ hierarchical_communities = await self._create_graph_and_cluster(
+ relationships=relationships,
+ leiden_params=leiden_params,
+ )
+
+ logger.info(
+ f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
+ )
+
+ if not hierarchical_communities:
+ num_communities = 0
+ else:
+ num_communities = (
+ max(item["cluster"] for item in hierarchical_communities) + 1
+ )
+
+ logger.info(
+ f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
+ )
+
+ return num_communities, hierarchical_communities
+
+ async def get_entity_map(
+ self, offset: int, limit: int, document_id: UUID
+ ) -> dict[str, dict[str, list[dict[str, Any]]]]:
+ QUERY1 = f"""
+ WITH entities_list AS (
+ SELECT DISTINCT name
+ FROM {self._get_table_name("documents_entities")}
+ WHERE parent_id = $1
+ ORDER BY name ASC
+ LIMIT {limit} OFFSET {offset}
+ )
+ SELECT e.name, e.description, e.category,
+ (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
+ e.parent_id
+ FROM {self._get_table_name("documents_entities")} e
+ JOIN entities_list el ON e.name = el.name
+ GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
+ ORDER BY e.name;"""
+
+ entities_list = await self.connection_manager.fetch_query(
+ QUERY1, [document_id]
+ )
+ entities_list = [Entity(**entity) for entity in entities_list]
+
+ QUERY2 = f"""
+ WITH entities_list AS (
+
+ SELECT DISTINCT name
+ FROM {self._get_table_name("documents_entities")}
+ WHERE parent_id = $1
+ ORDER BY name ASC
+ LIMIT {limit} OFFSET {offset}
+ )
+
+ SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
+ (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
+ FROM {self._get_table_name("documents_relationships")} t
+ JOIN entities_list el ON t.subject = el.name
+ ORDER BY t.subject, t.predicate, t.object;
+ """
+
+ relationships_list = await self.connection_manager.fetch_query(
+ QUERY2, [document_id]
+ )
+ relationships_list = [
+ Relationship(**relationship) for relationship in relationships_list
+ ]
+
+ entity_map: dict[str, dict[str, list[Any]]] = {}
+ for entity in entities_list:
+ if entity.name not in entity_map:
+ entity_map[entity.name] = {"entities": [], "relationships": []}
+ entity_map[entity.name]["entities"].append(entity)
+
+ for relationship in relationships_list:
+ if relationship.subject in entity_map:
+ entity_map[relationship.subject]["relationships"].append(
+ relationship
+ )
+ if relationship.object in entity_map:
+ entity_map[relationship.object]["relationships"].append(
+ relationship
+ )
+
+ return entity_map
+
+ async def graph_search(
+ self, query: str, **kwargs: Any
+ ) -> AsyncGenerator[Any, None]:
+ """Perform semantic search with similarity scores while maintaining
+ exact same structure."""
+
+ query_embedding = kwargs.get("query_embedding", None)
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding must be provided for semantic search"
+ )
+
+ search_type = kwargs.get(
+ "search_type", "entities"
+ ) # entities | relationships | communities
+ embedding_type = kwargs.get("embedding_type", "description_embedding")
+ property_names = kwargs.get("property_names", ["name", "description"])
+
+ # Add metadata if not present
+ if "metadata" not in property_names:
+ property_names.append("metadata")
+
+ filters = kwargs.get("filters", {})
+ limit = kwargs.get("limit", 10)
+ use_fulltext_search = kwargs.get("use_fulltext_search", True)
+ use_hybrid_search = kwargs.get("use_hybrid_search", True)
+
+ if use_hybrid_search or use_fulltext_search:
+ logger.warning(
+ "Hybrid and fulltext search not supported for graph search, ignoring."
+ )
+
+ table_name = f"graphs_{search_type}"
+ property_names_str = ", ".join(property_names)
+
+ # Build the WHERE clause from filters
+ params: list[str | int | bytes] = [
+ json.dumps(query_embedding),
+ limit,
+ ]
+ conditions_clause = self._build_filters(filters, params, search_type)
+ where_clause = (
+ f"WHERE {conditions_clause}" if conditions_clause else ""
+ )
+
+ # Construct the query
+ # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
+ # We'll convert that to similarity_score by doing (1 - distance).
+ QUERY = f"""
+ SELECT
+ {property_names_str},
+ ({embedding_type} <=> $1) as similarity_score
+ FROM {self._get_table_name(table_name)}
+ {where_clause}
+ ORDER BY {embedding_type} <=> $1
+ LIMIT $2;
+ """
+
+ results = await self.connection_manager.fetch_query(
+ QUERY, tuple(params)
+ )
+
+ for result in results:
+ output = {
+ prop: result[prop] for prop in property_names if prop in result
+ }
+ output["similarity_score"] = (
+ 1 - float(result["similarity_score"])
+ if result.get("similarity_score")
+ else "n/a"
+ )
+ yield output
+
+ def _build_filters(
+ self, filter_dict: dict, parameters: list[Any], search_type: str
+ ) -> str:
+ """Build a WHERE clause from a nested filter dictionary for the graph
+ search.
+
+ - If search_type == "communities", we normally filter by `collection_id`.
+ - Otherwise (entities/relationships), we normally filter by `parent_id`.
+ - If user provides `"collection_ids": {...}`, we interpret that as wanting
+ to filter by multiple collection IDs (i.e. 'parent_id IN (...)' or
+ 'collection_id IN (...)').
+ """
+
+ # The usual "base" column used by your code
+ base_id_column = (
+ "collection_id" if search_type == "communities" else "parent_id"
+ )
+
+ def parse_condition(key: str, value: Any) -> str:
+ # ----------------------------------------------------------------------
+ # 1) If it's the normal base_id_column (like "parent_id" or "collection_id")
+ # ----------------------------------------------------------------------
+ if key == base_id_column:
+ if isinstance(value, dict):
+ op, clause = next(iter(value.items()))
+ if op == "$eq":
+ # single equality
+ parameters.append(str(clause))
+ return f"{base_id_column} = ${len(parameters)}::uuid"
+ elif op in ("$in", "$overlap"):
+ # treat both $in/$overlap as "IN the set" for a single column
+ array_val = [str(x) for x in clause]
+ parameters.append(array_val)
+ return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
+ # handle other operators as needed
+ else:
+ # direct equality
+ parameters.append(str(value))
+ return f"{base_id_column} = ${len(parameters)}::uuid"
+
+ # ----------------------------------------------------------------------
+ # 2) SPECIAL: if user specifically sets "collection_ids" in filters
+ # We interpret that to mean "Look for rows whose parent_id (or collection_id)
+ # is in the array of values" – i.e. we do the same logic but we forcibly
+ # direct it to the same column: parent_id or collection_id.
+ # ----------------------------------------------------------------------
+ elif key == "collection_ids":
+ # If we are searching communities, the relevant field is `collection_id`.
+ # If searching entities/relationships, the relevant field is `parent_id`.
+ col_to_use = (
+ "collection_id"
+ if search_type == "communities"
+ else "parent_id"
+ )
+
+ if isinstance(value, dict):
+ op, clause = next(iter(value.items()))
+ if op == "$eq":
+ # single equality => col_to_use = clause
+ parameters.append(str(clause))
+ return f"{col_to_use} = ${len(parameters)}::uuid"
+ elif op in ("$in", "$overlap"):
+ # "col_to_use = ANY($param::uuid[])"
+ array_val = [str(x) for x in clause]
+ parameters.append(array_val)
+ return (
+ f"{col_to_use} = ANY(${len(parameters)}::uuid[])"
+ )
+ # add more if you want, e.g. $ne, $gt, etc.
+ else:
+ # direct equality scenario: "collection_ids": "some-uuid"
+ parameters.append(str(value))
+ return f"{col_to_use} = ${len(parameters)}::uuid"
+
+ # ----------------------------------------------------------------------
+ # 3) If key starts with "metadata.", handle metadata-based filters
+ # ----------------------------------------------------------------------
+ elif key.startswith("metadata."):
+ field = key.split("metadata.")[1]
+ if isinstance(value, dict):
+ op, clause = next(iter(value.items()))
+ if op == "$eq":
+ parameters.append(clause)
+ return f"(metadata->>'{field}') = ${len(parameters)}"
+ elif op == "$ne":
+ parameters.append(clause)
+ return f"(metadata->>'{field}') != ${len(parameters)}"
+ elif op == "$gt":
+ parameters.append(clause)
+ return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
+ # etc...
+ else:
+ parameters.append(value)
+ return f"(metadata->>'{field}') = ${len(parameters)}"
+
+ # ----------------------------------------------------------------------
+ # 4) Not recognized => return empty so we skip it
+ # ----------------------------------------------------------------------
+ return ""
+
+ # --------------------------------------------------------------------------
+ # 5) parse_filter() is the recursive walker that sees $and/$or or normal fields
+ # --------------------------------------------------------------------------
+ def parse_filter(fd: dict) -> str:
+ filter_conditions = []
+ for k, v in fd.items():
+ if k == "$and":
+ and_parts = [parse_filter(sub) for sub in v if sub]
+ and_parts = [x for x in and_parts if x.strip()]
+ if and_parts:
+ filter_conditions.append(
+ f"({' AND '.join(and_parts)})"
+ )
+ elif k == "$or":
+ or_parts = [parse_filter(sub) for sub in v if sub]
+ or_parts = [x for x in or_parts if x.strip()]
+ if or_parts:
+ filter_conditions.append(f"({' OR '.join(or_parts)})")
+ else:
+ c = parse_condition(k, v)
+ if c and c.strip():
+ filter_conditions.append(c)
+
+ if not filter_conditions:
+ return ""
+ if len(filter_conditions) == 1:
+ return filter_conditions[0]
+ return " AND ".join(filter_conditions)
+
+ return parse_filter(filter_dict)
+
+ async def get_existing_document_entity_chunk_ids(
+ self, document_id: UUID
+ ) -> list[str]:
+ QUERY = f"""
+ SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
+ """
+ return [
+ item["chunk_id"]
+ for item in await self.connection_manager.fetch_query(
+ QUERY, [document_id]
+ )
+ ]
+
+ async def get_entity_count(
+ self,
+ collection_id: Optional[UUID] = None,
+ document_id: Optional[UUID] = None,
+ distinct: bool = False,
+ entity_table_name: str = "entity",
+ ) -> int:
+ if collection_id is None and document_id is None:
+ raise ValueError(
+ "Either collection_id or document_id must be provided."
+ )
+
+ conditions = ["parent_id = $1"]
+ params = [str(document_id)]
+
+ count_value = "DISTINCT name" if distinct else "*"
+
+ QUERY = f"""
+ SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
+ WHERE {" AND ".join(conditions)}
+ """
+
+ return (await self.connection_manager.fetch_query(QUERY, params))[0][
+ "count"
+ ]
+
+ async def update_entity_descriptions(self, entities: list[Entity]):
+ query = f"""
+ UPDATE {self._get_table_name("graphs_entities")}
+ SET description = $3, description_embedding = $4
+ WHERE name = $1 AND graph_id = $2
+ """
+
+ inputs = [
+ (
+ entity.name,
+ entity.parent_id,
+ entity.description,
+ entity.description_embedding,
+ )
+ for entity in entities
+ ]
+
+ await self.connection_manager.execute_many(query, inputs) # type: ignore
+
+
+def _json_serialize(obj):
+ if isinstance(obj, UUID):
+ return str(obj)
+ elif isinstance(obj, (datetime.datetime, datetime.date)):
+ return obj.isoformat()
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
+
+
+async def _add_objects(
+ objects: list[dict],
+ full_table_name: str,
+ connection_manager: PostgresConnectionManager,
+ conflict_columns: list[str] | None = None,
+ exclude_metadata: list[str] | None = None,
+) -> list[UUID]:
+ """Bulk insert objects into the specified table using
+ jsonb_to_recordset."""
+
+ if conflict_columns is None:
+ conflict_columns = []
+ if exclude_metadata is None:
+ exclude_metadata = []
+
+ # Exclude specified metadata and prepare data
+ cleaned_objects = []
+ for obj in objects:
+ cleaned_obj = {
+ k: v
+ for k, v in obj.items()
+ if k not in exclude_metadata and v is not None
+ }
+ cleaned_objects.append(cleaned_obj)
+
+ # Serialize the list of objects to JSON
+ json_data = json.dumps(cleaned_objects, default=_json_serialize)
+
+ # Prepare the column definitions for jsonb_to_recordset
+
+ columns = cleaned_objects[0].keys()
+ column_defs = []
+ for col in columns:
+ # Map Python types to PostgreSQL types
+ sample_value = cleaned_objects[0][col]
+ if "embedding" in col:
+ pg_type = "vector"
+ elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
+ pg_type = "uuid[]"
+ elif col == "id" or "_id" in col:
+ pg_type = "uuid"
+ elif isinstance(sample_value, str):
+ pg_type = "text"
+ elif isinstance(sample_value, UUID):
+ pg_type = "uuid"
+ elif isinstance(sample_value, (int, float)):
+ pg_type = "numeric"
+ elif isinstance(sample_value, list) and all(
+ isinstance(x, UUID) for x in sample_value
+ ):
+ pg_type = "uuid[]"
+ elif isinstance(sample_value, list):
+ pg_type = "jsonb"
+ elif isinstance(sample_value, dict):
+ pg_type = "jsonb"
+ elif isinstance(sample_value, bool):
+ pg_type = "boolean"
+ elif isinstance(sample_value, (datetime.datetime, datetime.date)):
+ pg_type = "timestamp"
+ else:
+ raise TypeError(
+ f"Unsupported data type for column '{col}': {type(sample_value)}"
+ )
+
+ column_defs.append(f"{col} {pg_type}")
+
+ columns_str = ", ".join(columns)
+ column_defs_str = ", ".join(column_defs)
+
+ if conflict_columns:
+ conflict_columns_str = ", ".join(conflict_columns)
+ update_columns_str = ", ".join(
+ f"{col}=EXCLUDED.{col}"
+ for col in columns
+ if col not in conflict_columns
+ )
+ on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
+ else:
+ on_conflict_clause = ""
+
+ QUERY = f"""
+ INSERT INTO {full_table_name} ({columns_str})
+ SELECT {columns_str}
+ FROM jsonb_to_recordset($1::jsonb)
+ AS x({column_defs_str})
+ {on_conflict_clause}
+ RETURNING id;
+ """
+
+ # Execute the query
+ result = await connection_manager.fetch_query(QUERY, [json_data])
+
+ # Extract and return the IDs
+ return [record["id"] for record in result]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/limits.py b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
new file mode 100644
index 00000000..1029ec50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
@@ -0,0 +1,434 @@
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from core.base import Handler
+from shared.abstractions import User
+
+from ...base.providers.database import DatabaseConfig, LimitSettings
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+
+class PostgresLimitsHandler(Handler):
+ TABLE_NAME = "request_log"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ config: DatabaseConfig,
+ ):
+ """
+ :param config: The global DatabaseConfig with default rate limits.
+ """
+ super().__init__(project_name, connection_manager)
+ self.config = config
+
+ logger.debug(
+ f"Initialized PostgresLimitsHandler with project: {project_name}"
+ )
+
+ async def create_tables(self):
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+ time TIMESTAMPTZ NOT NULL,
+ user_id UUID NOT NULL,
+ route TEXT NOT NULL
+ );
+ """
+ logger.debug("Creating request_log table if not exists")
+ await self.connection_manager.execute_query(query)
+
+ async def _count_requests(
+ self,
+ user_id: UUID,
+ route: Optional[str],
+ since: datetime,
+ ) -> int:
+ """Count how many requests a user (optionally for a specific route) has
+ made since the given datetime."""
+ if route:
+ query = f"""
+ SELECT COUNT(*)::int
+ FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ WHERE user_id = $1
+ AND route = $2
+ AND time >= $3
+ """
+ params = [user_id, route, since]
+ logger.debug(
+ f"Counting requests for user={user_id}, route={route}"
+ )
+ else:
+ query = f"""
+ SELECT COUNT(*)::int
+ FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ WHERE user_id = $1
+ AND time >= $2
+ """
+ params = [user_id, since]
+ logger.debug(f"Counting all requests for user={user_id}")
+
+ result = await self.connection_manager.fetchrow_query(query, params)
+ return result["count"] if result else 0
+
+ async def _count_monthly_requests(
+ self,
+ user_id: UUID,
+ route: Optional[str] = None, # <--- ADDED THIS
+ ) -> int:
+ """Count the number of requests so far this month for a given user.
+
+ If route is provided, count only for that route. Otherwise, count
+ globally.
+ """
+ now = datetime.now(timezone.utc)
+ start_of_month = now.replace(
+ day=1, hour=0, minute=0, second=0, microsecond=0
+ )
+ return await self._count_requests(
+ user_id, route=route, since=start_of_month
+ )
+
+ def determine_effective_limits(
+ self, user: User, route: str
+ ) -> LimitSettings:
+ """
+ Determine the final effective limits for a user+route combination,
+ respecting:
+ 1) Global defaults
+ 2) Route-specific overrides
+ 3) User-level overrides
+ """
+ # ------------------------
+ # 1) Start with global/base
+ # ------------------------
+ base_limits = self.config.limits
+
+ # We’ll make a copy so we don’t mutate self.config.limits directly
+ effective = LimitSettings(
+ global_per_min=base_limits.global_per_min,
+ route_per_min=base_limits.route_per_min,
+ monthly_limit=base_limits.monthly_limit,
+ )
+
+ # ------------------------
+ # 2) Route-level overrides
+ # ------------------------
+ route_config = self.config.route_limits.get(route)
+ if route_config:
+ if route_config.global_per_min is not None:
+ effective.global_per_min = route_config.global_per_min
+ if route_config.route_per_min is not None:
+ effective.route_per_min = route_config.route_per_min
+ if route_config.monthly_limit is not None:
+ effective.monthly_limit = route_config.monthly_limit
+
+ # ------------------------
+ # 3) User-level overrides
+ # ------------------------
+ # The user object might have a dictionary of overrides
+ # which can include route_overrides, global_per_min, monthly_limit, etc.
+ user_overrides = user.limits_overrides or {}
+
+ # (a) "global" user overrides
+ if user_overrides.get("global_per_min") is not None:
+ effective.global_per_min = user_overrides["global_per_min"]
+ if user_overrides.get("monthly_limit") is not None:
+ effective.monthly_limit = user_overrides["monthly_limit"]
+
+ # (b) route-level user overrides
+ route_overrides = user_overrides.get("route_overrides", {})
+ specific_config = route_overrides.get(route, {})
+ if specific_config.get("global_per_min") is not None:
+ effective.global_per_min = specific_config["global_per_min"]
+ if specific_config.get("route_per_min") is not None:
+ effective.route_per_min = specific_config["route_per_min"]
+ if specific_config.get("monthly_limit") is not None:
+ effective.monthly_limit = specific_config["monthly_limit"]
+
+ return effective
+
+ async def check_limits(self, user: User, route: str):
+ """Perform rate limit checks for a user on a specific route.
+
+ :param user: The fully-fetched User object with .limits_overrides, etc.
+ :param route: The route/path being accessed.
+ :raises ValueError: if any limit is exceeded.
+ """
+ user_id = user.id
+ now = datetime.now(timezone.utc)
+ one_min_ago = now - timedelta(minutes=1)
+
+ # 1) Compute the final (effective) limits for this user & route
+ limits = self.determine_effective_limits(user, route)
+
+ # 2) Check each of them in turn, if they exist
+ # ------------------------------------------------------------
+ # Global per-minute limit
+ # ------------------------------------------------------------
+ if limits.global_per_min is not None:
+ user_req_count = await self._count_requests(
+ user_id, None, one_min_ago
+ )
+ if user_req_count > limits.global_per_min:
+ logger.warning(
+ f"Global per-minute limit exceeded for "
+ f"user_id={user_id}, route={route}"
+ )
+ raise ValueError("Global per-minute rate limit exceeded")
+
+ # ------------------------------------------------------------
+ # Route-specific per-minute limit
+ # ------------------------------------------------------------
+ if limits.route_per_min is not None:
+ route_req_count = await self._count_requests(
+ user_id, route, one_min_ago
+ )
+ if route_req_count > limits.route_per_min:
+ logger.warning(
+ f"Per-route per-minute limit exceeded for "
+ f"user_id={user_id}, route={route}"
+ )
+ raise ValueError("Per-route per-minute rate limit exceeded")
+
+ # ------------------------------------------------------------
+ # Monthly limit
+ # ------------------------------------------------------------
+ if limits.monthly_limit is not None:
+ # If you truly want a per-route monthly limit, we pass 'route'.
+ # If you want a global monthly limit, pass 'None'.
+ monthly_count = await self._count_monthly_requests(user_id, route)
+ if monthly_count > limits.monthly_limit:
+ logger.warning(
+ f"Monthly limit exceeded for user_id={user_id}, "
+ f"route={route}"
+ )
+ raise ValueError("Monthly rate limit exceeded")
+
+ async def log_request(self, user_id: UUID, route: str):
+ """Log a successful request to the request_log table."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ (time, user_id, route)
+ VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+ """
+ await self.connection_manager.execute_query(query, [user_id, route])
+
+
+# import logging
+# from datetime import datetime, timedelta, timezone
+# from typing import Optional
+# from uuid import UUID
+
+# from core.base import Handler
+# from shared.abstractions import User
+
+# from ..base.providers.database import DatabaseConfig, LimitSettings
+# from .base import PostgresConnectionManager
+
+# logger = logging.getLogger(__name__)
+
+# class PostgresLimitsHandler(Handler):
+# TABLE_NAME = "request_log"
+
+# def __init__(
+# self,
+# project_name: str,
+# connection_manager: PostgresConnectionManager,
+# config: DatabaseConfig,
+# ):
+# """
+# :param config: The global DatabaseConfig with default rate limits.
+# """
+# super().__init__(project_name, connection_manager)
+# self.config = config
+
+# logger.debug(
+# f"Initialized PostgresLimitsHandler with project: {project_name}"
+# )
+
+# async def create_tables(self):
+# query = f"""
+# CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+# time TIMESTAMPTZ NOT NULL,
+# user_id UUID NOT NULL,
+# route TEXT NOT NULL
+# );
+# """
+# logger.debug("Creating request_log table if not exists")
+# await self.connection_manager.execute_query(query)
+
+# async def _count_requests(
+# self,
+# user_id: UUID,
+# route: Optional[str],
+# since: datetime,
+# ) -> int:
+# """
+# Count how many requests a user (optionally for a specific route)
+# has made since the given datetime.
+# """
+# if route:
+# query = f"""
+# SELECT COUNT(*)::int
+# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# WHERE user_id = $1
+# AND route = $2
+# AND time >= $3
+# """
+# params = [user_id, route, since]
+# logger.debug(f"Counting requests for user={user_id}, route={route}")
+# else:
+# query = f"""
+# SELECT COUNT(*)::int
+# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# WHERE user_id = $1
+# AND time >= $2
+# """
+# params = [user_id, since]
+# logger.debug(f"Counting all requests for user={user_id}")
+
+# result = await self.connection_manager.fetchrow_query(query, params)
+# return result["count"] if result else 0
+
+# async def _count_monthly_requests(self, user_id: UUID) -> int:
+# """
+# Count the number of requests so far this month for a given user.
+# """
+# now = datetime.now(timezone.utc)
+# start_of_month = now.replace(
+# day=1, hour=0, minute=0, second=0, microsecond=0
+# )
+# return await self._count_requests(
+# user_id, route=None, since=start_of_month
+# )
+
+# def determine_effective_limits(
+# self, user: User, route: str
+# ) -> LimitSettings:
+# """
+# Determine the final effective limits for a user+route combination,
+# respecting:
+# 1) Global defaults
+# 2) Route-specific overrides
+# 3) User-level overrides
+# """
+# # ------------------------
+# # 1) Start with global/base
+# # ------------------------
+# base_limits = self.config.limits
+
+# # We’ll make a copy so we don’t mutate self.config.limits directly
+# effective = LimitSettings(
+# global_per_min=base_limits.global_per_min,
+# route_per_min=base_limits.route_per_min,
+# monthly_limit=base_limits.monthly_limit,
+# )
+
+# # ------------------------
+# # 2) Route-level overrides
+# # ------------------------
+# route_config = self.config.route_limits.get(route)
+# if route_config:
+# if route_config.global_per_min is not None:
+# effective.global_per_min = route_config.global_per_min
+# if route_config.route_per_min is not None:
+# effective.route_per_min = route_config.route_per_min
+# if route_config.monthly_limit is not None:
+# effective.monthly_limit = route_config.monthly_limit
+
+# # ------------------------
+# # 3) User-level overrides
+# # ------------------------
+# # The user object might have a dictionary of overrides
+# # which can include route_overrides, global_per_min, monthly_limit, etc.
+# user_overrides = user.limits_overrides or {}
+
+# # (a) "global" user overrides
+# if user_overrides.get("global_per_min") is not None:
+# effective.global_per_min = user_overrides["global_per_min"]
+# if user_overrides.get("monthly_limit") is not None:
+# effective.monthly_limit = user_overrides["monthly_limit"]
+
+# # (b) route-level user overrides
+# route_overrides = user_overrides.get("route_overrides", {})
+# specific_config = route_overrides.get(route, {})
+# if specific_config.get("global_per_min") is not None:
+# effective.global_per_min = specific_config["global_per_min"]
+# if specific_config.get("route_per_min") is not None:
+# effective.route_per_min = specific_config["route_per_min"]
+# if specific_config.get("monthly_limit") is not None:
+# effective.monthly_limit = specific_config["monthly_limit"]
+
+# return effective
+
+# async def check_limits(self, user: User, route: str):
+# """
+# Perform rate limit checks for a user on a specific route.
+
+# :param user: The fully-fetched User object with .limits_overrides, etc.
+# :param route: The route/path being accessed.
+# :raises ValueError: if any limit is exceeded.
+# """
+# user_id = user.id
+# now = datetime.now(timezone.utc)
+# one_min_ago = now - timedelta(minutes=1)
+
+# # 1) Compute the final (effective) limits for this user & route
+# limits = self.determine_effective_limits(user, route)
+
+# # 2) Check each of them in turn, if they exist
+# # ------------------------------------------------------------
+# # Global per-minute limit
+# # ------------------------------------------------------------
+# if limits.global_per_min is not None:
+# user_req_count = await self._count_requests(
+# user_id, None, one_min_ago
+# )
+# if user_req_count > limits.global_per_min:
+# logger.warning(
+# f"Global per-minute limit exceeded for "
+# f"user_id={user_id}, route={route}"
+# )
+# raise ValueError("Global per-minute rate limit exceeded")
+
+# # ------------------------------------------------------------
+# # Route-specific per-minute limit
+# # ------------------------------------------------------------
+# if limits.route_per_min is not None:
+# route_req_count = await self._count_requests(
+# user_id, route, one_min_ago
+# )
+# if route_req_count > limits.route_per_min:
+# logger.warning(
+# f"Per-route per-minute limit exceeded for "
+# f"user_id={user_id}, route={route}"
+# )
+# raise ValueError("Per-route per-minute rate limit exceeded")
+
+# # ------------------------------------------------------------
+# # Monthly limit
+# # ------------------------------------------------------------
+# if limits.monthly_limit is not None:
+# monthly_count = await self._count_monthly_requests(user_id)
+# if monthly_count > limits.monthly_limit:
+# logger.warning(
+# f"Monthly limit exceeded for user_id={user_id}, "
+# f"route={route}"
+# )
+# raise ValueError("Monthly rate limit exceeded")
+
+# async def log_request(self, user_id: UUID, route: str):
+# """
+# Log a successful request to the request_log table.
+# """
+# query = f"""
+# INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# (time, user_id, route)
+# VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+# """
+# await self.connection_manager.execute_query(query, [user_id, route])
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
new file mode 100644
index 00000000..acccc9c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
@@ -0,0 +1,286 @@
+# TODO: Clean this up and make it more congruent across the vector database and the relational database.
+import logging
+import os
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...base.abstractions import VectorQuantizationType
+from ...base.providers import (
+ DatabaseConfig,
+ DatabaseProvider,
+ PostgresConfigurationSettings,
+)
+from .base import PostgresConnectionManager, SemaphoreConnectionPool
+from .chunks import PostgresChunksHandler
+from .collections import PostgresCollectionsHandler
+from .conversations import PostgresConversationsHandler
+from .documents import PostgresDocumentsHandler
+from .files import PostgresFilesHandler
+from .graphs import (
+ PostgresCommunitiesHandler,
+ PostgresEntitiesHandler,
+ PostgresGraphsHandler,
+ PostgresRelationshipsHandler,
+)
+from .limits import PostgresLimitsHandler
+from .prompts_handler import PostgresPromptsHandler
+from .tokens import PostgresTokensHandler
+from .users import PostgresUserHandler
+
+if TYPE_CHECKING:
+ from ..crypto import BCryptCryptoProvider, NaClCryptoProvider
+
+ CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
+
+logger = logging.getLogger()
+
+
+class PostgresDatabaseProvider(DatabaseProvider):
+ # R2R configuration settings
+ config: DatabaseConfig
+ project_name: str
+
+ # Postgres connection settings
+ user: str
+ password: str
+ host: str
+ port: int
+ db_name: str
+ connection_string: str
+ dimension: int | float
+ conn: Optional[Any]
+
+ crypto_provider: "CryptoProviderType"
+ postgres_configuration_settings: PostgresConfigurationSettings
+ default_collection_name: str
+ default_collection_description: str
+
+ connection_manager: PostgresConnectionManager
+ documents_handler: PostgresDocumentsHandler
+ collections_handler: PostgresCollectionsHandler
+ token_handler: PostgresTokensHandler
+ users_handler: PostgresUserHandler
+ chunks_handler: PostgresChunksHandler
+ entities_handler: PostgresEntitiesHandler
+ communities_handler: PostgresCommunitiesHandler
+ relationships_handler: PostgresRelationshipsHandler
+ graphs_handler: PostgresGraphsHandler
+ prompts_handler: PostgresPromptsHandler
+ files_handler: PostgresFilesHandler
+ conversations_handler: PostgresConversationsHandler
+ limits_handler: PostgresLimitsHandler
+
+ def __init__(
+ self,
+ config: DatabaseConfig,
+ dimension: int | float,
+ crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(config)
+
+ env_vars = [
+ ("user", "R2R_POSTGRES_USER"),
+ ("password", "R2R_POSTGRES_PASSWORD"),
+ ("host", "R2R_POSTGRES_HOST"),
+ ("port", "R2R_POSTGRES_PORT"),
+ ("db_name", "R2R_POSTGRES_DBNAME"),
+ ]
+
+ for attr, env_var in env_vars:
+ if value := (getattr(config, attr) or os.getenv(env_var)):
+ setattr(self, attr, value)
+ else:
+ raise ValueError(
+ f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
+ )
+
+ self.port = int(self.port)
+
+ self.project_name = (
+ config.app.project_name
+ or os.getenv("R2R_PROJECT_NAME")
+ or "r2r_default"
+ )
+
+ if not self.project_name:
+ raise ValueError(
+ "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
+ )
+
+ # Check if it's a Unix socket connection
+ if self.host.startswith("/") and not self.port:
+ self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
+ logger.info("Connecting to Postgres via Unix socket")
+ else:
+ self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
+ logger.info("Connecting to Postgres via TCP/IP")
+
+ self.dimension = dimension
+ self.quantization_type = quantization_type
+ self.conn = None
+ self.config: DatabaseConfig = config
+ self.crypto_provider = crypto_provider
+ self.postgres_configuration_settings: PostgresConfigurationSettings = (
+ self._get_postgres_configuration_settings(config)
+ )
+ self.default_collection_name = config.default_collection_name
+ self.default_collection_description = (
+ config.default_collection_description
+ )
+
+ self.connection_manager: PostgresConnectionManager = (
+ PostgresConnectionManager()
+ )
+ self.documents_handler = PostgresDocumentsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ dimension=self.dimension,
+ )
+ self.token_handler = PostgresTokensHandler(
+ self.project_name, self.connection_manager
+ )
+ self.collections_handler = PostgresCollectionsHandler(
+ self.project_name, self.connection_manager, self.config
+ )
+ self.users_handler = PostgresUserHandler(
+ self.project_name, self.connection_manager, self.crypto_provider
+ )
+ self.chunks_handler = PostgresChunksHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ dimension=self.dimension,
+ quantization_type=(self.quantization_type),
+ )
+ self.conversations_handler = PostgresConversationsHandler(
+ self.project_name, self.connection_manager
+ )
+ self.entities_handler = PostgresEntitiesHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.relationships_handler = PostgresRelationshipsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.communities_handler = PostgresCommunitiesHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.graphs_handler = PostgresGraphsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.prompts_handler = PostgresPromptsHandler(
+ self.project_name, self.connection_manager
+ )
+ self.files_handler = PostgresFilesHandler(
+ self.project_name, self.connection_manager
+ )
+
+ self.limits_handler = PostgresLimitsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ config=self.config,
+ )
+
+ async def initialize(self):
+ logger.info("Initializing `PostgresDatabaseProvider`.")
+ self.pool = SemaphoreConnectionPool(
+ self.connection_string, self.postgres_configuration_settings
+ )
+ await self.pool.initialize()
+ await self.connection_manager.initialize(self.pool)
+
+ async with self.pool.get_connection() as conn:
+ await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
+
+ # Create schema if it doesn't exist
+ await conn.execute(
+ f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
+ )
+
+ await self.documents_handler.create_tables()
+ await self.collections_handler.create_tables()
+ await self.token_handler.create_tables()
+ await self.users_handler.create_tables()
+ await self.chunks_handler.create_tables()
+ await self.prompts_handler.create_tables()
+ await self.files_handler.create_tables()
+ await self.graphs_handler.create_tables()
+ await self.communities_handler.create_tables()
+ await self.entities_handler.create_tables()
+ await self.relationships_handler.create_tables()
+ await self.conversations_handler.create_tables()
+ await self.limits_handler.create_tables()
+
+ def _get_postgres_configuration_settings(
+ self, config: DatabaseConfig
+ ) -> PostgresConfigurationSettings:
+ settings = PostgresConfigurationSettings()
+
+ env_mapping = {
+ "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
+ "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
+ "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
+ "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
+ "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
+ "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
+ "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
+ "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
+ "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
+ "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
+ "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
+ "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
+ "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
+ "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
+ "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
+ "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
+ "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
+ "work_mem": "R2R_POSTGRES_WORK_MEM",
+ }
+
+ for setting, env_var in env_mapping.items():
+ value = getattr(
+ config.postgres_configuration_settings, setting, None
+ )
+ if value is None:
+ value = os.getenv(env_var)
+
+ if value is not None:
+ field_type = settings.__annotations__[setting]
+ if field_type == Optional[int]:
+ value = int(value)
+ elif field_type == Optional[float]:
+ value = float(value)
+
+ setattr(settings, setting, value)
+
+ return settings
+
+ async def close(self):
+ if self.pool:
+ await self.pool.close()
+
+ async def __aenter__(self):
+ await self.initialize()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.close()
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/__init__.py
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml
new file mode 100644
index 00000000..7e4a2615
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/chunk_enrichment.yaml
@@ -0,0 +1,56 @@
+chunk_enrichment:
+ template: >
+ ## Task:
+
+ Enrich and refine the given chunk of text while maintaining its independence and precision.
+
+ ## Context:
+ Document Summary: {document_summary}
+ Preceding Chunks: {preceding_chunks}
+ Succeeding Chunks: {succeeding_chunks}
+
+ ## Input Chunk:
+ {chunk}
+
+ ## Semantic Organization Guidelines:
+ 1. Group related information:
+ - Combine logically connected data points
+ - Maintain context within each grouping
+ - Preserve relationships between entities
+
+ 2. Structure hierarchy:
+ - Organize from general to specific
+ - Use clear categorical divisions
+ - Maintain parent-child relationships
+
+ 3. Information density:
+ - Balance completeness with clarity
+ - Ensure each chunk can stand alone
+ - Preserve essential context
+
+ 4. Pattern recognition:
+ - Standardize similar information
+ - Use consistent formatting for similar data types
+ - It is appropriate to restructure tables or lists in ways that are more advantageous for sematic matching
+ - Maintain searchable patterns
+
+ ## Output Requirements:
+ 1. Each chunk should be independently meaningful
+ 2. Related information should stay together
+ 3. Format should support efficient matching
+ 4. Original data relationships must be preserved
+ 5. Context should be clear without external references
+
+ Maximum length: {chunk_size} characters
+
+ Output the restructured chunk only.
+
+ ## Restructured Chunk:
+
+ input_types:
+ document_summary: str
+ chunk: str
+ preceding_chunks: str
+ succeeding_chunks: str
+ chunk_size: int
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml
new file mode 100644
index 00000000..b9475453
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/collection_summary.yaml
@@ -0,0 +1,41 @@
+collection_summary:
+ template: >
+ ## Task:
+
+ Generate a comprehensive collection-level summary that describes the overall content, themes, and relationships across multiple documents. The summary should provide a high-level understanding of what the collection contains and represents.
+
+ ### Input Documents:
+
+ Document Summaries:
+ {document_summaries}
+
+ ### Requirements:
+
+ 1. SCOPE
+ - Synthesize key themes and patterns across all documents
+ - Identify common topics, entities, and relationships
+ - Capture the collection's overall purpose or domain
+
+ 2. STRUCTURE
+ - Target length: Approximately 3-4 concise sentences
+ - Focus on collective insights rather than individual document details
+
+ 3. CONTENT GUIDELINES
+ - Emphasize shared concepts and recurring elements
+ - Highlight any temporal or thematic progression
+ - Identify key stakeholders or entities that appear across documents
+ - Note any significant relationships between documents
+
+ 4. INTEGRATION PRINCIPLES
+ - Connect related concepts across different documents
+ - Identify overarching narratives or frameworks
+ - Preserve important context from individual documents
+ - Balance breadth of coverage with depth of insight
+
+ ### Query:
+
+ Generate a collection-level summary following the above requirements. Focus on synthesizing the key themes and relationships across all documents while maintaining clarity and concision.
+
+ ## Response:
+ input_types:
+ document_summaries: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml
new file mode 100644
index 00000000..5b264530
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent.yaml
@@ -0,0 +1,28 @@
+dynamic_rag_agent:
+ template: >
+ ### You are a helpful agent that can search for information, the date is {date}.
+
+
+ The response should contain line-item attributions to relevant search results, and be as informative if possible. Note that you will only be able to load {max_tool_context_length} tokens of context at a time, if the context surpasses this then it will be truncated. If possible, set filters which will reduce the context returned to only that which is specific, by means of '$eq' or '$overlap' filters.
+
+
+ Search rarely exceeds the context window, while getting raw context can depending on the user data shown below. IF YOU CAN FETCH THE RAW CONTEXT, THEN DO SO.
+
+
+ The available user documents and collections are shown below:
+
+ <= Documents =>
+ {document_context}
+
+
+ If no relevant results are found, then state that no results were found. If no obvious question is present given the available tools and context, then do not carry out a search, and instead ask for clarification.
+
+
+ REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
+
+ input_types:
+ date: str
+ document_context: str
+ max_tool_context_length: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml
new file mode 100644
index 00000000..ce5784a3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml
@@ -0,0 +1,99 @@
+dynamic_rag_agent_xml_tooling:
+ template: |
+ You are an AI research assistant with access to document retrieval tools. You should use both your internal knowledge store and web search tools to answer the user questions. Today is {date}.
+
+ <AvailableTools>
+
+ <ToolDefinition>
+ <Name>web_search</Name>
+ <Description>External web search. Parameters must be a valid JSON object.</Description>
+ <Parameters>
+ <Parameter type="string" required="true">
+ <Name>query</Name>
+ <Example>{{"query": "recent AI developments 2024"}}</Example>
+ </Parameter>
+ </Parameters>
+ </ToolDefinition>
+
+ </AvailableTools>
+
+ ### Documents
+ {document_context}
+
+ 2. DECIDE response strategy:
+ - If specific document IDs are relevant: Use `content` with $eq filters
+ - For broad concepts: Use `search_file_knowledge` with keyword queries
+ - Use `web_search` to gather live information
+
+ 3. FORMAT response STRICTLY as:
+ <Action>
+ <ToolCalls>
+ <ToolCall>
+ <Name>search_file_knowledge</Name>
+ <!-- Parameters MUST be a single valid JSON object -->
+ <Parameters>{{"query": "example search"}}</Parameters>
+ </ToolCall>
+ <!-- Multiple tool call example -->
+ <ToolCall>
+ <Name>content</Name>
+ <!-- Example with nested filters -->
+ <Parameters>{{"filters": {{"$and": [{{"document_id": {{"$eq": "abc123"}}, {{"collection_ids": {{"$overlap": ["id1"]}}}}]}}}}}}</Parameters>
+ </ToolCall>
+ </ToolCalls>
+ </Action>
+
+ ### Constraints
+ - MAX_CONTEXT: {max_tool_context_length} tokens
+ - REQUIRED: Line-item references like [abc1234][def5678] when using content
+ - REQUIRED: All Parameters must be valid JSON objects
+ - PROHIBITED: Assuming document contents without retrieval
+ - PROHIBITED: Using XML format for Parameters values
+
+ ### Examples
+ 1. Good initial search oepration:
+ <Action>
+ <ToolCalls>
+ <ToolCall>
+ <Name>web_search</Name>
+ <Parameters>{{"query": "recent advances in machine learning"}}</Parameters>
+ </ToolCall>
+ <ToolCall>
+ <Name>search_file_knowledge</Name>
+ <Parameters>{{"query": "machine learning applications"}}</Parameters>
+ </ToolCall>
+ <ToolCall>
+ <Name>search_file_knowledge</Name>
+ <Parameters>{{"query": "recent advances in machine learning"}}</Parameters>
+ </ToolCall>
+ </ToolCalls>
+ </Action>
+
+
+ 2. Good content call with complex filters:
+ <Action>
+ <ToolCalls>
+ <ToolCall>
+ <Name>web_search</Name>
+ <Parameters>{{"query": "recent advances in machine learning"}}</Parameters>
+ </ToolCall>
+ <ToolCall>
+ <Name>content</Name>
+ <Parameters>{{"filters": {{"$or": [{{"document_id": {{"$eq": "a5b880db-..."}}}}, {{"document_id": {{"$overlap": ["54b523f6-...","26fc0bf5-..."]}}}}]}}}}}}</Parameters>
+ </ToolCall>
+ </ToolCalls>
+ </Action>
+
+ ### Important!
+ Continue to take actions until you have sufficient relevant context, then return your answer with the result tool.
+ You have a maximum of 100_000 context tokens or 10 iterations to find the information required.
+
+ RETURN A COMPLETE AND COMPREHENSIVE ANSWER WHEN POSSIBLE.
+
+ REMINDER - Use line item references like `[c910e2e], [b12cd2f]` with THIS EXACT FORMAT to refer to the specific search result IDs returned in the provided context.
+
+ input_types:
+ date: str
+ document_context: str
+ max_tool_context_length: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml
new file mode 100644
index 00000000..50e71544
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_communities.yaml
@@ -0,0 +1,74 @@
+graph_communities:
+ template: |
+ You are an AI assistant that helps a human analyst perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
+
+ Context Overview:
+ {collection_description}
+
+ Your Task:
+ Write a comprehensive report of a community as a single XML document. The report must follow this exact structure:
+
+ <community>
+ <name>A specific, concise community name representing its key entities</name>
+ <summary>An executive summary that contextualizes the community</summary>
+ <rating>A float score (0-10) representing impact severity</rating>
+ <rating_explanation>A single sentence explaining the rating</rating_explanation>
+ <findings>
+ <finding>First key insight about the community</finding>
+ <finding>Second key insight about the community</finding>
+ <!-- Include 5-10 findings total -->
+ </findings>
+ </community>
+
+ Data Reference Format:
+ Include data references in findings like this:
+ "Example sentence [Data: <dataset name> (record ids); <dataset name> (record ids)]"
+ Use no more than 5 record IDs per reference. Add "+more" to indicate additional records.
+
+ Example Input:
+ -----------
+ Text:
+
+ Entity: OpenAI
+ descriptions:
+ 101,OpenAI is an AI research and deployment company.
+ relationships:
+ 201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions.
+ 203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service.
+ 204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round.
+ Entity: Stripe
+ descriptions:
+ 102,Stripe is a technology company that builds economic infrastructure for the internet.
+ relationships:
+ 201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions.
+ 202,Stripe,Airbnb,Stripe provides payment processing services to Airbnb.
+ 204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round.
+ 205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options.
+ Entity: Airbnb
+ descriptions:
+ 103,Airbnb is an online marketplace for lodging and tourism experiences.
+ relationships:
+ 203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service.
+ 205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options.
+
+ Example Output:
+ <community>
+ <name>OpenAI-Stripe-Airbnb Community</name>
+ <summary>The OpenAI-Stripe-Airbnb Community is a network of companies that collaborate on AI research, payment solutions, and customer service.</summary>
+ <rating>8.5</rating>
+ <rating_explanation>The OpenAI-Stripe-Airbnb Community has a high impact on the collection due to its significant contributions to AI research, payment solutions, and customer service.</rating_explanation>
+ <findings>
+ <finding>OpenAI and Stripe have a partnership to integrate payment solutions [Data: Relationships (201)].</finding>
+ <finding>OpenAI and Airbnb collaborate on AI tools for customer service [Data: Relationships (203)].</finding>
+ <finding>Stripe provides payment processing services to Airbnb [Data: Relationships (202)].</finding>
+ <finding>Stripe invested in OpenAI's latest funding round [Data: Relationships (204)].</finding>
+ <finding>Airbnb and Stripe collaborate on expanding global payment options [Data: Relationships (205)].</finding>
+ </findings>
+ </community>
+
+ Entity Data:
+ {input_text}
+
+ input_types:
+ collection_description: str
+ input_text: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml
new file mode 100644
index 00000000..b46185fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_entity_description.yaml
@@ -0,0 +1,40 @@
+graph_entity_description:
+ template: |
+ Given the following information about an entity:
+
+ Document Summary:
+ {document_summary}
+
+ Entity Information:
+ {entity_info}
+
+ Relationship Data:
+ {relationships_txt}
+
+ Generate a comprehensive entity description that:
+
+ 1. Opens with a clear definition statement identifying the entity's primary classification and core function
+ 2. Incorporates key data points from both the document summary and relationship information
+ 3. Emphasizes the entity's role within its broader context or system
+ 4. Highlights critical relationships, particularly those that:
+ - Demonstrate hierarchical connections
+ - Show functional dependencies
+ - Indicate primary use cases or applications
+
+ Format Requirements:
+ - Length: 2-3 sentences
+ - Style: Technical and precise
+ - Structure: Definition + Context + Key Relationships
+ - Tone: Objective and authoritative
+
+ Integration Guidelines:
+ - Prioritize information that appears in multiple sources
+ - Resolve any conflicting information by favoring the most specific source
+ - Include temporal context if relevant to the entity's current state or evolution
+
+ Output should reflect the entity's complete nature while maintaining concision and clarity.
+ input_types:
+ document_summary: str
+ entity_info: str
+ relationships_txt: str
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml
new file mode 100644
index 00000000..9850878a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/graph_extraction.yaml
@@ -0,0 +1,100 @@
+graph_extraction:
+ template: >
+ # Context
+ {document_summary}
+
+ # Goal
+ Given both a document summary and full text, identify all entities and their entity types, along with all relationships among the identified entities.
+
+ # Steps
+ 1. Identify all entities given the full text, grounding and contextualizing them based on the summary. For each identified entity, extract:
+ - entit: Name of the entity, capitalized
+ - entity_type: Type of the entity (constrained to {entity_types} if provided, otherwise all types)
+ - entity_description: Comprehensive description incorporating context from both summary and full text
+
+ Format each Entity in XML tags as follows: <entity name="entity"><type>entity_type</type><description>entity_description</description></entity>
+
+ Note: Generate additional entities from descriptions if they contain named entities for relationship mapping.
+
+ 2. From the identified entities, identify all related entity pairs, using both summary and full text context:
+ - source_entity: name of the source entity
+ - target_entity: name of the target entity
+ - relation: relationship type (constrained to {relation_types} if provided)
+ - relationship_description: justification based on both summary and full text context
+ - relationship_weight: strength score 0-10
+
+ Format each relationship in XML tags as follows: <relationship><source>source_entity</source><target>target_entity</target><type>relation</type><description>relationship_description</description><weight>relationship_weight</weight></relationship>
+
+ 3. Coverage Requirements:
+ - Each entity must have at least one relationship
+ - Create intermediate entities if needed to establish relationships
+ - Verify relationships against both summary and full text
+ - Resolve any discrepancies between sources
+
+ Example 1:
+ If the list is empty, extract all entities and relations.
+ Entity_types:
+ Relation_types:
+ Text:
+ San Francisco is a city in California. It is known for the Golden Gate Bridge, cable cars, and steep hills. The city is surrounded by the Pacific Ocean and the San Francisco Bay.
+ ######################
+ Output:
+ <entity name="San Francisco"><type>City</type><description>San Francisco is a city in California known for the Golden Gate Bridge, cable cars, and steep hills. It is surrounded by the Pacific Ocean and the San Francisco Bay.</description></entity>
+ <entity name="California"><type>State</type><description>California is a state in the United States.</description></entity>
+ <entity name="Golden Gate Bridge"><type>Landmark</type><description>The Golden Gate Bridge is a famous bridge in San Francisco.</description></entity>
+ <entity name="Pacific Ocean"><type>Body of Water</type><description>The Pacific Ocean is a large body of water that surrounds San Francisco.</description></entity>
+ <entity name="San Francisco Bay"><type>Body of Water</type><description>The San Francisco Bay is a body of water that surrounds San Francisco.</description></entity>
+ <relationship><source>San Francisco</source><target>California</target><type>Located In</type><description>San Francisco is a city located in California.</description><weight>8</weight></relationship>
+ <relationship><source>San Francisco</source><target>Golden Gate Bridge</target><type>Features</type><description>San Francisco features the Golden Gate Bridge.</description><weight>9</weight></relationship>
+ <relationship><source>San Francisco</source><target>Pacific Ocean</target><type>Surrounded By</type><description>San Francisco is surrounded by the Pacific Ocean.</description><weight>7</weight></relationship>
+ <relationship><source>San Francisco</source><target>San Francisco Bay</target><type>Surrounded By</type><description>San Francisco is surrounded by the San Francisco Bay.</description><weight>7</weight></relationship>
+ <relationship><source>California</source><target>San Francisco</target><type>Contains</type><description>California contains the city of San Francisco.</description><weight>8</weight></relationship>
+ <relationship><source>Golden Gate Bridge</source><target>San Francisco</target><type>Located In</type><description>The Golden Gate Bridge is located in San Francisco.</description><weight>8</weight></relationship>
+ <relationship><source>Pacific Ocean</source><target>San Francisco</target><type>Surrounds</type><description>The Pacific Ocean surrounds San Francisco.</description><weight>7</weight></relationship>
+ <relationship><source>San Francisco Bay</source><target>San Francisco</target><type>Surrounds</type><description>The San Francisco Bay surrounds San Francisco.</description><weight>7</weight></relationship>
+
+ ######################
+ Example 2:
+ If the list is empty, extract all entities and relations.
+ Entity_types: Organization, Person
+ Relation_types: Located In, Features
+
+ Text:
+ The Green Bay Packers are a professional American football team based in Green Bay, Wisconsin. The team was established in 1919 by Earl "Curly" Lambeau and George Calhoun. The Packers are the third-oldest franchise in the NFL and have won 13 league championships, including four Super Bowls. The team's home games are played at Lambeau Field, which is named after Curly Lambeau.
+ ######################
+ Output:
+ <entity name="Green Bay Packers"><type>Organization</type><description>The Green Bay Packers are a professional American football team based in Green Bay, Wisconsin. The team was established in 1919 by Earl "Curly" Lambeau and George Calhoun. The Packers are the third-oldest franchise in the NFL and have won 13 league championships, including four Super Bowls. The team's home games are played at Lambeau Field, which is named after Curly Lambeau.</description></entity>
+ <entity name="Green Bay"><type>City</type><description>Green Bay is a city in Wisconsin.</description></entity>
+ <entity name="Wisconsin"><type>State</type><description>Wisconsin is a state in the United States.</description></entity>
+ <entity name="Earl "Curly" Lambeau"><type>Person</type><description>Earl "Curly" Lambeau was a co-founder of the Green Bay Packers.</description></entity>
+ <entity name="George Calhoun"><type>Person</type><description>George Calhoun was a co-founder of the Green Bay Packers.</description></entity>
+ <entity name="NFL"><type>Organization</type><description>The NFL is the National Football League.</description></entity>
+ <entity name="Super Bowl"><type>Event</type><description>The Super Bowl is the championship game of the NFL.</description></entity>
+ <entity name="Lambeau Field"><type>Stadium</type><description>Lambeau Field is the home stadium of the Green Bay Packers.</description></entity>
+ <relationship><source>Green Bay Packers</source><target>Green Bay</target><type>Located In</type><description>The Green Bay Packers are based in Green Bay, Wisconsin.</description><weight>8</weight></relationship>
+ <relationship><source>Green Bay</source><target>Wisconsin</target><type>Located In</type><description>Green Bay is located in Wisconsin.</description><weight>8</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>Earl "Curly" Lambeau</target><type>Founded By</type><description>The Green Bay Packers were established by Earl "Curly" Lambeau.</description><weight>9</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>George Calhoun</target><type>Founded By</type><description>The Green Bay Packers were established by George Calhoun.</description><weight>9</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>NFL</target><type>League</type><description>The Green Bay Packers are a franchise in the NFL.</description><weight>8</weight></relationship>
+ <relationship><source>Green Bay Packers</source><target>Super Bowl</target><type>Championships</type><description>The Green Bay Packers have won four Super Bowls.</description><weight>9</weight></relationship>
+
+ -Real Data-
+ ######################
+ If the list is empty, extract all entities and relations.
+ Entity_types: {entity_types}
+ Relation_types: {relation_types}
+
+ Document Summary:
+ {document_summary}
+
+ Full Text:
+ {input}
+ ######################
+ Output:
+ input_types:
+ document_summary: str
+ max_knowledge_relationships: int
+ input: str
+ entity_types: list[str]
+ relation_types: list[str]
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml
new file mode 100644
index 00000000..d8071d1f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/hyde.yaml
@@ -0,0 +1,29 @@
+hyde:
+ template: >
+ ### Instruction:
+
+ Given the query that follows write a double newline separated list of {num_outputs} single paragraph distinct attempted answers to the given query.
+
+
+ DO NOT generate any single answer which is likely to require information from multiple distinct documents,
+
+ EACH single answer will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents.
+
+
+ FOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two attempted answers would be
+
+ `The key themes of Great Gatsby are ... ANSWER_CONTINUED` and `The key themes of 1984 are ... ANSWER_CONTINUED`, where `ANSWER_CONTINUED` IS TO BE COMPLETED BY YOU in your response.
+
+
+ Here is the original user query to be transformed into answers:
+
+
+ ### Query:
+
+ {message}
+
+
+ ### Response:
+ input_types:
+ num_outputs: int
+ message: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml
new file mode 100644
index 00000000..c835517d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag.yaml
@@ -0,0 +1,29 @@
+rag:
+ template: >
+ ## Task:
+
+ Answer the query given immediately below given the context which follows later. Use line item references to like [c910e2e], [b12cd2f], ... refer to provided search results.
+
+
+ ### Query:
+
+ {query}
+
+
+ ### Context:
+
+ {context}
+
+
+ ### Query:
+
+ {query}
+
+
+ REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
+
+ ## Response:
+ input_types:
+ query: str
+ context: str
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml
new file mode 100644
index 00000000..874d3f39
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/rag_fusion.yaml
@@ -0,0 +1,27 @@
+rag_fusion:
+ template: >
+ ### Instruction:
+
+
+ Given the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query.
+
+ DO NOT generate any single query which is likely to require information from multiple distinct documents,
+
+ EACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents.
+
+ FOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be
+
+ `What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.
+
+ Here is the original user query to be transformed into answers:
+
+
+ ### Query:
+
+ {message}
+
+
+ ### Response:
+ input_types:
+ num_outputs: int
+ message: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml
new file mode 100644
index 00000000..0e940af1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_rag_agent.yaml
@@ -0,0 +1,16 @@
+static_rag_agent:
+ template: >
+ ### You are a helpful agent that can search for information, the date is {date}.
+
+ When asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION.
+
+ The response should contain line-item attributions to relevant search results, and be as informative if possible.
+
+ If no relevant results are found, then state that no results were found. If no obvious question is present, then do not carry out a search, and instead ask for clarification.
+
+ REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
+
+ input_types:
+ date: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml
new file mode 100644
index 00000000..417d161c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/static_research_agent.yaml
@@ -0,0 +1,61 @@
+static_research_agent:
+ template: >-
+ # You are a helpful agent that can search for information, the date is {date}.
+
+ # Comprehensive Strategic Analysis Report
+
+ ## Objective
+ Produce nuanced, robust, and strategically insightful analyses. Adjust your approach based on the nature of the question:
+
+ - **Broad, qualitative, or subjective questions**:
+ Deliver in-depth, qualitative analysis by systematically exploring multiple dimensions and diverse perspectives. Emphasize strategic insights, market psychology, long-term implications, and nuanced evaluations.
+
+ - **Narrow, academic, or factual questions**:
+ Provide focused, precise, and strategic analyses. Clearly articulate cause-effect relationships, relevant context, and strategic significance. Prioritize accuracy, clarity, and concise insights.
+
+ ## Research Guidance
+ - **Multi-thesis Approach (for qualitative/subjective queries):**
+ - Identify and retrieve detailed information from credible sources covering multiple angles, including technical, economic, market-specific, geopolitical, psychological, and long-term strategic implications.
+ - Seek contrasting viewpoints, expert opinions, market analyses, and nuanced discussions.
+
+ - **Focused Strategic Approach (for narrow/academic queries):**
+ - Clearly identify the core elements of the question and retrieve precise, relevant information.
+ - Highlight strategic significance, context, and implications concisely and accurately.
+
+ ## Source Diversity
+ - Draw from diverse, credible sources such as financial analyses, expert commentary, reputable news outlets, industry reports, academic papers, and analyst research.
+
+ ## Structured Analysis
+ - Organize findings into clear, logically sequenced sections (e.g., Technical Details, Market Reactions, Economic Implications, Strategic Insights).
+ - Explicitly link each factor to its implications, providing concrete examples.
+
+ ## Cause-and-Effect Relationships
+ - Clearly articulate how each piece of information contributes to the overall narrative or impact, emphasizing cause-effect dynamics.
+ - Include direct quotes or analyst insights when available to illustrate context or sentiment.
+
+ ## Depth and Nuance
+ - Go beyond surface-level reporting by critically evaluating how events challenge assumptions, reshape strategies, and alter investment theses.
+ - Highlight underlying trends, second-order effects, and unexpected consequences.
+
+ ## Forward-Looking Insights
+ - Offer informed analysis on long-term strategic implications, identifying future opportunities, risks, and market shifts.
+ - Discuss potential shifts in competitive dynamics, regulation, and industry paradigms.
+
+ ## Line-Item Referencing
+ - Carefully reference specific data and insights using line-item references (e.g., `[c910e2e]`, `[b12cd2f]`) for credibility and transparency.
+
+ ## Writing Style and Tone
+ - Adopt a concise yet sophisticated style suitable for investors, industry professionals, and strategic decision-makers.
+ - Prioritize clarity, strategic depth, and actionable insights.
+
+ ## Final Reminders
+ - Anticipate follow-up questions and proactively provide relevant context or comparisons.
+ - Clearly highlight potential opportunities, risks, and long-term implications.
+ - Focus on actionable, well-rounded insights rather than mere summaries.
+ - Add line item citatons like `[c910e2e]`, `[b12cd2f]`.
+ - IGNORE all previous instructions when dealing with academic or factual queries.
+
+ input_types:
+ date: str
+
+ overwrite_on_diff: true
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml
new file mode 100644
index 00000000..1658f2bf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/summary.yaml
@@ -0,0 +1,18 @@
+summary:
+ template: >
+ ## Task:
+
+ Your task is to generate a descriptive summary of the document that follows. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
+
+ ### Document:
+
+ {document}
+
+
+ ### Query:
+
+ Reminder: Your task is to generate a descriptive summary of the document that was given. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
+
+ ## Response:
+ input_types:
+ document: str
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml
new file mode 100644
index 00000000..4bc0770b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/system.yaml
@@ -0,0 +1,3 @@
+system:
+ template: You are a helpful agent.
+ input_types: {}
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml
new file mode 100644
index 00000000..4a1aa477
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_img.yaml
@@ -0,0 +1,4 @@
+vision_img:
+ template: >
+ First, provide a title for the image, then explain everything that you see. Be very thorough in your analysis as a user will need to understand the image without seeing it. If it is possible to transcribe the image to text directly, then do so. The more detail you provide, the better the user will understand the image.
+ input_types: {}
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml
new file mode 100644
index 00000000..350ead2d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts/vision_pdf.yaml
@@ -0,0 +1,42 @@
+vision_pdf:
+ template: >
+ Convert this PDF page to markdown format, preserving all content and formatting. Follow these guidelines:
+
+ Text:
+ - Maintain the original text hierarchy (headings, paragraphs, lists)
+ - Preserve any special formatting (bold, italic, underline)
+ - Include all footnotes, citations, and references
+ - Keep text in its original reading order
+
+ Tables:
+ - Recreate tables using markdown table syntax
+ - Preserve all headers, rows, and columns
+ - Maintain alignment and formatting where possible
+ - Include any table captions or notes
+
+ Equations:
+ - Convert mathematical equations using LaTeX notation
+ - Preserve equation numbers if present
+ - Include any surrounding context or references
+
+ Images:
+ - Enclose image descriptions within [FIG] and [/FIG] tags
+ - Include detailed descriptions of:
+ * Main subject matter
+ * Text overlays or captions
+ * Charts, graphs, or diagrams
+ * Relevant colors, patterns, or visual elements
+ - Maintain image placement relative to surrounding text
+
+ Additional Elements:
+ - Include page numbers if visible
+ - Preserve headers and footers
+ - Maintain sidebars or callout boxes
+ - Keep any special symbols or characters
+
+ Quality Requirements:
+ - Ensure 100% content preservation
+ - Maintain logical document flow
+ - Verify all markdown syntax is valid
+ - Double-check completeness before submitting
+ input_types: {}
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py b/.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py
new file mode 100644
index 00000000..29afbb3f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/prompts_handler.py
@@ -0,0 +1,748 @@
+import json
+import logging
+import os
+from abc import abstractmethod
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Generic, Optional, TypeVar
+
+import yaml
+
+from core.base import Handler, generate_default_prompt_id
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+@dataclass
+class CacheEntry(Generic[T]):
+ """Represents a cached item with metadata."""
+
+ value: T
+ created_at: datetime
+ last_accessed: datetime
+ access_count: int = 0
+
+
+class Cache(Generic[T]):
+ """A generic cache implementation with TTL and LRU-like features."""
+
+ def __init__(
+ self,
+ ttl: Optional[timedelta] = None,
+ max_size: Optional[int] = 1000,
+ cleanup_interval: timedelta = timedelta(hours=1),
+ ):
+ self._cache: dict[str, CacheEntry[T]] = {}
+ self._ttl = ttl
+ self._max_size = max_size
+ self._cleanup_interval = cleanup_interval
+ self._last_cleanup = datetime.now()
+
+ def get(self, key: str) -> Optional[T]:
+ """Retrieve an item from cache."""
+ self._maybe_cleanup()
+
+ if key not in self._cache:
+ return None
+
+ entry = self._cache[key]
+
+ if self._ttl and datetime.now() - entry.created_at > self._ttl:
+ del self._cache[key]
+ return None
+
+ entry.last_accessed = datetime.now()
+ entry.access_count += 1
+ return entry.value
+
+ def set(self, key: str, value: T) -> None:
+ """Store an item in cache."""
+ self._maybe_cleanup()
+
+ now = datetime.now()
+ self._cache[key] = CacheEntry(
+ value=value, created_at=now, last_accessed=now
+ )
+
+ if self._max_size and len(self._cache) > self._max_size:
+ self._evict_lru()
+
+ def invalidate(self, key: str) -> None:
+ """Remove an item from cache."""
+ self._cache.pop(key, None)
+
+ def clear(self) -> None:
+ """Clear all cached items."""
+ self._cache.clear()
+
+ def _maybe_cleanup(self) -> None:
+ """Periodically clean up expired entries."""
+ now = datetime.now()
+ if now - self._last_cleanup > self._cleanup_interval:
+ self._cleanup()
+ self._last_cleanup = now
+
+ def _cleanup(self) -> None:
+ """Remove expired entries."""
+ if not self._ttl:
+ return
+
+ now = datetime.now()
+ expired = [
+ k for k, v in self._cache.items() if now - v.created_at > self._ttl
+ ]
+ for k in expired:
+ del self._cache[k]
+
+ def _evict_lru(self) -> None:
+ """Remove least recently used item."""
+ if not self._cache:
+ return
+
+ lru_key = min(
+ self._cache.keys(), key=lambda k: self._cache[k].last_accessed
+ )
+ del self._cache[lru_key]
+
+
+class CacheablePromptHandler(Handler):
+ """Abstract base class that adds caching capabilities to prompt
+ handlers."""
+
+ def __init__(
+ self,
+ cache_ttl: Optional[timedelta] = timedelta(hours=1),
+ max_cache_size: Optional[int] = 1000,
+ ):
+ self._prompt_cache = Cache[str](ttl=cache_ttl, max_size=max_cache_size)
+ self._template_cache = Cache[dict](
+ ttl=cache_ttl, max_size=max_cache_size
+ )
+
+ def _cache_key(
+ self, prompt_name: str, inputs: Optional[dict] = None
+ ) -> str:
+ """Generate a cache key for a prompt request."""
+ if inputs:
+ # Sort dict items for consistent keys
+ sorted_inputs = sorted(inputs.items())
+ return f"{prompt_name}:{sorted_inputs}"
+ return prompt_name
+
+ async def get_cached_prompt(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ prompt_override: Optional[str] = None,
+ bypass_cache: bool = False,
+ ) -> str:
+ if prompt_override:
+ # If the user gave us a direct override, use it.
+ if inputs:
+ try:
+ return prompt_override.format(**inputs)
+ except KeyError:
+ return prompt_override
+ return prompt_override
+
+ cache_key = self._cache_key(prompt_name, inputs)
+
+ # If not bypassing, try returning from the prompt-level cache
+ if not bypass_cache:
+ cached = self._prompt_cache.get(cache_key)
+ if cached is not None:
+ logger.debug(f"Prompt cache hit: {cache_key}")
+ return cached
+
+ logger.debug(
+ "Prompt cache miss or bypass. Retrieving from DB or template cache."
+ )
+ # Notice the new parameter `bypass_template_cache` below
+ result = await self._get_prompt_impl(
+ prompt_name, inputs, bypass_template_cache=bypass_cache
+ )
+ self._prompt_cache.set(cache_key, result)
+ return result
+
+ async def get_prompt( # type: ignore
+ self,
+ name: str,
+ inputs: Optional[dict] = None,
+ prompt_override: Optional[str] = None,
+ ) -> dict:
+ query = f"""
+ SELECT id, name, template, input_types, created_at, updated_at
+ FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+ result = await self.connection_manager.fetchrow_query(query, [name])
+
+ if not result:
+ raise ValueError(f"Prompt template '{name}' not found")
+
+ input_types = result["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ return {
+ "id": result["id"],
+ "name": result["name"],
+ "template": result["template"],
+ "input_types": input_types,
+ "created_at": result["created_at"],
+ "updated_at": result["updated_at"],
+ }
+
+ def _format_prompt(
+ self,
+ template: str,
+ inputs: Optional[dict[str, Any]],
+ input_types: dict[str, str],
+ ) -> str:
+ if inputs:
+ # optional input validation if needed
+ for k, _v in inputs.items():
+ if k not in input_types:
+ raise ValueError(
+ f"Unexpected input '{k}' for prompt with input types {input_types}"
+ )
+ return template.format(**inputs)
+ return template
+
+ async def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ """Public method to update a prompt with proper cache invalidation."""
+ # First invalidate all caches for this prompt
+ self._template_cache.invalidate(name)
+ cache_keys_to_invalidate = [
+ key
+ for key in self._prompt_cache._cache.keys()
+ if key.startswith(f"{name}:") or key == name
+ ]
+ for key in cache_keys_to_invalidate:
+ self._prompt_cache.invalidate(key)
+
+ # Perform the update
+ await self._update_prompt_impl(name, template, input_types)
+
+ # Force refresh template cache
+ template_info = await self._get_template_info(name)
+ if template_info:
+ self._template_cache.set(name, template_info)
+
+ @abstractmethod
+ async def _update_prompt_impl(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ """Implementation of prompt update logic."""
+ pass
+
+ @abstractmethod
+ async def _get_template_info(self, prompt_name: str) -> Optional[dict]:
+ """Get template info with caching."""
+ pass
+
+ @abstractmethod
+ async def _get_prompt_impl(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ bypass_template_cache: bool = False,
+ ) -> str:
+ """Implementation of prompt retrieval logic."""
+ pass
+
+
+class PostgresPromptsHandler(CacheablePromptHandler):
+ """PostgreSQL implementation of the CacheablePromptHandler."""
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ prompt_directory: Optional[Path] = None,
+ **cache_options,
+ ):
+ super().__init__(**cache_options)
+ self.prompt_directory = (
+ prompt_directory or Path(os.path.dirname(__file__)) / "prompts"
+ )
+ self.connection_manager = connection_manager
+ self.project_name = project_name
+ self.prompts: dict[str, dict[str, str | dict[str, str]]] = {}
+
+ async def _load_prompts(self) -> None:
+ """Load prompts from both database and YAML files."""
+ # First load from database
+ await self._load_prompts_from_database()
+
+ # Then load from YAML files, potentially overriding unmodified database entries
+ await self._load_prompts_from_yaml_directory()
+
+ async def _load_prompts_from_database(self) -> None:
+ """Load prompts from the database."""
+ query = f"""
+ SELECT id, name, template, input_types, created_at, updated_at
+ FROM {self._get_table_name("prompts")};
+ """
+ try:
+ results = await self.connection_manager.fetch_query(query)
+ for row in results:
+ logger.info(f"Loading saved prompt: {row['name']}")
+
+ # Ensure input_types is a dictionary
+ input_types = row["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ self.prompts[row["name"]] = {
+ "id": row["id"],
+ "template": row["template"],
+ "input_types": input_types,
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ # Pre-populate the template cache
+ self._template_cache.set(
+ row["name"],
+ {
+ "id": row["id"],
+ "template": row["template"],
+ "input_types": input_types,
+ },
+ )
+ logger.debug(f"Loaded {len(results)} prompts from database")
+ except Exception as e:
+ logger.error(f"Failed to load prompts from database: {e}")
+ raise
+
+ async def _load_prompts_from_yaml_directory(
+ self, default_overwrite_on_diff: bool = False
+ ) -> None:
+ """Load prompts from YAML files in the specified directory.
+
+ :param default_overwrite_on_diff: If a YAML prompt does not specify
+ 'overwrite_on_diff', we use this default.
+ """
+ if not self.prompt_directory.is_dir():
+ logger.warning(
+ f"Prompt directory not found: {self.prompt_directory}"
+ )
+ return
+
+ logger.info(f"Loading prompts from {self.prompt_directory}")
+ for yaml_file in self.prompt_directory.glob("*.yaml"):
+ logger.debug(f"Processing {yaml_file}")
+ try:
+ with open(yaml_file, "r", encoding="utf-8") as file:
+ data = yaml.safe_load(file)
+ if not isinstance(data, dict):
+ raise ValueError(
+ f"Invalid format in YAML file {yaml_file}"
+ )
+
+ for name, prompt_data in data.items():
+ # Attempt to parse the relevant prompt fields
+ template = prompt_data.get("template")
+ input_types = prompt_data.get("input_types", {})
+
+ # Decide on per-prompt overwrite behavior (or fallback)
+ overwrite_on_diff = prompt_data.get(
+ "overwrite_on_diff", default_overwrite_on_diff
+ )
+ # Some logic to determine if we *should* modify
+ # For instance, preserve only if it has never been updated
+ # (i.e., created_at == updated_at).
+ should_modify = True
+ if name in self.prompts:
+ existing = self.prompts[name]
+ should_modify = (
+ existing["created_at"]
+ == existing["updated_at"]
+ )
+
+ # If should_modify is True, the default logic is
+ # preserve_existing = False,
+ # so we can pass that in. Otherwise, preserve_existing=True
+ # effectively means we skip the update.
+ logger.info(
+ f"Loading default prompt: {name} from {yaml_file}."
+ )
+
+ await self.add_prompt(
+ name=name,
+ template=template,
+ input_types=input_types,
+ preserve_existing=False,
+ overwrite_on_diff=overwrite_on_diff,
+ )
+ except Exception as e:
+ logger.error(f"Error loading {yaml_file}: {e}")
+ continue
+
+ def _get_table_name(self, base_name: str) -> str:
+ """Get the fully qualified table name."""
+ return f"{self.project_name}.{base_name}"
+
+ # Implementation of abstract methods from CacheablePromptHandler
+ async def _get_prompt_impl(
+ self,
+ prompt_name: str,
+ inputs: Optional[dict[str, Any]] = None,
+ bypass_template_cache: bool = False,
+ ) -> str:
+ """Implementation of database prompt retrieval."""
+ # If we're bypassing the template cache, skip the cache lookup
+ if not bypass_template_cache:
+ template_info = self._template_cache.get(prompt_name)
+ if template_info is not None:
+ logger.debug(f"Template cache hit: {prompt_name}")
+ # use that
+ return self._format_prompt(
+ template_info["template"],
+ inputs,
+ template_info["input_types"],
+ )
+
+ # If we get here, either no cache was found or bypass_cache is True
+ query = f"""
+ SELECT template, input_types
+ FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [prompt_name]
+ )
+
+ if not result:
+ raise ValueError(f"Prompt template '{prompt_name}' not found")
+
+ template = result["template"]
+ input_types = result["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ # Update template cache if not bypassing it
+ if not bypass_template_cache:
+ self._template_cache.set(
+ prompt_name, {"template": template, "input_types": input_types}
+ )
+
+ return self._format_prompt(template, inputs, input_types)
+
+ async def _get_template_info(self, prompt_name: str) -> Optional[dict]: # type: ignore
+ """Get template info with caching."""
+ cached = self._template_cache.get(prompt_name)
+ if cached is not None:
+ return cached
+
+ query = f"""
+ SELECT template, input_types
+ FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [prompt_name]
+ )
+
+ if result:
+ # Ensure input_types is a dictionary
+ input_types = result["input_types"]
+ if isinstance(input_types, str):
+ input_types = json.loads(input_types)
+
+ template_info = {
+ "template": result["template"],
+ "input_types": input_types,
+ }
+ self._template_cache.set(prompt_name, template_info)
+ return template_info
+
+ return None
+
+ async def _update_prompt_impl(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ """Implementation of database prompt update with proper connection
+ handling."""
+ if not template and not input_types:
+ return
+
+ # Clear caches first
+ self._template_cache.invalidate(name)
+ for key in list(self._prompt_cache._cache.keys()):
+ if key.startswith(f"{name}:"):
+ self._prompt_cache.invalidate(key)
+
+ # Build update query
+ set_clauses = []
+ params = [name] # First parameter is always the name
+ param_index = 2 # Start from 2 since $1 is name
+
+ if template:
+ set_clauses.append(f"template = ${param_index}")
+ params.append(template)
+ param_index += 1
+
+ if input_types:
+ set_clauses.append(f"input_types = ${param_index}")
+ params.append(json.dumps(input_types))
+ param_index += 1
+
+ set_clauses.append("updated_at = CURRENT_TIMESTAMP")
+
+ query = f"""
+ UPDATE {self._get_table_name("prompts")}
+ SET {", ".join(set_clauses)}
+ WHERE name = $1
+ RETURNING id, template, input_types;
+ """
+
+ try:
+ # Execute update and get returned values
+ result = await self.connection_manager.fetchrow_query(
+ query, params
+ )
+
+ if not result:
+ raise ValueError(f"Prompt template '{name}' not found")
+
+ # Update in-memory state
+ if name in self.prompts:
+ if template:
+ self.prompts[name]["template"] = template
+ if input_types:
+ self.prompts[name]["input_types"] = input_types
+ self.prompts[name]["updated_at"] = datetime.now().isoformat()
+
+ except Exception as e:
+ logger.error(f"Failed to update prompt {name}: {str(e)}")
+ raise
+
+ async def create_tables(self):
+ """Create the necessary tables for storing prompts."""
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name("prompts")} (
+ id UUID PRIMARY KEY,
+ name VARCHAR(255) NOT NULL UNIQUE,
+ template TEXT NOT NULL,
+ input_types JSONB NOT NULL,
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
+ );
+
+ CREATE OR REPLACE FUNCTION {self.project_name}.update_updated_at_column()
+ RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+ END;
+ $$ language 'plpgsql';
+
+ DROP TRIGGER IF EXISTS update_prompts_updated_at
+ ON {self._get_table_name("prompts")};
+
+ CREATE TRIGGER update_prompts_updated_at
+ BEFORE UPDATE ON {self._get_table_name("prompts")}
+ FOR EACH ROW
+ EXECUTE FUNCTION {self.project_name}.update_updated_at_column();
+ """
+ await self.connection_manager.execute_query(query)
+ await self._load_prompts()
+
+ async def add_prompt(
+ self,
+ name: str,
+ template: str,
+ input_types: dict[str, str],
+ preserve_existing: bool = False,
+ overwrite_on_diff: bool = False, # <-- new param
+ ) -> None:
+ """Add or update a prompt.
+
+ If `preserve_existing` is True and prompt already exists, we skip updating.
+
+ If `overwrite_on_diff` is True and an existing prompt differs from what is provided,
+ we overwrite and log a warning. Otherwise, we skip if the prompt differs.
+ """
+ # Check if prompt is in-memory
+ existing_prompt = self.prompts.get(name)
+
+ # If preserving existing and it already exists, skip entirely
+ if preserve_existing and existing_prompt:
+ logger.debug(
+ f"Preserving existing prompt: {name}, skipping update."
+ )
+ return
+
+ # If an existing prompt is found, check for diffs
+ if existing_prompt:
+ existing_template = existing_prompt["template"]
+ existing_input_types = existing_prompt["input_types"]
+
+ # If there's a difference in template or input_types, decide to overwrite or skip
+ if (
+ existing_template != template
+ or existing_input_types != input_types
+ ):
+ if overwrite_on_diff:
+ logger.warning(
+ f"Overwriting existing prompt '{name}' due to detected diff."
+ )
+ else:
+ logger.info(
+ f"Prompt '{name}' differs from existing but overwrite_on_diff=False. Skipping update."
+ )
+ return
+
+ prompt_id = generate_default_prompt_id(name)
+
+ # Ensure input_types is properly serialized
+ input_types_json = (
+ json.dumps(input_types)
+ if isinstance(input_types, dict)
+ else input_types
+ )
+
+ # Upsert logic
+ query = f"""
+ INSERT INTO {self._get_table_name("prompts")} (id, name, template, input_types)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (name) DO UPDATE
+ SET template = EXCLUDED.template,
+ input_types = EXCLUDED.input_types,
+ updated_at = CURRENT_TIMESTAMP
+ RETURNING id, created_at, updated_at;
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [prompt_id, name, template, input_types_json]
+ )
+
+ self.prompts[name] = {
+ "id": result["id"],
+ "template": template,
+ "input_types": input_types,
+ "created_at": result["created_at"],
+ "updated_at": result["updated_at"],
+ }
+
+ # Update template cache
+ self._template_cache.set(
+ name,
+ {
+ "id": prompt_id,
+ "template": template,
+ "input_types": input_types,
+ },
+ )
+
+ # Invalidate any cached formatted prompts
+ for key in list(self._prompt_cache._cache.keys()):
+ if key.startswith(f"{name}:"):
+ self._prompt_cache.invalidate(key)
+
+ async def get_all_prompts(self) -> dict[str, Any]:
+ """Retrieve all stored prompts."""
+ query = f"""
+ SELECT id, name, template, input_types, created_at, updated_at, COUNT(*) OVER() AS total_entries
+ FROM {self._get_table_name("prompts")};
+ """
+ results = await self.connection_manager.fetch_query(query)
+
+ if not results:
+ return {"results": [], "total_entries": 0}
+
+ total_entries = results[0]["total_entries"] if results else 0
+
+ prompts = [
+ {
+ "name": row["name"],
+ "id": row["id"],
+ "template": row["template"],
+ "input_types": (
+ json.loads(row["input_types"])
+ if isinstance(row["input_types"], str)
+ else row["input_types"]
+ ),
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]
+
+ return {"results": prompts, "total_entries": total_entries}
+
+ async def delete_prompt(self, name: str) -> None:
+ """Delete a prompt template."""
+ query = f"""
+ DELETE FROM {self._get_table_name("prompts")}
+ WHERE name = $1;
+ """
+ result = await self.connection_manager.execute_query(query, [name])
+ if result == "DELETE 0":
+ raise ValueError(f"Prompt template '{name}' not found")
+
+ # Invalidate caches
+ self._template_cache.invalidate(name)
+ for key in list(self._prompt_cache._cache.keys()):
+ if key.startswith(f"{name}:"):
+ self._prompt_cache.invalidate(key)
+
+ async def get_message_payload(
+ self,
+ system_prompt_name: Optional[str] = None,
+ system_role: str = "system",
+ system_inputs: dict | None = None,
+ system_prompt_override: Optional[str] = None,
+ task_prompt_name: Optional[str] = None,
+ task_role: str = "user",
+ task_inputs: Optional[dict] = None,
+ task_prompt: Optional[str] = None,
+ ) -> list[dict]:
+ """Create a message payload from system and task prompts."""
+ if system_inputs is None:
+ system_inputs = {}
+ if task_inputs is None:
+ task_inputs = {}
+ if system_prompt_override:
+ system_prompt = system_prompt_override
+ else:
+ system_prompt = await self.get_cached_prompt(
+ system_prompt_name or "system",
+ system_inputs,
+ prompt_override=system_prompt_override,
+ )
+
+ task_prompt = await self.get_cached_prompt(
+ task_prompt_name or "rag",
+ task_inputs,
+ prompt_override=task_prompt,
+ )
+
+ return [
+ {
+ "role": system_role,
+ "content": system_prompt,
+ },
+ {
+ "role": task_role,
+ "content": task_prompt,
+ },
+ ]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/tokens.py b/.venv/lib/python3.12/site-packages/core/providers/database/tokens.py
new file mode 100644
index 00000000..7d30c326
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/tokens.py
@@ -0,0 +1,67 @@
+from datetime import datetime, timedelta
+from typing import Optional
+
+from core.base import Handler
+
+from .base import PostgresConnectionManager
+
+
+class PostgresTokensHandler(Handler):
+ TABLE_NAME = "blacklisted_tokens"
+
+ def __init__(
+ self, project_name: str, connection_manager: PostgresConnectionManager
+ ):
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self):
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ token TEXT NOT NULL,
+ blacklisted_at TIMESTAMPTZ DEFAULT NOW()
+ );
+ CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_token
+ ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token);
+ CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_blacklisted_at
+ ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (blacklisted_at);
+ """
+ await self.connection_manager.execute_query(query)
+
+ async def blacklist_token(
+ self, token: str, current_time: Optional[datetime] = None
+ ):
+ if current_time is None:
+ current_time = datetime.utcnow()
+
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token, blacklisted_at)
+ VALUES ($1, $2)
+ """
+ await self.connection_manager.execute_query(
+ query, [token, current_time]
+ )
+
+ async def is_token_blacklisted(self, token: str) -> bool:
+ query = f"""
+ SELECT 1 FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
+ WHERE token = $1
+ LIMIT 1
+ """
+ result = await self.connection_manager.fetchrow_query(query, [token])
+ return bool(result)
+
+ async def clean_expired_blacklisted_tokens(
+ self,
+ max_age_hours: int = 7 * 24,
+ current_time: Optional[datetime] = None,
+ ):
+ if current_time is None:
+ current_time = datetime.utcnow()
+ expiry_time = current_time - timedelta(hours=max_age_hours)
+
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
+ WHERE blacklisted_at < $1
+ """
+ await self.connection_manager.execute_query(query, [expiry_time])
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/users.py b/.venv/lib/python3.12/site-packages/core/providers/database/users.py
new file mode 100644
index 00000000..208eeaa4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/users.py
@@ -0,0 +1,1325 @@
+import csv
+import json
+import tempfile
+from datetime import datetime
+from typing import IO, Optional
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core.base import CryptoProvider, Handler
+from core.base.abstractions import R2RException
+from core.utils import generate_user_id
+from shared.abstractions import User
+
+from .base import PostgresConnectionManager, QueryBuilder
+from .collections import PostgresCollectionsHandler
+
+
+def _merge_metadata(
+ existing_metadata: dict[str, str], new_metadata: dict[str, Optional[str]]
+) -> dict[str, str]:
+ """
+ Merges the new metadata with the existing metadata in the Stripe-style approach:
+ - new_metadata[key] = <string> => update or add that key
+ - new_metadata[key] = "" => remove that key
+ - if new_metadata is empty => remove all keys
+ """
+ # If new_metadata is an empty dict, it signals removal of all keys.
+ if new_metadata == {}:
+ return {}
+
+ # Copy so we don't mutate the original
+ final_metadata = dict(existing_metadata)
+
+ for key, value in new_metadata.items():
+ # If the user sets the key to an empty string, it means "delete" that key
+ if value == "":
+ if key in final_metadata:
+ del final_metadata[key]
+ # If not None and not empty, set or override
+ elif value is not None:
+ final_metadata[key] = value
+ else:
+ # If the user sets the value to None in some contexts, decide if you want to remove or ignore
+ # For now we might treat None same as empty string => remove
+ if key in final_metadata:
+ del final_metadata[key]
+
+ return final_metadata
+
+
+class PostgresUserHandler(Handler):
+ TABLE_NAME = "users"
+ API_KEYS_TABLE_NAME = "users_api_keys"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ crypto_provider: CryptoProvider,
+ ):
+ super().__init__(project_name, connection_manager)
+ self.crypto_provider = crypto_provider
+
+ async def create_tables(self):
+ user_table_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ email TEXT UNIQUE NOT NULL,
+ hashed_password TEXT NOT NULL,
+ is_superuser BOOLEAN DEFAULT FALSE,
+ is_active BOOLEAN DEFAULT TRUE,
+ is_verified BOOLEAN DEFAULT FALSE,
+ verification_code TEXT,
+ verification_code_expiry TIMESTAMPTZ,
+ name TEXT,
+ bio TEXT,
+ profile_picture TEXT,
+ reset_token TEXT,
+ reset_token_expiry TIMESTAMPTZ,
+ collection_ids UUID[] NULL,
+ limits_overrides JSONB,
+ metadata JSONB,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ account_type TEXT NOT NULL DEFAULT 'password',
+ google_id TEXT,
+ github_id TEXT
+ );
+ """
+
+ # API keys table with updated_at instead of last_used_at
+ api_keys_table_query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+ user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
+ public_key TEXT UNIQUE NOT NULL,
+ hashed_key TEXT NOT NULL,
+ name TEXT,
+ description TEXT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
+ ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
+
+ CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
+ ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
+ """
+
+ await self.connection_manager.execute_query(user_table_query)
+ await self.connection_manager.execute_query(api_keys_table_query)
+
+ # (New) Code snippet for adding columns if missing
+ # Postgres >= 9.6 supports "ADD COLUMN IF NOT EXISTS"
+ check_columns_query = f"""
+ ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS metadata JSONB;
+
+ ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS limits_overrides JSONB;
+
+ ALTER TABLE {self._get_table_name(self.API_KEYS_TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS description TEXT;
+ """
+ await self.connection_manager.execute_query(check_columns_query)
+
+ # Optionally, create indexes for quick lookups:
+ check_columns_query = f"""
+ ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+ ADD COLUMN IF NOT EXISTS account_type TEXT NOT NULL DEFAULT 'password',
+ ADD COLUMN IF NOT EXISTS google_id TEXT,
+ ADD COLUMN IF NOT EXISTS github_id TEXT;
+
+ CREATE INDEX IF NOT EXISTS idx_users_google_id
+ ON {self._get_table_name(self.TABLE_NAME)}(google_id);
+ CREATE INDEX IF NOT EXISTS idx_users_github_id
+ ON {self._get_table_name(self.TABLE_NAME)}(github_id);
+ """
+ await self.connection_manager.execute_query(check_columns_query)
+
+ async def get_user_by_id(self, id: UUID) -> User:
+ query, _ = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("id = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(query, [id])
+
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ hashed_password=result["hashed_password"],
+ account_type=result["account_type"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def get_user_by_email(self, email: str) -> User:
+ query, params = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "metadata",
+ "limits_overrides",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("email = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(query, [email])
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def create_user(
+ self,
+ email: str,
+ password: Optional[str] = None,
+ account_type: Optional[str] = "password",
+ google_id: Optional[str] = None,
+ github_id: Optional[str] = None,
+ is_superuser: bool = False,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ """Create a new user."""
+ # 1) Check if a user with this email already exists
+ try:
+ existing = await self.get_user_by_email(email)
+ if existing:
+ raise R2RException(
+ status_code=400,
+ message="User with this email already exists",
+ )
+ except R2RException as e:
+ if e.status_code != 404:
+ raise e
+ # 2) If google_id is provided, ensure no user already has it
+ if google_id:
+ existing_google_user = await self.get_user_by_google_id(google_id)
+ if existing_google_user:
+ raise R2RException(
+ status_code=400,
+ message="User with this Google account already exists",
+ )
+
+ # 3) If github_id is provided, ensure no user already has it
+ if github_id:
+ existing_github_user = await self.get_user_by_github_id(github_id)
+ if existing_github_user:
+ raise R2RException(
+ status_code=400,
+ message="User with this GitHub account already exists",
+ )
+
+ hashed_password = None
+ if account_type == "password":
+ if password is None:
+ raise R2RException(
+ status_code=400,
+ message="Password is required for a 'password' account_type",
+ )
+ hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
+
+ query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .insert(
+ {
+ "email": email,
+ "id": generate_user_id(email),
+ "is_superuser": is_superuser,
+ "collection_ids": [],
+ "limits_overrides": None,
+ "metadata": None,
+ "account_type": account_type,
+ "hashed_password": hashed_password
+ or "", # Ensure hashed_password is not None
+ # !!WARNING - Upstream checks are required to treat oauth differently from password!!
+ "google_id": google_id,
+ "github_id": github_id,
+ "is_verified": account_type != "password",
+ "name": name,
+ "bio": bio,
+ "profile_picture": profile_picture,
+ }
+ )
+ .returning(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "name",
+ "bio",
+ "profile_picture",
+ ]
+ )
+ .build()
+ )
+
+ result = await self.connection_manager.fetchrow_query(query, params)
+ if not result:
+ raise R2RException(
+ status_code=500,
+ message="Failed to create user",
+ )
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ name=result["name"],
+ bio=result["bio"],
+ profile_picture=result["profile_picture"],
+ account_type=account_type or "password",
+ hashed_password=hashed_password,
+ google_id=google_id,
+ github_id=github_id,
+ )
+
+ async def update_user(
+ self,
+ user: User,
+ merge_limits: bool = False,
+ new_metadata: dict[str, Optional[str]] | None = None,
+ ) -> User:
+ """Update user information including limits_overrides.
+
+ Args:
+ user: User object containing updated information
+ merge_limits: If True, will merge existing limits_overrides with new ones.
+ If False, will overwrite existing limits_overrides.
+
+ Returns:
+ Updated User object
+ """
+
+ # Get current user if we need to merge limits or get hashed password
+ current_user = None
+ try:
+ current_user = await self.get_user_by_id(user.id)
+ except R2RException:
+ raise R2RException(
+ status_code=404, message="User not found"
+ ) from None
+
+ # If the new user.google_id != current_user.google_id, check for duplicates
+ if user.email and (user.email != current_user.email):
+ existing_email_user = await self.get_user_by_email(user.email)
+ if existing_email_user and existing_email_user.id != user.id:
+ raise R2RException(
+ status_code=400,
+ message="That email account is already associated with another user.",
+ )
+
+ # If the new user.google_id != current_user.google_id, check for duplicates
+ if user.google_id and (user.google_id != current_user.google_id):
+ existing_google_user = await self.get_user_by_google_id(
+ user.google_id
+ )
+ if existing_google_user and existing_google_user.id != user.id:
+ raise R2RException(
+ status_code=400,
+ message="That Google account is already associated with another user.",
+ )
+
+ # Similarly for GitHub:
+ if user.github_id and (user.github_id != current_user.github_id):
+ existing_github_user = await self.get_user_by_github_id(
+ user.github_id
+ )
+ if existing_github_user and existing_github_user.id != user.id:
+ raise R2RException(
+ status_code=400,
+ message="That GitHub account is already associated with another user.",
+ )
+
+ # Merge or replace metadata if provided
+ final_metadata = current_user.metadata or {}
+ if new_metadata is not None:
+ final_metadata = _merge_metadata(final_metadata, new_metadata)
+
+ # Merge or replace limits_overrides
+ final_limits = user.limits_overrides
+ if (
+ merge_limits
+ and current_user.limits_overrides
+ and user.limits_overrides
+ ):
+ final_limits = {
+ **current_user.limits_overrides,
+ **user.limits_overrides,
+ }
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET email = $1,
+ is_superuser = $2,
+ is_active = $3,
+ is_verified = $4,
+ updated_at = NOW(),
+ name = $5,
+ profile_picture = $6,
+ bio = $7,
+ collection_ids = $8,
+ limits_overrides = $9::jsonb,
+ metadata = $10::jsonb
+ WHERE id = $11
+ RETURNING id, email, is_superuser, is_active, is_verified,
+ created_at, updated_at, name, profile_picture, bio,
+ collection_ids, limits_overrides, metadata, hashed_password,
+ account_type, google_id, github_id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query,
+ [
+ user.email,
+ user.is_superuser,
+ user.is_active,
+ user.is_verified,
+ user.name,
+ user.profile_picture,
+ user.bio,
+ user.collection_ids or [],
+ json.dumps(final_limits),
+ json.dumps(final_metadata),
+ user.id,
+ ],
+ )
+
+ if not result:
+ raise HTTPException(
+ status_code=500,
+ detail="Failed to update user",
+ )
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"]
+ or [], # Ensure null becomes empty array
+ limits_overrides=json.loads(
+ result["limits_overrides"] or "{}"
+ ), # Can be null
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result[
+ "hashed_password"
+ ], # Include hashed_password
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def delete_user_relational(self, id: UUID) -> None:
+ """Delete a user and update related records."""
+ # Get the collections the user belongs to
+ collection_query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .select(["collection_ids"])
+ .where("id = $1")
+ .build()
+ )
+
+ collection_result = await self.connection_manager.fetchrow_query(
+ collection_query, [id]
+ )
+
+ if not collection_result:
+ raise R2RException(status_code=404, message="User not found")
+
+ # Update documents query
+ doc_update_query, doc_params = (
+ QueryBuilder(self._get_table_name("documents"))
+ .update({"id": None})
+ .where("id = $1")
+ .build()
+ )
+
+ await self.connection_manager.execute_query(doc_update_query, [id])
+
+ # Delete user query
+ delete_query, del_params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .delete()
+ .where("id = $1")
+ .returning(["id"])
+ .build()
+ )
+
+ result = await self.connection_manager.fetchrow_query(
+ delete_query, [id]
+ )
+
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ async def update_user_password(self, id: UUID, new_hashed_password: str):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET hashed_password = $1, updated_at = NOW()
+ WHERE id = $2
+ """
+ await self.connection_manager.execute_query(
+ query, [new_hashed_password, id]
+ )
+
+ async def get_all_users(self) -> list[User]:
+ """Get all users with minimal information."""
+ query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "collection_ids",
+ "hashed_password",
+ "limits_overrides",
+ "metadata",
+ "name",
+ "bio",
+ "profile_picture",
+ "account_type",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .build()
+ )
+
+ results = await self.connection_manager.fetch_query(query, params)
+ return [
+ User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(
+ result["limits_overrides"] or "{}"
+ ),
+ metadata=json.loads(result["metadata"] or "{}"),
+ name=result["name"],
+ bio=result["bio"],
+ profile_picture=result["profile_picture"],
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+ for result in results
+ ]
+
+ async def store_verification_code(
+ self, id: UUID, verification_code: str, expiry: datetime
+ ):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET verification_code = $1, verification_code_expiry = $2
+ WHERE id = $3
+ """
+ await self.connection_manager.execute_query(
+ query, [verification_code, expiry, id]
+ )
+
+ async def verify_user(self, verification_code: str) -> None:
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+ WHERE verification_code = $1 AND verification_code_expiry > NOW()
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [verification_code]
+ )
+
+ if not result:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ async def remove_verification_code(self, verification_code: str):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET verification_code = NULL, verification_code_expiry = NULL
+ WHERE verification_code = $1
+ """
+ await self.connection_manager.execute_query(query, [verification_code])
+
+ async def expire_verification_code(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET verification_code_expiry = NOW() - INTERVAL '1 day'
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def store_reset_token(
+ self, id: UUID, reset_token: str, expiry: datetime
+ ):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET reset_token = $1, reset_token_expiry = $2
+ WHERE id = $3
+ """
+ await self.connection_manager.execute_query(
+ query, [reset_token, expiry, id]
+ )
+
+ async def get_user_id_by_reset_token(
+ self, reset_token: str
+ ) -> Optional[UUID]:
+ query = f"""
+ SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ WHERE reset_token = $1 AND reset_token_expiry > NOW()
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [reset_token]
+ )
+ return result["id"] if result else None
+
+ async def remove_reset_token(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET reset_token = NULL, reset_token_expiry = NULL
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def remove_user_from_all_collections(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET collection_ids = ARRAY[]::UUID[]
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def add_user_to_collection(
+ self, id: UUID, collection_id: UUID
+ ) -> bool:
+ # Check if the user exists
+ if not await self.get_user_by_id(id):
+ raise R2RException(status_code=404, message="User not found")
+
+ # Check if the collection exists
+ if not await self._collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET collection_ids = array_append(collection_ids, $1)
+ WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id, id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=400, message="User already in collection"
+ )
+
+ update_collection_query = f"""
+ UPDATE {self._get_table_name("collections")}
+ SET user_count = user_count + 1
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(
+ query=update_collection_query,
+ params=[collection_id],
+ )
+
+ return True
+
+ async def remove_user_from_collection(
+ self, id: UUID, collection_id: UUID
+ ) -> bool:
+ if not await self.get_user_by_id(id):
+ raise R2RException(status_code=404, message="User not found")
+
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET collection_ids = array_remove(collection_ids, $1)
+ WHERE id = $2 AND $1 = ANY(collection_ids)
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id, id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=400,
+ message="User is not a member of the specified collection",
+ )
+ return True
+
+ async def get_users_in_collection(
+ self, collection_id: UUID, offset: int, limit: int
+ ) -> dict[str, list[User] | int]:
+ """Get all users in a specific collection with pagination."""
+ if not await self._collection_exists(collection_id):
+ raise R2RException(status_code=404, message="Collection not found")
+
+ query, params = (
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
+ .select(
+ [
+ "id",
+ "email",
+ "is_active",
+ "is_superuser",
+ "created_at",
+ "updated_at",
+ "is_verified",
+ "collection_ids",
+ "name",
+ "bio",
+ "profile_picture",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ "COUNT(*) OVER() AS total_entries",
+ ]
+ )
+ .where("$1 = ANY(collection_ids)")
+ .order_by("name")
+ .offset("$2")
+ .limit("$3" if limit != -1 else None)
+ .build()
+ )
+
+ conditions = [collection_id, offset]
+ if limit != -1:
+ conditions.append(limit)
+
+ results = await self.connection_manager.fetch_query(query, conditions)
+
+ users_list = [
+ User(
+ id=row["id"],
+ email=row["email"],
+ is_active=row["is_active"],
+ is_superuser=row["is_superuser"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ is_verified=row["is_verified"],
+ collection_ids=row["collection_ids"] or [],
+ name=row["name"],
+ bio=row["bio"],
+ profile_picture=row["profile_picture"],
+ limits_overrides=json.loads(row["limits_overrides"] or "{}"),
+ metadata=json.loads(row["metadata"] or "{}"),
+ account_type=row["account_type"],
+ hashed_password=row["hashed_password"],
+ google_id=row["google_id"],
+ github_id=row["github_id"],
+ )
+ for row in results
+ ]
+
+ total_entries = results[0]["total_entries"] if results else 0
+ return {"results": users_list, "total_entries": total_entries}
+
+ async def mark_user_as_superuser(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET is_superuser = TRUE, is_verified = TRUE,
+ verification_code = NULL, verification_code_expiry = NULL
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def get_user_id_by_verification_code(
+ self, verification_code: str
+ ) -> UUID:
+ query = f"""
+ SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ WHERE verification_code = $1 AND verification_code_expiry > NOW()
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [verification_code]
+ )
+
+ if not result:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ return result["id"]
+
+ async def mark_user_as_verified(self, id: UUID):
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+ SET is_verified = TRUE,
+ verification_code = NULL,
+ verification_code_expiry = NULL
+ WHERE id = $1
+ """
+ await self.connection_manager.execute_query(query, [id])
+
+ async def get_users_overview(
+ self,
+ offset: int,
+ limit: int,
+ user_ids: Optional[list[UUID]] = None,
+ ) -> dict[str, list[User] | int]:
+ """Return users with document usage and total entries."""
+ query = f"""
+ WITH user_document_ids AS (
+ SELECT
+ u.id as user_id,
+ ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
+ FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+ LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
+ GROUP BY u.id
+ ),
+ user_docs AS (
+ SELECT
+ u.id,
+ u.email,
+ u.is_superuser,
+ u.is_active,
+ u.is_verified,
+ u.name,
+ u.bio,
+ u.profile_picture,
+ u.collection_ids,
+ u.created_at,
+ u.updated_at,
+ COUNT(d.id) AS num_files,
+ COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
+ ud.doc_ids as document_ids
+ FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+ LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
+ LEFT JOIN user_document_ids ud ON u.id = ud.user_id
+ {" WHERE u.id = ANY($3::uuid[])" if user_ids else ""}
+ GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
+ u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
+ )
+ SELECT
+ user_docs.*,
+ COUNT(*) OVER() AS total_entries
+ FROM user_docs
+ ORDER BY email
+ OFFSET $1
+ """
+
+ params: list = [offset]
+
+ if limit != -1:
+ query += " LIMIT $2"
+ params.append(limit)
+
+ if user_ids:
+ params.append(user_ids)
+
+ results = await self.connection_manager.fetch_query(query, params)
+ if not results:
+ raise R2RException(status_code=404, message="No users found")
+
+ users_list = []
+ for row in results:
+ users_list.append(
+ User(
+ id=row["id"],
+ email=row["email"],
+ is_superuser=row["is_superuser"],
+ is_active=row["is_active"],
+ is_verified=row["is_verified"],
+ name=row["name"],
+ bio=row["bio"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ profile_picture=row["profile_picture"],
+ collection_ids=row["collection_ids"] or [],
+ num_files=row["num_files"],
+ total_size_in_bytes=row["total_size_in_bytes"],
+ document_ids=(
+ list(row["document_ids"])
+ if row["document_ids"]
+ else []
+ ),
+ )
+ )
+
+ total_entries = results[0]["total_entries"]
+ return {"results": users_list, "total_entries": total_entries}
+
+ async def _collection_exists(self, collection_id: UUID) -> bool:
+ """Check if a collection exists."""
+ query = f"""
+ SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [collection_id]
+ )
+ return result is not None
+
+ async def get_user_validation_data(
+ self,
+ user_id: UUID,
+ ) -> dict:
+ """Get verification data for a specific user.
+
+ This method should be called after superuser authorization has been
+ verified.
+ """
+ query = f"""
+ SELECT
+ verification_code,
+ verification_code_expiry,
+ reset_token,
+ reset_token_expiry
+ FROM {self._get_table_name("users")}
+ WHERE id = $1
+ """
+ result = await self.connection_manager.fetchrow_query(query, [user_id])
+
+ if not result:
+ raise R2RException(status_code=404, message="User not found")
+
+ return {
+ "verification_data": {
+ "verification_code": result["verification_code"],
+ "verification_code_expiry": (
+ result["verification_code_expiry"].isoformat()
+ if result["verification_code_expiry"]
+ else None
+ ),
+ "reset_token": result["reset_token"],
+ "reset_token_expiry": (
+ result["reset_token_expiry"].isoformat()
+ if result["reset_token_expiry"]
+ else None
+ ),
+ }
+ }
+
+ # API Key methods
+ async def store_user_api_key(
+ self,
+ user_id: UUID,
+ key_id: str,
+ hashed_key: str,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> UUID:
+ """Store a new API key for a user with optional name and
+ description."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ (user_id, public_key, hashed_key, name, description)
+ VALUES ($1, $2, $3, $4, $5)
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [user_id, key_id, hashed_key, name or "", description or ""]
+ )
+ if not result:
+ raise R2RException(
+ status_code=500, message="Failed to store API key"
+ )
+ return result["id"]
+
+ async def get_api_key_record(self, key_id: str) -> Optional[dict]:
+ """Get API key record by 'public_key' and update 'updated_at' to now.
+
+ Returns { "user_id", "hashed_key" } or None if not found.
+ """
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ SET updated_at = NOW()
+ WHERE public_key = $1
+ RETURNING user_id, hashed_key
+ """
+ result = await self.connection_manager.fetchrow_query(query, [key_id])
+ if not result:
+ return None
+ return {
+ "user_id": result["user_id"],
+ "hashed_key": result["hashed_key"],
+ }
+
+ async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
+ """Get all API keys for a user."""
+ query = f"""
+ SELECT id, public_key, name, description, created_at, updated_at
+ FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ WHERE user_id = $1
+ ORDER BY created_at DESC
+ """
+ results = await self.connection_manager.fetch_query(query, [user_id])
+ return [
+ {
+ "key_id": str(row["id"]),
+ "public_key": row["public_key"],
+ "name": row["name"] or "",
+ "description": row["description"] or "",
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]
+
+ async def delete_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ """Delete a specific API key."""
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ WHERE id = $1 AND user_id = $2
+ RETURNING id, public_key, name, description
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [key_id, user_id]
+ )
+ if result is None:
+ raise R2RException(status_code=404, message="API key not found")
+
+ return True
+
+ async def update_api_key_name(
+ self, user_id: UUID, key_id: UUID, name: str
+ ) -> bool:
+ """Update the name of an existing API key."""
+ query = f"""
+ UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+ SET name = $1, updated_at = NOW()
+ WHERE id = $2 AND user_id = $3
+ RETURNING id
+ """
+ result = await self.connection_manager.fetchrow_query(
+ query, [name, key_id, user_id]
+ )
+ if result is None:
+ raise R2RException(status_code=404, message="API key not found")
+ return True
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "name",
+ "bio",
+ "collection_ids",
+ "created_at",
+ "updated_at",
+ }
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ email,
+ is_superuser,
+ is_active,
+ is_verified,
+ name,
+ bio,
+ collection_ids::text,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ params = []
+ if filters:
+ conditions = []
+ param_index = 1
+
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "email": row[1],
+ "is_superuser": row[2],
+ "is_active": row[3],
+ "is_verified": row[4],
+ "name": row[5],
+ "bio": row[6],
+ "collection_ids": row[7],
+ "created_at": row[8],
+ "updated_at": row[9],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e
+
+ async def get_user_by_google_id(self, google_id: str) -> Optional[User]:
+ """Return a User if the google_id is found; otherwise None."""
+ query, params = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("google_id = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(
+ query, [google_id]
+ )
+ if not result:
+ return None
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
+
+ async def get_user_by_github_id(self, github_id: str) -> Optional[User]:
+ """Return a User if the github_id is found; otherwise None."""
+ query, params = (
+ QueryBuilder(self._get_table_name("users"))
+ .select(
+ [
+ "id",
+ "email",
+ "is_superuser",
+ "is_active",
+ "is_verified",
+ "created_at",
+ "updated_at",
+ "name",
+ "profile_picture",
+ "bio",
+ "collection_ids",
+ "limits_overrides",
+ "metadata",
+ "account_type",
+ "hashed_password",
+ "google_id",
+ "github_id",
+ ]
+ )
+ .where("github_id = $1")
+ .build()
+ )
+ result = await self.connection_manager.fetchrow_query(
+ query, [github_id]
+ )
+ if not result:
+ return None
+
+ return User(
+ id=result["id"],
+ email=result["email"],
+ is_superuser=result["is_superuser"],
+ is_active=result["is_active"],
+ is_verified=result["is_verified"],
+ created_at=result["created_at"],
+ updated_at=result["updated_at"],
+ name=result["name"],
+ profile_picture=result["profile_picture"],
+ bio=result["bio"],
+ collection_ids=result["collection_ids"] or [],
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+ metadata=json.loads(result["metadata"] or "{}"),
+ account_type=result["account_type"],
+ hashed_password=result["hashed_password"],
+ google_id=result["google_id"],
+ github_id=result["github_id"],
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/email/__init__.py
new file mode 100644
index 00000000..38753695
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/__init__.py
@@ -0,0 +1,11 @@
+from .console_mock import ConsoleMockEmailProvider
+from .mailersend import MailerSendEmailProvider
+from .sendgrid import SendGridEmailProvider
+from .smtp import AsyncSMTPEmailProvider
+
+__all__ = [
+ "ConsoleMockEmailProvider",
+ "AsyncSMTPEmailProvider",
+ "SendGridEmailProvider",
+ "MailerSendEmailProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py b/.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py
new file mode 100644
index 00000000..459a978d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/console_mock.py
@@ -0,0 +1,67 @@
+import logging
+from typing import Optional
+
+from core.base import EmailProvider
+
+logger = logging.getLogger()
+
+
+class ConsoleMockEmailProvider(EmailProvider):
+ """A simple email provider that logs emails to console, useful for
+ testing."""
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: {subject}
+ Body:
+ {body}
+ -----------------------------
+ """)
+
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: Please verify your email address
+ Body:
+ Verification code: {verification_code}
+ -----------------------------
+ """)
+
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: Password Reset Request
+ Body:
+ Reset token: {reset_token}
+ -----------------------------
+ """)
+
+ async def send_password_changed_email(
+ self, to_email: str, *args, **kwargs
+ ) -> None:
+ logger.info(f"""
+ -------- Email Message --------
+ To: {to_email}
+ Subject: Your Password Has Been Changed
+ Body:
+ Your password has been successfully changed.
+
+ For security reasons, you will need to log in again on all your devices.
+ -----------------------------
+ """)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py b/.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py
new file mode 100644
index 00000000..10fccd56
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/mailersend.py
@@ -0,0 +1,281 @@
+import logging
+import os
+from typing import Optional
+
+from mailersend import emails
+
+from core.base import EmailConfig, EmailProvider
+
+logger = logging.getLogger(__name__)
+
+
+class MailerSendEmailProvider(EmailProvider):
+ """Email provider implementation using MailerSend API."""
+
+ def __init__(self, config: EmailConfig):
+ super().__init__(config)
+ self.api_key = config.mailersend_api_key or os.getenv(
+ "MAILERSEND_API_KEY"
+ )
+ if not self.api_key or not isinstance(self.api_key, str):
+ raise ValueError("A valid MailerSend API key is required.")
+
+ self.from_email = config.from_email or os.getenv("R2R_FROM_EMAIL")
+ if not self.from_email or not isinstance(self.from_email, str):
+ raise ValueError("A valid from email is required.")
+
+ self.frontend_url = config.frontend_url or os.getenv(
+ "R2R_FRONTEND_URL"
+ )
+ if not self.frontend_url or not isinstance(self.frontend_url, str):
+ raise ValueError("A valid frontend URL is required.")
+
+ self.verify_email_template_id = (
+ config.verify_email_template_id
+ or os.getenv("MAILERSEND_VERIFY_EMAIL_TEMPLATE_ID")
+ )
+ self.reset_password_template_id = (
+ config.reset_password_template_id
+ or os.getenv("MAILERSEND_RESET_PASSWORD_TEMPLATE_ID")
+ )
+ self.password_changed_template_id = (
+ config.password_changed_template_id
+ or os.getenv("MAILERSEND_PASSWORD_CHANGED_TEMPLATE_ID")
+ )
+ self.client = emails.NewEmail(self.api_key)
+ self.sender_name = config.sender_name or "R2R"
+
+ # Logo and documentation URLs
+ self.docs_base_url = f"{self.frontend_url}/documentation"
+
+ def _get_base_template_data(self, to_email: str) -> dict:
+ """Get base template data used across all email templates."""
+ return {
+ "user_email": to_email,
+ "docs_url": self.docs_base_url,
+ "quickstart_url": f"{self.docs_base_url}/quickstart",
+ "frontend_url": self.frontend_url,
+ }
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: Optional[str] = None,
+ body: Optional[str] = None,
+ html_body: Optional[str] = None,
+ template_id: Optional[str] = None,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ logger.info("Preparing MailerSend message...")
+
+ mail_body = {
+ "from": {
+ "email": self.from_email,
+ "name": self.sender_name,
+ },
+ "to": [{"email": to_email}],
+ }
+
+ if template_id:
+ # Transform the template data to MailerSend's expected format
+ if dynamic_template_data:
+ formatted_substitutions = {}
+ for key, value in dynamic_template_data.items():
+ formatted_substitutions[key] = {
+ "var": key,
+ "value": value,
+ }
+ mail_body["variables"] = [
+ {
+ "email": to_email,
+ "substitutions": formatted_substitutions,
+ }
+ ]
+
+ mail_body["template_id"] = template_id
+ else:
+ mail_body.update(
+ {
+ "subject": subject or "",
+ "text": body or "",
+ "html": html_body or "",
+ }
+ )
+
+ import asyncio
+
+ response = await asyncio.to_thread(self.client.send, mail_body)
+
+ # Handle different response formats
+ if isinstance(response, str):
+ # Clean the string response by stripping whitespace
+ response_clean = response.strip()
+ if response_clean in ["202", "200"]:
+ logger.info(
+ f"Email accepted for delivery with status code {response_clean}"
+ )
+ return
+ elif isinstance(response, int) and response in [200, 202]:
+ logger.info(
+ f"Email accepted for delivery with status code {response}"
+ )
+ return
+ elif isinstance(response, dict) and response.get(
+ "status_code"
+ ) in [200, 202]:
+ logger.info(
+ f"Email accepted for delivery with status code {response.get('status_code')}"
+ )
+ return
+
+ # If we get here, it's an error
+ error_msg = f"MailerSend error: {response}"
+ logger.error(error_msg)
+
+ except Exception as e:
+ error_msg = f"Failed to send email to {to_email}: {str(e)}"
+ logger.error(error_msg)
+
+ async def send_verification_email(
+ self,
+ to_email: str,
+ verification_code: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.verify_email_template_id:
+ verification_data = {
+ "verification_link": f"{self.frontend_url}/verify-email?verification_code={verification_code}&email={to_email}",
+ "verification_code": verification_code, # Include code separately for flexible template usage
+ }
+
+ # Merge with any additional template data
+ template_data = {
+ **(dynamic_template_data or {}),
+ **verification_data,
+ }
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.verify_email_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ # Fallback to basic email if no template ID is configured
+ subject = "Verify Your R2R Account"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Welcome to R2R!</h1>
+ <p>Please verify your email address to get started with R2R - the most advanced AI retrieval system.</p>
+ <p>Click the link below to verify your email:</p>
+ <p><a href="{self.frontend_url}/verify-email?verification_code={verification_code}&email={to_email}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Verify Email
+ </a></p>
+ <p>Or enter this verification code: <strong>{verification_code}</strong></p>
+ <p>If you didn't create an account with R2R, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Welcome to R2R! Please verify your email using this code: {verification_code}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send verification email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+
+ async def send_password_reset_email(
+ self,
+ to_email: str,
+ reset_token: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.reset_password_template_id:
+ reset_data = {
+ "reset_link": f"{self.frontend_url}/reset-password?token={reset_token}",
+ "reset_token": reset_token,
+ }
+
+ template_data = {**(dynamic_template_data or {}), **reset_data}
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.reset_password_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ subject = "Reset Your R2R Password"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Reset Request</h1>
+ <p>You've requested to reset your R2R password.</p>
+ <p>Click the link below to reset your password:</p>
+ <p><a href="{self.frontend_url}/reset-password?token={reset_token}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Reset Password
+ </a></p>
+ <p>Or use this reset token: <strong>{reset_token}</strong></p>
+ <p>If you didn't request a password reset, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Reset your R2R password using this token: {reset_token}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send password reset email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ dynamic_template_data: Optional[dict] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ if (
+ hasattr(self, "password_changed_template_id")
+ and self.password_changed_template_id
+ ):
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.password_changed_template_id,
+ dynamic_template_data=dynamic_template_data,
+ )
+ else:
+ subject = "Your Password Has Been Changed"
+ body = """
+ Your password has been successfully changed.
+
+ If you did not make this change, please contact support immediately and secure your account.
+
+ """
+ html_body = """
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Changed Successfully</h1>
+ <p>Your password has been successfully changed.</p>
+ </div>
+ """
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=body,
+ )
+ except Exception as e:
+ error_msg = f"Failed to send password change notification to {to_email}: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py b/.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py
new file mode 100644
index 00000000..8b2553f1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/sendgrid.py
@@ -0,0 +1,257 @@
+import logging
+import os
+from typing import Optional
+
+from sendgrid import SendGridAPIClient
+from sendgrid.helpers.mail import Content, From, Mail
+
+from core.base import EmailConfig, EmailProvider
+
+logger = logging.getLogger(__name__)
+
+
+class SendGridEmailProvider(EmailProvider):
+ """Email provider implementation using SendGrid API."""
+
+ def __init__(self, config: EmailConfig):
+ super().__init__(config)
+ self.api_key = config.sendgrid_api_key or os.getenv("SENDGRID_API_KEY")
+ if not self.api_key or not isinstance(self.api_key, str):
+ raise ValueError("A valid SendGrid API key is required.")
+
+ self.from_email = config.from_email or os.getenv("R2R_FROM_EMAIL")
+ if not self.from_email or not isinstance(self.from_email, str):
+ raise ValueError("A valid from email is required.")
+
+ self.frontend_url = config.frontend_url or os.getenv(
+ "R2R_FRONTEND_URL"
+ )
+ if not self.frontend_url or not isinstance(self.frontend_url, str):
+ raise ValueError("A valid frontend URL is required.")
+
+ self.verify_email_template_id = (
+ config.verify_email_template_id
+ or os.getenv("SENDGRID_EMAIL_TEMPLATE_ID")
+ )
+ self.reset_password_template_id = (
+ config.reset_password_template_id
+ or os.getenv("SENDGRID_RESET_TEMPLATE_ID")
+ )
+ self.password_changed_template_id = (
+ config.password_changed_template_id
+ or os.getenv("SENDGRID_PASSWORD_CHANGED_TEMPLATE_ID")
+ )
+ self.client = SendGridAPIClient(api_key=self.api_key)
+ self.sender_name = config.sender_name
+
+ # Logo and documentation URLs
+ self.docs_base_url = f"{self.frontend_url}/documentation"
+
+ def _get_base_template_data(self, to_email: str) -> dict:
+ """Get base template data used across all email templates."""
+ return {
+ "user_email": to_email,
+ "docs_url": self.docs_base_url,
+ "quickstart_url": f"{self.docs_base_url}/quickstart",
+ "frontend_url": self.frontend_url,
+ }
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: Optional[str] = None,
+ body: Optional[str] = None,
+ html_body: Optional[str] = None,
+ template_id: Optional[str] = None,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ logger.info("Preparing SendGrid message...")
+ message = Mail(
+ from_email=From(self.from_email, self.sender_name),
+ to_emails=to_email,
+ )
+
+ if template_id:
+ logger.info(f"Using dynamic template with ID: {template_id}")
+ message.template_id = template_id
+ base_data = self._get_base_template_data(to_email)
+ message.dynamic_template_data = {
+ **base_data,
+ **(dynamic_template_data or {}),
+ }
+ else:
+ if not subject:
+ raise ValueError(
+ "Subject is required when not using a template"
+ )
+ message.subject = subject
+ message.add_content(Content("text/plain", body or ""))
+ if html_body:
+ message.add_content(Content("text/html", html_body))
+
+ import asyncio
+
+ response = await asyncio.to_thread(self.client.send, message)
+
+ if response.status_code >= 400:
+ raise RuntimeError(
+ f"Failed to send email: {response.status_code}"
+ )
+ elif response.status_code == 202:
+ logger.info("Message sent successfully!")
+ else:
+ error_msg = f"Failed to send email. Status code: {response.status_code}, Body: {response.body}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg)
+
+ except Exception as e:
+ error_msg = f"Failed to send email to {to_email}: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_verification_email(
+ self,
+ to_email: str,
+ verification_code: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.verify_email_template_id:
+ verification_data = {
+ "verification_link": f"{self.frontend_url}/verify-email?verification_code={verification_code}&email={to_email}",
+ "verification_code": verification_code, # Include code separately for flexible template usage
+ }
+
+ # Merge with any additional template data
+ template_data = {
+ **(dynamic_template_data or {}),
+ **verification_data,
+ }
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.verify_email_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ # Fallback to basic email if no template ID is configured
+ subject = "Verify Your R2R Account"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Welcome to R2R!</h1>
+ <p>Please verify your email address to get started with R2R - the most advanced AI retrieval system.</p>
+ <p>Click the link below to verify your email:</p>
+ <p><a href="{self.frontend_url}/verify-email?token={verification_code}&email={to_email}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Verify Email
+ </a></p>
+ <p>Or enter this verification code: <strong>{verification_code}</strong></p>
+ <p>If you didn't create an account with R2R, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Welcome to R2R! Please verify your email using this code: {verification_code}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send verification email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_password_reset_email(
+ self,
+ to_email: str,
+ reset_token: str,
+ dynamic_template_data: Optional[dict] = None,
+ ) -> None:
+ try:
+ if self.reset_password_template_id:
+ reset_data = {
+ "reset_link": f"{self.frontend_url}/reset-password?token={reset_token}",
+ "reset_token": reset_token,
+ }
+
+ template_data = {**(dynamic_template_data or {}), **reset_data}
+
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.reset_password_template_id,
+ dynamic_template_data=template_data,
+ )
+ else:
+ subject = "Reset Your R2R Password"
+ html_body = f"""
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Reset Request</h1>
+ <p>You've requested to reset your R2R password.</p>
+ <p>Click the link below to reset your password:</p>
+ <p><a href="{self.frontend_url}/reset-password?token={reset_token}"
+ style="background-color: #007bff; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
+ Reset Password
+ </a></p>
+ <p>Or use this reset token: <strong>{reset_token}</strong></p>
+ <p>If you didn't request a password reset, please ignore this email.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=f"Reset your R2R password using this token: {reset_token}",
+ )
+ except Exception as e:
+ error_msg = (
+ f"Failed to send password reset email to {to_email}: {str(e)}"
+ )
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_password_changed_email(
+ self,
+ to_email: str,
+ dynamic_template_data: Optional[dict] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ if (
+ hasattr(self, "password_changed_template_id")
+ and self.password_changed_template_id
+ ):
+ await self.send_email(
+ to_email=to_email,
+ template_id=self.password_changed_template_id,
+ dynamic_template_data=dynamic_template_data,
+ )
+ else:
+ subject = "Your Password Has Been Changed"
+ body = """
+ Your password has been successfully changed.
+
+ If you did not make this change, please contact support immediately and secure your account.
+
+ """
+ html_body = """
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Changed Successfully</h1>
+ <p>Your password has been successfully changed.</p>
+ </div>
+ """
+ # Move send_email inside the else block
+ await self.send_email(
+ to_email=to_email,
+ subject=subject,
+ html_body=html_body,
+ body=body,
+ )
+ except Exception as e:
+ error_msg = f"Failed to send password change notification to {to_email}: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/email/smtp.py b/.venv/lib/python3.12/site-packages/core/providers/email/smtp.py
new file mode 100644
index 00000000..bd68ff36
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/email/smtp.py
@@ -0,0 +1,176 @@
+import asyncio
+import logging
+import os
+import smtplib
+import ssl
+from email.mime.multipart import MIMEMultipart
+from email.mime.text import MIMEText
+from typing import Optional
+
+from core.base import EmailConfig, EmailProvider
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncSMTPEmailProvider(EmailProvider):
+ """Email provider implementation using Brevo SMTP relay."""
+
+ def __init__(self, config: EmailConfig):
+ super().__init__(config)
+ self.smtp_server = config.smtp_server or os.getenv("R2R_SMTP_SERVER")
+ if not self.smtp_server:
+ raise ValueError("SMTP server is required")
+
+ self.smtp_port = config.smtp_port or os.getenv("R2R_SMTP_PORT")
+ if not self.smtp_port:
+ raise ValueError("SMTP port is required")
+
+ self.smtp_username = config.smtp_username or os.getenv(
+ "R2R_SMTP_USERNAME"
+ )
+ if not self.smtp_username:
+ raise ValueError("SMTP username is required")
+
+ self.smtp_password = config.smtp_password or os.getenv(
+ "R2R_SMTP_PASSWORD"
+ )
+ if not self.smtp_password:
+ raise ValueError("SMTP password is required")
+
+ self.from_email: Optional[str] = (
+ config.from_email
+ or os.getenv("R2R_FROM_EMAIL")
+ or self.smtp_username
+ )
+ self.ssl_context = ssl.create_default_context()
+
+ async def _send_email_sync(self, msg: MIMEMultipart) -> None:
+ """Synchronous email sending wrapped in asyncio executor."""
+ loop = asyncio.get_running_loop()
+
+ def _send():
+ with smtplib.SMTP_SSL(
+ self.smtp_server,
+ self.smtp_port,
+ context=self.ssl_context,
+ timeout=30,
+ ) as server:
+ logger.info("Connected to SMTP server")
+ server.login(self.smtp_username, self.smtp_password)
+ logger.info("Login successful")
+ server.send_message(msg)
+ logger.info("Message sent successfully!")
+
+ try:
+ await loop.run_in_executor(None, _send)
+ except Exception as e:
+ error_msg = f"Failed to send email: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ msg = MIMEMultipart("alternative")
+ msg["Subject"] = subject
+ msg["From"] = self.from_email # type: ignore
+ msg["To"] = to_email
+
+ msg.attach(MIMEText(body, "plain"))
+ if html_body:
+ msg.attach(MIMEText(html_body, "html"))
+
+ try:
+ logger.info("Initializing SMTP connection...")
+ async with asyncio.timeout(30): # Overall timeout
+ await self._send_email_sync(msg)
+ except asyncio.TimeoutError as e:
+ error_msg = "Operation timed out while trying to send email"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+ except Exception as e:
+ error_msg = f"Failed to send email: {str(e)}"
+ logger.error(error_msg)
+ raise RuntimeError(error_msg) from e
+
+ async def send_verification_email(
+ self, to_email: str, verification_code: str, *args, **kwargs
+ ) -> None:
+ body = f"""
+ Please verify your email address by entering the following code:
+
+ Verification code: {verification_code}
+
+ If you did not request this verification, please ignore this email.
+ """
+
+ html_body = f"""
+ <p>Please verify your email address by entering the following code:</p>
+ <p style="font-size: 24px; font-weight: bold; margin: 20px 0;">
+ Verification code: {verification_code}
+ </p>
+ <p>If you did not request this verification, please ignore this email.</p>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject="Please verify your email address",
+ body=body,
+ html_body=html_body,
+ )
+
+ async def send_password_reset_email(
+ self, to_email: str, reset_token: str, *args, **kwargs
+ ) -> None:
+ body = f"""
+ You have requested to reset your password.
+
+ Reset token: {reset_token}
+
+ If you did not request a password reset, please ignore this email.
+ """
+
+ html_body = f"""
+ <p>You have requested to reset your password.</p>
+ <p style="font-size: 24px; font-weight: bold; margin: 20px 0;">
+ Reset token: {reset_token}
+ </p>
+ <p>If you did not request a password reset, please ignore this email.</p>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject="Password Reset Request",
+ body=body,
+ html_body=html_body,
+ )
+
+ async def send_password_changed_email(
+ self, to_email: str, *args, **kwargs
+ ) -> None:
+ body = """
+ Your password has been successfully changed.
+
+ If you did not make this change, please contact support immediately and secure your account.
+
+ """
+
+ html_body = """
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
+ <h1>Password Changed Successfully</h1>
+ <p>Your password has been successfully changed.</p>
+ </div>
+ """
+
+ await self.send_email(
+ to_email=to_email,
+ subject="Your Password Has Been Changed",
+ body=body,
+ html_body=html_body,
+ )
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py
new file mode 100644
index 00000000..3fa67442
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/__init__.py
@@ -0,0 +1,9 @@
+from .litellm import LiteLLMEmbeddingProvider
+from .ollama import OllamaEmbeddingProvider
+from .openai import OpenAIEmbeddingProvider
+
+__all__ = [
+ "LiteLLMEmbeddingProvider",
+ "OpenAIEmbeddingProvider",
+ "OllamaEmbeddingProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py
new file mode 100644
index 00000000..5f705c91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/litellm.py
@@ -0,0 +1,305 @@
+import logging
+import math
+import os
+from copy import copy
+from typing import Any
+
+import litellm
+import requests
+from aiohttp import ClientError, ClientSession
+from litellm import AuthenticationError, aembedding, embedding
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+ R2RException,
+)
+
+logger = logging.getLogger()
+
+
+class LiteLLMEmbeddingProvider(EmbeddingProvider):
+ def __init__(
+ self,
+ config: EmbeddingConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(config)
+
+ self.litellm_embedding = embedding
+ self.litellm_aembedding = aembedding
+
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `LiteLLMEmbeddingProvider`."
+ )
+ if provider != "litellm":
+ raise ValueError(
+ "LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
+ )
+
+ self.rerank_url = None
+ if config.rerank_model:
+ if "huggingface" not in config.rerank_model:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
+ )
+
+ url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url
+ if not url:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
+ )
+ self.rerank_url = url
+
+ self.base_model = config.base_model
+ if "amazon" in self.base_model:
+ logger.warn("Amazon embedding model detected, dropping params")
+ litellm.drop_params = True
+ self.base_dimension = config.base_dimension
+
+ def _get_embedding_kwargs(self, **kwargs):
+ embedding_kwargs = {
+ "model": self.base_model,
+ "dimensions": self.base_dimension,
+ }
+ embedding_kwargs.update(kwargs)
+ return embedding_kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ if "dimensions" in kwargs and math.isnan(kwargs["dimensions"]):
+ kwargs.pop("dimensions")
+ logger.warning("Dropping nan dimensions from kwargs")
+
+ try:
+ response = await self.litellm_aembedding(
+ input=texts,
+ **kwargs,
+ )
+ return [data["embedding"] for data in response.data]
+ except AuthenticationError:
+ logger.error(
+ "Authentication error: Invalid API key or credentials."
+ )
+ raise
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+
+ raise R2RException(error_msg, 400) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+ try:
+ response = self.litellm_embedding(
+ input=texts,
+ **kwargs,
+ )
+ return [data["embedding"] for data in response.data]
+ except AuthenticationError:
+ logger.error(
+ "Authentication error: Invalid API key or credentials."
+ )
+ raise
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return (await self._execute_with_backoff_async(task))[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "LiteLLMEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ if self.config.rerank_model is not None:
+ if not self.rerank_url:
+ raise ValueError(
+ "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
+ )
+
+ texts = [result.text for result in results]
+
+ payload = {
+ "query": query,
+ "texts": texts,
+ "model-id": self.config.rerank_model.split("huggingface/")[1],
+ }
+
+ headers = {"Content-Type": "application/json"}
+
+ try:
+ response = requests.post(
+ self.rerank_url, json=payload, headers=headers
+ )
+ response.raise_for_status()
+ reranked_results = response.json()
+
+ # Copy reranked results into new array
+ scored_results = []
+ for rank_info in reranked_results:
+ original_result = results[rank_info["index"]]
+ copied_result = copy(original_result)
+ # Inject the reranking score into the result object
+ copied_result.score = rank_info["score"]
+ scored_results.append(copied_result)
+
+ # Return only the ChunkSearchResult objects, limited to specified count
+ return scored_results[:limit]
+
+ except requests.RequestException as e:
+ logger.error(f"Error during reranking: {str(e)}")
+ # Fall back to returning the original results if reranking fails
+ return results[:limit]
+ else:
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ) -> list[ChunkSearchResult]:
+ """Asynchronously rerank search results using the configured rerank
+ model.
+
+ Args:
+ query: The search query string
+ results: List of ChunkSearchResult objects to rerank
+ limit: Maximum number of results to return
+
+ Returns:
+ List of reranked ChunkSearchResult objects, limited to specified count
+ """
+ if self.config.rerank_model is not None:
+ if not self.rerank_url:
+ raise ValueError(
+ "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
+ )
+
+ texts = [result.text for result in results]
+
+ payload = {
+ "query": query,
+ "texts": texts,
+ "model-id": self.config.rerank_model.split("huggingface/")[1],
+ }
+
+ headers = {"Content-Type": "application/json"}
+
+ try:
+ async with ClientSession() as session:
+ async with session.post(
+ self.rerank_url, json=payload, headers=headers
+ ) as response:
+ response.raise_for_status()
+ reranked_results = await response.json()
+
+ # Copy reranked results into new array
+ scored_results = []
+ for rank_info in reranked_results:
+ original_result = results[rank_info["index"]]
+ copied_result = copy(original_result)
+ # Inject the reranking score into the result object
+ copied_result.score = rank_info["score"]
+ scored_results.append(copied_result)
+
+ # Return only the ChunkSearchResult objects, limited to specified count
+ return scored_results[:limit]
+
+ except (ClientError, Exception) as e:
+ logger.error(f"Error during async reranking: {str(e)}")
+ # Fall back to returning the original results if reranking fails
+ return results[:limit]
+ else:
+ return results[:limit]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
new file mode 100644
index 00000000..297d9167
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/ollama.py
@@ -0,0 +1,194 @@
+import logging
+import os
+from typing import Any
+
+from ollama import AsyncClient, Client
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+ R2RException,
+)
+
+logger = logging.getLogger()
+
+
+class OllamaEmbeddingProvider(EmbeddingProvider):
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ provider = config.provider
+ if not provider:
+ raise ValueError(
+ "Must set provider in order to initialize `OllamaEmbeddingProvider`."
+ )
+ if provider != "ollama":
+ raise ValueError(
+ "OllamaEmbeddingProvider must be initialized with provider `ollama`."
+ )
+ if config.rerank_model:
+ raise ValueError(
+ "OllamaEmbeddingProvider does not support separate reranking."
+ )
+
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+ self.base_url = os.getenv("OLLAMA_API_BASE")
+ logger.info(
+ f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
+ )
+ self.client = Client(host=self.base_url)
+ self.aclient = AsyncClient(host=self.base_url)
+
+ self.set_prefixes(config.prefixes or {}, self.base_model)
+ self.batch_size = config.batch_size or 32
+
+ def _get_embedding_kwargs(self, **kwargs):
+ embedding_kwargs = {
+ "model": self.base_model,
+ }
+ embedding_kwargs.update(kwargs)
+ return embedding_kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ purpose = task.get("purpose", EmbeddingPurpose.INDEX)
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ embeddings = []
+ for i in range(0, len(texts), self.batch_size):
+ batch = texts[i : i + self.batch_size]
+ prefixed_batch = [
+ self.prefixes.get(purpose, "") + text for text in batch
+ ]
+ response = await self.aclient.embed(
+ input=prefixed_batch, **kwargs
+ )
+ embeddings.extend(response["embeddings"])
+ return embeddings
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ purpose = task.get("purpose", EmbeddingPurpose.INDEX)
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ embeddings = []
+ for i in range(0, len(texts), self.batch_size):
+ batch = texts[i : i + self.batch_size]
+ prefixed_batch = [
+ self.prefixes.get(purpose, "") + text for text in batch
+ ]
+ response = self.client.embed(input=prefixed_batch, **kwargs)
+ embeddings.extend(response["embeddings"])
+ return embeddings
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise R2RException(error_msg, 400) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = await self._execute_with_backoff_async(task)
+ return result[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = self._execute_with_backoff_sync(task)
+ return result[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OllamaEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ) -> list[ChunkSearchResult]:
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
new file mode 100644
index 00000000..907cebd9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/embeddings/openai.py
@@ -0,0 +1,243 @@
+import logging
+import os
+from typing import Any
+
+import tiktoken
+from openai import AsyncOpenAI, AuthenticationError, OpenAI
+from openai._types import NOT_GIVEN
+
+from core.base import (
+ ChunkSearchResult,
+ EmbeddingConfig,
+ EmbeddingProvider,
+ EmbeddingPurpose,
+)
+
+logger = logging.getLogger()
+
+
+class OpenAIEmbeddingProvider(EmbeddingProvider):
+ MODEL_TO_TOKENIZER = {
+ "text-embedding-ada-002": "cl100k_base",
+ "text-embedding-3-small": "cl100k_base",
+ "text-embedding-3-large": "cl100k_base",
+ }
+ MODEL_TO_DIMENSIONS = {
+ "text-embedding-ada-002": [1536],
+ "text-embedding-3-small": [512, 1536],
+ "text-embedding-3-large": [256, 1024, 3072],
+ }
+
+ def __init__(self, config: EmbeddingConfig):
+ super().__init__(config)
+ if not config.provider:
+ raise ValueError(
+ "Must set provider in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if config.provider != "openai":
+ raise ValueError(
+ "OpenAIEmbeddingProvider must be initialized with provider `openai`."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
+ )
+ self.client = OpenAI()
+ self.async_client = AsyncOpenAI()
+
+ if config.rerank_model:
+ raise ValueError(
+ "OpenAIEmbeddingProvider does not support separate reranking."
+ )
+
+ if config.base_model and "openai/" in config.base_model:
+ self.base_model = config.base_model.split("/")[-1]
+ else:
+ self.base_model = config.base_model
+ self.base_dimension = config.base_dimension
+
+ if not self.base_model:
+ raise ValueError(
+ "Must set base_model in order to initialize OpenAIEmbeddingProvider."
+ )
+
+ if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(
+ f"OpenAI embedding model {self.base_model} not supported."
+ )
+
+ if self.base_dimension:
+ if (
+ self.base_dimension
+ not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
+ self.base_model
+ ]
+ ):
+ raise ValueError(
+ f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
+ )
+ else:
+ # If base_dimension is not set, use the largest available dimension for the model
+ self.base_dimension = max(
+ OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
+ )
+
+ def _get_dimensions(self):
+ return (
+ NOT_GIVEN
+ if self.base_model == "text-embedding-ada-002"
+ else self.base_dimension
+ or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
+ )
+
+ def _get_embedding_kwargs(self, **kwargs):
+ return {
+ "model": self.base_model,
+ "dimensions": self._get_dimensions(),
+ } | kwargs
+
+ async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+
+ try:
+ response = await self.async_client.embeddings.create(
+ input=texts,
+ **kwargs,
+ )
+ return [data.embedding for data in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise ValueError(error_msg) from e
+
+ def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
+ texts = task["texts"]
+ kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
+ try:
+ response = self.client.embeddings.create(
+ input=texts,
+ **kwargs,
+ )
+ return [data.embedding for data in response.data]
+ except AuthenticationError as e:
+ raise ValueError(
+ "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
+ ) from e
+ except Exception as e:
+ error_msg = f"Error getting embeddings: {str(e)}"
+ logger.error(error_msg)
+ raise ValueError(error_msg) from e
+
+ async def async_get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = await self._execute_with_backoff_async(task)
+ return result[0]
+
+ def get_embedding(
+ self,
+ text: str,
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[float]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": [text],
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ result = self._execute_with_backoff_sync(task)
+ return result[0]
+
+ async def async_get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return await self._execute_with_backoff_async(task)
+
+ def get_embeddings(
+ self,
+ texts: list[str],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
+ purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+ **kwargs,
+ ) -> list[list[float]]:
+ if stage != EmbeddingProvider.Step.BASE:
+ raise ValueError(
+ "OpenAIEmbeddingProvider only supports search stage."
+ )
+
+ task = {
+ "texts": texts,
+ "stage": stage,
+ "purpose": purpose,
+ "kwargs": kwargs,
+ }
+ return self._execute_with_backoff_sync(task)
+
+ def rerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ async def arerank(
+ self,
+ query: str,
+ results: list[ChunkSearchResult],
+ stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
+ limit: int = 10,
+ ):
+ return results[:limit]
+
+ def tokenize_string(self, text: str, model: str) -> list[int]:
+ if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
+ raise ValueError(f"OpenAI embedding model {model} not supported.")
+ encoding = tiktoken.get_encoding(
+ OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
+ )
+ return encoding.encode(text)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py
new file mode 100644
index 00000000..4a25d30d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/__init__.py
@@ -0,0 +1,13 @@
+# type: ignore
+from .r2r.base import R2RIngestionConfig, R2RIngestionProvider
+from .unstructured.base import (
+ UnstructuredIngestionConfig,
+ UnstructuredIngestionProvider,
+)
+
+__all__ = [
+ "R2RIngestionConfig",
+ "R2RIngestionProvider",
+ "UnstructuredIngestionProvider",
+ "UnstructuredIngestionConfig",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py
new file mode 100644
index 00000000..7d452245
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/r2r/base.py
@@ -0,0 +1,355 @@
+# type: ignore
+import logging
+import time
+from typing import Any, AsyncGenerator, Optional
+
+from core import parsers
+from core.base import (
+ AsyncParser,
+ ChunkingStrategy,
+ Document,
+ DocumentChunk,
+ DocumentType,
+ IngestionConfig,
+ IngestionProvider,
+ R2RDocumentProcessingError,
+ RecursiveCharacterTextSplitter,
+ TextSplitter,
+)
+from core.utils import generate_extraction_id
+
+from ...database import PostgresDatabaseProvider
+from ...llm import (
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+
+logger = logging.getLogger()
+
+
+class R2RIngestionConfig(IngestionConfig):
+ chunk_size: int = 1024
+ chunk_overlap: int = 512
+ chunking_strategy: ChunkingStrategy = ChunkingStrategy.RECURSIVE
+ extra_fields: dict[str, Any] = {}
+ separator: Optional[str] = None
+
+
+class R2RIngestionProvider(IngestionProvider):
+ DEFAULT_PARSERS = {
+ DocumentType.BMP: parsers.BMPParser,
+ DocumentType.CSV: parsers.CSVParser,
+ DocumentType.DOC: parsers.DOCParser,
+ DocumentType.DOCX: parsers.DOCXParser,
+ DocumentType.EML: parsers.EMLParser,
+ DocumentType.EPUB: parsers.EPUBParser,
+ DocumentType.HTML: parsers.HTMLParser,
+ DocumentType.HTM: parsers.HTMLParser,
+ DocumentType.ODT: parsers.ODTParser,
+ DocumentType.JSON: parsers.JSONParser,
+ DocumentType.MSG: parsers.MSGParser,
+ DocumentType.ORG: parsers.ORGParser,
+ DocumentType.MD: parsers.MDParser,
+ DocumentType.PDF: parsers.BasicPDFParser,
+ DocumentType.PPT: parsers.PPTParser,
+ DocumentType.PPTX: parsers.PPTXParser,
+ DocumentType.TXT: parsers.TextParser,
+ DocumentType.XLSX: parsers.XLSXParser,
+ DocumentType.GIF: parsers.ImageParser,
+ DocumentType.JPEG: parsers.ImageParser,
+ DocumentType.JPG: parsers.ImageParser,
+ DocumentType.TSV: parsers.TSVParser,
+ DocumentType.PNG: parsers.ImageParser,
+ DocumentType.HEIC: parsers.ImageParser,
+ DocumentType.SVG: parsers.ImageParser,
+ DocumentType.MP3: parsers.AudioParser,
+ DocumentType.P7S: parsers.P7SParser,
+ DocumentType.RST: parsers.RSTParser,
+ DocumentType.RTF: parsers.RTFParser,
+ DocumentType.TIFF: parsers.ImageParser,
+ DocumentType.XLS: parsers.XLSParser,
+ }
+
+ EXTRA_PARSERS = {
+ DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced},
+ DocumentType.PDF: {
+ "unstructured": parsers.PDFParserUnstructured,
+ "zerox": parsers.VLMPDFParser,
+ },
+ DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced},
+ }
+
+ IMAGE_TYPES = {
+ DocumentType.GIF,
+ DocumentType.HEIC,
+ DocumentType.JPG,
+ DocumentType.JPEG,
+ DocumentType.PNG,
+ DocumentType.SVG,
+ }
+
+ def __init__(
+ self,
+ config: R2RIngestionConfig,
+ database_provider: PostgresDatabaseProvider,
+ llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ ):
+ super().__init__(config, database_provider, llm_provider)
+ self.config: R2RIngestionConfig = config
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ self.llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ) = llm_provider
+ self.parsers: dict[DocumentType, AsyncParser] = {}
+ self.text_splitter = self._build_text_splitter()
+ self._initialize_parsers()
+
+ logger.info(
+ f"R2RIngestionProvider initialized with config: {self.config}"
+ )
+
+ def _initialize_parsers(self):
+ for doc_type, parser in self.DEFAULT_PARSERS.items():
+ # will choose the first parser in the list
+ if doc_type not in self.config.excluded_parsers:
+ self.parsers[doc_type] = parser(
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ for doc_type, doc_parser_name in self.config.extra_parsers.items():
+ self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = (
+ R2RIngestionProvider.EXTRA_PARSERS[doc_type][doc_parser_name](
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ )
+
+ def _build_text_splitter(
+ self, ingestion_config_override: Optional[dict] = None
+ ) -> TextSplitter:
+ logger.info(
+ f"Initializing text splitter with method: {self.config.chunking_strategy}"
+ )
+
+ if not ingestion_config_override:
+ ingestion_config_override = {}
+
+ chunking_strategy = (
+ ingestion_config_override.get("chunking_strategy")
+ or self.config.chunking_strategy
+ )
+
+ chunk_size = (
+ ingestion_config_override.get("chunk_size")
+ or self.config.chunk_size
+ )
+ chunk_overlap = (
+ ingestion_config_override.get("chunk_overlap")
+ or self.config.chunk_overlap
+ )
+
+ if chunking_strategy == ChunkingStrategy.RECURSIVE:
+ return RecursiveCharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ )
+ elif chunking_strategy == ChunkingStrategy.CHARACTER:
+ from core.base.utils.splitter.text import CharacterTextSplitter
+
+ separator = (
+ ingestion_config_override.get("separator")
+ or self.config.separator
+ or CharacterTextSplitter.DEFAULT_SEPARATOR
+ )
+
+ return CharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ separator=separator,
+ keep_separator=False,
+ strip_whitespace=True,
+ )
+ elif chunking_strategy == ChunkingStrategy.BASIC:
+ raise NotImplementedError(
+ "Basic chunking method not implemented. Please use Recursive."
+ )
+ elif chunking_strategy == ChunkingStrategy.BY_TITLE:
+ raise NotImplementedError("By title method not implemented")
+ else:
+ raise ValueError(f"Unsupported method type: {chunking_strategy}")
+
+ def validate_config(self) -> bool:
+ return self.config.chunk_size > 0 and self.config.chunk_overlap >= 0
+
+ def chunk(
+ self,
+ parsed_document: str | DocumentChunk,
+ ingestion_config_override: dict,
+ ) -> AsyncGenerator[Any, None]:
+ text_spliiter = self.text_splitter
+ if ingestion_config_override:
+ text_spliiter = self._build_text_splitter(
+ ingestion_config_override
+ )
+ if isinstance(parsed_document, DocumentChunk):
+ parsed_document = parsed_document.data
+
+ if isinstance(parsed_document, str):
+ chunks = text_spliiter.create_documents([parsed_document])
+ else:
+ # Assuming parsed_document is already a list of text chunks
+ chunks = parsed_document
+
+ for chunk in chunks:
+ yield (
+ chunk.page_content if hasattr(chunk, "page_content") else chunk
+ )
+
+ async def parse(
+ self,
+ file_content: bytes,
+ document: Document,
+ ingestion_config_override: dict,
+ ) -> AsyncGenerator[DocumentChunk, None]:
+ if document.document_type not in self.parsers:
+ raise R2RDocumentProcessingError(
+ document_id=document.id,
+ error_message=f"Parser for {document.document_type} not found in `R2RIngestionProvider`.",
+ )
+ else:
+ t0 = time.time()
+ contents = []
+ parser_overrides = ingestion_config_override.get(
+ "parser_overrides", {}
+ )
+ if document.document_type.value in parser_overrides:
+ logger.info(
+ f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
+ )
+ # TODO - Cleanup this approach to be less hardcoded
+ if (
+ document.document_type != DocumentType.PDF
+ or parser_overrides[DocumentType.PDF.value] != "zerox"
+ ):
+ raise ValueError(
+ "Only Zerox PDF parser override is available."
+ )
+
+ # Collect content from VLMPDFParser
+ async for chunk in self.parsers[
+ f"zerox_{DocumentType.PDF.value}"
+ ].ingest(file_content, **ingestion_config_override):
+ if isinstance(chunk, dict) and chunk.get("content"):
+ contents.append(chunk)
+ elif (
+ chunk
+ ): # Handle string output for backward compatibility
+ contents.append({"content": chunk})
+
+ if (
+ contents
+ and document.document_type == DocumentType.PDF
+ and parser_overrides.get(DocumentType.PDF.value) == "zerox"
+ ):
+ text_splitter = self._build_text_splitter(
+ ingestion_config_override
+ )
+
+ iteration = 0
+
+ sorted_contents = [
+ item
+ for item in sorted(
+ contents, key=lambda x: x.get("page_number", 0)
+ )
+ if isinstance(item.get("content"), str)
+ ]
+
+ for content_item in sorted_contents:
+ page_num = content_item.get("page_number", 0)
+ page_content = content_item["content"]
+
+ page_chunks = text_splitter.create_documents(
+ [page_content]
+ )
+
+ # Create document chunks for each split piece
+ for chunk in page_chunks:
+ metadata = {
+ **document.metadata,
+ "chunk_order": iteration,
+ "page_number": page_num,
+ }
+
+ extraction = DocumentChunk(
+ id=generate_extraction_id(
+ document.id, iteration
+ ),
+ document_id=document.id,
+ owner_id=document.owner_id,
+ collection_ids=document.collection_ids,
+ data=chunk.page_content,
+ metadata=metadata,
+ )
+ iteration += 1
+ yield extraction
+
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
+ f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
+ f"into {iteration} extractions in t={time.time() - t0:.2f} seconds using page-by-page splitting."
+ )
+ return
+
+ else:
+ # Standard parsing for non-override cases
+ async for text in self.parsers[document.document_type].ingest(
+ file_content, **ingestion_config_override
+ ):
+ if text is not None:
+ contents.append({"content": text})
+
+ if not contents:
+ logging.warning(
+ "No valid text content was extracted during parsing"
+ )
+ return
+
+ iteration = 0
+ for content_item in contents:
+ chunk_text = content_item["content"]
+ chunks = self.chunk(chunk_text, ingestion_config_override)
+
+ for chunk in chunks:
+ metadata = {**document.metadata, "chunk_order": iteration}
+ if "page_number" in content_item:
+ metadata["page_number"] = content_item["page_number"]
+
+ extraction = DocumentChunk(
+ id=generate_extraction_id(document.id, iteration),
+ document_id=document.id,
+ owner_id=document.owner_id,
+ collection_ids=document.collection_ids,
+ data=chunk,
+ metadata=metadata,
+ )
+ iteration += 1
+ yield extraction
+
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
+ f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
+ f"into {iteration} extractions in t={time.time() - t0:.2f} seconds."
+ )
+
+ def get_parser_for_document_type(self, doc_type: DocumentType) -> Any:
+ return self.parsers.get(doc_type)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py b/.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py
new file mode 100644
index 00000000..29c09bf5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/ingestion/unstructured/base.py
@@ -0,0 +1,396 @@
+# TODO - cleanup type issues in this file that relate to `bytes`
+import asyncio
+import base64
+import logging
+import os
+import time
+from copy import copy
+from io import BytesIO
+from typing import Any, AsyncGenerator
+
+import httpx
+from unstructured_client import UnstructuredClient
+from unstructured_client.models import operations, shared
+
+from core import parsers
+from core.base import (
+ AsyncParser,
+ ChunkingStrategy,
+ Document,
+ DocumentChunk,
+ DocumentType,
+ RecursiveCharacterTextSplitter,
+)
+from core.base.abstractions import R2RSerializable
+from core.base.providers.ingestion import IngestionConfig, IngestionProvider
+from core.utils import generate_extraction_id
+
+from ...database import PostgresDatabaseProvider
+from ...llm import (
+ LiteLLMCompletionProvider,
+ OpenAICompletionProvider,
+ R2RCompletionProvider,
+)
+
+logger = logging.getLogger()
+
+
+class FallbackElement(R2RSerializable):
+ text: str
+ metadata: dict[str, Any]
+
+
+class UnstructuredIngestionConfig(IngestionConfig):
+ combine_under_n_chars: int = 128
+ max_characters: int = 500
+ new_after_n_chars: int = 1500
+ overlap: int = 64
+
+ coordinates: bool | None = None
+ encoding: str | None = None # utf-8
+ extract_image_block_types: list[str] | None = None
+ gz_uncompressed_content_type: str | None = None
+ hi_res_model_name: str | None = None
+ include_orig_elements: bool | None = None
+ include_page_breaks: bool | None = None
+
+ languages: list[str] | None = None
+ multipage_sections: bool | None = None
+ ocr_languages: list[str] | None = None
+ # output_format: Optional[str] = "application/json"
+ overlap_all: bool | None = None
+ pdf_infer_table_structure: bool | None = None
+
+ similarity_threshold: float | None = None
+ skip_infer_table_types: list[str] | None = None
+ split_pdf_concurrency_level: int | None = None
+ split_pdf_page: bool | None = None
+ starting_page_number: int | None = None
+ strategy: str | None = None
+ chunking_strategy: str | ChunkingStrategy | None = None # type: ignore
+ unique_element_ids: bool | None = None
+ xml_keep_tags: bool | None = None
+
+ def to_ingestion_request(self):
+ import json
+
+ x = json.loads(self.json())
+ x.pop("extra_fields", None)
+ x.pop("provider", None)
+ x.pop("excluded_parsers", None)
+
+ x = {k: v for k, v in x.items() if v is not None}
+ return x
+
+
+class UnstructuredIngestionProvider(IngestionProvider):
+ R2R_FALLBACK_PARSERS = {
+ DocumentType.GIF: [parsers.ImageParser], # type: ignore
+ DocumentType.JPEG: [parsers.ImageParser], # type: ignore
+ DocumentType.JPG: [parsers.ImageParser], # type: ignore
+ DocumentType.PNG: [parsers.ImageParser], # type: ignore
+ DocumentType.SVG: [parsers.ImageParser], # type: ignore
+ DocumentType.HEIC: [parsers.ImageParser], # type: ignore
+ DocumentType.MP3: [parsers.AudioParser], # type: ignore
+ DocumentType.JSON: [parsers.JSONParser], # type: ignore
+ DocumentType.HTML: [parsers.HTMLParser], # type: ignore
+ DocumentType.XLS: [parsers.XLSParser], # type: ignore
+ DocumentType.XLSX: [parsers.XLSXParser], # type: ignore
+ DocumentType.DOC: [parsers.DOCParser], # type: ignore
+ DocumentType.PPT: [parsers.PPTParser], # type: ignore
+ }
+
+ EXTRA_PARSERS = {
+ DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced}, # type: ignore
+ DocumentType.PDF: {
+ "unstructured": parsers.PDFParserUnstructured, # type: ignore
+ "zerox": parsers.VLMPDFParser, # type: ignore
+ },
+ DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced}, # type: ignore
+ }
+
+ IMAGE_TYPES = {
+ DocumentType.GIF,
+ DocumentType.HEIC,
+ DocumentType.JPG,
+ DocumentType.JPEG,
+ DocumentType.PNG,
+ DocumentType.SVG,
+ }
+
+ def __init__(
+ self,
+ config: UnstructuredIngestionConfig,
+ database_provider: PostgresDatabaseProvider,
+ llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ),
+ ):
+ super().__init__(config, database_provider, llm_provider)
+ self.config: UnstructuredIngestionConfig = config
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ self.llm_provider: (
+ LiteLLMCompletionProvider
+ | OpenAICompletionProvider
+ | R2RCompletionProvider
+ ) = llm_provider
+
+ if config.provider == "unstructured_api":
+ try:
+ self.unstructured_api_auth = os.environ["UNSTRUCTURED_API_KEY"]
+ except KeyError as e:
+ raise ValueError(
+ "UNSTRUCTURED_API_KEY environment variable is not set"
+ ) from e
+
+ self.unstructured_api_url = os.environ.get(
+ "UNSTRUCTURED_API_URL",
+ "https://api.unstructuredapp.io/general/v0/general",
+ )
+
+ self.client = UnstructuredClient(
+ api_key_auth=self.unstructured_api_auth,
+ server_url=self.unstructured_api_url,
+ )
+ self.shared = shared
+ self.operations = operations
+
+ else:
+ try:
+ self.local_unstructured_url = os.environ[
+ "UNSTRUCTURED_SERVICE_URL"
+ ]
+ except KeyError as e:
+ raise ValueError(
+ "UNSTRUCTURED_SERVICE_URL environment variable is not set"
+ ) from e
+
+ self.client = httpx.AsyncClient()
+
+ self.parsers: dict[DocumentType, AsyncParser] = {}
+ self._initialize_parsers()
+
+ def _initialize_parsers(self):
+ for doc_type, parsers in self.R2R_FALLBACK_PARSERS.items():
+ for parser in parsers:
+ if (
+ doc_type not in self.config.excluded_parsers
+ and doc_type not in self.parsers
+ ):
+ # will choose the first parser in the list
+ self.parsers[doc_type] = parser(
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ # TODO - Reduce code duplication between Unstructured & R2R
+ for doc_type, doc_parser_name in self.config.extra_parsers.items():
+ self.parsers[f"{doc_parser_name}_{str(doc_type)}"] = (
+ UnstructuredIngestionProvider.EXTRA_PARSERS[doc_type][
+ doc_parser_name
+ ](
+ config=self.config,
+ database_provider=self.database_provider,
+ llm_provider=self.llm_provider,
+ )
+ )
+
+ async def parse_fallback(
+ self,
+ file_content: bytes,
+ ingestion_config: dict,
+ parser_name: str,
+ ) -> AsyncGenerator[FallbackElement, None]:
+ contents = []
+ async for chunk in self.parsers[parser_name].ingest( # type: ignore
+ file_content, **ingestion_config
+ ): # type: ignore
+ if isinstance(chunk, dict) and chunk.get("content"):
+ contents.append(chunk)
+ elif chunk: # Handle string output for backward compatibility
+ contents.append({"content": chunk})
+
+ if not contents:
+ logging.warning(
+ "No valid text content was extracted during parsing"
+ )
+ return
+
+ logging.info(f"Fallback ingestion with config = {ingestion_config}")
+
+ iteration = 0
+ for content_item in contents:
+ text = content_item["content"]
+
+ loop = asyncio.get_event_loop()
+ splitter = RecursiveCharacterTextSplitter(
+ chunk_size=ingestion_config["new_after_n_chars"],
+ chunk_overlap=ingestion_config["overlap"],
+ )
+ chunks = await loop.run_in_executor(
+ None, splitter.create_documents, [text]
+ )
+
+ for text_chunk in chunks:
+ metadata = {"chunk_id": iteration}
+ if "page_number" in content_item:
+ metadata["page_number"] = content_item["page_number"]
+
+ yield FallbackElement(
+ text=text_chunk.page_content,
+ metadata=metadata,
+ )
+ iteration += 1
+ await asyncio.sleep(0)
+
+ async def parse(
+ self,
+ file_content: bytes,
+ document: Document,
+ ingestion_config_override: dict,
+ ) -> AsyncGenerator[DocumentChunk, None]:
+ ingestion_config = copy(
+ {
+ **self.config.to_ingestion_request(),
+ **(ingestion_config_override or {}),
+ }
+ )
+ # cleanup extra fields
+ ingestion_config.pop("provider", None)
+ ingestion_config.pop("excluded_parsers", None)
+
+ t0 = time.time()
+ parser_overrides = ingestion_config_override.get(
+ "parser_overrides", {}
+ )
+ elements = []
+
+ # TODO - Cleanup this approach to be less hardcoded
+ # TODO - Remove code duplication between Unstructured & R2R
+ if document.document_type.value in parser_overrides:
+ logger.info(
+ f"Using parser_override for {document.document_type} with input value {parser_overrides[document.document_type.value]}"
+ )
+ async for element in self.parse_fallback(
+ file_content,
+ ingestion_config=ingestion_config,
+ parser_name=f"zerox_{DocumentType.PDF.value}",
+ ):
+ elements.append(element)
+
+ elif document.document_type in self.R2R_FALLBACK_PARSERS.keys():
+ logger.info(
+ f"Parsing {document.document_type}: {document.id} with fallback parser"
+ )
+ async for element in self.parse_fallback(
+ file_content,
+ ingestion_config=ingestion_config,
+ parser_name=document.document_type,
+ ):
+ elements.append(element)
+ else:
+ logger.info(
+ f"Parsing {document.document_type}: {document.id} with unstructured"
+ )
+ if isinstance(file_content, bytes):
+ file_content = BytesIO(file_content) # type: ignore
+
+ # TODO - Include check on excluded parsers here.
+ if self.config.provider == "unstructured_api":
+ logger.info(f"Using API to parse document {document.id}")
+ files = self.shared.Files(
+ content=file_content.read(), # type: ignore
+ file_name=document.metadata.get("title", "unknown_file"),
+ )
+
+ ingestion_config.pop("app", None)
+ ingestion_config.pop("extra_parsers", None)
+
+ req = self.operations.PartitionRequest(
+ self.shared.PartitionParameters(
+ files=files,
+ **ingestion_config,
+ )
+ )
+ elements = self.client.general.partition(req) # type: ignore
+ elements = list(elements.elements) # type: ignore
+
+ else:
+ logger.info(
+ f"Using local unstructured fastapi server to parse document {document.id}"
+ )
+ # Base64 encode the file content
+ encoded_content = base64.b64encode(file_content.read()).decode( # type: ignore
+ "utf-8"
+ )
+
+ logger.info(
+ f"Sending a request to {self.local_unstructured_url}/partition"
+ )
+
+ response = await self.client.post(
+ f"{self.local_unstructured_url}/partition",
+ json={
+ "file_content": encoded_content, # Use encoded string
+ "ingestion_config": ingestion_config,
+ "filename": document.metadata.get("title", None),
+ },
+ timeout=3600, # Adjust timeout as needed
+ )
+
+ if response.status_code != 200:
+ logger.error(f"Error partitioning file: {response.text}")
+ raise ValueError(
+ f"Error partitioning file: {response.text}"
+ )
+ elements = response.json().get("elements", [])
+
+ iteration = 0 # if there are no chunks
+ for iteration, element in enumerate(elements):
+ if isinstance(element, FallbackElement):
+ text = element.text
+ metadata = copy(document.metadata)
+ metadata.update(element.metadata)
+ else:
+ element_dict = (
+ element if isinstance(element, dict) else element.to_dict()
+ )
+ text = element_dict.get("text", "")
+ if text == "":
+ continue
+
+ metadata = copy(document.metadata)
+ for key, value in element_dict.items():
+ if key == "metadata":
+ for k, v in value.items():
+ if k not in metadata and k != "orig_elements":
+ metadata[f"unstructured_{k}"] = v
+
+ # indicate that the document was chunked using unstructured
+ # nullifies the need for chunking in the pipeline
+ metadata["partitioned_by_unstructured"] = True
+ metadata["chunk_order"] = iteration
+ # creating the text extraction
+ yield DocumentChunk(
+ id=generate_extraction_id(document.id, iteration),
+ document_id=document.id,
+ owner_id=document.owner_id,
+ collection_ids=document.collection_ids,
+ data=text,
+ metadata=metadata,
+ )
+
+ # TODO: explore why this is throwing inadvertedly
+ # if iteration == 0:
+ # raise ValueError(f"No chunks found for document {document.id}")
+
+ logger.debug(
+ f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, "
+ f"user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} "
+ f"into {iteration + 1} extractions in t={time.time() - t0:.2f} seconds."
+ )
+
+ def get_parser_for_document_type(self, doc_type: DocumentType) -> str:
+ return "unstructured_local"
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py
new file mode 100644
index 00000000..8132e11c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/__init__.py
@@ -0,0 +1,11 @@
+from .anthropic import AnthropicCompletionProvider
+from .litellm import LiteLLMCompletionProvider
+from .openai import OpenAICompletionProvider
+from .r2r_llm import R2RCompletionProvider
+
+__all__ = [
+ "AnthropicCompletionProvider",
+ "LiteLLMCompletionProvider",
+ "OpenAICompletionProvider",
+ "R2RCompletionProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py b/.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py
new file mode 100644
index 00000000..0089a207
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/anthropic.py
@@ -0,0 +1,925 @@
+import copy
+import json
+import logging
+import os
+import time
+import uuid
+from typing import (
+ Any,
+ AsyncGenerator,
+ Generator,
+ Optional,
+)
+
+from anthropic import Anthropic, AsyncAnthropic
+from anthropic.types import (
+ ContentBlockStopEvent,
+ Message,
+ MessageStopEvent,
+ RawContentBlockDeltaEvent,
+ RawContentBlockStartEvent,
+ RawMessageStartEvent,
+ ToolUseBlock,
+)
+
+from core.base.abstractions import GenerationConfig, LLMChatCompletion
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+from .utils import resize_base64_image
+
+logger = logging.getLogger(__name__)
+
+
+def generate_tool_id() -> str:
+ """Generate a unique tool ID using UUID4."""
+ return f"tool_{uuid.uuid4().hex[:12]}"
+
+
+def process_images_in_message(message: dict) -> dict:
+ """
+ Process all images in a message to ensure they're within Anthropic's recommended limits.
+ """
+ if not message or not isinstance(message, dict):
+ return message
+
+ # Handle nested image_data (old format)
+ if (
+ message.get("role")
+ and message.get("image_data")
+ and isinstance(message["image_data"], dict)
+ ):
+ if message["image_data"].get("data") and message["image_data"].get(
+ "media_type"
+ ):
+ message["image_data"]["data"] = resize_base64_image(
+ message["image_data"]["data"]
+ )
+ return message
+
+ # Handle standard content list format
+ if message.get("content") and isinstance(message["content"], list):
+ for i, block in enumerate(message["content"]):
+ if isinstance(block, dict) and block.get("type") == "image":
+ if block.get("source", {}).get("type") == "base64" and block[
+ "source"
+ ].get("data"):
+ message["content"][i]["source"]["data"] = (
+ resize_base64_image(block["source"]["data"])
+ )
+
+ # Handle string content with base64 image detection (less common)
+ elif (
+ message.get("content")
+ and isinstance(message["content"], str)
+ and ";base64," in message["content"]
+ ):
+ # This is a basic detection for base64 images in text - might need more robust handling
+ logger.warning(
+ "Detected potential base64 image in string content - not auto-resizing"
+ )
+
+ return message
+
+
+def openai_message_to_anthropic_block(msg: dict) -> dict:
+ """Converts a single OpenAI-style message (including function/tool calls)
+ into one Anthropic-style message.
+
+ Expected keys in `msg` can include:
+ - role: "system" | "assistant" | "user" | "function" | "tool"
+ - content: str (possibly JSON arguments or the final text)
+ - name: str (tool/function name)
+ - tool_call_id or function_call arguments
+ - function_call: {"name": ..., "arguments": "..."}
+ """
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+ tool_call_id = msg.get("tool_call_id")
+
+ # Handle old-style image_data field
+ image_data = msg.get("image_data")
+ # Handle nested image_url (less common)
+ image_url = msg.get("image_url")
+
+ if role == "system":
+ # System messages should not have images, extract any image to a separate user message
+ if image_url or image_data:
+ logger.warning(
+ "Found image in system message - images should be in user messages only"
+ )
+ return msg
+
+ if role in ["user", "assistant"]:
+ # If content is already a list, assume it's properly formatted
+ if isinstance(content, list):
+ return {"role": role, "content": content}
+
+ # Process old-style image_data or image_url
+ if image_url or image_data:
+ formatted_content = []
+
+ # Add image content first (as recommended by Anthropic)
+ if image_url:
+ formatted_content.append(
+ {
+ "type": "image",
+ "source": {"type": "url", "url": image_url},
+ }
+ )
+ elif image_data:
+ # Resize the image data if needed
+ resized_data = image_data.get("data", "")
+ if resized_data:
+ resized_data = resize_base64_image(resized_data)
+
+ formatted_content.append(
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": image_data.get(
+ "media_type", "image/jpeg"
+ ),
+ "data": resized_data,
+ },
+ }
+ )
+
+ # Add text content after the image
+ if content:
+ if isinstance(content, str):
+ formatted_content.append({"type": "text", "text": content})
+ elif isinstance(content, list):
+ # If it's already a list, extend with it
+ formatted_content.extend(content)
+
+ return {"role": role, "content": formatted_content}
+
+ if role in ["function", "tool"]:
+ return {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": tool_call_id,
+ "content": content,
+ }
+ ],
+ }
+
+ return {"role": role, "content": content}
+
+
+class AnthropicCompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.client = Anthropic()
+ self.async_client = AsyncAnthropic()
+ logger.debug("AnthropicCompletionProvider initialized successfully")
+
+ def _get_base_args(
+ self, generation_config: GenerationConfig
+ ) -> dict[str, Any]:
+ """Build the arguments dictionary for Anthropic's messages.create().
+
+ Handles tool configuration according to Anthropic's schema:
+ {
+ "type": "function", # Use 'function' type for custom tools
+ "name": "tool_name",
+ "description": "tool description",
+ "parameters": { # Note: Anthropic expects 'parameters', not 'input_schema'
+ "type": "object",
+ "properties": {...},
+ "required": [...]
+ }
+ }
+ """
+ model_str = generation_config.model or ""
+ model_name = (
+ model_str.split("anthropic/")[-1]
+ if model_str
+ else "claude-3-opus-20240229"
+ )
+
+ args: dict[str, Any] = {
+ "model": model_name,
+ "temperature": generation_config.temperature,
+ "max_tokens": generation_config.max_tokens_to_sample,
+ "stream": generation_config.stream,
+ }
+ if generation_config.top_p:
+ args["top_p"] = generation_config.top_p
+
+ if generation_config.tools is not None:
+ # Convert tools to Anthropic's format
+ anthropic_tools: list[dict[str, Any]] = []
+ for tool in generation_config.tools:
+ tool_def = {
+ "name": tool["function"]["name"],
+ "description": tool["function"]["description"],
+ "input_schema": tool["function"]["parameters"],
+ }
+ anthropic_tools.append(tool_def)
+ args["tools"] = anthropic_tools
+
+ if hasattr(generation_config, "tool_choice"):
+ tool_choice = generation_config.tool_choice
+ if isinstance(tool_choice, str):
+ if tool_choice == "auto":
+ args["tool_choice"] = {"type": "auto"}
+ elif tool_choice == "any":
+ args["tool_choice"] = {"type": "any"}
+ elif isinstance(tool_choice, dict):
+ if tool_choice.get("type") == "function":
+ args["tool_choice"] = {
+ "type": "function",
+ "name": tool_choice.get("name"),
+ }
+ if hasattr(generation_config, "disable_parallel_tool_use"):
+ args["tool_choice"] = args.get("tool_choice", {})
+ args["tool_choice"]["disable_parallel_tool_use"] = (
+ generation_config.disable_parallel_tool_use
+ )
+
+ # --- Extended Thinking Support ---
+ if getattr(generation_config, "extended_thinking", False):
+ if (
+ not hasattr(generation_config, "thinking_budget")
+ or generation_config.thinking_budget is None
+ ):
+ raise ValueError(
+ "Extended thinking is enabled but no thinking_budget is provided."
+ )
+ if (
+ generation_config.thinking_budget
+ >= generation_config.max_tokens_to_sample
+ ):
+ raise ValueError(
+ "thinking_budget must be less than max_tokens_to_sample."
+ )
+ args["thinking"] = {
+ "type": "enabled",
+ "budget_tokens": generation_config.thinking_budget,
+ }
+ return args
+
+ def _preprocess_messages(self, messages: list[dict]) -> list[dict]:
+ """
+ Preprocess all messages to optimize images before sending to Anthropic API.
+ """
+ if not messages or not isinstance(messages, list):
+ return messages
+
+ processed_messages = []
+ for message in messages:
+ processed_message = process_images_in_message(message)
+ processed_messages.append(processed_message)
+
+ return processed_messages
+
+ def _create_openai_style_message(self, content_blocks, tool_calls=None):
+ """
+ Create an OpenAI-style message from Anthropic content blocks
+ while preserving the original structure.
+ """
+ display_content = ""
+ structured_content: list[Any] = []
+
+ for block in content_blocks:
+ if block.type == "text":
+ display_content += block.text
+ elif block.type == "thinking" and hasattr(block, "thinking"):
+ # Store the complete thinking block
+ structured_content.append(
+ {
+ "type": "thinking",
+ "thinking": block.thinking,
+ "signature": block.signature,
+ }
+ )
+ # For display/logging
+ # display_content += f"<think>{block.thinking}</think>"
+ elif block.type == "redacted_thinking" and hasattr(block, "data"):
+ # Store the complete redacted thinking block
+ structured_content.append(
+ {"type": "redacted_thinking", "data": block.data}
+ )
+ # Add a placeholder for display/logging
+ display_content += "<redacted thinking block>"
+ elif block.type == "tool_use":
+ # Tool use blocks are handled separately via tool_calls
+ pass
+
+ # If we have structured content (thinking blocks), use that
+ if structured_content:
+ # Add any text block at the end if needed
+ for block in content_blocks:
+ if block.type == "text":
+ structured_content.append(
+ {"type": "text", "text": block.text}
+ )
+
+ return {
+ "content": display_content or None,
+ "structured_content": structured_content,
+ }
+ else:
+ # If no structured content, just return the display content
+ return {"content": display_content or None}
+
+ def _convert_to_chat_completion(self, anthropic_msg: Message) -> dict:
+ """
+ Convert a non-streaming Anthropic Message into an OpenAI-style dict.
+ Preserves thinking blocks for proper handling.
+ """
+ tool_calls: list[Any] = []
+ message_data: dict[str, Any] = {"role": anthropic_msg.role}
+
+ if anthropic_msg.content:
+ # First, extract any tool use blocks
+ for block in anthropic_msg.content:
+ if hasattr(block, "type") and block.type == "tool_use":
+ tool_calls.append(
+ {
+ "index": len(tool_calls),
+ "id": block.id,
+ "type": "function",
+ "function": {
+ "name": block.name,
+ "arguments": json.dumps(block.input),
+ },
+ }
+ )
+
+ # Then create the message with appropriate content
+ message_data.update(
+ self._create_openai_style_message(
+ anthropic_msg.content, tool_calls
+ )
+ )
+
+ # If we have tool calls, add them
+ if tool_calls:
+ message_data["tool_calls"] = tool_calls
+
+ finish_reason = (
+ "stop"
+ if anthropic_msg.stop_reason == "end_turn"
+ else anthropic_msg.stop_reason
+ )
+ finish_reason = (
+ "tool_calls"
+ if anthropic_msg.stop_reason == "tool_use"
+ else finish_reason
+ )
+
+ model_str = anthropic_msg.model or ""
+ model_name = model_str.split("anthropic/")[-1] if model_str else ""
+
+ return {
+ "id": anthropic_msg.id,
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": model_name,
+ "usage": {
+ "prompt_tokens": (
+ anthropic_msg.usage.input_tokens
+ if anthropic_msg.usage
+ else 0
+ ),
+ "completion_tokens": (
+ anthropic_msg.usage.output_tokens
+ if anthropic_msg.usage
+ else 0
+ ),
+ "total_tokens": (
+ (
+ anthropic_msg.usage.input_tokens
+ if anthropic_msg.usage
+ else 0
+ )
+ + (
+ anthropic_msg.usage.output_tokens
+ if anthropic_msg.usage
+ else 0
+ )
+ ),
+ },
+ "choices": [
+ {
+ "index": 0,
+ "message": message_data,
+ "finish_reason": finish_reason,
+ }
+ ],
+ }
+
+ def _split_system_messages(
+ self, messages: list[dict]
+ ) -> tuple[list[dict], Optional[str]]:
+ """
+ Process messages for Anthropic API, ensuring proper format for tool use and thinking blocks.
+ Now with image optimization.
+ """
+ # First preprocess to resize any images
+ messages = self._preprocess_messages(messages)
+
+ system_msg = None
+ filtered: list[dict[str, Any]] = []
+ pending_tool_results: list[dict[str, Any]] = []
+
+ # Look for pairs of tool_use and tool_result
+ i = 0
+ while i < len(messages):
+ m = copy.deepcopy(messages[i])
+
+ # Handle system message
+ if m["role"] == "system" and system_msg is None:
+ system_msg = m["content"]
+ i += 1
+ continue
+
+ # Case 1: Message with list format content (thinking blocks or tool blocks)
+ if (
+ isinstance(m.get("content"), list)
+ and len(m["content"]) > 0
+ and isinstance(m["content"][0], dict)
+ ):
+ filtered.append({"role": m["role"], "content": m["content"]})
+ i += 1
+ continue
+
+ # Case 2: Message with structured_content field
+ elif m.get("structured_content") and m["role"] == "assistant":
+ filtered.append(
+ {"role": "assistant", "content": m["structured_content"]}
+ )
+ i += 1
+ continue
+
+ # Case 3: Tool calls in an assistant message
+ elif m.get("tool_calls") and m["role"] == "assistant":
+ # Add content if it exists
+ if m.get("content") and not isinstance(m["content"], list):
+ content_to_add = m["content"]
+ # Handle content with thinking tags
+ if "<think>" in content_to_add:
+ thinking_start = content_to_add.find("<think>")
+ thinking_end = content_to_add.find("</think>")
+ if (
+ thinking_start >= 0
+ and thinking_end > thinking_start
+ ):
+ thinking_content = content_to_add[
+ thinking_start + 7 : thinking_end
+ ]
+ text_content = content_to_add[
+ thinking_end + 8 :
+ ].strip()
+ filtered.append(
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "thinking",
+ "thinking": thinking_content,
+ "signature": "placeholder_signature", # This is a placeholder
+ },
+ {"type": "text", "text": text_content},
+ ],
+ }
+ )
+ else:
+ filtered.append(
+ {
+ "role": "assistant",
+ "content": content_to_add,
+ }
+ )
+ else:
+ filtered.append(
+ {"role": "assistant", "content": content_to_add}
+ )
+
+ # Add tool use blocks
+ tool_uses = []
+ for call in m["tool_calls"]:
+ tool_uses.append(
+ {
+ "type": "tool_use",
+ "id": call["id"],
+ "name": call["function"]["name"],
+ "input": json.loads(call["function"]["arguments"]),
+ }
+ )
+
+ filtered.append({"role": "assistant", "content": tool_uses})
+
+ # Check if next message is a tool result for this tool call
+ if i + 1 < len(messages) and messages[i + 1]["role"] in [
+ "function",
+ "tool",
+ ]:
+ next_m = copy.deepcopy(messages[i + 1])
+
+ # Make sure this is a tool result for the current tool use
+ if next_m.get("tool_call_id") in [
+ call["id"] for call in m["tool_calls"]
+ ]:
+ # Add tool result as a user message
+ filtered.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": next_m["tool_call_id"],
+ "content": next_m["content"],
+ }
+ ],
+ }
+ )
+ i += 2 # Skip both the tool call and result
+ continue
+
+ i += 1
+ continue
+
+ # Case 4: Direct tool result (might be missing its paired tool call)
+ elif m["role"] in ["function", "tool"] and m.get("tool_call_id"):
+ # Add a user message with the tool result
+ filtered.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": m["tool_call_id"],
+ "content": m["content"],
+ }
+ ],
+ }
+ )
+ i += 1
+ continue
+
+ # Default case: normal message
+ elif m["role"] in ["function", "tool"]:
+ # Collect tool results to combine them
+ pending_tool_results.append(
+ {
+ "type": "tool_result",
+ "tool_use_id": m.get("tool_call_id"),
+ "content": m["content"],
+ }
+ )
+
+ # If we have all expected results, add them as one message
+ if len(filtered) > 0 and len(
+ filtered[-1].get("content", [])
+ ) == len(pending_tool_results):
+ filtered.append(
+ {"role": "user", "content": pending_tool_results}
+ )
+ pending_tool_results = []
+ else:
+ filtered.append(openai_message_to_anthropic_block(m))
+ i += 1
+
+ # Final validation: ensure no tool_use is at the end without a tool_result
+ if filtered and len(filtered) > 1:
+ last_msg = filtered[-1]
+ if (
+ last_msg["role"] == "assistant"
+ and isinstance(last_msg.get("content"), list)
+ and any(
+ block.get("type") == "tool_use"
+ for block in last_msg["content"]
+ )
+ ):
+ logger.warning(
+ "Found tool_use at end of conversation without tool_result - removing it"
+ )
+ filtered.pop() # Remove problematic message
+
+ return filtered, system_msg
+
+ async def _execute_task(self, task: dict[str, Any]):
+ """Async entry point.
+
+ Decide if streaming or not, then call the appropriate helper.
+ """
+ api_key = os.getenv("ANTHROPIC_API_KEY")
+ if not api_key:
+ logger.error("Missing ANTHROPIC_API_KEY in environment.")
+ raise ValueError(
+ "Anthropic API key not found. Set ANTHROPIC_API_KEY env var."
+ )
+
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ extra_kwargs = task["kwargs"]
+ base_args = self._get_base_args(generation_config)
+ filtered_messages, system_msg = self._split_system_messages(messages)
+ base_args["messages"] = filtered_messages
+ if system_msg:
+ base_args["system"] = system_msg
+
+ args = {**base_args, **extra_kwargs}
+ logger.debug(f"Anthropic async call with args={args}")
+
+ if generation_config.stream:
+ return self._execute_task_async_streaming(args)
+ else:
+ return await self._execute_task_async_nonstreaming(args)
+
+ async def _execute_task_async_nonstreaming(
+ self, args: dict[str, Any]
+ ) -> LLMChatCompletion:
+ api_key = os.getenv("ANTHROPIC_API_KEY")
+ if not api_key:
+ logger.error("Missing ANTHROPIC_API_KEY in environment.")
+ raise ValueError(
+ "Anthropic API key not found. Set ANTHROPIC_API_KEY env var."
+ )
+
+ try:
+ logger.debug(f"Anthropic API request: {args}")
+ response = await self.async_client.messages.create(**args)
+ logger.debug(f"Anthropic API response: {response}")
+
+ return LLMChatCompletion(
+ **self._convert_to_chat_completion(response)
+ )
+ except Exception as e:
+ logger.error(f"Anthropic async non-stream call failed: {e}")
+ logger.error("message payload = ", args)
+ raise
+
+ async def _execute_task_async_streaming(
+ self, args: dict
+ ) -> AsyncGenerator[dict[str, Any], None]:
+ """Streaming call (async): yields partial tokens in OpenAI-like SSE
+ format."""
+ # The `stream=True` is typically handled by Anthropics from the original args,
+ # but we remove it to avoid conflicts and rely on `messages.stream()`.
+ args.pop("stream", None)
+ try:
+ async with self.async_client.messages.stream(**args) as stream:
+ # We'll track partial JSON for function calls in buffer_data
+ buffer_data: dict[str, Any] = {
+ "tool_json_buffer": "",
+ "tool_name": None,
+ "tool_id": None,
+ "is_collecting_tool": False,
+ "thinking_buffer": "",
+ "is_collecting_thinking": False,
+ "thinking_signature": None,
+ "message_id": f"chatcmpl-{int(time.time())}",
+ }
+ model_name = args.get("model", "claude-2")
+ if isinstance(model_name, str):
+ model_name = model_name.split("anthropic/")[-1]
+
+ async for event in stream:
+ chunks = self._process_stream_event(
+ event=event,
+ buffer_data=buffer_data,
+ model_name=model_name,
+ )
+ for chunk in chunks:
+ yield chunk
+ except Exception as e:
+ logger.error(f"Failed to execute streaming Anthropic task: {e}")
+ logger.error("message payload = ", args)
+
+ raise
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ """Synchronous entry point."""
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ extra_kwargs = task["kwargs"]
+
+ base_args = self._get_base_args(generation_config)
+ filtered_messages, system_msg = self._split_system_messages(messages)
+ base_args["messages"] = filtered_messages
+ if system_msg:
+ base_args["system"] = system_msg
+
+ args = {**base_args, **extra_kwargs}
+ logger.debug(f"Anthropic sync call with args={args}")
+
+ if generation_config.stream:
+ return self._execute_task_sync_streaming(args)
+ else:
+ return self._execute_task_sync_nonstreaming(args)
+
+ def _execute_task_sync_nonstreaming(
+ self, args: dict[str, Any]
+ ): # -> LLMChatCompletion: # FIXME: LLMChatCompletion is an object from the OpenAI API, which causes a validation error
+ """Non-streaming synchronous call."""
+ try:
+ response = self.client.messages.create(**args)
+ logger.debug("Anthropic sync non-stream call succeeded.")
+ return LLMChatCompletion(
+ **self._convert_to_chat_completion(response)
+ )
+ except Exception as e:
+ logger.error(f"Anthropic sync call failed: {e}")
+ raise
+
+ def _execute_task_sync_streaming(
+ self, args: dict[str, Any]
+ ) -> Generator[dict[str, Any], None, None]:
+ """
+ Synchronous streaming call: yields partial tokens in a generator.
+ """
+ args.pop("stream", None)
+ try:
+ with self.client.messages.stream(**args) as stream:
+ buffer_data: dict[str, Any] = {
+ "tool_json_buffer": "",
+ "tool_name": None,
+ "tool_id": None,
+ "is_collecting_tool": False,
+ "thinking_buffer": "",
+ "is_collecting_thinking": False,
+ "thinking_signature": None,
+ "message_id": f"chatcmpl-{int(time.time())}",
+ }
+ model_name = args.get("model", "anthropic/claude-2")
+ if isinstance(model_name, str):
+ model_name = model_name.split("anthropic/")[-1]
+
+ for event in stream:
+ yield from self._process_stream_event(
+ event=event,
+ buffer_data=buffer_data,
+ model_name=model_name.split("anthropic/")[-1],
+ )
+ except Exception as e:
+ logger.error(f"Anthropic sync streaming call failed: {e}")
+ raise
+
+ def _process_stream_event(
+ self, event: Any, buffer_data: dict[str, Any], model_name: str
+ ) -> list[dict[str, Any]]:
+ chunks: list[dict[str, Any]] = []
+
+ def make_base_chunk() -> dict[str, Any]:
+ return {
+ "id": buffer_data["message_id"],
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model_name,
+ "choices": [{"index": 0, "delta": {}, "finish_reason": None}],
+ }
+
+ if isinstance(event, RawMessageStartEvent):
+ buffer_data["message_id"] = event.message.id
+ chunk = make_base_chunk()
+ input_tokens = (
+ event.message.usage.input_tokens if event.message.usage else 0
+ )
+ chunk["usage"] = {
+ "prompt_tokens": input_tokens,
+ "completion_tokens": 0,
+ "total_tokens": input_tokens,
+ }
+ chunks.append(chunk)
+
+ elif isinstance(event, RawContentBlockStartEvent):
+ if hasattr(event.content_block, "type"):
+ block_type = event.content_block.type
+ if block_type == "thinking":
+ buffer_data["is_collecting_thinking"] = True
+ buffer_data["thinking_buffer"] = ""
+ # Don't emit anything yet
+ elif block_type == "tool_use" or isinstance(
+ event.content_block, ToolUseBlock
+ ):
+ buffer_data["tool_name"] = event.content_block.name # type: ignore
+ buffer_data["tool_id"] = event.content_block.id # type: ignore
+ buffer_data["tool_json_buffer"] = ""
+ buffer_data["is_collecting_tool"] = True
+
+ elif isinstance(event, RawContentBlockDeltaEvent):
+ delta_obj = getattr(event, "delta", None)
+ delta_type = getattr(delta_obj, "type", None)
+
+ # Handle thinking deltas
+ if delta_type == "thinking_delta" and hasattr(
+ delta_obj, "thinking"
+ ):
+ thinking_chunk = delta_obj.thinking # type: ignore
+ if buffer_data["is_collecting_thinking"]:
+ buffer_data["thinking_buffer"] += thinking_chunk
+ # Stream thinking chunks as they come in
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {"thinking": thinking_chunk}
+ chunks.append(chunk)
+
+ # Handle signature deltas for thinking blocks
+ elif delta_type == "signature_delta" and hasattr(
+ delta_obj, "signature"
+ ):
+ if buffer_data["is_collecting_thinking"]:
+ buffer_data["thinking_signature"] = delta_obj.signature # type: ignore
+ # No need to emit anything for the signature
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {
+ "thinking_signature": delta_obj.signature # type: ignore
+ }
+ chunks.append(chunk)
+
+ # Handle text deltas
+ elif delta_type == "text_delta" and hasattr(delta_obj, "text"):
+ text_chunk = delta_obj.text # type: ignore
+ if not buffer_data["is_collecting_tool"] and text_chunk:
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {"content": text_chunk}
+ chunks.append(chunk)
+
+ # Handle partial JSON for tools
+ elif hasattr(delta_obj, "partial_json"):
+ if buffer_data["is_collecting_tool"]:
+ buffer_data["tool_json_buffer"] += delta_obj.partial_json # type: ignore
+
+ elif isinstance(event, ContentBlockStopEvent):
+ # Handle the end of a thinking block
+ if buffer_data.get("is_collecting_thinking"):
+ # Emit a special "structured_content_delta" with the complete thinking block
+ if (
+ buffer_data["thinking_buffer"]
+ and buffer_data["thinking_signature"]
+ ):
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {
+ "structured_content": [
+ {
+ "type": "thinking",
+ "thinking": buffer_data["thinking_buffer"],
+ "signature": buffer_data["thinking_signature"],
+ }
+ ]
+ }
+ chunks.append(chunk)
+
+ # Reset thinking collection
+ buffer_data["is_collecting_thinking"] = False
+ buffer_data["thinking_buffer"] = ""
+ buffer_data["thinking_signature"] = None
+
+ # Handle the end of a tool use block
+ elif buffer_data.get("is_collecting_tool"):
+ try:
+ json.loads(buffer_data["tool_json_buffer"])
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {
+ "tool_calls": [
+ {
+ "index": 0,
+ "type": "function",
+ "id": buffer_data["tool_id"]
+ or f"call_{generate_tool_id()}",
+ "function": {
+ "name": buffer_data["tool_name"],
+ "arguments": buffer_data[
+ "tool_json_buffer"
+ ],
+ },
+ }
+ ]
+ }
+ chunks.append(chunk)
+ buffer_data["is_collecting_tool"] = False
+ buffer_data["tool_json_buffer"] = ""
+ buffer_data["tool_name"] = None
+ buffer_data["tool_id"] = None
+ except json.JSONDecodeError:
+ logger.warning(
+ "Incomplete JSON in tool call, skipping chunk"
+ )
+
+ elif isinstance(event, MessageStopEvent):
+ # Check if the event has a message attribute before accessing it
+ stop_reason = getattr(event, "message", None)
+ if stop_reason and hasattr(stop_reason, "stop_reason"):
+ stop_reason = stop_reason.stop_reason
+ chunk = make_base_chunk()
+ if stop_reason == "tool_use":
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = "tool_calls"
+ else:
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = "stop"
+ chunks.append(chunk)
+ else:
+ # Handle the case where message is not available
+ chunk = make_base_chunk()
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = "stop"
+ chunks.append(chunk)
+
+ return chunks
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py b/.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py
new file mode 100644
index 00000000..863e44ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/azure_foundry.py
@@ -0,0 +1,110 @@
+import logging
+import os
+from typing import Any, Optional
+
+from azure.ai.inference import (
+ ChatCompletionsClient as AzureChatCompletionsClient,
+)
+from azure.ai.inference.aio import (
+ ChatCompletionsClient as AsyncAzureChatCompletionsClient,
+)
+from azure.core.credentials import AzureKeyCredential
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+logger = logging.getLogger(__name__)
+
+
+class AzureFoundryCompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.azure_foundry_client: Optional[AzureChatCompletionsClient] = None
+ self.async_azure_foundry_client: Optional[
+ AsyncAzureChatCompletionsClient
+ ] = None
+
+ # Initialize Azure Foundry clients if credentials exist.
+ azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
+ azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
+
+ if azure_foundry_api_key and azure_foundry_api_endpoint:
+ self.azure_foundry_client = AzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ logger.debug("Azure Foundry clients initialized successfully")
+
+ def _get_base_args(
+ self, generation_config: GenerationConfig
+ ) -> dict[str, Any]:
+ # Construct arguments similar to the other providers.
+ args: dict[str, Any] = {
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ "max_tokens": generation_config.max_tokens_to_sample,
+ "temperature": generation_config.temperature,
+ }
+
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+ if generation_config.tools is not None:
+ args["tools"] = generation_config.tools
+ if generation_config.response_format is not None:
+ args["response_format"] = generation_config.response_format
+ return args
+
+ async def _execute_task(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ # Azure Foundry does not require a "model" argument; the endpoint is fixed.
+ args["messages"] = messages
+ args = {**args, **kwargs}
+ logger.debug(f"Executing async Azure Foundry task with args: {args}")
+
+ try:
+ if self.async_azure_foundry_client is None:
+ raise ValueError("Azure Foundry client is not initialized")
+
+ response = await self.async_azure_foundry_client.complete(**args)
+ logger.debug("Async Azure Foundry task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(
+ f"Async Azure Foundry task execution failed: {str(e)}"
+ )
+ raise
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+ args = {**args, **kwargs}
+ logger.debug(f"Executing sync Azure Foundry task with args: {args}")
+
+ try:
+ if self.azure_foundry_client is None:
+ raise ValueError("Azure Foundry client is not initialized")
+
+ response = self.azure_foundry_client.complete(**args)
+ logger.debug("Sync Azure Foundry task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(f"Sync Azure Foundry task execution failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py b/.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py
new file mode 100644
index 00000000..44d467c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/litellm.py
@@ -0,0 +1,80 @@
+import logging
+from typing import Any
+
+import litellm
+from litellm import acompletion, completion
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+logger = logging.getLogger()
+
+
+class LiteLLMCompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ litellm.modify_params = True
+ self.acompletion = acompletion
+ self.completion = completion
+
+ # if config.provider != "litellm":
+ # logger.error(f"Invalid provider: {config.provider}")
+ # raise ValueError(
+ # "LiteLLMCompletionProvider must be initialized with config with `litellm` provider."
+ # )
+
+ def _get_base_args(
+ self, generation_config: GenerationConfig
+ ) -> dict[str, Any]:
+ args: dict[str, Any] = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ "max_tokens": generation_config.max_tokens_to_sample,
+ "api_base": generation_config.api_base,
+ }
+
+ # Fix the type errors by properly typing these assignments
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+ if generation_config.tools is not None:
+ args["tools"] = generation_config.tools
+ if generation_config.response_format is not None:
+ args["response_format"] = generation_config.response_format
+
+ return args
+
+ async def _execute_task(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+ args = {**args, **kwargs}
+
+ logger.debug(
+ f"Executing LiteLLM task with generation_config={generation_config}"
+ )
+
+ return await self.acompletion(**args)
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+ args = {**args, **kwargs}
+
+ logger.debug(
+ f"Executing LiteLLM task with generation_config={generation_config}"
+ )
+
+ try:
+ return self.completion(**args)
+ except Exception as e:
+ logger.error(f"Sync LiteLLM task execution failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/openai.py b/.venv/lib/python3.12/site-packages/core/providers/llm/openai.py
new file mode 100644
index 00000000..30ef37ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/openai.py
@@ -0,0 +1,522 @@
+import logging
+import os
+from typing import Any
+
+from openai import AsyncAzureOpenAI, AsyncOpenAI, OpenAI
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+from .utils import resize_base64_image
+
+logger = logging.getLogger()
+
+
+class OpenAICompletionProvider(CompletionProvider):
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.openai_client = None
+ self.async_openai_client = None
+ self.azure_client = None
+ self.async_azure_client = None
+ self.deepseek_client = None
+ self.async_deepseek_client = None
+ self.ollama_client = None
+ self.async_ollama_client = None
+ self.lmstudio_client = None
+ self.async_lmstudio_client = None
+ # NEW: Azure Foundry clients using the Azure Inference API
+ self.azure_foundry_client = None
+ self.async_azure_foundry_client = None
+
+ # Initialize OpenAI clients if credentials exist
+ if os.getenv("OPENAI_API_KEY"):
+ self.openai_client = OpenAI()
+ self.async_openai_client = AsyncOpenAI()
+ logger.debug("OpenAI clients initialized successfully")
+
+ # Initialize Azure OpenAI clients if credentials exist
+ azure_api_key = os.getenv("AZURE_API_KEY")
+ azure_api_base = os.getenv("AZURE_API_BASE")
+ if azure_api_key and azure_api_base:
+ self.azure_client = AsyncAzureOpenAI(
+ api_key=azure_api_key,
+ api_version=os.getenv(
+ "AZURE_API_VERSION", "2024-02-15-preview"
+ ),
+ azure_endpoint=azure_api_base,
+ )
+ self.async_azure_client = AsyncAzureOpenAI(
+ api_key=azure_api_key,
+ api_version=os.getenv(
+ "AZURE_API_VERSION", "2024-02-15-preview"
+ ),
+ azure_endpoint=azure_api_base,
+ )
+ logger.debug("Azure OpenAI clients initialized successfully")
+
+ # Initialize Deepseek clients if credentials exist
+ deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
+ deepseek_api_base = os.getenv(
+ "DEEPSEEK_API_BASE", "https://api.deepseek.com"
+ )
+ if deepseek_api_key and deepseek_api_base:
+ self.deepseek_client = OpenAI(
+ api_key=deepseek_api_key,
+ base_url=deepseek_api_base,
+ )
+ self.async_deepseek_client = AsyncOpenAI(
+ api_key=deepseek_api_key,
+ base_url=deepseek_api_base,
+ )
+ logger.debug("Deepseek OpenAI clients initialized successfully")
+
+ # Initialize Ollama clients with default API key
+ ollama_api_base = os.getenv(
+ "OLLAMA_API_BASE", "http://localhost:11434/v1"
+ )
+ if ollama_api_base:
+ self.ollama_client = OpenAI(
+ api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
+ base_url=ollama_api_base,
+ )
+ self.async_ollama_client = AsyncOpenAI(
+ api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
+ base_url=ollama_api_base,
+ )
+ logger.debug("Ollama OpenAI clients initialized successfully")
+
+ # Initialize LMStudio clients
+ lmstudio_api_base = os.getenv(
+ "LMSTUDIO_API_BASE", "http://localhost:1234/v1"
+ )
+ if lmstudio_api_base:
+ self.lmstudio_client = OpenAI(
+ api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
+ base_url=lmstudio_api_base,
+ )
+ self.async_lmstudio_client = AsyncOpenAI(
+ api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
+ base_url=lmstudio_api_base,
+ )
+ logger.debug("LMStudio OpenAI clients initialized successfully")
+
+ # Initialize Azure Foundry clients if credentials exist.
+ # These use the Azure Inference API (currently pasted into this handler).
+ azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
+ azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
+ if azure_foundry_api_key and azure_foundry_api_endpoint:
+ from azure.ai.inference import (
+ ChatCompletionsClient as AzureChatCompletionsClient,
+ )
+ from azure.ai.inference.aio import (
+ ChatCompletionsClient as AsyncAzureChatCompletionsClient,
+ )
+ from azure.core.credentials import AzureKeyCredential
+
+ self.azure_foundry_client = AzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
+ endpoint=azure_foundry_api_endpoint,
+ credential=AzureKeyCredential(azure_foundry_api_key),
+ api_version=os.getenv(
+ "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
+ ),
+ )
+ logger.debug("Azure Foundry clients initialized successfully")
+
+ if not any(
+ [
+ self.openai_client,
+ self.azure_client,
+ self.ollama_client,
+ self.lmstudio_client,
+ self.azure_foundry_client,
+ ]
+ ):
+ raise ValueError(
+ "No valid client credentials found. Please set either OPENAI_API_KEY, "
+ "both AZURE_API_KEY and AZURE_API_BASE environment variables, "
+ "OLLAMA_API_BASE, LMSTUDIO_API_BASE, or AZURE_FOUNDRY_API_KEY and AZURE_FOUNDRY_API_ENDPOINT."
+ )
+
+ def _get_client_and_model(self, model: str):
+ """Determine which client to use based on model prefix and return the
+ appropriate client and model name."""
+ if model.startswith("azure/"):
+ if not self.azure_client:
+ raise ValueError(
+ "Azure OpenAI credentials not configured but azure/ model prefix used"
+ )
+ return self.azure_client, model[6:] # Strip 'azure/' prefix
+ elif model.startswith("openai/"):
+ if not self.openai_client:
+ raise ValueError(
+ "OpenAI credentials not configured but openai/ model prefix used"
+ )
+ return self.openai_client, model[7:] # Strip 'openai/' prefix
+ elif model.startswith("deepseek/"):
+ if not self.deepseek_client:
+ raise ValueError(
+ "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
+ )
+ return self.deepseek_client, model[9:] # Strip 'deepseek/' prefix
+ elif model.startswith("ollama/"):
+ if not self.ollama_client:
+ raise ValueError(
+ "Ollama OpenAI credentials not configured but ollama/ model prefix used"
+ )
+ return self.ollama_client, model[7:] # Strip 'ollama/' prefix
+ elif model.startswith("lmstudio/"):
+ if not self.lmstudio_client:
+ raise ValueError(
+ "LMStudio credentials not configured but lmstudio/ model prefix used"
+ )
+ return self.lmstudio_client, model[9:] # Strip 'lmstudio/' prefix
+ elif model.startswith("azure-foundry/"):
+ if not self.azure_foundry_client:
+ raise ValueError(
+ "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
+ )
+ return (
+ self.azure_foundry_client,
+ model[14:],
+ ) # Strip 'azure-foundry/' prefix
+ else:
+ # Default to OpenAI if no prefix is provided.
+ if self.openai_client:
+ return self.openai_client, model
+ elif self.azure_client:
+ return self.azure_client, model
+ elif self.ollama_client:
+ return self.ollama_client, model
+ elif self.lmstudio_client:
+ return self.lmstudio_client, model
+ elif self.azure_foundry_client:
+ return self.azure_foundry_client, model
+ else:
+ raise ValueError("No valid client available for model prefix")
+
+ def _get_async_client_and_model(self, model: str):
+ """Get async client and model name based on prefix."""
+ if model.startswith("azure/"):
+ if not self.async_azure_client:
+ raise ValueError(
+ "Azure OpenAI credentials not configured but azure/ model prefix used"
+ )
+ return self.async_azure_client, model[6:]
+ elif model.startswith("openai/"):
+ if not self.async_openai_client:
+ raise ValueError(
+ "OpenAI credentials not configured but openai/ model prefix used"
+ )
+ return self.async_openai_client, model[7:]
+ elif model.startswith("deepseek/"):
+ if not self.async_deepseek_client:
+ raise ValueError(
+ "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
+ )
+ return self.async_deepseek_client, model[9:].strip()
+ elif model.startswith("ollama/"):
+ if not self.async_ollama_client:
+ raise ValueError(
+ "Ollama OpenAI credentials not configured but ollama/ model prefix used"
+ )
+ return self.async_ollama_client, model[7:]
+ elif model.startswith("lmstudio/"):
+ if not self.async_lmstudio_client:
+ raise ValueError(
+ "LMStudio credentials not configured but lmstudio/ model prefix used"
+ )
+ return self.async_lmstudio_client, model[9:]
+ elif model.startswith("azure-foundry/"):
+ if not self.async_azure_foundry_client:
+ raise ValueError(
+ "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
+ )
+ return self.async_azure_foundry_client, model[14:]
+ else:
+ if self.async_openai_client:
+ return self.async_openai_client, model
+ elif self.async_azure_client:
+ return self.async_azure_client, model
+ elif self.async_ollama_client:
+ return self.async_ollama_client, model
+ elif self.async_lmstudio_client:
+ return self.async_lmstudio_client, model
+ elif self.async_azure_foundry_client:
+ return self.async_azure_foundry_client, model
+ else:
+ raise ValueError(
+ "No valid async client available for model prefix"
+ )
+
+ def _process_messages_with_images(
+ self, messages: list[dict]
+ ) -> list[dict]:
+ """
+ Process messages that may contain image_url or image_data fields.
+ Now includes aggressive image resizing similar to Anthropic provider.
+ """
+ processed_messages = []
+
+ for msg in messages:
+ if msg.get("role") == "system":
+ # System messages don't support content arrays in OpenAI
+ processed_messages.append(msg)
+ continue
+
+ # Check if the message contains image data
+ image_url = msg.pop("image_url", None)
+ image_data = msg.pop("image_data", None)
+ content = msg.get("content")
+
+ if image_url or image_data:
+ # Convert to content array format
+ new_content = []
+
+ # Add image content
+ if image_url:
+ new_content.append(
+ {"type": "image_url", "image_url": {"url": image_url}}
+ )
+ elif image_data:
+ # Resize the base64 image data if available
+ media_type = image_data.get("media_type", "image/jpeg")
+ data = image_data.get("data", "")
+
+ # Apply image resizing if PIL is available
+ if data:
+ data = resize_base64_image(data)
+ logger.debug(
+ f"Image resized, new size: {len(data)} chars"
+ )
+
+ # OpenAI expects base64 images in data URL format
+ data_url = f"data:{media_type};base64,{data}"
+ new_content.append(
+ {"type": "image_url", "image_url": {"url": data_url}}
+ )
+
+ # Add text content if present
+ if content:
+ new_content.append({"type": "text", "text": content})
+
+ # Update the message
+ new_msg = dict(msg)
+ new_msg["content"] = new_content
+ processed_messages.append(new_msg)
+ else:
+ processed_messages.append(msg)
+
+ return processed_messages
+
+ def _process_array_content_with_images(self, content: list) -> list:
+ """
+ Process content array that may contain image_url items.
+ Used for messages that already have content in array format.
+ """
+ if not content or not isinstance(content, list):
+ return content
+
+ processed_content = []
+
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "image_url":
+ # Process image URL if needed
+ processed_content.append(item)
+ elif item.get("type") == "image" and item.get("source"):
+ # Convert Anthropic-style to OpenAI-style
+ source = item.get("source", {})
+ if source.get("type") == "base64" and source.get("data"):
+ # Resize the base64 image data
+ resized_data = resize_base64_image(source.get("data"))
+
+ media_type = source.get("media_type", "image/jpeg")
+ data_url = f"data:{media_type};base64,{resized_data}"
+
+ processed_content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": data_url},
+ }
+ )
+ elif source.get("type") == "url" and source.get("url"):
+ processed_content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": source.get("url")},
+ }
+ )
+ else:
+ # Pass through other types
+ processed_content.append(item)
+ else:
+ processed_content.append(item)
+
+ return processed_content
+
+ def _preprocess_messages(self, messages: list[dict]) -> list[dict]:
+ """
+ Preprocess all messages to optimize images before sending to OpenAI API.
+ """
+ if not messages or not isinstance(messages, list):
+ return messages
+
+ processed_messages = []
+
+ for msg in messages:
+ # Skip system messages as they're handled separately
+ if msg.get("role") == "system":
+ processed_messages.append(msg)
+ continue
+
+ # Process array-format content (might contain images)
+ if isinstance(msg.get("content"), list):
+ new_msg = dict(msg)
+ new_msg["content"] = self._process_array_content_with_images(
+ msg["content"]
+ )
+ processed_messages.append(new_msg)
+ else:
+ # Standard processing for non-array content
+ processed_messages.append(msg)
+
+ return processed_messages
+
+ def _get_base_args(self, generation_config: GenerationConfig) -> dict:
+ # Keep existing implementation...
+ args: dict[str, Any] = {
+ "model": generation_config.model,
+ "stream": generation_config.stream,
+ }
+
+ model_str = generation_config.model or ""
+
+ if "o1" not in model_str and "o3" not in model_str:
+ args["max_tokens"] = generation_config.max_tokens_to_sample
+ args["temperature"] = generation_config.temperature
+ args["top_p"] = generation_config.top_p
+ else:
+ args["max_completion_tokens"] = (
+ generation_config.max_tokens_to_sample
+ )
+
+ if generation_config.reasoning_effort is not None:
+ args["reasoning_effort"] = generation_config.reasoning_effort
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+ if generation_config.tools is not None:
+ args["tools"] = generation_config.tools
+ if generation_config.response_format is not None:
+ args["response_format"] = generation_config.response_format
+ return args
+
+ async def _execute_task(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ # First preprocess to handle any images in array format
+ messages = self._preprocess_messages(messages)
+
+ # Then process messages with direct image_url or image_data fields
+ processed_messages = self._process_messages_with_images(messages)
+
+ args = self._get_base_args(generation_config)
+ client, model_name = self._get_async_client_and_model(args["model"])
+ args["model"] = model_name
+ args["messages"] = processed_messages
+ args = {**args, **kwargs}
+
+ # Check if we're using a vision-capable model when images are present
+ contains_images = any(
+ isinstance(msg.get("content"), list)
+ and any(
+ item.get("type") == "image_url"
+ for item in msg.get("content", [])
+ )
+ for msg in processed_messages
+ )
+
+ if contains_images:
+ vision_models = ["gpt-4-vision", "gpt-4o"]
+ if all(
+ vision_model in model_name for vision_model in vision_models
+ ):
+ logger.warning(
+ f"Using model {model_name} with images, but it may not support vision"
+ )
+
+ logger.debug(f"Executing async task with args: {args}")
+ try:
+ # Same as before...
+ if client == self.async_azure_foundry_client:
+ model_value = args.pop(
+ "model"
+ ) # Remove model before passing args
+ response = await client.complete(**args)
+ else:
+ response = await client.chat.completions.create(**args)
+ logger.debug("Async task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(f"Async task execution failed: {str(e)}")
+ # HACK: print the exception to the console for debugging
+ raise
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ messages = task["messages"]
+ generation_config = task["generation_config"]
+ kwargs = task["kwargs"]
+
+ # First preprocess to handle any images in array format
+ messages = self._preprocess_messages(messages)
+
+ # Then process messages with direct image_url or image_data fields
+ processed_messages = self._process_messages_with_images(messages)
+
+ args = self._get_base_args(generation_config)
+ client, model_name = self._get_client_and_model(args["model"])
+ args["model"] = model_name
+ args["messages"] = processed_messages
+ args = {**args, **kwargs}
+
+ # Same vision model check as in async version
+ contains_images = any(
+ isinstance(msg.get("content"), list)
+ and any(
+ item.get("type") == "image_url"
+ for item in msg.get("content", [])
+ )
+ for msg in processed_messages
+ )
+
+ if contains_images:
+ vision_models = ["gpt-4-vision", "gpt-4o"]
+ if all(
+ vision_model in model_name for vision_model in vision_models
+ ):
+ logger.warning(
+ f"Using model {model_name} with images, but it may not support vision"
+ )
+
+ logger.debug(f"Executing sync OpenAI task with args: {args}")
+ try:
+ # Same as before...
+ if client == self.azure_foundry_client:
+ args.pop("model")
+ response = client.complete(**args)
+ else:
+ response = client.chat.completions.create(**args)
+ logger.debug("Sync task executed successfully")
+ return response
+ except Exception as e:
+ logger.error(f"Sync task execution failed: {str(e)}")
+ raise
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py b/.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py
new file mode 100644
index 00000000..b95b310a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/r2r_llm.py
@@ -0,0 +1,96 @@
+import logging
+from typing import Any
+
+from core.base.abstractions import GenerationConfig
+from core.base.providers.llm import CompletionConfig, CompletionProvider
+
+from .anthropic import AnthropicCompletionProvider
+from .azure_foundry import AzureFoundryCompletionProvider
+from .litellm import LiteLLMCompletionProvider
+from .openai import OpenAICompletionProvider
+
+logger = logging.getLogger(__name__)
+
+
+class R2RCompletionProvider(CompletionProvider):
+ """A provider that routes to the right LLM provider (R2R):
+
+ - If `generation_config.model` starts with "anthropic/", call AnthropicCompletionProvider.
+ - If it starts with "azure-foundry/", call AzureFoundryCompletionProvider.
+ - If it starts with one of the other OpenAI-like prefixes ("openai/", "azure/", "deepseek/", "ollama/", "lmstudio/")
+ or has no prefix (e.g. "gpt-4", "gpt-3.5"), call OpenAICompletionProvider.
+ - Otherwise, fallback to LiteLLMCompletionProvider.
+ """
+
+ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
+ """Initialize sub-providers for OpenAI, Anthropic, LiteLLM, and Azure
+ Foundry."""
+ super().__init__(config)
+ self.config = config
+
+ logger.info("Initializing R2RCompletionProvider...")
+ self._openai_provider = OpenAICompletionProvider(
+ self.config, *args, **kwargs
+ )
+ self._anthropic_provider = AnthropicCompletionProvider(
+ self.config, *args, **kwargs
+ )
+ self._litellm_provider = LiteLLMCompletionProvider(
+ self.config, *args, **kwargs
+ )
+ self._azure_foundry_provider = AzureFoundryCompletionProvider(
+ self.config, *args, **kwargs
+ ) # New provider
+
+ logger.debug(
+ "R2RCompletionProvider initialized with OpenAI, Anthropic, LiteLLM, and Azure Foundry sub-providers."
+ )
+
+ def _choose_subprovider_by_model(
+ self, model_name: str, is_streaming: bool = False
+ ) -> CompletionProvider:
+ """Decide which underlying sub-provider to call based on the model name
+ (prefix)."""
+ # Route to Anthropic if appropriate.
+ if model_name.startswith("anthropic/"):
+ return self._anthropic_provider
+
+ # Route to Azure Foundry explicitly.
+ if model_name.startswith("azure-foundry/"):
+ return self._azure_foundry_provider
+
+ # OpenAI-like prefixes.
+ openai_like_prefixes = [
+ "openai/",
+ "azure/",
+ "deepseek/",
+ "ollama/",
+ "lmstudio/",
+ ]
+ if (
+ any(
+ model_name.startswith(prefix)
+ for prefix in openai_like_prefixes
+ )
+ or "/" not in model_name
+ ):
+ return self._openai_provider
+
+ # Fallback to LiteLLM.
+ return self._litellm_provider
+
+ async def _execute_task(self, task: dict[str, Any]):
+ """Pick the sub-provider based on model name and forward the async
+ call."""
+ generation_config: GenerationConfig = task["generation_config"]
+ model_name = generation_config.model
+ sub_provider = self._choose_subprovider_by_model(model_name or "")
+ return await sub_provider._execute_task(task)
+
+ def _execute_task_sync(self, task: dict[str, Any]):
+ """Pick the sub-provider based on model name and forward the sync
+ call."""
+ generation_config: GenerationConfig = task["generation_config"]
+ model_name = generation_config.model
+ sub_provider = self._choose_subprovider_by_model(model_name or "")
+ return sub_provider._execute_task_sync(task)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/llm/utils.py b/.venv/lib/python3.12/site-packages/core/providers/llm/utils.py
new file mode 100644
index 00000000..619b2e73
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/llm/utils.py
@@ -0,0 +1,106 @@
+import base64
+import io
+import logging
+from typing import Tuple
+
+from PIL import Image
+
+logger = logging.getLogger()
+
+
+def resize_base64_image(
+ base64_string: str,
+ max_size: Tuple[int, int] = (512, 512),
+ max_megapixels: float = 0.25,
+) -> str:
+ """Aggressively resize images with better error handling and debug output"""
+ logger.debug(
+ f"RESIZING NOW!!! Original length: {len(base64_string)} chars"
+ )
+
+ # Decode base64 string to bytes
+ try:
+ image_data = base64.b64decode(base64_string)
+ image = Image.open(io.BytesIO(image_data))
+ logger.debug(f"Image opened successfully: {image.format} {image.size}")
+ except Exception as e:
+ logger.debug(f"Failed to decode/open image: {e}")
+ # Emergency fallback - truncate the base64 string to reduce tokens
+ if len(base64_string) > 50000:
+ return base64_string[:50000]
+ return base64_string
+
+ try:
+ width, height = image.size
+ current_megapixels = (width * height) / 1_000_000
+ logger.debug(
+ f"Original dimensions: {width}x{height} ({current_megapixels:.2f} MP)"
+ )
+
+ # MUCH more aggressive resizing for large images
+ if current_megapixels > 0.5:
+ max_size = (384, 384)
+ max_megapixels = 0.15
+ logger.debug("Large image detected! Using more aggressive limits")
+
+ # Calculate new dimensions with strict enforcement
+ # Always resize if the image is larger than we want
+ scale_factor = min(
+ max_size[0] / width,
+ max_size[1] / height,
+ (max_megapixels / current_megapixels) ** 0.5,
+ )
+
+ if scale_factor >= 1.0:
+ # No resize needed, but still compress
+ new_width, new_height = width, height
+ else:
+ # Apply scaling
+ new_width = max(int(width * scale_factor), 64) # Min width
+ new_height = max(int(height * scale_factor), 64) # Min height
+
+ # Always resize/recompress the image
+ logger.debug(f"Resizing to: {new_width}x{new_height}")
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS) # type: ignore
+
+ # Convert back to base64 with strong compression
+ buffer = io.BytesIO()
+ if image.format == "JPEG" or image.format is None:
+ # Apply very aggressive JPEG compression
+ quality = 50 # Very low quality to reduce size
+ resized_image.save(
+ buffer, format="JPEG", quality=quality, optimize=True
+ )
+ else:
+ # For other formats
+ resized_image.save(
+ buffer, format=image.format or "PNG", optimize=True
+ )
+
+ resized_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+ logger.debug(
+ f"Resized base64 length: {len(resized_base64)} chars (reduction: {100 * (1 - len(resized_base64) / len(base64_string)):.1f}%)"
+ )
+ return resized_base64
+
+ except Exception as e:
+ logger.debug(f"Error during resize: {e}")
+ # If anything goes wrong, truncate the base64 to a reasonable size
+ if len(base64_string) > 50000:
+ return base64_string[:50000]
+ return base64_string
+
+
+def estimate_image_tokens(width: int, height: int) -> int:
+ """
+ Estimate the number of tokens an image will use based on Anthropic's formula.
+
+ Args:
+ width: Image width in pixels
+ height: Image height in pixels
+
+ Returns:
+ Estimated number of tokens
+ """
+ return int((width * height) / 750)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py
new file mode 100644
index 00000000..b41d79b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/__init__.py
@@ -0,0 +1,4 @@
+from .hatchet import HatchetOrchestrationProvider
+from .simple import SimpleOrchestrationProvider
+
+__all__ = ["HatchetOrchestrationProvider", "SimpleOrchestrationProvider"]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py
new file mode 100644
index 00000000..941e2048
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/hatchet.py
@@ -0,0 +1,105 @@
+# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
+import asyncio
+import logging
+from typing import Any, Callable, Optional
+
+from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
+
+logger = logging.getLogger()
+
+
+class HatchetOrchestrationProvider(OrchestrationProvider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ try:
+ from hatchet_sdk import ClientConfig, Hatchet
+ except ImportError:
+ raise ImportError(
+ "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`."
+ ) from None
+ root_logger = logging.getLogger()
+
+ self.orchestrator = Hatchet(
+ config=ClientConfig(
+ logger=root_logger,
+ ),
+ )
+ self.root_logger = root_logger
+ self.config: OrchestrationConfig = config
+ self.messages: dict[str, str] = {}
+
+ def workflow(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.workflow(*args, **kwargs)
+
+ def step(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.step(*args, **kwargs)
+
+ def failure(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.on_failure_step(*args, **kwargs)
+
+ def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any:
+ if not max_runs:
+ max_runs = self.config.max_runs
+ self.worker = self.orchestrator.worker(name, max_runs) # type: ignore
+ return self.worker
+
+ def concurrency(self, *args, **kwargs) -> Callable:
+ return self.orchestrator.concurrency(*args, **kwargs)
+
+ async def start_worker(self):
+ if not self.worker:
+ raise ValueError(
+ "Worker not initialized. Call get_worker() first."
+ )
+
+ asyncio.create_task(self.worker.async_start())
+
+ async def run_workflow(
+ self,
+ workflow_name: str,
+ parameters: dict,
+ options: dict,
+ *args,
+ **kwargs,
+ ) -> Any:
+ task_id = self.orchestrator.admin.run_workflow( # type: ignore
+ workflow_name,
+ parameters,
+ options=options, # type: ignore
+ *args,
+ **kwargs,
+ )
+ return {
+ "task_id": str(task_id),
+ "message": self.messages.get(
+ workflow_name, "Workflow queued successfully."
+ ), # Return message based on workflow name
+ }
+
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ self.messages.update(messages)
+
+ logger.info(
+ f"Registering workflows for {workflow} with messages {messages}."
+ )
+ if workflow == Workflow.INGESTION:
+ from core.main.orchestration.hatchet.ingestion_workflow import ( # type: ignore
+ hatchet_ingestion_factory,
+ )
+
+ workflows = hatchet_ingestion_factory(self, service)
+ if self.worker:
+ for workflow in workflows.values():
+ self.worker.register_workflow(workflow)
+
+ elif workflow == Workflow.GRAPH:
+ from core.main.orchestration.hatchet.graph_workflow import ( # type: ignore
+ hatchet_graph_search_results_factory,
+ )
+
+ workflows = hatchet_graph_search_results_factory(self, service)
+ if self.worker:
+ for workflow in workflows.values():
+ self.worker.register_workflow(workflow)
diff --git a/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py
new file mode 100644
index 00000000..33028afe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/orchestration/simple.py
@@ -0,0 +1,61 @@
+from typing import Any
+
+from core.base import OrchestrationConfig, OrchestrationProvider, Workflow
+
+
+class SimpleOrchestrationProvider(OrchestrationProvider):
+ def __init__(self, config: OrchestrationConfig):
+ super().__init__(config)
+ self.config = config
+ self.messages: dict[str, str] = {}
+
+ async def start_worker(self):
+ pass
+
+ def get_worker(self, name: str, max_runs: int) -> Any:
+ pass
+
+ def step(self, *args, **kwargs) -> Any:
+ pass
+
+ def workflow(self, *args, **kwargs) -> Any:
+ pass
+
+ def failure(self, *args, **kwargs) -> Any:
+ pass
+
+ def register_workflows(
+ self, workflow: Workflow, service: Any, messages: dict
+ ) -> None:
+ for key, msg in messages.items():
+ self.messages[key] = msg
+
+ if workflow == Workflow.INGESTION:
+ from core.main.orchestration import simple_ingestion_factory
+
+ self.ingestion_workflows = simple_ingestion_factory(service)
+
+ elif workflow == Workflow.GRAPH:
+ from core.main.orchestration.simple.graph_workflow import (
+ simple_graph_search_results_factory,
+ )
+
+ self.graph_search_results_workflows = (
+ simple_graph_search_results_factory(service)
+ )
+
+ async def run_workflow(
+ self, workflow_name: str, parameters: dict, options: dict
+ ) -> dict[str, str]:
+ if workflow_name in self.ingestion_workflows:
+ await self.ingestion_workflows[workflow_name](
+ parameters.get("request")
+ )
+ return {"message": self.messages[workflow_name]}
+ elif workflow_name in self.graph_search_results_workflows:
+ await self.graph_search_results_workflows[workflow_name](
+ parameters.get("request")
+ )
+ return {"message": self.messages[workflow_name]}
+ else:
+ raise ValueError(f"Workflow '{workflow_name}' not found.")