aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/dns/dnssecalgs/rsa.py
blob: 27537aad0c1857276d62c476f583a3e093e34e0f (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import math
import struct

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa

from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY


class PublicRSA(CryptographyPublicKey):
    key: rsa.RSAPublicKey
    key_cls = rsa.RSAPublicKey
    algorithm: Algorithm
    chosen_hash: hashes.HashAlgorithm

    def verify(self, signature: bytes, data: bytes) -> None:
        self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)

    def encode_key_bytes(self) -> bytes:
        """Encode a public key per RFC 3110, section 2."""
        pn = self.key.public_numbers()
        _exp_len = math.ceil(int.bit_length(pn.e) / 8)
        exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
        if _exp_len > 255:
            exp_header = b"\0" + struct.pack("!H", _exp_len)
        else:
            exp_header = struct.pack("!B", _exp_len)
        if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
            raise ValueError("unsupported RSA key length")
        return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")

    @classmethod
    def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
        cls._ensure_algorithm_key_combination(key)
        keyptr = key.key
        (bytes_,) = struct.unpack("!B", keyptr[0:1])
        keyptr = keyptr[1:]
        if bytes_ == 0:
            (bytes_,) = struct.unpack("!H", keyptr[0:2])
            keyptr = keyptr[2:]
        rsa_e = keyptr[0:bytes_]
        rsa_n = keyptr[bytes_:]
        return cls(
            key=rsa.RSAPublicNumbers(
                int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
            ).public_key(default_backend())
        )


class PrivateRSA(CryptographyPrivateKey):
    key: rsa.RSAPrivateKey
    key_cls = rsa.RSAPrivateKey
    public_cls = PublicRSA
    default_public_exponent = 65537

    def sign(
        self,
        data: bytes,
        verify: bool = False,
        deterministic: bool = True,
    ) -> bytes:
        """Sign using a private key per RFC 3110, section 3."""
        signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
        if verify:
            self.public_key().verify(signature, data)
        return signature

    @classmethod
    def generate(cls, key_size: int) -> "PrivateRSA":
        return cls(
            key=rsa.generate_private_key(
                public_exponent=cls.default_public_exponent,
                key_size=key_size,
                backend=default_backend(),
            )
        )


class PublicRSAMD5(PublicRSA):
    algorithm = Algorithm.RSAMD5
    chosen_hash = hashes.MD5()


class PrivateRSAMD5(PrivateRSA):
    public_cls = PublicRSAMD5


class PublicRSASHA1(PublicRSA):
    algorithm = Algorithm.RSASHA1
    chosen_hash = hashes.SHA1()


class PrivateRSASHA1(PrivateRSA):
    public_cls = PublicRSASHA1


class PublicRSASHA1NSEC3SHA1(PublicRSA):
    algorithm = Algorithm.RSASHA1NSEC3SHA1
    chosen_hash = hashes.SHA1()


class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
    public_cls = PublicRSASHA1NSEC3SHA1


class PublicRSASHA256(PublicRSA):
    algorithm = Algorithm.RSASHA256
    chosen_hash = hashes.SHA256()


class PrivateRSASHA256(PrivateRSA):
    public_cls = PublicRSASHA256


class PublicRSASHA512(PublicRSA):
    algorithm = Algorithm.RSASHA512
    chosen_hash = hashes.SHA512()


class PrivateRSASHA512(PrivateRSA):
    public_cls = PublicRSASHA512