diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/jwt/algorithms.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/jwt/algorithms.py | 875 |
1 files changed, 875 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/jwt/algorithms.py b/.venv/lib/python3.12/site-packages/jwt/algorithms.py new file mode 100644 index 00000000..ccb1500f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/jwt/algorithms.py @@ -0,0 +1,875 @@ +from __future__ import annotations + +import hashlib +import hmac +import json +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload + +from .exceptions import InvalidKeyError +from .types import HashlibHash, JWKDict +from .utils import ( + base64url_decode, + base64url_encode, + der_to_raw_signature, + force_bytes, + from_base64url_uint, + is_pem_format, + is_ssh_key, + raw_to_der_signature, + to_base64url_uint, +) + +try: + from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import padding + from cryptography.hazmat.primitives.asymmetric.ec import ( + ECDSA, + SECP256K1, + SECP256R1, + SECP384R1, + SECP521R1, + EllipticCurve, + EllipticCurvePrivateKey, + EllipticCurvePrivateNumbers, + EllipticCurvePublicKey, + EllipticCurvePublicNumbers, + ) + from cryptography.hazmat.primitives.asymmetric.ed448 import ( + Ed448PrivateKey, + Ed448PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PrivateKey, + Ed25519PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPrivateNumbers, + RSAPublicKey, + RSAPublicNumbers, + rsa_crt_dmp1, + rsa_crt_dmq1, + rsa_crt_iqmp, + rsa_recover_prime_factors, + ) + from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + PublicFormat, + load_pem_private_key, + load_pem_public_key, + load_ssh_public_key, + ) + + has_crypto = True +except ModuleNotFoundError: + has_crypto = False + + +if TYPE_CHECKING: + # Type aliases for convenience in algorithms method signatures + AllowedRSAKeys = RSAPrivateKey | RSAPublicKey + AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey + AllowedOKPKeys = ( + Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + ) + AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys + AllowedPrivateKeys = ( + RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey + ) + AllowedPublicKeys = ( + RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey + ) + + +requires_cryptography = { + "RS256", + "RS384", + "RS512", + "ES256", + "ES256K", + "ES384", + "ES521", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA", +} + + +def get_default_algorithms() -> dict[str, Algorithm]: + """ + Returns the algorithms that are implemented by the library. + """ + default_algorithms = { + "none": NoneAlgorithm(), + "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), + "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), + "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), + } + + if has_crypto: + default_algorithms.update( + { + "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), + "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), + "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), + "ES256": ECAlgorithm(ECAlgorithm.SHA256), + "ES256K": ECAlgorithm(ECAlgorithm.SHA256), + "ES384": ECAlgorithm(ECAlgorithm.SHA384), + "ES521": ECAlgorithm(ECAlgorithm.SHA512), + "ES512": ECAlgorithm( + ECAlgorithm.SHA512 + ), # Backward compat for #219 fix + "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), + "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), + "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), + "EdDSA": OKPAlgorithm(), + } + ) + + return default_algorithms + + +class Algorithm(ABC): + """ + The interface for an algorithm used to sign and verify tokens. + """ + + def compute_hash_digest(self, bytestr: bytes) -> bytes: + """ + Compute a hash digest using the specified algorithm's hash algorithm. + + If there is no hash algorithm, raises a NotImplementedError. + """ + # lookup self.hash_alg if defined in a way that mypy can understand + hash_alg = getattr(self, "hash_alg", None) + if hash_alg is None: + raise NotImplementedError + + if ( + has_crypto + and isinstance(hash_alg, type) + and issubclass(hash_alg, hashes.HashAlgorithm) + ): + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return bytes(digest.finalize()) + else: + return bytes(hash_alg(bytestr).digest()) + + @abstractmethod + def prepare_key(self, key: Any) -> Any: + """ + Performs necessary validation and conversions on the key and returns + the key value in the proper format for sign() and verify(). + """ + + @abstractmethod + def sign(self, msg: bytes, key: Any) -> bytes: + """ + Returns a digital signature for the specified message + using the specified key value. + """ + + @abstractmethod + def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: + """ + Verifies that the specified digital signature is valid + for the specified message and key values. + """ + + @overload + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover + + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: + """ + Serializes a given key into a JWK + """ + + @staticmethod + @abstractmethod + def from_jwk(jwk: str | JWKDict) -> Any: + """ + Deserializes a given key from JWK back into a key object + """ + + +class NoneAlgorithm(Algorithm): + """ + Placeholder for use when no signing or verification + operations are required. + """ + + def prepare_key(self, key: str | None) -> None: + if key == "": + key = None + + if key is not None: + raise InvalidKeyError('When alg = "none", key value must be None.') + + return key + + def sign(self, msg: bytes, key: None) -> bytes: + return b"" + + def verify(self, msg: bytes, key: None, sig: bytes) -> bool: + return False + + @staticmethod + def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: + raise NotImplementedError() + + @staticmethod + def from_jwk(jwk: str | JWKDict) -> NoReturn: + raise NotImplementedError() + + +class HMACAlgorithm(Algorithm): + """ + Performs signing and verification operations using HMAC + and the specified hash function. + """ + + SHA256: ClassVar[HashlibHash] = hashlib.sha256 + SHA384: ClassVar[HashlibHash] = hashlib.sha384 + SHA512: ClassVar[HashlibHash] = hashlib.sha512 + + def __init__(self, hash_alg: HashlibHash) -> None: + self.hash_alg = hash_alg + + def prepare_key(self, key: str | bytes) -> bytes: + key_bytes = force_bytes(key) + + if is_pem_format(key_bytes) or is_ssh_key(key_bytes): + raise InvalidKeyError( + "The specified key is an asymmetric key or x509 certificate and" + " should not be used as an HMAC secret." + ) + + return key_bytes + + @overload + @staticmethod + def to_jwk( + key_obj: str | bytes, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + def to_jwk( + key_obj: str | bytes, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover + + @staticmethod + def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: + jwk = { + "k": base64url_encode(force_bytes(key_obj)).decode(), + "kty": "oct", + } + + if as_dict: + return jwk + else: + return json.dumps(jwk) + + @staticmethod + def from_jwk(jwk: str | JWKDict) -> bytes: + try: + if isinstance(jwk, str): + obj: JWKDict = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") from None + + if obj.get("kty") != "oct": + raise InvalidKeyError("Not an HMAC key") + + return base64url_decode(obj["k"]) + + def sign(self, msg: bytes, key: bytes) -> bytes: + return hmac.new(key, msg, self.hash_alg).digest() + + def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: + return hmac.compare_digest(sig, self.sign(msg, key)) + + +if has_crypto: + + class RSAAlgorithm(Algorithm): + """ + Performs signing and verification operations using + RSASSA-PKCS-v1_5 and the specified hash function. + """ + + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 + + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: + self.hash_alg = hash_alg + + def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: + if isinstance(key, (RSAPrivateKey, RSAPublicKey)): + return key + + if not isinstance(key, (bytes, str)): + raise TypeError("Expecting a PEM-formatted key.") + + key_bytes = force_bytes(key) + + try: + if key_bytes.startswith(b"ssh-rsa"): + return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) + else: + return cast( + RSAPrivateKey, load_pem_private_key(key_bytes, password=None) + ) + except ValueError: + try: + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + except (ValueError, UnsupportedAlgorithm): + raise InvalidKeyError( + "Could not parse the provided public key." + ) from None + + @overload + @staticmethod + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover + + @staticmethod + def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: + obj: dict[str, Any] | None = None + + if hasattr(key_obj, "private_numbers"): + # Private key + numbers = key_obj.private_numbers() + + obj = { + "kty": "RSA", + "key_ops": ["sign"], + "n": to_base64url_uint(numbers.public_numbers.n).decode(), + "e": to_base64url_uint(numbers.public_numbers.e).decode(), + "d": to_base64url_uint(numbers.d).decode(), + "p": to_base64url_uint(numbers.p).decode(), + "q": to_base64url_uint(numbers.q).decode(), + "dp": to_base64url_uint(numbers.dmp1).decode(), + "dq": to_base64url_uint(numbers.dmq1).decode(), + "qi": to_base64url_uint(numbers.iqmp).decode(), + } + + elif hasattr(key_obj, "verify"): + # Public key + numbers = key_obj.public_numbers() + + obj = { + "kty": "RSA", + "key_ops": ["verify"], + "n": to_base64url_uint(numbers.n).decode(), + "e": to_base64url_uint(numbers.e).decode(), + } + else: + raise InvalidKeyError("Not a public or private key") + + if as_dict: + return obj + else: + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") from None + + if obj.get("kty") != "RSA": + raise InvalidKeyError("Not an RSA key") from None + + if "d" in obj and "e" in obj and "n" in obj: + # Private key + if "oth" in obj: + raise InvalidKeyError( + "Unsupported RSA private key: > 2 primes not supported" + ) + + other_props = ["p", "q", "dp", "dq", "qi"] + props_found = [prop in obj for prop in other_props] + any_props_found = any(props_found) + + if any_props_found and not all(props_found): + raise InvalidKeyError( + "RSA key must include all parameters if any are present besides d" + ) from None + + public_numbers = RSAPublicNumbers( + from_base64url_uint(obj["e"]), + from_base64url_uint(obj["n"]), + ) + + if any_props_found: + numbers = RSAPrivateNumbers( + d=from_base64url_uint(obj["d"]), + p=from_base64url_uint(obj["p"]), + q=from_base64url_uint(obj["q"]), + dmp1=from_base64url_uint(obj["dp"]), + dmq1=from_base64url_uint(obj["dq"]), + iqmp=from_base64url_uint(obj["qi"]), + public_numbers=public_numbers, + ) + else: + d = from_base64url_uint(obj["d"]) + p, q = rsa_recover_prime_factors( + public_numbers.n, d, public_numbers.e + ) + + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers, + ) + + return numbers.private_key() + elif "n" in obj and "e" in obj: + # Public key + return RSAPublicNumbers( + from_base64url_uint(obj["e"]), + from_base64url_uint(obj["n"]), + ).public_key() + else: + raise InvalidKeyError("Not a public or private key") + + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: + return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) + + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: + try: + key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) + return True + except InvalidSignature: + return False + + class ECAlgorithm(Algorithm): + """ + Performs signing and verification operations using + ECDSA and the specified hash function + """ + + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 + + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: + self.hash_alg = hash_alg + + def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: + if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): + return key + + if not isinstance(key, (bytes, str)): + raise TypeError("Expecting a PEM-formatted key.") + + key_bytes = force_bytes(key) + + # Attempt to load key. We don't know if it's + # a Signing Key or a Verifying Key, so we try + # the Verifying Key first. + try: + if key_bytes.startswith(b"ecdsa-sha2-"): + crypto_key = load_ssh_public_key(key_bytes) + else: + crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] + except ValueError: + crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] + + # Explicit check the key to prevent confusing errors from cryptography + if not isinstance( + crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) + ): + raise InvalidKeyError( + "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" + ) from None + + return crypto_key + + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: + der_sig = key.sign(msg, ECDSA(self.hash_alg())) + + return der_to_raw_signature(der_sig, key.curve) + + def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: + try: + der_sig = raw_to_der_signature(sig, key.curve) + except ValueError: + return False + + try: + public_key = ( + key.public_key() + if isinstance(key, EllipticCurvePrivateKey) + else key + ) + public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) + return True + except InvalidSignature: + return False + + @overload + @staticmethod + def to_jwk( + key_obj: AllowedECKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + def to_jwk( + key_obj: AllowedECKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover + + @staticmethod + def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: + if isinstance(key_obj, EllipticCurvePrivateKey): + public_numbers = key_obj.public_key().public_numbers() + elif isinstance(key_obj, EllipticCurvePublicKey): + public_numbers = key_obj.public_numbers() + else: + raise InvalidKeyError("Not a public or private key") + + if isinstance(key_obj.curve, SECP256R1): + crv = "P-256" + elif isinstance(key_obj.curve, SECP384R1): + crv = "P-384" + elif isinstance(key_obj.curve, SECP521R1): + crv = "P-521" + elif isinstance(key_obj.curve, SECP256K1): + crv = "secp256k1" + else: + raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") + + obj: dict[str, Any] = { + "kty": "EC", + "crv": crv, + "x": to_base64url_uint( + public_numbers.x, + bit_length=key_obj.curve.key_size, + ).decode(), + "y": to_base64url_uint( + public_numbers.y, + bit_length=key_obj.curve.key_size, + ).decode(), + } + + if isinstance(key_obj, EllipticCurvePrivateKey): + obj["d"] = to_base64url_uint( + key_obj.private_numbers().private_value, + bit_length=key_obj.curve.key_size, + ).decode() + + if as_dict: + return obj + else: + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") from None + + if obj.get("kty") != "EC": + raise InvalidKeyError("Not an Elliptic curve key") from None + + if "x" not in obj or "y" not in obj: + raise InvalidKeyError("Not an Elliptic curve key") from None + + x = base64url_decode(obj.get("x")) + y = base64url_decode(obj.get("y")) + + curve = obj.get("crv") + curve_obj: EllipticCurve + + if curve == "P-256": + if len(x) == len(y) == 32: + curve_obj = SECP256R1() + else: + raise InvalidKeyError( + "Coords should be 32 bytes for curve P-256" + ) from None + elif curve == "P-384": + if len(x) == len(y) == 48: + curve_obj = SECP384R1() + else: + raise InvalidKeyError( + "Coords should be 48 bytes for curve P-384" + ) from None + elif curve == "P-521": + if len(x) == len(y) == 66: + curve_obj = SECP521R1() + else: + raise InvalidKeyError( + "Coords should be 66 bytes for curve P-521" + ) from None + elif curve == "secp256k1": + if len(x) == len(y) == 32: + curve_obj = SECP256K1() + else: + raise InvalidKeyError( + "Coords should be 32 bytes for curve secp256k1" + ) + else: + raise InvalidKeyError(f"Invalid curve: {curve}") + + public_numbers = EllipticCurvePublicNumbers( + x=int.from_bytes(x, byteorder="big"), + y=int.from_bytes(y, byteorder="big"), + curve=curve_obj, + ) + + if "d" not in obj: + return public_numbers.public_key() + + d = base64url_decode(obj.get("d")) + if len(d) != len(x): + raise InvalidKeyError( + "D should be {} bytes for curve {}", len(x), curve + ) + + return EllipticCurvePrivateNumbers( + int.from_bytes(d, byteorder="big"), public_numbers + ).private_key() + + class RSAPSSAlgorithm(RSAAlgorithm): + """ + Performs a signature using RSASSA-PSS with MGF1 + """ + + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: + return key.sign( + msg, + padding.PSS( + mgf=padding.MGF1(self.hash_alg()), + salt_length=self.hash_alg().digest_size, + ), + self.hash_alg(), + ) + + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: + try: + key.verify( + sig, + msg, + padding.PSS( + mgf=padding.MGF1(self.hash_alg()), + salt_length=self.hash_alg().digest_size, + ), + self.hash_alg(), + ) + return True + except InvalidSignature: + return False + + class OKPAlgorithm(Algorithm): + """ + Performs signing and verification operations using EdDSA + + This class requires ``cryptography>=2.6`` to be installed. + """ + + def __init__(self, **kwargs: Any) -> None: + pass + + def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: + if isinstance(key, (bytes, str)): + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + key_bytes = key.encode("utf-8") if isinstance(key, str) else key + + if "-----BEGIN PUBLIC" in key_str: + key = load_pem_public_key(key_bytes) # type: ignore[assignment] + elif "-----BEGIN PRIVATE" in key_str: + key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] + elif key_str[0:4] == "ssh-": + key = load_ssh_public_key(key_bytes) # type: ignore[assignment] + + # Explicit check the key to prevent confusing errors from cryptography + if not isinstance( + key, + (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey), + ): + raise InvalidKeyError( + "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms" + ) + + return key + + def sign( + self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey + ) -> bytes: + """ + Sign a message ``msg`` using the EdDSA private key ``key`` + :param str|bytes msg: Message to sign + :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` + or :class:`.Ed448PrivateKey` isinstance + :return bytes signature: The signature, as bytes + """ + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + return key.sign(msg_bytes) + + def verify( + self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes + ) -> bool: + """ + Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` + + :param str|bytes sig: EdDSA signature to check ``msg`` against + :param str|bytes msg: Message to sign + :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: + A private or public EdDSA key instance + :return bool verified: True if signature is valid, False if not. + """ + try: + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig + + public_key = ( + key.public_key() + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) + else key + ) + public_key.verify(sig_bytes, msg_bytes) + return True # If no exception was raised, the signature is valid. + except InvalidSignature: + return False + + @overload + @staticmethod + def to_jwk( + key: AllowedOKPKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + def to_jwk( + key: AllowedOKPKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover + + @staticmethod + def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: + if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): + x = key.public_bytes( + encoding=Encoding.Raw, + format=PublicFormat.Raw, + ) + crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" + + obj = { + "x": base64url_encode(force_bytes(x)).decode(), + "kty": "OKP", + "crv": crv, + } + + if as_dict: + return obj + else: + return json.dumps(obj) + + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): + d = key.private_bytes( + encoding=Encoding.Raw, + format=PrivateFormat.Raw, + encryption_algorithm=NoEncryption(), + ) + + x = key.public_key().public_bytes( + encoding=Encoding.Raw, + format=PublicFormat.Raw, + ) + + crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" + obj = { + "x": base64url_encode(force_bytes(x)).decode(), + "d": base64url_encode(force_bytes(d)).decode(), + "kty": "OKP", + "crv": crv, + } + + if as_dict: + return obj + else: + return json.dumps(obj) + + raise InvalidKeyError("Not a public or private key") + + @staticmethod + def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") from None + + if obj.get("kty") != "OKP": + raise InvalidKeyError("Not an Octet Key Pair") + + curve = obj.get("crv") + if curve != "Ed25519" and curve != "Ed448": + raise InvalidKeyError(f"Invalid curve: {curve}") + + if "x" not in obj: + raise InvalidKeyError('OKP should have "x" parameter') + x = base64url_decode(obj.get("x")) + + try: + if "d" not in obj: + if curve == "Ed25519": + return Ed25519PublicKey.from_public_bytes(x) + return Ed448PublicKey.from_public_bytes(x) + d = base64url_decode(obj.get("d")) + if curve == "Ed25519": + return Ed25519PrivateKey.from_private_bytes(d) + return Ed448PrivateKey.from_private_bytes(d) + except ValueError as err: + raise InvalidKeyError("Invalid key parameter") from err |