about summary refs log tree commit diff
path: root/R2R/r2r/main/api/client.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/api/client.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/main/api/client.py')
-rwxr-xr-xR2R/r2r/main/api/client.py377
1 files changed, 377 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/client.py b/R2R/r2r/main/api/client.py
new file mode 100755
index 00000000..b0f5b966
--- /dev/null
+++ b/R2R/r2r/main/api/client.py
@@ -0,0 +1,377 @@
+import asyncio
+import functools
+import json
+import os
+import threading
+import time
+import uuid
+from contextlib import ExitStack
+from typing import Any, AsyncGenerator, Generator, Optional, Union
+
+import fire
+import httpx
+import nest_asyncio
+import requests
+
+from .requests import (
+    R2RAnalyticsRequest,
+    R2RDeleteRequest,
+    R2RDocumentChunksRequest,
+    R2RDocumentsOverviewRequest,
+    R2RIngestFilesRequest,
+    R2RLogsRequest,
+    R2RPrintRelationshipsRequest,
+    R2RRAGRequest,
+    R2RSearchRequest,
+    R2RUpdateFilesRequest,
+    R2RUpdatePromptRequest,
+    R2RUsersOverviewRequest,
+)
+
+nest_asyncio.apply()
+
+
+class R2RHTTPError(Exception):
+    def __init__(self, status_code, error_type, message):
+        self.status_code = status_code
+        self.error_type = error_type
+        self.message = message
+        super().__init__(f"[{status_code}] {error_type}: {message}")
+
+
+def handle_request_error(response):
+    if response.status_code >= 400:
+        try:
+            error_content = response.json()
+            if isinstance(error_content, dict) and "detail" in error_content:
+                detail = error_content["detail"]
+                if isinstance(detail, dict):
+                    message = detail.get("message", str(response.text))
+                    error_type = detail.get("error_type", "UnknownError")
+                else:
+                    message = str(detail)
+                    error_type = "HTTPException"
+            else:
+                message = str(error_content)
+                error_type = "UnknownError"
+        except json.JSONDecodeError:
+            message = response.text
+            error_type = "UnknownError"
+
+        raise R2RHTTPError(
+            status_code=response.status_code,
+            error_type=error_type,
+            message=message,
+        )
+
+
+def monitor_request(func):
+    @functools.wraps(func)
+    def wrapper(*args, monitor=False, **kwargs):
+        if not monitor:
+            return func(*args, **kwargs)
+
+        result = None
+        exception = None
+
+        def run_func():
+            nonlocal result, exception
+            try:
+                result = func(*args, **kwargs)
+            except Exception as e:
+                exception = e
+
+        thread = threading.Thread(target=run_func)
+        thread.start()
+
+        dots = [".", "..", "..."]
+        i = 0
+        while thread.is_alive():
+            print(f"\rRequesting{dots[i % 3]}", end="", flush=True)
+            i += 1
+            time.sleep(0.5)
+
+        thread.join()
+
+        print("\r", end="", flush=True)
+
+        if exception:
+            raise exception
+        return result
+
+    return wrapper
+
+
+class R2RClient:
+    def __init__(self, base_url: str, prefix: str = "/v1"):
+        self.base_url = base_url
+        self.prefix = prefix
+
+    def _make_request(self, method, endpoint, **kwargs):
+        url = f"{self.base_url}{self.prefix}/{endpoint}"
+        response = requests.request(method, url, **kwargs)
+        handle_request_error(response)
+        return response.json()
+
+    def health(self) -> dict:
+        return self._make_request("GET", "health")
+
+    def update_prompt(
+        self,
+        name: str = "default_system",
+        template: Optional[str] = None,
+        input_types: Optional[dict] = None,
+    ) -> dict:
+        request = R2RUpdatePromptRequest(
+            name=name, template=template, input_types=input_types
+        )
+        return self._make_request(
+            "POST", "update_prompt", json=json.loads(request.json())
+        )
+
+    @monitor_request
+    def ingest_files(
+        self,
+        file_paths: list[str],
+        metadatas: Optional[list[dict]] = None,
+        document_ids: Optional[list[Union[uuid.UUID, str]]] = None,
+        versions: Optional[list[str]] = None,
+    ) -> dict:
+        all_file_paths = []
+
+        for path in file_paths:
+            if os.path.isdir(path):
+                for root, _, files in os.walk(path):
+                    all_file_paths.extend(
+                        os.path.join(root, file) for file in files
+                    )
+            else:
+                all_file_paths.append(path)
+
+        files_to_upload = [
+            (
+                "files",
+                (
+                    os.path.basename(file),
+                    open(file, "rb"),
+                    "application/octet-stream",
+                ),
+            )
+            for file in all_file_paths
+        ]
+        request = R2RIngestFilesRequest(
+            metadatas=metadatas,
+            document_ids=(
+                [str(ele) for ele in document_ids] if document_ids else None
+            ),
+            versions=versions,
+        )
+        try:
+            return self._make_request(
+                "POST",
+                "ingest_files",
+                data={
+                    k: json.dumps(v)
+                    for k, v in json.loads(request.json()).items()
+                },
+                files=files_to_upload,
+            )
+        finally:
+            for _, file_tuple in files_to_upload:
+                file_tuple[1].close()
+
+    @monitor_request
+    def update_files(
+        self,
+        file_paths: list[str],
+        document_ids: list[str],
+        metadatas: Optional[list[dict]] = None,
+    ) -> dict:
+        request = R2RUpdateFilesRequest(
+            metadatas=metadatas,
+            document_ids=document_ids,
+        )
+        with ExitStack() as stack:
+            return self._make_request(
+                "POST",
+                "update_files",
+                data={
+                    k: json.dumps(v)
+                    for k, v in json.loads(request.json()).items()
+                },
+                files=[
+                    (
+                        "files",
+                        (
+                            path.split("/")[-1],
+                            stack.enter_context(open(path, "rb")),
+                            "application/octet-stream",
+                        ),
+                    )
+                    for path in file_paths
+                ],
+            )
+
+    def search(
+        self,
+        query: str,
+        use_vector_search: bool = True,
+        search_filters: Optional[dict[str, Any]] = {},
+        search_limit: int = 10,
+        do_hybrid_search: bool = False,
+        use_kg_search: bool = False,
+        kg_agent_generation_config: Optional[dict] = None,
+    ) -> dict:
+        request = R2RSearchRequest(
+            query=query,
+            vector_search_settings={
+                "use_vector_search": use_vector_search,
+                "search_filters": search_filters or {},
+                "search_limit": search_limit,
+                "do_hybrid_search": do_hybrid_search,
+            },
+            kg_search_settings={
+                "use_kg_search": use_kg_search,
+                "agent_generation_config": kg_agent_generation_config,
+            },
+        )
+        return self._make_request(
+            "POST", "search", json=json.loads(request.json())
+        )
+
+    def rag(
+        self,
+        query: str,
+        use_vector_search: bool = True,
+        search_filters: Optional[dict[str, Any]] = {},
+        search_limit: int = 10,
+        do_hybrid_search: bool = False,
+        use_kg_search: bool = False,
+        kg_agent_generation_config: Optional[dict] = None,
+        rag_generation_config: Optional[dict] = None,
+    ) -> dict:
+        request = R2RRAGRequest(
+            query=query,
+            vector_search_settings={
+                "use_vector_search": use_vector_search,
+                "search_filters": search_filters or {},
+                "search_limit": search_limit,
+                "do_hybrid_search": do_hybrid_search,
+            },
+            kg_search_settings={
+                "use_kg_search": use_kg_search,
+                "agent_generation_config": kg_agent_generation_config,
+            },
+            rag_generation_config=rag_generation_config,
+        )
+
+        if rag_generation_config and rag_generation_config.get(
+            "stream", False
+        ):
+            return self._stream_rag_sync(request)
+        else:
+            return self._make_request(
+                "POST", "rag", json=json.loads(request.json())
+            )
+
+    async def _stream_rag(
+        self, rag_request: R2RRAGRequest
+    ) -> AsyncGenerator[str, None]:
+        url = f"{self.base_url}{self.prefix}/rag"
+        async with httpx.AsyncClient() as client:
+            async with client.stream(
+                "POST", url, json=json.loads(rag_request.json())
+            ) as response:
+                handle_request_error(response)
+                async for chunk in response.aiter_text():
+                    yield chunk
+
+    def _stream_rag_sync(
+        self, rag_request: R2RRAGRequest
+    ) -> Generator[str, None, None]:
+        async def run_async_generator():
+            async for chunk in self._stream_rag(rag_request):
+                yield chunk
+
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+
+        async_gen = run_async_generator()
+
+        try:
+            while True:
+                chunk = loop.run_until_complete(async_gen.__anext__())
+                yield chunk
+        except StopAsyncIteration:
+            pass
+        finally:
+            loop.close()
+
+    def delete(
+        self, keys: list[str], values: list[Union[bool, int, str]]
+    ) -> dict:
+        request = R2RDeleteRequest(keys=keys, values=values)
+        return self._make_request(
+            "DELETE", "delete", json=json.loads(request.json())
+        )
+
+    def logs(self, log_type_filter: Optional[str] = None) -> dict:
+        request = R2RLogsRequest(log_type_filter=log_type_filter)
+        return self._make_request(
+            "GET", "logs", json=json.loads(request.json())
+        )
+
+    def app_settings(self) -> dict:
+        return self._make_request("GET", "app_settings")
+
+    def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict:
+        request = R2RAnalyticsRequest(
+            filter_criteria=filter_criteria, analysis_types=analysis_types
+        )
+        return self._make_request(
+            "GET", "analytics", json=json.loads(request.json())
+        )
+
+    def users_overview(
+        self, user_ids: Optional[list[uuid.UUID]] = None
+    ) -> dict:
+        request = R2RUsersOverviewRequest(user_ids=user_ids)
+        return self._make_request(
+            "GET", "users_overview", json=json.loads(request.json())
+        )
+
+    def documents_overview(
+        self,
+        document_ids: Optional[list[str]] = None,
+        user_ids: Optional[list[str]] = None,
+    ) -> dict:
+        request = R2RDocumentsOverviewRequest(
+            document_ids=(
+                [uuid.UUID(did) for did in document_ids]
+                if document_ids
+                else None
+            ),
+            user_ids=(
+                [uuid.UUID(uid) for uid in user_ids] if user_ids else None
+            ),
+        )
+        return self._make_request(
+            "GET", "documents_overview", json=json.loads(request.json())
+        )
+
+    def document_chunks(self, document_id: str) -> dict:
+        request = R2RDocumentChunksRequest(document_id=document_id)
+        return self._make_request(
+            "GET", "document_chunks", json=json.loads(request.json())
+        )
+
+    def inspect_knowledge_graph(self, limit: int = 100) -> str:
+        request = R2RPrintRelationshipsRequest(limit=limit)
+        return self._make_request(
+            "POST", "inspect_knowledge_graph", json=json.loads(request.json())
+        )
+
+
+if __name__ == "__main__":
+    client = R2RClient(base_url="http://localhost:8000")
+    fire.Fire(client)