diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/providers/kg_provider.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/base/providers/kg_provider.py')
-rwxr-xr-x | R2R/r2r/base/providers/kg_provider.py | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/kg_provider.py b/R2R/r2r/base/providers/kg_provider.py new file mode 100755 index 00000000..4ae96b11 --- /dev/null +++ b/R2R/r2r/base/providers/kg_provider.py @@ -0,0 +1,182 @@ +"""Base classes for knowledge graph providers.""" + +import json +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, Tuple + +from .prompt_provider import PromptProvider + +if TYPE_CHECKING: + from r2r.main import R2RClient + +from ...base.utils.base_utils import EntityType, Relation +from ..abstractions.llama_abstractions import EntityNode, LabelledNode +from ..abstractions.llama_abstractions import Relation as LlamaRelation +from ..abstractions.llama_abstractions import VectorStoreQuery +from ..abstractions.llm import GenerationConfig +from .base_provider import ProviderConfig + +logger = logging.getLogger(__name__) + + +class KGConfig(ProviderConfig): + """A base KG config class""" + + provider: Optional[str] = None + batch_size: int = 1 + kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction" + kg_agent_prompt: Optional[str] = "kg_agent" + kg_extraction_config: Optional[GenerationConfig] = None + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return [None, "neo4j"] + + +class KGProvider(ABC): + """An abstract class to provide a common interface for Knowledge Graphs.""" + + def __init__(self, config: KGConfig) -> None: + if not isinstance(config, KGConfig): + raise ValueError( + "KGProvider must be initialized with a `KGConfig`." + ) + logger.info(f"Initializing KG provider with config: {config}") + self.config = config + self.validate_config() + + def validate_config(self) -> None: + self.config.validate() + + @property + @abstractmethod + def client(self) -> Any: + """Get client.""" + pass + + @abstractmethod + def get(self, subj: str) -> list[list[str]]: + """Abstract method to get triplets.""" + pass + + @abstractmethod + def get_rel_map( + self, + subjs: Optional[list[str]] = None, + depth: int = 2, + limit: int = 30, + ) -> dict[str, list[list[str]]]: + """Abstract method to get depth-aware rel map.""" + pass + + @abstractmethod + def upsert_nodes(self, nodes: list[EntityNode]) -> None: + """Abstract method to add triplet.""" + pass + + @abstractmethod + def upsert_relations(self, relations: list[LlamaRelation]) -> None: + """Abstract method to add triplet.""" + pass + + @abstractmethod + def delete(self, subj: str, rel: str, obj: str) -> None: + """Abstract method to delete triplet.""" + pass + + @abstractmethod + def get_schema(self, refresh: bool = False) -> str: + """Abstract method to get the schema of the graph store.""" + pass + + @abstractmethod + def structured_query( + self, query: str, param_map: Optional[dict[str, Any]] = {} + ) -> Any: + """Abstract method to query the graph store with statement and parameters.""" + pass + + @abstractmethod + def vector_query( + self, query: VectorStoreQuery, **kwargs: Any + ) -> Tuple[list[LabelledNode], list[float]]: + """Abstract method to query the graph store with a vector store query.""" + + # TODO - Type this method. + @abstractmethod + def update_extraction_prompt( + self, + prompt_provider: Any, + entity_types: list[Any], + relations: list[Relation], + ): + """Abstract method to update the KG extraction prompt.""" + pass + + # TODO - Type this method. + @abstractmethod + def update_kg_agent_prompt( + self, + prompt_provider: Any, + entity_types: list[Any], + relations: list[Relation], + ): + """Abstract method to update the KG agent prompt.""" + pass + + +def escape_braces(s: str) -> str: + """ + Escape braces in a string. + This is a placeholder function - implement the actual logic as needed. + """ + # Implement your escape_braces logic here + return s.replace("{", "{{").replace("}", "}}") + + +# TODO - Make this more configurable / intelligent +def update_kg_prompt( + client: "R2RClient", + r2r_prompts: PromptProvider, + prompt_base: str, + entity_types: list[EntityType], + relations: list[Relation], +) -> None: + # Get the default extraction template + template_name: str = f"{prompt_base}_with_spec" + + new_template: str = r2r_prompts.get_prompt( + template_name, + { + "entity_types": json.dumps( + { + "entity_types": [ + str(entity.name) for entity in entity_types + ] + }, + indent=4, + ), + "relations": json.dumps( + {"predicates": [str(relation.name) for relation in relations]}, + indent=4, + ), + "input": """\n{input}""", + }, + ) + + # Escape all braces in the template, except for the {input} placeholder, for formatting + escaped_template: str = escape_braces(new_template).replace( + """{{input}}""", """{input}""" + ) + + # Update the client's prompt + client.update_prompt( + prompt_base, + template=escaped_template, + input_types={"input": "str"}, + ) |