about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/graph.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/graph.py257
1 files changed, 257 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
new file mode 100644
index 00000000..3c1cec9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
@@ -0,0 +1,257 @@
+import json
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import Field
+
+from ..abstractions.llm import GenerationConfig
+from .base import R2RSerializable
+
+
+class Entity(R2RSerializable):
+    """An entity extracted from a document."""
+
+    name: str
+    description: Optional[str] = None
+    category: Optional[str] = None
+    metadata: Optional[dict[str, Any]] = None
+
+    id: Optional[UUID] = None
+    parent_id: Optional[UUID] = None  # graph_id | document_id
+    description_embedding: Optional[list[float] | str] = None
+    chunk_ids: Optional[list[UUID]] = []
+
+    def __str__(self):
+        return f"{self.name}:{self.category}"
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        if isinstance(self.metadata, str):
+            try:
+                self.metadata = json.loads(self.metadata)
+            except json.JSONDecodeError:
+                self.metadata = self.metadata
+
+
+class Relationship(R2RSerializable):
+    """A relationship between two entities.
+
+    This is a generic relationship, and can be used to represent any type of
+    relationship between any two entities.
+    """
+
+    id: Optional[UUID] = None
+    subject: str
+    predicate: str
+    object: str
+    description: Optional[str] = None
+    subject_id: Optional[UUID] = None
+    object_id: Optional[UUID] = None
+    weight: float | None = 1.0
+    chunk_ids: Optional[list[UUID]] = []
+    parent_id: Optional[UUID] = None
+    description_embedding: Optional[list[float] | str] = None
+    metadata: Optional[dict[str, Any] | str] = None
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        if isinstance(self.metadata, str):
+            try:
+                self.metadata = json.loads(self.metadata)
+            except json.JSONDecodeError:
+                self.metadata = self.metadata
+
+
+@dataclass
+class Community(R2RSerializable):
+    name: str = ""
+    summary: str = ""
+    level: Optional[int] = None
+    findings: list[str] = []
+    id: Optional[int | UUID] = None
+    community_id: Optional[UUID] = None
+    collection_id: Optional[UUID] = None
+    rating: Optional[float] = None
+    rating_explanation: Optional[str] = None
+    description_embedding: Optional[list[float]] = None
+    attributes: dict[str, Any] | None = None
+    created_at: datetime = Field(
+        default_factory=datetime.utcnow,
+    )
+    updated_at: datetime = Field(
+        default_factory=datetime.utcnow,
+    )
+
+    def __init__(self, **kwargs):
+        if isinstance(kwargs.get("attributes", None), str):
+            kwargs["attributes"] = json.loads(kwargs["attributes"])
+
+        if isinstance(kwargs.get("embedding", None), str):
+            kwargs["embedding"] = json.loads(kwargs["embedding"])
+
+        super().__init__(**kwargs)
+
+    @classmethod
+    def from_dict(cls, data: dict[str, Any] | str) -> "Community":
+        parsed_data: dict[str, Any] = (
+            json.loads(data) if isinstance(data, str) else data
+        )
+        if isinstance(parsed_data.get("embedding", None), str):
+            parsed_data["embedding"] = json.loads(parsed_data["embedding"])
+        return cls(**parsed_data)
+
+
+class GraphExtraction(R2RSerializable):
+    """A protocol for a knowledge graph extraction."""
+
+    entities: list[Entity]
+    relationships: list[Relationship]
+
+
+class Graph(R2RSerializable):
+    id: UUID | None = Field()
+    name: str
+    description: Optional[str] = None
+    created_at: datetime = Field(
+        default_factory=datetime.utcnow,
+    )
+    updated_at: datetime = Field(
+        default_factory=datetime.utcnow,
+    )
+    status: str = "pending"
+
+    class Config:
+        populate_by_name = True
+        from_attributes = True
+
+    @classmethod
+    def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
+        """Create a Graph instance from a dictionary."""
+        # Convert string to dict if needed
+        parsed_data: dict[str, Any] = (
+            json.loads(data) if isinstance(data, str) else data
+        )
+        return cls(**parsed_data)
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+
+class StoreType(str, Enum):
+    GRAPHS = "graphs"
+    DOCUMENTS = "documents"
+
+
+class GraphCreationSettings(R2RSerializable):
+    """Settings for knowledge graph creation."""
+
+    graph_extraction_prompt: str = Field(
+        default="graph_extraction",
+        description="The prompt to use for knowledge graph extraction.",
+    )
+
+    graph_entity_description_prompt: str = Field(
+        default="graph_entity_description",
+        description="The prompt to use for entity description generation.",
+    )
+
+    entity_types: list[str] = Field(
+        default=[],
+        description="The types of entities to extract.",
+    )
+
+    relation_types: list[str] = Field(
+        default=[],
+        description="The types of relations to extract.",
+    )
+
+    chunk_merge_count: int = Field(
+        default=2,
+        description="""The number of extractions to merge into a single graph
+        extraction.""",
+    )
+
+    max_knowledge_relationships: int = Field(
+        default=100,
+        description="""The maximum number of knowledge relationships to extract
+        from each chunk.""",
+    )
+
+    max_description_input_length: int = Field(
+        default=65536,
+        description="""The maximum length of the description for a node in the
+        graph.""",
+    )
+
+    generation_config: Optional[GenerationConfig] = Field(
+        default=None,
+        description="Configuration for text generation during graph enrichment.",
+    )
+
+    automatic_deduplication: bool = Field(
+        default=False,
+        description="Whether to automatically deduplicate entities.",
+    )
+
+
+class GraphEnrichmentSettings(R2RSerializable):
+    """Settings for knowledge graph enrichment."""
+
+    force_graph_search_results_enrichment: bool = Field(
+        default=False,
+        description="""Force run the enrichment step even if graph creation is
+        still in progress for some documents.""",
+    )
+
+    graph_communities_prompt: str = Field(
+        default="graph_communities",
+        description="The prompt to use for knowledge graph enrichment.",
+    )
+
+    max_summary_input_length: int = Field(
+        default=65536,
+        description="The maximum length of the summary for a community.",
+    )
+
+    generation_config: Optional[GenerationConfig] = Field(
+        default=None,
+        description="Configuration for text generation during graph enrichment.",
+    )
+
+    leiden_params: dict = Field(
+        default_factory=dict,
+        description="Parameters for the Leiden algorithm.",
+    )
+
+
+class GraphCommunitySettings(R2RSerializable):
+    """Settings for knowledge graph community enrichment."""
+
+    force_graph_search_results_enrichment: bool = Field(
+        default=False,
+        description="""Force run the enrichment step even if graph creation is
+        still in progress for some documents.""",
+    )
+
+    graph_communities: str = Field(
+        default="graph_communities",
+        description="The prompt to use for knowledge graph enrichment.",
+    )
+
+    max_summary_input_length: int = Field(
+        default=65536,
+        description="The maximum length of the summary for a community.",
+    )
+
+    generation_config: Optional[GenerationConfig] = Field(
+        default=None,
+        description="Configuration for text generation during graph enrichment.",
+    )
+
+    leiden_params: dict = Field(
+        default_factory=dict,
+        description="Parameters for the Leiden algorithm.",
+    )