aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/graph.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/graph.py257
1 files changed, 257 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
new file mode 100644
index 00000000..3c1cec9e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py
@@ -0,0 +1,257 @@
+import json
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional
+from uuid import UUID
+
+from pydantic import Field
+
+from ..abstractions.llm import GenerationConfig
+from .base import R2RSerializable
+
+
+class Entity(R2RSerializable):
+ """An entity extracted from a document."""
+
+ name: str
+ description: Optional[str] = None
+ category: Optional[str] = None
+ metadata: Optional[dict[str, Any]] = None
+
+ id: Optional[UUID] = None
+ parent_id: Optional[UUID] = None # graph_id | document_id
+ description_embedding: Optional[list[float] | str] = None
+ chunk_ids: Optional[list[UUID]] = []
+
+ def __str__(self):
+ return f"{self.name}:{self.category}"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(self.metadata, str):
+ try:
+ self.metadata = json.loads(self.metadata)
+ except json.JSONDecodeError:
+ self.metadata = self.metadata
+
+
+class Relationship(R2RSerializable):
+ """A relationship between two entities.
+
+ This is a generic relationship, and can be used to represent any type of
+ relationship between any two entities.
+ """
+
+ id: Optional[UUID] = None
+ subject: str
+ predicate: str
+ object: str
+ description: Optional[str] = None
+ subject_id: Optional[UUID] = None
+ object_id: Optional[UUID] = None
+ weight: float | None = 1.0
+ chunk_ids: Optional[list[UUID]] = []
+ parent_id: Optional[UUID] = None
+ description_embedding: Optional[list[float] | str] = None
+ metadata: Optional[dict[str, Any] | str] = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(self.metadata, str):
+ try:
+ self.metadata = json.loads(self.metadata)
+ except json.JSONDecodeError:
+ self.metadata = self.metadata
+
+
+@dataclass
+class Community(R2RSerializable):
+ name: str = ""
+ summary: str = ""
+ level: Optional[int] = None
+ findings: list[str] = []
+ id: Optional[int | UUID] = None
+ community_id: Optional[UUID] = None
+ collection_id: Optional[UUID] = None
+ rating: Optional[float] = None
+ rating_explanation: Optional[str] = None
+ description_embedding: Optional[list[float]] = None
+ attributes: dict[str, Any] | None = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+
+ def __init__(self, **kwargs):
+ if isinstance(kwargs.get("attributes", None), str):
+ kwargs["attributes"] = json.loads(kwargs["attributes"])
+
+ if isinstance(kwargs.get("embedding", None), str):
+ kwargs["embedding"] = json.loads(kwargs["embedding"])
+
+ super().__init__(**kwargs)
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any] | str) -> "Community":
+ parsed_data: dict[str, Any] = (
+ json.loads(data) if isinstance(data, str) else data
+ )
+ if isinstance(parsed_data.get("embedding", None), str):
+ parsed_data["embedding"] = json.loads(parsed_data["embedding"])
+ return cls(**parsed_data)
+
+
+class GraphExtraction(R2RSerializable):
+ """A protocol for a knowledge graph extraction."""
+
+ entities: list[Entity]
+ relationships: list[Relationship]
+
+
+class Graph(R2RSerializable):
+ id: UUID | None = Field()
+ name: str
+ description: Optional[str] = None
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ )
+ status: str = "pending"
+
+ class Config:
+ populate_by_name = True
+ from_attributes = True
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
+ """Create a Graph instance from a dictionary."""
+ # Convert string to dict if needed
+ parsed_data: dict[str, Any] = (
+ json.loads(data) if isinstance(data, str) else data
+ )
+ return cls(**parsed_data)
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+
+class StoreType(str, Enum):
+ GRAPHS = "graphs"
+ DOCUMENTS = "documents"
+
+
+class GraphCreationSettings(R2RSerializable):
+ """Settings for knowledge graph creation."""
+
+ graph_extraction_prompt: str = Field(
+ default="graph_extraction",
+ description="The prompt to use for knowledge graph extraction.",
+ )
+
+ graph_entity_description_prompt: str = Field(
+ default="graph_entity_description",
+ description="The prompt to use for entity description generation.",
+ )
+
+ entity_types: list[str] = Field(
+ default=[],
+ description="The types of entities to extract.",
+ )
+
+ relation_types: list[str] = Field(
+ default=[],
+ description="The types of relations to extract.",
+ )
+
+ chunk_merge_count: int = Field(
+ default=2,
+ description="""The number of extractions to merge into a single graph
+ extraction.""",
+ )
+
+ max_knowledge_relationships: int = Field(
+ default=100,
+ description="""The maximum number of knowledge relationships to extract
+ from each chunk.""",
+ )
+
+ max_description_input_length: int = Field(
+ default=65536,
+ description="""The maximum length of the description for a node in the
+ graph.""",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ automatic_deduplication: bool = Field(
+ default=False,
+ description="Whether to automatically deduplicate entities.",
+ )
+
+
+class GraphEnrichmentSettings(R2RSerializable):
+ """Settings for knowledge graph enrichment."""
+
+ force_graph_search_results_enrichment: bool = Field(
+ default=False,
+ description="""Force run the enrichment step even if graph creation is
+ still in progress for some documents.""",
+ )
+
+ graph_communities_prompt: str = Field(
+ default="graph_communities",
+ description="The prompt to use for knowledge graph enrichment.",
+ )
+
+ max_summary_input_length: int = Field(
+ default=65536,
+ description="The maximum length of the summary for a community.",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ leiden_params: dict = Field(
+ default_factory=dict,
+ description="Parameters for the Leiden algorithm.",
+ )
+
+
+class GraphCommunitySettings(R2RSerializable):
+ """Settings for knowledge graph community enrichment."""
+
+ force_graph_search_results_enrichment: bool = Field(
+ default=False,
+ description="""Force run the enrichment step even if graph creation is
+ still in progress for some documents.""",
+ )
+
+ graph_communities: str = Field(
+ default="graph_communities",
+ description="The prompt to use for knowledge graph enrichment.",
+ )
+
+ max_summary_input_length: int = Field(
+ default=65536,
+ description="The maximum length of the summary for a community.",
+ )
+
+ generation_config: Optional[GenerationConfig] = Field(
+ default=None,
+ description="Configuration for text generation during graph enrichment.",
+ )
+
+ leiden_params: dict = Field(
+ default_factory=dict,
+ description="Parameters for the Leiden algorithm.",
+ )