diff options
Diffstat (limited to 'R2R/r2r/base/abstractions/llama_abstractions.py')
-rwxr-xr-x | R2R/r2r/base/abstractions/llama_abstractions.py | 439 |
1 files changed, 439 insertions, 0 deletions
diff --git a/R2R/r2r/base/abstractions/llama_abstractions.py b/R2R/r2r/base/abstractions/llama_abstractions.py new file mode 100755 index 00000000..f6bc36e6 --- /dev/null +++ b/R2R/r2r/base/abstractions/llama_abstractions.py @@ -0,0 +1,439 @@ +# abstractions are taken from LlamaIndex +# https://github.com/run-llama/llama_index +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr + + +class LabelledNode(BaseModel): + """An entity in a graph.""" + + label: str = Field(default="node", description="The label of the node.") + embedding: Optional[List[float]] = Field( + default=None, description="The embeddings of the node." + ) + properties: Dict[str, Any] = Field(default_factory=dict) + + @abstractmethod + def __str__(self) -> str: + """Return the string representation of the node.""" + ... + + @property + @abstractmethod + def id(self) -> str: + """Get the node id.""" + ... + + +class EntityNode(LabelledNode): + """An entity in a graph.""" + + name: str = Field(description="The name of the entity.") + label: str = Field(default="entity", description="The label of the node.") + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the node.""" + return self.name + + @property + def id(self) -> str: + """Get the node id.""" + return self.name.replace('"', " ") + + +class ChunkNode(LabelledNode): + """A text chunk in a graph.""" + + text: str = Field(description="The text content of the chunk.") + id_: Optional[str] = Field( + default=None, + description="The id of the node. Defaults to a hash of the text.", + ) + label: str = Field( + default="text_chunk", description="The label of the node." + ) + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the node.""" + return self.text + + @property + def id(self) -> str: + """Get the node id.""" + return str(hash(self.text)) if self.id_ is None else self.id_ + + +class Relation(BaseModel): + """A relation connecting two entities in a graph.""" + + label: str + source_id: str + target_id: str + properties: Dict[str, Any] = Field(default_factory=dict) + + def __str__(self) -> str: + """Return the string representation of the relation.""" + return self.label + + @property + def id(self) -> str: + """Get the relation id.""" + return self.label + + +Triplet = Tuple[LabelledNode, Relation, LabelledNode] + + +class VectorStoreQueryMode(str, Enum): + """Vector store query mode.""" + + DEFAULT = "default" + SPARSE = "sparse" + HYBRID = "hybrid" + TEXT_SEARCH = "text_search" + SEMANTIC_HYBRID = "semantic_hybrid" + + # fit learners + SVM = "svm" + LOGISTIC_REGRESSION = "logistic_regression" + LINEAR_REGRESSION = "linear_regression" + + # maximum marginal relevance + MMR = "mmr" + + +class FilterOperator(str, Enum): + """Vector store filter operator.""" + + # TODO add more operators + EQ = "==" # default operator (string, int, float) + GT = ">" # greater than (int, float) + LT = "<" # less than (int, float) + NE = "!=" # not equal to (string, int, float) + GTE = ">=" # greater than or equal to (int, float) + LTE = "<=" # less than or equal to (int, float) + IN = "in" # In array (string or number) + NIN = "nin" # Not in array (string or number) + ANY = "any" # Contains any (array of strings) + ALL = "all" # Contains all (array of strings) + TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field) + CONTAINS = "contains" # metadata array contains value (string or number) + + +class MetadataFilter(BaseModel): + """Comprehensive metadata filter for vector stores to support more operators. + + Value uses Strict* types, as int, float and str are compatible types and were all + converted to string before. + + See: https://docs.pydantic.dev/latest/usage/types/#strict-types + """ + + key: str + value: Union[ + StrictInt, + StrictFloat, + StrictStr, + List[StrictStr], + List[StrictFloat], + List[StrictInt], + ] + operator: FilterOperator = FilterOperator.EQ + + @classmethod + def from_dict( + cls, + filter_dict: Dict, + ) -> "MetadataFilter": + """Create MetadataFilter from dictionary. + + Args: + filter_dict: Dict with key, value and operator. + + """ + return MetadataFilter.parse_obj(filter_dict) + + +# # TODO: Deprecate ExactMatchFilter and use MetadataFilter instead +# # Keep class for now so that AutoRetriever can still work with old vector stores +# class ExactMatchFilter(BaseModel): +# key: str +# value: Union[StrictInt, StrictFloat, StrictStr] + +# set ExactMatchFilter to MetadataFilter +ExactMatchFilter = MetadataFilter + + +class FilterCondition(str, Enum): + """Vector store filter conditions to combine different filters.""" + + # TODO add more conditions + AND = "and" + OR = "or" + + +class MetadataFilters(BaseModel): + """Metadata filters for vector stores.""" + + # Exact match filters and Advanced filters with operators like >, <, >=, <=, !=, etc. + filters: List[Union[MetadataFilter, ExactMatchFilter, "MetadataFilters"]] + # and/or such conditions for combining different filters + condition: Optional[FilterCondition] = FilterCondition.AND + + +@dataclass +class VectorStoreQuery: + """Vector store query.""" + + query_embedding: Optional[List[float]] = None + similarity_top_k: int = 1 + doc_ids: Optional[List[str]] = None + node_ids: Optional[List[str]] = None + query_str: Optional[str] = None + output_fields: Optional[List[str]] = None + embedding_field: Optional[str] = None + + mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT + + # NOTE: only for hybrid search (0 for bm25, 1 for vector search) + alpha: Optional[float] = None + + # metadata filters + filters: Optional[MetadataFilters] = None + + # only for mmr + mmr_threshold: Optional[float] = None + + # NOTE: currently only used by postgres hybrid search + sparse_top_k: Optional[int] = None + # NOTE: return top k results from hybrid search. similarity_top_k is used for dense search top k + hybrid_top_k: Optional[int] = None + + +class PropertyGraphStore(ABC): + """Abstract labelled graph store protocol. + + This protocol defines the interface for a graph store, which is responsible + for storing and retrieving knowledge graph data. + + Attributes: + client: Any: The client used to connect to the graph store. + get: Callable[[str], List[List[str]]]: Get triplets for a given subject. + get_rel_map: Callable[[Optional[List[str]], int], Dict[str, List[List[str]]]]: + Get subjects' rel map in max depth. + upsert_triplet: Callable[[str, str, str], None]: Upsert a triplet. + delete: Callable[[str, str, str], None]: Delete a triplet. + persist: Callable[[str, Optional[fsspec.AbstractFileSystem]], None]: + Persist the graph store to a file. + """ + + supports_structured_queries: bool = False + supports_vector_queries: bool = False + + @property + def client(self) -> Any: + """Get client.""" + ... + + @abstractmethod + def get( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Get nodes with matching values.""" + ... + + @abstractmethod + def get_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get triplets with matching values.""" + ... + + @abstractmethod + def get_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Get depth-aware rel map.""" + ... + + @abstractmethod + def upsert_nodes(self, nodes: List[LabelledNode]) -> None: + """Upsert nodes.""" + ... + + @abstractmethod + def upsert_relations(self, relations: List[Relation]) -> None: + """Upsert relations.""" + ... + + @abstractmethod + def delete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Delete matching data.""" + ... + + @abstractmethod + def structured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = None + ) -> Any: + """Query the graph store with statement and parameters.""" + ... + + @abstractmethod + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Query the graph store with a vector store query.""" + ... + + # def persist( + # self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None + # ) -> None: + # """Persist the graph store to a file.""" + # return + + def get_schema(self, refresh: bool = False) -> Any: + """Get the schema of the graph store.""" + return None + + def get_schema_str(self, refresh: bool = False) -> str: + """Get the schema of the graph store as a string.""" + return str(self.get_schema(refresh=refresh)) + + ### ----- Async Methods ----- ### + + async def aget( + self, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[LabelledNode]: + """Asynchronously get nodes with matching values.""" + return self.get(properties, ids) + + async def aget_triplets( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> List[Triplet]: + """Asynchronously get triplets with matching values.""" + return self.get_triplets(entity_names, relation_names, properties, ids) + + async def aget_rel_map( + self, + graph_nodes: List[LabelledNode], + depth: int = 2, + limit: int = 30, + ignore_rels: Optional[List[str]] = None, + ) -> List[Triplet]: + """Asynchronously get depth-aware rel map.""" + return self.get_rel_map(graph_nodes, depth, limit, ignore_rels) + + async def aupsert_nodes(self, nodes: List[LabelledNode]) -> None: + """Asynchronously add nodes.""" + return self.upsert_nodes(nodes) + + async def aupsert_relations(self, relations: List[Relation]) -> None: + """Asynchronously add relations.""" + return self.upsert_relations(relations) + + async def adelete( + self, + entity_names: Optional[List[str]] = None, + relation_names: Optional[List[str]] = None, + properties: Optional[dict] = None, + ids: Optional[List[str]] = None, + ) -> None: + """Asynchronously delete matching data.""" + return self.delete(entity_names, relation_names, properties, ids) + + async def astructured_query( + self, query: str, param_map: Optional[Dict[str, Any]] = {} + ) -> Any: + """Asynchronously query the graph store with statement and parameters.""" + return self.structured_query(query, param_map) + + async def avector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[List[LabelledNode], List[float]]: + """Asynchronously query the graph store with a vector store query.""" + return self.vector_query(query, **kwargs) + + async def aget_schema(self, refresh: bool = False) -> str: + """Asynchronously get the schema of the graph store.""" + return self.get_schema(refresh=refresh) + + async def aget_schema_str(self, refresh: bool = False) -> str: + """Asynchronously get the schema of the graph store as a string.""" + return str(await self.aget_schema(refresh=refresh)) + + +LIST_LIMIT = 128 + + +def clean_string_values(text: str) -> str: + return text.replace("\n", " ").replace("\r", " ") + + +def value_sanitize(d: Any) -> Any: + """Sanitize the input dictionary or list. + + Sanitizes the input by removing embedding-like values, + lists with more than 128 elements, that are mostly irrelevant for + generating answers in a LLM context. These properties, if left in + results, can occupy significant context space and detract from + the LLM's performance by introducing unnecessary noise and cost. + """ + if isinstance(d, dict): + new_dict = {} + for key, value in d.items(): + if isinstance(value, dict): + sanitized_value = value_sanitize(value) + if ( + sanitized_value is not None + ): # Check if the sanitized value is not None + new_dict[key] = sanitized_value + elif isinstance(value, list): + if len(value) < LIST_LIMIT: + sanitized_value = value_sanitize(value) + if ( + sanitized_value is not None + ): # Check if the sanitized value is not None + new_dict[key] = sanitized_value + # Do not include the key if the list is oversized + else: + new_dict[key] = value + return new_dict + elif isinstance(d, list): + if len(d) < LIST_LIMIT: + return [ + value_sanitize(item) + for item in d + if value_sanitize(item) is not None + ] + else: + return None + else: + return d |