diff options
Diffstat (limited to 'R2R/r2r/base/providers')
-rwxr-xr-x | R2R/r2r/base/providers/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/base_provider.py | 48 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/embedding_provider.py | 83 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/eval_provider.py | 46 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/kg_provider.py | 182 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/llm_provider.py | 66 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/prompt_provider.py | 65 | ||||
-rwxr-xr-x | R2R/r2r/base/providers/vector_db_provider.py | 142 |
8 files changed, 632 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/__init__.py b/R2R/r2r/base/providers/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/base/providers/__init__.py diff --git a/R2R/r2r/base/providers/base_provider.py b/R2R/r2r/base/providers/base_provider.py new file mode 100755 index 00000000..8ee8d56a --- /dev/null +++ b/R2R/r2r/base/providers/base_provider.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, Optional, Type + +from pydantic import BaseModel + + +class ProviderConfig(BaseModel, ABC): + """A base provider configuration class""" + + extra_fields: dict[str, Any] = {} + provider: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + ignore_extra = True + + @abstractmethod + def validate(self) -> None: + pass + + @classmethod + def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig": + base_args = cls.__fields__.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + @abstractproperty + @property + def supported_providers(self) -> list[str]: + """Define a list of supported providers.""" + pass + + +class Provider(ABC): + """A base provider class to provide a common interface for all providers.""" + + def __init__(self, config: Optional[ProviderConfig] = None): + if config: + config.validate() + self.config = config diff --git a/R2R/r2r/base/providers/embedding_provider.py b/R2R/r2r/base/providers/embedding_provider.py new file mode 100755 index 00000000..8f3af56f --- /dev/null +++ b/R2R/r2r/base/providers/embedding_provider.py @@ -0,0 +1,83 @@ +import logging +from abc import abstractmethod +from enum import Enum +from typing import Optional + +from ..abstractions.search import VectorSearchResult +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class EmbeddingConfig(ProviderConfig): + """A base embedding configuration class""" + + provider: Optional[str] = None + base_model: Optional[str] = None + base_dimension: Optional[int] = None + rerank_model: Optional[str] = None + rerank_dimension: Optional[int] = None + rerank_transformer_type: Optional[str] = None + batch_size: int = 1 + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return [None, "openai", "ollama", "sentence-transformers"] + + +class EmbeddingProvider(Provider): + """An abstract class to provide a common interface for embedding providers.""" + + class PipeStage(Enum): + BASE = 1 + RERANK = 2 + + def __init__(self, config: EmbeddingConfig): + if not isinstance(config, EmbeddingConfig): + raise ValueError( + "EmbeddingProvider must be initialized with a `EmbeddingConfig`." + ) + logger.info(f"Initializing EmbeddingProvider with config {config}.") + + super().__init__(config) + + @abstractmethod + def get_embedding(self, text: str, stage: PipeStage = PipeStage.BASE): + pass + + async def async_get_embedding( + self, text: str, stage: PipeStage = PipeStage.BASE + ): + return self.get_embedding(text, stage) + + @abstractmethod + def get_embeddings( + self, texts: list[str], stage: PipeStage = PipeStage.BASE + ): + pass + + async def async_get_embeddings( + self, texts: list[str], stage: PipeStage = PipeStage.BASE + ): + return self.get_embeddings(texts, stage) + + @abstractmethod + def rerank( + self, + query: str, + results: list[VectorSearchResult], + stage: PipeStage = PipeStage.RERANK, + limit: int = 10, + ): + pass + + @abstractmethod + def tokenize_string( + self, text: str, model: str, stage: PipeStage + ) -> list[int]: + """Tokenizes the input string.""" + pass diff --git a/R2R/r2r/base/providers/eval_provider.py b/R2R/r2r/base/providers/eval_provider.py new file mode 100755 index 00000000..76053f87 --- /dev/null +++ b/R2R/r2r/base/providers/eval_provider.py @@ -0,0 +1,46 @@ +from typing import Optional, Union + +from ..abstractions.llm import GenerationConfig +from .base_provider import Provider, ProviderConfig +from .llm_provider import LLMConfig + + +class EvalConfig(ProviderConfig): + """A base eval config class""" + + llm: Optional[LLMConfig] = None + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} not supported.") + if self.provider and not self.llm: + raise ValueError( + "EvalConfig must have a `llm` attribute when specifying a provider." + ) + + @property + def supported_providers(self) -> list[str]: + return [None, "local"] + + +class EvalProvider(Provider): + """An abstract class to provide a common interface for evaluation providers.""" + + def __init__(self, config: EvalConfig): + if not isinstance(config, EvalConfig): + raise ValueError( + "EvalProvider must be initialized with a `EvalConfig`." + ) + + super().__init__(config) + + def evaluate( + self, + query: str, + context: str, + completion: str, + eval_generation_config: Optional[GenerationConfig] = None, + ) -> dict[str, dict[str, Union[str, float]]]: + return self._evaluate( + query, context, completion, eval_generation_config + ) diff --git a/R2R/r2r/base/providers/kg_provider.py b/R2R/r2r/base/providers/kg_provider.py new file mode 100755 index 00000000..4ae96b11 --- /dev/null +++ b/R2R/r2r/base/providers/kg_provider.py @@ -0,0 +1,182 @@ +"""Base classes for knowledge graph providers.""" + +import json +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, Tuple + +from .prompt_provider import PromptProvider + +if TYPE_CHECKING: + from r2r.main import R2RClient + +from ...base.utils.base_utils import EntityType, Relation +from ..abstractions.llama_abstractions import EntityNode, LabelledNode +from ..abstractions.llama_abstractions import Relation as LlamaRelation +from ..abstractions.llama_abstractions import VectorStoreQuery +from ..abstractions.llm import GenerationConfig +from .base_provider import ProviderConfig + +logger = logging.getLogger(__name__) + + +class KGConfig(ProviderConfig): + """A base KG config class""" + + provider: Optional[str] = None + batch_size: int = 1 + kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction" + kg_agent_prompt: Optional[str] = "kg_agent" + kg_extraction_config: Optional[GenerationConfig] = None + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return [None, "neo4j"] + + +class KGProvider(ABC): + """An abstract class to provide a common interface for Knowledge Graphs.""" + + def __init__(self, config: KGConfig) -> None: + if not isinstance(config, KGConfig): + raise ValueError( + "KGProvider must be initialized with a `KGConfig`." + ) + logger.info(f"Initializing KG provider with config: {config}") + self.config = config + self.validate_config() + + def validate_config(self) -> None: + self.config.validate() + + @property + @abstractmethod + def client(self) -> Any: + """Get client.""" + pass + + @abstractmethod + def get(self, subj: str) -> list[list[str]]: + """Abstract method to get triplets.""" + pass + + @abstractmethod + def get_rel_map( + self, + subjs: Optional[list[str]] = None, + depth: int = 2, + limit: int = 30, + ) -> dict[str, list[list[str]]]: + """Abstract method to get depth-aware rel map.""" + pass + + @abstractmethod + def upsert_nodes(self, nodes: list[EntityNode]) -> None: + """Abstract method to add triplet.""" + pass + + @abstractmethod + def upsert_relations(self, relations: list[LlamaRelation]) -> None: + """Abstract method to add triplet.""" + pass + + @abstractmethod + def delete(self, subj: str, rel: str, obj: str) -> None: + """Abstract method to delete triplet.""" + pass + + @abstractmethod + def get_schema(self, refresh: bool = False) -> str: + """Abstract method to get the schema of the graph store.""" + pass + + @abstractmethod + def structured_query( + self, query: str, param_map: Optional[dict[str, Any]] = {} + ) -> Any: + """Abstract method to query the graph store with statement and parameters.""" + pass + + @abstractmethod + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[list[LabelledNode], list[float]]: + """Abstract method to query the graph store with a vector store query.""" + + # TODO - Type this method. + @abstractmethod + def update_extraction_prompt( + self, + prompt_provider: Any, + entity_types: list[Any], + relations: list[Relation], + ): + """Abstract method to update the KG extraction prompt.""" + pass + + # TODO - Type this method. + @abstractmethod + def update_kg_agent_prompt( + self, + prompt_provider: Any, + entity_types: list[Any], + relations: list[Relation], + ): + """Abstract method to update the KG agent prompt.""" + pass + + +def escape_braces(s: str) -> str: + """ + Escape braces in a string. + This is a placeholder function - implement the actual logic as needed. + """ + # Implement your escape_braces logic here + return s.replace("{", "{{").replace("}", "}}") + + +# TODO - Make this more configurable / intelligent +def update_kg_prompt( + client: "R2RClient", + r2r_prompts: PromptProvider, + prompt_base: str, + entity_types: list[EntityType], + relations: list[Relation], +) -> None: + # Get the default extraction template + template_name: str = f"{prompt_base}_with_spec" + + new_template: str = r2r_prompts.get_prompt( + template_name, + { + "entity_types": json.dumps( + { + "entity_types": [ + str(entity.name) for entity in entity_types + ] + }, + indent=4, + ), + "relations": json.dumps( + {"predicates": [str(relation.name) for relation in relations]}, + indent=4, + ), + "input": """\n{input}""", + }, + ) + + # Escape all braces in the template, except for the {input} placeholder, for formatting + escaped_template: str = escape_braces(new_template).replace( + """{{input}}""", """{input}""" + ) + + # Update the client's prompt + client.update_prompt( + prompt_base, + template=escaped_template, + input_types={"input": "str"}, + ) diff --git a/R2R/r2r/base/providers/llm_provider.py b/R2R/r2r/base/providers/llm_provider.py new file mode 100755 index 00000000..9b6499a4 --- /dev/null +++ b/R2R/r2r/base/providers/llm_provider.py @@ -0,0 +1,66 @@ +"""Base classes for language model providers.""" + +import logging +from abc import abstractmethod +from typing import Optional + +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class LLMConfig(ProviderConfig): + """A base LLM config class""" + + provider: Optional[str] = None + generation_config: Optional[GenerationConfig] = None + + def validate(self) -> None: + if not self.provider: + raise ValueError("Provider must be set.") + + if self.provider and self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["litellm", "openai"] + + +class LLMProvider(Provider): + """An abstract class to provide a common interface for LLMs.""" + + def __init__( + self, + config: LLMConfig, + ) -> None: + if not isinstance(config, LLMConfig): + raise ValueError( + "LLMProvider must be initialized with a `LLMConfig`." + ) + logger.info(f"Initializing LLM provider with config: {config}") + + super().__init__(config) + + @abstractmethod + def get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + """Abstract method to get a chat completion from the provider.""" + pass + + @abstractmethod + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletionChunk: + """Abstract method to get a completion stream from the provider.""" + pass diff --git a/R2R/r2r/base/providers/prompt_provider.py b/R2R/r2r/base/providers/prompt_provider.py new file mode 100755 index 00000000..78af9e11 --- /dev/null +++ b/R2R/r2r/base/providers/prompt_provider.py @@ -0,0 +1,65 @@ +import logging +from abc import abstractmethod +from typing import Any, Optional + +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class PromptConfig(ProviderConfig): + def validate(self) -> None: + pass + + @property + def supported_providers(self) -> list[str]: + # Return a list of supported prompt providers + return ["default_prompt_provider"] + + +class PromptProvider(Provider): + def __init__(self, config: Optional[PromptConfig] = None): + if config is None: + config = PromptConfig() + elif not isinstance(config, PromptConfig): + raise ValueError( + "PromptProvider must be initialized with a `PromptConfig`." + ) + logger.info(f"Initializing PromptProvider with config {config}.") + super().__init__(config) + + @abstractmethod + def add_prompt( + self, name: str, template: str, input_types: dict[str, str] + ) -> None: + pass + + @abstractmethod + def get_prompt( + self, prompt_name: str, inputs: Optional[dict[str, Any]] = None + ) -> str: + pass + + @abstractmethod + def get_all_prompts(self) -> dict[str, str]: + pass + + @abstractmethod + def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = None, + ) -> None: + pass + + def _get_message_payload( + self, system_prompt: str, task_prompt: str + ) -> dict: + return [ + { + "role": "system", + "content": system_prompt, + }, + {"role": "user", "content": task_prompt}, + ] diff --git a/R2R/r2r/base/providers/vector_db_provider.py b/R2R/r2r/base/providers/vector_db_provider.py new file mode 100755 index 00000000..a6d5aaa8 --- /dev/null +++ b/R2R/r2r/base/providers/vector_db_provider.py @@ -0,0 +1,142 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional, Union + +from ..abstractions.document import DocumentInfo +from ..abstractions.search import VectorSearchResult +from ..abstractions.vector import VectorEntry +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class VectorDBConfig(ProviderConfig): + provider: str + + def __post_init__(self): + self.validate() + # Capture additional fields + for key, value in self.extra_fields.items(): + setattr(self, key, value) + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["local", "pgvector"] + + +class VectorDBProvider(Provider, ABC): + def __init__(self, config: VectorDBConfig): + if not isinstance(config, VectorDBConfig): + raise ValueError( + "VectorDBProvider must be initialized with a `VectorDBConfig`." + ) + logger.info(f"Initializing VectorDBProvider with config {config}.") + super().__init__(config) + + @abstractmethod + def initialize_collection(self, dimension: int) -> None: + pass + + @abstractmethod + def copy(self, entry: VectorEntry, commit: bool = True) -> None: + pass + + @abstractmethod + def upsert(self, entry: VectorEntry, commit: bool = True) -> None: + pass + + @abstractmethod + def search( + self, + query_vector: list[float], + filters: dict[str, Union[bool, int, str]] = {}, + limit: int = 10, + *args, + **kwargs, + ) -> list[VectorSearchResult]: + pass + + @abstractmethod + def hybrid_search( + self, + query_text: str, + query_vector: list[float], + limit: int = 10, + filters: Optional[dict[str, Union[bool, int, str]]] = None, + # Hybrid search parameters + full_text_weight: float = 1.0, + semantic_weight: float = 1.0, + rrf_k: int = 20, # typical value is ~2x the number of results you want + *args, + **kwargs, + ) -> list[VectorSearchResult]: + pass + + @abstractmethod + def create_index(self, index_type, column_name, index_options): + pass + + def upsert_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + for entry in entries: + self.upsert(entry, commit=commit) + + def copy_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + for entry in entries: + self.copy(entry, commit=commit) + + @abstractmethod + def delete_by_metadata( + self, + metadata_fields: list[str], + metadata_values: list[Union[bool, int, str]], + ) -> list[str]: + if len(metadata_fields) != len(metadata_values): + raise ValueError( + "The number of metadata fields and values must be equal." + ) + pass + + @abstractmethod + def get_metadatas( + self, + metadata_fields: list[str], + filter_field: Optional[str] = None, + filter_value: Optional[str] = None, + ) -> list[str]: + pass + + @abstractmethod + def upsert_documents_overview( + self, document_infs: list[DocumentInfo] + ) -> None: + pass + + @abstractmethod + def get_documents_overview( + self, + filter_document_ids: Optional[list[str]] = None, + filter_user_ids: Optional[list[str]] = None, + ) -> list[DocumentInfo]: + pass + + @abstractmethod + def get_document_chunks(self, document_id: str) -> list[dict]: + pass + + @abstractmethod + def delete_from_documents_overview( + self, document_id: str, version: Optional[str] = None + ) -> dict: + pass + + @abstractmethod + def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict: + pass |