diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/shared/abstractions/graph.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/graph.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/shared/abstractions/graph.py | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py new file mode 100644 index 00000000..3c1cec9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/graph.py @@ -0,0 +1,257 @@ +import json +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Optional +from uuid import UUID + +from pydantic import Field + +from ..abstractions.llm import GenerationConfig +from .base import R2RSerializable + + +class Entity(R2RSerializable): + """An entity extracted from a document.""" + + name: str + description: Optional[str] = None + category: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + id: Optional[UUID] = None + parent_id: Optional[UUID] = None # graph_id | document_id + description_embedding: Optional[list[float] | str] = None + chunk_ids: Optional[list[UUID]] = [] + + def __str__(self): + return f"{self.name}:{self.category}" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.metadata, str): + try: + self.metadata = json.loads(self.metadata) + except json.JSONDecodeError: + self.metadata = self.metadata + + +class Relationship(R2RSerializable): + """A relationship between two entities. + + This is a generic relationship, and can be used to represent any type of + relationship between any two entities. + """ + + id: Optional[UUID] = None + subject: str + predicate: str + object: str + description: Optional[str] = None + subject_id: Optional[UUID] = None + object_id: Optional[UUID] = None + weight: float | None = 1.0 + chunk_ids: Optional[list[UUID]] = [] + parent_id: Optional[UUID] = None + description_embedding: Optional[list[float] | str] = None + metadata: Optional[dict[str, Any] | str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.metadata, str): + try: + self.metadata = json.loads(self.metadata) + except json.JSONDecodeError: + self.metadata = self.metadata + + +@dataclass +class Community(R2RSerializable): + name: str = "" + summary: str = "" + level: Optional[int] = None + findings: list[str] = [] + id: Optional[int | UUID] = None + community_id: Optional[UUID] = None + collection_id: Optional[UUID] = None + rating: Optional[float] = None + rating_explanation: Optional[str] = None + description_embedding: Optional[list[float]] = None + attributes: dict[str, Any] | None = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + + def __init__(self, **kwargs): + if isinstance(kwargs.get("attributes", None), str): + kwargs["attributes"] = json.loads(kwargs["attributes"]) + + if isinstance(kwargs.get("embedding", None), str): + kwargs["embedding"] = json.loads(kwargs["embedding"]) + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, data: dict[str, Any] | str) -> "Community": + parsed_data: dict[str, Any] = ( + json.loads(data) if isinstance(data, str) else data + ) + if isinstance(parsed_data.get("embedding", None), str): + parsed_data["embedding"] = json.loads(parsed_data["embedding"]) + return cls(**parsed_data) + + +class GraphExtraction(R2RSerializable): + """A protocol for a knowledge graph extraction.""" + + entities: list[Entity] + relationships: list[Relationship] + + +class Graph(R2RSerializable): + id: UUID | None = Field() + name: str + description: Optional[str] = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) + status: str = "pending" + + class Config: + populate_by_name = True + from_attributes = True + + @classmethod + def from_dict(cls, data: dict[str, Any] | str) -> "Graph": + """Create a Graph instance from a dictionary.""" + # Convert string to dict if needed + parsed_data: dict[str, Any] = ( + json.loads(data) if isinstance(data, str) else data + ) + return cls(**parsed_data) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class StoreType(str, Enum): + GRAPHS = "graphs" + DOCUMENTS = "documents" + + +class GraphCreationSettings(R2RSerializable): + """Settings for knowledge graph creation.""" + + graph_extraction_prompt: str = Field( + default="graph_extraction", + description="The prompt to use for knowledge graph extraction.", + ) + + graph_entity_description_prompt: str = Field( + default="graph_entity_description", + description="The prompt to use for entity description generation.", + ) + + entity_types: list[str] = Field( + default=[], + description="The types of entities to extract.", + ) + + relation_types: list[str] = Field( + default=[], + description="The types of relations to extract.", + ) + + chunk_merge_count: int = Field( + default=2, + description="""The number of extractions to merge into a single graph + extraction.""", + ) + + max_knowledge_relationships: int = Field( + default=100, + description="""The maximum number of knowledge relationships to extract + from each chunk.""", + ) + + max_description_input_length: int = Field( + default=65536, + description="""The maximum length of the description for a node in the + graph.""", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + automatic_deduplication: bool = Field( + default=False, + description="Whether to automatically deduplicate entities.", + ) + + +class GraphEnrichmentSettings(R2RSerializable): + """Settings for knowledge graph enrichment.""" + + force_graph_search_results_enrichment: bool = Field( + default=False, + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", + ) + + graph_communities_prompt: str = Field( + default="graph_communities", + description="The prompt to use for knowledge graph enrichment.", + ) + + max_summary_input_length: int = Field( + default=65536, + description="The maximum length of the summary for a community.", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + leiden_params: dict = Field( + default_factory=dict, + description="Parameters for the Leiden algorithm.", + ) + + +class GraphCommunitySettings(R2RSerializable): + """Settings for knowledge graph community enrichment.""" + + force_graph_search_results_enrichment: bool = Field( + default=False, + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", + ) + + graph_communities: str = Field( + default="graph_communities", + description="The prompt to use for knowledge graph enrichment.", + ) + + max_summary_input_length: int = Field( + default=65536, + description="The maximum length of the summary for a community.", + ) + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph enrichment.", + ) + + leiden_params: dict = Field( + default_factory=dict, + description="Parameters for the Leiden algorithm.", + ) |