about summary refs log tree commit diff
path: root/R2R/r2r/main/execution.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/execution.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/main/execution.py')
-rwxr-xr-xR2R/r2r/main/execution.py421
1 files changed, 421 insertions, 0 deletions
diff --git a/R2R/r2r/main/execution.py b/R2R/r2r/main/execution.py
new file mode 100755
index 00000000..187a2eea
--- /dev/null
+++ b/R2R/r2r/main/execution.py
@@ -0,0 +1,421 @@
+import ast
+import asyncio
+import json
+import os
+import uuid
+from typing import Optional, Union
+
+from fastapi import UploadFile
+
+from r2r.base import (
+    AnalysisTypes,
+    FilterCriteria,
+    GenerationConfig,
+    KGSearchSettings,
+    VectorSearchSettings,
+    generate_id_from_label,
+)
+
+from .api.client import R2RClient
+from .assembly.builder import R2RBuilder
+from .assembly.config import R2RConfig
+from .r2r import R2R
+
+
+class R2RExecutionWrapper:
+    """A demo class for the R2R library."""
+
+    def __init__(
+        self,
+        config_path: Optional[str] = None,
+        config_name: Optional[str] = "default",
+        client_mode: bool = True,
+        base_url="http://localhost:8000",
+    ):
+        if config_path and config_name:
+            raise Exception("Cannot specify both config_path and config_name")
+
+        # Handle fire CLI
+        if isinstance(client_mode, str):
+            client_mode = client_mode.lower() == "true"
+        self.client_mode = client_mode
+        self.base_url = base_url
+
+        if self.client_mode:
+            self.client = R2RClient(base_url)
+            self.app = None
+        else:
+            config = (
+                R2RConfig.from_json(config_path)
+                if config_path
+                else R2RConfig.from_json(
+                    R2RBuilder.CONFIG_OPTIONS[config_name or "default"]
+                )
+            )
+
+            self.client = None
+            self.app = R2R(config=config)
+
+    def serve(self, host: str = "0.0.0.0", port: int = 8000):
+        if not self.client_mode:
+            self.app.serve(host, port)
+        else:
+            raise ValueError(
+                "Serve method is only available when `client_mode=False`."
+            )
+
+    def _parse_metadata_string(metadata_string: str) -> list[dict]:
+        """
+        Convert a string representation of metadata into a list of dictionaries.
+
+        The input string can be in one of two formats:
+        1. JSON array of objects: '[{"key": "value"}, {"key2": "value2"}]'
+        2. Python-like list of dictionaries: "[{'key': 'value'}, {'key2': 'value2'}]"
+
+        Args:
+        metadata_string (str): The string representation of metadata.
+
+        Returns:
+        list[dict]: A list of dictionaries representing the metadata.
+
+        Raises:
+        ValueError: If the string cannot be parsed into a list of dictionaries.
+        """
+        if not metadata_string:
+            return []
+
+        try:
+            # First, try to parse as JSON
+            return json.loads(metadata_string)
+        except json.JSONDecodeError as e:
+            try:
+                # If JSON parsing fails, try to evaluate as a Python literal
+                result = ast.literal_eval(metadata_string)
+                if not isinstance(result, list) or not all(
+                    isinstance(item, dict) for item in result
+                ):
+                    raise ValueError(
+                        "The string does not represent a list of dictionaries"
+                    ) from e
+                return result
+            except (ValueError, SyntaxError) as exc:
+                raise ValueError(
+                    "Unable to parse the metadata string. "
+                    "Please ensure it's a valid JSON array or Python list of dictionaries."
+                ) from exc
+
+    def ingest_files(
+        self,
+        file_paths: list[str],
+        metadatas: Optional[list[dict]] = None,
+        document_ids: Optional[list[Union[uuid.UUID, str]]] = None,
+        versions: Optional[list[str]] = None,
+    ):
+        if isinstance(file_paths, str):
+            file_paths = list(file_paths.split(","))
+        if isinstance(metadatas, str):
+            metadatas = self._parse_metadata_string(metadatas)
+        if isinstance(document_ids, str):
+            document_ids = list(document_ids.split(","))
+        if isinstance(versions, str):
+            versions = list(versions.split(","))
+
+        all_file_paths = []
+        for path in file_paths:
+            if os.path.isdir(path):
+                for root, _, files in os.walk(path):
+                    all_file_paths.extend(
+                        os.path.join(root, file) for file in files
+                    )
+            else:
+                all_file_paths.append(path)
+
+        if not document_ids:
+            document_ids = [
+                generate_id_from_label(os.path.basename(file_path))
+                for file_path in all_file_paths
+            ]
+
+        files = [
+            UploadFile(
+                filename=os.path.basename(file_path),
+                file=open(file_path, "rb"),
+            )
+            for file_path in all_file_paths
+        ]
+
+        for file in files:
+            file.file.seek(0, 2)
+            file.size = file.file.tell()
+            file.file.seek(0)
+
+        try:
+            if self.client_mode:
+                return self.client.ingest_files(
+                    file_paths=all_file_paths,
+                    document_ids=document_ids,
+                    metadatas=metadatas,
+                    versions=versions,
+                    monitor=True,
+                )["results"]
+            else:
+                return self.app.ingest_files(
+                    files=files,
+                    document_ids=document_ids,
+                    metadatas=metadatas,
+                    versions=versions,
+                )
+        finally:
+            for file in files:
+                file.file.close()
+
+    def update_files(
+        self,
+        file_paths: list[str],
+        document_ids: list[str],
+        metadatas: Optional[list[dict]] = None,
+    ):
+        if isinstance(file_paths, str):
+            file_paths = list(file_paths.split(","))
+        if isinstance(metadatas, str):
+            metadatas = self._parse_metadata_string(metadatas)
+        if isinstance(document_ids, str):
+            document_ids = list(document_ids.split(","))
+
+        if self.client_mode:
+            return self.client.update_files(
+                file_paths=file_paths,
+                document_ids=document_ids,
+                metadatas=metadatas,
+                monitor=True,
+            )["results"]
+        else:
+            files = [
+                UploadFile(
+                    filename=file_path,
+                    file=open(file_path, "rb"),
+                )
+                for file_path in file_paths
+            ]
+            return self.app.update_files(
+                files=files, document_ids=document_ids, metadatas=metadatas
+            )
+
+    def search(
+        self,
+        query: str,
+        use_vector_search: bool = True,
+        search_filters: Optional[dict] = None,
+        search_limit: int = 10,
+        do_hybrid_search: bool = False,
+        use_kg_search: bool = False,
+        kg_agent_generation_config: Optional[dict] = None,
+    ):
+        if self.client_mode:
+            return self.client.search(
+                query,
+                use_vector_search,
+                search_filters,
+                search_limit,
+                do_hybrid_search,
+                use_kg_search,
+                kg_agent_generation_config,
+            )["results"]
+        else:
+            return self.app.search(
+                query,
+                VectorSearchSettings(
+                    use_vector_search=use_vector_search,
+                    search_filters=search_filters or {},
+                    search_limit=search_limit,
+                    do_hybrid_search=do_hybrid_search,
+                ),
+                KGSearchSettings(
+                    use_kg_search=use_kg_search,
+                    agent_generation_config=GenerationConfig(
+                        **(kg_agent_generation_config or {})
+                    ),
+                ),
+            )
+
+    def rag(
+        self,
+        query: str,
+        use_vector_search: bool = True,
+        search_filters: Optional[dict] = None,
+        search_limit: int = 10,
+        do_hybrid_search: bool = False,
+        use_kg_search: bool = False,
+        kg_agent_generation_config: Optional[dict] = None,
+        stream: bool = False,
+        rag_generation_config: Optional[dict] = None,
+    ):
+        if self.client_mode:
+            response = self.client.rag(
+                query=query,
+                use_vector_search=use_vector_search,
+                search_filters=search_filters or {},
+                search_limit=search_limit,
+                do_hybrid_search=do_hybrid_search,
+                use_kg_search=use_kg_search,
+                kg_agent_generation_config=kg_agent_generation_config,
+                rag_generation_config=rag_generation_config,
+            )
+            if not stream:
+                response = response["results"]
+                return response
+            else:
+                return response
+        else:
+            response = self.app.rag(
+                query,
+                vector_search_settings=VectorSearchSettings(
+                    use_vector_search=use_vector_search,
+                    search_filters=search_filters or {},
+                    search_limit=search_limit,
+                    do_hybrid_search=do_hybrid_search,
+                ),
+                kg_search_settings=KGSearchSettings(
+                    use_kg_search=use_kg_search,
+                    agent_generation_config=GenerationConfig(
+                        **(kg_agent_generation_config or {})
+                    ),
+                ),
+                rag_generation_config=GenerationConfig(
+                    **(rag_generation_config or {})
+                ),
+            )
+            if not stream:
+                return response
+            else:
+
+                async def async_generator():
+                    async for chunk in response:
+                        yield chunk
+
+                def sync_generator():
+                    try:
+                        loop = asyncio.get_event_loop()
+                        async_gen = async_generator()
+                        while True:
+                            try:
+                                yield loop.run_until_complete(
+                                    async_gen.__anext__()
+                                )
+                            except StopAsyncIteration:
+                                break
+                    except Exception:
+                        pass
+
+                return sync_generator()
+
+    def documents_overview(
+        self,
+        document_ids: Optional[list[str]] = None,
+        user_ids: Optional[list[str]] = None,
+    ):
+        if self.client_mode:
+            return self.client.documents_overview(document_ids, user_ids)[
+                "results"
+            ]
+        else:
+            return self.app.documents_overview(document_ids, user_ids)
+
+    def delete(
+        self,
+        keys: list[str],
+        values: list[str],
+    ):
+        if self.client_mode:
+            return self.client.delete(keys, values)["results"]
+        else:
+            return self.app.delete(keys, values)
+
+    def logs(self, log_type_filter: Optional[str] = None):
+        if self.client_mode:
+            return self.client.logs(log_type_filter)["results"]
+        else:
+            return self.app.logs(log_type_filter)
+
+    def document_chunks(self, document_id: str):
+        doc_uuid = uuid.UUID(document_id)
+        if self.client_mode:
+            return self.client.document_chunks(doc_uuid)["results"]
+        else:
+            return self.app.document_chunks(doc_uuid)
+
+    def app_settings(self):
+        if self.client_mode:
+            return self.client.app_settings()
+        else:
+            return self.app.app_settings()
+
+    def users_overview(self, user_ids: Optional[list[uuid.UUID]] = None):
+        if self.client_mode:
+            return self.client.users_overview(user_ids)["results"]
+        else:
+            return self.app.users_overview(user_ids)
+
+    def analytics(
+        self,
+        filters: Optional[str] = None,
+        analysis_types: Optional[str] = None,
+    ):
+        filter_criteria = FilterCriteria(filters=filters)
+        analysis_types = AnalysisTypes(analysis_types=analysis_types)
+
+        if self.client_mode:
+            return self.client.analytics(
+                filter_criteria=filter_criteria.model_dump(),
+                analysis_types=analysis_types.model_dump(),
+            )["results"]
+        else:
+            return self.app.analytics(
+                filter_criteria=filter_criteria, analysis_types=analysis_types
+            )
+
+    def ingest_sample_file(self, no_media: bool = True, option: int = 0):
+        from r2r.examples.scripts.sample_data_ingestor import (
+            SampleDataIngestor,
+        )
+
+        """Ingest the first sample file into R2R."""
+        sample_ingestor = SampleDataIngestor(self)
+        return sample_ingestor.ingest_sample_file(
+            no_media=no_media, option=option
+        )
+
+    def ingest_sample_files(self, no_media: bool = True):
+        from r2r.examples.scripts.sample_data_ingestor import (
+            SampleDataIngestor,
+        )
+
+        """Ingest the first sample file into R2R."""
+        sample_ingestor = SampleDataIngestor(self)
+        return sample_ingestor.ingest_sample_files(no_media=no_media)
+
+    def inspect_knowledge_graph(self, limit: int = 100) -> str:
+        if self.client_mode:
+            return self.client.inspect_knowledge_graph(limit)["results"]
+        else:
+            return self.engine.inspect_knowledge_graph(limit)
+
+    def health(self) -> str:
+        if self.client_mode:
+            return self.client.health()
+        else:
+            pass
+
+    def get_app(self):
+        if not self.client_mode:
+            return self.app.app.app
+        else:
+            raise Exception(
+                "`get_app` method is only available when running with `client_mode=False`."
+            )
+
+
+if __name__ == "__main__":
+    import fire
+
+    fire.Fire(R2RExecutionWrapper)