aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/providers/kg_provider.py
blob: 4ae96b1162d031afbff85e62246b4d8060d7b4fa (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""Base classes for knowledge graph providers."""

import json
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Tuple

from .prompt_provider import PromptProvider

if TYPE_CHECKING:
    from r2r.main import R2RClient

from ...base.utils.base_utils import EntityType, Relation
from ..abstractions.llama_abstractions import EntityNode, LabelledNode
from ..abstractions.llama_abstractions import Relation as LlamaRelation
from ..abstractions.llama_abstractions import VectorStoreQuery
from ..abstractions.llm import GenerationConfig
from .base_provider import ProviderConfig

logger = logging.getLogger(__name__)


class KGConfig(ProviderConfig):
    """A base KG config class"""

    provider: Optional[str] = None
    batch_size: int = 1
    kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
    kg_agent_prompt: Optional[str] = "kg_agent"
    kg_extraction_config: Optional[GenerationConfig] = None

    def validate(self) -> None:
        if self.provider not in self.supported_providers:
            raise ValueError(f"Provider '{self.provider}' is not supported.")

    @property
    def supported_providers(self) -> list[str]:
        return [None, "neo4j"]


class KGProvider(ABC):
    """An abstract class to provide a common interface for Knowledge Graphs."""

    def __init__(self, config: KGConfig) -> None:
        if not isinstance(config, KGConfig):
            raise ValueError(
                "KGProvider must be initialized with a `KGConfig`."
            )
        logger.info(f"Initializing KG provider with config: {config}")
        self.config = config
        self.validate_config()

    def validate_config(self) -> None:
        self.config.validate()

    @property
    @abstractmethod
    def client(self) -> Any:
        """Get client."""
        pass

    @abstractmethod
    def get(self, subj: str) -> list[list[str]]:
        """Abstract method to get triplets."""
        pass

    @abstractmethod
    def get_rel_map(
        self,
        subjs: Optional[list[str]] = None,
        depth: int = 2,
        limit: int = 30,
    ) -> dict[str, list[list[str]]]:
        """Abstract method to get depth-aware rel map."""
        pass

    @abstractmethod
    def upsert_nodes(self, nodes: list[EntityNode]) -> None:
        """Abstract method to add triplet."""
        pass

    @abstractmethod
    def upsert_relations(self, relations: list[LlamaRelation]) -> None:
        """Abstract method to add triplet."""
        pass

    @abstractmethod
    def delete(self, subj: str, rel: str, obj: str) -> None:
        """Abstract method to delete triplet."""
        pass

    @abstractmethod
    def get_schema(self, refresh: bool = False) -> str:
        """Abstract method to get the schema of the graph store."""
        pass

    @abstractmethod
    def structured_query(
        self, query: str, param_map: Optional[dict[str, Any]] = {}
    ) -> Any:
        """Abstract method to query the graph store with statement and parameters."""
        pass

    @abstractmethod
    def vector_query(
        self, query: VectorStoreQuery, **kwargs: Any
    ) -> Tuple[list[LabelledNode], list[float]]:
        """Abstract method to query the graph store with a vector store query."""

    # TODO - Type this method.
    @abstractmethod
    def update_extraction_prompt(
        self,
        prompt_provider: Any,
        entity_types: list[Any],
        relations: list[Relation],
    ):
        """Abstract method to update the KG extraction prompt."""
        pass

    # TODO - Type this method.
    @abstractmethod
    def update_kg_agent_prompt(
        self,
        prompt_provider: Any,
        entity_types: list[Any],
        relations: list[Relation],
    ):
        """Abstract method to update the KG agent prompt."""
        pass


def escape_braces(s: str) -> str:
    """
    Escape braces in a string.
    This is a placeholder function - implement the actual logic as needed.
    """
    # Implement your escape_braces logic here
    return s.replace("{", "{{").replace("}", "}}")


# TODO - Make this more configurable / intelligent
def update_kg_prompt(
    client: "R2RClient",
    r2r_prompts: PromptProvider,
    prompt_base: str,
    entity_types: list[EntityType],
    relations: list[Relation],
) -> None:
    # Get the default extraction template
    template_name: str = f"{prompt_base}_with_spec"

    new_template: str = r2r_prompts.get_prompt(
        template_name,
        {
            "entity_types": json.dumps(
                {
                    "entity_types": [
                        str(entity.name) for entity in entity_types
                    ]
                },
                indent=4,
            ),
            "relations": json.dumps(
                {"predicates": [str(relation.name) for relation in relations]},
                indent=4,
            ),
            "input": """\n{input}""",
        },
    )

    # Escape all braces in the template, except for the {input} placeholder, for formatting
    escaped_template: str = escape_braces(new_template).replace(
        """{{input}}""", """{input}"""
    )

    # Update the client's prompt
    client.update_prompt(
        prompt_base,
        template=escaped_template,
        input_types={"input": "str"},
    )