aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py')
-rw-r--r--.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py b/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py
new file mode 100644
index 00000000..1f34e3ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/tokenizers/implementations/bert_wordpiece.py
@@ -0,0 +1,151 @@
+from typing import Dict, Iterator, List, Optional, Union
+
+from tokenizers import AddedToken, Tokenizer, decoders, trainers
+from tokenizers.models import WordPiece
+from tokenizers.normalizers import BertNormalizer
+from tokenizers.pre_tokenizers import BertPreTokenizer
+from tokenizers.processors import BertProcessing
+
+from .base_tokenizer import BaseTokenizer
+
+
+class BertWordPieceTokenizer(BaseTokenizer):
+ """Bert WordPiece Tokenizer"""
+
+ def __init__(
+ self,
+ vocab: Optional[Union[str, Dict[str, int]]] = None,
+ unk_token: Union[str, AddedToken] = "[UNK]",
+ sep_token: Union[str, AddedToken] = "[SEP]",
+ cls_token: Union[str, AddedToken] = "[CLS]",
+ pad_token: Union[str, AddedToken] = "[PAD]",
+ mask_token: Union[str, AddedToken] = "[MASK]",
+ clean_text: bool = True,
+ handle_chinese_chars: bool = True,
+ strip_accents: Optional[bool] = None,
+ lowercase: bool = True,
+ wordpieces_prefix: str = "##",
+ ):
+ if vocab is not None:
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token)))
+ else:
+ tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
+
+ # Let the tokenizer know about special tokens if they are part of the vocab
+ if tokenizer.token_to_id(str(unk_token)) is not None:
+ tokenizer.add_special_tokens([str(unk_token)])
+ if tokenizer.token_to_id(str(sep_token)) is not None:
+ tokenizer.add_special_tokens([str(sep_token)])
+ if tokenizer.token_to_id(str(cls_token)) is not None:
+ tokenizer.add_special_tokens([str(cls_token)])
+ if tokenizer.token_to_id(str(pad_token)) is not None:
+ tokenizer.add_special_tokens([str(pad_token)])
+ if tokenizer.token_to_id(str(mask_token)) is not None:
+ tokenizer.add_special_tokens([str(mask_token)])
+
+ tokenizer.normalizer = BertNormalizer(
+ clean_text=clean_text,
+ handle_chinese_chars=handle_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=lowercase,
+ )
+ tokenizer.pre_tokenizer = BertPreTokenizer()
+
+ if vocab is not None:
+ sep_token_id = tokenizer.token_to_id(str(sep_token))
+ if sep_token_id is None:
+ raise TypeError("sep_token not found in the vocabulary")
+ cls_token_id = tokenizer.token_to_id(str(cls_token))
+ if cls_token_id is None:
+ raise TypeError("cls_token not found in the vocabulary")
+
+ tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id))
+ tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
+
+ parameters = {
+ "model": "BertWordPiece",
+ "unk_token": unk_token,
+ "sep_token": sep_token,
+ "cls_token": cls_token,
+ "pad_token": pad_token,
+ "mask_token": mask_token,
+ "clean_text": clean_text,
+ "handle_chinese_chars": handle_chinese_chars,
+ "strip_accents": strip_accents,
+ "lowercase": lowercase,
+ "wordpieces_prefix": wordpieces_prefix,
+ }
+
+ super().__init__(tokenizer, parameters)
+
+ @staticmethod
+ def from_file(vocab: str, **kwargs):
+ vocab = WordPiece.read_file(vocab)
+ return BertWordPieceTokenizer(vocab, **kwargs)
+
+ def train(
+ self,
+ files: Union[str, List[str]],
+ vocab_size: int = 30000,
+ min_frequency: int = 2,
+ limit_alphabet: int = 1000,
+ initial_alphabet: List[str] = [],
+ special_tokens: List[Union[str, AddedToken]] = [
+ "[PAD]",
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[MASK]",
+ ],
+ show_progress: bool = True,
+ wordpieces_prefix: str = "##",
+ ):
+ """Train the model using the given files"""
+
+ trainer = trainers.WordPieceTrainer(
+ vocab_size=vocab_size,
+ min_frequency=min_frequency,
+ limit_alphabet=limit_alphabet,
+ initial_alphabet=initial_alphabet,
+ special_tokens=special_tokens,
+ show_progress=show_progress,
+ continuing_subword_prefix=wordpieces_prefix,
+ )
+ if isinstance(files, str):
+ files = [files]
+ self._tokenizer.train(files, trainer=trainer)
+
+ def train_from_iterator(
+ self,
+ iterator: Union[Iterator[str], Iterator[Iterator[str]]],
+ vocab_size: int = 30000,
+ min_frequency: int = 2,
+ limit_alphabet: int = 1000,
+ initial_alphabet: List[str] = [],
+ special_tokens: List[Union[str, AddedToken]] = [
+ "[PAD]",
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[MASK]",
+ ],
+ show_progress: bool = True,
+ wordpieces_prefix: str = "##",
+ length: Optional[int] = None,
+ ):
+ """Train the model using the given iterator"""
+
+ trainer = trainers.WordPieceTrainer(
+ vocab_size=vocab_size,
+ min_frequency=min_frequency,
+ limit_alphabet=limit_alphabet,
+ initial_alphabet=initial_alphabet,
+ special_tokens=special_tokens,
+ show_progress=show_progress,
+ continuing_subword_prefix=wordpieces_prefix,
+ )
+ self._tokenizer.train_from_iterator(
+ iterator,
+ trainer=trainer,
+ length=length,
+ )