aboutsummaryrefslogtreecommitdiff
import json
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
from uuid import UUID

import aiofiles

from shared.api.models import (
    WrappedBooleanResponse,
    WrappedChunksResponse,
    WrappedCollectionsResponse,
    WrappedDocumentResponse,
    WrappedDocumentSearchResponse,
    WrappedDocumentsResponse,
    WrappedEntitiesResponse,
    WrappedGenericMessageResponse,
    WrappedIngestionResponse,
    WrappedRelationshipsResponse,
)

from ..models import IngestionMode, SearchMode, SearchSettings


class DocumentsSDK:
    """SDK for interacting with documents in the v3 API."""

    def __init__(self, client):
        self.client = client

    async def create(
        self,
        file_path: Optional[str] = None,
        raw_text: Optional[str] = None,
        chunks: Optional[list[str]] = None,
        id: Optional[str | UUID] = None,
        ingestion_mode: Optional[str] = None,
        collection_ids: Optional[list[str | UUID]] = None,
        metadata: Optional[dict] = None,
        ingestion_config: Optional[dict | IngestionMode] = None,
        run_with_orchestration: Optional[bool] = True,
    ) -> WrappedIngestionResponse:
        """Create a new document from either a file or content.

        Args:
            file_path (Optional[str]): The file to upload, if any
            content (Optional[str]): Optional text content to upload, if no file path is provided
            id (Optional[str | UUID]): Optional ID to assign to the document
            collection_ids (Optional[list[str | UUID]]): Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.
            metadata (Optional[dict]): Optional metadata to assign to the document
            ingestion_config (Optional[dict]): Optional ingestion configuration to use
            run_with_orchestration (Optional[bool]): Whether to run with orchestration

        Returns:
            WrappedIngestionResponse
        """
        if not file_path and not raw_text and not chunks:
            raise ValueError(
                "Either `file_path`, `raw_text` or `chunks` must be provided"
            )
        if (
            (file_path and raw_text)
            or (file_path and chunks)
            or (raw_text and chunks)
        ):
            raise ValueError(
                "Only one of `file_path`, `raw_text` or `chunks` may be provided"
            )

        data: dict[str, Any] = {}
        files = None

        if id:
            data["id"] = str(id)
        if metadata:
            data["metadata"] = json.dumps(metadata)
        if ingestion_config:
            if isinstance(ingestion_config, IngestionMode):
                ingestion_config = {"mode": ingestion_config.value}
            app_config: dict[str, Any] = (
                {}
                if isinstance(ingestion_config, dict)
                else ingestion_config["app"]
            )
            ingestion_config = dict(ingestion_config)
            ingestion_config["app"] = app_config
            data["ingestion_config"] = json.dumps(ingestion_config)
        if collection_ids:
            collection_ids = [
                str(collection_id) for collection_id in collection_ids
            ]  # type: ignore
            data["collection_ids"] = json.dumps(collection_ids)
        if run_with_orchestration is not None:
            data["run_with_orchestration"] = str(run_with_orchestration)
        if ingestion_mode is not None:
            data["ingestion_mode"] = ingestion_mode
        if file_path:
            # Create a new file instance that will remain open during the request
            file_instance = open(file_path, "rb")
            files = [
                (
                    "file",
                    (file_path, file_instance, "application/octet-stream"),
                )
            ]
            try:
                response_dict = await self.client._make_request(
                    "POST",
                    "documents",
                    data=data,
                    files=files,
                    version="v3",
                )
            finally:
                # Ensure we close the file after the request is complete
                file_instance.close()
        elif raw_text:
            data["raw_text"] = raw_text  # type: ignore
            response_dict = await self.client._make_request(
                "POST",
                "documents",
                data=data,
                version="v3",
            )
        else:
            data["chunks"] = json.dumps(chunks)
            response_dict = await self.client._make_request(
                "POST",
                "documents",
                data=data,
                version="v3",
            )

        return WrappedIngestionResponse(**response_dict)

    async def append_metadata(
        self,
        id: str | UUID,
        metadata: list[dict],
    ) -> WrappedDocumentResponse:
        """Append metadata to a document.

        Args:
            id (str | UUID): ID of document to append metadata to
            metadata (list[dict]): Metadata to append

        Returns:
            WrappedDocumentResponse
        """
        data = json.dumps(metadata)
        response_dict = await self.client._make_request(
            "PATCH",
            f"documents/{str(id)}/metadata",
            data=data,
            version="v3",
        )

        return WrappedDocumentResponse(**response_dict)

    async def replace_metadata(
        self,
        id: str | UUID,
        metadata: list[dict],
    ) -> WrappedDocumentResponse:
        """Replace metadata for a document.

        Args:
            id (str | UUID): ID of document to replace metadata for
            metadata (list[dict]): The metadata that will replace the existing metadata
        """
        data = json.dumps(metadata)
        response_dict = await self.client._make_request(
            "PUT",
            f"documents/{str(id)}/metadata",
            data=data,
            version="v3",
        )

        return WrappedDocumentResponse(**response_dict)

    async def retrieve(
        self,
        id: str | UUID,
    ) -> WrappedDocumentResponse:
        """Get a specific document by ID.

        Args:
            id (str | UUID): ID of document to retrieve

        Returns:
            WrappedDocumentResponse
        """
        response_dict = await self.client._make_request(
            "GET",
            f"documents/{str(id)}",
            version="v3",
        )

        return WrappedDocumentResponse(**response_dict)

    async def download(
        self,
        id: str | UUID,
    ) -> BytesIO:
        response = await self.client._make_request(
            "GET",
            f"documents/{str(id)}/download",
            version="v3",
        )
        if not isinstance(response, BytesIO):
            raise ValueError("Expected BytesIO response")
        return response

    async def download_zip(
        self,
        document_ids: Optional[list[str | UUID]] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        output_path: Optional[str | Path] = None,
    ) -> BytesIO | None:
        """Download multiple documents as a zip file."""
        params: dict[str, Any] = {}
        if document_ids:
            params["document_ids"] = [str(doc_id) for doc_id in document_ids]
        if start_date:
            params["start_date"] = start_date.isoformat()
        if end_date:
            params["end_date"] = end_date.isoformat()

        response = await self.client._make_request(
            "GET",
            "documents/download_zip",
            params=params,
            version="v3",
        )

        if not isinstance(response, BytesIO):
            raise ValueError("Expected BytesIO response")

        if output_path:
            output_path = (
                Path(output_path)
                if isinstance(output_path, str)
                else output_path
            )
            async with aiofiles.open(output_path, "wb") as f:
                await f.write(response.getvalue())
            return None

        return response

    async def export(
        self,
        output_path: str | Path,
        columns: Optional[list[str]] = None,
        filters: Optional[dict] = None,
        include_header: bool = True,
    ) -> None:
        """Export documents to a CSV file, streaming the results directly to
        disk.

        Args:
            output_path (str | Path): Local path where the CSV file should be saved
            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
            filters (Optional[dict]): Optional filters to apply when selecting documents
            include_header (bool): Whether to include column headers in the CSV (default: True)

        Returns:
            None
        """
        # Convert path to string if it's a Path object
        output_path = (
            str(output_path) if isinstance(output_path, Path) else output_path
        )

        data: dict[str, Any] = {"include_header": include_header}
        if columns:
            data["columns"] = columns
        if filters:
            data["filters"] = filters

        # Stream response directly to file
        async with aiofiles.open(output_path, "wb") as f:
            async with self.client.session.post(
                f"{self.client.base_url}/v3/documents/export",
                json=data,
                headers={
                    "Accept": "text/csv",
                    **self.client._get_auth_header(),
                },
            ) as response:
                if response.status != 200:
                    raise ValueError(
                        f"Export failed with status {response.status}",
                        response,
                    )

                async for chunk in response.content.iter_chunks():
                    if chunk:
                        await f.write(chunk[0])

    async def export_entities(
        self,
        id: str | UUID,
        output_path: str | Path,
        columns: Optional[list[str]] = None,
        filters: Optional[dict] = None,
        include_header: bool = True,
    ) -> None:
        """Export documents to a CSV file, streaming the results directly to
        disk.

        Args:
            output_path (str | Path): Local path where the CSV file should be saved
            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
            filters (Optional[dict]): Optional filters to apply when selecting documents
            include_header (bool): Whether to include column headers in the CSV (default: True)

        Returns:
            None
        """
        # Convert path to string if it's a Path object
        output_path = (
            str(output_path) if isinstance(output_path, Path) else output_path
        )

        # Prepare request data
        data: dict[str, Any] = {"include_header": include_header}
        if columns:
            data["columns"] = columns
        if filters:
            data["filters"] = filters

        # Stream response directly to file
        async with aiofiles.open(output_path, "wb") as f:
            async with self.client.session.post(
                f"{self.client.base_url}/v3/documents/{str(id)}/entities/export",
                json=data,
                headers={
                    "Accept": "text/csv",
                    **self.client._get_auth_header(),
                },
            ) as response:
                if response.status != 200:
                    raise ValueError(
                        f"Export failed with status {response.status}",
                        response,
                    )

                async for chunk in response.content.iter_chunks():
                    if chunk:
                        await f.write(chunk[0])

    async def export_relationships(
        self,
        id: str | UUID,
        output_path: str | Path,
        columns: Optional[list[str]] = None,
        filters: Optional[dict] = None,
        include_header: bool = True,
    ) -> None:
        """Export document relationships to a CSV file, streaming the results
        directly to disk.

        Args:
            output_path (str | Path): Local path where the CSV file should be saved
            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
            filters (Optional[dict]): Optional filters to apply when selecting documents
            include_header (bool): Whether to include column headers in the CSV (default: True)

        Returns:
            None
        """
        # Convert path to string if it's a Path object
        output_path = (
            str(output_path) if isinstance(output_path, Path) else output_path
        )

        # Prepare request data
        data: dict[str, Any] = {"include_header": include_header}
        if columns:
            data["columns"] = columns
        if filters:
            data["filters"] = filters

        # Stream response directly to file
        async with aiofiles.open(output_path, "wb") as f:
            async with self.client.session.post(
                f"{self.client.base_url}/v3/documents/{str(id)}/relationships/export",
                json=data,
                headers={
                    "Accept": "text/csv",
                    **self.client._get_auth_header(),
                },
            ) as response:
                if response.status != 200:
                    raise ValueError(
                        f"Export failed with status {response.status}",
                        response,
                    )

                async for chunk in response.content.iter_chunks():
                    if chunk:
                        await f.write(chunk[0])

    async def delete(
        self,
        id: str | UUID,
    ) -> WrappedBooleanResponse:
        """Delete a specific document.

        Args:
            id (str | UUID): ID of document to delete

        Returns:
            WrappedBooleanResponse
        """
        response_dict = await self.client._make_request(
            "DELETE",
            f"documents/{str(id)}",
            version="v3",
        )

        return WrappedBooleanResponse(**response_dict)

    async def list_chunks(
        self,
        id: str | UUID,
        include_vectors: Optional[bool] = False,
        offset: Optional[int] = 0,
        limit: Optional[int] = 100,
    ) -> WrappedChunksResponse:
        """Get chunks for a specific document.

        Args:
            id (str | UUID): ID of document to retrieve chunks for
            include_vectors (Optional[bool]): Whether to include vector embeddings in the response
            offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
            limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.

        Returns:
            WrappedChunksResponse
        """
        params = {
            "offset": offset,
            "limit": limit,
            "include_vectors": include_vectors,
        }
        response_dict = await self.client._make_request(
            "GET",
            f"documents/{str(id)}/chunks",
            params=params,
            version="v3",
        )

        return WrappedChunksResponse(**response_dict)

    async def list_collections(
        self,
        id: str | UUID,
        include_vectors: Optional[bool] = False,
        offset: Optional[int] = 0,
        limit: Optional[int] = 100,
    ) -> WrappedCollectionsResponse:
        """List collections for a specific document.

        Args:
            id (str | UUID): ID of document to retrieve collections for
            offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
            limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.

        Returns:
            WrappedCollectionsResponse
        """
        params = {
            "offset": offset,
            "limit": limit,
        }

        response_dict = await self.client._make_request(
            "GET",
            f"documents/{str(id)}/collections",
            params=params,
            version="v3",
        )

        return WrappedCollectionsResponse(**response_dict)

    async def delete_by_filter(
        self,
        filters: dict,
    ) -> WrappedBooleanResponse:
        """Delete documents based on filters.

        Args:
            filters (dict): Filters to apply when selecting documents to delete

        Returns:
            WrappedBooleanResponse
        """
        filters_json = json.dumps(filters)
        response_dict = await self.client._make_request(
            "DELETE",
            "documents/by-filter",
            data=filters_json,
            version="v3",
        )

        return WrappedBooleanResponse(**response_dict)

    async def extract(
        self,
        id: str | UUID,
        settings: Optional[dict] = None,
        run_with_orchestration: Optional[bool] = True,
    ) -> WrappedGenericMessageResponse:
        """Extract entities and relationships from a document.

        Args:
            id (str, UUID): ID of document to extract from
            settings (Optional[dict]): Settings for extraction process
            run_with_orchestration (Optional[bool]): Whether to run with orchestration

        Returns:
            WrappedGenericMessageResponse
        """
        data: dict[str, Any] = {}
        if settings:
            data["settings"] = json.dumps(settings)
        if run_with_orchestration is not None:
            data["run_with_orchestration"] = str(run_with_orchestration)

        response_dict = await self.client._make_request(
            "POST",
            f"documents/{str(id)}/extract",
            params=data,
            version="v3",
        )
        return WrappedGenericMessageResponse(**response_dict)

    async def list_entities(
        self,
        id: str | UUID,
        offset: Optional[int] = 0,
        limit: Optional[int] = 100,
        include_embeddings: Optional[bool] = False,
    ) -> WrappedEntitiesResponse:
        """List entities extracted from a document.

        Args:
            id (str | UUID): ID of document to get entities from
            offset (Optional[int]): Number of items to skip
            limit (Optional[int]): Max number of items to return
            include_embeddings (Optional[bool]): Whether to include embeddings

        Returns:
            WrappedEntitiesResponse
        """
        params = {
            "offset": offset,
            "limit": limit,
            "include_embeddings": include_embeddings,
        }
        response_dict = await self.client._make_request(
            "GET",
            f"documents/{str(id)}/entities",
            params=params,
            version="v3",
        )

        return WrappedEntitiesResponse(**response_dict)

    async def list_relationships(
        self,
        id: str | UUID,
        offset: Optional[int] = 0,
        limit: Optional[int] = 100,
        entity_names: Optional[list[str]] = None,
        relationship_types: Optional[list[str]] = None,
    ) -> WrappedRelationshipsResponse:
        """List relationships extracted from a document.

        Args:
            id (str | UUID): ID of document to get relationships from
            offset (Optional[int]): Number of items to skip
            limit (Optional[int]): Max number of items to return
            entity_names (Optional[list[str]]): Filter by entity names
            relationship_types (Optional[list[str]]): Filter by relationship types

        Returns:
            WrappedRelationshipsResponse
        """
        params: dict[str, Any] = {
            "offset": offset,
            "limit": limit,
        }
        if entity_names:
            params["entity_names"] = entity_names
        if relationship_types:
            params["relationship_types"] = relationship_types

        response_dict = await self.client._make_request(
            "GET",
            f"documents/{str(id)}/relationships",
            params=params,
            version="v3",
        )

        return WrappedRelationshipsResponse(**response_dict)

    async def list(
        self,
        ids: Optional[list[str | UUID]] = None,
        offset: Optional[int] = 0,
        limit: Optional[int] = 100,
    ) -> WrappedDocumentsResponse:
        """List documents with pagination.

        Args:
            ids (Optional[list[str | UUID]]): Optional list of document IDs to filter by
            offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
            limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.

        Returns:
            WrappedDocumentsResponse
        """
        params = {
            "offset": offset,
            "limit": limit,
        }
        if ids:
            params["ids"] = [str(doc_id) for doc_id in ids]  # type: ignore

        response_dict = await self.client._make_request(
            "GET",
            "documents",
            params=params,
            version="v3",
        )

        return WrappedDocumentsResponse(**response_dict)

    async def search(
        self,
        query: str,
        search_mode: Optional[str | SearchMode] = "custom",
        search_settings: Optional[dict | SearchSettings] = None,
    ) -> WrappedDocumentSearchResponse:
        """Conduct a vector and/or graph search.

        Args:
            query (str): The query to search for.
            search_settings (Optional[dict, SearchSettings]]): Vector search settings.

        Returns:
            WrappedDocumentSearchResponse
        """
        if search_settings and not isinstance(search_settings, dict):
            search_settings = search_settings.model_dump()
        data: dict[str, Any] = {
            "query": query,
            "search_settings": search_settings,
        }
        if search_mode:
            data["search_mode"] = search_mode

        response_dict = await self.client._make_request(
            "POST",
            "documents/search",
            json=data,
            version="v3",
        )

        return WrappedDocumentSearchResponse(**response_dict)

    async def deduplicate(
        self,
        id: str | UUID,
        settings: Optional[dict] = None,
        run_with_orchestration: Optional[bool] = True,
    ) -> WrappedGenericMessageResponse:
        """Deduplicate entities and relationships from a document.

        Args:
            id (str, UUID): ID of document to extract from
            settings (Optional[dict]): Settings for extraction process
            run_with_orchestration (Optional[bool]): Whether to run with orchestration

        Returns:
            WrappedGenericMessageResponse
        """
        data: dict[str, Any] = {}
        if settings:
            data["settings"] = json.dumps(settings)
        if run_with_orchestration is not None:
            data["run_with_orchestration"] = str(run_with_orchestration)

        response_dict = await self.client._make_request(
            "POST",
            f"documents/{str(id)}/deduplicate",
            params=data,
            version="v3",
        )

        return WrappedGenericMessageResponse(**response_dict)