diff options
Diffstat (limited to 'R2R/r2r/base/abstractions')
-rwxr-xr-x | R2R/r2r/base/abstractions/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/base.py | 93 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/document.py | 242 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/exception.py | 16 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/llama_abstractions.py | 439 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/llm.py | 112 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/prompt.py | 31 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/search.py | 84 | ||||
-rwxr-xr-x | R2R/r2r/base/abstractions/vector.py | 66 |
9 files changed, 1083 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/__init__.py b/R2R/r2r/base/abstractions/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/abstractions/__init__.py diff --git a/R2R/r2r/base/abstractions/base.py b/R2R/r2r/base/abstractions/base.py new file mode 100755 index 00000000..7121f6ce --- /dev/null +++ b/R2R/r2r/base/abstractions/base.py @@ -0,0 +1,93 @@ +import asyncio +import uuid +from typing import List + +from pydantic import BaseModel + + +class UserStats(BaseModel): + user_id: uuid.UUID + num_files: int + total_size_in_bytes: int + document_ids: List[uuid.UUID] + + +class AsyncSyncMeta(type): + _event_loop = None # Class-level shared event loop + + @classmethod + def get_event_loop(cls): + if cls._event_loop is None or cls._event_loop.is_closed(): + cls._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._event_loop) + return cls._event_loop + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + for attr_name, attr_value in dct.items(): + if asyncio.iscoroutinefunction(attr_value) and getattr( + attr_value, "_syncable", False + ): + sync_method_name = attr_name[ + 1: + ] # Remove leading 'a' for sync method + async_method = attr_value + + def make_sync_method(async_method): + def sync_wrapper(self, *args, **kwargs): + loop = cls.get_event_loop() + if not loop.is_running(): + # Setup to run the loop in a background thread if necessary + # to prevent blocking the main thread in a synchronous call environment + from threading import Thread + + result = None + exception = None + + def run(): + nonlocal result, exception + try: + asyncio.set_event_loop(loop) + result = loop.run_until_complete( + async_method(self, *args, **kwargs) + ) + except Exception as e: + exception = e + finally: + generation_config = kwargs.get( + "rag_generation_config", None + ) + if ( + not generation_config + or not generation_config.stream + ): + loop.run_until_complete( + loop.shutdown_asyncgens() + ) + loop.close() + + thread = Thread(target=run) + thread.start() + thread.join() + if exception: + raise exception + return result + else: + # If there's already a running loop, schedule and execute the coroutine + future = asyncio.run_coroutine_threadsafe( + async_method(self, *args, **kwargs), loop + ) + return future.result() + + return sync_wrapper + + setattr( + new_cls, sync_method_name, make_sync_method(async_method) + ) + return new_cls + + +def syncable(func): + """Decorator to mark methods for synchronous wrapper creation.""" + func._syncable = True + return func diff --git a/R2R/r2r/base/abstractions/document.py b/R2R/r2r/base/abstractions/document.py new file mode 100755 index 00000000..117db7b9 --- /dev/null +++ b/R2R/r2r/base/abstractions/document.py @@ -0,0 +1,242 @@ +"""Abstractions for documents and their extractions.""" + +import base64 +import json +import logging +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +DataType = Union[str, bytes] + + +class DocumentType(str, Enum): + """Types of documents that can be stored.""" + + CSV = "csv" + DOCX = "docx" + HTML = "html" + JSON = "json" + MD = "md" + PDF = "pdf" + PPTX = "pptx" + TXT = "txt" + XLSX = "xlsx" + GIF = "gif" + PNG = "png" + JPG = "jpg" + JPEG = "jpeg" + SVG = "svg" + MP3 = "mp3" + MP4 = "mp4" + + +class Document(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + type: DocumentType + data: Union[str, bytes] + metadata: dict + + def __init__(self, *args, **kwargs): + data = kwargs.get("data") + if data and isinstance(data, str): + try: + # Try to decode if it's already base64 encoded + kwargs["data"] = base64.b64decode(data) + except: + # If it's not base64, encode it to bytes + kwargs["data"] = data.encode("utf-8") + + doc_type = kwargs.get("type") + if isinstance(doc_type, str): + kwargs["type"] = DocumentType(doc_type) + + # Generate UUID based on the hash of the data + if "id" not in kwargs: + if isinstance(kwargs["data"], bytes): + data_hash = uuid.uuid5( + uuid.NAMESPACE_DNS, kwargs["data"].decode("utf-8") + ) + else: + data_hash = uuid.uuid5(uuid.NAMESPACE_DNS, kwargs["data"]) + + kwargs["id"] = data_hash # Set the id based on the data hash + + super().__init__(*args, **kwargs) + + class Config: + arbitrary_types_allowed = True + json_encoders = { + uuid.UUID: str, + bytes: lambda v: base64.b64encode(v).decode("utf-8"), + } + + +class DocumentStatus(str, Enum): + """Status of document processing.""" + + PROCESSING = "processing" + # TODO - Extend support for `partial-failure` + # PARTIAL_FAILURE = "partial-failure" + FAILURE = "failure" + SUCCESS = "success" + + +class DocumentInfo(BaseModel): + """Base class for document information handling.""" + + document_id: uuid.UUID + version: str + size_in_bytes: int + metadata: dict + status: DocumentStatus = DocumentStatus.PROCESSING + + user_id: Optional[uuid.UUID] = None + title: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + def convert_to_db_entry(self): + """Prepare the document info for database entry, extracting certain fields from metadata.""" + now = datetime.now() + metadata = self.metadata + if "user_id" in metadata: + metadata["user_id"] = str(metadata["user_id"]) + + metadata["title"] = metadata.get("title", "N/A") + return { + "document_id": str(self.document_id), + "title": metadata.get("title", "N/A"), + "user_id": metadata.get("user_id", None), + "version": self.version, + "size_in_bytes": self.size_in_bytes, + "metadata": json.dumps(self.metadata), + "created_at": self.created_at or now, + "updated_at": self.updated_at or now, + "status": self.status, + } + + +class ExtractionType(Enum): + """Types of extractions that can be performed.""" + + TXT = "txt" + IMG = "img" + MOV = "mov" + + +class Extraction(BaseModel): + """An extraction from a document.""" + + id: uuid.UUID + type: ExtractionType = ExtractionType.TXT + data: DataType + metadata: dict + document_id: uuid.UUID + + +class FragmentType(Enum): + """A type of fragment that can be extracted from a document.""" + + TEXT = "text" + IMAGE = "image" + + +class Fragment(BaseModel): + """A fragment extracted from a document.""" + + id: uuid.UUID + type: FragmentType + data: DataType + metadata: dict + document_id: uuid.UUID + extraction_id: uuid.UUID + + +class Entity(BaseModel): + """An entity extracted from a document.""" + + category: str + subcategory: Optional[str] = None + value: str + + def __str__(self): + return ( + f"{self.category}:{self.subcategory}:{self.value}" + if self.subcategory + else f"{self.category}:{self.value}" + ) + + +class Triple(BaseModel): + """A triple extracted from a document.""" + + subject: str + predicate: str + object: str + + +def extract_entities(llm_payload: list[str]) -> dict[str, Entity]: + entities = {} + for entry in llm_payload: + try: + if "], " in entry: # Check if the entry is an entity + entry_val = entry.split("], ")[0] + "]" + entry = entry.split("], ")[1] + colon_count = entry.count(":") + + if colon_count == 1: + category, value = entry.split(":") + subcategory = None + elif colon_count >= 2: + parts = entry.split(":", 2) + category, subcategory, value = ( + parts[0], + parts[1], + parts[2], + ) + else: + raise ValueError("Unexpected entry format") + + entities[entry_val] = Entity( + category=category, subcategory=subcategory, value=value + ) + except Exception as e: + logger.error(f"Error processing entity {entry}: {e}") + continue + return entities + + +def extract_triples( + llm_payload: list[str], entities: dict[str, Entity] +) -> list[Triple]: + triples = [] + for entry in llm_payload: + try: + if "], " not in entry: # Check if the entry is an entity + elements = entry.split(" ") + subject = elements[0] + predicate = elements[1] + object = " ".join(elements[2:]) + subject = entities[subject].value # Use entity.value + if "[" in object and "]" in object: + object = entities[object].value # Use entity.value + triples.append( + Triple(subject=subject, predicate=predicate, object=object) + ) + except Exception as e: + logger.error(f"Error processing triplet {entry}: {e}") + continue + return triples + + +class KGExtraction(BaseModel): + """An extraction from a document that is part of a knowledge graph.""" + + entities: dict[str, Entity] + triples: list[Triple] diff --git a/R2R/r2r/base/abstractions/exception.py b/R2R/r2r/base/abstractions/exception.py new file mode 100755 index 00000000..c76625a3 --- /dev/null +++ b/R2R/r2r/base/abstractions/exception.py @@ -0,0 +1,16 @@ +from typing import Any, Optional + + +class R2RException(Exception): + def __init__( + self, message: str, status_code: int, detail: Optional[Any] = None + ): + self.message = message + self.status_code = status_code + super().__init__(self.message) + + +class R2RDocumentProcessingError(R2RException): + def __init__(self, error_message, document_id): + self.document_id = document_id + super().__init__(error_message, 400, {"document_id": document_id}) diff --git a/R2R/r2r/base/abstractions/llama_abstractions.py b/R2R/r2r/base/abstractions/llama_abstractions.py new file mode 100755 index 00000000..f6bc36e6 --- /dev/null +++ b/R2R/r2r/base/abstractions/llama_abstractions.py @@ -0,0 +1,439 @@ +# abstractions are taken from LlamaIndex +# https://github.com/run-llama/llama_index +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr + + +class LabelledNode(BaseModel): + """An entity in a graph.""" + + label: str = Field(default="node", description="The label of the node.") + embedding: Optional[List[float]] = Field( + default=None, description="The embeddings of the node." + ) + properties: Dict[str, Any] = Field(default_factory=dict) + + @abstractmethod + def __str__(self) -> str: + """Return the string representation of the node.""" + ... + + @property + @abstractmethod + def id(self) -> str: + """Get the node id.""" + ... + + +class EntityNode(LabelledNode): + """An entity in a graph.""" + + name: str = Field(description="The name of the entity.") + label: str = Field(default="entity", description="The label of the node.") + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the node.""" + return self.name + + @property + def id(self) -> str: + """Get the node id.""" + return self.name.replace('"', " ") + + +class ChunkNode(LabelledNode): + """A text chunk in a graph.""" + + text: str = Field(description="The text content of the chunk.") + id_: Optional[str] = Field( + default=None, + description="The id of the node. Defaults to a hash of the text.", + ) + label: str = Field( + default="text_chunk", description="The label of the node." + ) + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the node.""" + return self.text + + @property + def id(self) -> str: + """Get the node id.""" + return str(hash(self.text)) if self.id_ is None else self.id_ + + +class Relation(BaseModel): + """A relation connecting two entities in a graph.""" + + label: str + source_id: str + target_id: str + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the relation.""" + return self.label + + @property + def id(self) -> str: + """Get the relation id.""" + return self.label + + +Triplet = Tuple[LabelledNode, Relation, LabelledNode] + + +class VectorStoreQueryMode(str, Enum): + """Vector store query mode.""" + + DEFAULT = "default" + SPARSE = "sparse" + HYBRID = "hybrid" + TEXT_SEARCH = "text_search" + SEMANTIC_HYBRID = "semantic_hybrid" + + # fit learners + SVM = "svm" + LOGISTIC_REGRESSION = "logistic_regression" + LINEAR_REGRESSION = "linear_regression" + + # maximum marginal relevance + MMR = "mmr" + + +class FilterOperator(str, Enum): + """Vector store filter operator.""" + + # TODO add more operators + EQ = "==" # default operator (string, int, float) + GT = ">" # greater than (int, float) + LT = "<" # less than (int, float) + NE = "!=" # not equal to (string, int, float) + GTE = ">=" # greater than or equal to (int, float) + LTE = "<=" # less than or equal to (int, float) + IN = "in" # In array (string or number) + NIN = "nin" # Not in array (string or number) + ANY = "any" # Contains any (array of strings) + ALL = "all" # Contains all (array of strings) + TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field) + CONTAINS = "contains" # metadata array contains value (string or number) + + +class MetadataFilter(BaseModel): + """Comprehensive metadata filter for vector stores to support more operators. + + Value uses Strict* types, as int, float and str are compatible types and were all + converted to string before. + + See: https://docs.pydantic.dev/latest/usage/types/#strict-types + """ + + key: str + value: Union[ + StrictInt, + StrictFloat, + StrictStr, + List[StrictStr], + List[StrictFloat], + List[StrictInt], + ] + operator: FilterOperator = FilterOperator.EQ + + @classmethod + def from_dict( + cls, + filter_dict: Dict, + ) -> "MetadataFilter": + """Create MetadataFilter from dictionary. + + Args: + filter_dict: Dict with key, value and operator. + + """ + return MetadataFilter.parse_obj(filter_dict) + + +# # TODO: Deprecate ExactMatchFilter and use MetadataFilter instead +# # Keep class for now so that AutoRetriever can still work with old vector stores +# class ExactMatchFilter(BaseModel): +# key: str +# value: Union[StrictInt, StrictFloat, StrictStr] + +# set ExactMatchFilter to MetadataFilter +ExactMatchFilter = MetadataFilter + + +class FilterCondition(str, Enum): + """Vector store filter conditions to combine different filters.""" + + # TODO add more conditions + AND = "and" + OR = "or" + + +class MetadataFilters(BaseModel): + """Metadata filters for vector stores.""" + + # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc. + filters: List[Union[MetadataFilter, ExactMatchFilter, "MetadataFilters"]] + # and/or such conditions for combining different filters + condition: Optional[FilterCondition] = FilterCondition.AND + + +@dataclass +class VectorStoreQuery: + """Vector store query.""" + + query_embedding: Optional[List[float]] = None + similarity_top_k: int = 1 + doc_ids: Optional[List[str]] = None + node_ids: Optional[List[str]] = None + query_str: Optional[str] = None + output_fields: Optional[List[str]] = None + embedding_field: Optional[str] = None + + mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT + + # NOTE: only for hybrid search (0 for bm25, 1 for vector search) + alpha: Optional[float] = None + + # metadata filters + filters: Optional[MetadataFilters] = None + + # only for mmr + mmr_threshold: Optional[float] = None + + # NOTE: currently only used by postgres hybrid search + sparse_top_k: Optional[int] = None + # NOTE: return top k results from hybrid search. similarity_top_k is used for dense search top k + hybrid_top_k: Optional[int] = None + + +class PropertyGraphStore(ABC): + """Abstract labelled graph store protocol. + + This protocol defines the interface for a graph store, which is responsible + for storing and retrieving knowledge graph data. + + Attributes: + client: Any: The client used to connect to the graph store. + get: Callable[[str], List[List[str]]]: Get triplets for a given subject. + get_rel_map: Callable[[Optional[List[str]], int], Dict[str, List[List[str]]]]: + Get subjects' rel map in max depth. + upsert_triplet: Callable[[str, str, str], None]: Upsert a triplet. + delete: Callable[[str, str, str], None]: Delete a triplet. + persist: Callable[[str, Optional[fsspec.AbstractFileSystem]], None]: + Persist the graph store to a file. + """ + + supports_structured_queries: bool = False + supports_vector_queries: bool = False + + @property + def client(self) -> Any: + """Get client.""" + ... + + @abstractmethod + def get( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Get nodes with matching values.""" + ... + + @abstractmethod + def get_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get triplets with matching values.""" + ... + + @abstractmethod + def get_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get depth-aware rel map.""" + ... + + @abstractmethod + def upsert_nodes(self, nodes: List[LabelledNode]) -> None: + """Upsert nodes.""" + ... + + @abstractmethod + def upsert_relations(self, relations: List[Relation]) -> None: + """Upsert relations.""" + ... + + @abstractmethod + def delete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Delete matching data.""" + ... + + @abstractmethod + def structured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = None + ) -> Any: + """Query the graph store with statement and parameters.""" + ... + + @abstractmethod + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Query the graph store with a vector store query.""" + ... + + # def persist( + # self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None + # ) -> None: + # """Persist the graph store to a file.""" + # return + + def get_schema(self, refresh: bool = False) -> Any: + """Get the schema of the graph store.""" + return None + + def get_schema_str(self, refresh: bool = False) -> str: + """Get the schema of the graph store as a string.""" + return str(self.get_schema(refresh=refresh)) + + ### ----- Async Methods ----- ### + + async def aget( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Asynchronously get nodes with matching values.""" + return self.get(properties, ids) + + async def aget_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + """Asynchronously get triplets with matching values.""" + return self.get_triplets(entity_names, relation_names, properties, ids) + + async def aget_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Asynchronously get depth-aware rel map.""" + return self.get_rel_map(graph_nodes, depth, limit, ignore_rels) + + async def aupsert_nodes(self, nodes: List[LabelledNode]) -> None: + """Asynchronously add nodes.""" + return self.upsert_nodes(nodes) + + async def aupsert_relations(self, relations: List[Relation]) -> None: + """Asynchronously add relations.""" + return self.upsert_relations(relations) + + async def adelete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Asynchronously delete matching data.""" + return self.delete(entity_names, relation_names, properties, ids) + + async def astructured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = {} + ) -> Any: + """Asynchronously query the graph store with statement and parameters.""" + return self.structured_query(query, param_map) + + async def avector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Asynchronously query the graph store with a vector store query.""" + return self.vector_query(query, **kwargs) + + async def aget_schema(self, refresh: bool = False) -> str: + """Asynchronously get the schema of the graph store.""" + return self.get_schema(refresh=refresh) + + async def aget_schema_str(self, refresh: bool = False) -> str: + """Asynchronously get the schema of the graph store as a string.""" + return str(await self.aget_schema(refresh=refresh)) + + +LIST_LIMIT = 128 + + +def clean_string_values(text: str) -> str: + return text.replace("\n", " ").replace("\r", " ") + + +def value_sanitize(d: Any) -> Any: + """Sanitize the input dictionary or list. + + Sanitizes the input by removing embedding-like values, + lists with more than 128 elements, that are mostly irrelevant for + generating answers in a LLM context. These properties, if left in + results, can occupy significant context space and detract from + the LLM's performance by introducing unnecessary noise and cost. + """ + if isinstance(d, dict): + new_dict = {} + for key, value in d.items(): + if isinstance(value, dict): + sanitized_value = value_sanitize(value) + if ( + sanitized_value is not None + ): # Check if the sanitized value is not None + new_dict[key] = sanitized_value + elif isinstance(value, list): + if len(value) < LIST_LIMIT: + sanitized_value = value_sanitize(value) + if ( + sanitized_value is not None + ): # Check if the sanitized value is not None + new_dict[key] = sanitized_value + # Do not include the key if the list is oversized + else: + new_dict[key] = value + return new_dict + elif isinstance(d, list): + if len(d) < LIST_LIMIT: + return [ + value_sanitize(item) + for item in d + if value_sanitize(item) is not None + ] + else: + return None + else: + return d diff --git a/R2R/r2r/base/abstractions/llm.py b/R2R/r2r/base/abstractions/llm.py new file mode 100755 index 00000000..3178d8dc --- /dev/null +++ b/R2R/r2r/base/abstractions/llm.py @@ -0,0 +1,112 @@ +"""Abstractions for the LLM model.""" + +from typing import TYPE_CHECKING, ClassVar, Optional + +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from .search import AggregateSearchResult + +LLMChatCompletion = ChatCompletion +LLMChatCompletionChunk = ChatCompletionChunk + + +class RAGCompletion: + completion: LLMChatCompletion + search_results: "AggregateSearchResult" + + def __init__( + self, + completion: LLMChatCompletion, + search_results: "AggregateSearchResult", + ): + self.completion = completion + self.search_results = search_results + + +class GenerationConfig(BaseModel): + _defaults: ClassVar[dict] = { + "model": "gpt-4o", + "temperature": 0.1, + "top_p": 1.0, + "top_k": 100, + "max_tokens_to_sample": 1024, + "stream": False, + "functions": None, + "skip_special_tokens": False, + "stop_token": None, + "num_beams": 1, + "do_sample": True, + "generate_with_chat": False, + "add_generation_kwargs": None, + "api_base": None, + } + + model: str = Field( + default_factory=lambda: GenerationConfig._defaults["model"] + ) + temperature: float = Field( + default_factory=lambda: GenerationConfig._defaults["temperature"] + ) + top_p: float = Field( + default_factory=lambda: GenerationConfig._defaults["top_p"] + ) + top_k: int = Field( + default_factory=lambda: GenerationConfig._defaults["top_k"] + ) + max_tokens_to_sample: int = Field( + default_factory=lambda: GenerationConfig._defaults[ + "max_tokens_to_sample" + ] + ) + stream: bool = Field( + default_factory=lambda: GenerationConfig._defaults["stream"] + ) + functions: Optional[list[dict]] = Field( + default_factory=lambda: GenerationConfig._defaults["functions"] + ) + skip_special_tokens: bool = Field( + default_factory=lambda: GenerationConfig._defaults[ + "skip_special_tokens" + ] + ) + stop_token: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["stop_token"] + ) + num_beams: int = Field( + default_factory=lambda: GenerationConfig._defaults["num_beams"] + ) + do_sample: bool = Field( + default_factory=lambda: GenerationConfig._defaults["do_sample"] + ) + generate_with_chat: bool = Field( + default_factory=lambda: GenerationConfig._defaults[ + "generate_with_chat" + ] + ) + add_generation_kwargs: Optional[dict] = Field( + default_factory=lambda: GenerationConfig._defaults[ + "add_generation_kwargs" + ] + ) + api_base: Optional[str] = Field( + default_factory=lambda: GenerationConfig._defaults["api_base"] + ) + + @classmethod + def set_default(cls, **kwargs): + for key, value in kwargs.items(): + if key in cls._defaults: + cls._defaults[key] = value + else: + raise AttributeError( + f"No default attribute '{key}' in GenerationConfig" + ) + + def __init__(self, **data): + model = data.pop("model", None) + if model is not None: + super().__init__(model=model, **data) + else: + super().__init__(**data) diff --git a/R2R/r2r/base/abstractions/prompt.py b/R2R/r2r/base/abstractions/prompt.py new file mode 100755 index 00000000..e37eeb5f --- /dev/null +++ b/R2R/r2r/base/abstractions/prompt.py @@ -0,0 +1,31 @@ +"""Abstraction for a prompt that can be formatted with inputs.""" + +from typing import Any + +from pydantic import BaseModel + + +class Prompt(BaseModel): + """A prompt that can be formatted with inputs.""" + + name: str + template: str + input_types: dict[str, str] + + def format_prompt(self, inputs: dict[str, Any]) -> str: + self._validate_inputs(inputs) + return self.template.format(**inputs) + + def _validate_inputs(self, inputs: dict[str, Any]) -> None: + for var, expected_type_name in self.input_types.items(): + expected_type = self._convert_type(expected_type_name) + if var not in inputs: + raise ValueError(f"Missing input: {var}") + if not isinstance(inputs[var], expected_type): + raise TypeError( + f"Input '{var}' must be of type {expected_type.__name__}, got {type(inputs[var]).__name__} instead." + ) + + def _convert_type(self, type_name: str) -> type: + type_mapping = {"int": int, "str": str} + return type_mapping.get(type_name, str) diff --git a/R2R/r2r/base/abstractions/search.py b/R2R/r2r/base/abstractions/search.py new file mode 100755 index 00000000..b13cc5aa --- /dev/null +++ b/R2R/r2r/base/abstractions/search.py @@ -0,0 +1,84 @@ +"""Abstractions for search functionality.""" + +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field + +from .llm import GenerationConfig + + +class VectorSearchRequest(BaseModel): + """Request for a search operation.""" + + query: str + limit: int + filters: Optional[dict[str, Any]] = None + + +class VectorSearchResult(BaseModel): + """Result of a search operation.""" + + id: uuid.UUID + score: float + metadata: dict[str, Any] + + def __str__(self) -> str: + return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" + + def __repr__(self) -> str: + return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" + + def dict(self) -> dict: + return { + "id": self.id, + "score": self.score, + "metadata": self.metadata, + } + + +class KGSearchRequest(BaseModel): + """Request for a knowledge graph search operation.""" + + query: str + + +# [query, ...] +KGSearchResult = List[Tuple[str, List[Dict[str, Any]]]] + + +class AggregateSearchResult(BaseModel): + """Result of an aggregate search operation.""" + + vector_search_results: Optional[List[VectorSearchResult]] + kg_search_results: Optional[KGSearchResult] = None + + def __str__(self) -> str: + return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" + + def __repr__(self) -> str: + return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" + + def dict(self) -> dict: + return { + "vector_search_results": ( + [result.dict() for result in self.vector_search_results] + if self.vector_search_results + else [] + ), + "kg_search_results": self.kg_search_results or [], + } + + +class VectorSearchSettings(BaseModel): + use_vector_search: bool = True + search_filters: dict[str, Any] = Field(default_factory=dict) + search_limit: int = 10 + do_hybrid_search: bool = False + + +class KGSearchSettings(BaseModel): + use_kg_search: bool = False + agent_generation_config: Optional[GenerationConfig] = Field( + default_factory=GenerationConfig + ) diff --git a/R2R/r2r/base/abstractions/vector.py b/R2R/r2r/base/abstractions/vector.py new file mode 100755 index 00000000..445f3302 --- /dev/null +++ b/R2R/r2r/base/abstractions/vector.py @@ -0,0 +1,66 @@ +"""Abstraction for a vector that can be stored in the system.""" + +from enum import Enum +from typing import Any +from uuid import UUID + + +class VectorType(Enum): + FIXED = "FIXED" + + +class Vector: + """A vector with the option to fix the number of elements.""" + + def __init__( + self, + data: list[float], + type: VectorType = VectorType.FIXED, + length: int = -1, + ): + self.data = data + self.type = type + self.length = length + + if ( + self.type == VectorType.FIXED + and length > 0 + and len(data) != length + ): + raise ValueError(f"Vector must be exactly {length} elements long.") + + def __repr__(self) -> str: + return ( + f"Vector(data={self.data}, type={self.type}, length={self.length})" + ) + + +class VectorEntry: + """A vector entry that can be stored directly in supported vector databases.""" + + def __init__(self, id: UUID, vector: Vector, metadata: dict[str, Any]): + """Create a new VectorEntry object.""" + self.vector = vector + self.id = id + self.metadata = metadata + + def to_serializable(self) -> str: + """Return a serializable representation of the VectorEntry.""" + metadata = self.metadata + + for key in metadata: + if isinstance(metadata[key], UUID): + metadata[key] = str(metadata[key]) + return { + "id": str(self.id), + "vector": self.vector.data, + "metadata": metadata, + } + + def __str__(self) -> str: + """Return a string representation of the VectorEntry.""" + return f"VectorEntry(id={self.id}, vector={self.vector}, metadata={self.metadata})" + + def __repr__(self) -> str: + """Return an unambiguous string representation of the VectorEntry.""" + return f"VectorEntry(id={self.id}, vector={self.vector}, metadata={self.metadata})" |