about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/tiktoken/load.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/tiktoken/load.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/tiktoken/load.py')
-rw-r--r--.venv/lib/python3.12/site-packages/tiktoken/load.py148
1 files changed, 148 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/tiktoken/load.py b/.venv/lib/python3.12/site-packages/tiktoken/load.py
new file mode 100644
index 00000000..8434c234
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/tiktoken/load.py
@@ -0,0 +1,148 @@
+from __future__ import annotations
+
+import base64
+import hashlib
+import json
+import os
+import tempfile
+import uuid
+
+import requests
+
+
+def read_file(blobpath: str) -> bytes:
+    if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
+        try:
+            import blobfile
+        except ImportError as e:
+            raise ImportError(
+                "blobfile is not installed. Please install it by running `pip install blobfile`."
+            ) from e
+        with blobfile.BlobFile(blobpath, "rb") as f:
+            return f.read()
+    # avoiding blobfile for public files helps avoid auth issues, like MFA prompts
+    resp = requests.get(blobpath)
+    resp.raise_for_status()
+    return resp.content
+
+
+def check_hash(data: bytes, expected_hash: str) -> bool:
+    actual_hash = hashlib.sha256(data).hexdigest()
+    return actual_hash == expected_hash
+
+
+def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:
+    user_specified_cache = True
+    if "TIKTOKEN_CACHE_DIR" in os.environ:
+        cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
+    elif "DATA_GYM_CACHE_DIR" in os.environ:
+        cache_dir = os.environ["DATA_GYM_CACHE_DIR"]
+    else:
+        cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
+        user_specified_cache = False
+
+    if cache_dir == "":
+        # disable caching
+        return read_file(blobpath)
+
+    cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
+
+    cache_path = os.path.join(cache_dir, cache_key)
+    if os.path.exists(cache_path):
+        with open(cache_path, "rb") as f:
+            data = f.read()
+        if expected_hash is None or check_hash(data, expected_hash):
+            return data
+
+        # the cached file does not match the hash, remove it and re-fetch
+        try:
+            os.remove(cache_path)
+        except OSError:
+            pass
+
+    contents = read_file(blobpath)
+    if expected_hash and not check_hash(contents, expected_hash):
+        raise ValueError(
+            f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
+            f"This may indicate a corrupted download. Please try again."
+        )
+
+    try:
+        os.makedirs(cache_dir, exist_ok=True)
+        tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"
+        with open(tmp_filename, "wb") as f:
+            f.write(contents)
+        os.rename(tmp_filename, cache_path)
+    except OSError:
+        # don't raise if we can't write to the default cache, e.g. issue #75
+        if user_specified_cache:
+            raise
+
+    return contents
+
+
+def data_gym_to_mergeable_bpe_ranks(
+    vocab_bpe_file: str,
+    encoder_json_file: str,
+    vocab_bpe_hash: str | None = None,
+    encoder_json_hash: str | None = None,
+) -> dict[bytes, int]:
+    # NB: do not add caching to this function
+    rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
+
+    data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
+    n = 0
+    for b in range(2**8):
+        if b not in rank_to_intbyte:
+            rank_to_intbyte.append(b)
+            data_gym_byte_to_byte[chr(2**8 + n)] = b
+            n += 1
+    assert len(rank_to_intbyte) == 2**8
+
+    # vocab_bpe contains the merges along with associated ranks
+    vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
+    bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
+
+    def decode_data_gym(value: str) -> bytes:
+        return bytes(data_gym_byte_to_byte[b] for b in value)
+
+    # add the single byte tokens
+    bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}
+    # add the merged tokens
+    n = len(bpe_ranks)
+    for first, second in bpe_merges:
+        bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n
+        n += 1
+
+    # check that the encoder file matches the merges file
+    # this sanity check is important since tiktoken assumes that ranks are ordered the same
+    # as merge priority
+    encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
+    encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
+    # drop these two special tokens if present, since they're not mergeable bpe tokens
+    encoder_json_loaded.pop(b"<|endoftext|>", None)
+    encoder_json_loaded.pop(b"<|startoftext|>", None)
+    assert bpe_ranks == encoder_json_loaded
+
+    return bpe_ranks
+
+
+def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
+    try:
+        import blobfile
+    except ImportError as e:
+        raise ImportError(
+            "blobfile is not installed. Please install it by running `pip install blobfile`."
+        ) from e
+    with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f:
+        for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
+            f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
+
+
+def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]:
+    # NB: do not add caching to this function
+    contents = read_file_cached(tiktoken_bpe_file, expected_hash)
+    return {
+        base64.b64decode(token): int(rank)
+        for token, rank in (line.split() for line in contents.splitlines() if line)
+    }