aboutsummaryrefslogtreecommitdiff
"""Base classes for knowledge graph providers."""

import json
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Tuple

from .prompt_provider import PromptProvider

if TYPE_CHECKING:
    from r2r.main import R2RClient

from ...base.utils.base_utils import EntityType, Relation
from ..abstractions.llama_abstractions import EntityNode, LabelledNode
from ..abstractions.llama_abstractions import Relation as LlamaRelation
from ..abstractions.llama_abstractions import VectorStoreQuery
from ..abstractions.llm import GenerationConfig
from .base_provider import ProviderConfig

logger = logging.getLogger(__name__)


class KGConfig(ProviderConfig):
    """A base KG config class"""

    provider: Optional[str] = None
    batch_size: int = 1
    kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
    kg_agent_prompt: Optional[str] = "kg_agent"
    kg_extraction_config: Optional[GenerationConfig] = None

    def validate(self) -> None:
        if self.provider not in self.supported_providers:
            raise ValueError(f"Provider '{self.provider}' is not supported.")

    @property
    def supported_providers(self) -> list[str]:
        return [None, "neo4j"]


class KGProvider(ABC):
    """An abstract class to provide a common interface for Knowledge Graphs."""

    def __init__(self, config: KGConfig) -> None:
        if not isinstance(config, KGConfig):
            raise ValueError(
                "KGProvider must be initialized with a `KGConfig`."
            )
        logger.info(f"Initializing KG provider with config: {config}")
        self.config = config
        self.validate_config()

    def validate_config(self) -> None:
        self.config.validate()

    @property
    @abstractmethod
    def client(self) -> Any:
        """Get client."""
        pass

    @abstractmethod
    def get(self, subj: str) -> list[list[str]]:
        """Abstract method to get triplets."""
        pass

    @abstractmethod
    def get_rel_map(
        self,
        subjs: Optional[list[str]] = None,
        depth: int = 2,
        limit: int = 30,
    ) -> dict[str, list[list[str]]]:
        """Abstract method to get depth-aware rel map."""
        pass

    @abstractmethod
    def upsert_nodes(self, nodes: list[EntityNode]) -> None:
        """Abstract method to add triplet."""
        pass

    @abstractmethod
    def upsert_relations(self, relations: list[LlamaRelation]) -> None:
        """Abstract method to add triplet."""
        pass

    @abstractmethod
    def delete(self, subj: str, rel: str, obj: str) -> None:
        """Abstract method to delete triplet."""
        pass

    @abstractmethod
    def get_schema(self, refresh: bool = False) -> str:
        """Abstract method to get the schema of the graph store."""
        pass

    @abstractmethod
    def structured_query(
        self, query: str, param_map: Optional[dict[str, Any]] = {}
    ) -> Any:
        """Abstract method to query the graph store with statement and parameters."""
        pass

    @abstractmethod
    def vector_query(
        self, query: VectorStoreQuery, **kwargs: Any
    ) -> Tuple[list[LabelledNode], list[float]]:
        """Abstract method to query the graph store with a vector store query."""

    # TODO - Type this method.
    @abstractmethod
    def update_extraction_prompt(
        self,
        prompt_provider: Any,
        entity_types: list[Any],
        relations: list[Relation],
    ):
        """Abstract method to update the KG extraction prompt."""
        pass

    # TODO - Type this method.
    @abstractmethod
    def update_kg_agent_prompt(
        self,
        prompt_provider: Any,
        entity_types: list[Any],
        relations: list[Relation],
    ):
        """Abstract method to update the KG agent prompt."""
        pass


def escape_braces(s: str) -> str:
    """
    Escape braces in a string.
    This is a placeholder function - implement the actual logic as needed.
    """
    # Implement your escape_braces logic here
    return s.replace("{", "{{").replace("}", "}}")


# TODO - Make this more configurable / intelligent
def update_kg_prompt(
    client: "R2RClient",
    r2r_prompts: PromptProvider,
    prompt_base: str,
    entity_types: list[EntityType],
    relations: list[Relation],
) -> None:
    # Get the default extraction template
    template_name: str = f"{prompt_base}_with_spec"

    new_template: str = r2r_prompts.get_prompt(
        template_name,
        {
            "entity_types": json.dumps(
                {
                    "entity_types": [
                        str(entity.name) for entity in entity_types
                    ]
                },
                indent=4,
            ),
            "relations": json.dumps(
                {"predicates": [str(relation.name) for relation in relations]},
                indent=4,
            ),
            "input": """\n{input}""",
        },
    )

    # Escape all braces in the template, except for the {input} placeholder, for formatting
    escaped_template: str = escape_braces(new_template).replace(
        """{{input}}""", """{input}"""
    )

    # Update the client's prompt
    client.update_prompt(
        prompt_base,
        template=escaped_template,
        input_types={"input": "str"},
    )