1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
|
from __future__ import annotations
import base64
import hashlib
import json
import os
import tempfile
import uuid
import requests
def read_file(blobpath: str) -> bytes:
if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
try:
import blobfile
except ImportError as e:
raise ImportError(
"blobfile is not installed. Please install it by running `pip install blobfile`."
) from e
with blobfile.BlobFile(blobpath, "rb") as f:
return f.read()
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts
resp = requests.get(blobpath)
resp.raise_for_status()
return resp.content
def check_hash(data: bytes, expected_hash: str) -> bool:
actual_hash = hashlib.sha256(data).hexdigest()
return actual_hash == expected_hash
def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
elif "DATA_GYM_CACHE_DIR" in os.environ:
cache_dir = os.environ["DATA_GYM_CACHE_DIR"]
else:
cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
user_specified_cache = False
if cache_dir == "":
# disable caching
return read_file(blobpath)
cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
cache_path = os.path.join(cache_dir, cache_key)
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
data = f.read()
if expected_hash is None or check_hash(data, expected_hash):
return data
# the cached file does not match the hash, remove it and re-fetch
try:
os.remove(cache_path)
except OSError:
pass
contents = read_file(blobpath)
if expected_hash and not check_hash(contents, expected_hash):
raise ValueError(
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
f"This may indicate a corrupted download. Please try again."
)
try:
os.makedirs(cache_dir, exist_ok=True)
tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"
with open(tmp_filename, "wb") as f:
f.write(contents)
os.rename(tmp_filename, cache_path)
except OSError:
# don't raise if we can't write to the default cache, e.g. issue #75
if user_specified_cache:
raise
return contents
def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str,
encoder_json_file: str,
vocab_bpe_hash: str | None = None,
encoder_json_hash: str | None = None,
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
n = 0
for b in range(2**8):
if b not in rank_to_intbyte:
rank_to_intbyte.append(b)
data_gym_byte_to_byte[chr(2**8 + n)] = b
n += 1
assert len(rank_to_intbyte) == 2**8
# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
def decode_data_gym(value: str) -> bytes:
return bytes(data_gym_byte_to_byte[b] for b in value)
# add the single byte tokens
bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}
# add the merged tokens
n = len(bpe_ranks)
for first, second in bpe_merges:
bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n
n += 1
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
encoder_json_loaded.pop(b"<|startoftext|>", None)
assert bpe_ranks == encoder_json_loaded
return bpe_ranks
def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
try:
import blobfile
except ImportError as e:
raise ImportError(
"blobfile is not installed. Please install it by running `pip install blobfile`."
) from e
with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f:
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
|