diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/vecs/adapter/text.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/vecs/adapter/text.py')
-rwxr-xr-x | R2R/r2r/vecs/adapter/text.py | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/adapter/text.py b/R2R/r2r/vecs/adapter/text.py new file mode 100755 index 00000000..78ae7732 --- /dev/null +++ b/R2R/r2r/vecs/adapter/text.py @@ -0,0 +1,151 @@ +""" +The `vecs.experimental.adapter.text` module provides adapter steps specifically designed for +handling text data. It provides two main classes, `TextEmbedding` and `ParagraphChunker`. + +All public classes, enums, and functions are re-exported by `vecs.adapters` module. +""" + +from typing import Any, Dict, Generator, Iterable, Literal, Optional, Tuple + +from flupy import flu +from vecs.exc import MissingDependency + +from .base import AdapterContext, AdapterStep + +TextEmbeddingModel = Literal[ + "all-mpnet-base-v2", + "multi-qa-mpnet-base-dot-v1", + "all-distilroberta-v1", + "all-MiniLM-L12-v2", + "multi-qa-distilbert-cos-v1", + "mixedbread-ai/mxbai-embed-large-v1", + "multi-qa-MiniLM-L6-cos-v1", + "paraphrase-multilingual-mpnet-base-v2", + "paraphrase-albert-small-v2", + "paraphrase-multilingual-MiniLM-L12-v2", + "paraphrase-MiniLM-L3-v2", + "distiluse-base-multilingual-cased-v1", + "distiluse-base-multilingual-cased-v2", +] + + +class TextEmbedding(AdapterStep): + """ + TextEmbedding is an AdapterStep that converts text media into + embeddings using a specified sentence transformers model. + """ + + def __init__( + self, + *, + model: TextEmbeddingModel, + batch_size: int = 8, + use_auth_token: str = None, + ): + """ + Initializes the TextEmbedding adapter with a sentence transformers model. + + Args: + model (TextEmbeddingModel): The sentence transformers model to use for embeddings. + batch_size (int): The number of records to encode simultaneously. + use_auth_token (str): The HuggingFace Hub auth token to use for private models. + + Raises: + MissingDependency: If the sentence_transformers library is not installed. + """ + try: + from sentence_transformers import SentenceTransformer as ST + except ImportError: + raise MissingDependency( + "Missing feature vecs[text_embedding]. Hint: `pip install 'vecs[text_embedding]'`" + ) + + self.model = ST(model, use_auth_token=use_auth_token) + self._exported_dimension = ( + self.model.get_sentence_embedding_dimension() + ) + self.batch_size = batch_size + + @property + def exported_dimension(self) -> Optional[int]: + """ + Returns the dimension of the embeddings produced by the sentence transformers model. + + Returns: + int: The dimension of the embeddings. + """ + return self._exported_dimension + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, # pyright: ignore + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Converts each media in the records to an embedding and yields the result. + + Args: + records: Iterable of tuples each containing an id, a media and an optional dict. + adapter_context: Context of the adapter. + + Yields: + Tuple[str, Any, Dict]: The id, the embedding, and the metadata. + """ + for batch in flu(records).chunk(self.batch_size): + batch_records = [x for x in batch] + media = [text for _, text, _ in batch_records] + + embeddings = self.model.encode(media, normalize_embeddings=True) + + for (id, _, metadata), embedding in zip(batch_records, embeddings): # type: ignore + yield (id, embedding, metadata or {}) + + +class ParagraphChunker(AdapterStep): + """ + ParagraphChunker is an AdapterStep that splits text media into + paragraphs and yields each paragraph as a separate record. + """ + + def __init__(self, *, skip_during_query: bool): + """ + Initializes the ParagraphChunker adapter. + + Args: + skip_during_query (bool): Whether to skip chunking during querying. + """ + self.skip_during_query = skip_during_query + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Splits each media in the records into paragraphs and yields each paragraph + as a separate record. If the `skip_during_query` attribute is set to True, + this step is skipped during querying. + + Args: + records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a media and an optional dict. + adapter_context (AdapterContext): Context of the adapter. + + Yields: + Tuple[str, Any, Dict]: The id appended with paragraph index, the paragraph, and the metadata. + """ + if ( + adapter_context == AdapterContext("query") + and self.skip_during_query + ): + for id, media, metadata in records: + yield (id, media, metadata or {}) + else: + for id, media, metadata in records: + paragraphs = media.split("\n\n") + + for paragraph_ix, paragraph in enumerate(paragraphs): + yield ( + f"{id}_para_{str(paragraph_ix).zfill(3)}", + paragraph, + metadata or {}, + ) |