diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/api/client.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/main/api/client.py')
-rwxr-xr-x | R2R/r2r/main/api/client.py | 377 |
1 files changed, 377 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/client.py b/R2R/r2r/main/api/client.py new file mode 100755 index 00000000..b0f5b966 --- /dev/null +++ b/R2R/r2r/main/api/client.py @@ -0,0 +1,377 @@ +import asyncio +import functools +import json +import os +import threading +import time +import uuid +from contextlib import ExitStack +from typing import Any, AsyncGenerator, Generator, Optional, Union + +import fire +import httpx +import nest_asyncio +import requests + +from .requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2RIngestFilesRequest, + R2RLogsRequest, + R2RPrintRelationshipsRequest, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) + +nest_asyncio.apply() + + +class R2RHTTPError(Exception): + def __init__(self, status_code, error_type, message): + self.status_code = status_code + self.error_type = error_type + self.message = message + super().__init__(f"[{status_code}] {error_type}: {message}") + + +def handle_request_error(response): + if response.status_code >= 400: + try: + error_content = response.json() + if isinstance(error_content, dict) and "detail" in error_content: + detail = error_content["detail"] + if isinstance(detail, dict): + message = detail.get("message", str(response.text)) + error_type = detail.get("error_type", "UnknownError") + else: + message = str(detail) + error_type = "HTTPException" + else: + message = str(error_content) + error_type = "UnknownError" + except json.JSONDecodeError: + message = response.text + error_type = "UnknownError" + + raise R2RHTTPError( + status_code=response.status_code, + error_type=error_type, + message=message, + ) + + +def monitor_request(func): + @functools.wraps(func) + def wrapper(*args, monitor=False, **kwargs): + if not monitor: + return func(*args, **kwargs) + + result = None + exception = None + + def run_func(): + nonlocal result, exception + try: + result = func(*args, **kwargs) + except Exception as e: + exception = e + + thread = threading.Thread(target=run_func) + thread.start() + + dots = [".", "..", "..."] + i = 0 + while thread.is_alive(): + print(f"\rRequesting{dots[i % 3]}", end="", flush=True) + i += 1 + time.sleep(0.5) + + thread.join() + + print("\r", end="", flush=True) + + if exception: + raise exception + return result + + return wrapper + + +class R2RClient: + def __init__(self, base_url: str, prefix: str = "/v1"): + self.base_url = base_url + self.prefix = prefix + + def _make_request(self, method, endpoint, **kwargs): + url = f"{self.base_url}{self.prefix}/{endpoint}" + response = requests.request(method, url, **kwargs) + handle_request_error(response) + return response.json() + + def health(self) -> dict: + return self._make_request("GET", "health") + + def update_prompt( + self, + name: str = "default_system", + template: Optional[str] = None, + input_types: Optional[dict] = None, + ) -> dict: + request = R2RUpdatePromptRequest( + name=name, template=template, input_types=input_types + ) + return self._make_request( + "POST", "update_prompt", json=json.loads(request.json()) + ) + + @monitor_request + def ingest_files( + self, + file_paths: list[str], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[Union[uuid.UUID, str]]] = None, + versions: Optional[list[str]] = None, + ) -> dict: + all_file_paths = [] + + for path in file_paths: + if os.path.isdir(path): + for root, _, files in os.walk(path): + all_file_paths.extend( + os.path.join(root, file) for file in files + ) + else: + all_file_paths.append(path) + + files_to_upload = [ + ( + "files", + ( + os.path.basename(file), + open(file, "rb"), + "application/octet-stream", + ), + ) + for file in all_file_paths + ] + request = R2RIngestFilesRequest( + metadatas=metadatas, + document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + versions=versions, + ) + try: + return self._make_request( + "POST", + "ingest_files", + data={ + k: json.dumps(v) + for k, v in json.loads(request.json()).items() + }, + files=files_to_upload, + ) + finally: + for _, file_tuple in files_to_upload: + file_tuple[1].close() + + @monitor_request + def update_files( + self, + file_paths: list[str], + document_ids: list[str], + metadatas: Optional[list[dict]] = None, + ) -> dict: + request = R2RUpdateFilesRequest( + metadatas=metadatas, + document_ids=document_ids, + ) + with ExitStack() as stack: + return self._make_request( + "POST", + "update_files", + data={ + k: json.dumps(v) + for k, v in json.loads(request.json()).items() + }, + files=[ + ( + "files", + ( + path.split("/")[-1], + stack.enter_context(open(path, "rb")), + "application/octet-stream", + ), + ) + for path in file_paths + ], + ) + + def search( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + ) -> dict: + request = R2RSearchRequest( + query=query, + vector_search_settings={ + "use_vector_search": use_vector_search, + "search_filters": search_filters or {}, + "search_limit": search_limit, + "do_hybrid_search": do_hybrid_search, + }, + kg_search_settings={ + "use_kg_search": use_kg_search, + "agent_generation_config": kg_agent_generation_config, + }, + ) + return self._make_request( + "POST", "search", json=json.loads(request.json()) + ) + + def rag( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + rag_generation_config: Optional[dict] = None, + ) -> dict: + request = R2RRAGRequest( + query=query, + vector_search_settings={ + "use_vector_search": use_vector_search, + "search_filters": search_filters or {}, + "search_limit": search_limit, + "do_hybrid_search": do_hybrid_search, + }, + kg_search_settings={ + "use_kg_search": use_kg_search, + "agent_generation_config": kg_agent_generation_config, + }, + rag_generation_config=rag_generation_config, + ) + + if rag_generation_config and rag_generation_config.get( + "stream", False + ): + return self._stream_rag_sync(request) + else: + return self._make_request( + "POST", "rag", json=json.loads(request.json()) + ) + + async def _stream_rag( + self, rag_request: R2RRAGRequest + ) -> AsyncGenerator[str, None]: + url = f"{self.base_url}{self.prefix}/rag" + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", url, json=json.loads(rag_request.json()) + ) as response: + handle_request_error(response) + async for chunk in response.aiter_text(): + yield chunk + + def _stream_rag_sync( + self, rag_request: R2RRAGRequest + ) -> Generator[str, None, None]: + async def run_async_generator(): + async for chunk in self._stream_rag(rag_request): + yield chunk + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async_gen = run_async_generator() + + try: + while True: + chunk = loop.run_until_complete(async_gen.__anext__()) + yield chunk + except StopAsyncIteration: + pass + finally: + loop.close() + + def delete( + self, keys: list[str], values: list[Union[bool, int, str]] + ) -> dict: + request = R2RDeleteRequest(keys=keys, values=values) + return self._make_request( + "DELETE", "delete", json=json.loads(request.json()) + ) + + def logs(self, log_type_filter: Optional[str] = None) -> dict: + request = R2RLogsRequest(log_type_filter=log_type_filter) + return self._make_request( + "GET", "logs", json=json.loads(request.json()) + ) + + def app_settings(self) -> dict: + return self._make_request("GET", "app_settings") + + def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict: + request = R2RAnalyticsRequest( + filter_criteria=filter_criteria, analysis_types=analysis_types + ) + return self._make_request( + "GET", "analytics", json=json.loads(request.json()) + ) + + def users_overview( + self, user_ids: Optional[list[uuid.UUID]] = None + ) -> dict: + request = R2RUsersOverviewRequest(user_ids=user_ids) + return self._make_request( + "GET", "users_overview", json=json.loads(request.json()) + ) + + def documents_overview( + self, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[str]] = None, + ) -> dict: + request = R2RDocumentsOverviewRequest( + document_ids=( + [uuid.UUID(did) for did in document_ids] + if document_ids + else None + ), + user_ids=( + [uuid.UUID(uid) for uid in user_ids] if user_ids else None + ), + ) + return self._make_request( + "GET", "documents_overview", json=json.loads(request.json()) + ) + + def document_chunks(self, document_id: str) -> dict: + request = R2RDocumentChunksRequest(document_id=document_id) + return self._make_request( + "GET", "document_chunks", json=json.loads(request.json()) + ) + + def inspect_knowledge_graph(self, limit: int = 100) -> str: + request = R2RPrintRelationshipsRequest(limit=limit) + return self._make_request( + "POST", "inspect_knowledge_graph", json=json.loads(request.json()) + ) + + +if __name__ == "__main__": + client = R2RClient(base_url="http://localhost:8000") + fire.Fire(client) |