about summary refs log tree commit diff
path: root/R2R/r2r/base/providers
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/providers')
-rwxr-xr-xR2R/r2r/base/providers/__init__.py0
-rwxr-xr-xR2R/r2r/base/providers/base_provider.py48
-rwxr-xr-xR2R/r2r/base/providers/embedding_provider.py83
-rwxr-xr-xR2R/r2r/base/providers/eval_provider.py46
-rwxr-xr-xR2R/r2r/base/providers/kg_provider.py182
-rwxr-xr-xR2R/r2r/base/providers/llm_provider.py66
-rwxr-xr-xR2R/r2r/base/providers/prompt_provider.py65
-rwxr-xr-xR2R/r2r/base/providers/vector_db_provider.py142
8 files changed, 632 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/__init__.py b/R2R/r2r/base/providers/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/providers/__init__.py
diff --git a/R2R/r2r/base/providers/base_provider.py b/R2R/r2r/base/providers/base_provider.py
new file mode 100755
index 00000000..8ee8d56a
--- /dev/null
+++ b/R2R/r2r/base/providers/base_provider.py
@@ -0,0 +1,48 @@
+from abc import ABC, abstractmethod, abstractproperty
+from typing import Any, Optional, Type
+
+from pydantic import BaseModel
+
+
+class ProviderConfig(BaseModel, ABC):
+    """A base provider configuration class"""
+
+    extra_fields: dict[str, Any] = {}
+    provider: Optional[str] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+        ignore_extra = True
+
+    @abstractmethod
+    def validate(self) -> None:
+        pass
+
+    @classmethod
+    def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+        base_args = cls.__fields__.keys()
+        filtered_kwargs = {
+            k: v if v != "None" else None
+            for k, v in kwargs.items()
+            if k in base_args
+        }
+        instance = cls(**filtered_kwargs)
+        for k, v in kwargs.items():
+            if k not in base_args:
+                instance.extra_fields[k] = v
+        return instance
+
+    @abstractproperty
+    @property
+    def supported_providers(self) -> list[str]:
+        """Define a list of supported providers."""
+        pass
+
+
+class Provider(ABC):
+    """A base provider class to provide a common interface for all providers."""
+
+    def __init__(self, config: Optional[ProviderConfig] = None):
+        if config:
+            config.validate()
+        self.config = config
diff --git a/R2R/r2r/base/providers/embedding_provider.py b/R2R/r2r/base/providers/embedding_provider.py
new file mode 100755
index 00000000..8f3af56f
--- /dev/null
+++ b/R2R/r2r/base/providers/embedding_provider.py
@@ -0,0 +1,83 @@
+import logging
+from abc import abstractmethod
+from enum import Enum
+from typing import Optional
+
+from ..abstractions.search import VectorSearchResult
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddingConfig(ProviderConfig):
+    """A base embedding configuration class"""
+
+    provider: Optional[str] = None
+    base_model: Optional[str] = None
+    base_dimension: Optional[int] = None
+    rerank_model: Optional[str] = None
+    rerank_dimension: Optional[int] = None
+    rerank_transformer_type: Optional[str] = None
+    batch_size: int = 1
+
+    def validate(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return [None, "openai", "ollama", "sentence-transformers"]
+
+
+class EmbeddingProvider(Provider):
+    """An abstract class to provide a common interface for embedding providers."""
+
+    class PipeStage(Enum):
+        BASE = 1
+        RERANK = 2
+
+    def __init__(self, config: EmbeddingConfig):
+        if not isinstance(config, EmbeddingConfig):
+            raise ValueError(
+                "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+            )
+        logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+        super().__init__(config)
+
+    @abstractmethod
+    def get_embedding(self, text: str, stage: PipeStage = PipeStage.BASE):
+        pass
+
+    async def async_get_embedding(
+        self, text: str, stage: PipeStage = PipeStage.BASE
+    ):
+        return self.get_embedding(text, stage)
+
+    @abstractmethod
+    def get_embeddings(
+        self, texts: list[str], stage: PipeStage = PipeStage.BASE
+    ):
+        pass
+
+    async def async_get_embeddings(
+        self, texts: list[str], stage: PipeStage = PipeStage.BASE
+    ):
+        return self.get_embeddings(texts, stage)
+
+    @abstractmethod
+    def rerank(
+        self,
+        query: str,
+        results: list[VectorSearchResult],
+        stage: PipeStage = PipeStage.RERANK,
+        limit: int = 10,
+    ):
+        pass
+
+    @abstractmethod
+    def tokenize_string(
+        self, text: str, model: str, stage: PipeStage
+    ) -> list[int]:
+        """Tokenizes the input string."""
+        pass
diff --git a/R2R/r2r/base/providers/eval_provider.py b/R2R/r2r/base/providers/eval_provider.py
new file mode 100755
index 00000000..76053f87
--- /dev/null
+++ b/R2R/r2r/base/providers/eval_provider.py
@@ -0,0 +1,46 @@
+from typing import Optional, Union
+
+from ..abstractions.llm import GenerationConfig
+from .base_provider import Provider, ProviderConfig
+from .llm_provider import LLMConfig
+
+
+class EvalConfig(ProviderConfig):
+    """A base eval config class"""
+
+    llm: Optional[LLMConfig] = None
+
+    def validate(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider {self.provider} not supported.")
+        if self.provider and not self.llm:
+            raise ValueError(
+                "EvalConfig must have a `llm` attribute when specifying a provider."
+            )
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return [None, "local"]
+
+
+class EvalProvider(Provider):
+    """An abstract class to provide a common interface for evaluation providers."""
+
+    def __init__(self, config: EvalConfig):
+        if not isinstance(config, EvalConfig):
+            raise ValueError(
+                "EvalProvider must be initialized with a `EvalConfig`."
+            )
+
+        super().__init__(config)
+
+    def evaluate(
+        self,
+        query: str,
+        context: str,
+        completion: str,
+        eval_generation_config: Optional[GenerationConfig] = None,
+    ) -> dict[str, dict[str, Union[str, float]]]:
+        return self._evaluate(
+            query, context, completion, eval_generation_config
+        )
diff --git a/R2R/r2r/base/providers/kg_provider.py b/R2R/r2r/base/providers/kg_provider.py
new file mode 100755
index 00000000..4ae96b11
--- /dev/null
+++ b/R2R/r2r/base/providers/kg_provider.py
@@ -0,0 +1,182 @@
+"""Base classes for knowledge graph providers."""
+
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Optional, Tuple
+
+from .prompt_provider import PromptProvider
+
+if TYPE_CHECKING:
+    from r2r.main import R2RClient
+
+from ...base.utils.base_utils import EntityType, Relation
+from ..abstractions.llama_abstractions import EntityNode, LabelledNode
+from ..abstractions.llama_abstractions import Relation as LlamaRelation
+from ..abstractions.llama_abstractions import VectorStoreQuery
+from ..abstractions.llm import GenerationConfig
+from .base_provider import ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class KGConfig(ProviderConfig):
+    """A base KG config class"""
+
+    provider: Optional[str] = None
+    batch_size: int = 1
+    kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
+    kg_agent_prompt: Optional[str] = "kg_agent"
+    kg_extraction_config: Optional[GenerationConfig] = None
+
+    def validate(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return [None, "neo4j"]
+
+
+class KGProvider(ABC):
+    """An abstract class to provide a common interface for Knowledge Graphs."""
+
+    def __init__(self, config: KGConfig) -> None:
+        if not isinstance(config, KGConfig):
+            raise ValueError(
+                "KGProvider must be initialized with a `KGConfig`."
+            )
+        logger.info(f"Initializing KG provider with config: {config}")
+        self.config = config
+        self.validate_config()
+
+    def validate_config(self) -> None:
+        self.config.validate()
+
+    @property
+    @abstractmethod
+    def client(self) -> Any:
+        """Get client."""
+        pass
+
+    @abstractmethod
+    def get(self, subj: str) -> list[list[str]]:
+        """Abstract method to get triplets."""
+        pass
+
+    @abstractmethod
+    def get_rel_map(
+        self,
+        subjs: Optional[list[str]] = None,
+        depth: int = 2,
+        limit: int = 30,
+    ) -> dict[str, list[list[str]]]:
+        """Abstract method to get depth-aware rel map."""
+        pass
+
+    @abstractmethod
+    def upsert_nodes(self, nodes: list[EntityNode]) -> None:
+        """Abstract method to add triplet."""
+        pass
+
+    @abstractmethod
+    def upsert_relations(self, relations: list[LlamaRelation]) -> None:
+        """Abstract method to add triplet."""
+        pass
+
+    @abstractmethod
+    def delete(self, subj: str, rel: str, obj: str) -> None:
+        """Abstract method to delete triplet."""
+        pass
+
+    @abstractmethod
+    def get_schema(self, refresh: bool = False) -> str:
+        """Abstract method to get the schema of the graph store."""
+        pass
+
+    @abstractmethod
+    def structured_query(
+        self, query: str, param_map: Optional[dict[str, Any]] = {}
+    ) -> Any:
+        """Abstract method to query the graph store with statement and parameters."""
+        pass
+
+    @abstractmethod
+    def vector_query(
+        self, query: VectorStoreQuery, **kwargs: Any
+    ) -> Tuple[list[LabelledNode], list[float]]:
+        """Abstract method to query the graph store with a vector store query."""
+
+    # TODO - Type this method.
+    @abstractmethod
+    def update_extraction_prompt(
+        self,
+        prompt_provider: Any,
+        entity_types: list[Any],
+        relations: list[Relation],
+    ):
+        """Abstract method to update the KG extraction prompt."""
+        pass
+
+    # TODO - Type this method.
+    @abstractmethod
+    def update_kg_agent_prompt(
+        self,
+        prompt_provider: Any,
+        entity_types: list[Any],
+        relations: list[Relation],
+    ):
+        """Abstract method to update the KG agent prompt."""
+        pass
+
+
+def escape_braces(s: str) -> str:
+    """
+    Escape braces in a string.
+    This is a placeholder function - implement the actual logic as needed.
+    """
+    # Implement your escape_braces logic here
+    return s.replace("{", "{{").replace("}", "}}")
+
+
+# TODO - Make this more configurable / intelligent
+def update_kg_prompt(
+    client: "R2RClient",
+    r2r_prompts: PromptProvider,
+    prompt_base: str,
+    entity_types: list[EntityType],
+    relations: list[Relation],
+) -> None:
+    # Get the default extraction template
+    template_name: str = f"{prompt_base}_with_spec"
+
+    new_template: str = r2r_prompts.get_prompt(
+        template_name,
+        {
+            "entity_types": json.dumps(
+                {
+                    "entity_types": [
+                        str(entity.name) for entity in entity_types
+                    ]
+                },
+                indent=4,
+            ),
+            "relations": json.dumps(
+                {"predicates": [str(relation.name) for relation in relations]},
+                indent=4,
+            ),
+            "input": """\n{input}""",
+        },
+    )
+
+    # Escape all braces in the template, except for the {input} placeholder, for formatting
+    escaped_template: str = escape_braces(new_template).replace(
+        """{{input}}""", """{input}"""
+    )
+
+    # Update the client's prompt
+    client.update_prompt(
+        prompt_base,
+        template=escaped_template,
+        input_types={"input": "str"},
+    )
diff --git a/R2R/r2r/base/providers/llm_provider.py b/R2R/r2r/base/providers/llm_provider.py
new file mode 100755
index 00000000..9b6499a4
--- /dev/null
+++ b/R2R/r2r/base/providers/llm_provider.py
@@ -0,0 +1,66 @@
+"""Base classes for language model providers."""
+
+import logging
+from abc import abstractmethod
+from typing import Optional
+
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LLMConfig(ProviderConfig):
+    """A base LLM config class"""
+
+    provider: Optional[str] = None
+    generation_config: Optional[GenerationConfig] = None
+
+    def validate(self) -> None:
+        if not self.provider:
+            raise ValueError("Provider must be set.")
+
+        if self.provider and self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["litellm", "openai"]
+
+
+class LLMProvider(Provider):
+    """An abstract class to provide a common interface for LLMs."""
+
+    def __init__(
+        self,
+        config: LLMConfig,
+    ) -> None:
+        if not isinstance(config, LLMConfig):
+            raise ValueError(
+                "LLMProvider must be initialized with a `LLMConfig`."
+            )
+        logger.info(f"Initializing LLM provider with config: {config}")
+
+        super().__init__(config)
+
+    @abstractmethod
+    def get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        """Abstract method to get a chat completion from the provider."""
+        pass
+
+    @abstractmethod
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletionChunk:
+        """Abstract method to get a completion stream from the provider."""
+        pass
diff --git a/R2R/r2r/base/providers/prompt_provider.py b/R2R/r2r/base/providers/prompt_provider.py
new file mode 100755
index 00000000..78af9e11
--- /dev/null
+++ b/R2R/r2r/base/providers/prompt_provider.py
@@ -0,0 +1,65 @@
+import logging
+from abc import abstractmethod
+from typing import Any, Optional
+
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class PromptConfig(ProviderConfig):
+    def validate(self) -> None:
+        pass
+
+    @property
+    def supported_providers(self) -> list[str]:
+        # Return a list of supported prompt providers
+        return ["default_prompt_provider"]
+
+
+class PromptProvider(Provider):
+    def __init__(self, config: Optional[PromptConfig] = None):
+        if config is None:
+            config = PromptConfig()
+        elif not isinstance(config, PromptConfig):
+            raise ValueError(
+                "PromptProvider must be initialized with a `PromptConfig`."
+            )
+        logger.info(f"Initializing PromptProvider with config {config}.")
+        super().__init__(config)
+
+    @abstractmethod
+    def add_prompt(
+        self, name: str, template: str, input_types: dict[str, str]
+    ) -> None:
+        pass
+
+    @abstractmethod
+    def get_prompt(
+        self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
+    ) -> str:
+        pass
+
+    @abstractmethod
+    def get_all_prompts(self) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    def update_prompt(
+        self,
+        name: str,
+        template: Optional[str] = None,
+        input_types: Optional[dict[str, str]] = None,
+    ) -> None:
+        pass
+
+    def _get_message_payload(
+        self, system_prompt: str, task_prompt: str
+    ) -> dict:
+        return [
+            {
+                "role": "system",
+                "content": system_prompt,
+            },
+            {"role": "user", "content": task_prompt},
+        ]
diff --git a/R2R/r2r/base/providers/vector_db_provider.py b/R2R/r2r/base/providers/vector_db_provider.py
new file mode 100755
index 00000000..a6d5aaa8
--- /dev/null
+++ b/R2R/r2r/base/providers/vector_db_provider.py
@@ -0,0 +1,142 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+
+from ..abstractions.document import DocumentInfo
+from ..abstractions.search import VectorSearchResult
+from ..abstractions.vector import VectorEntry
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class VectorDBConfig(ProviderConfig):
+    provider: str
+
+    def __post_init__(self):
+        self.validate()
+        # Capture additional fields
+        for key, value in self.extra_fields.items():
+            setattr(self, key, value)
+
+    def validate(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["local", "pgvector"]
+
+
+class VectorDBProvider(Provider, ABC):
+    def __init__(self, config: VectorDBConfig):
+        if not isinstance(config, VectorDBConfig):
+            raise ValueError(
+                "VectorDBProvider must be initialized with a `VectorDBConfig`."
+            )
+        logger.info(f"Initializing VectorDBProvider with config {config}.")
+        super().__init__(config)
+
+    @abstractmethod
+    def initialize_collection(self, dimension: int) -> None:
+        pass
+
+    @abstractmethod
+    def copy(self, entry: VectorEntry, commit: bool = True) -> None:
+        pass
+
+    @abstractmethod
+    def upsert(self, entry: VectorEntry, commit: bool = True) -> None:
+        pass
+
+    @abstractmethod
+    def search(
+        self,
+        query_vector: list[float],
+        filters: dict[str, Union[bool, int, str]] = {},
+        limit: int = 10,
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        pass
+
+    @abstractmethod
+    def hybrid_search(
+        self,
+        query_text: str,
+        query_vector: list[float],
+        limit: int = 10,
+        filters: Optional[dict[str, Union[bool, int, str]]] = None,
+        # Hybrid search parameters
+        full_text_weight: float = 1.0,
+        semantic_weight: float = 1.0,
+        rrf_k: int = 20,  # typical value is ~2x the number of results you want
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        pass
+
+    @abstractmethod
+    def create_index(self, index_type, column_name, index_options):
+        pass
+
+    def upsert_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        for entry in entries:
+            self.upsert(entry, commit=commit)
+
+    def copy_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        for entry in entries:
+            self.copy(entry, commit=commit)
+
+    @abstractmethod
+    def delete_by_metadata(
+        self,
+        metadata_fields: list[str],
+        metadata_values: list[Union[bool, int, str]],
+    ) -> list[str]:
+        if len(metadata_fields) != len(metadata_values):
+            raise ValueError(
+                "The number of metadata fields and values must be equal."
+            )
+        pass
+
+    @abstractmethod
+    def get_metadatas(
+        self,
+        metadata_fields: list[str],
+        filter_field: Optional[str] = None,
+        filter_value: Optional[str] = None,
+    ) -> list[str]:
+        pass
+
+    @abstractmethod
+    def upsert_documents_overview(
+        self, document_infs: list[DocumentInfo]
+    ) -> None:
+        pass
+
+    @abstractmethod
+    def get_documents_overview(
+        self,
+        filter_document_ids: Optional[list[str]] = None,
+        filter_user_ids: Optional[list[str]] = None,
+    ) -> list[DocumentInfo]:
+        pass
+
+    @abstractmethod
+    def get_document_chunks(self, document_id: str) -> list[dict]:
+        pass
+
+    @abstractmethod
+    def delete_from_documents_overview(
+        self, document_id: str, version: Optional[str] = None
+    ) -> dict:
+        pass
+
+    @abstractmethod
+    def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict:
+        pass