diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py b/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py new file mode 100644 index 00000000..1f34e3ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py @@ -0,0 +1,151 @@ +from typing import Dict, Iterator, List, Optional, Union + +from tokenizers import AddedToken, Tokenizer, decoders, trainers +from tokenizers.models import WordPiece +from tokenizers.normalizers import BertNormalizer +from tokenizers.pre_tokenizers import BertPreTokenizer +from tokenizers.processors import BertProcessing + +from .base_tokenizer import BaseTokenizer + + +class BertWordPieceTokenizer(BaseTokenizer): + """Bert WordPiece Tokenizer""" + + def __init__( + self, + vocab: Optional[Union[str, Dict[str, int]]] = None, + unk_token: Union[str, AddedToken] = "[UNK]", + sep_token: Union[str, AddedToken] = "[SEP]", + cls_token: Union[str, AddedToken] = "[CLS]", + pad_token: Union[str, AddedToken] = "[PAD]", + mask_token: Union[str, AddedToken] = "[MASK]", + clean_text: bool = True, + handle_chinese_chars: bool = True, + strip_accents: Optional[bool] = None, + lowercase: bool = True, + wordpieces_prefix: str = "##", + ): + if vocab is not None: + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token))) + else: + tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) + + # Let the tokenizer know about special tokens if they are part of the vocab + if tokenizer.token_to_id(str(unk_token)) is not None: + tokenizer.add_special_tokens([str(unk_token)]) + if tokenizer.token_to_id(str(sep_token)) is not None: + tokenizer.add_special_tokens([str(sep_token)]) + if tokenizer.token_to_id(str(cls_token)) is not None: + tokenizer.add_special_tokens([str(cls_token)]) + if tokenizer.token_to_id(str(pad_token)) is not None: + tokenizer.add_special_tokens([str(pad_token)]) + if tokenizer.token_to_id(str(mask_token)) is not None: + tokenizer.add_special_tokens([str(mask_token)]) + + tokenizer.normalizer = BertNormalizer( + clean_text=clean_text, + handle_chinese_chars=handle_chinese_chars, + strip_accents=strip_accents, + lowercase=lowercase, + ) + tokenizer.pre_tokenizer = BertPreTokenizer() + + if vocab is not None: + sep_token_id = tokenizer.token_to_id(str(sep_token)) + if sep_token_id is None: + raise TypeError("sep_token not found in the vocabulary") + cls_token_id = tokenizer.token_to_id(str(cls_token)) + if cls_token_id is None: + raise TypeError("cls_token not found in the vocabulary") + + tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id)) + tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) + + parameters = { + "model": "BertWordPiece", + "unk_token": unk_token, + "sep_token": sep_token, + "cls_token": cls_token, + "pad_token": pad_token, + "mask_token": mask_token, + "clean_text": clean_text, + "handle_chinese_chars": handle_chinese_chars, + "strip_accents": strip_accents, + "lowercase": lowercase, + "wordpieces_prefix": wordpieces_prefix, + } + + super().__init__(tokenizer, parameters) + + @staticmethod + def from_file(vocab: str, **kwargs): + vocab = WordPiece.read_file(vocab) + return BertWordPieceTokenizer(vocab, **kwargs) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + ): + """Train the model using the given files""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + if isinstance(files, str): + files = [files] + self._tokenizer.train(files, trainer=trainer) + + def train_from_iterator( + self, + iterator: Union[Iterator[str], Iterator[Iterator[str]]], + vocab_size: int = 30000, + min_frequency: int = 2, + limit_alphabet: int = 1000, + initial_alphabet: List[str] = [], + special_tokens: List[Union[str, AddedToken]] = [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + ], + show_progress: bool = True, + wordpieces_prefix: str = "##", + length: Optional[int] = None, + ): + """Train the model using the given iterator""" + + trainer = trainers.WordPieceTrainer( + vocab_size=vocab_size, + min_frequency=min_frequency, + limit_alphabet=limit_alphabet, + initial_alphabet=initial_alphabet, + special_tokens=special_tokens, + show_progress=show_progress, + continuing_subword_prefix=wordpieces_prefix, + ) + self._tokenizer.train_from_iterator( + iterator, + trainer=trainer, + length=length, + ) |