about summary refs log tree commit diff
path: root/R2R/r2r/base/providers/kg_provider.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/providers/kg_provider.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/base/providers/kg_provider.py')
-rwxr-xr-xR2R/r2r/base/providers/kg_provider.py182
1 files changed, 182 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/kg_provider.py b/R2R/r2r/base/providers/kg_provider.py
new file mode 100755
index 00000000..4ae96b11
--- /dev/null
+++ b/R2R/r2r/base/providers/kg_provider.py
@@ -0,0 +1,182 @@
+"""Base classes for knowledge graph providers."""
+
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Optional, Tuple
+
+from .prompt_provider import PromptProvider
+
+if TYPE_CHECKING:
+    from r2r.main import R2RClient
+
+from ...base.utils.base_utils import EntityType, Relation
+from ..abstractions.llama_abstractions import EntityNode, LabelledNode
+from ..abstractions.llama_abstractions import Relation as LlamaRelation
+from ..abstractions.llama_abstractions import VectorStoreQuery
+from ..abstractions.llm import GenerationConfig
+from .base_provider import ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class KGConfig(ProviderConfig):
+    """A base KG config class"""
+
+    provider: Optional[str] = None
+    batch_size: int = 1
+    kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
+    kg_agent_prompt: Optional[str] = "kg_agent"
+    kg_extraction_config: Optional[GenerationConfig] = None
+
+    def validate(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return [None, "neo4j"]
+
+
+class KGProvider(ABC):
+    """An abstract class to provide a common interface for Knowledge Graphs."""
+
+    def __init__(self, config: KGConfig) -> None:
+        if not isinstance(config, KGConfig):
+            raise ValueError(
+                "KGProvider must be initialized with a `KGConfig`."
+            )
+        logger.info(f"Initializing KG provider with config: {config}")
+        self.config = config
+        self.validate_config()
+
+    def validate_config(self) -> None:
+        self.config.validate()
+
+    @property
+    @abstractmethod
+    def client(self) -> Any:
+        """Get client."""
+        pass
+
+    @abstractmethod
+    def get(self, subj: str) -> list[list[str]]:
+        """Abstract method to get triplets."""
+        pass
+
+    @abstractmethod
+    def get_rel_map(
+        self,
+        subjs: Optional[list[str]] = None,
+        depth: int = 2,
+        limit: int = 30,
+    ) -> dict[str, list[list[str]]]:
+        """Abstract method to get depth-aware rel map."""
+        pass
+
+    @abstractmethod
+    def upsert_nodes(self, nodes: list[EntityNode]) -> None:
+        """Abstract method to add triplet."""
+        pass
+
+    @abstractmethod
+    def upsert_relations(self, relations: list[LlamaRelation]) -> None:
+        """Abstract method to add triplet."""
+        pass
+
+    @abstractmethod
+    def delete(self, subj: str, rel: str, obj: str) -> None:
+        """Abstract method to delete triplet."""
+        pass
+
+    @abstractmethod
+    def get_schema(self, refresh: bool = False) -> str:
+        """Abstract method to get the schema of the graph store."""
+        pass
+
+    @abstractmethod
+    def structured_query(
+        self, query: str, param_map: Optional[dict[str, Any]] = {}
+    ) -> Any:
+        """Abstract method to query the graph store with statement and parameters."""
+        pass
+
+    @abstractmethod
+    def vector_query(
+        self, query: VectorStoreQuery, **kwargs: Any
+    ) -> Tuple[list[LabelledNode], list[float]]:
+        """Abstract method to query the graph store with a vector store query."""
+
+    # TODO - Type this method.
+    @abstractmethod
+    def update_extraction_prompt(
+        self,
+        prompt_provider: Any,
+        entity_types: list[Any],
+        relations: list[Relation],
+    ):
+        """Abstract method to update the KG extraction prompt."""
+        pass
+
+    # TODO - Type this method.
+    @abstractmethod
+    def update_kg_agent_prompt(
+        self,
+        prompt_provider: Any,
+        entity_types: list[Any],
+        relations: list[Relation],
+    ):
+        """Abstract method to update the KG agent prompt."""
+        pass
+
+
+def escape_braces(s: str) -> str:
+    """
+    Escape braces in a string.
+    This is a placeholder function - implement the actual logic as needed.
+    """
+    # Implement your escape_braces logic here
+    return s.replace("{", "{{").replace("}", "}}")
+
+
+# TODO - Make this more configurable / intelligent
+def update_kg_prompt(
+    client: "R2RClient",
+    r2r_prompts: PromptProvider,
+    prompt_base: str,
+    entity_types: list[EntityType],
+    relations: list[Relation],
+) -> None:
+    # Get the default extraction template
+    template_name: str = f"{prompt_base}_with_spec"
+
+    new_template: str = r2r_prompts.get_prompt(
+        template_name,
+        {
+            "entity_types": json.dumps(
+                {
+                    "entity_types": [
+                        str(entity.name) for entity in entity_types
+                    ]
+                },
+                indent=4,
+            ),
+            "relations": json.dumps(
+                {"predicates": [str(relation.name) for relation in relations]},
+                indent=4,
+            ),
+            "input": """\n{input}""",
+        },
+    )
+
+    # Escape all braces in the template, except for the {input} placeholder, for formatting
+    escaped_template: str = escape_braces(new_template).replace(
+        """{{input}}""", """{input}"""
+    )
+
+    # Update the client's prompt
+    client.update_prompt(
+        prompt_base,
+        template=escaped_template,
+        input_types={"input": "str"},
+    )