about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/users.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/users.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/users.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/users.py1325
1 files changed, 1325 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/users.py b/.venv/lib/python3.12/site-packages/core/providers/database/users.py
new file mode 100644
index 00000000..208eeaa4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/users.py
@@ -0,0 +1,1325 @@
+import csv
+import json
+import tempfile
+from datetime import datetime
+from typing import IO, Optional
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core.base import CryptoProvider, Handler
+from core.base.abstractions import R2RException
+from core.utils import generate_user_id
+from shared.abstractions import User
+
+from .base import PostgresConnectionManager, QueryBuilder
+from .collections import PostgresCollectionsHandler
+
+
+def _merge_metadata(
+    existing_metadata: dict[str, str], new_metadata: dict[str, Optional[str]]
+) -> dict[str, str]:
+    """
+    Merges the new metadata with the existing metadata in the Stripe-style approach:
+      - new_metadata[key] = <string> => update or add that key
+      - new_metadata[key] = ""       => remove that key
+      - if new_metadata is empty => remove all keys
+    """
+    # If new_metadata is an empty dict, it signals removal of all keys.
+    if new_metadata == {}:
+        return {}
+
+    # Copy so we don't mutate the original
+    final_metadata = dict(existing_metadata)
+
+    for key, value in new_metadata.items():
+        # If the user sets the key to an empty string, it means "delete" that key
+        if value == "":
+            if key in final_metadata:
+                del final_metadata[key]
+        # If not None and not empty, set or override
+        elif value is not None:
+            final_metadata[key] = value
+        else:
+            # If the user sets the value to None in some contexts, decide if you want to remove or ignore
+            # For now we might treat None same as empty string => remove
+            if key in final_metadata:
+                del final_metadata[key]
+
+    return final_metadata
+
+
+class PostgresUserHandler(Handler):
+    TABLE_NAME = "users"
+    API_KEYS_TABLE_NAME = "users_api_keys"
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        crypto_provider: CryptoProvider,
+    ):
+        super().__init__(project_name, connection_manager)
+        self.crypto_provider = crypto_provider
+
+    async def create_tables(self):
+        user_table_query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            email TEXT UNIQUE NOT NULL,
+            hashed_password TEXT NOT NULL,
+            is_superuser BOOLEAN DEFAULT FALSE,
+            is_active BOOLEAN DEFAULT TRUE,
+            is_verified BOOLEAN DEFAULT FALSE,
+            verification_code TEXT,
+            verification_code_expiry TIMESTAMPTZ,
+            name TEXT,
+            bio TEXT,
+            profile_picture TEXT,
+            reset_token TEXT,
+            reset_token_expiry TIMESTAMPTZ,
+            collection_ids UUID[] NULL,
+            limits_overrides JSONB,
+            metadata JSONB,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            updated_at TIMESTAMPTZ DEFAULT NOW(),
+            account_type TEXT NOT NULL DEFAULT 'password',
+            google_id TEXT,
+            github_id TEXT
+        );
+        """
+
+        # API keys table with updated_at instead of last_used_at
+        api_keys_table_query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
+            public_key TEXT UNIQUE NOT NULL,
+            hashed_key TEXT NOT NULL,
+            name TEXT,
+            description TEXT,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            updated_at TIMESTAMPTZ DEFAULT NOW()
+        );
+
+        CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
+        ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
+
+        CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
+        ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
+        """
+
+        await self.connection_manager.execute_query(user_table_query)
+        await self.connection_manager.execute_query(api_keys_table_query)
+
+        # (New) Code snippet for adding columns if missing
+        # Postgres >= 9.6 supports "ADD COLUMN IF NOT EXISTS"
+        check_columns_query = f"""
+        ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+            ADD COLUMN IF NOT EXISTS metadata JSONB;
+
+        ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+            ADD COLUMN IF NOT EXISTS limits_overrides JSONB;
+
+        ALTER TABLE {self._get_table_name(self.API_KEYS_TABLE_NAME)}
+            ADD COLUMN IF NOT EXISTS description TEXT;
+        """
+        await self.connection_manager.execute_query(check_columns_query)
+
+        # Optionally, create indexes for quick lookups:
+        check_columns_query = f"""
+        ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
+            ADD COLUMN IF NOT EXISTS account_type TEXT NOT NULL DEFAULT 'password',
+            ADD COLUMN IF NOT EXISTS google_id TEXT,
+            ADD COLUMN IF NOT EXISTS github_id TEXT;
+
+        CREATE INDEX IF NOT EXISTS idx_users_google_id
+            ON {self._get_table_name(self.TABLE_NAME)}(google_id);
+        CREATE INDEX IF NOT EXISTS idx_users_github_id
+            ON {self._get_table_name(self.TABLE_NAME)}(github_id);
+        """
+        await self.connection_manager.execute_query(check_columns_query)
+
+    async def get_user_by_id(self, id: UUID) -> User:
+        query, _ = (
+            QueryBuilder(self._get_table_name("users"))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "name",
+                    "profile_picture",
+                    "bio",
+                    "collection_ids",
+                    "limits_overrides",
+                    "metadata",
+                    "account_type",
+                    "hashed_password",
+                    "google_id",
+                    "github_id",
+                ]
+            )
+            .where("id = $1")
+            .build()
+        )
+        result = await self.connection_manager.fetchrow_query(query, [id])
+
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            metadata=json.loads(result["metadata"] or "{}"),
+            hashed_password=result["hashed_password"],
+            account_type=result["account_type"],
+            google_id=result["google_id"],
+            github_id=result["github_id"],
+        )
+
+    async def get_user_by_email(self, email: str) -> User:
+        query, params = (
+            QueryBuilder(self._get_table_name("users"))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "name",
+                    "profile_picture",
+                    "bio",
+                    "collection_ids",
+                    "metadata",
+                    "limits_overrides",
+                    "account_type",
+                    "hashed_password",
+                    "google_id",
+                    "github_id",
+                ]
+            )
+            .where("email = $1")
+            .build()
+        )
+        result = await self.connection_manager.fetchrow_query(query, [email])
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            metadata=json.loads(result["metadata"] or "{}"),
+            account_type=result["account_type"],
+            hashed_password=result["hashed_password"],
+            google_id=result["google_id"],
+            github_id=result["github_id"],
+        )
+
+    async def create_user(
+        self,
+        email: str,
+        password: Optional[str] = None,
+        account_type: Optional[str] = "password",
+        google_id: Optional[str] = None,
+        github_id: Optional[str] = None,
+        is_superuser: bool = False,
+        name: Optional[str] = None,
+        bio: Optional[str] = None,
+        profile_picture: Optional[str] = None,
+    ) -> User:
+        """Create a new user."""
+        # 1) Check if a user with this email already exists
+        try:
+            existing = await self.get_user_by_email(email)
+            if existing:
+                raise R2RException(
+                    status_code=400,
+                    message="User with this email already exists",
+                )
+        except R2RException as e:
+            if e.status_code != 404:
+                raise e
+        # 2) If google_id is provided, ensure no user already has it
+        if google_id:
+            existing_google_user = await self.get_user_by_google_id(google_id)
+            if existing_google_user:
+                raise R2RException(
+                    status_code=400,
+                    message="User with this Google account already exists",
+                )
+
+        # 3) If github_id is provided, ensure no user already has it
+        if github_id:
+            existing_github_user = await self.get_user_by_github_id(github_id)
+            if existing_github_user:
+                raise R2RException(
+                    status_code=400,
+                    message="User with this GitHub account already exists",
+                )
+
+        hashed_password = None
+        if account_type == "password":
+            if password is None:
+                raise R2RException(
+                    status_code=400,
+                    message="Password is required for a 'password' account_type",
+                )
+            hashed_password = self.crypto_provider.get_password_hash(password)  # type: ignore
+
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .insert(
+                {
+                    "email": email,
+                    "id": generate_user_id(email),
+                    "is_superuser": is_superuser,
+                    "collection_ids": [],
+                    "limits_overrides": None,
+                    "metadata": None,
+                    "account_type": account_type,
+                    "hashed_password": hashed_password
+                    or "",  # Ensure hashed_password is not None
+                    # !!WARNING - Upstream checks are required to treat oauth differently from password!!
+                    "google_id": google_id,
+                    "github_id": github_id,
+                    "is_verified": account_type != "password",
+                    "name": name,
+                    "bio": bio,
+                    "profile_picture": profile_picture,
+                }
+            )
+            .returning(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "collection_ids",
+                    "limits_overrides",
+                    "metadata",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                ]
+            )
+            .build()
+        )
+
+        result = await self.connection_manager.fetchrow_query(query, params)
+        if not result:
+            raise R2RException(
+                status_code=500,
+                message="Failed to create user",
+            )
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            collection_ids=result["collection_ids"] or [],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            metadata=json.loads(result["metadata"] or "{}"),
+            name=result["name"],
+            bio=result["bio"],
+            profile_picture=result["profile_picture"],
+            account_type=account_type or "password",
+            hashed_password=hashed_password,
+            google_id=google_id,
+            github_id=github_id,
+        )
+
+    async def update_user(
+        self,
+        user: User,
+        merge_limits: bool = False,
+        new_metadata: dict[str, Optional[str]] | None = None,
+    ) -> User:
+        """Update user information including limits_overrides.
+
+        Args:
+            user: User object containing updated information
+            merge_limits: If True, will merge existing limits_overrides with new ones.
+                        If False, will overwrite existing limits_overrides.
+
+        Returns:
+            Updated User object
+        """
+
+        # Get current user if we need to merge limits or get hashed password
+        current_user = None
+        try:
+            current_user = await self.get_user_by_id(user.id)
+        except R2RException:
+            raise R2RException(
+                status_code=404, message="User not found"
+            ) from None
+
+        # If the new user.google_id != current_user.google_id, check for duplicates
+        if user.email and (user.email != current_user.email):
+            existing_email_user = await self.get_user_by_email(user.email)
+            if existing_email_user and existing_email_user.id != user.id:
+                raise R2RException(
+                    status_code=400,
+                    message="That email account is already associated with another user.",
+                )
+
+        # If the new user.google_id != current_user.google_id, check for duplicates
+        if user.google_id and (user.google_id != current_user.google_id):
+            existing_google_user = await self.get_user_by_google_id(
+                user.google_id
+            )
+            if existing_google_user and existing_google_user.id != user.id:
+                raise R2RException(
+                    status_code=400,
+                    message="That Google account is already associated with another user.",
+                )
+
+        # Similarly for GitHub:
+        if user.github_id and (user.github_id != current_user.github_id):
+            existing_github_user = await self.get_user_by_github_id(
+                user.github_id
+            )
+            if existing_github_user and existing_github_user.id != user.id:
+                raise R2RException(
+                    status_code=400,
+                    message="That GitHub account is already associated with another user.",
+                )
+
+        # Merge or replace metadata if provided
+        final_metadata = current_user.metadata or {}
+        if new_metadata is not None:
+            final_metadata = _merge_metadata(final_metadata, new_metadata)
+
+        # Merge or replace limits_overrides
+        final_limits = user.limits_overrides
+        if (
+            merge_limits
+            and current_user.limits_overrides
+            and user.limits_overrides
+        ):
+            final_limits = {
+                **current_user.limits_overrides,
+                **user.limits_overrides,
+            }
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET email = $1,
+                is_superuser = $2,
+                is_active = $3,
+                is_verified = $4,
+                updated_at = NOW(),
+                name = $5,
+                profile_picture = $6,
+                bio = $7,
+                collection_ids = $8,
+                limits_overrides = $9::jsonb,
+                metadata = $10::jsonb
+            WHERE id = $11
+            RETURNING id, email, is_superuser, is_active, is_verified,
+                    created_at, updated_at, name, profile_picture, bio,
+                    collection_ids, limits_overrides, metadata, hashed_password,
+                    account_type, google_id, github_id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query,
+            [
+                user.email,
+                user.is_superuser,
+                user.is_active,
+                user.is_verified,
+                user.name,
+                user.profile_picture,
+                user.bio,
+                user.collection_ids or [],
+                json.dumps(final_limits),
+                json.dumps(final_metadata),
+                user.id,
+            ],
+        )
+
+        if not result:
+            raise HTTPException(
+                status_code=500,
+                detail="Failed to update user",
+            )
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"]
+            or [],  # Ensure null becomes empty array
+            limits_overrides=json.loads(
+                result["limits_overrides"] or "{}"
+            ),  # Can be null
+            metadata=json.loads(result["metadata"] or "{}"),
+            account_type=result["account_type"],
+            hashed_password=result[
+                "hashed_password"
+            ],  # Include hashed_password
+            google_id=result["google_id"],
+            github_id=result["github_id"],
+        )
+
+    async def delete_user_relational(self, id: UUID) -> None:
+        """Delete a user and update related records."""
+        # Get the collections the user belongs to
+        collection_query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(["collection_ids"])
+            .where("id = $1")
+            .build()
+        )
+
+        collection_result = await self.connection_manager.fetchrow_query(
+            collection_query, [id]
+        )
+
+        if not collection_result:
+            raise R2RException(status_code=404, message="User not found")
+
+        # Update documents query
+        doc_update_query, doc_params = (
+            QueryBuilder(self._get_table_name("documents"))
+            .update({"id": None})
+            .where("id = $1")
+            .build()
+        )
+
+        await self.connection_manager.execute_query(doc_update_query, [id])
+
+        # Delete user query
+        delete_query, del_params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .delete()
+            .where("id = $1")
+            .returning(["id"])
+            .build()
+        )
+
+        result = await self.connection_manager.fetchrow_query(
+            delete_query, [id]
+        )
+
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+    async def update_user_password(self, id: UUID, new_hashed_password: str):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET hashed_password = $1, updated_at = NOW()
+            WHERE id = $2
+        """
+        await self.connection_manager.execute_query(
+            query, [new_hashed_password, id]
+        )
+
+    async def get_all_users(self) -> list[User]:
+        """Get all users with minimal information."""
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "collection_ids",
+                    "hashed_password",
+                    "limits_overrides",
+                    "metadata",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                    "account_type",
+                    "google_id",
+                    "github_id",
+                ]
+            )
+            .build()
+        )
+
+        results = await self.connection_manager.fetch_query(query, params)
+        return [
+            User(
+                id=result["id"],
+                email=result["email"],
+                is_superuser=result["is_superuser"],
+                is_active=result["is_active"],
+                is_verified=result["is_verified"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+                collection_ids=result["collection_ids"] or [],
+                limits_overrides=json.loads(
+                    result["limits_overrides"] or "{}"
+                ),
+                metadata=json.loads(result["metadata"] or "{}"),
+                name=result["name"],
+                bio=result["bio"],
+                profile_picture=result["profile_picture"],
+                account_type=result["account_type"],
+                hashed_password=result["hashed_password"],
+                google_id=result["google_id"],
+                github_id=result["github_id"],
+            )
+            for result in results
+        ]
+
+    async def store_verification_code(
+        self, id: UUID, verification_code: str, expiry: datetime
+    ):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET verification_code = $1, verification_code_expiry = $2
+            WHERE id = $3
+        """
+        await self.connection_manager.execute_query(
+            query, [verification_code, expiry, id]
+        )
+
+    async def verify_user(self, verification_code: str) -> None:
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            WHERE verification_code = $1 AND verification_code_expiry > NOW()
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [verification_code]
+        )
+
+        if not result:
+            raise R2RException(
+                status_code=400, message="Invalid or expired verification code"
+            )
+
+    async def remove_verification_code(self, verification_code: str):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET verification_code = NULL, verification_code_expiry = NULL
+            WHERE verification_code = $1
+        """
+        await self.connection_manager.execute_query(query, [verification_code])
+
+    async def expire_verification_code(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET verification_code_expiry = NOW() - INTERVAL '1 day'
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def store_reset_token(
+        self, id: UUID, reset_token: str, expiry: datetime
+    ):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET reset_token = $1, reset_token_expiry = $2
+            WHERE id = $3
+        """
+        await self.connection_manager.execute_query(
+            query, [reset_token, expiry, id]
+        )
+
+    async def get_user_id_by_reset_token(
+        self, reset_token: str
+    ) -> Optional[UUID]:
+        query = f"""
+            SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            WHERE reset_token = $1 AND reset_token_expiry > NOW()
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [reset_token]
+        )
+        return result["id"] if result else None
+
+    async def remove_reset_token(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET reset_token = NULL, reset_token_expiry = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def remove_user_from_all_collections(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET collection_ids = ARRAY[]::UUID[]
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def add_user_to_collection(
+        self, id: UUID, collection_id: UUID
+    ) -> bool:
+        # Check if the user exists
+        if not await self.get_user_by_id(id):
+            raise R2RException(status_code=404, message="User not found")
+
+        # Check if the collection exists
+        if not await self._collection_exists(collection_id):
+            raise R2RException(status_code=404, message="Collection not found")
+
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET collection_ids = array_append(collection_ids, $1)
+            WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id, id]
+        )
+        if not result:
+            raise R2RException(
+                status_code=400, message="User already in collection"
+            )
+
+        update_collection_query = f"""
+            UPDATE {self._get_table_name("collections")}
+            SET user_count = user_count + 1
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(
+            query=update_collection_query,
+            params=[collection_id],
+        )
+
+        return True
+
+    async def remove_user_from_collection(
+        self, id: UUID, collection_id: UUID
+    ) -> bool:
+        if not await self.get_user_by_id(id):
+            raise R2RException(status_code=404, message="User not found")
+
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET collection_ids = array_remove(collection_ids, $1)
+            WHERE id = $2 AND $1 = ANY(collection_ids)
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id, id]
+        )
+        if not result:
+            raise R2RException(
+                status_code=400,
+                message="User is not a member of the specified collection",
+            )
+        return True
+
+    async def get_users_in_collection(
+        self, collection_id: UUID, offset: int, limit: int
+    ) -> dict[str, list[User] | int]:
+        """Get all users in a specific collection with pagination."""
+        if not await self._collection_exists(collection_id):
+            raise R2RException(status_code=404, message="Collection not found")
+
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_active",
+                    "is_superuser",
+                    "created_at",
+                    "updated_at",
+                    "is_verified",
+                    "collection_ids",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                    "limits_overrides",
+                    "metadata",
+                    "account_type",
+                    "hashed_password",
+                    "google_id",
+                    "github_id",
+                    "COUNT(*) OVER() AS total_entries",
+                ]
+            )
+            .where("$1 = ANY(collection_ids)")
+            .order_by("name")
+            .offset("$2")
+            .limit("$3" if limit != -1 else None)
+            .build()
+        )
+
+        conditions = [collection_id, offset]
+        if limit != -1:
+            conditions.append(limit)
+
+        results = await self.connection_manager.fetch_query(query, conditions)
+
+        users_list = [
+            User(
+                id=row["id"],
+                email=row["email"],
+                is_active=row["is_active"],
+                is_superuser=row["is_superuser"],
+                created_at=row["created_at"],
+                updated_at=row["updated_at"],
+                is_verified=row["is_verified"],
+                collection_ids=row["collection_ids"] or [],
+                name=row["name"],
+                bio=row["bio"],
+                profile_picture=row["profile_picture"],
+                limits_overrides=json.loads(row["limits_overrides"] or "{}"),
+                metadata=json.loads(row["metadata"] or "{}"),
+                account_type=row["account_type"],
+                hashed_password=row["hashed_password"],
+                google_id=row["google_id"],
+                github_id=row["github_id"],
+            )
+            for row in results
+        ]
+
+        total_entries = results[0]["total_entries"] if results else 0
+        return {"results": users_list, "total_entries": total_entries}
+
+    async def mark_user_as_superuser(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET is_superuser = TRUE, is_verified = TRUE,
+                verification_code = NULL, verification_code_expiry = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def get_user_id_by_verification_code(
+        self, verification_code: str
+    ) -> UUID:
+        query = f"""
+            SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            WHERE verification_code = $1 AND verification_code_expiry > NOW()
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [verification_code]
+        )
+
+        if not result:
+            raise R2RException(
+                status_code=400, message="Invalid or expired verification code"
+            )
+
+        return result["id"]
+
+    async def mark_user_as_verified(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET is_verified = TRUE,
+                verification_code = NULL,
+                verification_code_expiry = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def get_users_overview(
+        self,
+        offset: int,
+        limit: int,
+        user_ids: Optional[list[UUID]] = None,
+    ) -> dict[str, list[User] | int]:
+        """Return users with document usage and total entries."""
+        query = f"""
+            WITH user_document_ids AS (
+                SELECT
+                    u.id as user_id,
+                    ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
+                FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+                LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
+                GROUP BY u.id
+            ),
+            user_docs AS (
+                SELECT
+                    u.id,
+                    u.email,
+                    u.is_superuser,
+                    u.is_active,
+                    u.is_verified,
+                    u.name,
+                    u.bio,
+                    u.profile_picture,
+                    u.collection_ids,
+                    u.created_at,
+                    u.updated_at,
+                    COUNT(d.id) AS num_files,
+                    COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
+                    ud.doc_ids as document_ids
+                FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+                LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
+                LEFT JOIN user_document_ids ud ON u.id = ud.user_id
+                {" WHERE u.id = ANY($3::uuid[])" if user_ids else ""}
+                GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
+                         u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
+            )
+            SELECT
+                user_docs.*,
+                COUNT(*) OVER() AS total_entries
+            FROM user_docs
+            ORDER BY email
+            OFFSET $1
+        """
+
+        params: list = [offset]
+
+        if limit != -1:
+            query += " LIMIT $2"
+            params.append(limit)
+
+        if user_ids:
+            params.append(user_ids)
+
+        results = await self.connection_manager.fetch_query(query, params)
+        if not results:
+            raise R2RException(status_code=404, message="No users found")
+
+        users_list = []
+        for row in results:
+            users_list.append(
+                User(
+                    id=row["id"],
+                    email=row["email"],
+                    is_superuser=row["is_superuser"],
+                    is_active=row["is_active"],
+                    is_verified=row["is_verified"],
+                    name=row["name"],
+                    bio=row["bio"],
+                    created_at=row["created_at"],
+                    updated_at=row["updated_at"],
+                    profile_picture=row["profile_picture"],
+                    collection_ids=row["collection_ids"] or [],
+                    num_files=row["num_files"],
+                    total_size_in_bytes=row["total_size_in_bytes"],
+                    document_ids=(
+                        list(row["document_ids"])
+                        if row["document_ids"]
+                        else []
+                    ),
+                )
+            )
+
+        total_entries = results[0]["total_entries"]
+        return {"results": users_list, "total_entries": total_entries}
+
+    async def _collection_exists(self, collection_id: UUID) -> bool:
+        """Check if a collection exists."""
+        query = f"""
+            SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+            WHERE id = $1
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id]
+        )
+        return result is not None
+
+    async def get_user_validation_data(
+        self,
+        user_id: UUID,
+    ) -> dict:
+        """Get verification data for a specific user.
+
+        This method should be called after superuser authorization has been
+        verified.
+        """
+        query = f"""
+            SELECT
+                verification_code,
+                verification_code_expiry,
+                reset_token,
+                reset_token_expiry
+            FROM {self._get_table_name("users")}
+            WHERE id = $1
+        """
+        result = await self.connection_manager.fetchrow_query(query, [user_id])
+
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+        return {
+            "verification_data": {
+                "verification_code": result["verification_code"],
+                "verification_code_expiry": (
+                    result["verification_code_expiry"].isoformat()
+                    if result["verification_code_expiry"]
+                    else None
+                ),
+                "reset_token": result["reset_token"],
+                "reset_token_expiry": (
+                    result["reset_token_expiry"].isoformat()
+                    if result["reset_token_expiry"]
+                    else None
+                ),
+            }
+        }
+
+    # API Key methods
+    async def store_user_api_key(
+        self,
+        user_id: UUID,
+        key_id: str,
+        hashed_key: str,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+    ) -> UUID:
+        """Store a new API key for a user with optional name and
+        description."""
+        query = f"""
+            INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            (user_id, public_key, hashed_key, name, description)
+            VALUES ($1, $2, $3, $4, $5)
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [user_id, key_id, hashed_key, name or "", description or ""]
+        )
+        if not result:
+            raise R2RException(
+                status_code=500, message="Failed to store API key"
+            )
+        return result["id"]
+
+    async def get_api_key_record(self, key_id: str) -> Optional[dict]:
+        """Get API key record by 'public_key' and update 'updated_at' to now.
+
+        Returns { "user_id", "hashed_key" } or None if not found.
+        """
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            SET updated_at = NOW()
+            WHERE public_key = $1
+            RETURNING user_id, hashed_key
+        """
+        result = await self.connection_manager.fetchrow_query(query, [key_id])
+        if not result:
+            return None
+        return {
+            "user_id": result["user_id"],
+            "hashed_key": result["hashed_key"],
+        }
+
+    async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
+        """Get all API keys for a user."""
+        query = f"""
+            SELECT id, public_key, name, description, created_at, updated_at
+            FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            WHERE user_id = $1
+            ORDER BY created_at DESC
+        """
+        results = await self.connection_manager.fetch_query(query, [user_id])
+        return [
+            {
+                "key_id": str(row["id"]),
+                "public_key": row["public_key"],
+                "name": row["name"] or "",
+                "description": row["description"] or "",
+                "updated_at": row["updated_at"],
+            }
+            for row in results
+        ]
+
+    async def delete_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+        """Delete a specific API key."""
+        query = f"""
+            DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            WHERE id = $1 AND user_id = $2
+            RETURNING id, public_key, name, description
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [key_id, user_id]
+        )
+        if result is None:
+            raise R2RException(status_code=404, message="API key not found")
+
+        return True
+
+    async def update_api_key_name(
+        self, user_id: UUID, key_id: UUID, name: str
+    ) -> bool:
+        """Update the name of an existing API key."""
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            SET name = $1, updated_at = NOW()
+            WHERE id = $2 AND user_id = $3
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [name, key_id, user_id]
+        )
+        if result is None:
+            raise R2RException(status_code=404, message="API key not found")
+        return True
+
+    async def export_to_csv(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """Creates a CSV file from the PostgreSQL data and returns the path to
+        the temp file."""
+        valid_columns = {
+            "id",
+            "email",
+            "is_superuser",
+            "is_active",
+            "is_verified",
+            "name",
+            "bio",
+            "collection_ids",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                email,
+                is_superuser,
+                is_active,
+                is_verified,
+                name,
+                bio,
+                collection_ids::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self.TABLE_NAME)}
+        """
+
+        params = []
+        if filters:
+            conditions = []
+            param_index = 1
+
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+            if conditions:
+                select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            row_dict = {
+                                "id": row[0],
+                                "email": row[1],
+                                "is_superuser": row[2],
+                                "is_active": row[3],
+                                "is_verified": row[4],
+                                "name": row[5],
+                                "bio": row[6],
+                                "collection_ids": row[7],
+                                "created_at": row[8],
+                                "updated_at": row[9],
+                            }
+                            writer.writerow([row_dict[col] for col in columns])
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
+    async def get_user_by_google_id(self, google_id: str) -> Optional[User]:
+        """Return a User if the google_id is found; otherwise None."""
+        query, params = (
+            QueryBuilder(self._get_table_name("users"))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "name",
+                    "profile_picture",
+                    "bio",
+                    "collection_ids",
+                    "limits_overrides",
+                    "metadata",
+                    "account_type",
+                    "hashed_password",
+                    "google_id",
+                    "github_id",
+                ]
+            )
+            .where("google_id = $1")
+            .build()
+        )
+        result = await self.connection_manager.fetchrow_query(
+            query, [google_id]
+        )
+        if not result:
+            return None
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"] or [],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            metadata=json.loads(result["metadata"] or "{}"),
+            account_type=result["account_type"],
+            hashed_password=result["hashed_password"],
+            google_id=result["google_id"],
+            github_id=result["github_id"],
+        )
+
+    async def get_user_by_github_id(self, github_id: str) -> Optional[User]:
+        """Return a User if the github_id is found; otherwise None."""
+        query, params = (
+            QueryBuilder(self._get_table_name("users"))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "name",
+                    "profile_picture",
+                    "bio",
+                    "collection_ids",
+                    "limits_overrides",
+                    "metadata",
+                    "account_type",
+                    "hashed_password",
+                    "google_id",
+                    "github_id",
+                ]
+            )
+            .where("github_id = $1")
+            .build()
+        )
+        result = await self.connection_manager.fetchrow_query(
+            query, [github_id]
+        )
+        if not result:
+            return None
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"] or [],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            metadata=json.loads(result["metadata"] or "{}"),
+            account_type=result["account_type"],
+            hashed_password=result["hashed_password"],
+            google_id=result["google_id"],
+            github_id=result["github_id"],
+        )