diff options
Diffstat (limited to 'R2R/r2r/base')
32 files changed, 5160 insertions, 0 deletions
diff --git a/R2R/r2r/base/__init__.py b/R2R/r2r/base/__init__.py new file mode 100755 index 00000000..a6794a84 --- /dev/null +++ b/R2R/r2r/base/__init__.py @@ -0,0 +1,160 @@ +from .abstractions.base import AsyncSyncMeta, UserStats, syncable +from .abstractions.document import ( + DataType, + Document, + DocumentInfo, + DocumentType, + Entity, + Extraction, + ExtractionType, + Fragment, + FragmentType, + KGExtraction, + Triple, + extract_entities, + extract_triples, +) +from .abstractions.exception import R2RDocumentProcessingError, R2RException +from .abstractions.llama_abstractions import VectorStoreQuery +from .abstractions.llm import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, + RAGCompletion, +) +from .abstractions.prompt import Prompt +from .abstractions.search import ( + AggregateSearchResult, + KGSearchRequest, + KGSearchResult, + KGSearchSettings, + VectorSearchRequest, + VectorSearchResult, + VectorSearchSettings, +) +from .abstractions.vector import Vector, VectorEntry, VectorType +from .logging.kv_logger import ( + KVLoggingSingleton, + LocalKVLoggingProvider, + LoggingConfig, + PostgresKVLoggingProvider, + PostgresLoggingConfig, + RedisKVLoggingProvider, + RedisLoggingConfig, +) +from .logging.log_processor import ( + AnalysisTypes, + FilterCriteria, + LogAnalytics, + LogAnalyticsConfig, + LogProcessor, +) +from .logging.run_manager import RunManager, manage_run +from .parsers import AsyncParser +from .pipeline.base_pipeline import AsyncPipeline +from .pipes.base_pipe import AsyncPipe, AsyncState, PipeType +from .providers.embedding_provider import EmbeddingConfig, EmbeddingProvider +from .providers.eval_provider import EvalConfig, EvalProvider +from .providers.kg_provider import KGConfig, KGProvider, update_kg_prompt +from .providers.llm_provider import LLMConfig, LLMProvider +from .providers.prompt_provider import PromptConfig, PromptProvider +from .providers.vector_db_provider import VectorDBConfig, VectorDBProvider +from .utils import ( + EntityType, + RecursiveCharacterTextSplitter, + Relation, + TextSplitter, + format_entity_types, + format_relations, + generate_id_from_label, + generate_run_id, + increment_version, + run_pipeline, + to_async_generator, +) + +__all__ = [ + # Logging + "AsyncParser", + "AnalysisTypes", + "LogAnalytics", + "LogAnalyticsConfig", + "LogProcessor", + "LoggingConfig", + "LocalKVLoggingProvider", + "PostgresLoggingConfig", + "PostgresKVLoggingProvider", + "RedisLoggingConfig", + "AsyncSyncMeta", + "syncable", + "RedisKVLoggingProvider", + "KVLoggingSingleton", + "RunManager", + "manage_run", + # Abstractions + "VectorEntry", + "VectorType", + "Vector", + "VectorSearchRequest", + "VectorSearchResult", + "VectorSearchSettings", + "KGSearchRequest", + "KGSearchResult", + "KGSearchSettings", + "AggregateSearchResult", + "AsyncPipe", + "PipeType", + "AsyncState", + "AsyncPipe", + "Prompt", + "DataType", + "DocumentType", + "Document", + "DocumentInfo", + "Extraction", + "ExtractionType", + "Fragment", + "FragmentType", + "extract_entities", + "Entity", + "extract_triples", + "R2RException", + "R2RDocumentProcessingError", + "Triple", + "KGExtraction", + "UserStats", + # Pipelines + "AsyncPipeline", + # Providers + "EmbeddingConfig", + "EmbeddingProvider", + "EvalConfig", + "EvalProvider", + "PromptConfig", + "PromptProvider", + "GenerationConfig", + "RAGCompletion", + "VectorStoreQuery", + "LLMChatCompletion", + "LLMChatCompletionChunk", + "LLMConfig", + "LLMProvider", + "VectorDBConfig", + "VectorDBProvider", + "KGProvider", + "KGConfig", + "update_kg_prompt", + # Other + "FilterCriteria", + "TextSplitter", + "RecursiveCharacterTextSplitter", + "to_async_generator", + "EntityType", + "Relation", + "format_entity_types", + "format_relations", + "increment_version", + "run_pipeline", + "generate_run_id", + "generate_id_from_label", +] diff --git a/R2R/r2r/base/abstractions/__init__.py b/R2R/r2r/base/abstractions/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/abstractions/__init__.py diff --git a/R2R/r2r/base/abstractions/base.py b/R2R/r2r/base/abstractions/base.py new file mode 100755 index 00000000..7121f6ce --- /dev/null +++ b/R2R/r2r/base/abstractions/base.py @@ -0,0 +1,93 @@ +import asyncio +import uuid +from typing import List + +from pydantic import BaseModel + + +class UserStats(BaseModel): + user_id: uuid.UUID + num_files: int + total_size_in_bytes: int + document_ids: List[uuid.UUID] + + +class AsyncSyncMeta(type): + _event_loop = None # Class-level shared event loop + + @classmethod + def get_event_loop(cls): + if cls._event_loop is None or cls._event_loop.is_closed(): + cls._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._event_loop) + return cls._event_loop + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + for attr_name, attr_value in dct.items(): + if asyncio.iscoroutinefunction(attr_value) and getattr( + attr_value, "_syncable", False + ): + sync_method_name = attr_name[ + 1: + ] # Remove leading 'a' for sync method + async_method = attr_value + + def make_sync_method(async_method): + def sync_wrapper(self, *args, **kwargs): + loop = cls.get_event_loop() + if not loop.is_running(): + # Setup to run the loop in a background thread if necessary + # to prevent blocking the main thread in a synchronous call environment + from threading import Thread + + result = None + exception = None + + def run(): + nonlocal result, exception + try: + asyncio.set_event_loop(loop) + result = loop.run_until_complete( + async_method(self, *args, **kwargs) + ) + except Exception as e: + exception = e + finally: + generation_config = kwargs.get( + "rag_generation_config", None + ) + if ( + not generation_config + or not generation_config.stream + ): + loop.run_until_complete( + loop.shutdown_asyncgens() + ) + loop.close() + + thread = Thread(target=run) + thread.start() + thread.join() + if exception: + raise exception + return result + else: + # If there's already a running loop, schedule and execute the coroutine + future = asyncio.run_coroutine_threadsafe( + async_method(self, *args, **kwargs), loop + ) + return future.result() + + return sync_wrapper + + setattr( + new_cls, sync_method_name, make_sync_method(async_method) + ) + return new_cls + + +def syncable(func): + """Decorator to mark methods for synchronous wrapper creation.""" + func._syncable = True + return func diff --git a/R2R/r2r/base/abstractions/document.py b/R2R/r2r/base/abstractions/document.py new file mode 100755 index 00000000..117db7b9 --- /dev/null +++ b/R2R/r2r/base/abstractions/document.py @@ -0,0 +1,242 @@ +"""Abstractions for documents and their extractions.""" + +import base64 +import json +import logging +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +DataType = Union[str, bytes] + + +class DocumentType(str, Enum): + """Types of documents that can be stored.""" + + CSV = "csv" + DOCX = "docx" + HTML = "html" + JSON = "json" + MD = "md" + PDF = "pdf" + PPTX = "pptx" + TXT = "txt" + XLSX = "xlsx" + GIF = "gif" + PNG = "png" + JPG = "jpg" + JPEG = "jpeg" + SVG = "svg" + MP3 = "mp3" + MP4 = "mp4" + + +class Document(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + type: DocumentType + data: Union[str, bytes] + metadata: dict + + def __init__(self, *args, **kwargs): + data = kwargs.get("data") + if data and isinstance(data, str): + try: + # Try to decode if it's already base64 encoded + kwargs["data"] = base64.b64decode(data) + except: + # If it's not base64, encode it to bytes + kwargs["data"] = data.encode("utf-8") + + doc_type = kwargs.get("type") + if isinstance(doc_type, str): + kwargs["type"] = DocumentType(doc_type) + + # Generate UUID based on the hash of the data + if "id" not in kwargs: + if isinstance(kwargs["data"], bytes): + data_hash = uuid.uuid5( + uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8") + ) + else: + data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"]) + + kwargs["id"] = data_hash # Set the id based on the data hash + + super().__init__(*args, **kwargs) + + class Config: + arbitrary_types_allowed = True + json_encoders = { + uuid.UUID: str, + bytes: lambda v: base64.b64encode(v).decode("utf-8"), + } + + +class DocumentStatus(str, Enum): + """Status of document processing.""" + + PROCESSING = "processing" + # TODO - Extend support for `partial-failure` + # PARTIAL_FAILURE = "partial-failure" + FAILURE = "failure" + SUCCESS = "success" + + +class DocumentInfo(BaseModel): + """Base class for document information handling.""" + + document_id: uuid.UUID + version: str + size_in_bytes: int + metadata: dict + status: DocumentStatus = DocumentStatus.PROCESSING + + user_id: Optional[uuid.UUID] = None + title: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + def convert_to_db_entry(self): + """Prepare the document info for database entry, extracting certain fields from metadata.""" + now = datetime.now() + metadata = self.metadata + if "user_id" in metadata: + metadata["user_id"] = str(metadata["user_id"]) + + metadata["title"] = metadata.get("title", "N/A") + return { + "document_id": str(self.document_id), + "title": metadata.get("title", "N/A"), + "user_id": metadata.get("user_id", None), + "version": self.version, + "size_in_bytes": self.size_in_bytes, + "metadata": json.dumps(self.metadata), + "created_at": self.created_at or now, + "updated_at": self.updated_at or now, + "status": self.status, + } + + +class ExtractionType(Enum): + """Types of extractions that can be performed.""" + + TXT = "txt" + IMG = "img" + MOV = "mov" + + +class Extraction(BaseModel): + """An extraction from a document.""" + + id: uuid.UUID + type: ExtractionType = ExtractionType.TXT + data: DataType + metadata: dict + document_id: uuid.UUID + + +class FragmentType(Enum): + """A type of fragment that can be extracted from a document.""" + + TEXT = "text" + IMAGE = "image" + + +class Fragment(BaseModel): + """A fragment extracted from a document.""" + + id: uuid.UUID + type: FragmentType + data: DataType + metadata: dict + document_id: uuid.UUID + extraction_id: uuid.UUID + + +class Entity(BaseModel): + """An entity extracted from a document.""" + + category: str + subcategory: Optional[str] = None + value: str + + def __str__(self): + return ( + f"{self.category}:{self.subcategory}:{self.value}" + if self.subcategory + else f"{self.category}:{self.value}" + ) + + +class Triple(BaseModel): + """A triple extracted from a document.""" + + subject: str + predicate: str + object: str + + +def extract_entities(llm_payload: list[str]) -> dict[str, Entity]: + entities = {} + for entry in llm_payload: + try: + if "], " in entry: # Check if the entry is an entity + entry_val = entry.split("], ")[0] + "]" + entry = entry.split("], ")[1] + colon_count = entry.count(":") + + if colon_count == 1: + category, value = entry.split(":") + subcategory = None + elif colon_count >= 2: + parts = entry.split(":", 2) + category, subcategory, value = ( + parts[0], + parts[1], + parts[2], + ) + else: + raise ValueError("Unexpected entry format") + + entities[entry_val] = Entity( + category=category, subcategory=subcategory, value=value + ) + except Exception as e: + logger.error(f"Error processing entity {entry}: {e}") + continue + return entities + + +def extract_triples( + llm_payload: list[str], entities: dict[str, Entity] +) -> list[Triple]: + triples = [] + for entry in llm_payload: + try: + if "], " not in entry: # Check if the entry is an entity + elements = entry.split(" ") + subject = elements[0] + predicate = elements[1] + object = " ".join(elements[2:]) + subject = entities[subject].value # Use entity.value + if "[" in object and "]" in object: + object = entities[object].value # Use entity.value + triples.append( + Triple(subject=subject, predicate=predicate, object=object) + ) + except Exception as e: + logger.error(f"Error processing triplet {entry}: {e}") + continue + return triples + + +class KGExtraction(BaseModel): + """An extraction from a document that is part of a knowledge graph.""" + + entities: dict[str, Entity] + triples: list[Triple] diff --git a/R2R/r2r/base/abstractions/exception.py b/R2R/r2r/base/abstractions/exception.py new file mode 100755 index 00000000..c76625a3 --- /dev/null +++ b/R2R/r2r/base/abstractions/exception.py @@ -0,0 +1,16 @@ +from typing import Any, Optional + + +class R2RException(Exception): + def __init__( + self, message: str, status_code: int, detail: Optional[Any] = None + ): + self.message = message + self.status_code = status_code + super().__init__(self.message) + + +class R2RDocumentProcessingError(R2RException): + def __init__(self, error_message, document_id): + self.document_id = document_id + super().__init__(error_message, 400, {"document_id": document_id}) diff --git a/R2R/r2r/base/abstractions/llama_abstractions.py b/R2R/r2r/base/abstractions/llama_abstractions.py new file mode 100755 index 00000000..f6bc36e6 --- /dev/null +++ b/R2R/r2r/base/abstractions/llama_abstractions.py @@ -0,0 +1,439 @@ +# abstractions are taken from LlamaIndex +# https://github.com/run-llama/llama_index +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr + + +class LabelledNode(BaseModel): + """An entity in a graph.""" + + label: str = Field(default="node", description="The label of the node.") + embedding: Optional[List[float]] = Field( + default=None, description="The embeddings of the node." + ) + properties: Dict[str, Any] = Field(default_factory=dict) + + @abstractmethod + def __str__(self) -> str: + """Return the string representation of the node.""" + ... + + @property + @abstractmethod + def id(self) -> str: + """Get the node id.""" + ... + + +class EntityNode(LabelledNode): + """An entity in a graph.""" + + name: str = Field(description="The name of the entity.") + label: str = Field(default="entity", description="The label of the node.") + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the node.""" + return self.name + + @property + def id(self) -> str: + """Get the node id.""" + return self.name.replace('"', " ") + + +class ChunkNode(LabelledNode): + """A text chunk in a graph.""" + + text: str = Field(description="The text content of the chunk.") + id_: Optional[str] = Field( + default=None, + description="The id of the node. Defaults to a hash of the text.", + ) + label: str = Field( + default="text_chunk", description="The label of the node." + ) + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the node.""" + return self.text + + @property + def id(self) -> str: + """Get the node id.""" + return str(hash(self.text)) if self.id_ is None else self.id_ + + +class Relation(BaseModel): + """A relation connecting two entities in a graph.""" + + label: str + source_id: str + target_id: str + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the relation.""" + return self.label + + @property + def id(self) -> str: + """Get the relation id.""" + return self.label + + +Triplet = Tuple[LabelledNode, Relation, LabelledNode] + + +class VectorStoreQueryMode(str, Enum): + """Vector store query mode.""" + + DEFAULT = "default" + SPARSE = "sparse" + HYBRID = "hybrid" + TEXT_SEARCH = "text_search" + SEMANTIC_HYBRID = "semantic_hybrid" + + # fit learners + SVM = "svm" + LOGISTIC_REGRESSION = "logistic_regression" + LINEAR_REGRESSION = "linear_regression" + + # maximum marginal relevance + MMR = "mmr" + + +class FilterOperator(str, Enum): + """Vector store filter operator.""" + + # TODO add more operators + EQ = "==" # default operator (string, int, float) + GT = ">" # greater than (int, float) + LT = "<" # less than (int, float) + NE = "!=" # not equal to (string, int, float) + GTE = ">=" # greater than or equal to (int, float) + LTE = "<=" # less than or equal to (int, float) + IN = "in" # In array (string or number) + NIN = "nin" # Not in array (string or number) + ANY = "any" # Contains any (array of strings) + ALL = "all" # Contains all (array of strings) + TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field) + CONTAINS = "contains" # metadata array contains value (string or number) + + +class MetadataFilter(BaseModel): + """Comprehensive metadata filter for vector stores to support more operators. + + Value uses Strict* types, as int, float and str are compatible types and were all + converted to string before. + + See: https://docs.pydantic.dev/latest/usage/types/#strict-types + """ + + key: str + value: Union[ + StrictInt, + StrictFloat, + StrictStr, + List[StrictStr], + List[StrictFloat], + List[StrictInt], + ] + operator: FilterOperator = FilterOperator.EQ + + @classmethod + def from_dict( + cls, + filter_dict: Dict, + ) -> "MetadataFilter": + """Create MetadataFilter from dictionary. + + Args: + filter_dict: Dict with key, value and operator. + + """ + return MetadataFilter.parse_obj(filter_dict) + + +# # TODO: Deprecate ExactMatchFilter and use MetadataFilter instead +# # Keep class for now so that AutoRetriever can still work with old vector stores +# class ExactMatchFilter(BaseModel): +# key: str +# value: Union[StrictInt, StrictFloat, StrictStr] + +# set ExactMatchFilter to MetadataFilter +ExactMatchFilter = MetadataFilter + + +class FilterCondition(str, Enum): + """Vector store filter conditions to combine different filters.""" + + # TODO add more conditions + AND = "and" + OR = "or" + + +class MetadataFilters(BaseModel): + """Metadata filters for vector stores.""" + + # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc. + filters: List[Union[MetadataFilter, ExactMatchFilter, "MetadataFilters"]] + # and/or such conditions for combining different filters + condition: Optional[FilterCondition] = FilterCondition.AND + + +@dataclass +class VectorStoreQuery: + """Vector store query.""" + + query_embedding: Optional[List[float]] = None + similarity_top_k: int = 1 + doc_ids: Optional[List[str]] = None + node_ids: Optional[List[str]] = None + query_str: Optional[str] = None + output_fields: Optional[List[str]] = None + embedding_field: Optional[str] = None + + mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT + + # NOTE: only for hybrid search (0 for bm25, 1 for vector search) + alpha: Optional[float] = None + + # metadata filters + filters: Optional[MetadataFilters] = None + + # only for mmr + mmr_threshold: Optional[float] = None + + # NOTE: currently only used by postgres hybrid search + sparse_top_k: Optional[int] = None + # NOTE: return top k results from hybrid search. similarity_top_k is used for dense search top k + hybrid_top_k: Optional[int] = None + + +class PropertyGraphStore(ABC): + """Abstract labelled graph store protocol. + + This protocol defines the interface for a graph store, which is responsible + for storing and retrieving knowledge graph data. + + Attributes: + client: Any: The client used to connect to the graph store. + get: Callable[[str], List[List[str]]]: Get triplets for a given subject. + get_rel_map: Callable[[Optional[List[str]], int], Dict[str, List[List[str]]]]: + Get subjects' rel map in max depth. + upsert_triplet: Callable[[str, str, str], None]: Upsert a triplet. + delete: Callable[[str, str, str], None]: Delete a triplet. + persist: Callable[[str, Optional[fsspec.AbstractFileSystem]], None]: + Persist the graph store to a file. + """ + + supports_structured_queries: bool = False + supports_vector_queries: bool = False + + @property + def client(self) -> Any: + """Get client.""" + ... + + @abstractmethod + def get( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Get nodes with matching values.""" + ... + + @abstractmethod + def get_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get triplets with matching values.""" + ... + + @abstractmethod + def get_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get depth-aware rel map.""" + ... + + @abstractmethod + def upsert_nodes(self, nodes: List[LabelledNode]) -> None: + """Upsert nodes.""" + ... + + @abstractmethod + def upsert_relations(self, relations: List[Relation]) -> None: + """Upsert relations.""" + ... + + @abstractmethod + def delete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Delete matching data.""" + ... + + @abstractmethod + def structured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = None + ) -> Any: + """Query the graph store with statement and parameters.""" + ... + + @abstractmethod + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Query the graph store with a vector store query.""" + ... + + # def persist( + # self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None + # ) -> None: + # """Persist the graph store to a file.""" + # return + + def get_schema(self, refresh: bool = False) -> Any: + """Get the schema of the graph store.""" + return None + + def get_schema_str(self, refresh: bool = False) -> str: + """Get the schema of the graph store as a string.""" + return str(self.get_schema(refresh=refresh)) + + ### ----- Async Methods ----- ### + + async def aget( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Asynchronously get nodes with matching values.""" + return self.get(properties, ids) + + async def aget_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + """Asynchronously get triplets with matching values.""" + return self.get_triplets(entity_names, relation_names, properties, ids) + + async def aget_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Asynchronously get depth-aware rel map.""" + return self.get_rel_map(graph_nodes, depth, limit, ignore_rels) + + async def aupsert_nodes(self, nodes: List[LabelledNode]) -> None: + """Asynchronously add nodes.""" + return self.upsert_nodes(nodes) + + async def aupsert_relations(self, relations: List[Relation]) -> None: + """Asynchronously add relations.""" + return self.upsert_relations(relations) + + async def adelete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Asynchronously delete matching data.""" + return self.delete(entity_names, relation_names, properties, ids) + + async def astructured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = {} + ) -> Any: + """Asynchronously query the graph store with statement and parameters.""" + return self.structured_query(query, param_map) + + async def avector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Asynchronously query the graph store with a vector store query.""" + return self.vector_query(query, **kwargs) + + async def aget_schema(self, refresh: bool = False) -> str: + """Asynchronously get the schema of the graph store.""" + return self.get_schema(refresh=refresh) + + async def aget_schema_str(self, refresh: bool = False) -> str: + """Asynchronously get the schema of the graph store as a string.""" + return str(await self.aget_schema(refresh=refresh)) + + +LIST_LIMIT = 128 + + +def clean_string_values(text: str) -> str: + return text.replace("\n", " ").replace("\r", " ") + + +def value_sanitize(d: Any) -> Any: + """Sanitize the input dictionary or list. + + Sanitizes the input by removing embedding-like values, + lists with more than 128 elements, that are mostly irrelevant for + generating answers in a LLM context. These properties, if left in + results, can occupy significant context space and detract from + the LLM's performance by introducing unnecessary noise and cost. + """ + if isinstance(d, dict): + new_dict = {} + for key, value in d.items(): + if isinstance(value, dict): + sanitized_value = value_sanitize(value) + if ( + sanitized_value is not None + ): # Check if the sanitized value is not None + new_dict[key] = sanitized_value + elif isinstance(value, list): + if len(value) < LIST_LIMIT: + sanitized_value = value_sanitize(value) + if ( + sanitized_value is not None + ): # Check if the sanitized value is not None + new_dict[key] = sanitized_value + # Do not include the key if the list is oversized + else: + new_dict[key] = value + return new_dict + elif isinstance(d, list): + if len(d) < LIST_LIMIT: + return [ + value_sanitize(item) + for item in d + if value_sanitize(item) is not None + ] + else: + return None + else: + return d diff --git a/R2R/r2r/base/abstractions/llm.py b/R2R/r2r/base/abstractions/llm.py new file mode 100755 index 00000000..3178d8dc --- /dev/null +++ b/R2R/r2r/base/abstractions/llm.py @@ -0,0 +1,112 @@ +"""Abstractions for the LLM model.""" + +from typing import TYPE_CHECKING, ClassVar, Optional + +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from .search import AggregateSearchResult + +LLMChatCompletion = ChatCompletion +LLMChatCompletionChunk = ChatCompletionChunk + + +class RAGCompletion: + completion: LLMChatCompletion + search_results: "AggregateSearchResult" + + def __init__( + self, + completion: LLMChatCompletion, + search_results: "AggregateSearchResult", + ): + self.completion = completion + self.search_results = search_results + + +class GenerationConfig(BaseModel): + _defaults: ClassVar[dict] = { + "model": "gpt-4o", + "temperature": 0.1, + "top_p": 1.0, + "top_k": 100, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "skip_special_tokens": False, + "stop_token": None, + "num_beams": 1, + "do_sample": True, + "generate_with_chat": False, + "add_generation_kwargs": None, + "api_base": None, + } + + model: str = Field( + default_factory=lambda: GenerationConfig._defaults["model"] + ) + temperature: float = Field( + default_factory=lambda: GenerationConfig._defaults["temperature"] + ) + top_p: float = Field( + default_factory=lambda: GenerationConfig._defaults["top_p"] + ) + top_k: int = Field( + default_factory=lambda: GenerationConfig._defaults["top_k"] + ) + max_tokens_to_sample: int = Field( + default_factory=lambda: GenerationConfig._defaults[ + "max_tokens_to_sample" + ] + ) + stream: bool = Field( + default_factory=lambda: GenerationConfig._defaults["stream"] + ) + functions: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["functions"] + ) + skip_special_tokens: bool = Field( + default_factory=lambda: GenerationConfig._defaults[ + "skip_special_tokens" + ] + ) + stop_token: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["stop_token"] + ) + num_beams: int = Field( + default_factory=lambda: GenerationConfig._defaults["num_beams"] + ) + do_sample: bool = Field( + default_factory=lambda: GenerationConfig._defaults["do_sample"] + ) + generate_with_chat: bool = Field( + default_factory=lambda: GenerationConfig._defaults[ + "generate_with_chat" + ] + ) + add_generation_kwargs: Optional[dict] = Field( + default_factory=lambda: GenerationConfig._defaults[ + "add_generation_kwargs" + ] + ) + api_base: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["api_base"] + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in GenerationConfig" + ) + + def __init__(self, **data): + model = data.pop("model", None) + if model is not None: + super().__init__(model=model, **data) + else: + super().__init__(**data) diff --git a/R2R/r2r/base/abstractions/prompt.py b/R2R/r2r/base/abstractions/prompt.py new file mode 100755 index 00000000..e37eeb5f --- /dev/null +++ b/R2R/r2r/base/abstractions/prompt.py @@ -0,0 +1,31 @@ +"""Abstraction for a prompt that can be formatted with inputs.""" + +from typing import Any + +from pydantic import BaseModel + + +class Prompt(BaseModel): + """A prompt that can be formatted with inputs.""" + + name: str + template: str + input_types: dict[str, str] + + def format_prompt(self, inputs: dict[str, Any]) -> str: + self._validate_inputs(inputs) + return self.template.format(**inputs) + + def _validate_inputs(self, inputs: dict[str, Any]) -> None: + for var, expected_type_name in self.input_types.items(): + expected_type = self._convert_type(expected_type_name) + if var not in inputs: + raise ValueError(f"Missing input: {var}") + if not isinstance(inputs[var], expected_type): + raise TypeError( + f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead." + ) + + def _convert_type(self, type_name: str) -> type: + type_mapping = {"int": int, "str": str} + return type_mapping.get(type_name, str) diff --git a/R2R/r2r/base/abstractions/search.py b/R2R/r2r/base/abstractions/search.py new file mode 100755 index 00000000..b13cc5aa --- /dev/null +++ b/R2R/r2r/base/abstractions/search.py @@ -0,0 +1,84 @@ +"""Abstractions for search functionality.""" + +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field + +from .llm import GenerationConfig + + +class VectorSearchRequest(BaseModel): + """Request for a search operation.""" + + query: str + limit: int + filters: Optional[dict[str, Any]] = None + + +class VectorSearchResult(BaseModel): + """Result of a search operation.""" + + id: uuid.UUID + score: float + metadata: dict[str, Any] + + def __str__(self) -> str: + return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" + + def __repr__(self) -> str: + return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" + + def dict(self) -> dict: + return { + "id": self.id, + "score": self.score, + "metadata": self.metadata, + } + + +class KGSearchRequest(BaseModel): + """Request for a knowledge graph search operation.""" + + query: str + + +# [query, ...] +KGSearchResult = List[Tuple[str, List[Dict[str, Any]]]] + + +class AggregateSearchResult(BaseModel): + """Result of an aggregate search operation.""" + + vector_search_results: Optional[List[VectorSearchResult]] + kg_search_results: Optional[KGSearchResult] = None + + def __str__(self) -> str: + return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" + + def __repr__(self) -> str: + return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" + + def dict(self) -> dict: + return { + "vector_search_results": ( + [result.dict() for result in self.vector_search_results] + if self.vector_search_results + else [] + ), + "kg_search_results": self.kg_search_results or [], + } + + +class VectorSearchSettings(BaseModel): + use_vector_search: bool = True + search_filters: dict[str, Any] = Field(default_factory=dict) + search_limit: int = 10 + do_hybrid_search: bool = False + + +class KGSearchSettings(BaseModel): + use_kg_search: bool = False + agent_generation_config: Optional[GenerationConfig] = Field( + default_factory=GenerationConfig + ) diff --git a/R2R/r2r/base/abstractions/vector.py b/R2R/r2r/base/abstractions/vector.py new file mode 100755 index 00000000..445f3302 --- /dev/null +++ b/R2R/r2r/base/abstractions/vector.py @@ -0,0 +1,66 @@ +"""Abstraction for a vector that can be stored in the system.""" + +from enum import Enum +from typing import Any +from uuid import UUID + + +class VectorType(Enum): + FIXED = "FIXED" + + +class Vector: + """A vector with the option to fix the number of elements.""" + + def __init__( + self, + data: list[float], + type: VectorType = VectorType.FIXED, + length: int = -1, + ): + self.data = data + self.type = type + self.length = length + + if ( + self.type == VectorType.FIXED + and length > 0 + and len(data) != length + ): + raise ValueError(f"Vector must be exactly {length} elements long.") + + def __repr__(self) -> str: + return ( + f"Vector(data={self.data}, type={self.type}, length={self.length})" + ) + + +class VectorEntry: + """A vector entry that can be stored directly in supported vector databases.""" + + def __init__(self, id: UUID, vector: Vector, metadata: dict[str, Any]): + """Create a new VectorEntry object.""" + self.vector = vector + self.id = id + self.metadata = metadata + + def to_serializable(self) -> str: + """Return a serializable representation of the VectorEntry.""" + metadata = self.metadata + + for key in metadata: + if isinstance(metadata[key], UUID): + metadata[key] = str(metadata[key]) + return { + "id": str(self.id), + "vector": self.vector.data, + "metadata": metadata, + } + + def __str__(self) -> str: + """Return a string representation of the VectorEntry.""" + return f"VectorEntry(id={self.id}, vector={self.vector}, metadata={self.metadata})" + + def __repr__(self) -> str: + """Return an unambiguous string representation of the VectorEntry.""" + return f"VectorEntry(id={self.id}, vector={self.vector}, metadata={self.metadata})" diff --git a/R2R/r2r/base/logging/__init__.py b/R2R/r2r/base/logging/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/logging/__init__.py diff --git a/R2R/r2r/base/logging/kv_logger.py b/R2R/r2r/base/logging/kv_logger.py new file mode 100755 index 00000000..2d444e9f --- /dev/null +++ b/R2R/r2r/base/logging/kv_logger.py @@ -0,0 +1,547 @@ +import json +import logging +import os +import uuid +from abc import abstractmethod +from datetime import datetime +from typing import Optional + +import asyncpg +from pydantic import BaseModel + +from ..providers.base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class RunInfo(BaseModel): + run_id: uuid.UUID + log_type: str + + +class LoggingConfig(ProviderConfig): + provider: str = "local" + log_table: str = "logs" + log_info_table: str = "logs_pipeline_info" + logging_path: Optional[str] = None + + def validate(self) -> None: + pass + + @property + def supported_providers(self) -> list[str]: + return ["local", "postgres", "redis"] + + +class KVLoggingProvider(Provider): + @abstractmethod + async def close(self): + pass + + @abstractmethod + async def log(self, log_id: uuid.UUID, key: str, value: str): + pass + + @abstractmethod + async def get_run_info( + self, + limit: int = 10, + log_type_filter: Optional[str] = None, + ) -> list[RunInfo]: + pass + + @abstractmethod + async def get_logs( + self, run_ids: list[uuid.UUID], limit_per_run: int + ) -> list: + pass + + +class LocalKVLoggingProvider(KVLoggingProvider): + def __init__(self, config: LoggingConfig): + self.log_table = config.log_table + self.log_info_table = config.log_info_table + self.logging_path = config.logging_path or os.getenv( + "LOCAL_DB_PATH", "local.sqlite" + ) + if not self.logging_path: + raise ValueError( + "Please set the environment variable LOCAL_DB_PATH." + ) + self.conn = None + try: + import aiosqlite + + self.aiosqlite = aiosqlite + except ImportError: + raise ImportError( + "Please install aiosqlite to use the LocalKVLoggingProvider." + ) + + async def init(self): + self.conn = await self.aiosqlite.connect(self.logging_path) + await self.conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.log_table} ( + timestamp DATETIME, + log_id TEXT, + key TEXT, + value TEXT + ) + """ + ) + await self.conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.log_info_table} ( + timestamp DATETIME, + log_id TEXT UNIQUE, + log_type TEXT + ) + """ + ) + await self.conn.commit() + + async def __aenter__(self): + if self.conn is None: + await self.init() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + if self.conn: + await self.conn.close() + self.conn = None + + async def log( + self, + log_id: uuid.UUID, + key: str, + value: str, + is_info_log=False, + ): + collection = self.log_info_table if is_info_log else self.log_table + + if is_info_log: + if "type" not in key: + raise ValueError("Info log keys must contain the text 'type'") + await self.conn.execute( + f"INSERT INTO {collection} (timestamp, log_id, log_type) VALUES (datetime('now'), ?, ?)", + (str(log_id), value), + ) + else: + await self.conn.execute( + f"INSERT INTO {collection} (timestamp, log_id, key, value) VALUES (datetime('now'), ?, ?, ?)", + (str(log_id), key, value), + ) + await self.conn.commit() + + async def get_run_info( + self, limit: int = 10, log_type_filter: Optional[str] = None + ) -> list[RunInfo]: + cursor = await self.conn.cursor() + query = f'SELECT log_id, log_type FROM "{self.log_info_table}"' + conditions = [] + params = [] + if log_type_filter: + conditions.append("log_type = ?") + params.append(log_type_filter) + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += " ORDER BY timestamp DESC LIMIT ?" + params.append(limit) + await cursor.execute(query, params) + rows = await cursor.fetchall() + return [ + RunInfo(run_id=uuid.UUID(row[0]), log_type=row[1]) for row in rows + ] + + async def get_logs( + self, run_ids: list[uuid.UUID], limit_per_run: int = 10 + ) -> list: + if not run_ids: + raise ValueError("No run ids provided.") + cursor = await self.conn.cursor() + placeholders = ",".join(["?" for _ in run_ids]) + query = f""" + SELECT * + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn + FROM {self.log_table} + WHERE log_id IN ({placeholders}) + ) + WHERE rn <= ? + ORDER BY timestamp DESC + """ + params = [str(ele) for ele in run_ids] + [limit_per_run] + await cursor.execute(query, params) + rows = await cursor.fetchall() + new_rows = [] + for row in rows: + new_rows.append( + (row[0], uuid.UUID(row[1]), row[2], row[3], row[4]) + ) + return [ + {desc[0]: row[i] for i, desc in enumerate(cursor.description)} + for row in new_rows + ] + + +class PostgresLoggingConfig(LoggingConfig): + provider: str = "postgres" + log_table: str = "logs" + log_info_table: str = "logs_pipeline_info" + + def validate(self) -> None: + required_env_vars = [ + "POSTGRES_DBNAME", + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_HOST", + "POSTGRES_PORT", + ] + for var in required_env_vars: + if not os.getenv(var): + raise ValueError(f"Environment variable {var} is not set.") + + @property + def supported_providers(self) -> list[str]: + return ["postgres"] + + +class PostgresKVLoggingProvider(KVLoggingProvider): + def __init__(self, config: PostgresLoggingConfig): + self.log_table = config.log_table + self.log_info_table = config.log_info_table + self.config = config + self.pool = None + if not os.getenv("POSTGRES_DBNAME"): + raise ValueError( + "Please set the environment variable POSTGRES_DBNAME." + ) + if not os.getenv("POSTGRES_USER"): + raise ValueError( + "Please set the environment variable POSTGRES_USER." + ) + if not os.getenv("POSTGRES_PASSWORD"): + raise ValueError( + "Please set the environment variable POSTGRES_PASSWORD." + ) + if not os.getenv("POSTGRES_HOST"): + raise ValueError( + "Please set the environment variable POSTGRES_HOST." + ) + if not os.getenv("POSTGRES_PORT"): + raise ValueError( + "Please set the environment variable POSTGRES_PORT." + ) + + async def init(self): + self.pool = await asyncpg.create_pool( + database=os.getenv("POSTGRES_DBNAME"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASSWORD"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), + statement_cache_size=0, # Disable statement caching + ) + async with self.pool.acquire() as conn: + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS "{self.log_table}" ( + timestamp TIMESTAMPTZ, + log_id UUID, + key TEXT, + value TEXT + ) + """ + ) + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS "{self.log_info_table}" ( + timestamp TIMESTAMPTZ, + log_id UUID UNIQUE, + log_type TEXT + ) + """ + ) + + async def __aenter__(self): + if self.pool is None: + await self.init() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + if self.pool: + await self.pool.close() + self.pool = None + + async def log( + self, + log_id: uuid.UUID, + key: str, + value: str, + is_info_log=False, + ): + collection = self.log_info_table if is_info_log else self.log_table + + if is_info_log: + if "type" not in key: + raise ValueError( + "Info log key must contain the string `type`." + ) + async with self.pool.acquire() as conn: + await self.pool.execute( + f'INSERT INTO "{collection}" (timestamp, log_id, log_type) VALUES (NOW(), $1, $2)', + log_id, + value, + ) + else: + async with self.pool.acquire() as conn: + await conn.execute( + f'INSERT INTO "{collection}" (timestamp, log_id, key, value) VALUES (NOW(), $1, $2, $3)', + log_id, + key, + value, + ) + + async def get_run_info( + self, limit: int = 10, log_type_filter: Optional[str] = None + ) -> list[RunInfo]: + query = f"SELECT log_id, log_type FROM {self.log_info_table}" + conditions = [] + params = [] + if log_type_filter: + conditions.append("log_type = $1") + params.append(log_type_filter) + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += " ORDER BY timestamp DESC LIMIT $2" + params.append(limit) + async with self.pool.acquire() as conn: + rows = await conn.fetch(query, *params) + return [ + RunInfo(run_id=row["log_id"], log_type=row["log_type"]) + for row in rows + ] + + async def get_logs( + self, run_ids: list[uuid.UUID], limit_per_run: int = 10 + ) -> list: + if not run_ids: + raise ValueError("No run ids provided.") + + placeholders = ",".join([f"${i + 1}" for i in range(len(run_ids))]) + query = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn + FROM "{self.log_table}" + WHERE log_id::text IN ({placeholders}) + ) sub + WHERE sub.rn <= ${len(run_ids) + 1} + ORDER BY sub.timestamp DESC + """ + params = [str(run_id) for run_id in run_ids] + [limit_per_run] + async with self.pool.acquire() as conn: + rows = await conn.fetch(query, *params) + return [{key: row[key] for key in row.keys()} for row in rows] + + +class RedisLoggingConfig(LoggingConfig): + provider: str = "redis" + log_table: str = "logs" + log_info_table: str = "logs_pipeline_info" + + def validate(self) -> None: + required_env_vars = ["REDIS_CLUSTER_IP", "REDIS_CLUSTER_PORT"] + for var in required_env_vars: + if not os.getenv(var): + raise ValueError(f"Environment variable {var} is not set.") + + @property + def supported_providers(self) -> list[str]: + return ["redis"] + + +class RedisKVLoggingProvider(KVLoggingProvider): + def __init__(self, config: RedisLoggingConfig): + logger.info( + f"Initializing RedisKVLoggingProvider with config: {config}" + ) + + if not all( + [ + os.getenv("REDIS_CLUSTER_IP"), + os.getenv("REDIS_CLUSTER_PORT"), + ] + ): + raise ValueError( + "Please set the environment variables REDIS_CLUSTER_IP and REDIS_CLUSTER_PORT to run `LoggingDatabaseConnection` with `redis`." + ) + try: + from redis.asyncio import Redis + except ImportError: + raise ValueError( + "Error, `redis` is not installed. Please install it using `pip install redis`." + ) + + cluster_ip = os.getenv("REDIS_CLUSTER_IP") + port = os.getenv("REDIS_CLUSTER_PORT") + self.redis = Redis(host=cluster_ip, port=port, decode_responses=True) + self.log_key = config.log_table + self.log_info_key = config.log_info_table + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + + async def close(self): + await self.redis.close() + + async def log( + self, + log_id: uuid.UUID, + key: str, + value: str, + is_info_log=False, + ): + timestamp = datetime.now().timestamp() + log_entry = { + "timestamp": timestamp, + "log_id": str(log_id), + "key": key, + "value": value, + } + if is_info_log: + if "type" not in key: + raise ValueError("Metadata keys must contain the text 'type'") + log_entry["log_type"] = value + await self.redis.hset( + self.log_info_key, str(log_id), json.dumps(log_entry) + ) + await self.redis.zadd( + f"{self.log_info_key}_sorted", {str(log_id): timestamp} + ) + else: + await self.redis.lpush( + f"{self.log_key}:{str(log_id)}", json.dumps(log_entry) + ) + + async def get_run_info( + self, limit: int = 10, log_type_filter: Optional[str] = None + ) -> list[RunInfo]: + run_info_list = [] + start = 0 + count_per_batch = 100 # Adjust batch size as needed + + while len(run_info_list) < limit: + log_ids = await self.redis.zrevrange( + f"{self.log_info_key}_sorted", + start, + start + count_per_batch - 1, + ) + if not log_ids: + break # No more log IDs to process + + start += count_per_batch + + for log_id in log_ids: + log_entry = json.loads( + await self.redis.hget(self.log_info_key, log_id) + ) + if log_type_filter: + if log_entry["log_type"] == log_type_filter: + run_info_list.append( + RunInfo( + run_id=uuid.UUID(log_entry["log_id"]), + log_type=log_entry["log_type"], + ) + ) + else: + run_info_list.append( + RunInfo( + run_id=uuid.UUID(log_entry["log_id"]), + log_type=log_entry["log_type"], + ) + ) + + if len(run_info_list) >= limit: + break + + return run_info_list[:limit] + + async def get_logs( + self, run_ids: list[uuid.UUID], limit_per_run: int = 10 + ) -> list: + logs = [] + for run_id in run_ids: + raw_logs = await self.redis.lrange( + f"{self.log_key}:{str(run_id)}", 0, limit_per_run - 1 + ) + for raw_log in raw_logs: + json_log = json.loads(raw_log) + json_log["log_id"] = uuid.UUID(json_log["log_id"]) + logs.append(json_log) + return logs + + +class KVLoggingSingleton: + _instance = None + _is_configured = False + + SUPPORTED_PROVIDERS = { + "local": LocalKVLoggingProvider, + "postgres": PostgresKVLoggingProvider, + "redis": RedisKVLoggingProvider, + } + + @classmethod + def get_instance(cls): + return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config) + + @classmethod + def configure( + cls, logging_config: Optional[LoggingConfig] = LoggingConfig() + ): + if not cls._is_configured: + cls._config = logging_config + cls._is_configured = True + else: + raise Exception("KVLoggingSingleton is already configured.") + + @classmethod + async def log( + cls, + log_id: uuid.UUID, + key: str, + value: str, + is_info_log=False, + ): + try: + async with cls.get_instance() as provider: + await provider.log(log_id, key, value, is_info_log=is_info_log) + + except Exception as e: + logger.error(f"Error logging data {(log_id, key, value)}: {e}") + + @classmethod + async def get_run_info( + cls, limit: int = 10, log_type_filter: Optional[str] = None + ) -> list[RunInfo]: + async with cls.get_instance() as provider: + return await provider.get_run_info( + limit, log_type_filter=log_type_filter + ) + + @classmethod + async def get_logs( + cls, run_ids: list[uuid.UUID], limit_per_run: int = 10 + ) -> list: + async with cls.get_instance() as provider: + return await provider.get_logs(run_ids, limit_per_run) diff --git a/R2R/r2r/base/logging/log_processor.py b/R2R/r2r/base/logging/log_processor.py new file mode 100755 index 00000000..e85d8de2 --- /dev/null +++ b/R2R/r2r/base/logging/log_processor.py @@ -0,0 +1,196 @@ +import contextlib +import json +import logging +import statistics +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Sequence + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class FilterCriteria(BaseModel): + filters: Optional[dict[str, str]] = None + + +class LogProcessor: + timestamp_format = "%Y-%m-%d %H:%M:%S" + + def __init__(self, filters: Dict[str, Callable[[Dict[str, Any]], bool]]): + self.filters = filters + self.populations = {name: [] for name in filters} + + def process_log(self, log: Dict[str, Any]): + for name, filter_func in self.filters.items(): + if filter_func(log): + self.populations[name].append(log) + + +class StatisticsCalculator: + @staticmethod + def calculate_statistics( + population: List[Dict[str, Any]], + stat_functions: Dict[str, Callable[[List[Dict[str, Any]]], Any]], + ) -> Dict[str, Any]: + return { + name: func(population) for name, func in stat_functions.items() + } + + +class DistributionGenerator: + @staticmethod + def generate_distributions( + population: List[Dict[str, Any]], + dist_functions: Dict[str, Callable[[List[Dict[str, Any]]], Any]], + ) -> Dict[str, Any]: + return { + name: func(population) for name, func in dist_functions.items() + } + + +class VisualizationPreparer: + @staticmethod + def prepare_visualization_data( + data: Dict[str, Any], + vis_functions: Dict[str, Callable[[Dict[str, Any]], Any]], + ) -> Dict[str, Any]: + return {name: func(data) for name, func in vis_functions.items()} + + +class LogAnalyticsConfig: + def __init__(self, filters, stat_functions, dist_functions, vis_functions): + self.filters = filters + self.stat_functions = stat_functions + self.dist_functions = dist_functions + self.vis_functions = vis_functions + + +class AnalysisTypes(BaseModel): + analysis_types: Optional[dict[str, Sequence[str]]] = None + + @staticmethod + def generate_bar_chart_data(logs, key): + chart_data = {"labels": [], "datasets": []} + value_counts = defaultdict(int) + + for log in logs: + if "entries" in log: + for entry in log["entries"]: + if entry["key"] == key: + value_counts[entry["value"]] += 1 + elif "key" in log and log["key"] == key: + value_counts[log["value"]] += 1 + + for value, count in value_counts.items(): + chart_data["labels"].append(value) + chart_data["datasets"].append({"label": key, "data": [count]}) + + return chart_data + + @staticmethod + def calculate_basic_statistics(logs, key): + values = [] + for log in logs: + if log["key"] == "search_results": + results = json.loads(log["value"]) + scores = [ + float(json.loads(result)["score"]) for result in results + ] + values.extend(scores) + else: + value = log.get("value") + if value is not None: + with contextlib.suppress(ValueError): + values.append(float(value)) + + if not values: + return { + "Mean": None, + "Median": None, + "Mode": None, + "Standard Deviation": None, + "Variance": None, + } + + if len(values) == 1: + single_value = round(values[0], 3) + return { + "Mean": single_value, + "Median": single_value, + "Mode": single_value, + "Standard Deviation": 0, + "Variance": 0, + } + + mean = round(sum(values) / len(values), 3) + median = round(statistics.median(values), 3) + mode = ( + round(statistics.mode(values), 3) + if len(set(values)) != len(values) + else None + ) + std_dev = round(statistics.stdev(values) if len(values) > 1 else 0, 3) + variance = round( + statistics.variance(values) if len(values) > 1 else 0, 3 + ) + + return { + "Mean": mean, + "Median": median, + "Mode": mode, + "Standard Deviation": std_dev, + "Variance": variance, + } + + @staticmethod + def calculate_percentile(logs, key, percentile): + values = [] + for log in logs: + if log["key"] == key: + value = log.get("value") + if value is not None: + with contextlib.suppress(ValueError): + values.append(float(value)) + + if not values: + return {"percentile": percentile, "value": None} + + values.sort() + index = int((percentile / 100) * (len(values) - 1)) + return {"percentile": percentile, "value": round(values[index], 3)} + + +class LogAnalytics: + def __init__(self, logs: List[Dict[str, Any]], config: LogAnalyticsConfig): + self.logs = logs + self.log_processor = LogProcessor(config.filters) + self.statistics_calculator = StatisticsCalculator() + self.distribution_generator = DistributionGenerator() + self.visualization_preparer = VisualizationPreparer() + self.config = config + + def count_logs(self) -> Dict[str, Any]: + """Count the logs for each filter.""" + return { + name: len(population) + for name, population in self.log_processor.populations.items() + } + + def process_logs(self) -> Dict[str, Any]: + for log in self.logs: + self.log_processor.process_log(log) + + analytics = {} + for name, population in self.log_processor.populations.items(): + stats = self.statistics_calculator.calculate_statistics( + population, self.config.stat_functions + ) + dists = self.distribution_generator.generate_distributions( + population, self.config.dist_functions + ) + analytics[name] = {"statistics": stats, "distributions": dists} + + return self.visualization_preparer.prepare_visualization_data( + analytics, self.config.vis_functions + ) diff --git a/R2R/r2r/base/logging/run_manager.py b/R2R/r2r/base/logging/run_manager.py new file mode 100755 index 00000000..ac192bca --- /dev/null +++ b/R2R/r2r/base/logging/run_manager.py @@ -0,0 +1,56 @@ +import contextvars +import uuid +from contextlib import asynccontextmanager +from typing import Any + +from .kv_logger import KVLoggingSingleton + +run_id_var = contextvars.ContextVar("run_id", default=None) + + +class RunManager: + def __init__(self, logger: KVLoggingSingleton): + self.logger = logger + self.run_info = {} + + def generate_run_id(self) -> uuid.UUID: + return uuid.uuid4() + + async def set_run_info(self, pipeline_type: str): + run_id = run_id_var.get() + if run_id is None: + run_id = self.generate_run_id() + token = run_id_var.set(run_id) + self.run_info[run_id] = {"pipeline_type": pipeline_type} + else: + token = run_id_var.set(run_id) + return run_id, token + + async def get_run_info(self): + run_id = run_id_var.get() + return self.run_info.get(run_id, None) + + async def log_run_info( + self, key: str, value: Any, is_info_log: bool = False + ): + run_id = run_id_var.get() + if run_id: + await self.logger.log( + log_id=run_id, key=key, value=value, is_info_log=is_info_log + ) + + async def clear_run_info(self, token: contextvars.Token): + run_id = run_id_var.get() + run_id_var.reset(token) + if run_id and run_id in self.run_info: + del self.run_info[run_id] + + +@asynccontextmanager +async def manage_run(run_manager: RunManager, pipeline_type: str): + run_id, token = await run_manager.set_run_info(pipeline_type) + try: + yield run_id + finally: + # Note: Do not clear the run info to ensure the run ID remains the same + run_id_var.reset(token) diff --git a/R2R/r2r/base/parsers/__init__.py b/R2R/r2r/base/parsers/__init__.py new file mode 100755 index 00000000..d7696202 --- /dev/null +++ b/R2R/r2r/base/parsers/__init__.py @@ -0,0 +1,5 @@ +from .base_parser import AsyncParser + +__all__ = [ + "AsyncParser", +] diff --git a/R2R/r2r/base/parsers/base_parser.py b/R2R/r2r/base/parsers/base_parser.py new file mode 100755 index 00000000..f1bb49d7 --- /dev/null +++ b/R2R/r2r/base/parsers/base_parser.py @@ -0,0 +1,14 @@ +"""Abstract base class for parsers.""" + +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Generic, TypeVar + +from ..abstractions.document import DataType + +T = TypeVar("T") + + +class AsyncParser(ABC, Generic[T]): + @abstractmethod + async def ingest(self, data: T) -> AsyncGenerator[DataType, None]: + pass diff --git a/R2R/r2r/base/pipeline/__init__.py b/R2R/r2r/base/pipeline/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/pipeline/__init__.py diff --git a/R2R/r2r/base/pipeline/base_pipeline.py b/R2R/r2r/base/pipeline/base_pipeline.py new file mode 100755 index 00000000..3c1eff9a --- /dev/null +++ b/R2R/r2r/base/pipeline/base_pipeline.py @@ -0,0 +1,233 @@ +"""Base pipeline class for running a sequence of pipes.""" + +import asyncio +import logging +from enum import Enum +from typing import Any, AsyncGenerator, Optional + +from ..logging.kv_logger import KVLoggingSingleton +from ..logging.run_manager import RunManager, manage_run +from ..pipes.base_pipe import AsyncPipe, AsyncState + +logger = logging.getLogger(__name__) + + +class PipelineTypes(Enum): + EVAL = "eval" + INGESTION = "ingestion" + SEARCH = "search" + RAG = "rag" + OTHER = "other" + + +class AsyncPipeline: + """Pipeline class for running a sequence of pipes.""" + + pipeline_type: str = "other" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + self.pipes: list[AsyncPipe] = [] + self.upstream_outputs: list[list[dict[str, str]]] = [] + self.pipe_logger = pipe_logger or KVLoggingSingleton() + self.run_manager = run_manager or RunManager(self.pipe_logger) + self.futures = {} + self.level = 0 + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + *args, + **kwargs, + ) -> None: + """Add a pipe to the pipeline.""" + self.pipes.append(pipe) + if not add_upstream_outputs: + add_upstream_outputs = [] + self.upstream_outputs.append(add_upstream_outputs) + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + log_run_info: bool = True, + *args: Any, + **kwargs: Any, + ): + """Run the pipeline.""" + run_manager = run_manager or self.run_manager + + try: + PipelineTypes(self.pipeline_type) + except ValueError: + raise ValueError( + f"Invalid pipeline type: {self.pipeline_type}, must be one of {PipelineTypes.__members__.keys()}" + ) + + self.state = state or AsyncState() + current_input = input + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + try: + for pipe_num in range(len(self.pipes)): + config_name = self.pipes[pipe_num].config.name + self.futures[config_name] = asyncio.Future() + + current_input = self._run_pipe( + pipe_num, + current_input, + run_manager, + *args, + **kwargs, + ) + self.futures[config_name].set_result(current_input) + if not stream: + final_result = await self._consume_all(current_input) + return final_result + else: + return current_input + except Exception as error: + logger.error(f"Pipeline failed with error: {error}") + raise error + + async def _consume_all(self, gen: AsyncGenerator) -> list[Any]: + result = [] + async for item in gen: + if hasattr( + item, "__aiter__" + ): # Check if the item is an async generator + sub_result = await self._consume_all(item) + result.extend(sub_result) + else: + result.append(item) + return result + + async def _run_pipe( + self, + pipe_num: int, + input: Any, + run_manager: RunManager, + *args: Any, + **kwargs: Any, + ): + # Collect inputs, waiting for the necessary futures + pipe = self.pipes[pipe_num] + add_upstream_outputs = self.sort_upstream_outputs( + self.upstream_outputs[pipe_num] + ) + input_dict = {"message": input} + + # Group upstream outputs by prev_pipe_name + grouped_upstream_outputs = {} + for upstream_input in add_upstream_outputs: + upstream_pipe_name = upstream_input["prev_pipe_name"] + if upstream_pipe_name not in grouped_upstream_outputs: + grouped_upstream_outputs[upstream_pipe_name] = [] + grouped_upstream_outputs[upstream_pipe_name].append(upstream_input) + + for ( + upstream_pipe_name, + upstream_inputs, + ) in grouped_upstream_outputs.items(): + + async def resolve_future_output(future): + result = future.result() + # consume the async generator + return [item async for item in result] + + async def replay_items_as_async_gen(items): + for item in items: + yield item + + temp_results = await resolve_future_output( + self.futures[upstream_pipe_name] + ) + if upstream_pipe_name == self.pipes[pipe_num - 1].config.name: + input_dict["message"] = replay_items_as_async_gen(temp_results) + + for upstream_input in upstream_inputs: + outputs = await self.state.get(upstream_pipe_name, "output") + prev_output_field = upstream_input.get( + "prev_output_field", None + ) + if not prev_output_field: + raise ValueError( + "`prev_output_field` must be specified in the upstream_input" + ) + input_dict[upstream_input["input_field"]] = outputs[ + prev_output_field + ] + + # Handle the pipe generator + async for ele in await pipe.run( + pipe.Input(**input_dict), + self.state, + run_manager, + *args, + **kwargs, + ): + yield ele + + def sort_upstream_outputs( + self, add_upstream_outputs: list[dict[str, str]] + ) -> list[dict[str, str]]: + pipe_name_to_index = { + pipe.config.name: index for index, pipe in enumerate(self.pipes) + } + + def get_pipe_index(upstream_output): + return pipe_name_to_index[upstream_output["prev_pipe_name"]] + + sorted_outputs = sorted( + add_upstream_outputs, key=get_pipe_index, reverse=True + ) + return sorted_outputs + + +class EvalPipeline(AsyncPipeline): + """A pipeline for evaluation.""" + + pipeline_type: str = "eval" + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + *args: Any, + **kwargs: Any, + ): + return await super().run( + input, state, stream, run_manager, *args, **kwargs + ) + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + *args, + **kwargs, + ) -> None: + logger.debug(f"Adding pipe {pipe.config.name} to the EvalPipeline") + return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs) + + +async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator: + """Create an async generator to dequeue requests.""" + while True: + request = await queue.get() + if request is None: + break + yield request diff --git a/R2R/r2r/base/pipes/__init__.py b/R2R/r2r/base/pipes/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/pipes/__init__.py diff --git a/R2R/r2r/base/pipes/base_pipe.py b/R2R/r2r/base/pipes/base_pipe.py new file mode 100755 index 00000000..63e3d04e --- /dev/null +++ b/R2R/r2r/base/pipes/base_pipe.py @@ -0,0 +1,163 @@ +import asyncio +import logging +import uuid +from abc import abstractmethod +from enum import Enum +from typing import Any, AsyncGenerator, Optional + +from pydantic import BaseModel + +from r2r.base.logging.kv_logger import KVLoggingSingleton +from r2r.base.logging.run_manager import RunManager, manage_run + +logger = logging.getLogger(__name__) + + +class PipeType(Enum): + INGESTOR = "ingestor" + EVAL = "eval" + GENERATOR = "generator" + SEARCH = "search" + TRANSFORM = "transform" + OTHER = "other" + + +class AsyncState: + """A state object for storing data between pipes.""" + + def __init__(self): + self.data = {} + self.lock = asyncio.Lock() + + async def update(self, outer_key: str, values: dict): + """Update the state with new values.""" + async with self.lock: + if not isinstance(values, dict): + raise ValueError("Values must be contained in a dictionary.") + if outer_key not in self.data: + self.data[outer_key] = {} + for inner_key, inner_value in values.items(): + self.data[outer_key][inner_key] = inner_value + + async def get(self, outer_key: str, inner_key: str, default=None): + """Get a value from the state.""" + async with self.lock: + if outer_key not in self.data: + raise ValueError( + f"Key {outer_key} does not exist in the state." + ) + if inner_key not in self.data[outer_key]: + return default or {} + return self.data[outer_key][inner_key] + + async def delete(self, outer_key: str, inner_key: Optional[str] = None): + """Delete a value from the state.""" + async with self.lock: + if outer_key in self.data and not inner_key: + del self.data[outer_key] + else: + if inner_key not in self.data[outer_key]: + raise ValueError( + f"Key {inner_key} does not exist in the state." + ) + del self.data[outer_key][inner_key] + + +class AsyncPipe: + """An asynchronous pipe for processing data with logging capabilities.""" + + class PipeConfig(BaseModel): + """Configuration for a pipe.""" + + name: str = "default_pipe" + max_log_queue_size: int = 100 + + class Config: + extra = "forbid" + arbitrary_types_allowed = True + + class Input(BaseModel): + """Input for a pipe.""" + + message: AsyncGenerator[Any, None] + + class Config: + extra = "forbid" + arbitrary_types_allowed = True + + def __init__( + self, + type: PipeType = PipeType.OTHER, + config: Optional[PipeConfig] = None, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + self._config = config or self.PipeConfig() + self._type = type + self.pipe_logger = pipe_logger or KVLoggingSingleton() + self.log_queue = asyncio.Queue() + self.log_worker_task = None + self._run_manager = run_manager or RunManager(self.pipe_logger) + + logger.debug( + f"Initialized pipe {self.config.name} of type {self.type}" + ) + + @property + def config(self) -> PipeConfig: + return self._config + + @property + def type(self) -> PipeType: + return self._type + + async def log_worker(self): + while True: + log_data = await self.log_queue.get() + run_id, key, value = log_data + await self.pipe_logger.log(run_id, key, value) + self.log_queue.task_done() + + async def enqueue_log(self, run_id: uuid.UUID, key: str, value: str): + if self.log_queue.qsize() < self.config.max_log_queue_size: + await self.log_queue.put((run_id, key, value)) + + async def run( + self, + input: Input, + state: AsyncState, + run_manager: Optional[RunManager] = None, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Any, None]: + """Run the pipe with logging capabilities.""" + + run_manager = run_manager or self._run_manager + + async def wrapped_run() -> AsyncGenerator[Any, None]: + async with manage_run(run_manager, self.config.name) as run_id: + self.log_worker_task = asyncio.create_task( + self.log_worker(), name=f"log-worker-{self.config.name}" + ) + try: + async for result in self._run_logic( + input, state, run_id=run_id, *args, **kwargs + ): + yield result + finally: + await self.log_queue.join() + self.log_worker_task.cancel() + self.log_queue = asyncio.Queue() + + return wrapped_run() + + @abstractmethod + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[Any, None]: + pass diff --git a/R2R/r2r/base/providers/__init__.py b/R2R/r2r/base/providers/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/providers/__init__.py diff --git a/R2R/r2r/base/providers/base_provider.py b/R2R/r2r/base/providers/base_provider.py new file mode 100755 index 00000000..8ee8d56a --- /dev/null +++ b/R2R/r2r/base/providers/base_provider.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, Optional, Type + +from pydantic import BaseModel + + +class ProviderConfig(BaseModel, ABC): + """A base provider configuration class""" + + extra_fields: dict[str, Any] = {} + provider: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + ignore_extra = True + + @abstractmethod + def validate(self) -> None: + pass + + @classmethod + def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig": + base_args = cls.__fields__.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + @abstractproperty + @property + def supported_providers(self) -> list[str]: + """Define a list of supported providers.""" + pass + + +class Provider(ABC): + """A base provider class to provide a common interface for all providers.""" + + def __init__(self, config: Optional[ProviderConfig] = None): + if config: + config.validate() + self.config = config diff --git a/R2R/r2r/base/providers/embedding_provider.py b/R2R/r2r/base/providers/embedding_provider.py new file mode 100755 index 00000000..8f3af56f --- /dev/null +++ b/R2R/r2r/base/providers/embedding_provider.py @@ -0,0 +1,83 @@ +import logging +from abc import abstractmethod +from enum import Enum +from typing import Optional + +from ..abstractions.search import VectorSearchResult +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class EmbeddingConfig(ProviderConfig): + """A base embedding configuration class""" + + provider: Optional[str] = None + base_model: Optional[str] = None + base_dimension: Optional[int] = None + rerank_model: Optional[str] = None + rerank_dimension: Optional[int] = None + rerank_transformer_type: Optional[str] = None + batch_size: int = 1 + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return [None, "openai", "ollama", "sentence-transformers"] + + +class EmbeddingProvider(Provider): + """An abstract class to provide a common interface for embedding providers.""" + + class PipeStage(Enum): + BASE = 1 + RERANK = 2 + + def __init__(self, config: EmbeddingConfig): + if not isinstance(config, EmbeddingConfig): + raise ValueError( + "EmbeddingProvider must be initialized with a `EmbeddingConfig`." + ) + logger.info(f"Initializing EmbeddingProvider with config {config}.") + + super().__init__(config) + + @abstractmethod + def get_embedding(self, text: str, stage: PipeStage = PipeStage.BASE): + pass + + async def async_get_embedding( + self, text: str, stage: PipeStage = PipeStage.BASE + ): + return self.get_embedding(text, stage) + + @abstractmethod + def get_embeddings( + self, texts: list[str], stage: PipeStage = PipeStage.BASE + ): + pass + + async def async_get_embeddings( + self, texts: list[str], stage: PipeStage = PipeStage.BASE + ): + return self.get_embeddings(texts, stage) + + @abstractmethod + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: PipeStage = PipeStage.RERANK, + limit: int = 10, + ): + pass + + @abstractmethod + def tokenize_string( + self, text: str, model: str, stage: PipeStage + ) -> list[int]: + """Tokenizes the input string.""" + pass diff --git a/R2R/r2r/base/providers/eval_provider.py b/R2R/r2r/base/providers/eval_provider.py new file mode 100755 index 00000000..76053f87 --- /dev/null +++ b/R2R/r2r/base/providers/eval_provider.py @@ -0,0 +1,46 @@ +from typing import Optional, Union + +from ..abstractions.llm import GenerationConfig +from .base_provider import Provider, ProviderConfig +from .llm_provider import LLMConfig + + +class EvalConfig(ProviderConfig): + """A base eval config class""" + + llm: Optional[LLMConfig] = None + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} not supported.") + if self.provider and not self.llm: + raise ValueError( + "EvalConfig must have a `llm` attribute when specifying a provider." + ) + + @property + def supported_providers(self) -> list[str]: + return [None, "local"] + + +class EvalProvider(Provider): + """An abstract class to provide a common interface for evaluation providers.""" + + def __init__(self, config: EvalConfig): + if not isinstance(config, EvalConfig): + raise ValueError( + "EvalProvider must be initialized with a `EvalConfig`." + ) + + super().__init__(config) + + def evaluate( + self, + query: str, + context: str, + completion: str, + eval_generation_config: Optional[GenerationConfig] = None, + ) -> dict[str, dict[str, Union[str, float]]]: + return self._evaluate( + query, context, completion, eval_generation_config + ) diff --git a/R2R/r2r/base/providers/kg_provider.py b/R2R/r2r/base/providers/kg_provider.py new file mode 100755 index 00000000..4ae96b11 --- /dev/null +++ b/R2R/r2r/base/providers/kg_provider.py @@ -0,0 +1,182 @@ +"""Base classes for knowledge graph providers.""" + +import json +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, Tuple + +from .prompt_provider import PromptProvider + +if TYPE_CHECKING: + from r2r.main import R2RClient + +from ...base.utils.base_utils import EntityType, Relation +from ..abstractions.llama_abstractions import EntityNode, LabelledNode +from ..abstractions.llama_abstractions import Relation as LlamaRelation +from ..abstractions.llama_abstractions import VectorStoreQuery +from ..abstractions.llm import GenerationConfig +from .base_provider import ProviderConfig + +logger = logging.getLogger(__name__) + + +class KGConfig(ProviderConfig): + """A base KG config class""" + + provider: Optional[str] = None + batch_size: int = 1 + kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction" + kg_agent_prompt: Optional[str] = "kg_agent" + kg_extraction_config: Optional[GenerationConfig] = None + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return [None, "neo4j"] + + +class KGProvider(ABC): + """An abstract class to provide a common interface for Knowledge Graphs.""" + + def __init__(self, config: KGConfig) -> None: + if not isinstance(config, KGConfig): + raise ValueError( + "KGProvider must be initialized with a `KGConfig`." + ) + logger.info(f"Initializing KG provider with config: {config}") + self.config = config + self.validate_config() + + def validate_config(self) -> None: + self.config.validate() + + @property + @abstractmethod + def client(self) -> Any: + """Get client.""" + pass + + @abstractmethod + def get(self, subj: str) -> list[list[str]]: + """Abstract method to get triplets.""" + pass + + @abstractmethod + def get_rel_map( + self, + subjs: Optional[list[str]] = None, + depth: int = 2, + limit: int = 30, + ) -> dict[str, list[list[str]]]: + """Abstract method to get depth-aware rel map.""" + pass + + @abstractmethod + def upsert_nodes(self, nodes: list[EntityNode]) -> None: + """Abstract method to add triplet.""" + pass + + @abstractmethod + def upsert_relations(self, relations: list[LlamaRelation]) -> None: + """Abstract method to add triplet.""" + pass + + @abstractmethod + def delete(self, subj: str, rel: str, obj: str) -> None: + """Abstract method to delete triplet.""" + pass + + @abstractmethod + def get_schema(self, refresh: bool = False) -> str: + """Abstract method to get the schema of the graph store.""" + pass + + @abstractmethod + def structured_query( + self, query: str, param_map: Optional[dict[str, Any]] = {} + ) -> Any: + """Abstract method to query the graph store with statement and parameters.""" + pass + + @abstractmethod + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[list[LabelledNode], list[float]]: + """Abstract method to query the graph store with a vector store query.""" + + # TODO - Type this method. + @abstractmethod + def update_extraction_prompt( + self, + prompt_provider: Any, + entity_types: list[Any], + relations: list[Relation], + ): + """Abstract method to update the KG extraction prompt.""" + pass + + # TODO - Type this method. + @abstractmethod + def update_kg_agent_prompt( + self, + prompt_provider: Any, + entity_types: list[Any], + relations: list[Relation], + ): + """Abstract method to update the KG agent prompt.""" + pass + + +def escape_braces(s: str) -> str: + """ + Escape braces in a string. + This is a placeholder function - implement the actual logic as needed. + """ + # Implement your escape_braces logic here + return s.replace("{", "{{").replace("}", "}}") + + +# TODO - Make this more configurable / intelligent +def update_kg_prompt( + client: "R2RClient", + r2r_prompts: PromptProvider, + prompt_base: str, + entity_types: list[EntityType], + relations: list[Relation], +) -> None: + # Get the default extraction template + template_name: str = f"{prompt_base}_with_spec" + + new_template: str = r2r_prompts.get_prompt( + template_name, + { + "entity_types": json.dumps( + { + "entity_types": [ + str(entity.name) for entity in entity_types + ] + }, + indent=4, + ), + "relations": json.dumps( + {"predicates": [str(relation.name) for relation in relations]}, + indent=4, + ), + "input": """\n{input}""", + }, + ) + + # Escape all braces in the template, except for the {input} placeholder, for formatting + escaped_template: str = escape_braces(new_template).replace( + """{{input}}""", """{input}""" + ) + + # Update the client's prompt + client.update_prompt( + prompt_base, + template=escaped_template, + input_types={"input": "str"}, + ) diff --git a/R2R/r2r/base/providers/llm_provider.py b/R2R/r2r/base/providers/llm_provider.py new file mode 100755 index 00000000..9b6499a4 --- /dev/null +++ b/R2R/r2r/base/providers/llm_provider.py @@ -0,0 +1,66 @@ +"""Base classes for language model providers.""" + +import logging +from abc import abstractmethod +from typing import Optional + +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class LLMConfig(ProviderConfig): + """A base LLM config class""" + + provider: Optional[str] = None + generation_config: Optional[GenerationConfig] = None + + def validate(self) -> None: + if not self.provider: + raise ValueError("Provider must be set.") + + if self.provider and self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["litellm", "openai"] + + +class LLMProvider(Provider): + """An abstract class to provide a common interface for LLMs.""" + + def __init__( + self, + config: LLMConfig, + ) -> None: + if not isinstance(config, LLMConfig): + raise ValueError( + "LLMProvider must be initialized with a `LLMConfig`." + ) + logger.info(f"Initializing LLM provider with config: {config}") + + super().__init__(config) + + @abstractmethod + def get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + """Abstract method to get a chat completion from the provider.""" + pass + + @abstractmethod + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletionChunk: + """Abstract method to get a completion stream from the provider.""" + pass diff --git a/R2R/r2r/base/providers/prompt_provider.py b/R2R/r2r/base/providers/prompt_provider.py new file mode 100755 index 00000000..78af9e11 --- /dev/null +++ b/R2R/r2r/base/providers/prompt_provider.py @@ -0,0 +1,65 @@ +import logging +from abc import abstractmethod +from typing import Any, Optional + +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class PromptConfig(ProviderConfig): + def validate(self) -> None: + pass + + @property + def supported_providers(self) -> list[str]: + # Return a list of supported prompt providers + return ["default_prompt_provider"] + + +class PromptProvider(Provider): + def __init__(self, config: Optional[PromptConfig] = None): + if config is None: + config = PromptConfig() + elif not isinstance(config, PromptConfig): + raise ValueError( + "PromptProvider must be initialized with a `PromptConfig`." + ) + logger.info(f"Initializing PromptProvider with config {config}.") + super().__init__(config) + + @abstractmethod + def add_prompt( + self, name: str, template: str, input_types: dict[str, str] + ) -> None: + pass + + @abstractmethod + def get_prompt( + self, prompt_name: str, inputs: Optional[dict[str, Any]] = None + ) -> str: + pass + + @abstractmethod + def get_all_prompts(self) -> dict[str, str]: + pass + + @abstractmethod + def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = None, + ) -> None: + pass + + def _get_message_payload( + self, system_prompt: str, task_prompt: str + ) -> dict: + return [ + { + "role": "system", + "content": system_prompt, + }, + {"role": "user", "content": task_prompt}, + ] diff --git a/R2R/r2r/base/providers/vector_db_provider.py b/R2R/r2r/base/providers/vector_db_provider.py new file mode 100755 index 00000000..a6d5aaa8 --- /dev/null +++ b/R2R/r2r/base/providers/vector_db_provider.py @@ -0,0 +1,142 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional, Union + +from ..abstractions.document import DocumentInfo +from ..abstractions.search import VectorSearchResult +from ..abstractions.vector import VectorEntry +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class VectorDBConfig(ProviderConfig): + provider: str + + def __post_init__(self): + self.validate() + # Capture additional fields + for key, value in self.extra_fields.items(): + setattr(self, key, value) + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["local", "pgvector"] + + +class VectorDBProvider(Provider, ABC): + def __init__(self, config: VectorDBConfig): + if not isinstance(config, VectorDBConfig): + raise ValueError( + "VectorDBProvider must be initialized with a `VectorDBConfig`." + ) + logger.info(f"Initializing VectorDBProvider with config {config}.") + super().__init__(config) + + @abstractmethod + def initialize_collection(self, dimension: int) -> None: + pass + + @abstractmethod + def copy(self, entry: VectorEntry, commit: bool = True) -> None: + pass + + @abstractmethod + def upsert(self, entry: VectorEntry, commit: bool = True) -> None: + pass + + @abstractmethod + def search( + self, + query_vector: list[float], + filters: dict[str, Union[bool, int, str]] = {}, + limit: int = 10, + *args, + **kwargs, + ) -> list[VectorSearchResult]: + pass + + @abstractmethod + def hybrid_search( + self, + query_text: str, + query_vector: list[float], + limit: int = 10, + filters: Optional[dict[str, Union[bool, int, str]]] = None, + # Hybrid search parameters + full_text_weight: float = 1.0, + semantic_weight: float = 1.0, + rrf_k: int = 20, # typical value is ~2x the number of results you want + *args, + **kwargs, + ) -> list[VectorSearchResult]: + pass + + @abstractmethod + def create_index(self, index_type, column_name, index_options): + pass + + def upsert_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + for entry in entries: + self.upsert(entry, commit=commit) + + def copy_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + for entry in entries: + self.copy(entry, commit=commit) + + @abstractmethod + def delete_by_metadata( + self, + metadata_fields: list[str], + metadata_values: list[Union[bool, int, str]], + ) -> list[str]: + if len(metadata_fields) != len(metadata_values): + raise ValueError( + "The number of metadata fields and values must be equal." + ) + pass + + @abstractmethod + def get_metadatas( + self, + metadata_fields: list[str], + filter_field: Optional[str] = None, + filter_value: Optional[str] = None, + ) -> list[str]: + pass + + @abstractmethod + def upsert_documents_overview( + self, document_infs: list[DocumentInfo] + ) -> None: + pass + + @abstractmethod + def get_documents_overview( + self, + filter_document_ids: Optional[list[str]] = None, + filter_user_ids: Optional[list[str]] = None, + ) -> list[DocumentInfo]: + pass + + @abstractmethod + def get_document_chunks(self, document_id: str) -> list[dict]: + pass + + @abstractmethod + def delete_from_documents_overview( + self, document_id: str, version: Optional[str] = None + ) -> dict: + pass + + @abstractmethod + def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict: + pass diff --git a/R2R/r2r/base/utils/__init__.py b/R2R/r2r/base/utils/__init__.py new file mode 100755 index 00000000..104d50eb --- /dev/null +++ b/R2R/r2r/base/utils/__init__.py @@ -0,0 +1,26 @@ +from .base_utils import ( + EntityType, + Relation, + format_entity_types, + format_relations, + generate_id_from_label, + generate_run_id, + increment_version, + run_pipeline, + to_async_generator, +) +from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter + +__all__ = [ + "RecursiveCharacterTextSplitter", + "TextSplitter", + "run_pipeline", + "to_async_generator", + "generate_run_id", + "generate_id_from_label", + "increment_version", + "EntityType", + "Relation", + "format_entity_types", + "format_relations", +] diff --git a/R2R/r2r/base/utils/base_utils.py b/R2R/r2r/base/utils/base_utils.py new file mode 100755 index 00000000..12652833 --- /dev/null +++ b/R2R/r2r/base/utils/base_utils.py @@ -0,0 +1,63 @@ +import asyncio +import uuid +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable + +if TYPE_CHECKING: + from ..pipeline.base_pipeline import AsyncPipeline + + +def generate_run_id() -> uuid.UUID: + return uuid.uuid4() + + +def generate_id_from_label(label: str) -> uuid.UUID: + return uuid.uuid5(uuid.NAMESPACE_DNS, label) + + +async def to_async_generator( + iterable: Iterable[Any], +) -> AsyncGenerator[Any, None]: + for item in iterable: + yield item + + +def run_pipeline(pipeline: "AsyncPipeline", input: Any, *args, **kwargs): + if not isinstance(input, AsyncGenerator) and not isinstance(input, list): + input = to_async_generator([input]) + elif not isinstance(input, AsyncGenerator): + input = to_async_generator(input) + + async def _run_pipeline(input, *args, **kwargs): + return await pipeline.run(input, *args, **kwargs) + + return asyncio.run(_run_pipeline(input, *args, **kwargs)) + + +def increment_version(version: str) -> str: + prefix = version[:-1] + suffix = int(version[-1]) + return f"{prefix}{suffix + 1}" + + +class EntityType: + def __init__(self, name: str): + self.name = name + + +class Relation: + def __init__(self, name: str): + self.name = name + + +def format_entity_types(entity_types: list[EntityType]) -> str: + lines = [] + for entity in entity_types: + lines.append(entity.name) + return "\n".join(lines) + + +def format_relations(predicates: list[Relation]) -> str: + lines = [] + for predicate in predicates: + lines.append(predicate.name) + return "\n".join(lines) diff --git a/R2R/r2r/base/utils/splitter/__init__.py b/R2R/r2r/base/utils/splitter/__init__.py new file mode 100755 index 00000000..07a9f554 --- /dev/null +++ b/R2R/r2r/base/utils/splitter/__init__.py @@ -0,0 +1,3 @@ +from .text import RecursiveCharacterTextSplitter + +__all__ = ["RecursiveCharacterTextSplitter"] diff --git a/R2R/r2r/base/utils/splitter/text.py b/R2R/r2r/base/utils/splitter/text.py new file mode 100755 index 00000000..5458310c --- /dev/null +++ b/R2R/r2r/base/utils/splitter/text.py @@ -0,0 +1,1979 @@ +# Source - LangChain +# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851 +"""**Text Splitters** are classes for splitting text. + + +**Class hierarchy:** + +.. code-block:: + + BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter + RecursiveCharacterTextSplitter --> <name>TextSplitter + +Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter. + + +**Main helpers:** + +.. code-block:: + + Document, Tokenizer, Language, LineType, HeaderType + +""" # noqa: E501 + +from __future__ import annotations + +import copy +import json +import logging +import pathlib +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from io import BytesIO, StringIO +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + TypeVar, + Union, + cast, +) + +import requests +from pydantic import BaseModel, Field, PrivateAttr +from typing_extensions import NotRequired + +logger = logging.getLogger(__name__) + +TS = TypeVar("TS", bound="TextSplitter") + + +class BaseSerialized(TypedDict): + """Base class for serialized objects.""" + + lc: int + id: List[str] + name: NotRequired[str] + graph: NotRequired[Dict[str, Any]] + + +class SerializedConstructor(BaseSerialized): + """Serialized constructor.""" + + type: Literal["constructor"] + kwargs: Dict[str, Any] + + +class SerializedSecret(BaseSerialized): + """Serialized secret.""" + + type: Literal["secret"] + + +class SerializedNotImplemented(BaseSerialized): + """Serialized not implemented.""" + + type: Literal["not_implemented"] + repr: Optional[str] + + +def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: + """Try to determine if a value is different from the default. + + Args: + value: The value. + key: The key. + model: The model. + + Returns: + Whether the value is different from the default. + """ + try: + return model.__fields__[key].get_default() != value + except Exception: + return True + + +class Serializable(BaseModel, ABC): + """Serializable base class.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Is this class serializable?""" + return False + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object. + + For example, if the class is `langchain.llms.openai.OpenAI`, then the + namespace is ["langchain", "llms", "openai"] + """ + return cls.__module__.split(".") + + @property + def lc_secrets(self) -> Dict[str, str]: + """A map of constructor argument names to secret ids. + + For example, + {"openai_api_key": "OPENAI_API_KEY"} + """ + return dict() + + @property + def lc_attributes(self) -> Dict: + """List of attribute names that should be included in the serialized kwargs. + + These attributes must be accepted by the constructor. + """ + return {} + + @classmethod + def lc_id(cls) -> List[str]: + """A unique identifier for this class for serialization purposes. + + The unique identifier is a list of strings that describes the path + to the object. + """ + return [*cls.get_lc_namespace(), cls.__name__] + + class Config: + extra = "ignore" + + def __repr_args__(self) -> Any: + return [ + (k, v) + for k, v in super().__repr_args__() + if (k not in self.__fields__ or try_neq_default(v, k, self)) + ] + + _lc_kwargs = PrivateAttr(default_factory=dict) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lc_kwargs = kwargs + + def to_json( + self, + ) -> Union[SerializedConstructor, SerializedNotImplemented]: + if not self.is_lc_serializable(): + return self.to_json_not_implemented() + + secrets = dict() + # Get latest values for kwargs if there is an attribute with same name + lc_kwargs = { + k: getattr(self, k, v) + for k, v in self._lc_kwargs.items() + if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore + } + + # Merge the lc_secrets and lc_attributes from every class in the MRO + for cls in [None, *self.__class__.mro()]: + # Once we get to Serializable, we're done + if cls is Serializable: + break + + if cls: + deprecated_attributes = [ + "lc_namespace", + "lc_serializable", + ] + + for attr in deprecated_attributes: + if hasattr(cls, attr): + raise ValueError( + f"Class {self.__class__} has a deprecated " + f"attribute {attr}. Please use the corresponding " + f"classmethod instead." + ) + + # Get a reference to self bound to each class in the MRO + this = cast( + Serializable, self if cls is None else super(cls, self) + ) + + secrets.update(this.lc_secrets) + # Now also add the aliases for the secrets + # This ensures known secret aliases are hidden. + # Note: this does NOT hide any other extra kwargs + # that are not present in the fields. + for key in list(secrets): + value = secrets[key] + if key in this.__fields__: + secrets[this.__fields__[key].alias] = value + lc_kwargs.update(this.lc_attributes) + + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + + return { + "lc": 1, + "type": "constructor", + "id": self.lc_id(), + "kwargs": ( + lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets) + ), + } + + def to_json_not_implemented(self) -> SerializedNotImplemented: + return to_json_not_implemented(self) + + +def _replace_secrets( + root: Dict[Any, Any], secrets_map: Dict[str, str] +) -> Dict[Any, Any]: + result = root.copy() + for path, secret_id in secrets_map.items(): + [*parts, last] = path.split(".") + current = result + for part in parts: + if part not in current: + break + current[part] = current[part].copy() + current = current[part] + if last in current: + current[last] = { + "lc": 1, + "type": "secret", + "id": [secret_id], + } + return result + + +def to_json_not_implemented(obj: object) -> SerializedNotImplemented: + """Serialize a "not implemented" object. + + Args: + obj: object to serialize + + Returns: + SerializedNotImplemented + """ + _id: List[str] = [] + try: + if hasattr(obj, "__name__"): + _id = [*obj.__module__.split("."), obj.__name__] + elif hasattr(obj, "__class__"): + _id = [ + *obj.__class__.__module__.split("."), + obj.__class__.__name__, + ] + except Exception: + pass + + result: SerializedNotImplemented = { + "lc": 1, + "type": "not_implemented", + "id": _id, + "repr": None, + } + try: + result["repr"] = repr(obj) + except Exception: + pass + return result + + +class Document(Serializable): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + """String text.""" + metadata: dict = Field(default_factory=dict) + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + type: Literal["Document"] = "Document" + + def __init__(self, page_content: str, **kwargs: Any) -> None: + """Pass page_content in as positional or named arg.""" + super().__init__(page_content=page_content, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "schema", "document"] + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + raise NotImplementedError("This method is not implemented.") + # return await langchain_core.runnables.config.run_in_executor( + # None, self.transform_documents, documents, **kwargs + # ) + + +def _make_spacy_pipe_for_splitting( + pipe: str, *, max_length: int = 1_000_000 +) -> Any: # avoid importing spacy + try: + import spacy + except ImportError: + raise ImportError( + "Spacy is not installed, please install it with `pip install spacy`." + ) + if pipe == "sentencizer": + from spacy.lang.en import English + + sentencizer = English() + sentencizer.add_pipe("sentencizer") + else: + sentencizer = spacy.load(pipe, exclude=["ner", "tagger"]) + sentencizer.max_length = max_length + return sentencizer + + +def _split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [ + _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2) + ] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class TextSplitter(BaseDocumentTransformer, ABC): + """Interface for splitting text into chunks.""" + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ) -> None: + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether to keep the separator in the chunks + add_start_index: If `True`, includes chunk's start index in metadata + strip_whitespace: If `True`, strips whitespace from the start and end of + every document + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + self._strip_whitespace = strip_whitespace + + @abstractmethod + def split_text(self, text: str) -> List[str]: + """Split text into multiple components.""" + + def create_documents( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> List[Document]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + index = 0 + previous_chunk_len = 0 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self._add_start_index: + offset = index + previous_chunk_len - self._chunk_overlap + index = text.find(chunk, max(0, offset)) + metadata["start_index"] = index + previous_chunk_len = len(chunk) + new_doc = Document(page_content=chunk, metadata=metadata) + documents.append(new_doc) + return documents + + def split_documents(self, documents: Iterable[Document]) -> List[Document]: + """Split documents.""" + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) + return self.create_documents(texts, metadatas=metadatas) + + def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: + text = separator.join(docs) + if self._strip_whitespace: + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits( + self, splits: Iterable[str], separator: str + ) -> List[str]: + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc: List[str] = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + + _len + + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs + + @classmethod + def from_huggingface_tokenizer( + cls, tokenizer: Any, **kwargs: Any + ) -> TextSplitter: + """Text splitter that uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + + @classmethod + def from_tiktoken_encoder( + cls: Type[TS], + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> TS: + """Text splitter that uses tiktoken encoder to count length.""" + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate max_tokens_for_prompt. " + "Please install it with `pip install tiktoken`." + ) + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + + def _tiktoken_encoder(text: str) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model": model, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_tiktoken_encoder, **kwargs) + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) + + +class CharacterTextSplitter(TextSplitter): + """Splitting text that looks at characters.""" + + def __init__( + self, + separator: str = "\n\n", + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + self._is_separator_regex = is_separator_regex + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + separator = ( + self._separator + if self._is_separator_regex + else re.escape(self._separator) + ) + splits = _split_text_with_regex(text, separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) + + +class LineType(TypedDict): + """Line type as typed dict.""" + + metadata: Dict[str, str] + content: str + + +class HeaderType(TypedDict): + """Header type as typed dict.""" + + level: int + name: str + data: str + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, + headers_to_split_on: List[Tuple[str, str]], + return_each_line: bool = False, + strip_headers: bool = True, + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + return_each_line: Return each line w/ associated headers + strip_headers: Strip split headers from the content of the chunk + """ + # Output line-by-line or aggregated into chunks w/ common headers + self.return_each_line = return_each_line + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + # Strip headers split headers from the content of the chunk + self.strip_headers = strip_headers + + def aggregate_lines_to_chunks( + self, lines: List[LineType] + ) -> List[Document]: + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks: List[LineType] = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == line["metadata"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + elif ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] != line["metadata"] + # may be issues if other metadata is present + and len(aggregated_chunks[-1]["metadata"]) + < len(line["metadata"]) + and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#" + and not self.strip_headers + ): + # If the last line in the aggregated list + # has different metadata as the current line, + # and has shallower header level than the current line, + # and the last line is a header, + # and we are not stripping headers, + # append the current content to the last line's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + # and update the last line's metadata + aggregated_chunks[-1]["metadata"] = line["metadata"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in aggregated_chunks + ] + + def split_text(self, text: str) -> List[Document]: + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata: List[LineType] = [] + # Content and metadata of the chunk currently being processed + current_content: List[str] = [] + current_metadata: Dict[str, str] = {} + # Keep track of the nested header structure + # header_stack: List[Dict[str, Union[int, str]]] = [] + header_stack: List[HeaderType] = [] + initial_metadata: Dict[str, str] = {} + + in_code_block = False + opening_fence = "" + + for line in lines: + stripped_line = line.strip() + + if not in_code_block: + # Exclude inline code spans + if ( + stripped_line.startswith("```") + and stripped_line.count("```") == 1 + ): + in_code_block = True + opening_fence = "```" + elif stripped_line.startswith("~~~"): + in_code_block = True + opening_fence = "~~~" + else: + if stripped_line.startswith(opening_fence): + in_code_block = False + opening_fence = "" + + if in_code_block: + current_content.append(stripped_line) + continue + + # Check each line against each of the header types (e.g., #, ##) + for sep, name in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + # Ensure we are tracking the header as metadata + if name is not None: + # Get the current header level + current_header_level = sep.count("#") + + # Pop out headers of lower or same level from the stack + while ( + header_stack + and header_stack[-1]["level"] + >= current_header_level + ): + # We have encountered a new header + # at the same or higher level + popped_header = header_stack.pop() + # Clear the metadata for the + # popped header in initial_metadata + if popped_header["name"] in initial_metadata: + initial_metadata.pop(popped_header["name"]) + + # Push the current header to the stack + header: HeaderType = { + "level": current_header_level, + "name": name, + "data": stripped_line[len(sep) :].strip(), + } + header_stack.append(header) + # Update initial_metadata with the current header + initial_metadata[name] = header["data"] + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + if not self.strip_headers: + current_content.append(stripped_line) + + break + else: + if stripped_line: + current_content.append(stripped_line) + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + current_metadata = initial_metadata.copy() + + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata, + } + ) + + # lines_with_metadata has each line with associated header metadata + # aggregate these into chunks based on common metadata + if not self.return_each_line: + return self.aggregate_lines_to_chunks(lines_with_metadata) + else: + return [ + Document( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in lines_with_metadata + ] + + +class ElementType(TypedDict): + """Element type as typed dict.""" + + url: str + xpath: str + content: str + metadata: Dict[str, str] + + +class HTMLHeaderTextSplitter: + """ + Splitting HTML files based on specified headers. + Requires lxml package. + """ + + def __init__( + self, + headers_to_split_on: List[Tuple[str, str]], + return_each_element: bool = False, + ): + """Create a new HTMLHeaderTextSplitter. + + Args: + headers_to_split_on: list of tuples of headers we want to track mapped to + (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, + h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)]. + return_each_element: Return each element w/ associated headers. + """ + # Output element-by-element or aggregated into chunks w/ common headers + self.return_each_element = return_each_element + self.headers_to_split_on = sorted(headers_to_split_on) + + def aggregate_elements_to_chunks( + self, elements: List[ElementType] + ) -> List[Document]: + """Combine elements with common metadata into chunks + + Args: + elements: HTML element content with associated identifying info and metadata + """ + aggregated_chunks: List[ElementType] = [] + + for element in elements: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == element["metadata"] + ): + # If the last element in the aggregated list + # has the same metadata as the current element, + # append the current content to the last element's content + aggregated_chunks[-1]["content"] += " \n" + element["content"] + else: + # Otherwise, append the current element to the aggregated list + aggregated_chunks.append(element) + + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in aggregated_chunks + ] + + def split_text_from_url(self, url: str) -> List[Document]: + """Split HTML from web URL + + Args: + url: web URL + """ + r = requests.get(url) + return self.split_text_from_file(BytesIO(r.content)) + + def split_text(self, text: str) -> List[Document]: + """Split HTML text string + + Args: + text: HTML text + """ + return self.split_text_from_file(StringIO(text)) + + def split_text_from_file(self, file: Any) -> List[Document]: + """Split HTML file + + Args: + file: HTML file + """ + try: + from lxml import etree + except ImportError as e: + raise ImportError( + "Unable to import lxml, please install with `pip install lxml`." + ) from e + # use lxml library to parse html document and return xml ElementTree + # Explicitly encoding in utf-8 allows non-English + # html files to be processed without garbled characters + parser = etree.HTMLParser(encoding="utf-8") + tree = etree.parse(file, parser) + + # document transformation for "structure-aware" chunking is handled with xsl. + # see comments in html_chunks_with_headers.xslt for more detailed information. + xslt_path = ( + pathlib.Path(__file__).parent + / "document_transformers/xsl/html_chunks_with_headers.xslt" + ) + xslt_tree = etree.parse(xslt_path) + transform = etree.XSLT(xslt_tree) + result = transform(tree) + result_dom = etree.fromstring(str(result)) + + # create filter and mapping for header metadata + header_filter = [header[0] for header in self.headers_to_split_on] + header_mapping = dict(self.headers_to_split_on) + + # map xhtml namespace prefix + ns_map = {"h": "http://www.w3.org/1999/xhtml"} + + # build list of elements from DOM + elements = [] + for element in result_dom.findall("*//*", ns_map): + if element.findall("*[@class='headers']") or element.findall( + "*[@class='chunk']" + ): + elements.append( + ElementType( + url=file, + xpath="".join( + [ + node.text + for node in element.findall( + "*[@class='xpath']", ns_map + ) + ] + ), + content="".join( + [ + node.text + for node in element.findall( + "*[@class='chunk']", ns_map + ) + ] + ), + metadata={ + # Add text of specified headers to metadata using header + # mapping. + header_mapping[node.tag]: node.text + for node in filter( + lambda x: x.tag in header_filter, + element.findall( + "*[@class='headers']/*", ns_map + ), + ) + }, + ) + ) + + if not self.return_each_element: + return self.aggregate_elements_to_chunks(elements) + else: + return [ + Document( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in elements + ] + + +# should be in newer Python versions (3.10+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + """Tokenizer data class.""" + + chunk_overlap: int + """Overlap in tokens between chunks""" + tokens_per_chunk: int + """Maximum number of tokens per chunk""" + decode: Callable[[List[int]], str] + """ Function to decode a list of token ids to a string""" + encode: Callable[[str], List[int]] + """ Function to encode a string to a list of token ids""" + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: List[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + if cur_idx == len(input_ids): + break + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + +class TokenTextSplitter(TextSplitter): + """Splitting text to tokens using model tokenizer.""" + + def __init__( + self, + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please install it with `pip install tiktoken`." + ) + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special + + def split_text(self, text: str) -> List[str]: + def _encode(_text: str) -> List[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class SentenceTransformersTokenTextSplitter(TextSplitter): + """Splitting text to tokens using sentence model tokenizer.""" + + def __init__( + self, + chunk_overlap: int = 50, + model: str = "sentence-transformers/all-mpnet-base-v2", + tokens_per_chunk: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs, chunk_overlap=chunk_overlap) + + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + "Could not import sentence_transformer python package. " + "This is needed in order to for SentenceTransformersTokenTextSplitter. " + "Please install it with `pip install sentence-transformers`." + ) + + self.model = model + self._model = SentenceTransformer(self.model, trust_remote_code=True) + self.tokenizer = self._model.tokenizer + self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) + + def _initialize_chunk_configuration( + self, *, tokens_per_chunk: Optional[int] + ) -> None: + self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) + + if tokens_per_chunk is None: + self.tokens_per_chunk = self.maximum_tokens_per_chunk + else: + self.tokens_per_chunk = tokens_per_chunk + + if self.tokens_per_chunk > self.maximum_tokens_per_chunk: + raise ValueError( + f"The token limit of the models '{self.model}'" + f" is: {self.maximum_tokens_per_chunk}." + f" Argument tokens_per_chunk={self.tokens_per_chunk}" + f" > maximum token limit." + ) + + def split_text(self, text: str) -> List[str]: + def encode_strip_start_and_stop_token_ids(text: str) -> List[int]: + return self._encode(text)[1:-1] + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self.tokens_per_chunk, + decode=self.tokenizer.decode, + encode=encode_strip_start_and_stop_token_ids, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + def count_tokens(self, *, text: str) -> int: + return len(self._encode(text)) + + _max_length_equal_32_bit_integer: int = 2**32 + + def _encode(self, text: str) -> List[int]: + token_ids_with_start_and_end_token_ids = self.tokenizer.encode( + text, + max_length=self._max_length_equal_32_bit_integer, + truncation="do_not_truncate", + ) + return token_ids_with_start_and_end_token_ids + + +class Language(str, Enum): + """Enum of the programming languages.""" + + CPP = "cpp" + GO = "go" + JAVA = "java" + KOTLIN = "kotlin" + JS = "js" + TS = "ts" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + HTML = "html" + SOL = "sol" + CSHARP = "csharp" + COBOL = "cobol" + C = "c" + LUA = "lua" + PERL = "perl" + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one + that works. + """ + + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or ["\n\n", "\n", " ", ""] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = ( + separator if self._is_separator_regex else re.escape(separator) + ) + splits = _split_text_with_regex(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> List[str]: + return self._split_text(text, self._separators) + + @classmethod + def from_language( + cls, language: Language, **kwargs: Any + ) -> RecursiveCharacterTextSplitter: + separators = cls.get_separators_for_language(language) + return cls(separators=separators, is_separator_regex=True, **kwargs) + + @staticmethod + def get_separators_for_language(language: Language) -> List[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.KOTLIN: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\ninternal ", + "\ncompanion ", + "\nfun ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nwhen ", + "\ncase ", + "\nelse ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.TS: + return [ + "\nenum ", + "\ninterface ", + "\nnamespace ", + "\ntype ", + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n=+\n", + "\n-+\n", + "\n\\*+\n", + # Split along directive markers + "\n\n.. *\n\n", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\\*\\*\\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle horizontal lines defined + # by *three or more* of ***, ---, or ___, but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\\\chapter{", + "\n\\\\section{", + "\n\\\\subsection{", + "\n\\\\subsubsection{", + # Now split by environments + "\n\\\\begin{enumerate}", + "\n\\\\begin{itemize}", + "\n\\\\begin{description}", + "\n\\\\begin{list}", + "\n\\\\begin{quote}", + "\n\\\\begin{quotation}", + "\n\\\\begin{verse}", + "\n\\\\begin{verbatim}", + # Now split by math environments + "\n\\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + elif language == Language.HTML: + return [ + # First, try to split along HTML tags + "<body", + "<div", + "<p", + "<br", + "<li", + "<h1", + "<h2", + "<h3", + "<h4", + "<h5", + "<h6", + "<span", + "<table", + "<tr", + "<td", + "<th", + "<ul", + "<ol", + "<header", + "<footer", + "<nav", + # Head + "<head", + "<style", + "<script", + "<meta", + "<title", + "", + ] + elif language == Language.CSHARP: + return [ + "\ninterface ", + "\nenum ", + "\nimplements ", + "\ndelegate ", + "\nevent ", + # Split along class definitions + "\nclass ", + "\nabstract ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + "\nreturn ", + # Split along control flow statements + "\nif ", + "\ncontinue ", + "\nfor ", + "\nforeach ", + "\nwhile ", + "\nswitch ", + "\nbreak ", + "\ncase ", + "\nelse ", + # Split by exceptions + "\ntry ", + "\nthrow ", + "\nfinally ", + "\ncatch ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SOL: + return [ + # Split along compiler information definitions + "\npragma ", + "\nusing ", + # Split along contract definitions + "\ncontract ", + "\ninterface ", + "\nlibrary ", + # Split along method definitions + "\nconstructor ", + "\ntype ", + "\nfunction ", + "\nevent ", + "\nmodifier ", + "\nerror ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo while ", + "\nassembly ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.COBOL: + return [ + # Split along divisions + "\nIDENTIFICATION DIVISION.", + "\nENVIRONMENT DIVISION.", + "\nDATA DIVISION.", + "\nPROCEDURE DIVISION.", + # Split along sections within DATA DIVISION + "\nWORKING-STORAGE SECTION.", + "\nLINKAGE SECTION.", + "\nFILE SECTION.", + # Split along sections within PROCEDURE DIVISION + "\nINPUT-OUTPUT SECTION.", + # Split along paragraphs and common statements + "\nOPEN ", + "\nCLOSE ", + "\nREAD ", + "\nWRITE ", + "\nIF ", + "\nELSE ", + "\nMOVE ", + "\nPERFORM ", + "\nUNTIL ", + "\nVARYING ", + "\nACCEPT ", + "\nDISPLAY ", + "\nSTOP RUN.", + # Split by the normal type of lines + "\n", + " ", + "", + ] + + else: + raise ValueError( + f"Language {language} is not supported! " + f"Please choose from {list(Language)}" + ) + + +class NLTKTextSplitter(TextSplitter): + """Splitting text using NLTK package.""" + + def __init__( + self, separator: str = "\n\n", language: str = "english", **kwargs: Any + ) -> None: + """Initialize the NLTK splitter.""" + super().__init__(**kwargs) + try: + from nltk.tokenize import sent_tokenize + + self._tokenizer = sent_tokenize + except ImportError: + raise ImportError( + "NLTK is not installed, please install it with `pip install nltk`." + ) + self._separator = separator + self._language = language + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + splits = self._tokenizer(text, language=self._language) + return self._merge_splits(splits, self._separator) + + +class SpacyTextSplitter(TextSplitter): + """Splitting text using Spacy package. + + + Per default, Spacy's `en_core_web_sm` model is used and + its default max_length is 1000000 (it is the length of maximum character + this model takes which can be increased for large files). For a faster, but + potentially less accurate splitting, you can use `pipe='sentencizer'`. + """ + + def __init__( + self, + separator: str = "\n\n", + pipe: str = "en_core_web_sm", + max_length: int = 1_000_000, + **kwargs: Any, + ) -> None: + """Initialize the spacy text splitter.""" + super().__init__(**kwargs) + self._tokenizer = _make_spacy_pipe_for_splitting( + pipe, max_length=max_length + ) + self._separator = separator + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + splits = (s.text for s in self._tokenizer(text).sents) + return self._merge_splits(splits, self._separator) + + +class KonlpyTextSplitter(TextSplitter): + """Splitting text using Konlpy package. + + It is good for splitting Korean text. + """ + + def __init__( + self, + separator: str = "\n\n", + **kwargs: Any, + ) -> None: + """Initialize the Konlpy text splitter.""" + super().__init__(**kwargs) + self._separator = separator + try: + from konlpy.tag import Kkma + except ImportError: + raise ImportError( + """ + Konlpy is not installed, please install it with + `pip install konlpy` + """ + ) + self.kkma = Kkma() + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + splits = self.kkma.sentences(text) + return self._merge_splits(splits, self._separator) + + +# For backwards compatibility +class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Python syntax.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a PythonCodeTextSplitter.""" + separators = self.get_separators_for_language(Language.PYTHON) + super().__init__(separators=separators, **kwargs) + + +class MarkdownTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Markdown-formatted headings.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a MarkdownTextSplitter.""" + separators = self.get_separators_for_language(Language.MARKDOWN) + super().__init__(separators=separators, **kwargs) + + +class LatexTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Latex-formatted layout elements.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a LatexTextSplitter.""" + separators = self.get_separators_for_language(Language.LATEX) + super().__init__(separators=separators, **kwargs) + + +class RecursiveJsonSplitter: + def __init__( + self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None + ): + super().__init__() + self.max_chunk_size = max_chunk_size + self.min_chunk_size = ( + min_chunk_size + if min_chunk_size is not None + else max(max_chunk_size - 200, 50) + ) + + @staticmethod + def _json_size(data: Dict) -> int: + """Calculate the size of the serialized JSON object.""" + return len(json.dumps(data)) + + @staticmethod + def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None: + """Set a value in a nested dictionary based on the given path.""" + for key in path[:-1]: + d = d.setdefault(key, {}) + d[path[-1]] = value + + def _list_to_dict_preprocessing(self, data: Any) -> Any: + if isinstance(data, dict): + # Process each key-value pair in the dictionary + return { + k: self._list_to_dict_preprocessing(v) for k, v in data.items() + } + elif isinstance(data, list): + # Convert the list to a dictionary with index-based keys + return { + str(i): self._list_to_dict_preprocessing(item) + for i, item in enumerate(data) + } + else: + # Base case: the item is neither a dict nor a list, so return it unchanged + return data + + def _json_split( + self, + data: Dict[str, Any], + current_path: List[str] = [], + chunks: List[Dict] = [{}], + ) -> List[Dict]: + """ + Split json into maximum size dictionaries while preserving structure. + """ + if isinstance(data, dict): + for key, value in data.items(): + new_path = current_path + [key] + chunk_size = self._json_size(chunks[-1]) + size = self._json_size({key: value}) + remaining = self.max_chunk_size - chunk_size + + if size < remaining: + # Add item to current chunk + self._set_nested_dict(chunks[-1], new_path, value) + else: + if chunk_size >= self.min_chunk_size: + # Chunk is big enough, start a new chunk + chunks.append({}) + + # Iterate + self._json_split(value, new_path, chunks) + else: + # handle single item + self._set_nested_dict(chunks[-1], current_path, data) + return chunks + + def split_json( + self, + json_data: Dict[str, Any], + convert_lists: bool = False, + ) -> List[Dict]: + """Splits JSON into a list of JSON chunks""" + + if convert_lists: + chunks = self._json_split( + self._list_to_dict_preprocessing(json_data) + ) + else: + chunks = self._json_split(json_data) + + # Remove the last chunk if it's empty + if not chunks[-1]: + chunks.pop() + return chunks + + def split_text( + self, json_data: Dict[str, Any], convert_lists: bool = False + ) -> List[str]: + """Splits JSON into a list of JSON formatted strings""" + + chunks = self.split_json( + json_data=json_data, convert_lists=convert_lists + ) + + # Convert to string + return [json.dumps(chunk) for chunk in chunks] + + def create_documents( + self, + texts: List[Dict], + convert_lists: bool = False, + metadatas: Optional[List[dict]] = None, + ) -> List[Document]: + """Create documents from a list of json objects (Dict).""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + for chunk in self.split_text( + json_data=text, convert_lists=convert_lists + ): + metadata = copy.deepcopy(_metadatas[i]) + new_doc = Document(page_content=chunk, metadata=metadata) + documents.append(new_doc) + return documents |