about summary refs log tree commit diff
import ast
import asyncio
import json
import os
import uuid
from typing import Optional, Union

from fastapi import UploadFile

from r2r.base import (
    AnalysisTypes,
    FilterCriteria,
    GenerationConfig,
    KGSearchSettings,
    VectorSearchSettings,
    generate_id_from_label,
)

from .api.client import R2RClient
from .assembly.builder import R2RBuilder
from .assembly.config import R2RConfig
from .r2r import R2R


class R2RExecutionWrapper:
    """A demo class for the R2R library."""

    def __init__(
        self,
        config_path: Optional[str] = None,
        config_name: Optional[str] = "default",
        client_mode: bool = True,
        base_url="http://localhost:8000",
    ):
        if config_path and config_name:
            raise Exception("Cannot specify both config_path and config_name")

        # Handle fire CLI
        if isinstance(client_mode, str):
            client_mode = client_mode.lower() == "true"
        self.client_mode = client_mode
        self.base_url = base_url

        if self.client_mode:
            self.client = R2RClient(base_url)
            self.app = None
        else:
            config = (
                R2RConfig.from_json(config_path)
                if config_path
                else R2RConfig.from_json(
                    R2RBuilder.CONFIG_OPTIONS[config_name or "default"]
                )
            )

            self.client = None
            self.app = R2R(config=config)

    def serve(self, host: str = "0.0.0.0", port: int = 8000):
        if not self.client_mode:
            self.app.serve(host, port)
        else:
            raise ValueError(
                "Serve method is only available when `client_mode=False`."
            )

    def _parse_metadata_string(metadata_string: str) -> list[dict]:
        """
        Convert a string representation of metadata into a list of dictionaries.

        The input string can be in one of two formats:
        1. JSON array of objects: '[{"key": "value"}, {"key2": "value2"}]'
        2. Python-like list of dictionaries: "[{'key': 'value'}, {'key2': 'value2'}]"

        Args:
        metadata_string (str): The string representation of metadata.

        Returns:
        list[dict]: A list of dictionaries representing the metadata.

        Raises:
        ValueError: If the string cannot be parsed into a list of dictionaries.
        """
        if not metadata_string:
            return []

        try:
            # First, try to parse as JSON
            return json.loads(metadata_string)
        except json.JSONDecodeError as e:
            try:
                # If JSON parsing fails, try to evaluate as a Python literal
                result = ast.literal_eval(metadata_string)
                if not isinstance(result, list) or not all(
                    isinstance(item, dict) for item in result
                ):
                    raise ValueError(
                        "The string does not represent a list of dictionaries"
                    ) from e
                return result
            except (ValueError, SyntaxError) as exc:
                raise ValueError(
                    "Unable to parse the metadata string. "
                    "Please ensure it's a valid JSON array or Python list of dictionaries."
                ) from exc

    def ingest_files(
        self,
        file_paths: list[str],
        metadatas: Optional[list[dict]] = None,
        document_ids: Optional[list[Union[uuid.UUID, str]]] = None,
        versions: Optional[list[str]] = None,
    ):
        if isinstance(file_paths, str):
            file_paths = list(file_paths.split(","))
        if isinstance(metadatas, str):
            metadatas = self._parse_metadata_string(metadatas)
        if isinstance(document_ids, str):
            document_ids = list(document_ids.split(","))
        if isinstance(versions, str):
            versions = list(versions.split(","))

        all_file_paths = []
        for path in file_paths:
            if os.path.isdir(path):
                for root, _, files in os.walk(path):
                    all_file_paths.extend(
                        os.path.join(root, file) for file in files
                    )
            else:
                all_file_paths.append(path)

        if not document_ids:
            document_ids = [
                generate_id_from_label(os.path.basename(file_path))
                for file_path in all_file_paths
            ]

        files = [
            UploadFile(
                filename=os.path.basename(file_path),
                file=open(file_path, "rb"),
            )
            for file_path in all_file_paths
        ]

        for file in files:
            file.file.seek(0, 2)
            file.size = file.file.tell()
            file.file.seek(0)

        try:
            if self.client_mode:
                return self.client.ingest_files(
                    file_paths=all_file_paths,
                    document_ids=document_ids,
                    metadatas=metadatas,
                    versions=versions,
                    monitor=True,
                )["results"]
            else:
                return self.app.ingest_files(
                    files=files,
                    document_ids=document_ids,
                    metadatas=metadatas,
                    versions=versions,
                )
        finally:
            for file in files:
                file.file.close()

    def update_files(
        self,
        file_paths: list[str],
        document_ids: list[str],
        metadatas: Optional[list[dict]] = None,
    ):
        if isinstance(file_paths, str):
            file_paths = list(file_paths.split(","))
        if isinstance(metadatas, str):
            metadatas = self._parse_metadata_string(metadatas)
        if isinstance(document_ids, str):
            document_ids = list(document_ids.split(","))

        if self.client_mode:
            return self.client.update_files(
                file_paths=file_paths,
                document_ids=document_ids,
                metadatas=metadatas,
                monitor=True,
            )["results"]
        else:
            files = [
                UploadFile(
                    filename=file_path,
                    file=open(file_path, "rb"),
                )
                for file_path in file_paths
            ]
            return self.app.update_files(
                files=files, document_ids=document_ids, metadatas=metadatas
            )

    def search(
        self,
        query: str,
        use_vector_search: bool = True,
        search_filters: Optional[dict] = None,
        search_limit: int = 10,
        do_hybrid_search: bool = False,
        use_kg_search: bool = False,
        kg_agent_generation_config: Optional[dict] = None,
    ):
        if self.client_mode:
            return self.client.search(
                query,
                use_vector_search,
                search_filters,
                search_limit,
                do_hybrid_search,
                use_kg_search,
                kg_agent_generation_config,
            )["results"]
        else:
            return self.app.search(
                query,
                VectorSearchSettings(
                    use_vector_search=use_vector_search,
                    search_filters=search_filters or {},
                    search_limit=search_limit,
                    do_hybrid_search=do_hybrid_search,
                ),
                KGSearchSettings(
                    use_kg_search=use_kg_search,
                    agent_generation_config=GenerationConfig(
                        **(kg_agent_generation_config or {})
                    ),
                ),
            )

    def rag(
        self,
        query: str,
        use_vector_search: bool = True,
        search_filters: Optional[dict] = None,
        search_limit: int = 10,
        do_hybrid_search: bool = False,
        use_kg_search: bool = False,
        kg_agent_generation_config: Optional[dict] = None,
        stream: bool = False,
        rag_generation_config: Optional[dict] = None,
    ):
        if self.client_mode:
            response = self.client.rag(
                query=query,
                use_vector_search=use_vector_search,
                search_filters=search_filters or {},
                search_limit=search_limit,
                do_hybrid_search=do_hybrid_search,
                use_kg_search=use_kg_search,
                kg_agent_generation_config=kg_agent_generation_config,
                rag_generation_config=rag_generation_config,
            )
            if not stream:
                response = response["results"]
                return response
            else:
                return response
        else:
            response = self.app.rag(
                query,
                vector_search_settings=VectorSearchSettings(
                    use_vector_search=use_vector_search,
                    search_filters=search_filters or {},
                    search_limit=search_limit,
                    do_hybrid_search=do_hybrid_search,
                ),
                kg_search_settings=KGSearchSettings(
                    use_kg_search=use_kg_search,
                    agent_generation_config=GenerationConfig(
                        **(kg_agent_generation_config or {})
                    ),
                ),
                rag_generation_config=GenerationConfig(
                    **(rag_generation_config or {})
                ),
            )
            if not stream:
                return response
            else:

                async def async_generator():
                    async for chunk in response:
                        yield chunk

                def sync_generator():
                    try:
                        loop = asyncio.get_event_loop()
                        async_gen = async_generator()
                        while True:
                            try:
                                yield loop.run_until_complete(
                                    async_gen.__anext__()
                                )
                            except StopAsyncIteration:
                                break
                    except Exception:
                        pass

                return sync_generator()

    def documents_overview(
        self,
        document_ids: Optional[list[str]] = None,
        user_ids: Optional[list[str]] = None,
    ):
        if self.client_mode:
            return self.client.documents_overview(document_ids, user_ids)[
                "results"
            ]
        else:
            return self.app.documents_overview(document_ids, user_ids)

    def delete(
        self,
        keys: list[str],
        values: list[str],
    ):
        if self.client_mode:
            return self.client.delete(keys, values)["results"]
        else:
            return self.app.delete(keys, values)

    def logs(self, log_type_filter: Optional[str] = None):
        if self.client_mode:
            return self.client.logs(log_type_filter)["results"]
        else:
            return self.app.logs(log_type_filter)

    def document_chunks(self, document_id: str):
        doc_uuid = uuid.UUID(document_id)
        if self.client_mode:
            return self.client.document_chunks(doc_uuid)["results"]
        else:
            return self.app.document_chunks(doc_uuid)

    def app_settings(self):
        if self.client_mode:
            return self.client.app_settings()
        else:
            return self.app.app_settings()

    def users_overview(self, user_ids: Optional[list[uuid.UUID]] = None):
        if self.client_mode:
            return self.client.users_overview(user_ids)["results"]
        else:
            return self.app.users_overview(user_ids)

    def analytics(
        self,
        filters: Optional[str] = None,
        analysis_types: Optional[str] = None,
    ):
        filter_criteria = FilterCriteria(filters=filters)
        analysis_types = AnalysisTypes(analysis_types=analysis_types)

        if self.client_mode:
            return self.client.analytics(
                filter_criteria=filter_criteria.model_dump(),
                analysis_types=analysis_types.model_dump(),
            )["results"]
        else:
            return self.app.analytics(
                filter_criteria=filter_criteria, analysis_types=analysis_types
            )

    def ingest_sample_file(self, no_media: bool = True, option: int = 0):
        from r2r.examples.scripts.sample_data_ingestor import (
            SampleDataIngestor,
        )

        """Ingest the first sample file into R2R."""
        sample_ingestor = SampleDataIngestor(self)
        return sample_ingestor.ingest_sample_file(
            no_media=no_media, option=option
        )

    def ingest_sample_files(self, no_media: bool = True):
        from r2r.examples.scripts.sample_data_ingestor import (
            SampleDataIngestor,
        )

        """Ingest the first sample file into R2R."""
        sample_ingestor = SampleDataIngestor(self)
        return sample_ingestor.ingest_sample_files(no_media=no_media)

    def inspect_knowledge_graph(self, limit: int = 100) -> str:
        if self.client_mode:
            return self.client.inspect_knowledge_graph(limit)["results"]
        else:
            return self.engine.inspect_knowledge_graph(limit)

    def health(self) -> str:
        if self.client_mode:
            return self.client.health()
        else:
            pass

    def get_app(self):
        if not self.client_mode:
            return self.app.app.app
        else:
            raise Exception(
                "`get_app` method is only available when running with `client_mode=False`."
            )


if __name__ == "__main__":
    import fire

    fire.Fire(R2RExecutionWrapper)