aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base')
-rwxr-xr-xR2R/r2r/base/__init__.py160
-rwxr-xr-xR2R/r2r/base/abstractions/__init__.py0
-rwxr-xr-xR2R/r2r/base/abstractions/base.py93
-rwxr-xr-xR2R/r2r/base/abstractions/document.py242
-rwxr-xr-xR2R/r2r/base/abstractions/exception.py16
-rwxr-xr-xR2R/r2r/base/abstractions/llama_abstractions.py439
-rwxr-xr-xR2R/r2r/base/abstractions/llm.py112
-rwxr-xr-xR2R/r2r/base/abstractions/prompt.py31
-rwxr-xr-xR2R/r2r/base/abstractions/search.py84
-rwxr-xr-xR2R/r2r/base/abstractions/vector.py66
-rwxr-xr-xR2R/r2r/base/logging/__init__.py0
-rwxr-xr-xR2R/r2r/base/logging/kv_logger.py547
-rwxr-xr-xR2R/r2r/base/logging/log_processor.py196
-rwxr-xr-xR2R/r2r/base/logging/run_manager.py56
-rwxr-xr-xR2R/r2r/base/parsers/__init__.py5
-rwxr-xr-xR2R/r2r/base/parsers/base_parser.py14
-rwxr-xr-xR2R/r2r/base/pipeline/__init__.py0
-rwxr-xr-xR2R/r2r/base/pipeline/base_pipeline.py233
-rwxr-xr-xR2R/r2r/base/pipes/__init__.py0
-rwxr-xr-xR2R/r2r/base/pipes/base_pipe.py163
-rwxr-xr-xR2R/r2r/base/providers/__init__.py0
-rwxr-xr-xR2R/r2r/base/providers/base_provider.py48
-rwxr-xr-xR2R/r2r/base/providers/embedding_provider.py83
-rwxr-xr-xR2R/r2r/base/providers/eval_provider.py46
-rwxr-xr-xR2R/r2r/base/providers/kg_provider.py182
-rwxr-xr-xR2R/r2r/base/providers/llm_provider.py66
-rwxr-xr-xR2R/r2r/base/providers/prompt_provider.py65
-rwxr-xr-xR2R/r2r/base/providers/vector_db_provider.py142
-rwxr-xr-xR2R/r2r/base/utils/__init__.py26
-rwxr-xr-xR2R/r2r/base/utils/base_utils.py63
-rwxr-xr-xR2R/r2r/base/utils/splitter/__init__.py3
-rwxr-xr-xR2R/r2r/base/utils/splitter/text.py1979
32 files changed, 5160 insertions, 0 deletions
diff --git a/R2R/r2r/base/__init__.py b/R2R/r2r/base/__init__.py
new file mode 100755
index 00000000..a6794a84
--- /dev/null
+++ b/R2R/r2r/base/__init__.py
@@ -0,0 +1,160 @@
+from .abstractions.base import AsyncSyncMeta, UserStats, syncable
+from .abstractions.document import (
+ DataType,
+ Document,
+ DocumentInfo,
+ DocumentType,
+ Entity,
+ Extraction,
+ ExtractionType,
+ Fragment,
+ FragmentType,
+ KGExtraction,
+ Triple,
+ extract_entities,
+ extract_triples,
+)
+from .abstractions.exception import R2RDocumentProcessingError, R2RException
+from .abstractions.llama_abstractions import VectorStoreQuery
+from .abstractions.llm import (
+ GenerationConfig,
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ RAGCompletion,
+)
+from .abstractions.prompt import Prompt
+from .abstractions.search import (
+ AggregateSearchResult,
+ KGSearchRequest,
+ KGSearchResult,
+ KGSearchSettings,
+ VectorSearchRequest,
+ VectorSearchResult,
+ VectorSearchSettings,
+)
+from .abstractions.vector import Vector, VectorEntry, VectorType
+from .logging.kv_logger import (
+ KVLoggingSingleton,
+ LocalKVLoggingProvider,
+ LoggingConfig,
+ PostgresKVLoggingProvider,
+ PostgresLoggingConfig,
+ RedisKVLoggingProvider,
+ RedisLoggingConfig,
+)
+from .logging.log_processor import (
+ AnalysisTypes,
+ FilterCriteria,
+ LogAnalytics,
+ LogAnalyticsConfig,
+ LogProcessor,
+)
+from .logging.run_manager import RunManager, manage_run
+from .parsers import AsyncParser
+from .pipeline.base_pipeline import AsyncPipeline
+from .pipes.base_pipe import AsyncPipe, AsyncState, PipeType
+from .providers.embedding_provider import EmbeddingConfig, EmbeddingProvider
+from .providers.eval_provider import EvalConfig, EvalProvider
+from .providers.kg_provider import KGConfig, KGProvider, update_kg_prompt
+from .providers.llm_provider import LLMConfig, LLMProvider
+from .providers.prompt_provider import PromptConfig, PromptProvider
+from .providers.vector_db_provider import VectorDBConfig, VectorDBProvider
+from .utils import (
+ EntityType,
+ RecursiveCharacterTextSplitter,
+ Relation,
+ TextSplitter,
+ format_entity_types,
+ format_relations,
+ generate_id_from_label,
+ generate_run_id,
+ increment_version,
+ run_pipeline,
+ to_async_generator,
+)
+
+__all__ = [
+ # Logging
+ "AsyncParser",
+ "AnalysisTypes",
+ "LogAnalytics",
+ "LogAnalyticsConfig",
+ "LogProcessor",
+ "LoggingConfig",
+ "LocalKVLoggingProvider",
+ "PostgresLoggingConfig",
+ "PostgresKVLoggingProvider",
+ "RedisLoggingConfig",
+ "AsyncSyncMeta",
+ "syncable",
+ "RedisKVLoggingProvider",
+ "KVLoggingSingleton",
+ "RunManager",
+ "manage_run",
+ # Abstractions
+ "VectorEntry",
+ "VectorType",
+ "Vector",
+ "VectorSearchRequest",
+ "VectorSearchResult",
+ "VectorSearchSettings",
+ "KGSearchRequest",
+ "KGSearchResult",
+ "KGSearchSettings",
+ "AggregateSearchResult",
+ "AsyncPipe",
+ "PipeType",
+ "AsyncState",
+ "AsyncPipe",
+ "Prompt",
+ "DataType",
+ "DocumentType",
+ "Document",
+ "DocumentInfo",
+ "Extraction",
+ "ExtractionType",
+ "Fragment",
+ "FragmentType",
+ "extract_entities",
+ "Entity",
+ "extract_triples",
+ "R2RException",
+ "R2RDocumentProcessingError",
+ "Triple",
+ "KGExtraction",
+ "UserStats",
+ # Pipelines
+ "AsyncPipeline",
+ # Providers
+ "EmbeddingConfig",
+ "EmbeddingProvider",
+ "EvalConfig",
+ "EvalProvider",
+ "PromptConfig",
+ "PromptProvider",
+ "GenerationConfig",
+ "RAGCompletion",
+ "VectorStoreQuery",
+ "LLMChatCompletion",
+ "LLMChatCompletionChunk",
+ "LLMConfig",
+ "LLMProvider",
+ "VectorDBConfig",
+ "VectorDBProvider",
+ "KGProvider",
+ "KGConfig",
+ "update_kg_prompt",
+ # Other
+ "FilterCriteria",
+ "TextSplitter",
+ "RecursiveCharacterTextSplitter",
+ "to_async_generator",
+ "EntityType",
+ "Relation",
+ "format_entity_types",
+ "format_relations",
+ "increment_version",
+ "run_pipeline",
+ "generate_run_id",
+ "generate_id_from_label",
+]
diff --git a/R2R/r2r/base/abstractions/__init__.py b/R2R/r2r/base/abstractions/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/abstractions/__init__.py
diff --git a/R2R/r2r/base/abstractions/base.py b/R2R/r2r/base/abstractions/base.py
new file mode 100755
index 00000000..7121f6ce
--- /dev/null
+++ b/R2R/r2r/base/abstractions/base.py
@@ -0,0 +1,93 @@
+import asyncio
+import uuid
+from typing import List
+
+from pydantic import BaseModel
+
+
+class UserStats(BaseModel):
+ user_id: uuid.UUID
+ num_files: int
+ total_size_in_bytes: int
+ document_ids: List[uuid.UUID]
+
+
+class AsyncSyncMeta(type):
+ _event_loop = None # Class-level shared event loop
+
+ @classmethod
+ def get_event_loop(cls):
+ if cls._event_loop is None or cls._event_loop.is_closed():
+ cls._event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(cls._event_loop)
+ return cls._event_loop
+
+ def __new__(cls, name, bases, dct):
+ new_cls = super().__new__(cls, name, bases, dct)
+ for attr_name, attr_value in dct.items():
+ if asyncio.iscoroutinefunction(attr_value) and getattr(
+ attr_value, "_syncable", False
+ ):
+ sync_method_name = attr_name[
+ 1:
+ ] # Remove leading 'a' for sync method
+ async_method = attr_value
+
+ def make_sync_method(async_method):
+ def sync_wrapper(self, *args, **kwargs):
+ loop = cls.get_event_loop()
+ if not loop.is_running():
+ # Setup to run the loop in a background thread if necessary
+ # to prevent blocking the main thread in a synchronous call environment
+ from threading import Thread
+
+ result = None
+ exception = None
+
+ def run():
+ nonlocal result, exception
+ try:
+ asyncio.set_event_loop(loop)
+ result = loop.run_until_complete(
+ async_method(self, *args, **kwargs)
+ )
+ except Exception as e:
+ exception = e
+ finally:
+ generation_config = kwargs.get(
+ "rag_generation_config", None
+ )
+ if (
+ not generation_config
+ or not generation_config.stream
+ ):
+ loop.run_until_complete(
+ loop.shutdown_asyncgens()
+ )
+ loop.close()
+
+ thread = Thread(target=run)
+ thread.start()
+ thread.join()
+ if exception:
+ raise exception
+ return result
+ else:
+ # If there's already a running loop, schedule and execute the coroutine
+ future = asyncio.run_coroutine_threadsafe(
+ async_method(self, *args, **kwargs), loop
+ )
+ return future.result()
+
+ return sync_wrapper
+
+ setattr(
+ new_cls, sync_method_name, make_sync_method(async_method)
+ )
+ return new_cls
+
+
+def syncable(func):
+ """Decorator to mark methods for synchronous wrapper creation."""
+ func._syncable = True
+ return func
diff --git a/R2R/r2r/base/abstractions/document.py b/R2R/r2r/base/abstractions/document.py
new file mode 100755
index 00000000..117db7b9
--- /dev/null
+++ b/R2R/r2r/base/abstractions/document.py
@@ -0,0 +1,242 @@
+"""Abstractions for documents and their extractions."""
+
+import base64
+import json
+import logging
+import uuid
+from datetime import datetime
+from enum import Enum
+from typing import Optional, Union
+
+from pydantic import BaseModel, Field
+
+logger = logging.getLogger(__name__)
+
+DataType = Union[str, bytes]
+
+
+class DocumentType(str, Enum):
+ """Types of documents that can be stored."""
+
+ CSV = "csv"
+ DOCX = "docx"
+ HTML = "html"
+ JSON = "json"
+ MD = "md"
+ PDF = "pdf"
+ PPTX = "pptx"
+ TXT = "txt"
+ XLSX = "xlsx"
+ GIF = "gif"
+ PNG = "png"
+ JPG = "jpg"
+ JPEG = "jpeg"
+ SVG = "svg"
+ MP3 = "mp3"
+ MP4 = "mp4"
+
+
+class Document(BaseModel):
+ id: uuid.UUID = Field(default_factory=uuid.uuid4)
+ type: DocumentType
+ data: Union[str, bytes]
+ metadata: dict
+
+ def __init__(self, *args, **kwargs):
+ data = kwargs.get("data")
+ if data and isinstance(data, str):
+ try:
+ # Try to decode if it's already base64 encoded
+ kwargs["data"] = base64.b64decode(data)
+ except:
+ # If it's not base64, encode it to bytes
+ kwargs["data"] = data.encode("utf-8")
+
+ doc_type = kwargs.get("type")
+ if isinstance(doc_type, str):
+ kwargs["type"] = DocumentType(doc_type)
+
+ # Generate UUID based on the hash of the data
+ if "id" not in kwargs:
+ if isinstance(kwargs["data"], bytes):
+ data_hash = uuid.uuid5(
+ uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8")
+ )
+ else:
+ data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"])
+
+ kwargs["id"] = data_hash # Set the id based on the data hash
+
+ super().__init__(*args, **kwargs)
+
+ class Config:
+ arbitrary_types_allowed = True
+ json_encoders = {
+ uuid.UUID: str,
+ bytes: lambda v: base64.b64encode(v).decode("utf-8"),
+ }
+
+
+class DocumentStatus(str, Enum):
+ """Status of document processing."""
+
+ PROCESSING = "processing"
+ # TODO - Extend support for `partial-failure`
+ # PARTIAL_FAILURE = "partial-failure"
+ FAILURE = "failure"
+ SUCCESS = "success"
+
+
+class DocumentInfo(BaseModel):
+ """Base class for document information handling."""
+
+ document_id: uuid.UUID
+ version: str
+ size_in_bytes: int
+ metadata: dict
+ status: DocumentStatus = DocumentStatus.PROCESSING
+
+ user_id: Optional[uuid.UUID] = None
+ title: Optional[str] = None
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+
+ def convert_to_db_entry(self):
+ """Prepare the document info for database entry, extracting certain fields from metadata."""
+ now = datetime.now()
+ metadata = self.metadata
+ if "user_id" in metadata:
+ metadata["user_id"] = str(metadata["user_id"])
+
+ metadata["title"] = metadata.get("title", "N/A")
+ return {
+ "document_id": str(self.document_id),
+ "title": metadata.get("title", "N/A"),
+ "user_id": metadata.get("user_id", None),
+ "version": self.version,
+ "size_in_bytes": self.size_in_bytes,
+ "metadata": json.dumps(self.metadata),
+ "created_at": self.created_at or now,
+ "updated_at": self.updated_at or now,
+ "status": self.status,
+ }
+
+
+class ExtractionType(Enum):
+ """Types of extractions that can be performed."""
+
+ TXT = "txt"
+ IMG = "img"
+ MOV = "mov"
+
+
+class Extraction(BaseModel):
+ """An extraction from a document."""
+
+ id: uuid.UUID
+ type: ExtractionType = ExtractionType.TXT
+ data: DataType
+ metadata: dict
+ document_id: uuid.UUID
+
+
+class FragmentType(Enum):
+ """A type of fragment that can be extracted from a document."""
+
+ TEXT = "text"
+ IMAGE = "image"
+
+
+class Fragment(BaseModel):
+ """A fragment extracted from a document."""
+
+ id: uuid.UUID
+ type: FragmentType
+ data: DataType
+ metadata: dict
+ document_id: uuid.UUID
+ extraction_id: uuid.UUID
+
+
+class Entity(BaseModel):
+ """An entity extracted from a document."""
+
+ category: str
+ subcategory: Optional[str] = None
+ value: str
+
+ def __str__(self):
+ return (
+ f"{self.category}:{self.subcategory}:{self.value}"
+ if self.subcategory
+ else f"{self.category}:{self.value}"
+ )
+
+
+class Triple(BaseModel):
+ """A triple extracted from a document."""
+
+ subject: str
+ predicate: str
+ object: str
+
+
+def extract_entities(llm_payload: list[str]) -> dict[str, Entity]:
+ entities = {}
+ for entry in llm_payload:
+ try:
+ if "], " in entry: # Check if the entry is an entity
+ entry_val = entry.split("], ")[0] + "]"
+ entry = entry.split("], ")[1]
+ colon_count = entry.count(":")
+
+ if colon_count == 1:
+ category, value = entry.split(":")
+ subcategory = None
+ elif colon_count >= 2:
+ parts = entry.split(":", 2)
+ category, subcategory, value = (
+ parts[0],
+ parts[1],
+ parts[2],
+ )
+ else:
+ raise ValueError("Unexpected entry format")
+
+ entities[entry_val] = Entity(
+ category=category, subcategory=subcategory, value=value
+ )
+ except Exception as e:
+ logger.error(f"Error processing entity {entry}: {e}")
+ continue
+ return entities
+
+
+def extract_triples(
+ llm_payload: list[str], entities: dict[str, Entity]
+) -> list[Triple]:
+ triples = []
+ for entry in llm_payload:
+ try:
+ if "], " not in entry: # Check if the entry is an entity
+ elements = entry.split(" ")
+ subject = elements[0]
+ predicate = elements[1]
+ object = " ".join(elements[2:])
+ subject = entities[subject].value # Use entity.value
+ if "[" in object and "]" in object:
+ object = entities[object].value # Use entity.value
+ triples.append(
+ Triple(subject=subject, predicate=predicate, object=object)
+ )
+ except Exception as e:
+ logger.error(f"Error processing triplet {entry}: {e}")
+ continue
+ return triples
+
+
+class KGExtraction(BaseModel):
+ """An extraction from a document that is part of a knowledge graph."""
+
+ entities: dict[str, Entity]
+ triples: list[Triple]
diff --git a/R2R/r2r/base/abstractions/exception.py b/R2R/r2r/base/abstractions/exception.py
new file mode 100755
index 00000000..c76625a3
--- /dev/null
+++ b/R2R/r2r/base/abstractions/exception.py
@@ -0,0 +1,16 @@
+from typing import Any, Optional
+
+
+class R2RException(Exception):
+ def __init__(
+ self, message: str, status_code: int, detail: Optional[Any] = None
+ ):
+ self.message = message
+ self.status_code = status_code
+ super().__init__(self.message)
+
+
+class R2RDocumentProcessingError(R2RException):
+ def __init__(self, error_message, document_id):
+ self.document_id = document_id
+ super().__init__(error_message, 400, {"document_id": document_id})
diff --git a/R2R/r2r/base/abstractions/llama_abstractions.py b/R2R/r2r/base/abstractions/llama_abstractions.py
new file mode 100755
index 00000000..f6bc36e6
--- /dev/null
+++ b/R2R/r2r/base/abstractions/llama_abstractions.py
@@ -0,0 +1,439 @@
+# abstractions are taken from LlamaIndex
+# https://github.com/run-llama/llama_index
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
+
+
+class LabelledNode(BaseModel):
+ """An entity in a graph."""
+
+ label: str = Field(default="node", description="The label of the node.")
+ embedding: Optional[List[float]] = Field(
+ default=None, description="The embeddings of the node."
+ )
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+ @abstractmethod
+ def __str__(self) -> str:
+ """Return the string representation of the node."""
+ ...
+
+ @property
+ @abstractmethod
+ def id(self) -> str:
+ """Get the node id."""
+ ...
+
+
+class EntityNode(LabelledNode):
+ """An entity in a graph."""
+
+ name: str = Field(description="The name of the entity.")
+ label: str = Field(default="entity", description="The label of the node.")
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+ def __str__(self) -> str:
+ """Return the string representation of the node."""
+ return self.name
+
+ @property
+ def id(self) -> str:
+ """Get the node id."""
+ return self.name.replace('"', " ")
+
+
+class ChunkNode(LabelledNode):
+ """A text chunk in a graph."""
+
+ text: str = Field(description="The text content of the chunk.")
+ id_: Optional[str] = Field(
+ default=None,
+ description="The id of the node. Defaults to a hash of the text.",
+ )
+ label: str = Field(
+ default="text_chunk", description="The label of the node."
+ )
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+ def __str__(self) -> str:
+ """Return the string representation of the node."""
+ return self.text
+
+ @property
+ def id(self) -> str:
+ """Get the node id."""
+ return str(hash(self.text)) if self.id_ is None else self.id_
+
+
+class Relation(BaseModel):
+ """A relation connecting two entities in a graph."""
+
+ label: str
+ source_id: str
+ target_id: str
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+ def __str__(self) -> str:
+ """Return the string representation of the relation."""
+ return self.label
+
+ @property
+ def id(self) -> str:
+ """Get the relation id."""
+ return self.label
+
+
+Triplet = Tuple[LabelledNode, Relation, LabelledNode]
+
+
+class VectorStoreQueryMode(str, Enum):
+ """Vector store query mode."""
+
+ DEFAULT = "default"
+ SPARSE = "sparse"
+ HYBRID = "hybrid"
+ TEXT_SEARCH = "text_search"
+ SEMANTIC_HYBRID = "semantic_hybrid"
+
+ # fit learners
+ SVM = "svm"
+ LOGISTIC_REGRESSION = "logistic_regression"
+ LINEAR_REGRESSION = "linear_regression"
+
+ # maximum marginal relevance
+ MMR = "mmr"
+
+
+class FilterOperator(str, Enum):
+ """Vector store filter operator."""
+
+ # TODO add more operators
+ EQ = "==" # default operator (string, int, float)
+ GT = ">" # greater than (int, float)
+ LT = "<" # less than (int, float)
+ NE = "!=" # not equal to (string, int, float)
+ GTE = ">=" # greater than or equal to (int, float)
+ LTE = "<=" # less than or equal to (int, float)
+ IN = "in" # In array (string or number)
+ NIN = "nin" # Not in array (string or number)
+ ANY = "any" # Contains any (array of strings)
+ ALL = "all" # Contains all (array of strings)
+ TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field)
+ CONTAINS = "contains" # metadata array contains value (string or number)
+
+
+class MetadataFilter(BaseModel):
+ """Comprehensive metadata filter for vector stores to support more operators.
+
+ Value uses Strict* types, as int, float and str are compatible types and were all
+ converted to string before.
+
+ See: https://docs.pydantic.dev/latest/usage/types/#strict-types
+ """
+
+ key: str
+ value: Union[
+ StrictInt,
+ StrictFloat,
+ StrictStr,
+ List[StrictStr],
+ List[StrictFloat],
+ List[StrictInt],
+ ]
+ operator: FilterOperator = FilterOperator.EQ
+
+ @classmethod
+ def from_dict(
+ cls,
+ filter_dict: Dict,
+ ) -> "MetadataFilter":
+ """Create MetadataFilter from dictionary.
+
+ Args:
+ filter_dict: Dict with key, value and operator.
+
+ """
+ return MetadataFilter.parse_obj(filter_dict)
+
+
+# # TODO: Deprecate ExactMatchFilter and use MetadataFilter instead
+# # Keep class for now so that AutoRetriever can still work with old vector stores
+# class ExactMatchFilter(BaseModel):
+# key: str
+# value: Union[StrictInt, StrictFloat, StrictStr]
+
+# set ExactMatchFilter to MetadataFilter
+ExactMatchFilter = MetadataFilter
+
+
+class FilterCondition(str, Enum):
+ """Vector store filter conditions to combine different filters."""
+
+ # TODO add more conditions
+ AND = "and"
+ OR = "or"
+
+
+class MetadataFilters(BaseModel):
+ """Metadata filters for vector stores."""
+
+ # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc.
+ filters: List[Union[MetadataFilter, ExactMatchFilter, "MetadataFilters"]]
+ # and/or such conditions for combining different filters
+ condition: Optional[FilterCondition] = FilterCondition.AND
+
+
+@dataclass
+class VectorStoreQuery:
+ """Vector store query."""
+
+ query_embedding: Optional[List[float]] = None
+ similarity_top_k: int = 1
+ doc_ids: Optional[List[str]] = None
+ node_ids: Optional[List[str]] = None
+ query_str: Optional[str] = None
+ output_fields: Optional[List[str]] = None
+ embedding_field: Optional[str] = None
+
+ mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
+
+ # NOTE: only for hybrid search (0 for bm25, 1 for vector search)
+ alpha: Optional[float] = None
+
+ # metadata filters
+ filters: Optional[MetadataFilters] = None
+
+ # only for mmr
+ mmr_threshold: Optional[float] = None
+
+ # NOTE: currently only used by postgres hybrid search
+ sparse_top_k: Optional[int] = None
+ # NOTE: return top k results from hybrid search. similarity_top_k is used for dense search top k
+ hybrid_top_k: Optional[int] = None
+
+
+class PropertyGraphStore(ABC):
+ """Abstract labelled graph store protocol.
+
+ This protocol defines the interface for a graph store, which is responsible
+ for storing and retrieving knowledge graph data.
+
+ Attributes:
+ client: Any: The client used to connect to the graph store.
+ get: Callable[[str], List[List[str]]]: Get triplets for a given subject.
+ get_rel_map: Callable[[Optional[List[str]], int], Dict[str, List[List[str]]]]:
+ Get subjects' rel map in max depth.
+ upsert_triplet: Callable[[str, str, str], None]: Upsert a triplet.
+ delete: Callable[[str, str, str], None]: Delete a triplet.
+ persist: Callable[[str, Optional[fsspec.AbstractFileSystem]], None]:
+ Persist the graph store to a file.
+ """
+
+ supports_structured_queries: bool = False
+ supports_vector_queries: bool = False
+
+ @property
+ def client(self) -> Any:
+ """Get client."""
+ ...
+
+ @abstractmethod
+ def get(
+ self,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[LabelledNode]:
+ """Get nodes with matching values."""
+ ...
+
+ @abstractmethod
+ def get_triplets(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ """Get triplets with matching values."""
+ ...
+
+ @abstractmethod
+ def get_rel_map(
+ self,
+ graph_nodes: List[LabelledNode],
+ depth: int = 2,
+ limit: int = 30,
+ ignore_rels: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ """Get depth-aware rel map."""
+ ...
+
+ @abstractmethod
+ def upsert_nodes(self, nodes: List[LabelledNode]) -> None:
+ """Upsert nodes."""
+ ...
+
+ @abstractmethod
+ def upsert_relations(self, relations: List[Relation]) -> None:
+ """Upsert relations."""
+ ...
+
+ @abstractmethod
+ def delete(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> None:
+ """Delete matching data."""
+ ...
+
+ @abstractmethod
+ def structured_query(
+ self, query: str, param_map: Optional[Dict[str, Any]] = None
+ ) -> Any:
+ """Query the graph store with statement and parameters."""
+ ...
+
+ @abstractmethod
+ def vector_query(
+ self, query: VectorStoreQuery, **kwargs: Any
+ ) -> Tuple[List[LabelledNode], List[float]]:
+ """Query the graph store with a vector store query."""
+ ...
+
+ # def persist(
+ # self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
+ # ) -> None:
+ # """Persist the graph store to a file."""
+ # return
+
+ def get_schema(self, refresh: bool = False) -> Any:
+ """Get the schema of the graph store."""
+ return None
+
+ def get_schema_str(self, refresh: bool = False) -> str:
+ """Get the schema of the graph store as a string."""
+ return str(self.get_schema(refresh=refresh))
+
+ ### ----- Async Methods ----- ###
+
+ async def aget(
+ self,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[LabelledNode]:
+ """Asynchronously get nodes with matching values."""
+ return self.get(properties, ids)
+
+ async def aget_triplets(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ """Asynchronously get triplets with matching values."""
+ return self.get_triplets(entity_names, relation_names, properties, ids)
+
+ async def aget_rel_map(
+ self,
+ graph_nodes: List[LabelledNode],
+ depth: int = 2,
+ limit: int = 30,
+ ignore_rels: Optional[List[str]] = None,
+ ) -> List[Triplet]:
+ """Asynchronously get depth-aware rel map."""
+ return self.get_rel_map(graph_nodes, depth, limit, ignore_rels)
+
+ async def aupsert_nodes(self, nodes: List[LabelledNode]) -> None:
+ """Asynchronously add nodes."""
+ return self.upsert_nodes(nodes)
+
+ async def aupsert_relations(self, relations: List[Relation]) -> None:
+ """Asynchronously add relations."""
+ return self.upsert_relations(relations)
+
+ async def adelete(
+ self,
+ entity_names: Optional[List[str]] = None,
+ relation_names: Optional[List[str]] = None,
+ properties: Optional[dict] = None,
+ ids: Optional[List[str]] = None,
+ ) -> None:
+ """Asynchronously delete matching data."""
+ return self.delete(entity_names, relation_names, properties, ids)
+
+ async def astructured_query(
+ self, query: str, param_map: Optional[Dict[str, Any]] = {}
+ ) -> Any:
+ """Asynchronously query the graph store with statement and parameters."""
+ return self.structured_query(query, param_map)
+
+ async def avector_query(
+ self, query: VectorStoreQuery, **kwargs: Any
+ ) -> Tuple[List[LabelledNode], List[float]]:
+ """Asynchronously query the graph store with a vector store query."""
+ return self.vector_query(query, **kwargs)
+
+ async def aget_schema(self, refresh: bool = False) -> str:
+ """Asynchronously get the schema of the graph store."""
+ return self.get_schema(refresh=refresh)
+
+ async def aget_schema_str(self, refresh: bool = False) -> str:
+ """Asynchronously get the schema of the graph store as a string."""
+ return str(await self.aget_schema(refresh=refresh))
+
+
+LIST_LIMIT = 128
+
+
+def clean_string_values(text: str) -> str:
+ return text.replace("\n", " ").replace("\r", " ")
+
+
+def value_sanitize(d: Any) -> Any:
+ """Sanitize the input dictionary or list.
+
+ Sanitizes the input by removing embedding-like values,
+ lists with more than 128 elements, that are mostly irrelevant for
+ generating answers in a LLM context. These properties, if left in
+ results, can occupy significant context space and detract from
+ the LLM's performance by introducing unnecessary noise and cost.
+ """
+ if isinstance(d, dict):
+ new_dict = {}
+ for key, value in d.items():
+ if isinstance(value, dict):
+ sanitized_value = value_sanitize(value)
+ if (
+ sanitized_value is not None
+ ): # Check if the sanitized value is not None
+ new_dict[key] = sanitized_value
+ elif isinstance(value, list):
+ if len(value) < LIST_LIMIT:
+ sanitized_value = value_sanitize(value)
+ if (
+ sanitized_value is not None
+ ): # Check if the sanitized value is not None
+ new_dict[key] = sanitized_value
+ # Do not include the key if the list is oversized
+ else:
+ new_dict[key] = value
+ return new_dict
+ elif isinstance(d, list):
+ if len(d) < LIST_LIMIT:
+ return [
+ value_sanitize(item)
+ for item in d
+ if value_sanitize(item) is not None
+ ]
+ else:
+ return None
+ else:
+ return d
diff --git a/R2R/r2r/base/abstractions/llm.py b/R2R/r2r/base/abstractions/llm.py
new file mode 100755
index 00000000..3178d8dc
--- /dev/null
+++ b/R2R/r2r/base/abstractions/llm.py
@@ -0,0 +1,112 @@
+"""Abstractions for the LLM model."""
+
+from typing import TYPE_CHECKING, ClassVar, Optional
+
+from openai.types.chat import ChatCompletion, ChatCompletionChunk
+from pydantic import BaseModel, Field
+
+if TYPE_CHECKING:
+ from .search import AggregateSearchResult
+
+LLMChatCompletion = ChatCompletion
+LLMChatCompletionChunk = ChatCompletionChunk
+
+
+class RAGCompletion:
+ completion: LLMChatCompletion
+ search_results: "AggregateSearchResult"
+
+ def __init__(
+ self,
+ completion: LLMChatCompletion,
+ search_results: "AggregateSearchResult",
+ ):
+ self.completion = completion
+ self.search_results = search_results
+
+
+class GenerationConfig(BaseModel):
+ _defaults: ClassVar[dict] = {
+ "model": "gpt-4o",
+ "temperature": 0.1,
+ "top_p": 1.0,
+ "top_k": 100,
+ "max_tokens_to_sample": 1024,
+ "stream": False,
+ "functions": None,
+ "skip_special_tokens": False,
+ "stop_token": None,
+ "num_beams": 1,
+ "do_sample": True,
+ "generate_with_chat": False,
+ "add_generation_kwargs": None,
+ "api_base": None,
+ }
+
+ model: str = Field(
+ default_factory=lambda: GenerationConfig._defaults["model"]
+ )
+ temperature: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["temperature"]
+ )
+ top_p: float = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_p"]
+ )
+ top_k: int = Field(
+ default_factory=lambda: GenerationConfig._defaults["top_k"]
+ )
+ max_tokens_to_sample: int = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "max_tokens_to_sample"
+ ]
+ )
+ stream: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["stream"]
+ )
+ functions: Optional[list[dict]] = Field(
+ default_factory=lambda: GenerationConfig._defaults["functions"]
+ )
+ skip_special_tokens: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "skip_special_tokens"
+ ]
+ )
+ stop_token: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["stop_token"]
+ )
+ num_beams: int = Field(
+ default_factory=lambda: GenerationConfig._defaults["num_beams"]
+ )
+ do_sample: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults["do_sample"]
+ )
+ generate_with_chat: bool = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "generate_with_chat"
+ ]
+ )
+ add_generation_kwargs: Optional[dict] = Field(
+ default_factory=lambda: GenerationConfig._defaults[
+ "add_generation_kwargs"
+ ]
+ )
+ api_base: Optional[str] = Field(
+ default_factory=lambda: GenerationConfig._defaults["api_base"]
+ )
+
+ @classmethod
+ def set_default(cls, **kwargs):
+ for key, value in kwargs.items():
+ if key in cls._defaults:
+ cls._defaults[key] = value
+ else:
+ raise AttributeError(
+ f"No default attribute '{key}' in GenerationConfig"
+ )
+
+ def __init__(self, **data):
+ model = data.pop("model", None)
+ if model is not None:
+ super().__init__(model=model, **data)
+ else:
+ super().__init__(**data)
diff --git a/R2R/r2r/base/abstractions/prompt.py b/R2R/r2r/base/abstractions/prompt.py
new file mode 100755
index 00000000..e37eeb5f
--- /dev/null
+++ b/R2R/r2r/base/abstractions/prompt.py
@@ -0,0 +1,31 @@
+"""Abstraction for a prompt that can be formatted with inputs."""
+
+from typing import Any
+
+from pydantic import BaseModel
+
+
+class Prompt(BaseModel):
+ """A prompt that can be formatted with inputs."""
+
+ name: str
+ template: str
+ input_types: dict[str, str]
+
+ def format_prompt(self, inputs: dict[str, Any]) -> str:
+ self._validate_inputs(inputs)
+ return self.template.format(**inputs)
+
+ def _validate_inputs(self, inputs: dict[str, Any]) -> None:
+ for var, expected_type_name in self.input_types.items():
+ expected_type = self._convert_type(expected_type_name)
+ if var not in inputs:
+ raise ValueError(f"Missing input: {var}")
+ if not isinstance(inputs[var], expected_type):
+ raise TypeError(
+ f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead."
+ )
+
+ def _convert_type(self, type_name: str) -> type:
+ type_mapping = {"int": int, "str": str}
+ return type_mapping.get(type_name, str)
diff --git a/R2R/r2r/base/abstractions/search.py b/R2R/r2r/base/abstractions/search.py
new file mode 100755
index 00000000..b13cc5aa
--- /dev/null
+++ b/R2R/r2r/base/abstractions/search.py
@@ -0,0 +1,84 @@
+"""Abstractions for search functionality."""
+
+import uuid
+from typing import Any, Dict, List, Optional, Tuple
+
+from pydantic import BaseModel, Field
+
+from .llm import GenerationConfig
+
+
+class VectorSearchRequest(BaseModel):
+ """Request for a search operation."""
+
+ query: str
+ limit: int
+ filters: Optional[dict[str, Any]] = None
+
+
+class VectorSearchResult(BaseModel):
+ """Result of a search operation."""
+
+ id: uuid.UUID
+ score: float
+ metadata: dict[str, Any]
+
+ def __str__(self) -> str:
+ return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})"
+
+ def __repr__(self) -> str:
+ return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})"
+
+ def dict(self) -> dict:
+ return {
+ "id": self.id,
+ "score": self.score,
+ "metadata": self.metadata,
+ }
+
+
+class KGSearchRequest(BaseModel):
+ """Request for a knowledge graph search operation."""
+
+ query: str
+
+
+# [query, ...]
+KGSearchResult = List[Tuple[str, List[Dict[str, Any]]]]
+
+
+class AggregateSearchResult(BaseModel):
+ """Result of an aggregate search operation."""
+
+ vector_search_results: Optional[List[VectorSearchResult]]
+ kg_search_results: Optional[KGSearchResult] = None
+
+ def __str__(self) -> str:
+ return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})"
+
+ def __repr__(self) -> str:
+ return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})"
+
+ def dict(self) -> dict:
+ return {
+ "vector_search_results": (
+ [result.dict() for result in self.vector_search_results]
+ if self.vector_search_results
+ else []
+ ),
+ "kg_search_results": self.kg_search_results or [],
+ }
+
+
+class VectorSearchSettings(BaseModel):
+ use_vector_search: bool = True
+ search_filters: dict[str, Any] = Field(default_factory=dict)
+ search_limit: int = 10
+ do_hybrid_search: bool = False
+
+
+class KGSearchSettings(BaseModel):
+ use_kg_search: bool = False
+ agent_generation_config: Optional[GenerationConfig] = Field(
+ default_factory=GenerationConfig
+ )
diff --git a/R2R/r2r/base/abstractions/vector.py b/R2R/r2r/base/abstractions/vector.py
new file mode 100755
index 00000000..445f3302
--- /dev/null
+++ b/R2R/r2r/base/abstractions/vector.py
@@ -0,0 +1,66 @@
+"""Abstraction for a vector that can be stored in the system."""
+
+from enum import Enum
+from typing import Any
+from uuid import UUID
+
+
+class VectorType(Enum):
+ FIXED = "FIXED"
+
+
+class Vector:
+ """A vector with the option to fix the number of elements."""
+
+ def __init__(
+ self,
+ data: list[float],
+ type: VectorType = VectorType.FIXED,
+ length: int = -1,
+ ):
+ self.data = data
+ self.type = type
+ self.length = length
+
+ if (
+ self.type == VectorType.FIXED
+ and length > 0
+ and len(data) != length
+ ):
+ raise ValueError(f"Vector must be exactly {length} elements long.")
+
+ def __repr__(self) -> str:
+ return (
+ f"Vector(data={self.data}, type={self.type}, length={self.length})"
+ )
+
+
+class VectorEntry:
+ """A vector entry that can be stored directly in supported vector databases."""
+
+ def __init__(self, id: UUID, vector: Vector, metadata: dict[str, Any]):
+ """Create a new VectorEntry object."""
+ self.vector = vector
+ self.id = id
+ self.metadata = metadata
+
+ def to_serializable(self) -> str:
+ """Return a serializable representation of the VectorEntry."""
+ metadata = self.metadata
+
+ for key in metadata:
+ if isinstance(metadata[key], UUID):
+ metadata[key] = str(metadata[key])
+ return {
+ "id": str(self.id),
+ "vector": self.vector.data,
+ "metadata": metadata,
+ }
+
+ def __str__(self) -> str:
+ """Return a string representation of the VectorEntry."""
+ return f"VectorEntry(id={self.id}, vector={self.vector}, metadata={self.metadata})"
+
+ def __repr__(self) -> str:
+ """Return an unambiguous string representation of the VectorEntry."""
+ return f"VectorEntry(id={self.id}, vector={self.vector}, metadata={self.metadata})"
diff --git a/R2R/r2r/base/logging/__init__.py b/R2R/r2r/base/logging/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/logging/__init__.py
diff --git a/R2R/r2r/base/logging/kv_logger.py b/R2R/r2r/base/logging/kv_logger.py
new file mode 100755
index 00000000..2d444e9f
--- /dev/null
+++ b/R2R/r2r/base/logging/kv_logger.py
@@ -0,0 +1,547 @@
+import json
+import logging
+import os
+import uuid
+from abc import abstractmethod
+from datetime import datetime
+from typing import Optional
+
+import asyncpg
+from pydantic import BaseModel
+
+from ..providers.base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class RunInfo(BaseModel):
+ run_id: uuid.UUID
+ log_type: str
+
+
+class LoggingConfig(ProviderConfig):
+ provider: str = "local"
+ log_table: str = "logs"
+ log_info_table: str = "logs_pipeline_info"
+ logging_path: Optional[str] = None
+
+ def validate(self) -> None:
+ pass
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["local", "postgres", "redis"]
+
+
+class KVLoggingProvider(Provider):
+ @abstractmethod
+ async def close(self):
+ pass
+
+ @abstractmethod
+ async def log(self, log_id: uuid.UUID, key: str, value: str):
+ pass
+
+ @abstractmethod
+ async def get_run_info(
+ self,
+ limit: int = 10,
+ log_type_filter: Optional[str] = None,
+ ) -> list[RunInfo]:
+ pass
+
+ @abstractmethod
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int
+ ) -> list:
+ pass
+
+
+class LocalKVLoggingProvider(KVLoggingProvider):
+ def __init__(self, config: LoggingConfig):
+ self.log_table = config.log_table
+ self.log_info_table = config.log_info_table
+ self.logging_path = config.logging_path or os.getenv(
+ "LOCAL_DB_PATH", "local.sqlite"
+ )
+ if not self.logging_path:
+ raise ValueError(
+ "Please set the environment variable LOCAL_DB_PATH."
+ )
+ self.conn = None
+ try:
+ import aiosqlite
+
+ self.aiosqlite = aiosqlite
+ except ImportError:
+ raise ImportError(
+ "Please install aiosqlite to use the LocalKVLoggingProvider."
+ )
+
+ async def init(self):
+ self.conn = await self.aiosqlite.connect(self.logging_path)
+ await self.conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {self.log_table} (
+ timestamp DATETIME,
+ log_id TEXT,
+ key TEXT,
+ value TEXT
+ )
+ """
+ )
+ await self.conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {self.log_info_table} (
+ timestamp DATETIME,
+ log_id TEXT UNIQUE,
+ log_type TEXT
+ )
+ """
+ )
+ await self.conn.commit()
+
+ async def __aenter__(self):
+ if self.conn is None:
+ await self.init()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ async def close(self):
+ if self.conn:
+ await self.conn.close()
+ self.conn = None
+
+ async def log(
+ self,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ collection = self.log_info_table if is_info_log else self.log_table
+
+ if is_info_log:
+ if "type" not in key:
+ raise ValueError("Info log keys must contain the text 'type'")
+ await self.conn.execute(
+ f"INSERT INTO {collection} (timestamp, log_id, log_type) VALUES (datetime('now'), ?, ?)",
+ (str(log_id), value),
+ )
+ else:
+ await self.conn.execute(
+ f"INSERT INTO {collection} (timestamp, log_id, key, value) VALUES (datetime('now'), ?, ?, ?)",
+ (str(log_id), key, value),
+ )
+ await self.conn.commit()
+
+ async def get_run_info(
+ self, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ cursor = await self.conn.cursor()
+ query = f'SELECT log_id, log_type FROM "{self.log_info_table}"'
+ conditions = []
+ params = []
+ if log_type_filter:
+ conditions.append("log_type = ?")
+ params.append(log_type_filter)
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+ query += " ORDER BY timestamp DESC LIMIT ?"
+ params.append(limit)
+ await cursor.execute(query, params)
+ rows = await cursor.fetchall()
+ return [
+ RunInfo(run_id=uuid.UUID(row[0]), log_type=row[1]) for row in rows
+ ]
+
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ if not run_ids:
+ raise ValueError("No run ids provided.")
+ cursor = await self.conn.cursor()
+ placeholders = ",".join(["?" for _ in run_ids])
+ query = f"""
+ SELECT *
+ FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
+ FROM {self.log_table}
+ WHERE log_id IN ({placeholders})
+ )
+ WHERE rn <= ?
+ ORDER BY timestamp DESC
+ """
+ params = [str(ele) for ele in run_ids] + [limit_per_run]
+ await cursor.execute(query, params)
+ rows = await cursor.fetchall()
+ new_rows = []
+ for row in rows:
+ new_rows.append(
+ (row[0], uuid.UUID(row[1]), row[2], row[3], row[4])
+ )
+ return [
+ {desc[0]: row[i] for i, desc in enumerate(cursor.description)}
+ for row in new_rows
+ ]
+
+
+class PostgresLoggingConfig(LoggingConfig):
+ provider: str = "postgres"
+ log_table: str = "logs"
+ log_info_table: str = "logs_pipeline_info"
+
+ def validate(self) -> None:
+ required_env_vars = [
+ "POSTGRES_DBNAME",
+ "POSTGRES_USER",
+ "POSTGRES_PASSWORD",
+ "POSTGRES_HOST",
+ "POSTGRES_PORT",
+ ]
+ for var in required_env_vars:
+ if not os.getenv(var):
+ raise ValueError(f"Environment variable {var} is not set.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["postgres"]
+
+
+class PostgresKVLoggingProvider(KVLoggingProvider):
+ def __init__(self, config: PostgresLoggingConfig):
+ self.log_table = config.log_table
+ self.log_info_table = config.log_info_table
+ self.config = config
+ self.pool = None
+ if not os.getenv("POSTGRES_DBNAME"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_DBNAME."
+ )
+ if not os.getenv("POSTGRES_USER"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_USER."
+ )
+ if not os.getenv("POSTGRES_PASSWORD"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_PASSWORD."
+ )
+ if not os.getenv("POSTGRES_HOST"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_HOST."
+ )
+ if not os.getenv("POSTGRES_PORT"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_PORT."
+ )
+
+ async def init(self):
+ self.pool = await asyncpg.create_pool(
+ database=os.getenv("POSTGRES_DBNAME"),
+ user=os.getenv("POSTGRES_USER"),
+ password=os.getenv("POSTGRES_PASSWORD"),
+ host=os.getenv("POSTGRES_HOST"),
+ port=os.getenv("POSTGRES_PORT"),
+ statement_cache_size=0, # Disable statement caching
+ )
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS "{self.log_table}" (
+ timestamp TIMESTAMPTZ,
+ log_id UUID,
+ key TEXT,
+ value TEXT
+ )
+ """
+ )
+ await conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS "{self.log_info_table}" (
+ timestamp TIMESTAMPTZ,
+ log_id UUID UNIQUE,
+ log_type TEXT
+ )
+ """
+ )
+
+ async def __aenter__(self):
+ if self.pool is None:
+ await self.init()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ async def close(self):
+ if self.pool:
+ await self.pool.close()
+ self.pool = None
+
+ async def log(
+ self,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ collection = self.log_info_table if is_info_log else self.log_table
+
+ if is_info_log:
+ if "type" not in key:
+ raise ValueError(
+ "Info log key must contain the string `type`."
+ )
+ async with self.pool.acquire() as conn:
+ await self.pool.execute(
+ f'INSERT INTO "{collection}" (timestamp, log_id, log_type) VALUES (NOW(), $1, $2)',
+ log_id,
+ value,
+ )
+ else:
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ f'INSERT INTO "{collection}" (timestamp, log_id, key, value) VALUES (NOW(), $1, $2, $3)',
+ log_id,
+ key,
+ value,
+ )
+
+ async def get_run_info(
+ self, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ query = f"SELECT log_id, log_type FROM {self.log_info_table}"
+ conditions = []
+ params = []
+ if log_type_filter:
+ conditions.append("log_type = $1")
+ params.append(log_type_filter)
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+ query += " ORDER BY timestamp DESC LIMIT $2"
+ params.append(limit)
+ async with self.pool.acquire() as conn:
+ rows = await conn.fetch(query, *params)
+ return [
+ RunInfo(run_id=row["log_id"], log_type=row["log_type"])
+ for row in rows
+ ]
+
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ if not run_ids:
+ raise ValueError("No run ids provided.")
+
+ placeholders = ",".join([f"${i + 1}" for i in range(len(run_ids))])
+ query = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
+ FROM "{self.log_table}"
+ WHERE log_id::text IN ({placeholders})
+ ) sub
+ WHERE sub.rn <= ${len(run_ids) + 1}
+ ORDER BY sub.timestamp DESC
+ """
+ params = [str(run_id) for run_id in run_ids] + [limit_per_run]
+ async with self.pool.acquire() as conn:
+ rows = await conn.fetch(query, *params)
+ return [{key: row[key] for key in row.keys()} for row in rows]
+
+
+class RedisLoggingConfig(LoggingConfig):
+ provider: str = "redis"
+ log_table: str = "logs"
+ log_info_table: str = "logs_pipeline_info"
+
+ def validate(self) -> None:
+ required_env_vars = ["REDIS_CLUSTER_IP", "REDIS_CLUSTER_PORT"]
+ for var in required_env_vars:
+ if not os.getenv(var):
+ raise ValueError(f"Environment variable {var} is not set.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["redis"]
+
+
+class RedisKVLoggingProvider(KVLoggingProvider):
+ def __init__(self, config: RedisLoggingConfig):
+ logger.info(
+ f"Initializing RedisKVLoggingProvider with config: {config}"
+ )
+
+ if not all(
+ [
+ os.getenv("REDIS_CLUSTER_IP"),
+ os.getenv("REDIS_CLUSTER_PORT"),
+ ]
+ ):
+ raise ValueError(
+ "Please set the environment variables REDIS_CLUSTER_IP and REDIS_CLUSTER_PORT to run `LoggingDatabaseConnection` with `redis`."
+ )
+ try:
+ from redis.asyncio import Redis
+ except ImportError:
+ raise ValueError(
+ "Error, `redis` is not installed. Please install it using `pip install redis`."
+ )
+
+ cluster_ip = os.getenv("REDIS_CLUSTER_IP")
+ port = os.getenv("REDIS_CLUSTER_PORT")
+ self.redis = Redis(host=cluster_ip, port=port, decode_responses=True)
+ self.log_key = config.log_table
+ self.log_info_key = config.log_info_table
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.close()
+
+ async def close(self):
+ await self.redis.close()
+
+ async def log(
+ self,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ timestamp = datetime.now().timestamp()
+ log_entry = {
+ "timestamp": timestamp,
+ "log_id": str(log_id),
+ "key": key,
+ "value": value,
+ }
+ if is_info_log:
+ if "type" not in key:
+ raise ValueError("Metadata keys must contain the text 'type'")
+ log_entry["log_type"] = value
+ await self.redis.hset(
+ self.log_info_key, str(log_id), json.dumps(log_entry)
+ )
+ await self.redis.zadd(
+ f"{self.log_info_key}_sorted", {str(log_id): timestamp}
+ )
+ else:
+ await self.redis.lpush(
+ f"{self.log_key}:{str(log_id)}", json.dumps(log_entry)
+ )
+
+ async def get_run_info(
+ self, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ run_info_list = []
+ start = 0
+ count_per_batch = 100 # Adjust batch size as needed
+
+ while len(run_info_list) < limit:
+ log_ids = await self.redis.zrevrange(
+ f"{self.log_info_key}_sorted",
+ start,
+ start + count_per_batch - 1,
+ )
+ if not log_ids:
+ break # No more log IDs to process
+
+ start += count_per_batch
+
+ for log_id in log_ids:
+ log_entry = json.loads(
+ await self.redis.hget(self.log_info_key, log_id)
+ )
+ if log_type_filter:
+ if log_entry["log_type"] == log_type_filter:
+ run_info_list.append(
+ RunInfo(
+ run_id=uuid.UUID(log_entry["log_id"]),
+ log_type=log_entry["log_type"],
+ )
+ )
+ else:
+ run_info_list.append(
+ RunInfo(
+ run_id=uuid.UUID(log_entry["log_id"]),
+ log_type=log_entry["log_type"],
+ )
+ )
+
+ if len(run_info_list) >= limit:
+ break
+
+ return run_info_list[:limit]
+
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ logs = []
+ for run_id in run_ids:
+ raw_logs = await self.redis.lrange(
+ f"{self.log_key}:{str(run_id)}", 0, limit_per_run - 1
+ )
+ for raw_log in raw_logs:
+ json_log = json.loads(raw_log)
+ json_log["log_id"] = uuid.UUID(json_log["log_id"])
+ logs.append(json_log)
+ return logs
+
+
+class KVLoggingSingleton:
+ _instance = None
+ _is_configured = False
+
+ SUPPORTED_PROVIDERS = {
+ "local": LocalKVLoggingProvider,
+ "postgres": PostgresKVLoggingProvider,
+ "redis": RedisKVLoggingProvider,
+ }
+
+ @classmethod
+ def get_instance(cls):
+ return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config)
+
+ @classmethod
+ def configure(
+ cls, logging_config: Optional[LoggingConfig] = LoggingConfig()
+ ):
+ if not cls._is_configured:
+ cls._config = logging_config
+ cls._is_configured = True
+ else:
+ raise Exception("KVLoggingSingleton is already configured.")
+
+ @classmethod
+ async def log(
+ cls,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ try:
+ async with cls.get_instance() as provider:
+ await provider.log(log_id, key, value, is_info_log=is_info_log)
+
+ except Exception as e:
+ logger.error(f"Error logging data {(log_id, key, value)}: {e}")
+
+ @classmethod
+ async def get_run_info(
+ cls, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ async with cls.get_instance() as provider:
+ return await provider.get_run_info(
+ limit, log_type_filter=log_type_filter
+ )
+
+ @classmethod
+ async def get_logs(
+ cls, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ async with cls.get_instance() as provider:
+ return await provider.get_logs(run_ids, limit_per_run)
diff --git a/R2R/r2r/base/logging/log_processor.py b/R2R/r2r/base/logging/log_processor.py
new file mode 100755
index 00000000..e85d8de2
--- /dev/null
+++ b/R2R/r2r/base/logging/log_processor.py
@@ -0,0 +1,196 @@
+import contextlib
+import json
+import logging
+import statistics
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Optional, Sequence
+
+from pydantic import BaseModel
+
+logger = logging.getLogger(__name__)
+
+
+class FilterCriteria(BaseModel):
+ filters: Optional[dict[str, str]] = None
+
+
+class LogProcessor:
+ timestamp_format = "%Y-%m-%d %H:%M:%S"
+
+ def __init__(self, filters: Dict[str, Callable[[Dict[str, Any]], bool]]):
+ self.filters = filters
+ self.populations = {name: [] for name in filters}
+
+ def process_log(self, log: Dict[str, Any]):
+ for name, filter_func in self.filters.items():
+ if filter_func(log):
+ self.populations[name].append(log)
+
+
+class StatisticsCalculator:
+ @staticmethod
+ def calculate_statistics(
+ population: List[Dict[str, Any]],
+ stat_functions: Dict[str, Callable[[List[Dict[str, Any]]], Any]],
+ ) -> Dict[str, Any]:
+ return {
+ name: func(population) for name, func in stat_functions.items()
+ }
+
+
+class DistributionGenerator:
+ @staticmethod
+ def generate_distributions(
+ population: List[Dict[str, Any]],
+ dist_functions: Dict[str, Callable[[List[Dict[str, Any]]], Any]],
+ ) -> Dict[str, Any]:
+ return {
+ name: func(population) for name, func in dist_functions.items()
+ }
+
+
+class VisualizationPreparer:
+ @staticmethod
+ def prepare_visualization_data(
+ data: Dict[str, Any],
+ vis_functions: Dict[str, Callable[[Dict[str, Any]], Any]],
+ ) -> Dict[str, Any]:
+ return {name: func(data) for name, func in vis_functions.items()}
+
+
+class LogAnalyticsConfig:
+ def __init__(self, filters, stat_functions, dist_functions, vis_functions):
+ self.filters = filters
+ self.stat_functions = stat_functions
+ self.dist_functions = dist_functions
+ self.vis_functions = vis_functions
+
+
+class AnalysisTypes(BaseModel):
+ analysis_types: Optional[dict[str, Sequence[str]]] = None
+
+ @staticmethod
+ def generate_bar_chart_data(logs, key):
+ chart_data = {"labels": [], "datasets": []}
+ value_counts = defaultdict(int)
+
+ for log in logs:
+ if "entries" in log:
+ for entry in log["entries"]:
+ if entry["key"] == key:
+ value_counts[entry["value"]] += 1
+ elif "key" in log and log["key"] == key:
+ value_counts[log["value"]] += 1
+
+ for value, count in value_counts.items():
+ chart_data["labels"].append(value)
+ chart_data["datasets"].append({"label": key, "data": [count]})
+
+ return chart_data
+
+ @staticmethod
+ def calculate_basic_statistics(logs, key):
+ values = []
+ for log in logs:
+ if log["key"] == "search_results":
+ results = json.loads(log["value"])
+ scores = [
+ float(json.loads(result)["score"]) for result in results
+ ]
+ values.extend(scores)
+ else:
+ value = log.get("value")
+ if value is not None:
+ with contextlib.suppress(ValueError):
+ values.append(float(value))
+
+ if not values:
+ return {
+ "Mean": None,
+ "Median": None,
+ "Mode": None,
+ "Standard Deviation": None,
+ "Variance": None,
+ }
+
+ if len(values) == 1:
+ single_value = round(values[0], 3)
+ return {
+ "Mean": single_value,
+ "Median": single_value,
+ "Mode": single_value,
+ "Standard Deviation": 0,
+ "Variance": 0,
+ }
+
+ mean = round(sum(values) / len(values), 3)
+ median = round(statistics.median(values), 3)
+ mode = (
+ round(statistics.mode(values), 3)
+ if len(set(values)) != len(values)
+ else None
+ )
+ std_dev = round(statistics.stdev(values) if len(values) > 1 else 0, 3)
+ variance = round(
+ statistics.variance(values) if len(values) > 1 else 0, 3
+ )
+
+ return {
+ "Mean": mean,
+ "Median": median,
+ "Mode": mode,
+ "Standard Deviation": std_dev,
+ "Variance": variance,
+ }
+
+ @staticmethod
+ def calculate_percentile(logs, key, percentile):
+ values = []
+ for log in logs:
+ if log["key"] == key:
+ value = log.get("value")
+ if value is not None:
+ with contextlib.suppress(ValueError):
+ values.append(float(value))
+
+ if not values:
+ return {"percentile": percentile, "value": None}
+
+ values.sort()
+ index = int((percentile / 100) * (len(values) - 1))
+ return {"percentile": percentile, "value": round(values[index], 3)}
+
+
+class LogAnalytics:
+ def __init__(self, logs: List[Dict[str, Any]], config: LogAnalyticsConfig):
+ self.logs = logs
+ self.log_processor = LogProcessor(config.filters)
+ self.statistics_calculator = StatisticsCalculator()
+ self.distribution_generator = DistributionGenerator()
+ self.visualization_preparer = VisualizationPreparer()
+ self.config = config
+
+ def count_logs(self) -> Dict[str, Any]:
+ """Count the logs for each filter."""
+ return {
+ name: len(population)
+ for name, population in self.log_processor.populations.items()
+ }
+
+ def process_logs(self) -> Dict[str, Any]:
+ for log in self.logs:
+ self.log_processor.process_log(log)
+
+ analytics = {}
+ for name, population in self.log_processor.populations.items():
+ stats = self.statistics_calculator.calculate_statistics(
+ population, self.config.stat_functions
+ )
+ dists = self.distribution_generator.generate_distributions(
+ population, self.config.dist_functions
+ )
+ analytics[name] = {"statistics": stats, "distributions": dists}
+
+ return self.visualization_preparer.prepare_visualization_data(
+ analytics, self.config.vis_functions
+ )
diff --git a/R2R/r2r/base/logging/run_manager.py b/R2R/r2r/base/logging/run_manager.py
new file mode 100755
index 00000000..ac192bca
--- /dev/null
+++ b/R2R/r2r/base/logging/run_manager.py
@@ -0,0 +1,56 @@
+import contextvars
+import uuid
+from contextlib import asynccontextmanager
+from typing import Any
+
+from .kv_logger import KVLoggingSingleton
+
+run_id_var = contextvars.ContextVar("run_id", default=None)
+
+
+class RunManager:
+ def __init__(self, logger: KVLoggingSingleton):
+ self.logger = logger
+ self.run_info = {}
+
+ def generate_run_id(self) -> uuid.UUID:
+ return uuid.uuid4()
+
+ async def set_run_info(self, pipeline_type: str):
+ run_id = run_id_var.get()
+ if run_id is None:
+ run_id = self.generate_run_id()
+ token = run_id_var.set(run_id)
+ self.run_info[run_id] = {"pipeline_type": pipeline_type}
+ else:
+ token = run_id_var.set(run_id)
+ return run_id, token
+
+ async def get_run_info(self):
+ run_id = run_id_var.get()
+ return self.run_info.get(run_id, None)
+
+ async def log_run_info(
+ self, key: str, value: Any, is_info_log: bool = False
+ ):
+ run_id = run_id_var.get()
+ if run_id:
+ await self.logger.log(
+ log_id=run_id, key=key, value=value, is_info_log=is_info_log
+ )
+
+ async def clear_run_info(self, token: contextvars.Token):
+ run_id = run_id_var.get()
+ run_id_var.reset(token)
+ if run_id and run_id in self.run_info:
+ del self.run_info[run_id]
+
+
+@asynccontextmanager
+async def manage_run(run_manager: RunManager, pipeline_type: str):
+ run_id, token = await run_manager.set_run_info(pipeline_type)
+ try:
+ yield run_id
+ finally:
+ # Note: Do not clear the run info to ensure the run ID remains the same
+ run_id_var.reset(token)
diff --git a/R2R/r2r/base/parsers/__init__.py b/R2R/r2r/base/parsers/__init__.py
new file mode 100755
index 00000000..d7696202
--- /dev/null
+++ b/R2R/r2r/base/parsers/__init__.py
@@ -0,0 +1,5 @@
+from .base_parser import AsyncParser
+
+__all__ = [
+ "AsyncParser",
+]
diff --git a/R2R/r2r/base/parsers/base_parser.py b/R2R/r2r/base/parsers/base_parser.py
new file mode 100755
index 00000000..f1bb49d7
--- /dev/null
+++ b/R2R/r2r/base/parsers/base_parser.py
@@ -0,0 +1,14 @@
+"""Abstract base class for parsers."""
+
+from abc import ABC, abstractmethod
+from typing import AsyncGenerator, Generic, TypeVar
+
+from ..abstractions.document import DataType
+
+T = TypeVar("T")
+
+
+class AsyncParser(ABC, Generic[T]):
+ @abstractmethod
+ async def ingest(self, data: T) -> AsyncGenerator[DataType, None]:
+ pass
diff --git a/R2R/r2r/base/pipeline/__init__.py b/R2R/r2r/base/pipeline/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/pipeline/__init__.py
diff --git a/R2R/r2r/base/pipeline/base_pipeline.py b/R2R/r2r/base/pipeline/base_pipeline.py
new file mode 100755
index 00000000..3c1eff9a
--- /dev/null
+++ b/R2R/r2r/base/pipeline/base_pipeline.py
@@ -0,0 +1,233 @@
+"""Base pipeline class for running a sequence of pipes."""
+
+import asyncio
+import logging
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional
+
+from ..logging.kv_logger import KVLoggingSingleton
+from ..logging.run_manager import RunManager, manage_run
+from ..pipes.base_pipe import AsyncPipe, AsyncState
+
+logger = logging.getLogger(__name__)
+
+
+class PipelineTypes(Enum):
+ EVAL = "eval"
+ INGESTION = "ingestion"
+ SEARCH = "search"
+ RAG = "rag"
+ OTHER = "other"
+
+
+class AsyncPipeline:
+ """Pipeline class for running a sequence of pipes."""
+
+ pipeline_type: str = "other"
+
+ def __init__(
+ self,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ run_manager: Optional[RunManager] = None,
+ ):
+ self.pipes: list[AsyncPipe] = []
+ self.upstream_outputs: list[list[dict[str, str]]] = []
+ self.pipe_logger = pipe_logger or KVLoggingSingleton()
+ self.run_manager = run_manager or RunManager(self.pipe_logger)
+ self.futures = {}
+ self.level = 0
+
+ def add_pipe(
+ self,
+ pipe: AsyncPipe,
+ add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ """Add a pipe to the pipeline."""
+ self.pipes.append(pipe)
+ if not add_upstream_outputs:
+ add_upstream_outputs = []
+ self.upstream_outputs.append(add_upstream_outputs)
+
+ async def run(
+ self,
+ input: Any,
+ state: Optional[AsyncState] = None,
+ stream: bool = False,
+ run_manager: Optional[RunManager] = None,
+ log_run_info: bool = True,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ """Run the pipeline."""
+ run_manager = run_manager or self.run_manager
+
+ try:
+ PipelineTypes(self.pipeline_type)
+ except ValueError:
+ raise ValueError(
+ f"Invalid pipeline type: {self.pipeline_type}, must be one of {PipelineTypes.__members__.keys()}"
+ )
+
+ self.state = state or AsyncState()
+ current_input = input
+ async with manage_run(run_manager, self.pipeline_type):
+ if log_run_info:
+ await run_manager.log_run_info(
+ key="pipeline_type",
+ value=self.pipeline_type,
+ is_info_log=True,
+ )
+ try:
+ for pipe_num in range(len(self.pipes)):
+ config_name = self.pipes[pipe_num].config.name
+ self.futures[config_name] = asyncio.Future()
+
+ current_input = self._run_pipe(
+ pipe_num,
+ current_input,
+ run_manager,
+ *args,
+ **kwargs,
+ )
+ self.futures[config_name].set_result(current_input)
+ if not stream:
+ final_result = await self._consume_all(current_input)
+ return final_result
+ else:
+ return current_input
+ except Exception as error:
+ logger.error(f"Pipeline failed with error: {error}")
+ raise error
+
+ async def _consume_all(self, gen: AsyncGenerator) -> list[Any]:
+ result = []
+ async for item in gen:
+ if hasattr(
+ item, "__aiter__"
+ ): # Check if the item is an async generator
+ sub_result = await self._consume_all(item)
+ result.extend(sub_result)
+ else:
+ result.append(item)
+ return result
+
+ async def _run_pipe(
+ self,
+ pipe_num: int,
+ input: Any,
+ run_manager: RunManager,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ # Collect inputs, waiting for the necessary futures
+ pipe = self.pipes[pipe_num]
+ add_upstream_outputs = self.sort_upstream_outputs(
+ self.upstream_outputs[pipe_num]
+ )
+ input_dict = {"message": input}
+
+ # Group upstream outputs by prev_pipe_name
+ grouped_upstream_outputs = {}
+ for upstream_input in add_upstream_outputs:
+ upstream_pipe_name = upstream_input["prev_pipe_name"]
+ if upstream_pipe_name not in grouped_upstream_outputs:
+ grouped_upstream_outputs[upstream_pipe_name] = []
+ grouped_upstream_outputs[upstream_pipe_name].append(upstream_input)
+
+ for (
+ upstream_pipe_name,
+ upstream_inputs,
+ ) in grouped_upstream_outputs.items():
+
+ async def resolve_future_output(future):
+ result = future.result()
+ # consume the async generator
+ return [item async for item in result]
+
+ async def replay_items_as_async_gen(items):
+ for item in items:
+ yield item
+
+ temp_results = await resolve_future_output(
+ self.futures[upstream_pipe_name]
+ )
+ if upstream_pipe_name == self.pipes[pipe_num - 1].config.name:
+ input_dict["message"] = replay_items_as_async_gen(temp_results)
+
+ for upstream_input in upstream_inputs:
+ outputs = await self.state.get(upstream_pipe_name, "output")
+ prev_output_field = upstream_input.get(
+ "prev_output_field", None
+ )
+ if not prev_output_field:
+ raise ValueError(
+ "`prev_output_field` must be specified in the upstream_input"
+ )
+ input_dict[upstream_input["input_field"]] = outputs[
+ prev_output_field
+ ]
+
+ # Handle the pipe generator
+ async for ele in await pipe.run(
+ pipe.Input(**input_dict),
+ self.state,
+ run_manager,
+ *args,
+ **kwargs,
+ ):
+ yield ele
+
+ def sort_upstream_outputs(
+ self, add_upstream_outputs: list[dict[str, str]]
+ ) -> list[dict[str, str]]:
+ pipe_name_to_index = {
+ pipe.config.name: index for index, pipe in enumerate(self.pipes)
+ }
+
+ def get_pipe_index(upstream_output):
+ return pipe_name_to_index[upstream_output["prev_pipe_name"]]
+
+ sorted_outputs = sorted(
+ add_upstream_outputs, key=get_pipe_index, reverse=True
+ )
+ return sorted_outputs
+
+
+class EvalPipeline(AsyncPipeline):
+ """A pipeline for evaluation."""
+
+ pipeline_type: str = "eval"
+
+ async def run(
+ self,
+ input: Any,
+ state: Optional[AsyncState] = None,
+ stream: bool = False,
+ run_manager: Optional[RunManager] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ return await super().run(
+ input, state, stream, run_manager, *args, **kwargs
+ )
+
+ def add_pipe(
+ self,
+ pipe: AsyncPipe,
+ add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline")
+ return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs)
+
+
+async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator:
+ """Create an async generator to dequeue requests."""
+ while True:
+ request = await queue.get()
+ if request is None:
+ break
+ yield request
diff --git a/R2R/r2r/base/pipes/__init__.py b/R2R/r2r/base/pipes/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/pipes/__init__.py
diff --git a/R2R/r2r/base/pipes/base_pipe.py b/R2R/r2r/base/pipes/base_pipe.py
new file mode 100755
index 00000000..63e3d04e
--- /dev/null
+++ b/R2R/r2r/base/pipes/base_pipe.py
@@ -0,0 +1,163 @@
+import asyncio
+import logging
+import uuid
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional
+
+from pydantic import BaseModel
+
+from r2r.base.logging.kv_logger import KVLoggingSingleton
+from r2r.base.logging.run_manager import RunManager, manage_run
+
+logger = logging.getLogger(__name__)
+
+
+class PipeType(Enum):
+ INGESTOR = "ingestor"
+ EVAL = "eval"
+ GENERATOR = "generator"
+ SEARCH = "search"
+ TRANSFORM = "transform"
+ OTHER = "other"
+
+
+class AsyncState:
+ """A state object for storing data between pipes."""
+
+ def __init__(self):
+ self.data = {}
+ self.lock = asyncio.Lock()
+
+ async def update(self, outer_key: str, values: dict):
+ """Update the state with new values."""
+ async with self.lock:
+ if not isinstance(values, dict):
+ raise ValueError("Values must be contained in a dictionary.")
+ if outer_key not in self.data:
+ self.data[outer_key] = {}
+ for inner_key, inner_value in values.items():
+ self.data[outer_key][inner_key] = inner_value
+
+ async def get(self, outer_key: str, inner_key: str, default=None):
+ """Get a value from the state."""
+ async with self.lock:
+ if outer_key not in self.data:
+ raise ValueError(
+ f"Key {outer_key} does not exist in the state."
+ )
+ if inner_key not in self.data[outer_key]:
+ return default or {}
+ return self.data[outer_key][inner_key]
+
+ async def delete(self, outer_key: str, inner_key: Optional[str] = None):
+ """Delete a value from the state."""
+ async with self.lock:
+ if outer_key in self.data and not inner_key:
+ del self.data[outer_key]
+ else:
+ if inner_key not in self.data[outer_key]:
+ raise ValueError(
+ f"Key {inner_key} does not exist in the state."
+ )
+ del self.data[outer_key][inner_key]
+
+
+class AsyncPipe:
+ """An asynchronous pipe for processing data with logging capabilities."""
+
+ class PipeConfig(BaseModel):
+ """Configuration for a pipe."""
+
+ name: str = "default_pipe"
+ max_log_queue_size: int = 100
+
+ class Config:
+ extra = "forbid"
+ arbitrary_types_allowed = True
+
+ class Input(BaseModel):
+ """Input for a pipe."""
+
+ message: AsyncGenerator[Any, None]
+
+ class Config:
+ extra = "forbid"
+ arbitrary_types_allowed = True
+
+ def __init__(
+ self,
+ type: PipeType = PipeType.OTHER,
+ config: Optional[PipeConfig] = None,
+ pipe_logger: Optional[KVLoggingSingleton] = None,
+ run_manager: Optional[RunManager] = None,
+ ):
+ self._config = config or self.PipeConfig()
+ self._type = type
+ self.pipe_logger = pipe_logger or KVLoggingSingleton()
+ self.log_queue = asyncio.Queue()
+ self.log_worker_task = None
+ self._run_manager = run_manager or RunManager(self.pipe_logger)
+
+ logger.debug(
+ f"Initialized pipe {self.config.name} of type {self.type}"
+ )
+
+ @property
+ def config(self) -> PipeConfig:
+ return self._config
+
+ @property
+ def type(self) -> PipeType:
+ return self._type
+
+ async def log_worker(self):
+ while True:
+ log_data = await self.log_queue.get()
+ run_id, key, value = log_data
+ await self.pipe_logger.log(run_id, key, value)
+ self.log_queue.task_done()
+
+ async def enqueue_log(self, run_id: uuid.UUID, key: str, value: str):
+ if self.log_queue.qsize() < self.config.max_log_queue_size:
+ await self.log_queue.put((run_id, key, value))
+
+ async def run(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_manager: Optional[RunManager] = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[Any, None]:
+ """Run the pipe with logging capabilities."""
+
+ run_manager = run_manager or self._run_manager
+
+ async def wrapped_run() -> AsyncGenerator[Any, None]:
+ async with manage_run(run_manager, self.config.name) as run_id:
+ self.log_worker_task = asyncio.create_task(
+ self.log_worker(), name=f"log-worker-{self.config.name}"
+ )
+ try:
+ async for result in self._run_logic(
+ input, state, run_id=run_id, *args, **kwargs
+ ):
+ yield result
+ finally:
+ await self.log_queue.join()
+ self.log_worker_task.cancel()
+ self.log_queue = asyncio.Queue()
+
+ return wrapped_run()
+
+ @abstractmethod
+ async def _run_logic(
+ self,
+ input: Input,
+ state: AsyncState,
+ run_id: uuid.UUID,
+ *args: Any,
+ **kwargs: Any,
+ ) -> AsyncGenerator[Any, None]:
+ pass
diff --git a/R2R/r2r/base/providers/__init__.py b/R2R/r2r/base/providers/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/providers/__init__.py
diff --git a/R2R/r2r/base/providers/base_provider.py b/R2R/r2r/base/providers/base_provider.py
new file mode 100755
index 00000000..8ee8d56a
--- /dev/null
+++ b/R2R/r2r/base/providers/base_provider.py
@@ -0,0 +1,48 @@
+from abc import ABC, abstractmethod, abstractproperty
+from typing import Any, Optional, Type
+
+from pydantic import BaseModel
+
+
+class ProviderConfig(BaseModel, ABC):
+ """A base provider configuration class"""
+
+ extra_fields: dict[str, Any] = {}
+ provider: Optional[str] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+ ignore_extra = True
+
+ @abstractmethod
+ def validate(self) -> None:
+ pass
+
+ @classmethod
+ def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+ base_args = cls.__fields__.keys()
+ filtered_kwargs = {
+ k: v if v != "None" else None
+ for k, v in kwargs.items()
+ if k in base_args
+ }
+ instance = cls(**filtered_kwargs)
+ for k, v in kwargs.items():
+ if k not in base_args:
+ instance.extra_fields[k] = v
+ return instance
+
+ @abstractproperty
+ @property
+ def supported_providers(self) -> list[str]:
+ """Define a list of supported providers."""
+ pass
+
+
+class Provider(ABC):
+ """A base provider class to provide a common interface for all providers."""
+
+ def __init__(self, config: Optional[ProviderConfig] = None):
+ if config:
+ config.validate()
+ self.config = config
diff --git a/R2R/r2r/base/providers/embedding_provider.py b/R2R/r2r/base/providers/embedding_provider.py
new file mode 100755
index 00000000..8f3af56f
--- /dev/null
+++ b/R2R/r2r/base/providers/embedding_provider.py
@@ -0,0 +1,83 @@
+import logging
+from abc import abstractmethod
+from enum import Enum
+from typing import Optional
+
+from ..abstractions.search import VectorSearchResult
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddingConfig(ProviderConfig):
+ """A base embedding configuration class"""
+
+ provider: Optional[str] = None
+ base_model: Optional[str] = None
+ base_dimension: Optional[int] = None
+ rerank_model: Optional[str] = None
+ rerank_dimension: Optional[int] = None
+ rerank_transformer_type: Optional[str] = None
+ batch_size: int = 1
+
+ def validate(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return [None, "openai", "ollama", "sentence-transformers"]
+
+
+class EmbeddingProvider(Provider):
+ """An abstract class to provide a common interface for embedding providers."""
+
+ class PipeStage(Enum):
+ BASE = 1
+ RERANK = 2
+
+ def __init__(self, config: EmbeddingConfig):
+ if not isinstance(config, EmbeddingConfig):
+ raise ValueError(
+ "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+ )
+ logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+ super().__init__(config)
+
+ @abstractmethod
+ def get_embedding(self, text: str, stage: PipeStage = PipeStage.BASE):
+ pass
+
+ async def async_get_embedding(
+ self, text: str, stage: PipeStage = PipeStage.BASE
+ ):
+ return self.get_embedding(text, stage)
+
+ @abstractmethod
+ def get_embeddings(
+ self, texts: list[str], stage: PipeStage = PipeStage.BASE
+ ):
+ pass
+
+ async def async_get_embeddings(
+ self, texts: list[str], stage: PipeStage = PipeStage.BASE
+ ):
+ return self.get_embeddings(texts, stage)
+
+ @abstractmethod
+ def rerank(
+ self,
+ query: str,
+ results: list[VectorSearchResult],
+ stage: PipeStage = PipeStage.RERANK,
+ limit: int = 10,
+ ):
+ pass
+
+ @abstractmethod
+ def tokenize_string(
+ self, text: str, model: str, stage: PipeStage
+ ) -> list[int]:
+ """Tokenizes the input string."""
+ pass
diff --git a/R2R/r2r/base/providers/eval_provider.py b/R2R/r2r/base/providers/eval_provider.py
new file mode 100755
index 00000000..76053f87
--- /dev/null
+++ b/R2R/r2r/base/providers/eval_provider.py
@@ -0,0 +1,46 @@
+from typing import Optional, Union
+
+from ..abstractions.llm import GenerationConfig
+from .base_provider import Provider, ProviderConfig
+from .llm_provider import LLMConfig
+
+
+class EvalConfig(ProviderConfig):
+ """A base eval config class"""
+
+ llm: Optional[LLMConfig] = None
+
+ def validate(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider {self.provider} not supported.")
+ if self.provider and not self.llm:
+ raise ValueError(
+ "EvalConfig must have a `llm` attribute when specifying a provider."
+ )
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return [None, "local"]
+
+
+class EvalProvider(Provider):
+ """An abstract class to provide a common interface for evaluation providers."""
+
+ def __init__(self, config: EvalConfig):
+ if not isinstance(config, EvalConfig):
+ raise ValueError(
+ "EvalProvider must be initialized with a `EvalConfig`."
+ )
+
+ super().__init__(config)
+
+ def evaluate(
+ self,
+ query: str,
+ context: str,
+ completion: str,
+ eval_generation_config: Optional[GenerationConfig] = None,
+ ) -> dict[str, dict[str, Union[str, float]]]:
+ return self._evaluate(
+ query, context, completion, eval_generation_config
+ )
diff --git a/R2R/r2r/base/providers/kg_provider.py b/R2R/r2r/base/providers/kg_provider.py
new file mode 100755
index 00000000..4ae96b11
--- /dev/null
+++ b/R2R/r2r/base/providers/kg_provider.py
@@ -0,0 +1,182 @@
+"""Base classes for knowledge graph providers."""
+
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Optional, Tuple
+
+from .prompt_provider import PromptProvider
+
+if TYPE_CHECKING:
+ from r2r.main import R2RClient
+
+from ...base.utils.base_utils import EntityType, Relation
+from ..abstractions.llama_abstractions import EntityNode, LabelledNode
+from ..abstractions.llama_abstractions import Relation as LlamaRelation
+from ..abstractions.llama_abstractions import VectorStoreQuery
+from ..abstractions.llm import GenerationConfig
+from .base_provider import ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class KGConfig(ProviderConfig):
+ """A base KG config class"""
+
+ provider: Optional[str] = None
+ batch_size: int = 1
+ kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
+ kg_agent_prompt: Optional[str] = "kg_agent"
+ kg_extraction_config: Optional[GenerationConfig] = None
+
+ def validate(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return [None, "neo4j"]
+
+
+class KGProvider(ABC):
+ """An abstract class to provide a common interface for Knowledge Graphs."""
+
+ def __init__(self, config: KGConfig) -> None:
+ if not isinstance(config, KGConfig):
+ raise ValueError(
+ "KGProvider must be initialized with a `KGConfig`."
+ )
+ logger.info(f"Initializing KG provider with config: {config}")
+ self.config = config
+ self.validate_config()
+
+ def validate_config(self) -> None:
+ self.config.validate()
+
+ @property
+ @abstractmethod
+ def client(self) -> Any:
+ """Get client."""
+ pass
+
+ @abstractmethod
+ def get(self, subj: str) -> list[list[str]]:
+ """Abstract method to get triplets."""
+ pass
+
+ @abstractmethod
+ def get_rel_map(
+ self,
+ subjs: Optional[list[str]] = None,
+ depth: int = 2,
+ limit: int = 30,
+ ) -> dict[str, list[list[str]]]:
+ """Abstract method to get depth-aware rel map."""
+ pass
+
+ @abstractmethod
+ def upsert_nodes(self, nodes: list[EntityNode]) -> None:
+ """Abstract method to add triplet."""
+ pass
+
+ @abstractmethod
+ def upsert_relations(self, relations: list[LlamaRelation]) -> None:
+ """Abstract method to add triplet."""
+ pass
+
+ @abstractmethod
+ def delete(self, subj: str, rel: str, obj: str) -> None:
+ """Abstract method to delete triplet."""
+ pass
+
+ @abstractmethod
+ def get_schema(self, refresh: bool = False) -> str:
+ """Abstract method to get the schema of the graph store."""
+ pass
+
+ @abstractmethod
+ def structured_query(
+ self, query: str, param_map: Optional[dict[str, Any]] = {}
+ ) -> Any:
+ """Abstract method to query the graph store with statement and parameters."""
+ pass
+
+ @abstractmethod
+ def vector_query(
+ self, query: VectorStoreQuery, **kwargs: Any
+ ) -> Tuple[list[LabelledNode], list[float]]:
+ """Abstract method to query the graph store with a vector store query."""
+
+ # TODO - Type this method.
+ @abstractmethod
+ def update_extraction_prompt(
+ self,
+ prompt_provider: Any,
+ entity_types: list[Any],
+ relations: list[Relation],
+ ):
+ """Abstract method to update the KG extraction prompt."""
+ pass
+
+ # TODO - Type this method.
+ @abstractmethod
+ def update_kg_agent_prompt(
+ self,
+ prompt_provider: Any,
+ entity_types: list[Any],
+ relations: list[Relation],
+ ):
+ """Abstract method to update the KG agent prompt."""
+ pass
+
+
+def escape_braces(s: str) -> str:
+ """
+ Escape braces in a string.
+ This is a placeholder function - implement the actual logic as needed.
+ """
+ # Implement your escape_braces logic here
+ return s.replace("{", "{{").replace("}", "}}")
+
+
+# TODO - Make this more configurable / intelligent
+def update_kg_prompt(
+ client: "R2RClient",
+ r2r_prompts: PromptProvider,
+ prompt_base: str,
+ entity_types: list[EntityType],
+ relations: list[Relation],
+) -> None:
+ # Get the default extraction template
+ template_name: str = f"{prompt_base}_with_spec"
+
+ new_template: str = r2r_prompts.get_prompt(
+ template_name,
+ {
+ "entity_types": json.dumps(
+ {
+ "entity_types": [
+ str(entity.name) for entity in entity_types
+ ]
+ },
+ indent=4,
+ ),
+ "relations": json.dumps(
+ {"predicates": [str(relation.name) for relation in relations]},
+ indent=4,
+ ),
+ "input": """\n{input}""",
+ },
+ )
+
+ # Escape all braces in the template, except for the {input} placeholder, for formatting
+ escaped_template: str = escape_braces(new_template).replace(
+ """{{input}}""", """{input}"""
+ )
+
+ # Update the client's prompt
+ client.update_prompt(
+ prompt_base,
+ template=escaped_template,
+ input_types={"input": "str"},
+ )
diff --git a/R2R/r2r/base/providers/llm_provider.py b/R2R/r2r/base/providers/llm_provider.py
new file mode 100755
index 00000000..9b6499a4
--- /dev/null
+++ b/R2R/r2r/base/providers/llm_provider.py
@@ -0,0 +1,66 @@
+"""Base classes for language model providers."""
+
+import logging
+from abc import abstractmethod
+from typing import Optional
+
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LLMConfig(ProviderConfig):
+ """A base LLM config class"""
+
+ provider: Optional[str] = None
+ generation_config: Optional[GenerationConfig] = None
+
+ def validate(self) -> None:
+ if not self.provider:
+ raise ValueError("Provider must be set.")
+
+ if self.provider and self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["litellm", "openai"]
+
+
+class LLMProvider(Provider):
+ """An abstract class to provide a common interface for LLMs."""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ ) -> None:
+ if not isinstance(config, LLMConfig):
+ raise ValueError(
+ "LLMProvider must be initialized with a `LLMConfig`."
+ )
+ logger.info(f"Initializing LLM provider with config: {config}")
+
+ super().__init__(config)
+
+ @abstractmethod
+ def get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ """Abstract method to get a chat completion from the provider."""
+ pass
+
+ @abstractmethod
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletionChunk:
+ """Abstract method to get a completion stream from the provider."""
+ pass
diff --git a/R2R/r2r/base/providers/prompt_provider.py b/R2R/r2r/base/providers/prompt_provider.py
new file mode 100755
index 00000000..78af9e11
--- /dev/null
+++ b/R2R/r2r/base/providers/prompt_provider.py
@@ -0,0 +1,65 @@
+import logging
+from abc import abstractmethod
+from typing import Any, Optional
+
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class PromptConfig(ProviderConfig):
+ def validate(self) -> None:
+ pass
+
+ @property
+ def supported_providers(self) -> list[str]:
+ # Return a list of supported prompt providers
+ return ["default_prompt_provider"]
+
+
+class PromptProvider(Provider):
+ def __init__(self, config: Optional[PromptConfig] = None):
+ if config is None:
+ config = PromptConfig()
+ elif not isinstance(config, PromptConfig):
+ raise ValueError(
+ "PromptProvider must be initialized with a `PromptConfig`."
+ )
+ logger.info(f"Initializing PromptProvider with config {config}.")
+ super().__init__(config)
+
+ @abstractmethod
+ def add_prompt(
+ self, name: str, template: str, input_types: dict[str, str]
+ ) -> None:
+ pass
+
+ @abstractmethod
+ def get_prompt(
+ self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
+ ) -> str:
+ pass
+
+ @abstractmethod
+ def get_all_prompts(self) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = None,
+ ) -> None:
+ pass
+
+ def _get_message_payload(
+ self, system_prompt: str, task_prompt: str
+ ) -> dict:
+ return [
+ {
+ "role": "system",
+ "content": system_prompt,
+ },
+ {"role": "user", "content": task_prompt},
+ ]
diff --git a/R2R/r2r/base/providers/vector_db_provider.py b/R2R/r2r/base/providers/vector_db_provider.py
new file mode 100755
index 00000000..a6d5aaa8
--- /dev/null
+++ b/R2R/r2r/base/providers/vector_db_provider.py
@@ -0,0 +1,142 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+
+from ..abstractions.document import DocumentInfo
+from ..abstractions.search import VectorSearchResult
+from ..abstractions.vector import VectorEntry
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class VectorDBConfig(ProviderConfig):
+ provider: str
+
+ def __post_init__(self):
+ self.validate()
+ # Capture additional fields
+ for key, value in self.extra_fields.items():
+ setattr(self, key, value)
+
+ def validate(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["local", "pgvector"]
+
+
+class VectorDBProvider(Provider, ABC):
+ def __init__(self, config: VectorDBConfig):
+ if not isinstance(config, VectorDBConfig):
+ raise ValueError(
+ "VectorDBProvider must be initialized with a `VectorDBConfig`."
+ )
+ logger.info(f"Initializing VectorDBProvider with config {config}.")
+ super().__init__(config)
+
+ @abstractmethod
+ def initialize_collection(self, dimension: int) -> None:
+ pass
+
+ @abstractmethod
+ def copy(self, entry: VectorEntry, commit: bool = True) -> None:
+ pass
+
+ @abstractmethod
+ def upsert(self, entry: VectorEntry, commit: bool = True) -> None:
+ pass
+
+ @abstractmethod
+ def search(
+ self,
+ query_vector: list[float],
+ filters: dict[str, Union[bool, int, str]] = {},
+ limit: int = 10,
+ *args,
+ **kwargs,
+ ) -> list[VectorSearchResult]:
+ pass
+
+ @abstractmethod
+ def hybrid_search(
+ self,
+ query_text: str,
+ query_vector: list[float],
+ limit: int = 10,
+ filters: Optional[dict[str, Union[bool, int, str]]] = None,
+ # Hybrid search parameters
+ full_text_weight: float = 1.0,
+ semantic_weight: float = 1.0,
+ rrf_k: int = 20, # typical value is ~2x the number of results you want
+ *args,
+ **kwargs,
+ ) -> list[VectorSearchResult]:
+ pass
+
+ @abstractmethod
+ def create_index(self, index_type, column_name, index_options):
+ pass
+
+ def upsert_entries(
+ self, entries: list[VectorEntry], commit: bool = True
+ ) -> None:
+ for entry in entries:
+ self.upsert(entry, commit=commit)
+
+ def copy_entries(
+ self, entries: list[VectorEntry], commit: bool = True
+ ) -> None:
+ for entry in entries:
+ self.copy(entry, commit=commit)
+
+ @abstractmethod
+ def delete_by_metadata(
+ self,
+ metadata_fields: list[str],
+ metadata_values: list[Union[bool, int, str]],
+ ) -> list[str]:
+ if len(metadata_fields) != len(metadata_values):
+ raise ValueError(
+ "The number of metadata fields and values must be equal."
+ )
+ pass
+
+ @abstractmethod
+ def get_metadatas(
+ self,
+ metadata_fields: list[str],
+ filter_field: Optional[str] = None,
+ filter_value: Optional[str] = None,
+ ) -> list[str]:
+ pass
+
+ @abstractmethod
+ def upsert_documents_overview(
+ self, document_infs: list[DocumentInfo]
+ ) -> None:
+ pass
+
+ @abstractmethod
+ def get_documents_overview(
+ self,
+ filter_document_ids: Optional[list[str]] = None,
+ filter_user_ids: Optional[list[str]] = None,
+ ) -> list[DocumentInfo]:
+ pass
+
+ @abstractmethod
+ def get_document_chunks(self, document_id: str) -> list[dict]:
+ pass
+
+ @abstractmethod
+ def delete_from_documents_overview(
+ self, document_id: str, version: Optional[str] = None
+ ) -> dict:
+ pass
+
+ @abstractmethod
+ def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict:
+ pass
diff --git a/R2R/r2r/base/utils/__init__.py b/R2R/r2r/base/utils/__init__.py
new file mode 100755
index 00000000..104d50eb
--- /dev/null
+++ b/R2R/r2r/base/utils/__init__.py
@@ -0,0 +1,26 @@
+from .base_utils import (
+ EntityType,
+ Relation,
+ format_entity_types,
+ format_relations,
+ generate_id_from_label,
+ generate_run_id,
+ increment_version,
+ run_pipeline,
+ to_async_generator,
+)
+from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter
+
+__all__ = [
+ "RecursiveCharacterTextSplitter",
+ "TextSplitter",
+ "run_pipeline",
+ "to_async_generator",
+ "generate_run_id",
+ "generate_id_from_label",
+ "increment_version",
+ "EntityType",
+ "Relation",
+ "format_entity_types",
+ "format_relations",
+]
diff --git a/R2R/r2r/base/utils/base_utils.py b/R2R/r2r/base/utils/base_utils.py
new file mode 100755
index 00000000..12652833
--- /dev/null
+++ b/R2R/r2r/base/utils/base_utils.py
@@ -0,0 +1,63 @@
+import asyncio
+import uuid
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable
+
+if TYPE_CHECKING:
+ from ..pipeline.base_pipeline import AsyncPipeline
+
+
+def generate_run_id() -> uuid.UUID:
+ return uuid.uuid4()
+
+
+def generate_id_from_label(label: str) -> uuid.UUID:
+ return uuid.uuid5(uuid.NAMESPACE_DNS, label)
+
+
+async def to_async_generator(
+ iterable: Iterable[Any],
+) -> AsyncGenerator[Any, None]:
+ for item in iterable:
+ yield item
+
+
+def run_pipeline(pipeline: "AsyncPipeline", input: Any, *args, **kwargs):
+ if not isinstance(input, AsyncGenerator) and not isinstance(input, list):
+ input = to_async_generator([input])
+ elif not isinstance(input, AsyncGenerator):
+ input = to_async_generator(input)
+
+ async def _run_pipeline(input, *args, **kwargs):
+ return await pipeline.run(input, *args, **kwargs)
+
+ return asyncio.run(_run_pipeline(input, *args, **kwargs))
+
+
+def increment_version(version: str) -> str:
+ prefix = version[:-1]
+ suffix = int(version[-1])
+ return f"{prefix}{suffix + 1}"
+
+
+class EntityType:
+ def __init__(self, name: str):
+ self.name = name
+
+
+class Relation:
+ def __init__(self, name: str):
+ self.name = name
+
+
+def format_entity_types(entity_types: list[EntityType]) -> str:
+ lines = []
+ for entity in entity_types:
+ lines.append(entity.name)
+ return "\n".join(lines)
+
+
+def format_relations(predicates: list[Relation]) -> str:
+ lines = []
+ for predicate in predicates:
+ lines.append(predicate.name)
+ return "\n".join(lines)
diff --git a/R2R/r2r/base/utils/splitter/__init__.py b/R2R/r2r/base/utils/splitter/__init__.py
new file mode 100755
index 00000000..07a9f554
--- /dev/null
+++ b/R2R/r2r/base/utils/splitter/__init__.py
@@ -0,0 +1,3 @@
+from .text import RecursiveCharacterTextSplitter
+
+__all__ = ["RecursiveCharacterTextSplitter"]
diff --git a/R2R/r2r/base/utils/splitter/text.py b/R2R/r2r/base/utils/splitter/text.py
new file mode 100755
index 00000000..5458310c
--- /dev/null
+++ b/R2R/r2r/base/utils/splitter/text.py
@@ -0,0 +1,1979 @@
+# Source - LangChain
+# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
+"""**Text Splitters** are classes for splitting text.
+
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
+ RecursiveCharacterTextSplitter --> <name>TextSplitter
+
+Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
+
+
+**Main helpers:**
+
+.. code-block::
+
+ Document, Tokenizer, Language, LineType, HeaderType
+
+""" # noqa: E501
+
+from __future__ import annotations
+
+import copy
+import json
+import logging
+import pathlib
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from io import BytesIO, StringIO
+from typing import (
+ AbstractSet,
+ Any,
+ Callable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypedDict,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import requests
+from pydantic import BaseModel, Field, PrivateAttr
+from typing_extensions import NotRequired
+
+logger = logging.getLogger(__name__)
+
+TS = TypeVar("TS", bound="TextSplitter")
+
+
+class BaseSerialized(TypedDict):
+ """Base class for serialized objects."""
+
+ lc: int
+ id: List[str]
+ name: NotRequired[str]
+ graph: NotRequired[Dict[str, Any]]
+
+
+class SerializedConstructor(BaseSerialized):
+ """Serialized constructor."""
+
+ type: Literal["constructor"]
+ kwargs: Dict[str, Any]
+
+
+class SerializedSecret(BaseSerialized):
+ """Serialized secret."""
+
+ type: Literal["secret"]
+
+
+class SerializedNotImplemented(BaseSerialized):
+ """Serialized not implemented."""
+
+ type: Literal["not_implemented"]
+ repr: Optional[str]
+
+
+def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
+ """Try to determine if a value is different from the default.
+
+ Args:
+ value: The value.
+ key: The key.
+ model: The model.
+
+ Returns:
+ Whether the value is different from the default.
+ """
+ try:
+ return model.__fields__[key].get_default() != value
+ except Exception:
+ return True
+
+
+class Serializable(BaseModel, ABC):
+ """Serializable base class."""
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Is this class serializable?"""
+ return False
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object.
+
+ For example, if the class is `langchain.llms.openai.OpenAI`, then the
+ namespace is ["langchain", "llms", "openai"]
+ """
+ return cls.__module__.split(".")
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ """A map of constructor argument names to secret ids.
+
+ For example,
+ {"openai_api_key": "OPENAI_API_KEY"}
+ """
+ return dict()
+
+ @property
+ def lc_attributes(self) -> Dict:
+ """List of attribute names that should be included in the serialized kwargs.
+
+ These attributes must be accepted by the constructor.
+ """
+ return {}
+
+ @classmethod
+ def lc_id(cls) -> List[str]:
+ """A unique identifier for this class for serialization purposes.
+
+ The unique identifier is a list of strings that describes the path
+ to the object.
+ """
+ return [*cls.get_lc_namespace(), cls.__name__]
+
+ class Config:
+ extra = "ignore"
+
+ def __repr_args__(self) -> Any:
+ return [
+ (k, v)
+ for k, v in super().__repr_args__()
+ if (k not in self.__fields__ or try_neq_default(v, k, self))
+ ]
+
+ _lc_kwargs = PrivateAttr(default_factory=dict)
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self._lc_kwargs = kwargs
+
+ def to_json(
+ self,
+ ) -> Union[SerializedConstructor, SerializedNotImplemented]:
+ if not self.is_lc_serializable():
+ return self.to_json_not_implemented()
+
+ secrets = dict()
+ # Get latest values for kwargs if there is an attribute with same name
+ lc_kwargs = {
+ k: getattr(self, k, v)
+ for k, v in self._lc_kwargs.items()
+ if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
+ }
+
+ # Merge the lc_secrets and lc_attributes from every class in the MRO
+ for cls in [None, *self.__class__.mro()]:
+ # Once we get to Serializable, we're done
+ if cls is Serializable:
+ break
+
+ if cls:
+ deprecated_attributes = [
+ "lc_namespace",
+ "lc_serializable",
+ ]
+
+ for attr in deprecated_attributes:
+ if hasattr(cls, attr):
+ raise ValueError(
+ f"Class {self.__class__} has a deprecated "
+ f"attribute {attr}. Please use the corresponding "
+ f"classmethod instead."
+ )
+
+ # Get a reference to self bound to each class in the MRO
+ this = cast(
+ Serializable, self if cls is None else super(cls, self)
+ )
+
+ secrets.update(this.lc_secrets)
+ # Now also add the aliases for the secrets
+ # This ensures known secret aliases are hidden.
+ # Note: this does NOT hide any other extra kwargs
+ # that are not present in the fields.
+ for key in list(secrets):
+ value = secrets[key]
+ if key in this.__fields__:
+ secrets[this.__fields__[key].alias] = value
+ lc_kwargs.update(this.lc_attributes)
+
+ # include all secrets, even if not specified in kwargs
+ # as these secrets may be passed as an environment variable instead
+ for key in secrets.keys():
+ secret_value = getattr(self, key, None) or lc_kwargs.get(key)
+ if secret_value is not None:
+ lc_kwargs.update({key: secret_value})
+
+ return {
+ "lc": 1,
+ "type": "constructor",
+ "id": self.lc_id(),
+ "kwargs": (
+ lc_kwargs
+ if not secrets
+ else _replace_secrets(lc_kwargs, secrets)
+ ),
+ }
+
+ def to_json_not_implemented(self) -> SerializedNotImplemented:
+ return to_json_not_implemented(self)
+
+
+def _replace_secrets(
+ root: Dict[Any, Any], secrets_map: Dict[str, str]
+) -> Dict[Any, Any]:
+ result = root.copy()
+ for path, secret_id in secrets_map.items():
+ [*parts, last] = path.split(".")
+ current = result
+ for part in parts:
+ if part not in current:
+ break
+ current[part] = current[part].copy()
+ current = current[part]
+ if last in current:
+ current[last] = {
+ "lc": 1,
+ "type": "secret",
+ "id": [secret_id],
+ }
+ return result
+
+
+def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
+ """Serialize a "not implemented" object.
+
+ Args:
+ obj: object to serialize
+
+ Returns:
+ SerializedNotImplemented
+ """
+ _id: List[str] = []
+ try:
+ if hasattr(obj, "__name__"):
+ _id = [*obj.__module__.split("."), obj.__name__]
+ elif hasattr(obj, "__class__"):
+ _id = [
+ *obj.__class__.__module__.split("."),
+ obj.__class__.__name__,
+ ]
+ except Exception:
+ pass
+
+ result: SerializedNotImplemented = {
+ "lc": 1,
+ "type": "not_implemented",
+ "id": _id,
+ "repr": None,
+ }
+ try:
+ result["repr"] = repr(obj)
+ except Exception:
+ pass
+ return result
+
+
+class Document(Serializable):
+ """Class for storing a piece of text and associated metadata."""
+
+ page_content: str
+ """String text."""
+ metadata: dict = Field(default_factory=dict)
+ """Arbitrary metadata about the page content (e.g., source, relationships to other
+ documents, etc.).
+ """
+ type: Literal["Document"] = "Document"
+
+ def __init__(self, page_content: str, **kwargs: Any) -> None:
+ """Pass page_content in as positional or named arg."""
+ super().__init__(page_content=page_content, **kwargs)
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this class is serializable."""
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "schema", "document"]
+
+
+class BaseDocumentTransformer(ABC):
+ """Abstract base class for document transformation systems.
+
+ A document transformation system takes a sequence of Documents and returns a
+ sequence of transformed Documents.
+
+ Example:
+ .. code-block:: python
+
+ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
+ embeddings: Embeddings
+ similarity_fn: Callable = cosine_similarity
+ similarity_threshold: float = 0.95
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ stateful_documents = get_stateful_documents(documents)
+ embedded_documents = _get_embeddings_from_stateful_docs(
+ self.embeddings, stateful_documents
+ )
+ included_idxs = _filter_similar_embeddings(
+ embedded_documents, self.similarity_fn, self.similarity_threshold
+ )
+ return [stateful_documents[i] for i in sorted(included_idxs)]
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+ """ # noqa: E501
+
+ @abstractmethod
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Transform a list of documents.
+
+ Args:
+ documents: A sequence of Documents to be transformed.
+
+ Returns:
+ A list of transformed Documents.
+ """
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Asynchronously transform a list of documents.
+
+ Args:
+ documents: A sequence of Documents to be transformed.
+
+ Returns:
+ A list of transformed Documents.
+ """
+ raise NotImplementedError("This method is not implemented.")
+ # return await langchain_core.runnables.config.run_in_executor(
+ # None, self.transform_documents, documents, **kwargs
+ # )
+
+
+def _make_spacy_pipe_for_splitting(
+ pipe: str, *, max_length: int = 1_000_000
+) -> Any: # avoid importing spacy
+ try:
+ import spacy
+ except ImportError:
+ raise ImportError(
+ "Spacy is not installed, please install it with `pip install spacy`."
+ )
+ if pipe == "sentencizer":
+ from spacy.lang.en import English
+
+ sentencizer = English()
+ sentencizer.add_pipe("sentencizer")
+ else:
+ sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
+ sentencizer.max_length = max_length
+ return sentencizer
+
+
+def _split_text_with_regex(
+ text: str, separator: str, keep_separator: bool
+) -> List[str]:
+ # Now that we have the separator, split the text
+ if separator:
+ if keep_separator:
+ # The parentheses in the pattern keep the delimiters in the result.
+ _splits = re.split(f"({separator})", text)
+ splits = [
+ _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
+ ]
+ if len(_splits) % 2 == 0:
+ splits += _splits[-1:]
+ splits = [_splits[0]] + splits
+ else:
+ splits = re.split(separator, text)
+ else:
+ splits = list(text)
+ return [s for s in splits if s != ""]
+
+
+class TextSplitter(BaseDocumentTransformer, ABC):
+ """Interface for splitting text into chunks."""
+
+ def __init__(
+ self,
+ chunk_size: int = 4000,
+ chunk_overlap: int = 200,
+ length_function: Callable[[str], int] = len,
+ keep_separator: bool = False,
+ add_start_index: bool = False,
+ strip_whitespace: bool = True,
+ ) -> None:
+ """Create a new TextSplitter.
+
+ Args:
+ chunk_size: Maximum size of chunks to return
+ chunk_overlap: Overlap in characters between chunks
+ length_function: Function that measures the length of given chunks
+ keep_separator: Whether to keep the separator in the chunks
+ add_start_index: If `True`, includes chunk's start index in metadata
+ strip_whitespace: If `True`, strips whitespace from the start and end of
+ every document
+ """
+ if chunk_overlap > chunk_size:
+ raise ValueError(
+ f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
+ f"({chunk_size}), should be smaller."
+ )
+ self._chunk_size = chunk_size
+ self._chunk_overlap = chunk_overlap
+ self._length_function = length_function
+ self._keep_separator = keep_separator
+ self._add_start_index = add_start_index
+ self._strip_whitespace = strip_whitespace
+
+ @abstractmethod
+ def split_text(self, text: str) -> List[str]:
+ """Split text into multiple components."""
+
+ def create_documents(
+ self, texts: List[str], metadatas: Optional[List[dict]] = None
+ ) -> List[Document]:
+ """Create documents from a list of texts."""
+ _metadatas = metadatas or [{}] * len(texts)
+ documents = []
+ for i, text in enumerate(texts):
+ index = 0
+ previous_chunk_len = 0
+ for chunk in self.split_text(text):
+ metadata = copy.deepcopy(_metadatas[i])
+ if self._add_start_index:
+ offset = index + previous_chunk_len - self._chunk_overlap
+ index = text.find(chunk, max(0, offset))
+ metadata["start_index"] = index
+ previous_chunk_len = len(chunk)
+ new_doc = Document(page_content=chunk, metadata=metadata)
+ documents.append(new_doc)
+ return documents
+
+ def split_documents(self, documents: Iterable[Document]) -> List[Document]:
+ """Split documents."""
+ texts, metadatas = [], []
+ for doc in documents:
+ texts.append(doc.page_content)
+ metadatas.append(doc.metadata)
+ return self.create_documents(texts, metadatas=metadatas)
+
+ def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
+ text = separator.join(docs)
+ if self._strip_whitespace:
+ text = text.strip()
+ if text == "":
+ return None
+ else:
+ return text
+
+ def _merge_splits(
+ self, splits: Iterable[str], separator: str
+ ) -> List[str]:
+ # We now want to combine these smaller pieces into medium size
+ # chunks to send to the LLM.
+ separator_len = self._length_function(separator)
+
+ docs = []
+ current_doc: List[str] = []
+ total = 0
+ for d in splits:
+ _len = self._length_function(d)
+ if (
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
+ > self._chunk_size
+ ):
+ if total > self._chunk_size:
+ logger.warning(
+ f"Created a chunk of size {total}, "
+ f"which is longer than the specified {self._chunk_size}"
+ )
+ if len(current_doc) > 0:
+ doc = self._join_docs(current_doc, separator)
+ if doc is not None:
+ docs.append(doc)
+ # Keep on popping if:
+ # - we have a larger chunk than in the chunk overlap
+ # - or if we still have any chunks and the length is long
+ while total > self._chunk_overlap or (
+ total
+ + _len
+ + (separator_len if len(current_doc) > 0 else 0)
+ > self._chunk_size
+ and total > 0
+ ):
+ total -= self._length_function(current_doc[0]) + (
+ separator_len if len(current_doc) > 1 else 0
+ )
+ current_doc = current_doc[1:]
+ current_doc.append(d)
+ total += _len + (separator_len if len(current_doc) > 1 else 0)
+ doc = self._join_docs(current_doc, separator)
+ if doc is not None:
+ docs.append(doc)
+ return docs
+
+ @classmethod
+ def from_huggingface_tokenizer(
+ cls, tokenizer: Any, **kwargs: Any
+ ) -> TextSplitter:
+ """Text splitter that uses HuggingFace tokenizer to count length."""
+ try:
+ from transformers import PreTrainedTokenizerBase
+
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
+ raise ValueError(
+ "Tokenizer received was not an instance of PreTrainedTokenizerBase"
+ )
+
+ def _huggingface_tokenizer_length(text: str) -> int:
+ return len(tokenizer.encode(text))
+
+ except ImportError:
+ raise ValueError(
+ "Could not import transformers python package. "
+ "Please install it with `pip install transformers`."
+ )
+ return cls(length_function=_huggingface_tokenizer_length, **kwargs)
+
+ @classmethod
+ def from_tiktoken_encoder(
+ cls: Type[TS],
+ encoding_name: str = "gpt2",
+ model: Optional[str] = None,
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+ **kwargs: Any,
+ ) -> TS:
+ """Text splitter that uses tiktoken encoder to count length."""
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to calculate max_tokens_for_prompt. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ if model is not None:
+ enc = tiktoken.encoding_for_model(model)
+ else:
+ enc = tiktoken.get_encoding(encoding_name)
+
+ def _tiktoken_encoder(text: str) -> int:
+ return len(
+ enc.encode(
+ text,
+ allowed_special=allowed_special,
+ disallowed_special=disallowed_special,
+ )
+ )
+
+ if issubclass(cls, TokenTextSplitter):
+ extra_kwargs = {
+ "encoding_name": encoding_name,
+ "model": model,
+ "allowed_special": allowed_special,
+ "disallowed_special": disallowed_special,
+ }
+ kwargs = {**kwargs, **extra_kwargs}
+
+ return cls(length_function=_tiktoken_encoder, **kwargs)
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Transform sequence of documents by splitting them."""
+ return self.split_documents(list(documents))
+
+
+class CharacterTextSplitter(TextSplitter):
+ """Splitting text that looks at characters."""
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ is_separator_regex: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs)
+ self._separator = separator
+ self._is_separator_regex = is_separator_regex
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ # First we naively split the large input into a bunch of smaller ones.
+ separator = (
+ self._separator
+ if self._is_separator_regex
+ else re.escape(self._separator)
+ )
+ splits = _split_text_with_regex(text, separator, self._keep_separator)
+ _separator = "" if self._keep_separator else self._separator
+ return self._merge_splits(splits, _separator)
+
+
+class LineType(TypedDict):
+ """Line type as typed dict."""
+
+ metadata: Dict[str, str]
+ content: str
+
+
+class HeaderType(TypedDict):
+ """Header type as typed dict."""
+
+ level: int
+ name: str
+ data: str
+
+
+class MarkdownHeaderTextSplitter:
+ """Splitting markdown files based on specified headers."""
+
+ def __init__(
+ self,
+ headers_to_split_on: List[Tuple[str, str]],
+ return_each_line: bool = False,
+ strip_headers: bool = True,
+ ):
+ """Create a new MarkdownHeaderTextSplitter.
+
+ Args:
+ headers_to_split_on: Headers we want to track
+ return_each_line: Return each line w/ associated headers
+ strip_headers: Strip split headers from the content of the chunk
+ """
+ # Output line-by-line or aggregated into chunks w/ common headers
+ self.return_each_line = return_each_line
+ # Given the headers we want to split on,
+ # (e.g., "#, ##, etc") order by length
+ self.headers_to_split_on = sorted(
+ headers_to_split_on, key=lambda split: len(split[0]), reverse=True
+ )
+ # Strip headers split headers from the content of the chunk
+ self.strip_headers = strip_headers
+
+ def aggregate_lines_to_chunks(
+ self, lines: List[LineType]
+ ) -> List[Document]:
+ """Combine lines with common metadata into chunks
+ Args:
+ lines: Line of text / associated header metadata
+ """
+ aggregated_chunks: List[LineType] = []
+
+ for line in lines:
+ if (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] == line["metadata"]
+ ):
+ # If the last line in the aggregated list
+ # has the same metadata as the current line,
+ # append the current content to the last lines's content
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
+ elif (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] != line["metadata"]
+ # may be issues if other metadata is present
+ and len(aggregated_chunks[-1]["metadata"])
+ < len(line["metadata"])
+ and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
+ and not self.strip_headers
+ ):
+ # If the last line in the aggregated list
+ # has different metadata as the current line,
+ # and has shallower header level than the current line,
+ # and the last line is a header,
+ # and we are not stripping headers,
+ # append the current content to the last line's content
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
+ # and update the last line's metadata
+ aggregated_chunks[-1]["metadata"] = line["metadata"]
+ else:
+ # Otherwise, append the current line to the aggregated list
+ aggregated_chunks.append(line)
+
+ return [
+ Document(page_content=chunk["content"], metadata=chunk["metadata"])
+ for chunk in aggregated_chunks
+ ]
+
+ def split_text(self, text: str) -> List[Document]:
+ """Split markdown file
+ Args:
+ text: Markdown file"""
+
+ # Split the input text by newline character ("\n").
+ lines = text.split("\n")
+ # Final output
+ lines_with_metadata: List[LineType] = []
+ # Content and metadata of the chunk currently being processed
+ current_content: List[str] = []
+ current_metadata: Dict[str, str] = {}
+ # Keep track of the nested header structure
+ # header_stack: List[Dict[str, Union[int, str]]] = []
+ header_stack: List[HeaderType] = []
+ initial_metadata: Dict[str, str] = {}
+
+ in_code_block = False
+ opening_fence = ""
+
+ for line in lines:
+ stripped_line = line.strip()
+
+ if not in_code_block:
+ # Exclude inline code spans
+ if (
+ stripped_line.startswith("```")
+ and stripped_line.count("```") == 1
+ ):
+ in_code_block = True
+ opening_fence = "```"
+ elif stripped_line.startswith("~~~"):
+ in_code_block = True
+ opening_fence = "~~~"
+ else:
+ if stripped_line.startswith(opening_fence):
+ in_code_block = False
+ opening_fence = ""
+
+ if in_code_block:
+ current_content.append(stripped_line)
+ continue
+
+ # Check each line against each of the header types (e.g., #, ##)
+ for sep, name in self.headers_to_split_on:
+ # Check if line starts with a header that we intend to split on
+ if stripped_line.startswith(sep) and (
+ # Header with no text OR header is followed by space
+ # Both are valid conditions that sep is being used a header
+ len(stripped_line) == len(sep)
+ or stripped_line[len(sep)] == " "
+ ):
+ # Ensure we are tracking the header as metadata
+ if name is not None:
+ # Get the current header level
+ current_header_level = sep.count("#")
+
+ # Pop out headers of lower or same level from the stack
+ while (
+ header_stack
+ and header_stack[-1]["level"]
+ >= current_header_level
+ ):
+ # We have encountered a new header
+ # at the same or higher level
+ popped_header = header_stack.pop()
+ # Clear the metadata for the
+ # popped header in initial_metadata
+ if popped_header["name"] in initial_metadata:
+ initial_metadata.pop(popped_header["name"])
+
+ # Push the current header to the stack
+ header: HeaderType = {
+ "level": current_header_level,
+ "name": name,
+ "data": stripped_line[len(sep) :].strip(),
+ }
+ header_stack.append(header)
+ # Update initial_metadata with the current header
+ initial_metadata[name] = header["data"]
+
+ # Add the previous line to the lines_with_metadata
+ # only if current_content is not empty
+ if current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata.copy(),
+ }
+ )
+ current_content.clear()
+
+ if not self.strip_headers:
+ current_content.append(stripped_line)
+
+ break
+ else:
+ if stripped_line:
+ current_content.append(stripped_line)
+ elif current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata.copy(),
+ }
+ )
+ current_content.clear()
+
+ current_metadata = initial_metadata.copy()
+
+ if current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata,
+ }
+ )
+
+ # lines_with_metadata has each line with associated header metadata
+ # aggregate these into chunks based on common metadata
+ if not self.return_each_line:
+ return self.aggregate_lines_to_chunks(lines_with_metadata)
+ else:
+ return [
+ Document(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in lines_with_metadata
+ ]
+
+
+class ElementType(TypedDict):
+ """Element type as typed dict."""
+
+ url: str
+ xpath: str
+ content: str
+ metadata: Dict[str, str]
+
+
+class HTMLHeaderTextSplitter:
+ """
+ Splitting HTML files based on specified headers.
+ Requires lxml package.
+ """
+
+ def __init__(
+ self,
+ headers_to_split_on: List[Tuple[str, str]],
+ return_each_element: bool = False,
+ ):
+ """Create a new HTMLHeaderTextSplitter.
+
+ Args:
+ headers_to_split_on: list of tuples of headers we want to track mapped to
+ (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
+ h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
+ return_each_element: Return each element w/ associated headers.
+ """
+ # Output element-by-element or aggregated into chunks w/ common headers
+ self.return_each_element = return_each_element
+ self.headers_to_split_on = sorted(headers_to_split_on)
+
+ def aggregate_elements_to_chunks(
+ self, elements: List[ElementType]
+ ) -> List[Document]:
+ """Combine elements with common metadata into chunks
+
+ Args:
+ elements: HTML element content with associated identifying info and metadata
+ """
+ aggregated_chunks: List[ElementType] = []
+
+ for element in elements:
+ if (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] == element["metadata"]
+ ):
+ # If the last element in the aggregated list
+ # has the same metadata as the current element,
+ # append the current content to the last element's content
+ aggregated_chunks[-1]["content"] += " \n" + element["content"]
+ else:
+ # Otherwise, append the current element to the aggregated list
+ aggregated_chunks.append(element)
+
+ return [
+ Document(page_content=chunk["content"], metadata=chunk["metadata"])
+ for chunk in aggregated_chunks
+ ]
+
+ def split_text_from_url(self, url: str) -> List[Document]:
+ """Split HTML from web URL
+
+ Args:
+ url: web URL
+ """
+ r = requests.get(url)
+ return self.split_text_from_file(BytesIO(r.content))
+
+ def split_text(self, text: str) -> List[Document]:
+ """Split HTML text string
+
+ Args:
+ text: HTML text
+ """
+ return self.split_text_from_file(StringIO(text))
+
+ def split_text_from_file(self, file: Any) -> List[Document]:
+ """Split HTML file
+
+ Args:
+ file: HTML file
+ """
+ try:
+ from lxml import etree
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import lxml, please install with `pip install lxml`."
+ ) from e
+ # use lxml library to parse html document and return xml ElementTree
+ # Explicitly encoding in utf-8 allows non-English
+ # html files to be processed without garbled characters
+ parser = etree.HTMLParser(encoding="utf-8")
+ tree = etree.parse(file, parser)
+
+ # document transformation for "structure-aware" chunking is handled with xsl.
+ # see comments in html_chunks_with_headers.xslt for more detailed information.
+ xslt_path = (
+ pathlib.Path(__file__).parent
+ / "document_transformers/xsl/html_chunks_with_headers.xslt"
+ )
+ xslt_tree = etree.parse(xslt_path)
+ transform = etree.XSLT(xslt_tree)
+ result = transform(tree)
+ result_dom = etree.fromstring(str(result))
+
+ # create filter and mapping for header metadata
+ header_filter = [header[0] for header in self.headers_to_split_on]
+ header_mapping = dict(self.headers_to_split_on)
+
+ # map xhtml namespace prefix
+ ns_map = {"h": "http://www.w3.org/1999/xhtml"}
+
+ # build list of elements from DOM
+ elements = []
+ for element in result_dom.findall("*//*", ns_map):
+ if element.findall("*[@class='headers']") or element.findall(
+ "*[@class='chunk']"
+ ):
+ elements.append(
+ ElementType(
+ url=file,
+ xpath="".join(
+ [
+ node.text
+ for node in element.findall(
+ "*[@class='xpath']", ns_map
+ )
+ ]
+ ),
+ content="".join(
+ [
+ node.text
+ for node in element.findall(
+ "*[@class='chunk']", ns_map
+ )
+ ]
+ ),
+ metadata={
+ # Add text of specified headers to metadata using header
+ # mapping.
+ header_mapping[node.tag]: node.text
+ for node in filter(
+ lambda x: x.tag in header_filter,
+ element.findall(
+ "*[@class='headers']/*", ns_map
+ ),
+ )
+ },
+ )
+ )
+
+ if not self.return_each_element:
+ return self.aggregate_elements_to_chunks(elements)
+ else:
+ return [
+ Document(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in elements
+ ]
+
+
+# should be in newer Python versions (3.10+)
+# @dataclass(frozen=True, kw_only=True, slots=True)
+@dataclass(frozen=True)
+class Tokenizer:
+ """Tokenizer data class."""
+
+ chunk_overlap: int
+ """Overlap in tokens between chunks"""
+ tokens_per_chunk: int
+ """Maximum number of tokens per chunk"""
+ decode: Callable[[List[int]], str]
+ """ Function to decode a list of token ids to a string"""
+ encode: Callable[[str], List[int]]
+ """ Function to encode a string to a list of token ids"""
+
+
+def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
+ """Split incoming text and return chunks using tokenizer."""
+ splits: List[str] = []
+ input_ids = tokenizer.encode(text)
+ start_idx = 0
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+ chunk_ids = input_ids[start_idx:cur_idx]
+ while start_idx < len(input_ids):
+ splits.append(tokenizer.decode(chunk_ids))
+ if cur_idx == len(input_ids):
+ break
+ start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+ chunk_ids = input_ids[start_idx:cur_idx]
+ return splits
+
+
+class TokenTextSplitter(TextSplitter):
+ """Splitting text to tokens using model tokenizer."""
+
+ def __init__(
+ self,
+ encoding_name: str = "gpt2",
+ model: Optional[str] = None,
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs)
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to for TokenTextSplitter. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ if model is not None:
+ enc = tiktoken.encoding_for_model(model)
+ else:
+ enc = tiktoken.get_encoding(encoding_name)
+ self._tokenizer = enc
+ self._allowed_special = allowed_special
+ self._disallowed_special = disallowed_special
+
+ def split_text(self, text: str) -> List[str]:
+ def _encode(_text: str) -> List[int]:
+ return self._tokenizer.encode(
+ _text,
+ allowed_special=self._allowed_special,
+ disallowed_special=self._disallowed_special,
+ )
+
+ tokenizer = Tokenizer(
+ chunk_overlap=self._chunk_overlap,
+ tokens_per_chunk=self._chunk_size,
+ decode=self._tokenizer.decode,
+ encode=_encode,
+ )
+
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+
+class SentenceTransformersTokenTextSplitter(TextSplitter):
+ """Splitting text to tokens using sentence model tokenizer."""
+
+ def __init__(
+ self,
+ chunk_overlap: int = 50,
+ model: str = "sentence-transformers/all-mpnet-base-v2",
+ tokens_per_chunk: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs, chunk_overlap=chunk_overlap)
+
+ try:
+ from sentence_transformers import SentenceTransformer
+ except ImportError:
+ raise ImportError(
+ "Could not import sentence_transformer python package. "
+ "This is needed in order to for SentenceTransformersTokenTextSplitter. "
+ "Please install it with `pip install sentence-transformers`."
+ )
+
+ self.model = model
+ self._model = SentenceTransformer(self.model, trust_remote_code=True)
+ self.tokenizer = self._model.tokenizer
+ self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
+
+ def _initialize_chunk_configuration(
+ self, *, tokens_per_chunk: Optional[int]
+ ) -> None:
+ self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
+
+ if tokens_per_chunk is None:
+ self.tokens_per_chunk = self.maximum_tokens_per_chunk
+ else:
+ self.tokens_per_chunk = tokens_per_chunk
+
+ if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
+ raise ValueError(
+ f"The token limit of the models '{self.model}'"
+ f" is: {self.maximum_tokens_per_chunk}."
+ f" Argument tokens_per_chunk={self.tokens_per_chunk}"
+ f" > maximum token limit."
+ )
+
+ def split_text(self, text: str) -> List[str]:
+ def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
+ return self._encode(text)[1:-1]
+
+ tokenizer = Tokenizer(
+ chunk_overlap=self._chunk_overlap,
+ tokens_per_chunk=self.tokens_per_chunk,
+ decode=self.tokenizer.decode,
+ encode=encode_strip_start_and_stop_token_ids,
+ )
+
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+ def count_tokens(self, *, text: str) -> int:
+ return len(self._encode(text))
+
+ _max_length_equal_32_bit_integer: int = 2**32
+
+ def _encode(self, text: str) -> List[int]:
+ token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
+ text,
+ max_length=self._max_length_equal_32_bit_integer,
+ truncation="do_not_truncate",
+ )
+ return token_ids_with_start_and_end_token_ids
+
+
+class Language(str, Enum):
+ """Enum of the programming languages."""
+
+ CPP = "cpp"
+ GO = "go"
+ JAVA = "java"
+ KOTLIN = "kotlin"
+ JS = "js"
+ TS = "ts"
+ PHP = "php"
+ PROTO = "proto"
+ PYTHON = "python"
+ RST = "rst"
+ RUBY = "ruby"
+ RUST = "rust"
+ SCALA = "scala"
+ SWIFT = "swift"
+ MARKDOWN = "markdown"
+ LATEX = "latex"
+ HTML = "html"
+ SOL = "sol"
+ CSHARP = "csharp"
+ COBOL = "cobol"
+ C = "c"
+ LUA = "lua"
+ PERL = "perl"
+
+
+class RecursiveCharacterTextSplitter(TextSplitter):
+ """Splitting text by recursively look at characters.
+
+ Recursively tries to split by different characters to find one
+ that works.
+ """
+
+ def __init__(
+ self,
+ separators: Optional[List[str]] = None,
+ keep_separator: bool = True,
+ is_separator_regex: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(keep_separator=keep_separator, **kwargs)
+ self._separators = separators or ["\n\n", "\n", " ", ""]
+ self._is_separator_regex = is_separator_regex
+
+ def _split_text(self, text: str, separators: List[str]) -> List[str]:
+ """Split incoming text and return chunks."""
+ final_chunks = []
+ # Get appropriate separator to use
+ separator = separators[-1]
+ new_separators = []
+ for i, _s in enumerate(separators):
+ _separator = _s if self._is_separator_regex else re.escape(_s)
+ if _s == "":
+ separator = _s
+ break
+ if re.search(_separator, text):
+ separator = _s
+ new_separators = separators[i + 1 :]
+ break
+
+ _separator = (
+ separator if self._is_separator_regex else re.escape(separator)
+ )
+ splits = _split_text_with_regex(text, _separator, self._keep_separator)
+
+ # Now go merging things, recursively splitting longer texts.
+ _good_splits = []
+ _separator = "" if self._keep_separator else separator
+ for s in splits:
+ if self._length_function(s) < self._chunk_size:
+ _good_splits.append(s)
+ else:
+ if _good_splits:
+ merged_text = self._merge_splits(_good_splits, _separator)
+ final_chunks.extend(merged_text)
+ _good_splits = []
+ if not new_separators:
+ final_chunks.append(s)
+ else:
+ other_info = self._split_text(s, new_separators)
+ final_chunks.extend(other_info)
+ if _good_splits:
+ merged_text = self._merge_splits(_good_splits, _separator)
+ final_chunks.extend(merged_text)
+ return final_chunks
+
+ def split_text(self, text: str) -> List[str]:
+ return self._split_text(text, self._separators)
+
+ @classmethod
+ def from_language(
+ cls, language: Language, **kwargs: Any
+ ) -> RecursiveCharacterTextSplitter:
+ separators = cls.get_separators_for_language(language)
+ return cls(separators=separators, is_separator_regex=True, **kwargs)
+
+ @staticmethod
+ def get_separators_for_language(language: Language) -> List[str]:
+ if language == Language.CPP:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along function definitions
+ "\nvoid ",
+ "\nint ",
+ "\nfloat ",
+ "\ndouble ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.GO:
+ return [
+ # Split along function definitions
+ "\nfunc ",
+ "\nvar ",
+ "\nconst ",
+ "\ntype ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.JAVA:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\nstatic ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.KOTLIN:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\ninternal ",
+ "\ncompanion ",
+ "\nfun ",
+ "\nval ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nwhen ",
+ "\ncase ",
+ "\nelse ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.JS:
+ return [
+ # Split along function definitions
+ "\nfunction ",
+ "\nconst ",
+ "\nlet ",
+ "\nvar ",
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ "\ndefault ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.TS:
+ return [
+ "\nenum ",
+ "\ninterface ",
+ "\nnamespace ",
+ "\ntype ",
+ # Split along class definitions
+ "\nclass ",
+ # Split along function definitions
+ "\nfunction ",
+ "\nconst ",
+ "\nlet ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ "\ndefault ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PHP:
+ return [
+ # Split along function definitions
+ "\nfunction ",
+ # Split along class definitions
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nforeach ",
+ "\nwhile ",
+ "\ndo ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PROTO:
+ return [
+ # Split along message definitions
+ "\nmessage ",
+ # Split along service definitions
+ "\nservice ",
+ # Split along enum definitions
+ "\nenum ",
+ # Split along option definitions
+ "\noption ",
+ # Split along import statements
+ "\nimport ",
+ # Split along syntax declarations
+ "\nsyntax ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PYTHON:
+ return [
+ # First, try to split along class definitions
+ "\nclass ",
+ "\ndef ",
+ "\n\tdef ",
+ # Now split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RST:
+ return [
+ # Split along section titles
+ "\n=+\n",
+ "\n-+\n",
+ "\n\\*+\n",
+ # Split along directive markers
+ "\n\n.. *\n\n",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RUBY:
+ return [
+ # Split along method definitions
+ "\ndef ",
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nunless ",
+ "\nwhile ",
+ "\nfor ",
+ "\ndo ",
+ "\nbegin ",
+ "\nrescue ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RUST:
+ return [
+ # Split along function definitions
+ "\nfn ",
+ "\nconst ",
+ "\nlet ",
+ # Split along control flow statements
+ "\nif ",
+ "\nwhile ",
+ "\nfor ",
+ "\nloop ",
+ "\nmatch ",
+ "\nconst ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SCALA:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ "\nobject ",
+ # Split along method definitions
+ "\ndef ",
+ "\nval ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nmatch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SWIFT:
+ return [
+ # Split along function definitions
+ "\nfunc ",
+ # Split along class definitions
+ "\nclass ",
+ "\nstruct ",
+ "\nenum ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\ndo ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.MARKDOWN:
+ return [
+ # First, try to split along Markdown headings (starting with level 2)
+ "\n#{1,6} ",
+ # Note the alternative syntax for headings (below) is not handled here
+ # Heading level 2
+ # ---------------
+ # End of code block
+ "```\n",
+ # Horizontal lines
+ "\n\\*\\*\\*+\n",
+ "\n---+\n",
+ "\n___+\n",
+ # Note that this splitter doesn't handle horizontal lines defined
+ # by *three or more* of ***, ---, or ___, but this is not handled
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.LATEX:
+ return [
+ # First, try to split along Latex sections
+ "\n\\\\chapter{",
+ "\n\\\\section{",
+ "\n\\\\subsection{",
+ "\n\\\\subsubsection{",
+ # Now split by environments
+ "\n\\\\begin{enumerate}",
+ "\n\\\\begin{itemize}",
+ "\n\\\\begin{description}",
+ "\n\\\\begin{list}",
+ "\n\\\\begin{quote}",
+ "\n\\\\begin{quotation}",
+ "\n\\\\begin{verse}",
+ "\n\\\\begin{verbatim}",
+ # Now split by math environments
+ "\n\\\begin{align}",
+ "$$",
+ "$",
+ # Now split by the normal type of lines
+ " ",
+ "",
+ ]
+ elif language == Language.HTML:
+ return [
+ # First, try to split along HTML tags
+ "<body",
+ "<div",
+ "<p",
+ "<br",
+ "<li",
+ "<h1",
+ "<h2",
+ "<h3",
+ "<h4",
+ "<h5",
+ "<h6",
+ "<span",
+ "<table",
+ "<tr",
+ "<td",
+ "<th",
+ "<ul",
+ "<ol",
+ "<header",
+ "<footer",
+ "<nav",
+ # Head
+ "<head",
+ "<style",
+ "<script",
+ "<meta",
+ "<title",
+ "",
+ ]
+ elif language == Language.CSHARP:
+ return [
+ "\ninterface ",
+ "\nenum ",
+ "\nimplements ",
+ "\ndelegate ",
+ "\nevent ",
+ # Split along class definitions
+ "\nclass ",
+ "\nabstract ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\nstatic ",
+ "\nreturn ",
+ # Split along control flow statements
+ "\nif ",
+ "\ncontinue ",
+ "\nfor ",
+ "\nforeach ",
+ "\nwhile ",
+ "\nswitch ",
+ "\nbreak ",
+ "\ncase ",
+ "\nelse ",
+ # Split by exceptions
+ "\ntry ",
+ "\nthrow ",
+ "\nfinally ",
+ "\ncatch ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SOL:
+ return [
+ # Split along compiler information definitions
+ "\npragma ",
+ "\nusing ",
+ # Split along contract definitions
+ "\ncontract ",
+ "\ninterface ",
+ "\nlibrary ",
+ # Split along method definitions
+ "\nconstructor ",
+ "\ntype ",
+ "\nfunction ",
+ "\nevent ",
+ "\nmodifier ",
+ "\nerror ",
+ "\nstruct ",
+ "\nenum ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\ndo while ",
+ "\nassembly ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.COBOL:
+ return [
+ # Split along divisions
+ "\nIDENTIFICATION DIVISION.",
+ "\nENVIRONMENT DIVISION.",
+ "\nDATA DIVISION.",
+ "\nPROCEDURE DIVISION.",
+ # Split along sections within DATA DIVISION
+ "\nWORKING-STORAGE SECTION.",
+ "\nLINKAGE SECTION.",
+ "\nFILE SECTION.",
+ # Split along sections within PROCEDURE DIVISION
+ "\nINPUT-OUTPUT SECTION.",
+ # Split along paragraphs and common statements
+ "\nOPEN ",
+ "\nCLOSE ",
+ "\nREAD ",
+ "\nWRITE ",
+ "\nIF ",
+ "\nELSE ",
+ "\nMOVE ",
+ "\nPERFORM ",
+ "\nUNTIL ",
+ "\nVARYING ",
+ "\nACCEPT ",
+ "\nDISPLAY ",
+ "\nSTOP RUN.",
+ # Split by the normal type of lines
+ "\n",
+ " ",
+ "",
+ ]
+
+ else:
+ raise ValueError(
+ f"Language {language} is not supported! "
+ f"Please choose from {list(Language)}"
+ )
+
+
+class NLTKTextSplitter(TextSplitter):
+ """Splitting text using NLTK package."""
+
+ def __init__(
+ self, separator: str = "\n\n", language: str = "english", **kwargs: Any
+ ) -> None:
+ """Initialize the NLTK splitter."""
+ super().__init__(**kwargs)
+ try:
+ from nltk.tokenize import sent_tokenize
+
+ self._tokenizer = sent_tokenize
+ except ImportError:
+ raise ImportError(
+ "NLTK is not installed, please install it with `pip install nltk`."
+ )
+ self._separator = separator
+ self._language = language
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ # First we naively split the large input into a bunch of smaller ones.
+ splits = self._tokenizer(text, language=self._language)
+ return self._merge_splits(splits, self._separator)
+
+
+class SpacyTextSplitter(TextSplitter):
+ """Splitting text using Spacy package.
+
+
+ Per default, Spacy's `en_core_web_sm` model is used and
+ its default max_length is 1000000 (it is the length of maximum character
+ this model takes which can be increased for large files). For a faster, but
+ potentially less accurate splitting, you can use `pipe='sentencizer'`.
+ """
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ pipe: str = "en_core_web_sm",
+ max_length: int = 1_000_000,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the spacy text splitter."""
+ super().__init__(**kwargs)
+ self._tokenizer = _make_spacy_pipe_for_splitting(
+ pipe, max_length=max_length
+ )
+ self._separator = separator
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ splits = (s.text for s in self._tokenizer(text).sents)
+ return self._merge_splits(splits, self._separator)
+
+
+class KonlpyTextSplitter(TextSplitter):
+ """Splitting text using Konlpy package.
+
+ It is good for splitting Korean text.
+ """
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the Konlpy text splitter."""
+ super().__init__(**kwargs)
+ self._separator = separator
+ try:
+ from konlpy.tag import Kkma
+ except ImportError:
+ raise ImportError(
+ """
+ Konlpy is not installed, please install it with
+ `pip install konlpy`
+ """
+ )
+ self.kkma = Kkma()
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ splits = self.kkma.sentences(text)
+ return self._merge_splits(splits, self._separator)
+
+
+# For backwards compatibility
+class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Python syntax."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a PythonCodeTextSplitter."""
+ separators = self.get_separators_for_language(Language.PYTHON)
+ super().__init__(separators=separators, **kwargs)
+
+
+class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Markdown-formatted headings."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a MarkdownTextSplitter."""
+ separators = self.get_separators_for_language(Language.MARKDOWN)
+ super().__init__(separators=separators, **kwargs)
+
+
+class LatexTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Latex-formatted layout elements."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a LatexTextSplitter."""
+ separators = self.get_separators_for_language(Language.LATEX)
+ super().__init__(separators=separators, **kwargs)
+
+
+class RecursiveJsonSplitter:
+ def __init__(
+ self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
+ ):
+ super().__init__()
+ self.max_chunk_size = max_chunk_size
+ self.min_chunk_size = (
+ min_chunk_size
+ if min_chunk_size is not None
+ else max(max_chunk_size - 200, 50)
+ )
+
+ @staticmethod
+ def _json_size(data: Dict) -> int:
+ """Calculate the size of the serialized JSON object."""
+ return len(json.dumps(data))
+
+ @staticmethod
+ def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
+ """Set a value in a nested dictionary based on the given path."""
+ for key in path[:-1]:
+ d = d.setdefault(key, {})
+ d[path[-1]] = value
+
+ def _list_to_dict_preprocessing(self, data: Any) -> Any:
+ if isinstance(data, dict):
+ # Process each key-value pair in the dictionary
+ return {
+ k: self._list_to_dict_preprocessing(v) for k, v in data.items()
+ }
+ elif isinstance(data, list):
+ # Convert the list to a dictionary with index-based keys
+ return {
+ str(i): self._list_to_dict_preprocessing(item)
+ for i, item in enumerate(data)
+ }
+ else:
+ # Base case: the item is neither a dict nor a list, so return it unchanged
+ return data
+
+ def _json_split(
+ self,
+ data: Dict[str, Any],
+ current_path: List[str] = [],
+ chunks: List[Dict] = [{}],
+ ) -> List[Dict]:
+ """
+ Split json into maximum size dictionaries while preserving structure.
+ """
+ if isinstance(data, dict):
+ for key, value in data.items():
+ new_path = current_path + [key]
+ chunk_size = self._json_size(chunks[-1])
+ size = self._json_size({key: value})
+ remaining = self.max_chunk_size - chunk_size
+
+ if size < remaining:
+ # Add item to current chunk
+ self._set_nested_dict(chunks[-1], new_path, value)
+ else:
+ if chunk_size >= self.min_chunk_size:
+ # Chunk is big enough, start a new chunk
+ chunks.append({})
+
+ # Iterate
+ self._json_split(value, new_path, chunks)
+ else:
+ # handle single item
+ self._set_nested_dict(chunks[-1], current_path, data)
+ return chunks
+
+ def split_json(
+ self,
+ json_data: Dict[str, Any],
+ convert_lists: bool = False,
+ ) -> List[Dict]:
+ """Splits JSON into a list of JSON chunks"""
+
+ if convert_lists:
+ chunks = self._json_split(
+ self._list_to_dict_preprocessing(json_data)
+ )
+ else:
+ chunks = self._json_split(json_data)
+
+ # Remove the last chunk if it's empty
+ if not chunks[-1]:
+ chunks.pop()
+ return chunks
+
+ def split_text(
+ self, json_data: Dict[str, Any], convert_lists: bool = False
+ ) -> List[str]:
+ """Splits JSON into a list of JSON formatted strings"""
+
+ chunks = self.split_json(
+ json_data=json_data, convert_lists=convert_lists
+ )
+
+ # Convert to string
+ return [json.dumps(chunk) for chunk in chunks]
+
+ def create_documents(
+ self,
+ texts: List[Dict],
+ convert_lists: bool = False,
+ metadatas: Optional[List[dict]] = None,
+ ) -> List[Document]:
+ """Create documents from a list of json objects (Dict)."""
+ _metadatas = metadatas or [{}] * len(texts)
+ documents = []
+ for i, text in enumerate(texts):
+ for chunk in self.split_text(
+ json_data=text, convert_lists=convert_lists
+ ):
+ metadata = copy.deepcopy(_metadatas[i])
+ new_doc = Document(page_content=chunk, metadata=metadata)
+ documents.append(new_doc)
+ return documents