import asyncio
import functools
import json
import os
import threading
import time
import uuid
from contextlib import ExitStack
from typing import Any, AsyncGenerator, Generator, Optional, Union
import fire
import httpx
import nest_asyncio
import requests
from .requests import (
R2RAnalyticsRequest,
R2RDeleteRequest,
R2RDocumentChunksRequest,
R2RDocumentsOverviewRequest,
R2RIngestFilesRequest,
R2RLogsRequest,
R2RPrintRelationshipsRequest,
R2RRAGRequest,
R2RSearchRequest,
R2RUpdateFilesRequest,
R2RUpdatePromptRequest,
R2RUsersOverviewRequest,
)
nest_asyncio.apply()
class R2RHTTPError(Exception):
def __init__(self, status_code, error_type, message):
self.status_code = status_code
self.error_type = error_type
self.message = message
super().__init__(f"[{status_code}] {error_type}: {message}")
def handle_request_error(response):
if response.status_code >= 400:
try:
error_content = response.json()
if isinstance(error_content, dict) and "detail" in error_content:
detail = error_content["detail"]
if isinstance(detail, dict):
message = detail.get("message", str(response.text))
error_type = detail.get("error_type", "UnknownError")
else:
message = str(detail)
error_type = "HTTPException"
else:
message = str(error_content)
error_type = "UnknownError"
except json.JSONDecodeError:
message = response.text
error_type = "UnknownError"
raise R2RHTTPError(
status_code=response.status_code,
error_type=error_type,
message=message,
)
def monitor_request(func):
@functools.wraps(func)
def wrapper(*args, monitor=False, **kwargs):
if not monitor:
return func(*args, **kwargs)
result = None
exception = None
def run_func():
nonlocal result, exception
try:
result = func(*args, **kwargs)
except Exception as e:
exception = e
thread = threading.Thread(target=run_func)
thread.start()
dots = [".", "..", "..."]
i = 0
while thread.is_alive():
print(f"\rRequesting{dots[i % 3]}", end="", flush=True)
i += 1
time.sleep(0.5)
thread.join()
print("\r", end="", flush=True)
if exception:
raise exception
return result
return wrapper
class R2RClient:
def __init__(self, base_url: str, prefix: str = "/v1"):
self.base_url = base_url
self.prefix = prefix
def _make_request(self, method, endpoint, **kwargs):
url = f"{self.base_url}{self.prefix}/{endpoint}"
response = requests.request(method, url, **kwargs)
handle_request_error(response)
return response.json()
def health(self) -> dict:
return self._make_request("GET", "health")
def update_prompt(
self,
name: str = "default_system",
template: Optional[str] = None,
input_types: Optional[dict] = None,
) -> dict:
request = R2RUpdatePromptRequest(
name=name, template=template, input_types=input_types
)
return self._make_request(
"POST", "update_prompt", json=json.loads(request.json())
)
@monitor_request
def ingest_files(
self,
file_paths: list[str],
metadatas: Optional[list[dict]] = None,
document_ids: Optional[list[Union[uuid.UUID, str]]] = None,
versions: Optional[list[str]] = None,
) -> dict:
all_file_paths = []
for path in file_paths:
if os.path.isdir(path):
for root, _, files in os.walk(path):
all_file_paths.extend(
os.path.join(root, file) for file in files
)
else:
all_file_paths.append(path)
files_to_upload = [
(
"files",
(
os.path.basename(file),
open(file, "rb"),
"application/octet-stream",
),
)
for file in all_file_paths
]
request = R2RIngestFilesRequest(
metadatas=metadatas,
document_ids=(
[str(ele) for ele in document_ids] if document_ids else None
),
versions=versions,
)
try:
return self._make_request(
"POST",
"ingest_files",
data={
k: json.dumps(v)
for k, v in json.loads(request.json()).items()
},
files=files_to_upload,
)
finally:
for _, file_tuple in files_to_upload:
file_tuple[1].close()
@monitor_request
def update_files(
self,
file_paths: list[str],
document_ids: list[str],
metadatas: Optional[list[dict]] = None,
) -> dict:
request = R2RUpdateFilesRequest(
metadatas=metadatas,
document_ids=document_ids,
)
with ExitStack() as stack:
return self._make_request(
"POST",
"update_files",
data={
k: json.dumps(v)
for k, v in json.loads(request.json()).items()
},
files=[
(
"files",
(
path.split("/")[-1],
stack.enter_context(open(path, "rb")),
"application/octet-stream",
),
)
for path in file_paths
],
)
def search(
self,
query: str,
use_vector_search: bool = True,
search_filters: Optional[dict[str, Any]] = {},
search_limit: int = 10,
do_hybrid_search: bool = False,
use_kg_search: bool = False,
kg_agent_generation_config: Optional[dict] = None,
) -> dict:
request = R2RSearchRequest(
query=query,
vector_search_settings={
"use_vector_search": use_vector_search,
"search_filters": search_filters or {},
"search_limit": search_limit,
"do_hybrid_search": do_hybrid_search,
},
kg_search_settings={
"use_kg_search": use_kg_search,
"agent_generation_config": kg_agent_generation_config,
},
)
return self._make_request(
"POST", "search", json=json.loads(request.json())
)
def rag(
self,
query: str,
use_vector_search: bool = True,
search_filters: Optional[dict[str, Any]] = {},
search_limit: int = 10,
do_hybrid_search: bool = False,
use_kg_search: bool = False,
kg_agent_generation_config: Optional[dict] = None,
rag_generation_config: Optional[dict] = None,
) -> dict:
request = R2RRAGRequest(
query=query,
vector_search_settings={
"use_vector_search": use_vector_search,
"search_filters": search_filters or {},
"search_limit": search_limit,
"do_hybrid_search": do_hybrid_search,
},
kg_search_settings={
"use_kg_search": use_kg_search,
"agent_generation_config": kg_agent_generation_config,
},
rag_generation_config=rag_generation_config,
)
if rag_generation_config and rag_generation_config.get(
"stream", False
):
return self._stream_rag_sync(request)
else:
return self._make_request(
"POST", "rag", json=json.loads(request.json())
)
async def _stream_rag(
self, rag_request: R2RRAGRequest
) -> AsyncGenerator[str, None]:
url = f"{self.base_url}{self.prefix}/rag"
async with httpx.AsyncClient() as client:
async with client.stream(
"POST", url, json=json.loads(rag_request.json())
) as response:
handle_request_error(response)
async for chunk in response.aiter_text():
yield chunk
def _stream_rag_sync(
self, rag_request: R2RRAGRequest
) -> Generator[str, None, None]:
async def run_async_generator():
async for chunk in self._stream_rag(rag_request):
yield chunk
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async_gen = run_async_generator()
try:
while True:
chunk = loop.run_until_complete(async_gen.__anext__())
yield chunk
except StopAsyncIteration:
pass
finally:
loop.close()
def delete(
self, keys: list[str], values: list[Union[bool, int, str]]
) -> dict:
request = R2RDeleteRequest(keys=keys, values=values)
return self._make_request(
"DELETE", "delete", json=json.loads(request.json())
)
def logs(self, log_type_filter: Optional[str] = None) -> dict:
request = R2RLogsRequest(log_type_filter=log_type_filter)
return self._make_request(
"GET", "logs", json=json.loads(request.json())
)
def app_settings(self) -> dict:
return self._make_request("GET", "app_settings")
def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict:
request = R2RAnalyticsRequest(
filter_criteria=filter_criteria, analysis_types=analysis_types
)
return self._make_request(
"GET", "analytics", json=json.loads(request.json())
)
def users_overview(
self, user_ids: Optional[list[uuid.UUID]] = None
) -> dict:
request = R2RUsersOverviewRequest(user_ids=user_ids)
return self._make_request(
"GET", "users_overview", json=json.loads(request.json())
)
def documents_overview(
self,
document_ids: Optional[list[str]] = None,
user_ids: Optional[list[str]] = None,
) -> dict:
request = R2RDocumentsOverviewRequest(
document_ids=(
[uuid.UUID(did) for did in document_ids]
if document_ids
else None
),
user_ids=(
[uuid.UUID(uid) for uid in user_ids] if user_ids else None
),
)
return self._make_request(
"GET", "documents_overview", json=json.loads(request.json())
)
def document_chunks(self, document_id: str) -> dict:
request = R2RDocumentChunksRequest(document_id=document_id)
return self._make_request(
"GET", "document_chunks", json=json.loads(request.json())
)
def inspect_knowledge_graph(self, limit: int = 100) -> str:
request = R2RPrintRelationshipsRequest(limit=limit)
return self._make_request(
"POST", "inspect_knowledge_graph", json=json.loads(request.json())
)
if __name__ == "__main__":
client = R2RClient(base_url="http://localhost:8000")
fire.Fire(client)