from __future__ import annotations import functools import importlib import pkgutil import threading from typing import Any, Callable, Sequence import tiktoken_ext import tiktoken from tiktoken.core import Encoding _lock = threading.RLock() ENCODINGS: dict[str, Encoding] = {} ENCODING_CONSTRUCTORS: dict[str, Callable[[], dict[str, Any]]] | None = None @functools.lru_cache def _available_plugin_modules() -> Sequence[str]: # tiktoken_ext is a namespace package # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes # - we use namespace package pattern so `pkgutil.iter_modules` is fast # - it's a separate top-level package because namespace subpackages of non-namespace # packages don't quite do what you want with editable installs mods = [] plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".") for _, mod_name, _ in plugin_mods: mods.append(mod_name) return mods def _find_constructors() -> None: global ENCODING_CONSTRUCTORS with _lock: if ENCODING_CONSTRUCTORS is not None: return ENCODING_CONSTRUCTORS = {} try: for mod_name in _available_plugin_modules(): mod = importlib.import_module(mod_name) try: constructors = mod.ENCODING_CONSTRUCTORS except AttributeError as e: raise ValueError( f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" ) from e for enc_name, constructor in constructors.items(): if enc_name in ENCODING_CONSTRUCTORS: raise ValueError( f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" ) ENCODING_CONSTRUCTORS[enc_name] = constructor except Exception: # Ensure we idempotently raise errors ENCODING_CONSTRUCTORS = None raise def get_encoding(encoding_name: str) -> Encoding: if not isinstance(encoding_name, str): raise ValueError(f"Expected a string in get_encoding, got {type(encoding_name)}") if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] with _lock: if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] if ENCODING_CONSTRUCTORS is None: _find_constructors() assert ENCODING_CONSTRUCTORS is not None if encoding_name not in ENCODING_CONSTRUCTORS: raise ValueError( f"Unknown encoding {encoding_name}.\n" f"Plugins found: {_available_plugin_modules()}\n" f"tiktoken version: {tiktoken.__version__} (are you on latest?)" ) constructor = ENCODING_CONSTRUCTORS[encoding_name] enc = Encoding(**constructor()) ENCODINGS[encoding_name] = enc return enc def list_encoding_names() -> list[str]: with _lock: if ENCODING_CONSTRUCTORS is None: _find_constructors() assert ENCODING_CONSTRUCTORS is not None return list(ENCODING_CONSTRUCTORS)