aboutsummaryrefslogtreecommitdiff
import base64
import logging
import os
from abc import ABC
from datetime import datetime, timezone
from typing import Optional, Tuple

import bcrypt
import jwt
import nacl.encoding
import nacl.exceptions
import nacl.signing
import nacl.utils

from core.base import CryptoConfig, CryptoProvider

DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM"  # Replace or load from env or secrets manager


class BcryptCryptoConfig(CryptoConfig):
    provider: str = "bcrypt"
    # Number of rounds for bcrypt (increasing this makes hashing slower but more secure)
    bcrypt_rounds: int = 12
    secret_key: Optional[str] = None
    api_key_bytes: int = 32  # Length of raw API keys

    @property
    def supported_providers(self) -> list[str]:
        return ["bcrypt"]

    def validate_config(self) -> None:
        super().validate_config()
        if self.provider not in self.supported_providers:
            raise ValueError(f"Unsupported crypto provider: {self.provider}")
        if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31:
            raise ValueError("bcrypt_rounds must be between 4 and 31")

    def verify_password(
        self, plain_password: str, hashed_password: str
    ) -> bool:
        try:
            # First try to decode as base64 (new format)
            stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
        except Exception:
            # If that fails, treat as raw bcrypt hash (old format)
            stored_hash = hashed_password.encode("utf-8")

        return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)


class BCryptCryptoProvider(CryptoProvider, ABC):
    def __init__(self, config: BcryptCryptoConfig):
        if not isinstance(config, BcryptCryptoConfig):
            raise ValueError(
                "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig"
            )
        logging.info("Initializing BcryptCryptoProvider")
        super().__init__(config)
        self.config: BcryptCryptoConfig = config

        # Load the secret key for JWT
        # No fallback defaults: fail if not provided
        self.secret_key = (
            config.secret_key
            or os.getenv("R2R_SECRET_KEY")
            or DEFAULT_BCRYPT_SECRET_KEY
        )
        if not self.secret_key:
            raise ValueError(
                "No secret key provided for BcryptCryptoProvider."
            )

    def get_password_hash(self, password: str) -> str:
        # Bcrypt expects bytes
        password_bytes = password.encode("utf-8")
        hashed = bcrypt.hashpw(
            password_bytes, bcrypt.gensalt(rounds=self.config.bcrypt_rounds)
        )
        return base64.b64encode(hashed).decode("utf-8")

    def verify_password(
        self, plain_password: str, hashed_password: str
    ) -> bool:
        try:
            # First try to decode as base64 (new format)
            stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
            if not stored_hash.startswith(b"$2b$"):  # Valid bcrypt hash prefix
                stored_hash = hashed_password.encode("utf-8")
        except Exception:
            # Otherwise raw bcrypt hash (old format)
            stored_hash = hashed_password.encode("utf-8")

        try:
            return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
        except ValueError as e:
            if "Invalid salt" in str(e):
                # If it's an invalid salt, the hash format is wrong - try the other format
                try:
                    stored_hash = (
                        hashed_password
                        if isinstance(hashed_password, bytes)
                        else hashed_password.encode("utf-8")
                    )
                    return bcrypt.checkpw(
                        plain_password.encode("utf-8"), stored_hash
                    )
                except ValueError:
                    return False
            raise

    def generate_verification_code(self, length: int = 32) -> str:
        random_bytes = nacl.utils.random(length)
        return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")

    def generate_signing_keypair(self) -> Tuple[str, str, str]:
        signing_key = nacl.signing.SigningKey.generate()
        verify_key = signing_key.verify_key

        # Generate unique key_id
        key_entropy = nacl.utils.random(16)
        key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"

        private_key = base64.b64encode(bytes(signing_key)).decode()
        public_key = base64.b64encode(bytes(verify_key)).decode()
        return key_id, private_key, public_key

    def sign_request(self, private_key: str, data: str) -> str:
        try:
            key_bytes = base64.b64decode(private_key)
            signing_key = nacl.signing.SigningKey(key_bytes)
            signature = signing_key.sign(data.encode())
            return base64.b64encode(signature.signature).decode()
        except Exception as e:
            raise ValueError(
                f"Invalid private key or signing error: {str(e)}"
            ) from e

    def verify_request_signature(
        self, public_key: str, signature: str, data: str
    ) -> bool:
        try:
            key_bytes = base64.b64decode(public_key)
            verify_key = nacl.signing.VerifyKey(key_bytes)
            signature_bytes = base64.b64decode(signature)
            verify_key.verify(data.encode(), signature_bytes)
            return True
        except (nacl.exceptions.BadSignatureError, ValueError):
            return False

    def generate_api_key(self) -> Tuple[str, str]:
        # Similar approach as with NaCl provider:
        key_id_bytes = nacl.utils.random(16)
        key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"

        # Generate raw API key
        raw_api_key = base64.urlsafe_b64encode(
            nacl.utils.random(self.config.api_key_bytes)
        ).decode()
        return key_id, raw_api_key

    def hash_api_key(self, raw_api_key: str) -> str:
        # Hash with bcrypt
        hashed = bcrypt.hashpw(
            raw_api_key.encode("utf-8"),
            bcrypt.gensalt(rounds=self.config.bcrypt_rounds),
        )
        return base64.b64encode(hashed).decode("utf-8")

    def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
        stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
        return bcrypt.checkpw(raw_api_key.encode("utf-8"), stored_hash)

    def generate_secure_token(self, data: dict, expiry: datetime) -> str:
        now = datetime.now(timezone.utc)
        to_encode = {
            **data,
            "exp": expiry.timestamp(),
            "iat": now.timestamp(),
            "nbf": now.timestamp(),
            "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
            "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
        }
        return jwt.encode(to_encode, self.secret_key, algorithm="HS256")

    def verify_secure_token(self, token: str) -> Optional[dict]:
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
            exp = payload.get("exp")
            if exp is None or datetime.fromtimestamp(
                exp, tz=timezone.utc
            ) < datetime.now(timezone.utc):
                return None
            return payload
        except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
            return None