aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/tiktoken/_educational.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/tiktoken/_educational.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/tiktoken/_educational.py')
-rw-r--r--.venv/lib/python3.12/site-packages/tiktoken/_educational.py223
1 files changed, 223 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/tiktoken/_educational.py b/.venv/lib/python3.12/site-packages/tiktoken/_educational.py
new file mode 100644
index 00000000..317e7756
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/tiktoken/_educational.py
@@ -0,0 +1,223 @@
+"""This is an educational implementation of the byte pair encoding algorithm."""
+
+from __future__ import annotations
+
+import collections
+
+import regex
+
+import tiktoken
+
+
+class SimpleBytePairEncoding:
+ def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
+ """Creates an Encoding object."""
+ # A regex pattern string that is used to split the input text
+ self.pat_str = pat_str
+ # A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
+ self.mergeable_ranks = mergeable_ranks
+
+ self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()}
+ self._pat = regex.compile(pat_str)
+
+ def encode(self, text: str, visualise: str | None = "colour") -> list[int]:
+ """Encodes a string into tokens.
+
+ >>> enc.encode("hello world")
+ [388, 372]
+ """
+ # Use the regex to split the text into (approximately) words
+ words = self._pat.findall(text)
+ tokens = []
+ for word in words:
+ # Turn each word into tokens, using the byte pair encoding algorithm
+ word_bytes = word.encode("utf-8")
+ word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise)
+ tokens.extend(word_tokens)
+ return tokens
+
+ def decode_bytes(self, tokens: list[int]) -> bytes:
+ """Decodes a list of tokens into bytes.
+
+ >>> enc.decode_bytes([388, 372])
+ b'hello world'
+ """
+ return b"".join(self._decoder[token] for token in tokens)
+
+ def decode(self, tokens: list[int]) -> str:
+ """Decodes a list of tokens into a string.
+
+ Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
+ the invalid bytes with the replacement character "�".
+
+ >>> enc.decode([388, 372])
+ 'hello world'
+ """
+ return self.decode_bytes(tokens).decode("utf-8", errors="replace")
+
+ def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
+ """Decodes a list of tokens into a list of bytes.
+
+ Useful for visualising how a string is tokenised.
+
+ >>> enc.decode_tokens_bytes([388, 372])
+ [b'hello', b' world']
+ """
+ return [self._decoder[token] for token in tokens]
+
+ @staticmethod
+ def train(training_data: str, vocab_size: int, pat_str: str):
+ """Train a BPE tokeniser on some data!"""
+ mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
+ return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks)
+
+ @staticmethod
+ def from_tiktoken(encoding):
+ if isinstance(encoding, str):
+ encoding = tiktoken.get_encoding(encoding)
+ return SimpleBytePairEncoding(
+ pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks
+ )
+
+
+def bpe_encode(
+ mergeable_ranks: dict[bytes, int], input: bytes, visualise: str | None = "colour"
+) -> list[int]:
+ parts = [bytes([b]) for b in input]
+ while True:
+ # See the intermediate merges play out!
+ if visualise:
+ if visualise in ["colour", "color"]:
+ visualise_tokens(parts)
+ elif visualise == "simple":
+ print(parts)
+
+ # Iterate over all pairs and find the pair we want to merge the most
+ min_idx = None
+ min_rank = None
+ for i, pair in enumerate(zip(parts[:-1], parts[1:])):
+ rank = mergeable_ranks.get(pair[0] + pair[1])
+ if rank is not None and (min_rank is None or rank < min_rank):
+ min_idx = i
+ min_rank = rank
+
+ # If there were no pairs we could merge, we're done!
+ if min_rank is None:
+ break
+ assert min_idx is not None
+
+ # Otherwise, merge that pair and leave the rest unchanged. Then repeat.
+ parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
+
+ if visualise:
+ print()
+
+ tokens = [mergeable_ranks[part] for part in parts]
+ return tokens
+
+
+def bpe_train(
+ data: str, vocab_size: int, pat_str: str, visualise: str | None = "colour"
+) -> dict[bytes, int]:
+ # First, add tokens for each individual byte value
+ if vocab_size < 2**8:
+ raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
+ ranks = {}
+ for i in range(2**8):
+ ranks[bytes([i])] = i
+
+ # Splinter up our data into lists of bytes
+ # data = "Hello world"
+ # words = [
+ # [b'H', b'e', b'l', b'l', b'o'],
+ # [b' ', b'w', b'o', b'r', b'l', b'd']
+ # ]
+ words: list[list[bytes]] = [
+ [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
+ ]
+
+ # Now, use our data to figure out which merges we should make
+ while len(ranks) < vocab_size:
+ # Find the most common pair. This will become our next token
+ stats = collections.Counter()
+ for piece in words:
+ for pair in zip(piece[:-1], piece[1:]):
+ stats[pair] += 1
+
+ most_common_pair = max(stats, key=lambda x: stats[x])
+ token_bytes = most_common_pair[0] + most_common_pair[1]
+ token = len(ranks)
+ # Add the new token!
+ ranks[token_bytes] = token
+
+ # Now merge that most common pair in all the words. That is, update our training data
+ # to reflect our decision to make that pair into a new token.
+ new_words = []
+ for word in words:
+ new_word = []
+ i = 0
+ while i < len(word) - 1:
+ if (word[i], word[i + 1]) == most_common_pair:
+ # We found our pair! Merge it
+ new_word.append(token_bytes)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ if i == len(word) - 1:
+ new_word.append(word[i])
+ new_words.append(new_word)
+ words = new_words
+
+ # See the intermediate merges play out!
+ if visualise:
+ print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}")
+ print(f"So we made {token_bytes} our {len(ranks)}th token")
+ if visualise in ["colour", "color"]:
+ print("Now the first fifty words in our training data look like:")
+ visualise_tokens([token for word in words[:50] for token in word])
+ elif visualise == "simple":
+ print("Now the first twenty words in our training data look like:")
+ for word in words[:20]:
+ print(word)
+ print("\n")
+
+ return ranks
+
+
+def visualise_tokens(token_values: list[bytes]) -> None:
+ background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]]
+ # If token boundaries do not occur at unicode character boundaries, it's unclear how best to
+ # visualise the token. Here, we'll just use the unicode replacement character to represent some
+ # fraction of a character.
+ unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values]
+
+ running_length = 0
+ last_color = None
+ for token in unicode_token_values:
+ color = background[running_length % len(background)]
+ if color == last_color:
+ color = background[(running_length + 1) % len(background)]
+ assert color != last_color
+ last_color = color
+ running_length += len(token)
+ print(color + token, end="")
+ print("\u001b[0m")
+
+
+def train_simple_encoding():
+ gpt2_pattern = (
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+ )
+ with open(__file__) as f:
+ data = f.read()
+
+ enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern)
+
+ print("This is the sequence of merges performed in order to encode 'hello world':")
+ tokens = enc.encode("hello world")
+ assert enc.decode(tokens) == "hello world"
+ assert enc.decode_bytes(tokens) == b"hello world"
+ assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"]
+
+ return enc