diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/tiktoken/_educational.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/tiktoken/_educational.py | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/tiktoken/_educational.py b/.venv/lib/python3.12/site-packages/tiktoken/_educational.py new file mode 100644 index 00000000..317e7756 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/tiktoken/_educational.py @@ -0,0 +1,223 @@ +"""This is an educational implementation of the byte pair encoding algorithm.""" + +from __future__ import annotations + +import collections + +import regex + +import tiktoken + + +class SimpleBytePairEncoding: + def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: + """Creates an Encoding object.""" + # A regex pattern string that is used to split the input text + self.pat_str = pat_str + # A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority + self.mergeable_ranks = mergeable_ranks + + self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()} + self._pat = regex.compile(pat_str) + + def encode(self, text: str, visualise: str | None = "colour") -> list[int]: + """Encodes a string into tokens. + + >>> enc.encode("hello world") + [388, 372] + """ + # Use the regex to split the text into (approximately) words + words = self._pat.findall(text) + tokens = [] + for word in words: + # Turn each word into tokens, using the byte pair encoding algorithm + word_bytes = word.encode("utf-8") + word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise) + tokens.extend(word_tokens) + return tokens + + def decode_bytes(self, tokens: list[int]) -> bytes: + """Decodes a list of tokens into bytes. + + >>> enc.decode_bytes([388, 372]) + b'hello world' + """ + return b"".join(self._decoder[token] for token in tokens) + + def decode(self, tokens: list[int]) -> str: + """Decodes a list of tokens into a string. + + Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace + the invalid bytes with the replacement character "�". + + >>> enc.decode([388, 372]) + 'hello world' + """ + return self.decode_bytes(tokens).decode("utf-8", errors="replace") + + def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: + """Decodes a list of tokens into a list of bytes. + + Useful for visualising how a string is tokenised. + + >>> enc.decode_tokens_bytes([388, 372]) + [b'hello', b' world'] + """ + return [self._decoder[token] for token in tokens] + + @staticmethod + def train(training_data: str, vocab_size: int, pat_str: str): + """Train a BPE tokeniser on some data!""" + mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str) + return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks) + + @staticmethod + def from_tiktoken(encoding): + if isinstance(encoding, str): + encoding = tiktoken.get_encoding(encoding) + return SimpleBytePairEncoding( + pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks + ) + + +def bpe_encode( + mergeable_ranks: dict[bytes, int], input: bytes, visualise: str | None = "colour" +) -> list[int]: + parts = [bytes([b]) for b in input] + while True: + # See the intermediate merges play out! + if visualise: + if visualise in ["colour", "color"]: + visualise_tokens(parts) + elif visualise == "simple": + print(parts) + + # Iterate over all pairs and find the pair we want to merge the most + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + + # If there were no pairs we could merge, we're done! + if min_rank is None: + break + assert min_idx is not None + + # Otherwise, merge that pair and leave the rest unchanged. Then repeat. + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + + if visualise: + print() + + tokens = [mergeable_ranks[part] for part in parts] + return tokens + + +def bpe_train( + data: str, vocab_size: int, pat_str: str, visualise: str | None = "colour" +) -> dict[bytes, int]: + # First, add tokens for each individual byte value + if vocab_size < 2**8: + raise ValueError("vocab_size must be at least 256, so we can encode all bytes") + ranks = {} + for i in range(2**8): + ranks[bytes([i])] = i + + # Splinter up our data into lists of bytes + # data = "Hello world" + # words = [ + # [b'H', b'e', b'l', b'l', b'o'], + # [b' ', b'w', b'o', b'r', b'l', b'd'] + # ] + words: list[list[bytes]] = [ + [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data) + ] + + # Now, use our data to figure out which merges we should make + while len(ranks) < vocab_size: + # Find the most common pair. This will become our next token + stats = collections.Counter() + for piece in words: + for pair in zip(piece[:-1], piece[1:]): + stats[pair] += 1 + + most_common_pair = max(stats, key=lambda x: stats[x]) + token_bytes = most_common_pair[0] + most_common_pair[1] + token = len(ranks) + # Add the new token! + ranks[token_bytes] = token + + # Now merge that most common pair in all the words. That is, update our training data + # to reflect our decision to make that pair into a new token. + new_words = [] + for word in words: + new_word = [] + i = 0 + while i < len(word) - 1: + if (word[i], word[i + 1]) == most_common_pair: + # We found our pair! Merge it + new_word.append(token_bytes) + i += 2 + else: + new_word.append(word[i]) + i += 1 + if i == len(word) - 1: + new_word.append(word[i]) + new_words.append(new_word) + words = new_words + + # See the intermediate merges play out! + if visualise: + print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}") + print(f"So we made {token_bytes} our {len(ranks)}th token") + if visualise in ["colour", "color"]: + print("Now the first fifty words in our training data look like:") + visualise_tokens([token for word in words[:50] for token in word]) + elif visualise == "simple": + print("Now the first twenty words in our training data look like:") + for word in words[:20]: + print(word) + print("\n") + + return ranks + + +def visualise_tokens(token_values: list[bytes]) -> None: + background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] + # If token boundaries do not occur at unicode character boundaries, it's unclear how best to + # visualise the token. Here, we'll just use the unicode replacement character to represent some + # fraction of a character. + unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values] + + running_length = 0 + last_color = None + for token in unicode_token_values: + color = background[running_length % len(background)] + if color == last_color: + color = background[(running_length + 1) % len(background)] + assert color != last_color + last_color = color + running_length += len(token) + print(color + token, end="") + print("\u001b[0m") + + +def train_simple_encoding(): + gpt2_pattern = ( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + with open(__file__) as f: + data = f.read() + + enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern) + + print("This is the sequence of merges performed in order to encode 'hello world':") + tokens = enc.encode("hello world") + assert enc.decode(tokens) == "hello world" + assert enc.decode_bytes(tokens) == b"hello world" + assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"] + + return enc |