aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/vecs/adapter/text.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/vecs/adapter/text.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/vecs/adapter/text.py')
-rwxr-xr-xR2R/r2r/vecs/adapter/text.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/adapter/text.py b/R2R/r2r/vecs/adapter/text.py
new file mode 100755
index 00000000..78ae7732
--- /dev/null
+++ b/R2R/r2r/vecs/adapter/text.py
@@ -0,0 +1,151 @@
+"""
+The `vecs.experimental.adapter.text` module provides adapter steps specifically designed for
+handling text data. It provides two main classes, `TextEmbedding` and `ParagraphChunker`.
+
+All public classes, enums, and functions are re-exported by `vecs.adapters` module.
+"""
+
+from typing import Any, Dict, Generator, Iterable, Literal, Optional, Tuple
+
+from flupy import flu
+from vecs.exc import MissingDependency
+
+from .base import AdapterContext, AdapterStep
+
+TextEmbeddingModel = Literal[
+ "all-mpnet-base-v2",
+ "multi-qa-mpnet-base-dot-v1",
+ "all-distilroberta-v1",
+ "all-MiniLM-L12-v2",
+ "multi-qa-distilbert-cos-v1",
+ "mixedbread-ai/mxbai-embed-large-v1",
+ "multi-qa-MiniLM-L6-cos-v1",
+ "paraphrase-multilingual-mpnet-base-v2",
+ "paraphrase-albert-small-v2",
+ "paraphrase-multilingual-MiniLM-L12-v2",
+ "paraphrase-MiniLM-L3-v2",
+ "distiluse-base-multilingual-cased-v1",
+ "distiluse-base-multilingual-cased-v2",
+]
+
+
+class TextEmbedding(AdapterStep):
+ """
+ TextEmbedding is an AdapterStep that converts text media into
+ embeddings using a specified sentence transformers model.
+ """
+
+ def __init__(
+ self,
+ *,
+ model: TextEmbeddingModel,
+ batch_size: int = 8,
+ use_auth_token: str = None,
+ ):
+ """
+ Initializes the TextEmbedding adapter with a sentence transformers model.
+
+ Args:
+ model (TextEmbeddingModel): The sentence transformers model to use for embeddings.
+ batch_size (int): The number of records to encode simultaneously.
+ use_auth_token (str): The HuggingFace Hub auth token to use for private models.
+
+ Raises:
+ MissingDependency: If the sentence_transformers library is not installed.
+ """
+ try:
+ from sentence_transformers import SentenceTransformer as ST
+ except ImportError:
+ raise MissingDependency(
+ "Missing feature vecs[text_embedding]. Hint: `pip install 'vecs[text_embedding]'`"
+ )
+
+ self.model = ST(model, use_auth_token=use_auth_token)
+ self._exported_dimension = (
+ self.model.get_sentence_embedding_dimension()
+ )
+ self.batch_size = batch_size
+
+ @property
+ def exported_dimension(self) -> Optional[int]:
+ """
+ Returns the dimension of the embeddings produced by the sentence transformers model.
+
+ Returns:
+ int: The dimension of the embeddings.
+ """
+ return self._exported_dimension
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext, # pyright: ignore
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Converts each media in the records to an embedding and yields the result.
+
+ Args:
+ records: Iterable of tuples each containing an id, a media and an optional dict.
+ adapter_context: Context of the adapter.
+
+ Yields:
+ Tuple[str, Any, Dict]: The id, the embedding, and the metadata.
+ """
+ for batch in flu(records).chunk(self.batch_size):
+ batch_records = [x for x in batch]
+ media = [text for _, text, _ in batch_records]
+
+ embeddings = self.model.encode(media, normalize_embeddings=True)
+
+ for (id, _, metadata), embedding in zip(batch_records, embeddings): # type: ignore
+ yield (id, embedding, metadata or {})
+
+
+class ParagraphChunker(AdapterStep):
+ """
+ ParagraphChunker is an AdapterStep that splits text media into
+ paragraphs and yields each paragraph as a separate record.
+ """
+
+ def __init__(self, *, skip_during_query: bool):
+ """
+ Initializes the ParagraphChunker adapter.
+
+ Args:
+ skip_during_query (bool): Whether to skip chunking during querying.
+ """
+ self.skip_during_query = skip_during_query
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext,
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Splits each media in the records into paragraphs and yields each paragraph
+ as a separate record. If the `skip_during_query` attribute is set to True,
+ this step is skipped during querying.
+
+ Args:
+ records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a media and an optional dict.
+ adapter_context (AdapterContext): Context of the adapter.
+
+ Yields:
+ Tuple[str, Any, Dict]: The id appended with paragraph index, the paragraph, and the metadata.
+ """
+ if (
+ adapter_context == AdapterContext("query")
+ and self.skip_during_query
+ ):
+ for id, media, metadata in records:
+ yield (id, media, metadata or {})
+ else:
+ for id, media, metadata in records:
+ paragraphs = media.split("\n\n")
+
+ for paragraph_ix, paragraph in enumerate(paragraphs):
+ yield (
+ f"{id}_para_{str(paragraph_ix).zfill(3)}",
+ paragraph,
+ metadata or {},
+ )