"""
The `vecs.experimental.adapter.text` module provides adapter steps specifically designed for
handling text data. It provides two main classes, `TextEmbedding` and `ParagraphChunker`.
All public classes, enums, and functions are re-exported by `vecs.adapters` module.
"""
from typing import Any, Dict, Generator, Iterable, Literal, Optional, Tuple
from flupy import flu
from vecs.exc import MissingDependency
from .base import AdapterContext, AdapterStep
TextEmbeddingModel = Literal[
"all-mpnet-base-v2",
"multi-qa-mpnet-base-dot-v1",
"all-distilroberta-v1",
"all-MiniLM-L12-v2",
"multi-qa-distilbert-cos-v1",
"mixedbread-ai/mxbai-embed-large-v1",
"multi-qa-MiniLM-L6-cos-v1",
"paraphrase-multilingual-mpnet-base-v2",
"paraphrase-albert-small-v2",
"paraphrase-multilingual-MiniLM-L12-v2",
"paraphrase-MiniLM-L3-v2",
"distiluse-base-multilingual-cased-v1",
"distiluse-base-multilingual-cased-v2",
]
class TextEmbedding(AdapterStep):
"""
TextEmbedding is an AdapterStep that converts text media into
embeddings using a specified sentence transformers model.
"""
def __init__(
self,
*,
model: TextEmbeddingModel,
batch_size: int = 8,
use_auth_token: str = None,
):
"""
Initializes the TextEmbedding adapter with a sentence transformers model.
Args:
model (TextEmbeddingModel): The sentence transformers model to use for embeddings.
batch_size (int): The number of records to encode simultaneously.
use_auth_token (str): The HuggingFace Hub auth token to use for private models.
Raises:
MissingDependency: If the sentence_transformers library is not installed.
"""
try:
from sentence_transformers import SentenceTransformer as ST
except ImportError:
raise MissingDependency(
"Missing feature vecs[text_embedding]. Hint: `pip install 'vecs[text_embedding]'`"
)
self.model = ST(model, use_auth_token=use_auth_token)
self._exported_dimension = (
self.model.get_sentence_embedding_dimension()
)
self.batch_size = batch_size
@property
def exported_dimension(self) -> Optional[int]:
"""
Returns the dimension of the embeddings produced by the sentence transformers model.
Returns:
int: The dimension of the embeddings.
"""
return self._exported_dimension
def __call__(
self,
records: Iterable[Tuple[str, Any, Optional[Dict]]],
adapter_context: AdapterContext, # pyright: ignore
) -> Generator[Tuple[str, Any, Dict], None, None]:
"""
Converts each media in the records to an embedding and yields the result.
Args:
records: Iterable of tuples each containing an id, a media and an optional dict.
adapter_context: Context of the adapter.
Yields:
Tuple[str, Any, Dict]: The id, the embedding, and the metadata.
"""
for batch in flu(records).chunk(self.batch_size):
batch_records = [x for x in batch]
media = [text for _, text, _ in batch_records]
embeddings = self.model.encode(media, normalize_embeddings=True)
for (id, _, metadata), embedding in zip(batch_records, embeddings): # type: ignore
yield (id, embedding, metadata or {})
class ParagraphChunker(AdapterStep):
"""
ParagraphChunker is an AdapterStep that splits text media into
paragraphs and yields each paragraph as a separate record.
"""
def __init__(self, *, skip_during_query: bool):
"""
Initializes the ParagraphChunker adapter.
Args:
skip_during_query (bool): Whether to skip chunking during querying.
"""
self.skip_during_query = skip_during_query
def __call__(
self,
records: Iterable[Tuple[str, Any, Optional[Dict]]],
adapter_context: AdapterContext,
) -> Generator[Tuple[str, Any, Dict], None, None]:
"""
Splits each media in the records into paragraphs and yields each paragraph
as a separate record. If the `skip_during_query` attribute is set to True,
this step is skipped during querying.
Args:
records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a media and an optional dict.
adapter_context (AdapterContext): Context of the adapter.
Yields:
Tuple[str, Any, Dict]: The id appended with paragraph index, the paragraph, and the metadata.
"""
if (
adapter_context == AdapterContext("query")
and self.skip_during_query
):
for id, media, metadata in records:
yield (id, media, metadata or {})
else:
for id, media, metadata in records:
paragraphs = media.split("\n\n")
for paragraph_ix, paragraph in enumerate(paragraphs):
yield (
f"{id}_para_{str(paragraph_ix).zfill(3)}",
paragraph,
metadata or {},
)