about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py')
-rw-r--r--.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py b/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py
new file mode 100644
index 00000000..1f34e3ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py
@@ -0,0 +1,151 @@
+from typing import Dict, Iterator, List, Optional, Union
+
+from tokenizers import AddedToken, Tokenizer, decoders, trainers
+from tokenizers.models import WordPiece
+from tokenizers.normalizers import BertNormalizer
+from tokenizers.pre_tokenizers import BertPreTokenizer
+from tokenizers.processors import BertProcessing
+
+from .base_tokenizer import BaseTokenizer
+
+
+class BertWordPieceTokenizer(BaseTokenizer):
+    """Bert WordPiece Tokenizer"""
+
+    def __init__(
+        self,
+        vocab: Optional[Union[str, Dict[str, int]]] = None,
+        unk_token: Union[str, AddedToken] = "[UNK]",
+        sep_token: Union[str, AddedToken] = "[SEP]",
+        cls_token: Union[str, AddedToken] = "[CLS]",
+        pad_token: Union[str, AddedToken] = "[PAD]",
+        mask_token: Union[str, AddedToken] = "[MASK]",
+        clean_text: bool = True,
+        handle_chinese_chars: bool = True,
+        strip_accents: Optional[bool] = None,
+        lowercase: bool = True,
+        wordpieces_prefix: str = "##",
+    ):
+        if vocab is not None:
+            tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token)))
+        else:
+            tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
+
+        # Let the tokenizer know about special tokens if they are part of the vocab
+        if tokenizer.token_to_id(str(unk_token)) is not None:
+            tokenizer.add_special_tokens([str(unk_token)])
+        if tokenizer.token_to_id(str(sep_token)) is not None:
+            tokenizer.add_special_tokens([str(sep_token)])
+        if tokenizer.token_to_id(str(cls_token)) is not None:
+            tokenizer.add_special_tokens([str(cls_token)])
+        if tokenizer.token_to_id(str(pad_token)) is not None:
+            tokenizer.add_special_tokens([str(pad_token)])
+        if tokenizer.token_to_id(str(mask_token)) is not None:
+            tokenizer.add_special_tokens([str(mask_token)])
+
+        tokenizer.normalizer = BertNormalizer(
+            clean_text=clean_text,
+            handle_chinese_chars=handle_chinese_chars,
+            strip_accents=strip_accents,
+            lowercase=lowercase,
+        )
+        tokenizer.pre_tokenizer = BertPreTokenizer()
+
+        if vocab is not None:
+            sep_token_id = tokenizer.token_to_id(str(sep_token))
+            if sep_token_id is None:
+                raise TypeError("sep_token not found in the vocabulary")
+            cls_token_id = tokenizer.token_to_id(str(cls_token))
+            if cls_token_id is None:
+                raise TypeError("cls_token not found in the vocabulary")
+
+            tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id))
+        tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
+
+        parameters = {
+            "model": "BertWordPiece",
+            "unk_token": unk_token,
+            "sep_token": sep_token,
+            "cls_token": cls_token,
+            "pad_token": pad_token,
+            "mask_token": mask_token,
+            "clean_text": clean_text,
+            "handle_chinese_chars": handle_chinese_chars,
+            "strip_accents": strip_accents,
+            "lowercase": lowercase,
+            "wordpieces_prefix": wordpieces_prefix,
+        }
+
+        super().__init__(tokenizer, parameters)
+
+    @staticmethod
+    def from_file(vocab: str, **kwargs):
+        vocab = WordPiece.read_file(vocab)
+        return BertWordPieceTokenizer(vocab, **kwargs)
+
+    def train(
+        self,
+        files: Union[str, List[str]],
+        vocab_size: int = 30000,
+        min_frequency: int = 2,
+        limit_alphabet: int = 1000,
+        initial_alphabet: List[str] = [],
+        special_tokens: List[Union[str, AddedToken]] = [
+            "[PAD]",
+            "[UNK]",
+            "[CLS]",
+            "[SEP]",
+            "[MASK]",
+        ],
+        show_progress: bool = True,
+        wordpieces_prefix: str = "##",
+    ):
+        """Train the model using the given files"""
+
+        trainer = trainers.WordPieceTrainer(
+            vocab_size=vocab_size,
+            min_frequency=min_frequency,
+            limit_alphabet=limit_alphabet,
+            initial_alphabet=initial_alphabet,
+            special_tokens=special_tokens,
+            show_progress=show_progress,
+            continuing_subword_prefix=wordpieces_prefix,
+        )
+        if isinstance(files, str):
+            files = [files]
+        self._tokenizer.train(files, trainer=trainer)
+
+    def train_from_iterator(
+        self,
+        iterator: Union[Iterator[str], Iterator[Iterator[str]]],
+        vocab_size: int = 30000,
+        min_frequency: int = 2,
+        limit_alphabet: int = 1000,
+        initial_alphabet: List[str] = [],
+        special_tokens: List[Union[str, AddedToken]] = [
+            "[PAD]",
+            "[UNK]",
+            "[CLS]",
+            "[SEP]",
+            "[MASK]",
+        ],
+        show_progress: bool = True,
+        wordpieces_prefix: str = "##",
+        length: Optional[int] = None,
+    ):
+        """Train the model using the given iterator"""
+
+        trainer = trainers.WordPieceTrainer(
+            vocab_size=vocab_size,
+            min_frequency=min_frequency,
+            limit_alphabet=limit_alphabet,
+            initial_alphabet=initial_alphabet,
+            special_tokens=special_tokens,
+            show_progress=show_progress,
+            continuing_subword_prefix=wordpieces_prefix,
+        )
+        self._tokenizer.train_from_iterator(
+            iterator,
+            trainer=trainer,
+            length=length,
+        )