about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/ellipticcurve/utils/der.py
blob: 84546aea0cc522e981a1b175bc9515d50512587f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from datetime import datetime
from .oid import oidToHex, oidFromHex
from .binary import hexFromInt, intFromHex, byteStringFromHex, bitsFromHex


class DerFieldType:

    integer = "integer"
    bitString = "bitString"
    octetString = "octetString"
    null = "null"
    object = "object"
    printableString = "printableString"
    utcTime = "utcTime"
    sequence = "sequence"
    set = "set"
    oidContainer = "oidContainer"
    publicKeyPointContainer = "publicKeyPointContainer"


_hexTagToType = {
    "02": DerFieldType.integer,
    "03": DerFieldType.bitString,
    "04": DerFieldType.octetString,
    "05": DerFieldType.null,
    "06": DerFieldType.object,
    "13": DerFieldType.printableString,
    "17": DerFieldType.utcTime,
    "30": DerFieldType.sequence,
    "31": DerFieldType.set,
    "a0": DerFieldType.oidContainer,
    "a1": DerFieldType.publicKeyPointContainer,
}
_typeToHexTag = {v: k for k, v in _hexTagToType.items()}


def encodeConstructed(*encodedValues):
    return encodePrimitive(DerFieldType.sequence, "".join(encodedValues))


def encodePrimitive(tagType, value):
    if tagType == DerFieldType.integer:
        value = _encodeInteger(value)
    if tagType == DerFieldType.object:
        value = oidToHex(value)
    return "{tag}{size}{value}".format(tag=_typeToHexTag[tagType], size=_generateLengthBytes(value), value=value)


def parse(hexadecimal):
    if not hexadecimal:
        return []
    typeByte, hexadecimal = hexadecimal[:2], hexadecimal[2:]
    length, lengthBytes = _readLengthBytes(hexadecimal)
    content, hexadecimal = hexadecimal[lengthBytes: lengthBytes + length], hexadecimal[lengthBytes + length:]
    if len(content) < length:
        raise Exception("missing bytes in DER parse")

    tagData = _getTagData(typeByte)
    if tagData["isConstructed"]:
        content = parse(content)

    valueParser = {
        DerFieldType.null: _parseNull,
        DerFieldType.object: _parseOid,
        DerFieldType.utcTime: _parseTime,
        DerFieldType.integer: _parseInteger,
        DerFieldType.printableString: _parseString,
    }.get(tagData["type"], _parseAny)
    return [valueParser(content)] + parse(hexadecimal)


def _parseAny(hexadecimal):
    return hexadecimal


def _parseOid(hexadecimal):
    return tuple(oidFromHex(hexadecimal))


def _parseTime(hexadecimal):
    string = _parseString(hexadecimal)
    return datetime.strptime(string, "%y%m%d%H%M%SZ")


def _parseString(hexadecimal):
    return byteStringFromHex(hexadecimal).decode()


def _parseNull(_content):
    return None


def _parseInteger(hexadecimal):
    integer = intFromHex(hexadecimal)
    bits = bitsFromHex(hexadecimal[0])
    if bits[0] == "0":  # negative numbers are encoded using two's complement
        return integer
    bitCount = 4 * len(hexadecimal)
    return integer - (2 ** bitCount)


def _encodeInteger(number):
    hexadecimal = hexFromInt(abs(number))
    if number < 0:
        bitCount = 4 * len(hexadecimal)
        twosComplement = (2 ** bitCount) + number
        return hexFromInt(twosComplement)
    bits = bitsFromHex(hexadecimal[0])
    if bits[0] == "1":  # if first bit was left as 1, number would be parsed as a negative integer with two's complement
        hexadecimal = "00" + hexadecimal
    return hexadecimal


def _readLengthBytes(hexadecimal):
    lengthBytes = 2
    lengthIndicator = intFromHex(hexadecimal[0:lengthBytes])
    isShortForm = lengthIndicator < 128  # checks if first bit of byte is 1 (a.k.a. short-form)
    if isShortForm:
        length = lengthIndicator * 2
        return length, lengthBytes

    lengthLength = lengthIndicator - 128  # nullifies first bit of byte (only used as long-form flag)
    if lengthLength == 0:
        raise Exception("indefinite length encoding located in DER")
    lengthBytes += 2 * lengthLength
    length = intFromHex(hexadecimal[2:lengthBytes]) * 2
    return length, lengthBytes


def _generateLengthBytes(hexadecimal):
    size = len(hexadecimal) // 2
    length = hexFromInt(size)
    if size < 128:  # checks if first bit of byte should be 0 (a.k.a. short-form flag)
        return length.zfill(2)
    lengthLength = 128 + len(length) // 2  # +128 sets the first bit of the byte as 1 (a.k.a. long-form flag)
    return hexFromInt(lengthLength) + length


def _getTagData(tag):
    bits = bitsFromHex(tag)
    bit8, bit7, bit6 = bits[:3]

    tagClass = {
        "0": {
            "0": "universal",
            "1": "application",
        },
        "1": {
            "0": "context-specific",
            "1": "private",
        },
    }[bit8][bit7]
    isConstructed = bit6 == "1"

    return {
        "class": tagClass,
        "isConstructed": isConstructed,
        "type": _hexTagToType.get(tag),
    }