about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/graphs.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/graphs.py2884
1 files changed, 2884 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py b/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
new file mode 100644
index 00000000..ba9c22ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/graphs.py
@@ -0,0 +1,2884 @@
+import asyncio
+import contextlib
+import csv
+import datetime
+import json
+import logging
+import os
+import tempfile
+import time
+from typing import IO, Any, AsyncGenerator, Optional, Tuple
+from uuid import UUID
+
+import asyncpg
+import httpx
+from asyncpg.exceptions import UniqueViolationError
+from fastapi import HTTPException
+
+from core.base.abstractions import (
+    Community,
+    Entity,
+    Graph,
+    GraphExtractionStatus,
+    R2RException,
+    Relationship,
+    StoreType,
+    VectorQuantizationType,
+)
+from core.base.api.models import GraphResponse
+from core.base.providers.database import Handler
+from core.base.utils import (
+    _get_vector_column_str,
+    generate_entity_document_id,
+)
+
+from .base import PostgresConnectionManager
+from .collections import PostgresCollectionsHandler
+
+logger = logging.getLogger()
+
+
+class PostgresEntitiesHandler(Handler):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get(
+            "connection_manager"
+        )  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get(
+            "quantization_type"
+        )  # type: ignore
+        self.relationships_handler: PostgresRelationshipsHandler = (
+            PostgresRelationshipsHandler(*args, **kwargs)
+        )
+
+    def _get_table_name(self, table: str) -> str:
+        """Get the fully qualified table name."""
+        return f'"{self.project_name}"."{table}"'
+
+    def _get_entity_table_for_store(self, store_type: StoreType) -> str:
+        """Get the appropriate table name for the store type."""
+        return f"{store_type.value}_entities"
+
+    def _get_parent_constraint(self, store_type: StoreType) -> str:
+        """Get the appropriate foreign key constraint for the store type."""
+        if store_type == StoreType.GRAPHS:
+            return f"""
+                CONSTRAINT fk_graph
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("graphs")}(id)
+                    ON DELETE CASCADE
+            """
+        else:
+            return f"""
+                CONSTRAINT fk_document
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("documents")}(id)
+                    ON DELETE CASCADE
+            """
+
+    async def create_tables(self) -> None:
+        """Create separate tables for graph and document entities."""
+        vector_column_str = _get_vector_column_str(
+            self.dimension, self.quantization_type
+        )
+
+        for store_type in StoreType:
+            table_name = self._get_entity_table_for_store(store_type)
+            parent_constraint = self._get_parent_constraint(store_type)
+
+            QUERY = f"""
+                CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+                    id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+                    name TEXT NOT NULL,
+                    category TEXT,
+                    description TEXT,
+                    parent_id UUID NOT NULL,
+                    description_embedding {vector_column_str},
+                    chunk_ids UUID[],
+                    metadata JSONB,
+                    created_at TIMESTAMPTZ DEFAULT NOW(),
+                    updated_at TIMESTAMPTZ DEFAULT NOW(),
+                    {parent_constraint}
+                );
+                CREATE INDEX IF NOT EXISTS {table_name}_name_idx
+                    ON {self._get_table_name(table_name)} (name);
+                CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+                    ON {self._get_table_name(table_name)} (parent_id);
+                CREATE INDEX IF NOT EXISTS {table_name}_category_idx
+                    ON {self._get_table_name(table_name)} (category);
+            """
+            await self.connection_manager.execute_query(QUERY)
+
+    async def create(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        name: str,
+        category: Optional[str] = None,
+        description: Optional[str] = None,
+        description_embedding: Optional[list[float] | str] = None,
+        chunk_ids: Optional[list[UUID]] = None,
+        metadata: Optional[dict[str, Any] | str] = None,
+    ) -> Entity:
+        """Create a new entity in the specified store."""
+        table_name = self._get_entity_table_for_store(store_type)
+
+        if isinstance(metadata, str):
+            with contextlib.suppress(json.JSONDecodeError):
+                metadata = json.loads(metadata)
+
+        if isinstance(description_embedding, list):
+            description_embedding = str(description_embedding)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+        """
+
+        params = [
+            name,
+            category,
+            description,
+            parent_id,
+            description_embedding,
+            chunk_ids,
+            json.dumps(metadata) if metadata else None,
+        ]
+
+        result = await self.connection_manager.fetchrow_query(
+            query=query,
+            params=params,
+        )
+
+        return Entity(
+            id=result["id"],
+            name=result["name"],
+            category=result["category"],
+            description=result["description"],
+            parent_id=result["parent_id"],
+            chunk_ids=result["chunk_ids"],
+            metadata=result["metadata"],
+        )
+
+    async def get(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        offset: int,
+        limit: int,
+        entity_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ):
+        """Retrieve entities from the specified store."""
+        table_name = self._get_entity_table_for_store(store_type)
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if entity_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(entity_ids)
+            param_index += 1
+
+        if entity_names:
+            conditions.append(f"name = ANY(${param_index})")
+            params.append(entity_names)
+            param_index += 1
+
+        select_fields = """
+            id, name, category, description, parent_id,
+            chunk_ids, metadata
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name(table_name)}
+            WHERE {" AND ".join(conditions)}
+        """
+
+        count_params = params[: param_index - 1]
+        count = (
+            await self.connection_manager.fetch_query(
+                COUNT_QUERY, count_params
+            )
+        )[0]["count"]
+
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name(table_name)}
+            WHERE {" AND ".join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        entities = []
+        for row in rows:
+            # Convert the Record to a dictionary
+            entity_dict = dict(row)
+
+            # Process metadata if it exists and is a string
+            if isinstance(entity_dict["metadata"], str):
+                with contextlib.suppress(json.JSONDecodeError):
+                    entity_dict["metadata"] = json.loads(
+                        entity_dict["metadata"]
+                    )
+
+            entities.append(Entity(**entity_dict))
+
+        return entities, count
+
+    async def update(
+        self,
+        entity_id: UUID,
+        store_type: StoreType,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+        description_embedding: Optional[list[float] | str] = None,
+        category: Optional[str] = None,
+        metadata: Optional[dict] = None,
+    ) -> Entity:
+        """Update an entity in the specified store."""
+        table_name = self._get_entity_table_for_store(store_type)
+        update_fields = []
+        params: list[Any] = []
+        param_index = 1
+
+        if isinstance(metadata, str):
+            with contextlib.suppress(json.JSONDecodeError):
+                metadata = json.loads(metadata)
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if description_embedding is not None:
+            update_fields.append(f"description_embedding = ${param_index}")
+            params.append(description_embedding)
+            param_index += 1
+
+        if category is not None:
+            update_fields.append(f"category = ${param_index}")
+            params.append(category)
+            param_index += 1
+
+        if metadata is not None:
+            update_fields.append(f"metadata = ${param_index}")
+            params.append(json.dumps(metadata))
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(entity_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {", ".join(update_fields)}
+            WHERE id = ${param_index}\
+            RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+        """
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return Entity(
+                id=result["id"],
+                name=result["name"],
+                category=result["category"],
+                description=result["description"],
+                parent_id=result["parent_id"],
+                chunk_ids=result["chunk_ids"],
+                metadata=result["metadata"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the entity: {e}",
+            ) from e
+
+    async def delete(
+        self,
+        parent_id: UUID,
+        entity_ids: Optional[list[UUID]] = None,
+        store_type: StoreType = StoreType.GRAPHS,
+    ) -> None:
+        """Delete entities from the specified store. If entity_ids is not
+        provided, deletes all entities for the given parent_id.
+
+        Args:
+            parent_id (UUID): Parent ID (collection_id or document_id)
+            entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
+            store_type (StoreType): Type of store (graph or document)
+
+        Returns:
+            list[UUID]: List of deleted entity IDs
+
+        Raises:
+            R2RException: If specific entities were requested but not all found
+        """
+        table_name = self._get_entity_table_for_store(store_type)
+
+        if entity_ids is None:
+            # Delete all entities for the parent_id
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE parent_id = $1
+                RETURNING id
+            """
+            results = await self.connection_manager.fetch_query(
+                QUERY, [parent_id]
+            )
+        else:
+            # Delete specific entities
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE id = ANY($1) AND parent_id = $2
+                RETURNING id
+            """
+
+            results = await self.connection_manager.fetch_query(
+                QUERY, [entity_ids, parent_id]
+            )
+
+            # Check if all requested entities were deleted
+            deleted_ids = [row["id"] for row in results]
+            if entity_ids and len(deleted_ids) != len(entity_ids):
+                raise R2RException(
+                    f"Some entities not found in {store_type} store or no permission to delete",
+                    404,
+                )
+
+    async def get_duplicate_name_blocks(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+    ) -> list[list[Entity]]:
+        """Find all groups of entities that share identical names within the
+        same parent.
+
+        Returns a list of entity groups, where each group contains entities
+        with the same name. For each group, includes the n most dissimilar
+        descriptions based on cosine similarity.
+        """
+        table_name = self._get_entity_table_for_store(store_type)
+
+        # First get the duplicate names and their descriptions with embeddings
+        query = f"""
+            WITH duplicates AS (
+                SELECT name
+                FROM {self._get_table_name(table_name)}
+                WHERE parent_id = $1
+                GROUP BY name
+                HAVING COUNT(*) > 1
+            )
+            SELECT
+                e.id, e.name, e.category, e.description,
+                e.parent_id, e.chunk_ids, e.metadata
+            FROM {self._get_table_name(table_name)} e
+            WHERE e.parent_id = $1
+            AND e.name IN (SELECT name FROM duplicates)
+            ORDER BY e.name;
+        """
+
+        rows = await self.connection_manager.fetch_query(query, [parent_id])
+
+        # Group entities by name
+        name_groups: dict[str, list[Entity]] = {}
+        for row in rows:
+            entity_dict = dict(row)
+            if isinstance(entity_dict["metadata"], str):
+                with contextlib.suppress(json.JSONDecodeError):
+                    entity_dict["metadata"] = json.loads(
+                        entity_dict["metadata"]
+                    )
+
+            entity = Entity(**entity_dict)
+            name_groups.setdefault(entity.name, []).append(entity)
+
+        return list(name_groups.values())
+
+    async def merge_duplicate_name_blocks(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+    ) -> list[tuple[list[Entity], Entity]]:
+        """Merge entities that share identical names.
+
+        Returns list of tuples: (original_entities, merged_entity)
+        """
+        duplicate_blocks = await self.get_duplicate_name_blocks(
+            parent_id, store_type
+        )
+        merged_results: list[tuple[list[Entity], Entity]] = []
+
+        for block in duplicate_blocks:
+            # Create a new merged entity from the block
+            merged_entity = await self._create_merged_entity(block)
+            merged_results.append((block, merged_entity))
+
+            table_name = self._get_entity_table_for_store(store_type)
+            async with self.connection_manager.transaction():
+                # Insert the merged entity
+                new_id = await self._insert_merged_entity(
+                    merged_entity, table_name
+                )
+
+                merged_entity.id = new_id
+
+                # Get the old entity IDs
+                old_ids = [str(entity.id) for entity in block]
+
+                relationship_table = self.relationships_handler._get_relationship_table_for_store(
+                    store_type
+                )
+
+                # Update relationships where old entities appear as subjects
+                subject_update_query = f"""
+                    UPDATE {self._get_table_name(relationship_table)}
+                    SET subject_id = $1
+                    WHERE subject_id = ANY($2::uuid[])
+                    AND parent_id = $3
+                """
+                await self.connection_manager.execute_query(
+                    subject_update_query, [new_id, old_ids, parent_id]
+                )
+
+                # Update relationships where old entities appear as objects
+                object_update_query = f"""
+                    UPDATE {self._get_table_name(relationship_table)}
+                    SET object_id = $1
+                    WHERE object_id = ANY($2::uuid[])
+                    AND parent_id = $3
+                """
+                await self.connection_manager.execute_query(
+                    object_update_query, [new_id, old_ids, parent_id]
+                )
+
+                # Delete the original entities
+                delete_query = f"""
+                    DELETE FROM {self._get_table_name(table_name)}
+                    WHERE id = ANY($1::uuid[])
+                """
+                await self.connection_manager.execute_query(
+                    delete_query, [old_ids]
+                )
+
+        return merged_results
+
+    async def _insert_merged_entity(
+        self, entity: Entity, table_name: str
+    ) -> UUID:
+        """Insert merged entity and return its new ID."""
+        new_id = generate_entity_document_id()
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (id, name, category, description, parent_id, chunk_ids, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            RETURNING id
+        """
+
+        values = [
+            new_id,
+            entity.name,
+            entity.category,
+            entity.description,
+            entity.parent_id,
+            entity.chunk_ids,
+            json.dumps(entity.metadata) if entity.metadata else None,
+        ]
+
+        result = await self.connection_manager.fetch_query(query, values)
+        return result[0]["id"]
+
+    async def _create_merged_entity(self, entities: list[Entity]) -> Entity:
+        """Create a merged entity from a list of duplicate entities.
+
+        Uses various strategies to combine fields.
+        """
+        if not entities:
+            raise ValueError("Cannot merge empty list of entities")
+
+        # Take the first non-None category, or None if all are None
+        category = next(
+            (e.category for e in entities if e.category is not None), None
+        )
+
+        # Combine descriptions with newlines if they differ
+        descriptions = {e.description for e in entities if e.description}
+        description = "\n\n".join(descriptions) if descriptions else None
+
+        # Combine chunk_ids, removing duplicates
+        chunk_ids = list(
+            {
+                chunk_id
+                for entity in entities
+                for chunk_id in (entity.chunk_ids or [])
+            }
+        )
+
+        # Merge metadata dictionaries
+        merged_metadata: dict[str, Any] = {}
+        for entity in entities:
+            if entity.metadata:
+                merged_metadata |= entity.metadata
+
+        # Create new merged entity (without actually inserting to DB)
+        return Entity(
+            id=UUID(
+                "00000000-0000-0000-0000-000000000000"
+            ),  # Placeholder UUID
+            name=entities[0].name,  # All entities in block have same name
+            category=category,
+            description=description,
+            parent_id=entities[0].parent_id,
+            chunk_ids=chunk_ids or None,
+            metadata=merged_metadata or None,
+        )
+
+    async def export_to_csv(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """Creates a CSV file from the PostgreSQL data and returns the path to
+        the temp file."""
+        valid_columns = {
+            "id",
+            "name",
+            "category",
+            "description",
+            "parent_id",
+            "chunk_ids",
+            "metadata",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                name,
+                category,
+                description,
+                parent_id::text,
+                chunk_ids::text,
+                metadata::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
+        """
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            row_dict = {
+                                "id": row[0],
+                                "name": row[1],
+                                "category": row[2],
+                                "description": row[3],
+                                "parent_id": row[4],
+                                "chunk_ids": row[5],
+                                "metadata": row[6],
+                                "created_at": row[7],
+                                "updated_at": row[8],
+                            }
+                            writer.writerow([row_dict[col] for col in columns])
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
+
+class PostgresRelationshipsHandler(Handler):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get(
+            "connection_manager"
+        )  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get(
+            "quantization_type"
+        )  # type: ignore
+
+    def _get_table_name(self, table: str) -> str:
+        """Get the fully qualified table name."""
+        return f'"{self.project_name}"."{table}"'
+
+    def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
+        """Get the appropriate table name for the store type."""
+        return f"{store_type.value}_relationships"
+
+    def _get_parent_constraint(self, store_type: StoreType) -> str:
+        """Get the appropriate foreign key constraint for the store type."""
+        if store_type == StoreType.GRAPHS:
+            return f"""
+                CONSTRAINT fk_graph
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("graphs")}(id)
+                    ON DELETE CASCADE
+            """
+        else:
+            return f"""
+                CONSTRAINT fk_document
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("documents")}(id)
+                    ON DELETE CASCADE
+            """
+
+    async def create_tables(self) -> None:
+        """Create separate tables for graph and document relationships."""
+        for store_type in StoreType:
+            table_name = self._get_relationship_table_for_store(store_type)
+            parent_constraint = self._get_parent_constraint(store_type)
+            vector_column_str = _get_vector_column_str(
+                self.dimension, self.quantization_type
+            )
+
+            QUERY = f"""
+                CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+                    id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+                    subject TEXT NOT NULL,
+                    predicate TEXT NOT NULL,
+                    object TEXT NOT NULL,
+                    description TEXT,
+                    description_embedding {vector_column_str},
+                    subject_id UUID,
+                    object_id UUID,
+                    weight FLOAT DEFAULT 1.0,
+                    chunk_ids UUID[],
+                    parent_id UUID NOT NULL,
+                    metadata JSONB,
+                    created_at TIMESTAMPTZ DEFAULT NOW(),
+                    updated_at TIMESTAMPTZ DEFAULT NOW(),
+                    {parent_constraint}
+                );
+
+                CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
+                    ON {self._get_table_name(table_name)} (subject);
+                CREATE INDEX IF NOT EXISTS {table_name}_object_idx
+                    ON {self._get_table_name(table_name)} (object);
+                CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
+                    ON {self._get_table_name(table_name)} (predicate);
+                CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+                    ON {self._get_table_name(table_name)} (parent_id);
+                CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
+                    ON {self._get_table_name(table_name)} (subject_id);
+                CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
+                    ON {self._get_table_name(table_name)} (object_id);
+            """
+            await self.connection_manager.execute_query(QUERY)
+
+    async def create(
+        self,
+        subject: str,
+        subject_id: UUID,
+        predicate: str,
+        object: str,
+        object_id: UUID,
+        parent_id: UUID,
+        store_type: StoreType,
+        description: str | None = None,
+        weight: float | None = 1.0,
+        chunk_ids: Optional[list[UUID]] = None,
+        description_embedding: Optional[list[float] | str] = None,
+        metadata: Optional[dict[str, Any] | str] = None,
+    ) -> Relationship:
+        """Create a new relationship in the specified store."""
+        table_name = self._get_relationship_table_for_store(store_type)
+
+        if isinstance(metadata, str):
+            with contextlib.suppress(json.JSONDecodeError):
+                metadata = json.loads(metadata)
+
+        if isinstance(description_embedding, list):
+            description_embedding = str(description_embedding)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (subject, predicate, object, description, subject_id, object_id,
+             weight, chunk_ids, parent_id, description_embedding, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+            RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+        """
+
+        params = [
+            subject,
+            predicate,
+            object,
+            description,
+            subject_id,
+            object_id,
+            weight,
+            chunk_ids,
+            parent_id,
+            description_embedding,
+            json.dumps(metadata) if metadata else None,
+        ]
+
+        result = await self.connection_manager.fetchrow_query(
+            query=query,
+            params=params,
+        )
+
+        return Relationship(
+            id=result["id"],
+            subject=result["subject"],
+            predicate=result["predicate"],
+            object=result["object"],
+            description=result["description"],
+            subject_id=result["subject_id"],
+            object_id=result["object_id"],
+            weight=result["weight"],
+            chunk_ids=result["chunk_ids"],
+            parent_id=result["parent_id"],
+            metadata=result["metadata"],
+        )
+
+    async def get(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        offset: int,
+        limit: int,
+        relationship_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        relationship_types: Optional[list[str]] = None,
+        include_metadata: bool = False,
+    ):
+        """Get relationships from the specified store.
+
+        Args:
+            parent_id: UUID of the parent (collection_id or document_id)
+            store_type: Type of store (graph or document)
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            relationship_ids: Optional list of specific relationship IDs to retrieve
+            entity_names: Optional list of entity names to filter by (matches subject or object)
+            relationship_types: Optional list of relationship types (predicates) to filter by
+            include_metadata: Whether to include metadata in the response
+
+        Returns:
+            Tuple of (list of relationships, total count)
+        """
+        table_name = self._get_relationship_table_for_store(store_type)
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if relationship_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(relationship_ids)
+            param_index += 1
+
+        if entity_names:
+            conditions.append(
+                f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
+            )
+            params.append(entity_names)
+            param_index += 1
+
+        if relationship_types:
+            conditions.append(f"predicate = ANY(${param_index})")
+            params.append(relationship_types)
+            param_index += 1
+
+        select_fields = """
+            id, subject, predicate, object, description,
+            subject_id, object_id, weight, chunk_ids,
+            parent_id
+        """
+        if include_metadata:
+            select_fields += ", metadata"
+
+        # Count query
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name(table_name)}
+            WHERE {" AND ".join(conditions)}
+        """
+        count_params = params[: param_index - 1]
+        count = (
+            await self.connection_manager.fetch_query(
+                COUNT_QUERY, count_params
+            )
+        )[0]["count"]
+
+        # Main query
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name(table_name)}
+            WHERE {" AND ".join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        relationships = []
+        for row in rows:
+            relationship_dict = dict(row)
+            if include_metadata and isinstance(
+                relationship_dict["metadata"], str
+            ):
+                with contextlib.suppress(json.JSONDecodeError):
+                    relationship_dict["metadata"] = json.loads(
+                        relationship_dict["metadata"]
+                    )
+            elif not include_metadata:
+                relationship_dict.pop("metadata", None)
+            relationships.append(Relationship(**relationship_dict))
+
+        return relationships, count
+
+    async def update(
+        self,
+        relationship_id: UUID,
+        store_type: StoreType,
+        subject: Optional[str],
+        subject_id: Optional[UUID],
+        predicate: Optional[str],
+        object: Optional[str],
+        object_id: Optional[UUID],
+        description: Optional[str],
+        description_embedding: Optional[list[float] | str],
+        weight: Optional[float],
+        metadata: Optional[dict[str, Any] | str],
+    ) -> Relationship:
+        """Update multiple relationships in the specified store."""
+        table_name = self._get_relationship_table_for_store(store_type)
+        update_fields = []
+        params: list = []
+        param_index = 1
+
+        if isinstance(metadata, str):
+            with contextlib.suppress(json.JSONDecodeError):
+                metadata = json.loads(metadata)
+
+        if subject is not None:
+            update_fields.append(f"subject = ${param_index}")
+            params.append(subject)
+            param_index += 1
+
+        if subject_id is not None:
+            update_fields.append(f"subject_id = ${param_index}")
+            params.append(subject_id)
+            param_index += 1
+
+        if predicate is not None:
+            update_fields.append(f"predicate = ${param_index}")
+            params.append(predicate)
+            param_index += 1
+
+        if object is not None:
+            update_fields.append(f"object = ${param_index}")
+            params.append(object)
+            param_index += 1
+
+        if object_id is not None:
+            update_fields.append(f"object_id = ${param_index}")
+            params.append(object_id)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if description_embedding is not None:
+            update_fields.append(f"description_embedding = ${param_index}")
+            params.append(description_embedding)
+            param_index += 1
+
+        if weight is not None:
+            update_fields.append(f"weight = ${param_index}")
+            params.append(weight)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(relationship_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {", ".join(update_fields)}
+            WHERE id = ${param_index}
+            RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+        """
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return Relationship(
+                id=result["id"],
+                subject=result["subject"],
+                predicate=result["predicate"],
+                object=result["object"],
+                description=result["description"],
+                subject_id=result["subject_id"],
+                object_id=result["object_id"],
+                weight=result["weight"],
+                chunk_ids=result["chunk_ids"],
+                parent_id=result["parent_id"],
+                metadata=result["metadata"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the relationship: {e}",
+            ) from e
+
+    async def delete(
+        self,
+        parent_id: UUID,
+        relationship_ids: Optional[list[UUID]] = None,
+        store_type: StoreType = StoreType.GRAPHS,
+    ) -> None:
+        """Delete relationships from the specified store. If relationship_ids
+        is not provided, deletes all relationships for the given parent_id.
+
+        Args:
+            parent_id: UUID of the parent (collection_id or document_id)
+            relationship_ids: Optional list of specific relationship IDs to delete
+            store_type: Type of store (graph or document)
+
+        Returns:
+            List of deleted relationship IDs
+
+        Raises:
+            R2RException: If specific relationships were requested but not all found
+        """
+        table_name = self._get_relationship_table_for_store(store_type)
+
+        if relationship_ids is None:
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE parent_id = $1
+                RETURNING id
+            """
+            results = await self.connection_manager.fetch_query(
+                QUERY, [parent_id]
+            )
+        else:
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE id = ANY($1) AND parent_id = $2
+                RETURNING id
+            """
+            results = await self.connection_manager.fetch_query(
+                QUERY, [relationship_ids, parent_id]
+            )
+
+            deleted_ids = [row["id"] for row in results]
+            if relationship_ids and len(deleted_ids) != len(relationship_ids):
+                raise R2RException(
+                    f"Some relationships not found in {store_type} store or no permission to delete",
+                    404,
+                )
+
+    async def export_to_csv(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """Creates a CSV file from the PostgreSQL data and returns the path to
+        the temp file."""
+        valid_columns = {
+            "id",
+            "subject",
+            "predicate",
+            "object",
+            "description",
+            "subject_id",
+            "object_id",
+            "weight",
+            "chunk_ids",
+            "parent_id",
+            "metadata",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                subject,
+                predicate,
+                object,
+                description,
+                subject_id::text,
+                object_id::text,
+                weight,
+                chunk_ids::text,
+                parent_id::text,
+                metadata::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
+        """
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
+
+class PostgresCommunitiesHandler(Handler):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get(
+            "connection_manager"
+        )  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get(
+            "quantization_type"
+        )  # type: ignore
+
+    async def create_tables(self) -> None:
+        vector_column_str = _get_vector_column_str(
+            self.dimension, self.quantization_type
+        )
+
+        query = f"""
+            CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            collection_id UUID,
+            community_id UUID,
+            level INT,
+            name TEXT NOT NULL,
+            summary TEXT NOT NULL,
+            findings TEXT[],
+            rating FLOAT,
+            rating_explanation TEXT,
+            description_embedding {vector_column_str} NOT NULL,
+            created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+            updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+            metadata JSONB,
+            UNIQUE (community_id, level, collection_id)
+        );"""
+
+        await self.connection_manager.execute_query(query)
+
+    async def create(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        name: str,
+        summary: str,
+        findings: Optional[list[str]],
+        rating: Optional[float],
+        rating_explanation: Optional[str],
+        description_embedding: Optional[list[float] | str] = None,
+    ) -> Community:
+        table_name = "graphs_communities"
+
+        if isinstance(description_embedding, list):
+            description_embedding = str(description_embedding)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+        """
+
+        params = [
+            parent_id,
+            name,
+            summary,
+            findings,
+            rating,
+            rating_explanation,
+            description_embedding,
+        ]
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return Community(
+                id=result["id"],
+                collection_id=result["collection_id"],
+                name=result["name"],
+                summary=result["summary"],
+                findings=result["findings"],
+                rating=result["rating"],
+                rating_explanation=result["rating_explanation"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while creating the community: {e}",
+            ) from e
+
+    async def update(
+        self,
+        community_id: UUID,
+        store_type: StoreType,
+        name: Optional[str] = None,
+        summary: Optional[str] = None,
+        summary_embedding: Optional[list[float] | str] = None,
+        findings: Optional[list[str]] = None,
+        rating: Optional[float] = None,
+        rating_explanation: Optional[str] = None,
+    ) -> Community:
+        table_name = "graphs_communities"
+        update_fields = []
+        params: list[Any] = []
+        param_index = 1
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if summary is not None:
+            update_fields.append(f"summary = ${param_index}")
+            params.append(summary)
+            param_index += 1
+
+        if summary_embedding is not None:
+            update_fields.append(f"description_embedding = ${param_index}")
+            params.append(summary_embedding)
+            param_index += 1
+
+        if findings is not None:
+            update_fields.append(f"findings = ${param_index}")
+            params.append(findings)
+            param_index += 1
+
+        if rating is not None:
+            update_fields.append(f"rating = ${param_index}")
+            params.append(rating)
+            param_index += 1
+
+        if rating_explanation is not None:
+            update_fields.append(f"rating_explanation = ${param_index}")
+            params.append(rating_explanation)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(community_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {", ".join(update_fields)}
+            WHERE id = ${param_index}\
+            RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+        """
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query, params
+            )
+
+            return Community(
+                id=result["id"],
+                community_id=result["community_id"],
+                name=result["name"],
+                summary=result["summary"],
+                findings=result["findings"],
+                rating=result["rating"],
+                rating_explanation=result["rating_explanation"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the community: {e}",
+            ) from e
+
+    async def delete(
+        self,
+        parent_id: UUID,
+        community_id: UUID,
+    ) -> None:
+        table_name = "graphs_communities"
+
+        params = [community_id, parent_id]
+
+        # Delete the community
+        query = f"""
+            DELETE FROM {self._get_table_name(table_name)}
+            WHERE id = $1 AND collection_id = $2
+        """
+
+        try:
+            await self.connection_manager.execute_query(query, params)
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while deleting the community: {e}",
+            ) from e
+
+    async def delete_all_communities(
+        self,
+        parent_id: UUID,
+    ) -> None:
+        table_name = "graphs_communities"
+
+        params = [parent_id]
+
+        # Delete all communities for the parent_id
+        query = f"""
+            DELETE FROM {self._get_table_name(table_name)}
+            WHERE collection_id = $1
+        """
+
+        try:
+            await self.connection_manager.execute_query(query, params)
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while deleting communities: {e}",
+            ) from e
+
+    async def get(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        offset: int,
+        limit: int,
+        community_ids: Optional[list[UUID]] = None,
+        community_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ):
+        """Retrieve communities from the specified store."""
+        # Do we ever want to get communities from document store?
+        table_name = "graphs_communities"
+
+        conditions = ["collection_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if community_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(community_ids)
+            param_index += 1
+
+        if community_names:
+            conditions.append(f"name = ANY(${param_index})")
+            params.append(community_names)
+            param_index += 1
+
+        select_fields = """
+            id, community_id, name, summary, findings, rating,
+            rating_explanation, level, created_at, updated_at
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name(table_name)}
+            WHERE {" AND ".join(conditions)}
+        """
+
+        count = (
+            await self.connection_manager.fetch_query(
+                COUNT_QUERY, params[: param_index - 1]
+            )
+        )[0]["count"]
+
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name(table_name)}
+            WHERE {" AND ".join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        communities = []
+        for row in rows:
+            community_dict = dict(row)
+
+            communities.append(Community(**community_dict))
+
+        return communities, count
+
+    async def export_to_csv(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """Creates a CSV file from the PostgreSQL data and returns the path to
+        the temp file."""
+        valid_columns = {
+            "id",
+            "collection_id",
+            "community_id",
+            "level",
+            "name",
+            "summary",
+            "findings",
+            "rating",
+            "rating_explanation",
+            "created_at",
+            "updated_at",
+            "metadata",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        table_name = "graphs_communities"
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                collection_id::text,
+                community_id::text,
+                level,
+                name,
+                summary,
+                findings::text,
+                rating,
+                rating_explanation,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+                metadata::text
+            FROM {self._get_table_name(table_name)}
+        """
+
+        conditions = ["collection_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
+
+class PostgresGraphsHandler(Handler):
+    """Handler for Knowledge Graph METHODS in PostgreSQL."""
+
+    TABLE_NAME = "graphs"
+
+    def __init__(
+        self,
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get(
+            "connection_manager"
+        )  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get(
+            "quantization_type"
+        )  # type: ignore
+        self.collections_handler: PostgresCollectionsHandler = kwargs.get(
+            "collections_handler"
+        )  # type: ignore
+
+        self.entities = PostgresEntitiesHandler(*args, **kwargs)
+        self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
+        self.communities = PostgresCommunitiesHandler(*args, **kwargs)
+
+        self.handlers = [
+            self.entities,
+            self.relationships,
+            self.communities,
+        ]
+
+    async def create_tables(self) -> None:
+        """Create the graph tables with mandatory collection_id support."""
+        QUERY = f"""
+            CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
+                id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+                collection_id UUID NOT NULL,
+                name TEXT NOT NULL,
+                description TEXT,
+                status TEXT NOT NULL,
+                document_ids UUID[],
+                metadata JSONB,
+                created_at TIMESTAMPTZ DEFAULT NOW(),
+                updated_at TIMESTAMPTZ DEFAULT NOW()
+            );
+
+            CREATE INDEX IF NOT EXISTS graph_collection_id_idx
+                ON {self._get_table_name("graphs")} (collection_id);
+        """
+
+        await self.connection_manager.execute_query(QUERY)
+
+        for handler in self.handlers:
+            await handler.create_tables()
+
+    async def create(
+        self,
+        collection_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+        status: str = "pending",
+    ) -> GraphResponse:
+        """Create a new graph associated with a collection."""
+
+        name = name or f"Graph {collection_id}"
+        description = description or ""
+
+        query = f"""
+            INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            (id, collection_id, name, description, status)
+            VALUES ($1, $2, $3, $4, $5)
+            RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
+        """
+        params = [
+            collection_id,
+            collection_id,
+            name,
+            description,
+            status,
+        ]
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return GraphResponse(
+                id=result["id"],
+                collection_id=result["collection_id"],
+                name=result["name"],
+                description=result["description"],
+                status=result["status"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+                document_ids=result["document_ids"] or [],
+            )
+        except UniqueViolationError:
+            raise R2RException(
+                message="Graph with this ID already exists",
+                status_code=409,
+            ) from None
+
+    async def reset(self, parent_id: UUID) -> None:
+        """Completely reset a graph and all associated data."""
+
+        await self.entities.delete(
+            parent_id=parent_id, store_type=StoreType.GRAPHS
+        )
+        await self.relationships.delete(
+            parent_id=parent_id, store_type=StoreType.GRAPHS
+        )
+        await self.communities.delete_all_communities(parent_id=parent_id)
+
+        # Now, update the graph record to remove any attached document IDs.
+        # This sets document_ids to an empty UUID array.
+        query = f"""
+            UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            SET document_ids = ARRAY[]::uuid[]
+            WHERE id = $1;
+        """
+        await self.connection_manager.execute_query(query, [parent_id])
+
+    async def list_graphs(
+        self,
+        offset: int,
+        limit: int,
+        # filter_user_ids: Optional[list[UUID]] = None,
+        filter_graph_ids: Optional[list[UUID]] = None,
+        filter_collection_id: Optional[UUID] = None,
+    ) -> dict[str, list[GraphResponse] | int]:
+        conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        if filter_graph_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(filter_graph_ids)
+            param_index += 1
+
+        # if filter_user_ids:
+        #     conditions.append(f"user_id = ANY(${param_index})")
+        #     params.append(filter_user_ids)
+        #     param_index += 1
+
+        if filter_collection_id:
+            conditions.append(f"collection_id = ${param_index}")
+            params.append(filter_collection_id)
+            param_index += 1
+
+        where_clause = (
+            f"WHERE {' AND '.join(conditions)}" if conditions else ""
+        )
+
+        query = f"""
+            WITH RankedGraphs AS (
+                SELECT
+                    id, collection_id, name, description, status, created_at, updated_at, document_ids,
+                    COUNT(*) OVER() as total_entries,
+                    ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
+                FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+                {where_clause}
+            )
+            SELECT * FROM RankedGraphs
+            WHERE rn = 1
+            ORDER BY created_at DESC
+            OFFSET ${param_index} LIMIT ${param_index + 1}
+        """
+
+        params.extend([offset, limit])
+
+        try:
+            results = await self.connection_manager.fetch_query(query, params)
+            if not results:
+                return {"results": [], "total_entries": 0}
+
+            total_entries = results[0]["total_entries"] if results else 0
+
+            graphs = [
+                GraphResponse(
+                    id=row["id"],
+                    document_ids=row["document_ids"] or [],
+                    name=row["name"],
+                    collection_id=row["collection_id"],
+                    description=row["description"],
+                    status=row["status"],
+                    created_at=row["created_at"],
+                    updated_at=row["updated_at"],
+                )
+                for row in results
+            ]
+
+            return {"results": graphs, "total_entries": total_entries}
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while fetching graphs: {e}",
+            ) from e
+
+    async def get(
+        self, offset: int, limit: int, graph_id: Optional[UUID] = None
+    ):
+        if graph_id is None:
+            params = [offset, limit]
+
+            QUERY = f"""
+                SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+                OFFSET $1 LIMIT $2
+            """
+
+            ret = await self.connection_manager.fetch_query(QUERY, params)
+
+            COUNT_QUERY = f"""
+                SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            """
+            count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
+                0
+            ]["count"]
+
+            return {
+                "results": [Graph(**row) for row in ret],
+                "total_entries": count,
+            }
+
+        else:
+            QUERY = f"""
+                SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
+            """
+
+            params = [graph_id]  # type: ignore
+
+            return {
+                "results": [
+                    Graph(
+                        **await self.connection_manager.fetchrow_query(
+                            QUERY, params
+                        )
+                    )
+                ]
+            }
+
+    async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
+        """Add documents to the graph by copying their entities and
+        relationships."""
+        # Copy entities from document_entity to graphs_entities
+        ENTITY_COPY_QUERY = f"""
+            INSERT INTO {self._get_table_name("graphs_entities")} (
+                name, category, description, parent_id, description_embedding,
+                chunk_ids, metadata
+            )
+            SELECT
+                name, category, description, $1, description_embedding,
+                chunk_ids, metadata
+            FROM {self._get_table_name("documents_entities")}
+            WHERE parent_id = ANY($2)
+        """
+        await self.connection_manager.execute_query(
+            ENTITY_COPY_QUERY, [id, document_ids]
+        )
+
+        # Copy relationships from documents_relationships to graphs_relationships
+        RELATIONSHIP_COPY_QUERY = f"""
+            INSERT INTO {self._get_table_name("graphs_relationships")} (
+                subject, predicate, object, description, subject_id, object_id,
+                weight, chunk_ids, parent_id, metadata, description_embedding
+            )
+            SELECT
+                subject, predicate, object, description, subject_id, object_id,
+                weight, chunk_ids, $1, metadata, description_embedding
+            FROM {self._get_table_name("documents_relationships")}
+            WHERE parent_id = ANY($2)
+        """
+        await self.connection_manager.execute_query(
+            RELATIONSHIP_COPY_QUERY, [id, document_ids]
+        )
+
+        # Add document_ids to the graph
+        UPDATE_GRAPH_QUERY = f"""
+            UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            SET document_ids = array_cat(
+                CASE
+                    WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
+                    ELSE document_ids
+                END,
+                $2::uuid[]
+            )
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(
+            UPDATE_GRAPH_QUERY, [id, document_ids]
+        )
+
+        return True
+
+    async def update(
+        self,
+        collection_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+    ) -> GraphResponse:
+        """Update an existing graph."""
+        update_fields = []
+        params: list = []
+        param_index = 1
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(collection_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            SET {", ".join(update_fields)}
+            WHERE id = ${param_index}
+            RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
+        """
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query, params
+            )
+
+            if not result:
+                raise R2RException(status_code=404, message="Graph not found")
+
+            return GraphResponse(
+                id=result["id"],
+                collection_id=result["collection_id"],
+                name=result["name"],
+                description=result["description"],
+                status=result["status"],
+                created_at=result["created_at"],
+                document_ids=result["document_ids"] or [],
+                updated_at=result["updated_at"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the graph: {e}",
+            ) from e
+
+    async def get_entities(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        entity_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ) -> tuple[list[Entity], int]:
+        """Get entities for a graph.
+
+        Args:
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            parent_id: UUID of the collection
+            entity_ids: Optional list of entity IDs to filter by
+            entity_names: Optional list of entity names to filter by
+            include_embeddings: Whether to include embeddings in the response
+
+        Returns:
+            Tuple of (list of entities, total count)
+        """
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if entity_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(entity_ids)
+            param_index += 1
+
+        if entity_names:
+            conditions.append(f"name = ANY(${param_index})")
+            params.append(entity_names)
+            param_index += 1
+
+        # Count query - uses the same conditions but without offset/limit
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name("graphs_entities")}
+            WHERE {" AND ".join(conditions)}
+        """
+        count = (
+            await self.connection_manager.fetch_query(COUNT_QUERY, params)
+        )[0]["count"]
+
+        # Define base columns to select
+        select_fields = """
+            id, name, category, description, parent_id,
+            chunk_ids, metadata
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        # Main query for fetching entities with pagination
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name("graphs_entities")}
+            WHERE {" AND ".join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        entities = []
+        for row in rows:
+            entity_dict = dict(row)
+            if isinstance(entity_dict["metadata"], str):
+                with contextlib.suppress(json.JSONDecodeError):
+                    entity_dict["metadata"] = json.loads(
+                        entity_dict["metadata"]
+                    )
+
+            entities.append(Entity(**entity_dict))
+
+        return entities, count
+
+    async def get_relationships(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        relationship_ids: Optional[list[UUID]] = None,
+        relationship_types: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ) -> tuple[list[Relationship], int]:
+        """Get relationships for a graph.
+
+        Args:
+            parent_id: UUID of the graph
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            relationship_ids: Optional list of relationship IDs to filter by
+            relationship_types: Optional list of relationship types to filter by
+            include_metadata: Whether to include metadata in the response
+
+        Returns:
+            Tuple of (list of relationships, total count)
+        """
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if relationship_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(relationship_ids)
+            param_index += 1
+
+        if relationship_types:
+            conditions.append(f"predicate = ANY(${param_index})")
+            params.append(relationship_types)
+            param_index += 1
+
+        # Count query - uses the same conditions but without offset/limit
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name("graphs_relationships")}
+            WHERE {" AND ".join(conditions)}
+        """
+        count = (
+            await self.connection_manager.fetch_query(COUNT_QUERY, params)
+        )[0]["count"]
+
+        # Define base columns to select
+        select_fields = """
+            id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        # Main query for fetching relationships with pagination
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name("graphs_relationships")}
+            WHERE {" AND ".join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        relationships = []
+        for row in rows:
+            relationship_dict = dict(row)
+            if isinstance(relationship_dict["metadata"], str):
+                with contextlib.suppress(json.JSONDecodeError):
+                    relationship_dict["metadata"] = json.loads(
+                        relationship_dict["metadata"]
+                    )
+
+            relationships.append(Relationship(**relationship_dict))
+
+        return relationships, count
+
+    async def add_entities(
+        self,
+        entities: list[Entity],
+        table_name: str,
+        conflict_columns: list[str] | None = None,
+    ) -> asyncpg.Record:
+        """Upsert entities into the entities_raw table. These are raw entities
+        extracted from the document.
+
+        Args:
+            entities: list[Entity]: list of entities to upsert
+            collection_name: str: name of the collection
+
+        Returns:
+            result: asyncpg.Record: result of the upsert operation
+        """
+        if not conflict_columns:
+            conflict_columns = []
+        cleaned_entities = []
+        for entity in entities:
+            entity_dict = entity.to_dict()
+            entity_dict["chunk_ids"] = (
+                entity_dict["chunk_ids"]
+                if entity_dict.get("chunk_ids")
+                else []
+            )
+            entity_dict["description_embedding"] = (
+                str(entity_dict["description_embedding"])
+                if entity_dict.get("description_embedding")  # type: ignore
+                else None
+            )
+            cleaned_entities.append(entity_dict)
+
+        return await _add_objects(
+            objects=cleaned_entities,
+            full_table_name=self._get_table_name(table_name),
+            connection_manager=self.connection_manager,
+            conflict_columns=conflict_columns,
+        )
+
+    async def get_all_relationships(
+        self,
+        collection_id: UUID | None,
+        graph_id: UUID | None,
+        document_ids: Optional[list[UUID]] = None,
+    ) -> list[Relationship]:
+        QUERY = f"""
+            SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
+        """
+        relationships = await self.connection_manager.fetch_query(
+            QUERY, [collection_id]
+        )
+
+        return [Relationship(**relationship) for relationship in relationships]
+
+    async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
+        """Check if a document exists in the graph's document_ids array.
+
+        Args:
+            graph_id (UUID): ID of the graph to check
+            document_id (UUID): ID of the document to look for
+
+        Returns:
+            bool: True if document exists in graph, False otherwise
+
+        Raises:
+            R2RException: If graph not found
+        """
+        QUERY = f"""
+            SELECT EXISTS (
+                SELECT 1
+                FROM {self._get_table_name("graphs")}
+                WHERE id = $1
+                AND document_ids IS NOT NULL
+                AND $2 = ANY(document_ids)
+            ) as exists;
+        """
+
+        result = await self.connection_manager.fetchrow_query(
+            QUERY, [graph_id, document_id]
+        )
+
+        if result is None:
+            raise R2RException(f"Graph {graph_id} not found", 404)
+
+        return result["exists"]
+
+    async def get_communities(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        community_ids: Optional[list[UUID]] = None,
+        include_embeddings: bool = False,
+    ) -> tuple[list[Community], int]:
+        """Get communities for a graph.
+
+        Args:
+            collection_id: UUID of the collection
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            community_ids: Optional list of community IDs to filter by
+            include_embeddings: Whether to include embeddings in the response
+
+        Returns:
+            Tuple of (list of communities, total count)
+        """
+        conditions = ["collection_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if community_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(community_ids)
+            param_index += 1
+
+        select_fields = """
+            id, collection_id, name, summary, findings, rating, rating_explanation
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name("graphs_communities")}
+            WHERE {" AND ".join(conditions)}
+        """
+        count = (
+            await self.connection_manager.fetch_query(COUNT_QUERY, params)
+        )[0]["count"]
+
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name("graphs_communities")}
+            WHERE {" AND ".join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        communities = []
+        for row in rows:
+            community_dict = dict(row)
+            communities.append(Community(**community_dict))
+
+        return communities, count
+
+    async def add_community(self, community: Community) -> None:
+        # TODO: Fix in the short term.
+        # we need to do this because postgres insert needs to be a string
+        community.description_embedding = str(community.description_embedding)  # type: ignore[assignment]
+
+        non_null_attrs = {
+            k: v for k, v in community.__dict__.items() if v is not None
+        }
+        columns = ", ".join(non_null_attrs.keys())
+        placeholders = ", ".join(
+            f"${i + 1}" for i in range(len(non_null_attrs))
+        )
+
+        conflict_columns = ", ".join(
+            [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
+        )
+
+        QUERY = f"""
+            INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
+            VALUES ({placeholders})
+            ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
+                {conflict_columns}
+            """
+
+        await self.connection_manager.execute_many(
+            QUERY, [tuple(non_null_attrs.values())]
+        )
+
+    async def delete(self, collection_id: UUID) -> None:
+        graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
+
+        if len(graphs["results"]) == 0:
+            raise R2RException(
+                message=f"Graph not found for collection {collection_id}",
+                status_code=404,
+            )
+        await self.reset(collection_id)
+        # set status to PENDING for this collection.
+        QUERY = f"""
+            UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
+        """
+        await self.connection_manager.execute_query(
+            QUERY, [GraphExtractionStatus.PENDING, collection_id]
+        )
+        # Delete the graph
+        QUERY = f"""
+            DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
+        """
+        try:
+            await self.connection_manager.execute_query(QUERY, [collection_id])
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while deleting the graph: {e}",
+            ) from e
+
+    async def perform_graph_clustering(
+        self,
+        collection_id: UUID,
+        leiden_params: dict[str, Any],
+    ) -> Tuple[int, Any]:
+        """Calls the external clustering service to cluster the graph."""
+
+        offset = 0
+        page_size = 1000
+        all_relationships = []
+        while True:
+            relationships, count = await self.relationships.get(
+                parent_id=collection_id,
+                store_type=StoreType.GRAPHS,
+                offset=offset,
+                limit=page_size,
+            )
+
+            if not relationships:
+                break
+
+            all_relationships.extend(relationships)
+            offset += len(relationships)
+
+            if offset >= count:
+                break
+
+        logger.info(
+            f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
+        )
+        if len(all_relationships) == 0:
+            raise R2RException(
+                message="No relationships found for clustering",
+                status_code=400,
+            )
+
+        return await self._cluster_and_add_community_info(
+            relationships=all_relationships,
+            leiden_params=leiden_params,
+            collection_id=collection_id,
+        )
+
+    async def _call_clustering_service(
+        self, relationships: list[Relationship], leiden_params: dict[str, Any]
+    ) -> list[dict]:
+        """Calls the external Graspologic clustering service, sending
+        relationships and parameters.
+
+        Expects a response with 'communities' field.
+        """
+        # Convert relationships to a JSON-friendly format
+        rel_data = []
+        for r in relationships:
+            rel_data.append(
+                {
+                    "id": str(r.id),
+                    "subject": r.subject,
+                    "object": r.object,
+                    "weight": r.weight if r.weight is not None else 1.0,
+                }
+            )
+
+        endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
+        if not endpoint:
+            raise ValueError("CLUSTERING_SERVICE_URL not set.")
+
+        url = f"{endpoint}/cluster"
+
+        payload = {"relationships": rel_data, "leiden_params": leiden_params}
+
+        async with httpx.AsyncClient() as client:
+            response = await client.post(url, json=payload, timeout=3600)
+            response.raise_for_status()
+
+        data = response.json()
+        return data.get("communities", [])
+
+    async def _create_graph_and_cluster(
+        self,
+        relationships: list[Relationship],
+        leiden_params: dict[str, Any],
+    ) -> Any:
+        """Create a graph and cluster it."""
+
+        return await self._call_clustering_service(
+            relationships, leiden_params
+        )
+
+    async def _cluster_and_add_community_info(
+        self,
+        relationships: list[Relationship],
+        leiden_params: dict[str, Any],
+        collection_id: UUID,
+    ) -> Tuple[int, Any]:
+        logger.info(f"Creating graph and clustering for {collection_id}")
+
+        await asyncio.sleep(0.1)
+        start_time = time.time()
+
+        hierarchical_communities = await self._create_graph_and_cluster(
+            relationships=relationships,
+            leiden_params=leiden_params,
+        )
+
+        logger.info(
+            f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
+        )
+
+        if not hierarchical_communities:
+            num_communities = 0
+        else:
+            num_communities = (
+                max(item["cluster"] for item in hierarchical_communities) + 1
+            )
+
+        logger.info(
+            f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
+        )
+
+        return num_communities, hierarchical_communities
+
+    async def get_entity_map(
+        self, offset: int, limit: int, document_id: UUID
+    ) -> dict[str, dict[str, list[dict[str, Any]]]]:
+        QUERY1 = f"""
+            WITH entities_list AS (
+                SELECT DISTINCT name
+                FROM {self._get_table_name("documents_entities")}
+                WHERE parent_id = $1
+                ORDER BY name ASC
+                LIMIT {limit} OFFSET {offset}
+            )
+            SELECT e.name, e.description, e.category,
+                   (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
+                   e.parent_id
+            FROM {self._get_table_name("documents_entities")} e
+            JOIN entities_list el ON e.name = el.name
+            GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
+            ORDER BY e.name;"""
+
+        entities_list = await self.connection_manager.fetch_query(
+            QUERY1, [document_id]
+        )
+        entities_list = [Entity(**entity) for entity in entities_list]
+
+        QUERY2 = f"""
+            WITH entities_list AS (
+
+                SELECT DISTINCT name
+                FROM {self._get_table_name("documents_entities")}
+                WHERE parent_id = $1
+                ORDER BY name ASC
+                LIMIT {limit} OFFSET {offset}
+            )
+
+            SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
+                   (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
+            FROM {self._get_table_name("documents_relationships")} t
+            JOIN entities_list el ON t.subject = el.name
+            ORDER BY t.subject, t.predicate, t.object;
+        """
+
+        relationships_list = await self.connection_manager.fetch_query(
+            QUERY2, [document_id]
+        )
+        relationships_list = [
+            Relationship(**relationship) for relationship in relationships_list
+        ]
+
+        entity_map: dict[str, dict[str, list[Any]]] = {}
+        for entity in entities_list:
+            if entity.name not in entity_map:
+                entity_map[entity.name] = {"entities": [], "relationships": []}
+            entity_map[entity.name]["entities"].append(entity)
+
+        for relationship in relationships_list:
+            if relationship.subject in entity_map:
+                entity_map[relationship.subject]["relationships"].append(
+                    relationship
+                )
+            if relationship.object in entity_map:
+                entity_map[relationship.object]["relationships"].append(
+                    relationship
+                )
+
+        return entity_map
+
+    async def graph_search(
+        self, query: str, **kwargs: Any
+    ) -> AsyncGenerator[Any, None]:
+        """Perform semantic search with similarity scores while maintaining
+        exact same structure."""
+
+        query_embedding = kwargs.get("query_embedding", None)
+        if query_embedding is None:
+            raise ValueError(
+                "query_embedding must be provided for semantic search"
+            )
+
+        search_type = kwargs.get(
+            "search_type", "entities"
+        )  # entities | relationships | communities
+        embedding_type = kwargs.get("embedding_type", "description_embedding")
+        property_names = kwargs.get("property_names", ["name", "description"])
+
+        # Add metadata if not present
+        if "metadata" not in property_names:
+            property_names.append("metadata")
+
+        filters = kwargs.get("filters", {})
+        limit = kwargs.get("limit", 10)
+        use_fulltext_search = kwargs.get("use_fulltext_search", True)
+        use_hybrid_search = kwargs.get("use_hybrid_search", True)
+
+        if use_hybrid_search or use_fulltext_search:
+            logger.warning(
+                "Hybrid and fulltext search not supported for graph search, ignoring."
+            )
+
+        table_name = f"graphs_{search_type}"
+        property_names_str = ", ".join(property_names)
+
+        # Build the WHERE clause from filters
+        params: list[str | int | bytes] = [
+            json.dumps(query_embedding),
+            limit,
+        ]
+        conditions_clause = self._build_filters(filters, params, search_type)
+        where_clause = (
+            f"WHERE {conditions_clause}" if conditions_clause else ""
+        )
+
+        # Construct the query
+        # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
+        # We'll convert that to similarity_score by doing (1 - distance).
+        QUERY = f"""
+            SELECT
+                {property_names_str},
+                ({embedding_type} <=> $1) as similarity_score
+            FROM {self._get_table_name(table_name)}
+            {where_clause}
+            ORDER BY {embedding_type} <=> $1
+            LIMIT $2;
+        """
+
+        results = await self.connection_manager.fetch_query(
+            QUERY, tuple(params)
+        )
+
+        for result in results:
+            output = {
+                prop: result[prop] for prop in property_names if prop in result
+            }
+            output["similarity_score"] = (
+                1 - float(result["similarity_score"])
+                if result.get("similarity_score")
+                else "n/a"
+            )
+            yield output
+
+    def _build_filters(
+        self, filter_dict: dict, parameters: list[Any], search_type: str
+    ) -> str:
+        """Build a WHERE clause from a nested filter dictionary for the graph
+        search.
+
+        - If search_type == "communities", we normally filter by `collection_id`.
+        - Otherwise (entities/relationships), we normally filter by `parent_id`.
+        - If user provides `"collection_ids": {...}`, we interpret that as wanting
+        to filter by multiple collection IDs (i.e. 'parent_id IN (...)' or
+        'collection_id IN (...)').
+        """
+
+        # The usual "base" column used by your code
+        base_id_column = (
+            "collection_id" if search_type == "communities" else "parent_id"
+        )
+
+        def parse_condition(key: str, value: Any) -> str:
+            # ----------------------------------------------------------------------
+            # 1) If it's the normal base_id_column (like "parent_id" or "collection_id")
+            # ----------------------------------------------------------------------
+            if key == base_id_column:
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        # single equality
+                        parameters.append(str(clause))
+                        return f"{base_id_column} = ${len(parameters)}::uuid"
+                    elif op in ("$in", "$overlap"):
+                        # treat both $in/$overlap as "IN the set" for a single column
+                        array_val = [str(x) for x in clause]
+                        parameters.append(array_val)
+                        return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
+                    # handle other operators as needed
+                else:
+                    # direct equality
+                    parameters.append(str(value))
+                    return f"{base_id_column} = ${len(parameters)}::uuid"
+
+            # ----------------------------------------------------------------------
+            # 2) SPECIAL: if user specifically sets "collection_ids" in filters
+            #    We interpret that to mean "Look for rows whose parent_id (or collection_id)
+            #    is in the array of values" – i.e. we do the same logic but we forcibly
+            #    direct it to the same column: parent_id or collection_id.
+            # ----------------------------------------------------------------------
+            elif key == "collection_ids":
+                # If we are searching communities, the relevant field is `collection_id`.
+                # If searching entities/relationships, the relevant field is `parent_id`.
+                col_to_use = (
+                    "collection_id"
+                    if search_type == "communities"
+                    else "parent_id"
+                )
+
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        # single equality => col_to_use = clause
+                        parameters.append(str(clause))
+                        return f"{col_to_use} = ${len(parameters)}::uuid"
+                    elif op in ("$in", "$overlap"):
+                        # "col_to_use = ANY($param::uuid[])"
+                        array_val = [str(x) for x in clause]
+                        parameters.append(array_val)
+                        return (
+                            f"{col_to_use} = ANY(${len(parameters)}::uuid[])"
+                        )
+                    # add more if you want, e.g. $ne, $gt, etc.
+                else:
+                    # direct equality scenario: "collection_ids": "some-uuid"
+                    parameters.append(str(value))
+                    return f"{col_to_use} = ${len(parameters)}::uuid"
+
+            # ----------------------------------------------------------------------
+            # 3) If key starts with "metadata.", handle metadata-based filters
+            # ----------------------------------------------------------------------
+            elif key.startswith("metadata."):
+                field = key.split("metadata.")[1]
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}') = ${len(parameters)}"
+                    elif op == "$ne":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}') != ${len(parameters)}"
+                    elif op == "$gt":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
+                    # etc...
+                else:
+                    parameters.append(value)
+                    return f"(metadata->>'{field}') = ${len(parameters)}"
+
+            # ----------------------------------------------------------------------
+            # 4) Not recognized => return empty so we skip it
+            # ----------------------------------------------------------------------
+            return ""
+
+        # --------------------------------------------------------------------------
+        # 5) parse_filter() is the recursive walker that sees $and/$or or normal fields
+        # --------------------------------------------------------------------------
+        def parse_filter(fd: dict) -> str:
+            filter_conditions = []
+            for k, v in fd.items():
+                if k == "$and":
+                    and_parts = [parse_filter(sub) for sub in v if sub]
+                    and_parts = [x for x in and_parts if x.strip()]
+                    if and_parts:
+                        filter_conditions.append(
+                            f"({' AND '.join(and_parts)})"
+                        )
+                elif k == "$or":
+                    or_parts = [parse_filter(sub) for sub in v if sub]
+                    or_parts = [x for x in or_parts if x.strip()]
+                    if or_parts:
+                        filter_conditions.append(f"({' OR '.join(or_parts)})")
+                else:
+                    c = parse_condition(k, v)
+                    if c and c.strip():
+                        filter_conditions.append(c)
+
+            if not filter_conditions:
+                return ""
+            if len(filter_conditions) == 1:
+                return filter_conditions[0]
+            return " AND ".join(filter_conditions)
+
+        return parse_filter(filter_dict)
+
+    async def get_existing_document_entity_chunk_ids(
+        self, document_id: UUID
+    ) -> list[str]:
+        QUERY = f"""
+            SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
+        """
+        return [
+            item["chunk_id"]
+            for item in await self.connection_manager.fetch_query(
+                QUERY, [document_id]
+            )
+        ]
+
+    async def get_entity_count(
+        self,
+        collection_id: Optional[UUID] = None,
+        document_id: Optional[UUID] = None,
+        distinct: bool = False,
+        entity_table_name: str = "entity",
+    ) -> int:
+        if collection_id is None and document_id is None:
+            raise ValueError(
+                "Either collection_id or document_id must be provided."
+            )
+
+        conditions = ["parent_id = $1"]
+        params = [str(document_id)]
+
+        count_value = "DISTINCT name" if distinct else "*"
+
+        QUERY = f"""
+            SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
+            WHERE {" AND ".join(conditions)}
+        """
+
+        return (await self.connection_manager.fetch_query(QUERY, params))[0][
+            "count"
+        ]
+
+    async def update_entity_descriptions(self, entities: list[Entity]):
+        query = f"""
+            UPDATE {self._get_table_name("graphs_entities")}
+            SET description = $3, description_embedding = $4
+            WHERE name = $1 AND graph_id = $2
+        """
+
+        inputs = [
+            (
+                entity.name,
+                entity.parent_id,
+                entity.description,
+                entity.description_embedding,
+            )
+            for entity in entities
+        ]
+
+        await self.connection_manager.execute_many(query, inputs)  # type: ignore
+
+
+def _json_serialize(obj):
+    if isinstance(obj, UUID):
+        return str(obj)
+    elif isinstance(obj, (datetime.datetime, datetime.date)):
+        return obj.isoformat()
+    raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
+
+
+async def _add_objects(
+    objects: list[dict],
+    full_table_name: str,
+    connection_manager: PostgresConnectionManager,
+    conflict_columns: list[str] | None = None,
+    exclude_metadata: list[str] | None = None,
+) -> list[UUID]:
+    """Bulk insert objects into the specified table using
+    jsonb_to_recordset."""
+
+    if conflict_columns is None:
+        conflict_columns = []
+    if exclude_metadata is None:
+        exclude_metadata = []
+
+    # Exclude specified metadata and prepare data
+    cleaned_objects = []
+    for obj in objects:
+        cleaned_obj = {
+            k: v
+            for k, v in obj.items()
+            if k not in exclude_metadata and v is not None
+        }
+        cleaned_objects.append(cleaned_obj)
+
+    # Serialize the list of objects to JSON
+    json_data = json.dumps(cleaned_objects, default=_json_serialize)
+
+    # Prepare the column definitions for jsonb_to_recordset
+
+    columns = cleaned_objects[0].keys()
+    column_defs = []
+    for col in columns:
+        # Map Python types to PostgreSQL types
+        sample_value = cleaned_objects[0][col]
+        if "embedding" in col:
+            pg_type = "vector"
+        elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
+            pg_type = "uuid[]"
+        elif col == "id" or "_id" in col:
+            pg_type = "uuid"
+        elif isinstance(sample_value, str):
+            pg_type = "text"
+        elif isinstance(sample_value, UUID):
+            pg_type = "uuid"
+        elif isinstance(sample_value, (int, float)):
+            pg_type = "numeric"
+        elif isinstance(sample_value, list) and all(
+            isinstance(x, UUID) for x in sample_value
+        ):
+            pg_type = "uuid[]"
+        elif isinstance(sample_value, list):
+            pg_type = "jsonb"
+        elif isinstance(sample_value, dict):
+            pg_type = "jsonb"
+        elif isinstance(sample_value, bool):
+            pg_type = "boolean"
+        elif isinstance(sample_value, (datetime.datetime, datetime.date)):
+            pg_type = "timestamp"
+        else:
+            raise TypeError(
+                f"Unsupported data type for column '{col}': {type(sample_value)}"
+            )
+
+        column_defs.append(f"{col} {pg_type}")
+
+    columns_str = ", ".join(columns)
+    column_defs_str = ", ".join(column_defs)
+
+    if conflict_columns:
+        conflict_columns_str = ", ".join(conflict_columns)
+        update_columns_str = ", ".join(
+            f"{col}=EXCLUDED.{col}"
+            for col in columns
+            if col not in conflict_columns
+        )
+        on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
+    else:
+        on_conflict_clause = ""
+
+    QUERY = f"""
+        INSERT INTO {full_table_name} ({columns_str})
+        SELECT {columns_str}
+        FROM jsonb_to_recordset($1::jsonb)
+        AS x({column_defs_str})
+        {on_conflict_clause}
+        RETURNING id;
+    """
+
+    # Execute the query
+    result = await connection_manager.fetch_query(QUERY, [json_data])
+
+    # Extract and return the IDs
+    return [record["id"] for record in result]